1/*
2Copyright 2017 Google LLC
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17package spanner
18
19import (
20	"context"
21	"fmt"
22	"io"
23	"io/ioutil"
24	"log"
25	"os"
26	"strings"
27	"testing"
28	"time"
29
30	"cloud.google.com/go/civil"
31	itestutil "cloud.google.com/go/internal/testutil"
32	. "cloud.google.com/go/spanner/internal/testutil"
33	"github.com/golang/protobuf/proto"
34	structpb "github.com/golang/protobuf/ptypes/struct"
35	"google.golang.org/api/iterator"
36	"google.golang.org/api/option"
37	instancepb "google.golang.org/genproto/googleapis/spanner/admin/instance/v1"
38	sppb "google.golang.org/genproto/googleapis/spanner/v1"
39	"google.golang.org/grpc/codes"
40	"google.golang.org/grpc/status"
41)
42
43func setupMockedTestServer(t *testing.T) (server *MockedSpannerInMemTestServer, client *Client, teardown func()) {
44	return setupMockedTestServerWithConfig(t, ClientConfig{})
45}
46
47func setupMockedTestServerWithConfig(t *testing.T, config ClientConfig) (server *MockedSpannerInMemTestServer, client *Client, teardown func()) {
48	return setupMockedTestServerWithConfigAndClientOptions(t, config, []option.ClientOption{})
49}
50
51func setupMockedTestServerWithConfigAndClientOptions(t *testing.T, config ClientConfig, clientOptions []option.ClientOption) (server *MockedSpannerInMemTestServer, client *Client, teardown func()) {
52	grpcHeaderChecker := &itestutil.HeadersEnforcer{
53		OnFailure: t.Fatalf,
54		Checkers: []*itestutil.HeaderChecker{
55			{
56				Key: "x-goog-api-client",
57				ValuesValidator: func(token ...string) error {
58					if len(token) != 1 {
59						return status.Errorf(codes.Internal, "unexpected number of api client token headers: %v", len(token))
60					}
61					if !strings.HasPrefix(token[0], "gl-go/") {
62						return status.Errorf(codes.Internal, "unexpected api client token: %v", token[0])
63					}
64					return nil
65				},
66			},
67		},
68	}
69	clientOptions = append(clientOptions, grpcHeaderChecker.CallOptions()...)
70	server, opts, serverTeardown := NewMockedSpannerInMemTestServer(t)
71	opts = append(opts, clientOptions...)
72	ctx := context.Background()
73	formattedDatabase := fmt.Sprintf("projects/%s/instances/%s/databases/%s", "[PROJECT]", "[INSTANCE]", "[DATABASE]")
74	client, err := NewClientWithConfig(ctx, formattedDatabase, config, opts...)
75	if err != nil {
76		t.Fatal(err)
77	}
78	return server, client, func() {
79		client.Close()
80		serverTeardown()
81	}
82}
83
84// Test validDatabaseName()
85func TestValidDatabaseName(t *testing.T) {
86	validDbURI := "projects/spanner-cloud-test/instances/foo/databases/foodb"
87	invalidDbUris := []string{
88		// Completely wrong DB URI.
89		"foobarDB",
90		// Project ID contains "/".
91		"projects/spanner-cloud/test/instances/foo/databases/foodb",
92		// No instance ID.
93		"projects/spanner-cloud-test/instances//databases/foodb",
94	}
95	if err := validDatabaseName(validDbURI); err != nil {
96		t.Errorf("validateDatabaseName(%q) = %v, want nil", validDbURI, err)
97	}
98	for _, d := range invalidDbUris {
99		if err, wantErr := validDatabaseName(d), "should conform to pattern"; !strings.Contains(err.Error(), wantErr) {
100			t.Errorf("validateDatabaseName(%q) = %q, want error pattern %q", validDbURI, err, wantErr)
101		}
102	}
103}
104
105// Test getInstanceName()
106func TestGetInstanceName(t *testing.T) {
107	validDbURI := "projects/spanner-cloud-test/instances/foo/databases/foodb"
108	invalidDbUris := []string{
109		// Completely wrong DB URI.
110		"foobarDB",
111		// Project ID contains "/".
112		"projects/spanner-cloud/test/instances/foo/databases/foodb",
113		// No instance ID.
114		"projects/spanner-cloud-test/instances//databases/foodb",
115	}
116	want := "projects/spanner-cloud-test/instances/foo"
117	got, err := getInstanceName(validDbURI)
118	if err != nil {
119		t.Errorf("getInstanceName(%q) has an error: %q, want nil", validDbURI, err)
120	}
121	if got != want {
122		t.Errorf("getInstanceName(%q) = %q, want %q", validDbURI, got, want)
123	}
124	for _, d := range invalidDbUris {
125		wantErr := "Failed to retrieve instance name"
126		_, err = getInstanceName(d)
127		if !strings.Contains(err.Error(), wantErr) {
128			t.Errorf("getInstanceName(%q) has an error: %q, want error pattern %q", validDbURI, err, wantErr)
129		}
130	}
131}
132
133func TestReadOnlyTransactionClose(t *testing.T) {
134	// Closing a ReadOnlyTransaction shouldn't panic.
135	c := &Client{}
136	tx := c.ReadOnlyTransaction()
137	tx.Close()
138}
139
140func TestClient_Single(t *testing.T) {
141	t.Parallel()
142	err := testSingleQuery(t, nil)
143	if err != nil {
144		t.Fatal(err)
145	}
146}
147
148func TestClient_Single_Unavailable(t *testing.T) {
149	t.Parallel()
150	err := testSingleQuery(t, status.Error(codes.Unavailable, "Temporary unavailable"))
151	if err != nil {
152		t.Fatal(err)
153	}
154}
155
156func TestClient_Single_InvalidArgument(t *testing.T) {
157	t.Parallel()
158	err := testSingleQuery(t, status.Error(codes.InvalidArgument, "Invalid argument"))
159	if status.Code(err) != codes.InvalidArgument {
160		t.Fatalf("got: %v, want: %v", err, codes.InvalidArgument)
161	}
162}
163
164func TestClient_Single_SessionNotFound(t *testing.T) {
165	t.Parallel()
166
167	server, client, teardown := setupMockedTestServer(t)
168	defer teardown()
169	server.TestSpanner.PutExecutionTime(
170		MethodExecuteStreamingSql,
171		SimulatedExecutionTime{Errors: []error{newSessionNotFoundError("projects/p/instances/i/databases/d/sessions/s")}},
172	)
173	ctx := context.Background()
174	iter := client.Single().Query(ctx, NewStatement(SelectSingerIDAlbumIDAlbumTitleFromAlbums))
175	defer iter.Stop()
176	rowCount := int64(0)
177	for {
178		_, err := iter.Next()
179		if err == iterator.Done {
180			break
181		}
182		if err != nil {
183			t.Fatal(err)
184		}
185		rowCount++
186	}
187	if rowCount != SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount {
188		t.Fatalf("row count mismatch\nGot: %v\nWant: %v", rowCount, SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount)
189	}
190}
191
192func TestClient_Single_RetryableErrorOnPartialResultSet(t *testing.T) {
193	t.Parallel()
194	server, client, teardown := setupMockedTestServer(t)
195	defer teardown()
196
197	// Add two errors that will be returned by the mock server when the client
198	// is trying to fetch a partial result set. Both errors are retryable.
199	// The errors are not 'sticky' on the mocked server, i.e. once the error
200	// has been returned once, the next call for the same partial result set
201	// will succeed.
202
203	// When the client is fetching the partial result set with resume token 2,
204	// the mock server will respond with an internal error with the message
205	// 'stream terminated by RST_STREAM'. The client will retry the call to get
206	// this partial result set.
207	server.TestSpanner.AddPartialResultSetError(
208		SelectSingerIDAlbumIDAlbumTitleFromAlbums,
209		PartialResultSetExecutionTime{
210			ResumeToken: EncodeResumeToken(2),
211			Err:         status.Errorf(codes.Internal, "stream terminated by RST_STREAM"),
212		},
213	)
214	// When the client is fetching the partial result set with resume token 3,
215	// the mock server will respond with a 'Unavailable' error. The client will
216	// retry the call to get this partial result set.
217	server.TestSpanner.AddPartialResultSetError(
218		SelectSingerIDAlbumIDAlbumTitleFromAlbums,
219		PartialResultSetExecutionTime{
220			ResumeToken: EncodeResumeToken(3),
221			Err:         status.Errorf(codes.Unavailable, "server is unavailable"),
222		},
223	)
224	ctx := context.Background()
225	if err := executeSingerQuery(ctx, client.Single()); err != nil {
226		t.Fatal(err)
227	}
228}
229
230func TestClient_Single_NonRetryableErrorOnPartialResultSet(t *testing.T) {
231	t.Parallel()
232	server, client, teardown := setupMockedTestServer(t)
233	defer teardown()
234
235	// Add two errors that will be returned by the mock server when the client
236	// is trying to fetch a partial result set. The first error is retryable,
237	// the second is not.
238
239	// This error will automatically be retried.
240	server.TestSpanner.AddPartialResultSetError(
241		SelectSingerIDAlbumIDAlbumTitleFromAlbums,
242		PartialResultSetExecutionTime{
243			ResumeToken: EncodeResumeToken(2),
244			Err:         status.Errorf(codes.Internal, "stream terminated by RST_STREAM"),
245		},
246	)
247	// 'Session not found' is not retryable and the error will be returned to
248	// the user.
249	server.TestSpanner.AddPartialResultSetError(
250		SelectSingerIDAlbumIDAlbumTitleFromAlbums,
251		PartialResultSetExecutionTime{
252			ResumeToken: EncodeResumeToken(3),
253			Err:         newSessionNotFoundError("projects/p/instances/i/databases/d/sessions/s"),
254		},
255	)
256	ctx := context.Background()
257	err := executeSingerQuery(ctx, client.Single())
258	if status.Code(err) != codes.NotFound {
259		t.Fatalf("Error mismatch:\ngot: %v\nwant: %v", err, codes.NotFound)
260	}
261}
262
263func TestClient_Single_DeadlineExceeded_NoErrors(t *testing.T) {
264	t.Parallel()
265	server, client, teardown := setupMockedTestServer(t)
266	defer teardown()
267	server.TestSpanner.PutExecutionTime(MethodExecuteStreamingSql,
268		SimulatedExecutionTime{
269			MinimumExecutionTime: 50 * time.Millisecond,
270		})
271	ctx := context.Background()
272	ctx, cancel := context.WithDeadline(ctx, time.Now().Add(5*time.Millisecond))
273	defer cancel()
274	err := executeSingerQuery(ctx, client.Single())
275	if status.Code(err) != codes.DeadlineExceeded {
276		t.Fatalf("Error mismatch:\ngot: %v\nwant: %v", err, codes.DeadlineExceeded)
277	}
278}
279
280func TestClient_Single_DeadlineExceeded_WithErrors(t *testing.T) {
281	t.Parallel()
282	server, client, teardown := setupMockedTestServer(t)
283	defer teardown()
284	server.TestSpanner.AddPartialResultSetError(
285		SelectSingerIDAlbumIDAlbumTitleFromAlbums,
286		PartialResultSetExecutionTime{
287			ResumeToken: EncodeResumeToken(2),
288			Err:         status.Errorf(codes.Internal, "stream terminated by RST_STREAM"),
289		},
290	)
291	server.TestSpanner.AddPartialResultSetError(
292		SelectSingerIDAlbumIDAlbumTitleFromAlbums,
293		PartialResultSetExecutionTime{
294			ResumeToken:   EncodeResumeToken(3),
295			Err:           status.Errorf(codes.Unavailable, "server is unavailable"),
296			ExecutionTime: 50 * time.Millisecond,
297		},
298	)
299	ctx := context.Background()
300	ctx, cancel := context.WithDeadline(ctx, time.Now().Add(25*time.Millisecond))
301	defer cancel()
302	err := executeSingerQuery(ctx, client.Single())
303	if status.Code(err) != codes.DeadlineExceeded {
304		t.Fatalf("got unexpected error %v, expected DeadlineExceeded", err)
305	}
306}
307
308func TestClient_Single_ContextCanceled_noDeclaredServerErrors(t *testing.T) {
309	t.Parallel()
310	_, client, teardown := setupMockedTestServer(t)
311	defer teardown()
312	ctx := context.Background()
313	ctx, cancel := context.WithCancel(ctx)
314	cancel()
315	err := executeSingerQuery(ctx, client.Single())
316	if status.Code(err) != codes.Canceled {
317		t.Fatalf("got unexpected error %v, expected Canceled", err)
318	}
319}
320
321func TestClient_Single_ContextCanceled_withDeclaredServerErrors(t *testing.T) {
322	t.Parallel()
323	server, client, teardown := setupMockedTestServer(t)
324	defer teardown()
325	server.TestSpanner.AddPartialResultSetError(
326		SelectSingerIDAlbumIDAlbumTitleFromAlbums,
327		PartialResultSetExecutionTime{
328			ResumeToken: EncodeResumeToken(2),
329			Err:         status.Errorf(codes.Internal, "stream terminated by RST_STREAM"),
330		},
331	)
332	server.TestSpanner.AddPartialResultSetError(
333		SelectSingerIDAlbumIDAlbumTitleFromAlbums,
334		PartialResultSetExecutionTime{
335			ResumeToken: EncodeResumeToken(3),
336			Err:         status.Errorf(codes.Unavailable, "server is unavailable"),
337		},
338	)
339	ctx := context.Background()
340	ctx, cancel := context.WithCancel(ctx)
341	defer cancel()
342	f := func(rowCount int64) error {
343		if rowCount == 2 {
344			cancel()
345		}
346		return nil
347	}
348	iter := client.Single().Query(ctx, NewStatement(SelectSingerIDAlbumIDAlbumTitleFromAlbums))
349	defer iter.Stop()
350	err := executeSingerQueryWithRowFunc(ctx, client.Single(), f)
351	if status.Code(err) != codes.Canceled {
352		t.Fatalf("got unexpected error %v, expected Canceled", err)
353	}
354}
355
356func TestClient_ResourceBasedRouting_WithEndpointsReturned(t *testing.T) {
357	os.Setenv("GOOGLE_CLOUD_SPANNER_ENABLE_RESOURCE_BASED_ROUTING", "true")
358	defer os.Setenv("GOOGLE_CLOUD_SPANNER_ENABLE_RESOURCE_BASED_ROUTING", "")
359
360	// Create two servers. The base server receives the GetInstance request and
361	// returns the instance endpoint of the target server. The client should contact
362	// the target server after getting the instance endpoint.
363	serverBase, optsBase, serverTeardownBase := NewMockedSpannerInMemTestServerWithAddr(t, "localhost:8081")
364	defer serverTeardownBase()
365	serverTarget, optsTarget, serverTeardownTarget := NewMockedSpannerInMemTestServerWithAddr(t, "localhost:8082")
366	defer serverTeardownTarget()
367
368	// Return the instance endpoint.
369	instanceEndpoint := fmt.Sprintf("%s", optsTarget[0])
370	resps := []proto.Message{&instancepb.Instance{
371		EndpointUris: []string{instanceEndpoint},
372	}}
373	serverBase.TestInstanceAdmin.SetResps(resps)
374
375	ctx := context.Background()
376	formattedDatabase := fmt.Sprintf("projects/%s/instances/%s/databases/%s", "some-project", "some-instance", "some-database")
377	client, err := NewClientWithConfig(ctx, formattedDatabase, ClientConfig{}, optsBase...)
378	if err != nil {
379		t.Fatal(err)
380	}
381
382	if err := executeSingerQuery(ctx, client.Single()); err != nil {
383		t.Fatal(err)
384	}
385
386	// The base server should not receive any requests.
387	if _, err := shouldHaveReceived(serverBase.TestSpanner, []interface{}{}); err != nil {
388		t.Fatal(err)
389	}
390
391	// The target server should receive requests.
392	if _, err = shouldHaveReceived(serverTarget.TestSpanner, []interface{}{
393		&sppb.CreateSessionRequest{},
394		&sppb.ExecuteSqlRequest{},
395	}); err != nil {
396		t.Fatal(err)
397	}
398}
399
400func TestClient_ResourceBasedRouting_WithoutEndpointsReturned(t *testing.T) {
401	os.Setenv("GOOGLE_CLOUD_SPANNER_ENABLE_RESOURCE_BASED_ROUTING", "true")
402	defer os.Setenv("GOOGLE_CLOUD_SPANNER_ENABLE_RESOURCE_BASED_ROUTING", "")
403
404	server, opts, serverTeardown := NewMockedSpannerInMemTestServer(t)
405	defer serverTeardown()
406
407	// Return an empty list of endpoints.
408	resps := []proto.Message{&instancepb.Instance{
409		EndpointUris: []string{},
410	}}
411	server.TestInstanceAdmin.SetResps(resps)
412
413	ctx := context.Background()
414	formattedDatabase := fmt.Sprintf("projects/%s/instances/%s/databases/%s", "some-project", "some-instance", "some-database")
415	client, err := NewClientWithConfig(ctx, formattedDatabase, ClientConfig{}, opts...)
416	if err != nil {
417		t.Fatal(err)
418	}
419
420	if err := executeSingerQuery(ctx, client.Single()); err != nil {
421		t.Fatal(err)
422	}
423
424	// Check if the request goes to the default endpoint.
425	if _, err := shouldHaveReceived(server.TestSpanner, []interface{}{
426		&sppb.CreateSessionRequest{},
427		&sppb.ExecuteSqlRequest{},
428	}); err != nil {
429		t.Fatal(err)
430	}
431}
432
433func TestClient_ResourceBasedRouting_WithPermissionDeniedError(t *testing.T) {
434	os.Setenv("GOOGLE_CLOUD_SPANNER_ENABLE_RESOURCE_BASED_ROUTING", "true")
435	defer os.Setenv("GOOGLE_CLOUD_SPANNER_ENABLE_RESOURCE_BASED_ROUTING", "")
436
437	server, opts, serverTeardown := NewMockedSpannerInMemTestServer(t)
438	defer serverTeardown()
439
440	server.TestInstanceAdmin.SetErr(status.Error(codes.PermissionDenied, "Permission Denied"))
441
442	ctx := context.Background()
443	formattedDatabase := fmt.Sprintf("projects/%s/instances/%s/databases/%s", "some-project", "some-instance", "some-database")
444	// `PermissionDeniedError` causes a warning message to be logged, which is expected.
445	// We set the output to be discarded to avoid spamming the log.
446	logger := log.New(ioutil.Discard, "", log.LstdFlags)
447	client, err := NewClientWithConfig(ctx, formattedDatabase, ClientConfig{logger: logger}, opts...)
448	if err != nil {
449		t.Fatal(err)
450	}
451
452	if err := executeSingerQuery(ctx, client.Single()); err != nil {
453		t.Fatal(err)
454	}
455
456	// Fallback to use the default endpoint when calling GetInstance() returns
457	// a PermissionDenied error.
458	if _, err := shouldHaveReceived(server.TestSpanner, []interface{}{
459		&sppb.CreateSessionRequest{},
460		&sppb.ExecuteSqlRequest{},
461	}); err != nil {
462		t.Fatal(err)
463	}
464}
465
466func TestClient_ResourceBasedRouting_WithUnavailableError(t *testing.T) {
467	os.Setenv("GOOGLE_CLOUD_SPANNER_ENABLE_RESOURCE_BASED_ROUTING", "true")
468	defer os.Setenv("GOOGLE_CLOUD_SPANNER_ENABLE_RESOURCE_BASED_ROUTING", "")
469
470	server, opts, serverTeardown := NewMockedSpannerInMemTestServer(t)
471	defer serverTeardown()
472
473	resps := []proto.Message{&instancepb.Instance{
474		EndpointUris: []string{},
475	}}
476	server.TestInstanceAdmin.SetResps(resps)
477	server.TestInstanceAdmin.SetErr(status.Error(codes.Unavailable, "Temporary unavailable"))
478
479	ctx := context.Background()
480	formattedDatabase := fmt.Sprintf("projects/%s/instances/%s/databases/%s", "some-project", "some-instance", "some-database")
481	_, err := NewClientWithConfig(ctx, formattedDatabase, ClientConfig{}, opts...)
482	// The first request will get an error and the server resets the error to nil,
483	// so the next request will be fine. Due to retrying, there is no errors.
484	if err != nil {
485		t.Fatal(err)
486	}
487}
488
489func TestClient_ResourceBasedRouting_WithInvalidArgumentError(t *testing.T) {
490	os.Setenv("GOOGLE_CLOUD_SPANNER_ENABLE_RESOURCE_BASED_ROUTING", "true")
491	defer os.Setenv("GOOGLE_CLOUD_SPANNER_ENABLE_RESOURCE_BASED_ROUTING", "")
492
493	server, opts, serverTeardown := NewMockedSpannerInMemTestServer(t)
494	defer serverTeardown()
495
496	server.TestInstanceAdmin.SetErr(status.Error(codes.InvalidArgument, "Invalid argument"))
497
498	ctx := context.Background()
499	formattedDatabase := fmt.Sprintf("projects/%s/instances/%s/databases/%s", "some-project", "some-instance", "some-database")
500	_, err := NewClientWithConfig(ctx, formattedDatabase, ClientConfig{}, opts...)
501
502	if status.Code(err) != codes.InvalidArgument {
503		t.Fatalf("got unexpected exception %v, expected InvalidArgument", err)
504	}
505}
506
507func testSingleQuery(t *testing.T, serverError error) error {
508	ctx := context.Background()
509	server, client, teardown := setupMockedTestServer(t)
510	defer teardown()
511	if serverError != nil {
512		server.TestSpanner.SetError(serverError)
513	}
514	return executeSingerQuery(ctx, client.Single())
515}
516
517func executeSingerQuery(ctx context.Context, tx *ReadOnlyTransaction) error {
518	return executeSingerQueryWithRowFunc(ctx, tx, nil)
519}
520
521func executeSingerQueryWithRowFunc(ctx context.Context, tx *ReadOnlyTransaction, f func(rowCount int64) error) error {
522	iter := tx.Query(ctx, NewStatement(SelectSingerIDAlbumIDAlbumTitleFromAlbums))
523	defer iter.Stop()
524	rowCount := int64(0)
525	for {
526		row, err := iter.Next()
527		if err == iterator.Done {
528			break
529		}
530		if err != nil {
531			return err
532		}
533		var singerID, albumID int64
534		var albumTitle string
535		if err := row.Columns(&singerID, &albumID, &albumTitle); err != nil {
536			return err
537		}
538		rowCount++
539		if f != nil {
540			if err := f(rowCount); err != nil {
541				return err
542			}
543		}
544	}
545	if rowCount != SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount {
546		return status.Errorf(codes.Internal, "Row count mismatch, got %v, expected %v", rowCount, SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount)
547	}
548	return nil
549}
550
551func createSimulatedExecutionTimeWithTwoUnavailableErrors(method string) map[string]SimulatedExecutionTime {
552	errors := make([]error, 2)
553	errors[0] = status.Error(codes.Unavailable, "Temporary unavailable")
554	errors[1] = status.Error(codes.Unavailable, "Temporary unavailable")
555	executionTimes := make(map[string]SimulatedExecutionTime)
556	executionTimes[method] = SimulatedExecutionTime{
557		Errors: errors,
558	}
559	return executionTimes
560}
561
562func TestClient_ReadOnlyTransaction(t *testing.T) {
563	t.Parallel()
564	if err := testReadOnlyTransaction(t, make(map[string]SimulatedExecutionTime)); err != nil {
565		t.Fatal(err)
566	}
567}
568
569func TestClient_ReadOnlyTransaction_UnavailableOnSessionCreate(t *testing.T) {
570	t.Parallel()
571	if err := testReadOnlyTransaction(t, createSimulatedExecutionTimeWithTwoUnavailableErrors(MethodCreateSession)); err != nil {
572		t.Fatal(err)
573	}
574}
575
576func TestClient_ReadOnlyTransaction_UnavailableOnBeginTransaction(t *testing.T) {
577	t.Parallel()
578	if err := testReadOnlyTransaction(t, createSimulatedExecutionTimeWithTwoUnavailableErrors(MethodBeginTransaction)); err != nil {
579		t.Fatal(err)
580	}
581}
582
583func TestClient_ReadOnlyTransaction_UnavailableOnExecuteStreamingSql(t *testing.T) {
584	t.Parallel()
585	if err := testReadOnlyTransaction(t, createSimulatedExecutionTimeWithTwoUnavailableErrors(MethodExecuteStreamingSql)); err != nil {
586		t.Fatal(err)
587	}
588}
589
590func TestClient_ReadOnlyTransaction_SessionNotFoundOnExecuteStreamingSql(t *testing.T) {
591	t.Parallel()
592	// Session not found is not retryable for a query on a multi-use read-only
593	// transaction, as we would need to start a new transaction on a new
594	// session.
595	err := testReadOnlyTransaction(t, map[string]SimulatedExecutionTime{
596		MethodExecuteStreamingSql: {Errors: []error{newSessionNotFoundError("projects/p/instances/i/databases/d/sessions/s")}},
597	})
598	want := toSpannerError(newSessionNotFoundError("projects/p/instances/i/databases/d/sessions/s"))
599	if err == nil {
600		t.Fatalf("missing expected error\nGot: nil\nWant: %v", want)
601	}
602	if status.Code(err) != status.Code(want) || !strings.Contains(err.Error(), want.Error()) {
603		t.Fatalf("error mismatch\nGot: %v\nWant: %v", err, want)
604	}
605}
606
607func TestClient_ReadOnlyTransaction_UnavailableOnCreateSessionAndBeginTransaction(t *testing.T) {
608	t.Parallel()
609	exec := map[string]SimulatedExecutionTime{
610		MethodCreateSession:    {Errors: []error{status.Error(codes.Unavailable, "Temporary unavailable")}},
611		MethodBeginTransaction: {Errors: []error{status.Error(codes.Unavailable, "Temporary unavailable")}},
612	}
613	if err := testReadOnlyTransaction(t, exec); err != nil {
614		t.Fatal(err)
615	}
616}
617
618func TestClient_ReadOnlyTransaction_UnavailableOnCreateSessionAndInvalidArgumentOnBeginTransaction(t *testing.T) {
619	t.Parallel()
620	exec := map[string]SimulatedExecutionTime{
621		MethodCreateSession:    {Errors: []error{status.Error(codes.Unavailable, "Temporary unavailable")}},
622		MethodBeginTransaction: {Errors: []error{status.Error(codes.InvalidArgument, "Invalid argument")}},
623	}
624	if err := testReadOnlyTransaction(t, exec); err == nil {
625		t.Fatalf("Missing expected exception")
626	} else if status.Code(err) != codes.InvalidArgument {
627		t.Fatalf("Got unexpected exception: %v", err)
628	}
629}
630
631func TestClient_ReadOnlyTransaction_SessionNotFoundOnBeginTransaction(t *testing.T) {
632	t.Parallel()
633	if err := testReadOnlyTransaction(
634		t,
635		map[string]SimulatedExecutionTime{
636			MethodBeginTransaction: {Errors: []error{newSessionNotFoundError("projects/p/instances/i/databases/d/sessions/s")}},
637		},
638	); err != nil {
639		t.Fatal(err)
640	}
641}
642
643func TestClient_ReadOnlyTransaction_SessionNotFoundOnBeginTransaction_WithMaxOneSession(t *testing.T) {
644	t.Parallel()
645	server, client, teardown := setupMockedTestServerWithConfig(
646		t,
647		ClientConfig{
648			SessionPoolConfig: SessionPoolConfig{
649				MinOpened: 0,
650				MaxOpened: 1,
651			},
652		})
653	defer teardown()
654	server.TestSpanner.PutExecutionTime(
655		MethodBeginTransaction,
656		SimulatedExecutionTime{Errors: []error{newSessionNotFoundError("projects/p/instances/i/databases/d/sessions/s")}},
657	)
658	tx := client.ReadOnlyTransaction()
659	defer tx.Close()
660	ctx := context.Background()
661	if err := executeSingerQuery(ctx, tx); err != nil {
662		t.Fatal(err)
663	}
664}
665
666func testReadOnlyTransaction(t *testing.T, executionTimes map[string]SimulatedExecutionTime) error {
667	server, client, teardown := setupMockedTestServer(t)
668	defer teardown()
669	for method, exec := range executionTimes {
670		server.TestSpanner.PutExecutionTime(method, exec)
671	}
672	tx := client.ReadOnlyTransaction()
673	defer tx.Close()
674	ctx := context.Background()
675	return executeSingerQuery(ctx, tx)
676}
677
678func TestClient_ReadWriteTransaction(t *testing.T) {
679	t.Parallel()
680	if err := testReadWriteTransaction(t, make(map[string]SimulatedExecutionTime), 1); err != nil {
681		t.Fatal(err)
682	}
683}
684
685func TestClient_ReadWriteTransactionCommitAborted(t *testing.T) {
686	t.Parallel()
687	if err := testReadWriteTransaction(t, map[string]SimulatedExecutionTime{
688		MethodCommitTransaction: {Errors: []error{status.Error(codes.Aborted, "Transaction aborted")}},
689	}, 2); err != nil {
690		t.Fatal(err)
691	}
692}
693
694func TestClient_ReadWriteTransaction_SessionNotFoundOnCommit(t *testing.T) {
695	t.Parallel()
696	if err := testReadWriteTransaction(t, map[string]SimulatedExecutionTime{
697		MethodCommitTransaction: {Errors: []error{newSessionNotFoundError("projects/p/instances/i/databases/d/sessions/s")}},
698	}, 2); err != nil {
699		t.Fatal(err)
700	}
701}
702
703func TestClient_ReadWriteTransaction_SessionNotFoundOnBeginTransaction(t *testing.T) {
704	t.Parallel()
705	// We expect only 1 attempt, as the 'Session not found' error is already
706	//handled in the session pool where the session is prepared.
707	if err := testReadWriteTransaction(t, map[string]SimulatedExecutionTime{
708		MethodBeginTransaction: {Errors: []error{newSessionNotFoundError("projects/p/instances/i/databases/d/sessions/s")}},
709	}, 1); err != nil {
710		t.Fatal(err)
711	}
712}
713
714func TestClient_ReadWriteTransaction_SessionNotFoundOnBeginTransactionWithEmptySessionPool(t *testing.T) {
715	t.Parallel()
716	// There will be no prepared sessions in the pool, so the error will occur
717	// when the transaction tries to get a session from the pool. This will
718	// also be handled by the session pool, so the transaction itself does not
719	// need to retry, hence the expectedAttempts == 1.
720	if err := testReadWriteTransactionWithConfig(t, ClientConfig{
721		SessionPoolConfig: SessionPoolConfig{WriteSessions: 0.0},
722	}, map[string]SimulatedExecutionTime{
723		MethodBeginTransaction: {Errors: []error{newSessionNotFoundError("projects/p/instances/i/databases/d/sessions/s")}},
724	}, 1); err != nil {
725		t.Fatal(err)
726	}
727}
728
729func TestClient_ReadWriteTransaction_SessionNotFoundOnExecuteStreamingSql(t *testing.T) {
730	t.Parallel()
731	if err := testReadWriteTransaction(t, map[string]SimulatedExecutionTime{
732		MethodExecuteStreamingSql: {Errors: []error{newSessionNotFoundError("projects/p/instances/i/databases/d/sessions/s")}},
733	}, 2); err != nil {
734		t.Fatal(err)
735	}
736}
737
738func TestClient_ReadWriteTransaction_SessionNotFoundOnExecuteUpdate(t *testing.T) {
739	t.Parallel()
740
741	server, client, teardown := setupMockedTestServer(t)
742	defer teardown()
743	server.TestSpanner.PutExecutionTime(
744		MethodExecuteSql,
745		SimulatedExecutionTime{Errors: []error{newSessionNotFoundError("projects/p/instances/i/databases/d/sessions/s")}},
746	)
747	ctx := context.Background()
748	var attempts int
749	_, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
750		attempts++
751		rowCount, err := tx.Update(ctx, NewStatement(UpdateBarSetFoo))
752		if err != nil {
753			return err
754		}
755		if g, w := rowCount, int64(UpdateBarSetFooRowCount); g != w {
756			return status.Errorf(codes.FailedPrecondition, "Row count mismatch\nGot: %v\nWant: %v", g, w)
757		}
758		return nil
759	})
760	if err != nil {
761		t.Fatal(err)
762	}
763	if g, w := attempts, 2; g != w {
764		t.Fatalf("number of attempts mismatch:\nGot%d\nWant:%d", g, w)
765	}
766}
767
768func TestClient_ReadWriteTransaction_SessionNotFoundOnExecuteBatchUpdate(t *testing.T) {
769	t.Parallel()
770
771	server, client, teardown := setupMockedTestServer(t)
772	defer teardown()
773	server.TestSpanner.PutExecutionTime(
774		MethodExecuteBatchDml,
775		SimulatedExecutionTime{Errors: []error{newSessionNotFoundError("projects/p/instances/i/databases/d/sessions/s")}},
776	)
777	ctx := context.Background()
778	var attempts int
779	_, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
780		attempts++
781		rowCounts, err := tx.BatchUpdate(ctx, []Statement{NewStatement(UpdateBarSetFoo)})
782		if err != nil {
783			return err
784		}
785		if g, w := len(rowCounts), 1; g != w {
786			return status.Errorf(codes.FailedPrecondition, "Row counts length mismatch\nGot: %v\nWant: %v", g, w)
787		}
788		if g, w := rowCounts[0], int64(UpdateBarSetFooRowCount); g != w {
789			return status.Errorf(codes.FailedPrecondition, "Row count mismatch\nGot: %v\nWant: %v", g, w)
790		}
791		return nil
792	})
793	if err != nil {
794		t.Fatal(err)
795	}
796	if g, w := attempts, 2; g != w {
797		t.Fatalf("number of attempts mismatch:\nGot%d\nWant:%d", g, w)
798	}
799}
800
801func TestClient_SessionNotFound(t *testing.T) {
802	// Ensure we always have at least one session in the pool.
803	sc := SessionPoolConfig{
804		MinOpened: 1,
805	}
806	server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{SessionPoolConfig: sc})
807	defer teardown()
808	ctx := context.Background()
809	for {
810		client.idleSessions.mu.Lock()
811		numSessions := client.idleSessions.idleList.Len()
812		client.idleSessions.mu.Unlock()
813		if numSessions > 0 {
814			break
815		}
816		time.After(time.Millisecond)
817	}
818	// Remove the session from the server without the pool knowing it.
819	_, err := server.TestSpanner.DeleteSession(ctx, &sppb.DeleteSessionRequest{Name: client.idleSessions.idleList.Front().Value.(*session).id})
820	if err != nil {
821		t.Fatalf("Failed to delete session unexpectedly: %v", err)
822	}
823
824	_, err = client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
825		iter := tx.Query(ctx, NewStatement(SelectSingerIDAlbumIDAlbumTitleFromAlbums))
826		defer iter.Stop()
827		rowCount := int64(0)
828		for {
829			row, err := iter.Next()
830			if err == iterator.Done {
831				break
832			}
833			if err != nil {
834				return err
835			}
836			var singerID, albumID int64
837			var albumTitle string
838			if err := row.Columns(&singerID, &albumID, &albumTitle); err != nil {
839				return err
840			}
841			rowCount++
842		}
843		if rowCount != SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount {
844			return spannerErrorf(codes.FailedPrecondition, "Row count mismatch, got %v, expected %v", rowCount, SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount)
845		}
846		return nil
847	})
848	if err != nil {
849		t.Fatalf("Unexpected error during transaction: %v", err)
850	}
851}
852
853func TestClient_ReadWriteTransactionExecuteStreamingSqlAborted(t *testing.T) {
854	t.Parallel()
855	if err := testReadWriteTransaction(t, map[string]SimulatedExecutionTime{
856		MethodExecuteStreamingSql: {Errors: []error{status.Error(codes.Aborted, "Transaction aborted")}},
857	}, 2); err != nil {
858		t.Fatal(err)
859	}
860}
861
862func TestClient_ReadWriteTransaction_UnavailableOnBeginTransaction(t *testing.T) {
863	t.Parallel()
864	if err := testReadWriteTransaction(t, map[string]SimulatedExecutionTime{
865		MethodBeginTransaction: {Errors: []error{status.Error(codes.Unavailable, "Unavailable")}},
866	}, 1); err != nil {
867		t.Fatal(err)
868	}
869}
870
871func TestClient_ReadWriteTransaction_UnavailableOnBeginAndAbortOnCommit(t *testing.T) {
872	if err := testReadWriteTransaction(t, map[string]SimulatedExecutionTime{
873		MethodBeginTransaction:  {Errors: []error{status.Error(codes.Unavailable, "Unavailable")}},
874		MethodCommitTransaction: {Errors: []error{status.Error(codes.Aborted, "Aborted")}},
875	}, 2); err != nil {
876		t.Fatal(err)
877	}
878}
879
880func TestClient_ReadWriteTransaction_UnavailableOnExecuteStreamingSql(t *testing.T) {
881	t.Parallel()
882	if err := testReadWriteTransaction(t, map[string]SimulatedExecutionTime{
883		MethodExecuteStreamingSql: {Errors: []error{status.Error(codes.Unavailable, "Unavailable")}},
884	}, 1); err != nil {
885		t.Fatal(err)
886	}
887}
888
889func TestClient_ReadWriteTransaction_UnavailableOnBeginAndExecuteStreamingSqlAndTwiceAbortOnCommit(t *testing.T) {
890	t.Parallel()
891	if err := testReadWriteTransaction(t, map[string]SimulatedExecutionTime{
892		MethodBeginTransaction:    {Errors: []error{status.Error(codes.Unavailable, "Unavailable")}},
893		MethodExecuteStreamingSql: {Errors: []error{status.Error(codes.Unavailable, "Unavailable")}},
894		MethodCommitTransaction:   {Errors: []error{status.Error(codes.Aborted, "Aborted"), status.Error(codes.Aborted, "Aborted")}},
895	}, 3); err != nil {
896		t.Fatal(err)
897	}
898}
899
900func TestClient_ReadWriteTransaction_CommitAborted(t *testing.T) {
901	t.Parallel()
902	server, client, teardown := setupMockedTestServer(t)
903	server.TestSpanner.PutExecutionTime(MethodCommitTransaction, SimulatedExecutionTime{
904		Errors: []error{status.Error(codes.Aborted, "Aborted")},
905	})
906	defer teardown()
907	ctx := context.Background()
908	attempts := 0
909	_, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
910		attempts++
911		_, err := tx.Update(ctx, Statement{SQL: UpdateBarSetFoo})
912		if err != nil {
913			return err
914		}
915		return nil
916	})
917	if err != nil {
918		t.Fatal(err)
919	}
920	if g, w := attempts, 2; g != w {
921		t.Fatalf("attempt count mismatch:\nWant: %v\nGot: %v", w, g)
922	}
923}
924
925func TestClient_ReadWriteTransaction_DMLAborted(t *testing.T) {
926	t.Parallel()
927	server, client, teardown := setupMockedTestServer(t)
928	server.TestSpanner.PutExecutionTime(MethodExecuteSql, SimulatedExecutionTime{
929		Errors: []error{status.Error(codes.Aborted, "Aborted")},
930	})
931	defer teardown()
932	ctx := context.Background()
933	attempts := 0
934	_, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
935		attempts++
936		_, err := tx.Update(ctx, Statement{SQL: UpdateBarSetFoo})
937		if err != nil {
938			return err
939		}
940		return nil
941	})
942	if err != nil {
943		t.Fatal(err)
944	}
945	if g, w := attempts, 2; g != w {
946		t.Fatalf("attempt count mismatch:\nWant: %v\nGot: %v", w, g)
947	}
948}
949
950func TestClient_ReadWriteTransaction_BatchDMLAborted(t *testing.T) {
951	t.Parallel()
952	server, client, teardown := setupMockedTestServer(t)
953	server.TestSpanner.PutExecutionTime(MethodExecuteBatchDml, SimulatedExecutionTime{
954		Errors: []error{status.Error(codes.Aborted, "Aborted")},
955	})
956	defer teardown()
957	ctx := context.Background()
958	attempts := 0
959	_, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
960		attempts++
961		_, err := tx.BatchUpdate(ctx, []Statement{{SQL: UpdateBarSetFoo}})
962		if err != nil {
963			return err
964		}
965		return nil
966	})
967	if err != nil {
968		t.Fatal(err)
969	}
970	if g, w := attempts, 2; g != w {
971		t.Fatalf("attempt count mismatch:\nWant: %v\nGot: %v", w, g)
972	}
973}
974
975func TestClient_ReadWriteTransaction_QueryAborted(t *testing.T) {
976	t.Parallel()
977	server, client, teardown := setupMockedTestServer(t)
978	server.TestSpanner.PutExecutionTime(MethodExecuteStreamingSql, SimulatedExecutionTime{
979		Errors: []error{status.Error(codes.Aborted, "Aborted")},
980	})
981	defer teardown()
982	ctx := context.Background()
983	attempts := 0
984	_, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
985		attempts++
986		iter := tx.Query(ctx, Statement{SQL: SelectFooFromBar})
987		defer iter.Stop()
988		for {
989			_, err := iter.Next()
990			if err == iterator.Done {
991				break
992			}
993			if err != nil {
994				return err
995			}
996		}
997		return nil
998	})
999	if err != nil {
1000		t.Fatal(err)
1001	}
1002	if g, w := attempts, 2; g != w {
1003		t.Fatalf("attempt count mismatch:\nWant: %v\nGot: %v", w, g)
1004	}
1005}
1006
1007func TestClient_ReadWriteTransaction_AbortedOnExecuteStreamingSqlAndCommit(t *testing.T) {
1008	t.Parallel()
1009	if err := testReadWriteTransaction(t, map[string]SimulatedExecutionTime{
1010		MethodExecuteStreamingSql: {Errors: []error{status.Error(codes.Aborted, "Aborted")}},
1011		MethodCommitTransaction:   {Errors: []error{status.Error(codes.Aborted, "Aborted"), status.Error(codes.Aborted, "Aborted")}},
1012	}, 4); err != nil {
1013		t.Fatal(err)
1014	}
1015}
1016
1017func TestClient_ReadWriteTransactionCommitAbortedAndUnavailable(t *testing.T) {
1018	t.Parallel()
1019	if err := testReadWriteTransaction(t, map[string]SimulatedExecutionTime{
1020		MethodCommitTransaction: {
1021			Errors: []error{
1022				status.Error(codes.Aborted, "Transaction aborted"),
1023				status.Error(codes.Unavailable, "Unavailable"),
1024			},
1025		},
1026	}, 2); err != nil {
1027		t.Fatal(err)
1028	}
1029}
1030
1031func TestClient_ReadWriteTransactionCommitAlreadyExists(t *testing.T) {
1032	t.Parallel()
1033	if err := testReadWriteTransaction(t, map[string]SimulatedExecutionTime{
1034		MethodCommitTransaction: {Errors: []error{status.Error(codes.AlreadyExists, "A row with this key already exists")}},
1035	}, 1); err != nil {
1036		if status.Code(err) != codes.AlreadyExists {
1037			t.Fatalf("Got unexpected error %v, expected %v", err, codes.AlreadyExists)
1038		}
1039	} else {
1040		t.Fatalf("Missing expected exception")
1041	}
1042}
1043
1044func testReadWriteTransaction(t *testing.T, executionTimes map[string]SimulatedExecutionTime, expectedAttempts int) error {
1045	return testReadWriteTransactionWithConfig(t, ClientConfig{SessionPoolConfig: DefaultSessionPoolConfig}, executionTimes, expectedAttempts)
1046}
1047
1048func testReadWriteTransactionWithConfig(t *testing.T, config ClientConfig, executionTimes map[string]SimulatedExecutionTime, expectedAttempts int) error {
1049	server, client, teardown := setupMockedTestServer(t)
1050	defer teardown()
1051	for method, exec := range executionTimes {
1052		server.TestSpanner.PutExecutionTime(method, exec)
1053	}
1054	ctx := context.Background()
1055	var attempts int
1056	_, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
1057		attempts++
1058		iter := tx.Query(ctx, NewStatement(SelectSingerIDAlbumIDAlbumTitleFromAlbums))
1059		defer iter.Stop()
1060		rowCount := int64(0)
1061		for {
1062			row, err := iter.Next()
1063			if err == iterator.Done {
1064				break
1065			}
1066			if err != nil {
1067				return err
1068			}
1069			var singerID, albumID int64
1070			var albumTitle string
1071			if err := row.Columns(&singerID, &albumID, &albumTitle); err != nil {
1072				return err
1073			}
1074			rowCount++
1075		}
1076		if rowCount != SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount {
1077			return status.Errorf(codes.FailedPrecondition, "Row count mismatch, got %v, expected %v", rowCount, SelectSingerIDAlbumIDAlbumTitleFromAlbumsRowCount)
1078		}
1079		return nil
1080	})
1081	if err != nil {
1082		return err
1083	}
1084	if expectedAttempts != attempts {
1085		t.Fatalf("unexpected number of attempts: %d, expected %d", attempts, expectedAttempts)
1086	}
1087	return nil
1088}
1089
1090func TestClient_ApplyAtLeastOnce(t *testing.T) {
1091	t.Parallel()
1092	server, client, teardown := setupMockedTestServer(t)
1093	defer teardown()
1094	ms := []*Mutation{
1095		Insert("Accounts", []string{"AccountId", "Nickname", "Balance"}, []interface{}{int64(1), "Foo", int64(50)}),
1096		Insert("Accounts", []string{"AccountId", "Nickname", "Balance"}, []interface{}{int64(2), "Bar", int64(1)}),
1097	}
1098	server.TestSpanner.PutExecutionTime(MethodCommitTransaction,
1099		SimulatedExecutionTime{
1100			Errors: []error{status.Error(codes.Aborted, "Transaction aborted")},
1101		})
1102	_, err := client.Apply(context.Background(), ms, ApplyAtLeastOnce())
1103	if err != nil {
1104		t.Fatal(err)
1105	}
1106}
1107
1108func TestClient_ApplyAtLeastOnceReuseSession(t *testing.T) {
1109	t.Parallel()
1110	server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{
1111		SessionPoolConfig: SessionPoolConfig{
1112			MinOpened:           0,
1113			WriteSessions:       0.0,
1114			TrackSessionHandles: true,
1115		},
1116	})
1117	defer teardown()
1118	ms := []*Mutation{
1119		Insert("Accounts", []string{"AccountId", "Nickname", "Balance"}, []interface{}{int64(1), "Foo", int64(50)}),
1120		Insert("Accounts", []string{"AccountId", "Nickname", "Balance"}, []interface{}{int64(2), "Bar", int64(1)}),
1121	}
1122	for i := 0; i < 10; i++ {
1123		_, err := client.Apply(context.Background(), ms, ApplyAtLeastOnce())
1124		if err != nil {
1125			t.Fatal(err)
1126		}
1127		if g, w := client.idleSessions.idleList.Len(), 1; g != w {
1128			t.Fatalf("idle session count mismatch:\nGot: %v\nWant: %v", g, w)
1129		}
1130		if g, w := len(server.TestSpanner.DumpSessions()), 1; g != w {
1131			t.Fatalf("server session count mismatch:\nGot: %v\nWant: %v", g, w)
1132		}
1133	}
1134	// There should be no sessions marked as checked out.
1135	client.idleSessions.mu.Lock()
1136	g, w := client.idleSessions.trackedSessionHandles.Len(), 0
1137	client.idleSessions.mu.Unlock()
1138	if g != w {
1139		t.Fatalf("checked out sessions count mismatch:\nGot: %v\nWant: %v", g, w)
1140	}
1141}
1142
1143func TestClient_ApplyAtLeastOnceInvalidArgument(t *testing.T) {
1144	t.Parallel()
1145	server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{
1146		SessionPoolConfig: SessionPoolConfig{
1147			MinOpened:           0,
1148			WriteSessions:       0.0,
1149			TrackSessionHandles: true,
1150		},
1151	})
1152	defer teardown()
1153	ms := []*Mutation{
1154		Insert("Accounts", []string{"AccountId", "Nickname", "Balance"}, []interface{}{int64(1), "Foo", int64(50)}),
1155		Insert("Accounts", []string{"AccountId", "Nickname", "Balance"}, []interface{}{int64(2), "Bar", int64(1)}),
1156	}
1157	for i := 0; i < 10; i++ {
1158		server.TestSpanner.PutExecutionTime(MethodCommitTransaction,
1159			SimulatedExecutionTime{
1160				Errors: []error{status.Error(codes.InvalidArgument, "Invalid data")},
1161			})
1162		_, err := client.Apply(context.Background(), ms, ApplyAtLeastOnce())
1163		if status.Code(err) != codes.InvalidArgument {
1164			t.Fatal(err)
1165		}
1166		if g, w := client.idleSessions.idleList.Len(), 1; g != w {
1167			t.Fatalf("idle session count mismatch:\nGot: %v\nWant: %v", g, w)
1168		}
1169		if g, w := len(server.TestSpanner.DumpSessions()), 1; g != w {
1170			t.Fatalf("server session count mismatch:\nGot: %v\nWant: %v", g, w)
1171		}
1172	}
1173	// There should be no sessions marked as checked out.
1174	client.idleSessions.mu.Lock()
1175	g, w := client.idleSessions.trackedSessionHandles.Len(), 0
1176	client.idleSessions.mu.Unlock()
1177	if g != w {
1178		t.Fatalf("checked out sessions count mismatch:\nGot: %v\nWant: %v", g, w)
1179	}
1180}
1181
1182func TestReadWriteTransaction_ErrUnexpectedEOF(t *testing.T) {
1183	t.Parallel()
1184	_, client, teardown := setupMockedTestServer(t)
1185	defer teardown()
1186	ctx := context.Background()
1187	var attempts int
1188	_, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
1189		attempts++
1190		iter := tx.Query(ctx, NewStatement(SelectSingerIDAlbumIDAlbumTitleFromAlbums))
1191		defer iter.Stop()
1192		for {
1193			row, err := iter.Next()
1194			if err == iterator.Done {
1195				break
1196			}
1197			if err != nil {
1198				return err
1199			}
1200			var singerID, albumID int64
1201			var albumTitle string
1202			if err := row.Columns(&singerID, &albumID, &albumTitle); err != nil {
1203				return err
1204			}
1205		}
1206		return io.ErrUnexpectedEOF
1207	})
1208	if err != io.ErrUnexpectedEOF {
1209		t.Fatalf("Missing expected error %v, got %v", io.ErrUnexpectedEOF, err)
1210	}
1211	if attempts != 1 {
1212		t.Fatalf("unexpected number of attempts: %d, expected %d", attempts, 1)
1213	}
1214}
1215
1216func TestReadWriteTransaction_WrapError(t *testing.T) {
1217	t.Parallel()
1218	server, client, teardown := setupMockedTestServer(t)
1219	defer teardown()
1220	// Abort the transaction on both the query as well as commit.
1221	// The first abort error will be wrapped. The client will unwrap the cause
1222	// of the error and retry the transaction. The aborted error on commit
1223	// will not be wrapped, but will also be recognized by the client as an
1224	// abort that should be retried.
1225	server.TestSpanner.PutExecutionTime(MethodExecuteStreamingSql,
1226		SimulatedExecutionTime{
1227			Errors: []error{status.Error(codes.Aborted, "Transaction aborted")},
1228		})
1229	server.TestSpanner.PutExecutionTime(MethodCommitTransaction,
1230		SimulatedExecutionTime{
1231			Errors: []error{status.Error(codes.Aborted, "Transaction aborted")},
1232		})
1233	msg := "query failed"
1234	numAttempts := 0
1235	ctx := context.Background()
1236	_, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
1237		numAttempts++
1238		iter := tx.Query(ctx, NewStatement(SelectSingerIDAlbumIDAlbumTitleFromAlbums))
1239		defer iter.Stop()
1240		for {
1241			_, err := iter.Next()
1242			if err == iterator.Done {
1243				break
1244			}
1245			if err != nil {
1246				// Wrap the error in another error that implements the
1247				// (xerrors|errors).Wrapper interface.
1248				return &wrappedTestError{err, msg}
1249			}
1250		}
1251		return nil
1252	})
1253	if err != nil {
1254		t.Fatalf("Unexpected error\nGot: %v\nWant: nil", err)
1255	}
1256	if g, w := numAttempts, 3; g != w {
1257		t.Fatalf("Number of transaction attempts mismatch\nGot: %d\nWant: %d", w, w)
1258	}
1259
1260	// Execute a transaction that returns a non-retryable error that is
1261	// wrapped in a custom error. The transaction should return the custom
1262	// error.
1263	server.TestSpanner.PutExecutionTime(MethodExecuteStreamingSql,
1264		SimulatedExecutionTime{
1265			Errors: []error{status.Error(codes.NotFound, "Table not found")},
1266		})
1267	_, err = client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
1268		numAttempts++
1269		iter := tx.Query(ctx, NewStatement(SelectSingerIDAlbumIDAlbumTitleFromAlbums))
1270		defer iter.Stop()
1271		for {
1272			_, err := iter.Next()
1273			if err == iterator.Done {
1274				break
1275			}
1276			if err != nil {
1277				// Wrap the error in another error that implements the
1278				// (xerrors|errors).Wrapper interface.
1279				return &wrappedTestError{err, msg}
1280			}
1281		}
1282		return nil
1283	})
1284	if err == nil || err.Error() != msg {
1285		t.Fatalf("Unexpected error\nGot: %v\nWant: %v", err, msg)
1286	}
1287}
1288
1289func TestReadWriteTransaction_WrapSessionNotFoundError(t *testing.T) {
1290	t.Parallel()
1291	server, client, teardown := setupMockedTestServer(t)
1292	defer teardown()
1293	server.TestSpanner.PutExecutionTime(MethodBeginTransaction,
1294		SimulatedExecutionTime{
1295			Errors: []error{newSessionNotFoundError("projects/p/instances/i/databases/d/sessions/s")},
1296		})
1297	server.TestSpanner.PutExecutionTime(MethodExecuteStreamingSql,
1298		SimulatedExecutionTime{
1299			Errors: []error{newSessionNotFoundError("projects/p/instances/i/databases/d/sessions/s")},
1300		})
1301	server.TestSpanner.PutExecutionTime(MethodCommitTransaction,
1302		SimulatedExecutionTime{
1303			Errors: []error{newSessionNotFoundError("projects/p/instances/i/databases/d/sessions/s")},
1304		})
1305	msg := "query failed"
1306	numAttempts := 0
1307	ctx := context.Background()
1308	_, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
1309		numAttempts++
1310		iter := tx.Query(ctx, NewStatement(SelectSingerIDAlbumIDAlbumTitleFromAlbums))
1311		defer iter.Stop()
1312		for {
1313			_, err := iter.Next()
1314			if err == iterator.Done {
1315				break
1316			}
1317			if err != nil {
1318				// Wrap the error in another error that implements the
1319				// (xerrors|errors).Wrapper interface.
1320				return &wrappedTestError{err, msg}
1321			}
1322		}
1323		return nil
1324	})
1325	if err != nil {
1326		t.Fatalf("Unexpected error\nGot: %v\nWant: nil", err)
1327	}
1328	// We want 3 attempts. The 'Session not found' error on BeginTransaction
1329	// will not retry the entire transaction, which means that we will have two
1330	// failed attempts and then a successful attempt.
1331	if g, w := numAttempts, 3; g != w {
1332		t.Fatalf("Number of transaction attempts mismatch\nGot: %d\nWant: %d", g, w)
1333	}
1334}
1335
1336func TestClient_WriteStructWithPointers(t *testing.T) {
1337	t.Parallel()
1338	server, client, teardown := setupMockedTestServer(t)
1339	defer teardown()
1340	type T struct {
1341		ID    int64
1342		Col1  *string
1343		Col2  []*string
1344		Col3  *bool
1345		Col4  []*bool
1346		Col5  *int64
1347		Col6  []*int64
1348		Col7  *float64
1349		Col8  []*float64
1350		Col9  *time.Time
1351		Col10 []*time.Time
1352		Col11 *civil.Date
1353		Col12 []*civil.Date
1354	}
1355	t1 := T{
1356		ID:    1,
1357		Col2:  []*string{nil},
1358		Col4:  []*bool{nil},
1359		Col6:  []*int64{nil},
1360		Col8:  []*float64{nil},
1361		Col10: []*time.Time{nil},
1362		Col12: []*civil.Date{nil},
1363	}
1364	s := "foo"
1365	b := true
1366	i := int64(100)
1367	f := 3.14
1368	tm := time.Now()
1369	d := civil.DateOf(time.Now())
1370	t2 := T{
1371		ID:    2,
1372		Col1:  &s,
1373		Col2:  []*string{&s},
1374		Col3:  &b,
1375		Col4:  []*bool{&b},
1376		Col5:  &i,
1377		Col6:  []*int64{&i},
1378		Col7:  &f,
1379		Col8:  []*float64{&f},
1380		Col9:  &tm,
1381		Col10: []*time.Time{&tm},
1382		Col11: &d,
1383		Col12: []*civil.Date{&d},
1384	}
1385	m1, err := InsertStruct("Tab", &t1)
1386	if err != nil {
1387		t.Fatal(err)
1388	}
1389	m2, err := InsertStruct("Tab", &t2)
1390	if err != nil {
1391		t.Fatal(err)
1392	}
1393	_, err = client.Apply(context.Background(), []*Mutation{m1, m2})
1394	if err != nil {
1395		t.Fatal(err)
1396	}
1397	requests := drainRequestsFromServer(server.TestSpanner)
1398	for _, req := range requests {
1399		if commit, ok := req.(*sppb.CommitRequest); ok {
1400			if g, w := len(commit.Mutations), 2; w != g {
1401				t.Fatalf("mutation count mismatch\nGot: %v\nWant: %v", g, w)
1402			}
1403			insert := commit.Mutations[0].GetInsert()
1404			// The first insert should contain NULL values and arrays
1405			// containing exactly one NULL element.
1406			for i := 1; i < len(insert.Values[0].Values); i += 2 {
1407				// The non-array columns should contain NULL values.
1408				g, w := insert.Values[0].Values[i].GetKind(), &structpb.Value_NullValue{}
1409				if _, ok := g.(*structpb.Value_NullValue); !ok {
1410					t.Fatalf("type mismatch\nGot: %v\nWant: %v", g, w)
1411				}
1412				// The array columns should not be NULL.
1413				g, wList := insert.Values[0].Values[i+1].GetKind(), &structpb.Value_ListValue{}
1414				if _, ok := g.(*structpb.Value_ListValue); !ok {
1415					t.Fatalf("type mismatch\nGot: %v\nWant: %v", g, wList)
1416				}
1417				// The array should contain 1 NULL value.
1418				if gLength, wLength := len(insert.Values[0].Values[i+1].GetListValue().Values), 1; gLength != wLength {
1419					t.Fatalf("list value length mismatch\nGot: %v\nWant: %v", gLength, wLength)
1420				}
1421				g, w = insert.Values[0].Values[i+1].GetListValue().Values[0].GetKind(), &structpb.Value_NullValue{}
1422				if _, ok := g.(*structpb.Value_NullValue); !ok {
1423					t.Fatalf("type mismatch\nGot: %v\nWant: %v", g, w)
1424				}
1425			}
1426
1427			// The second insert should contain all non-NULL values.
1428			insert = commit.Mutations[1].GetInsert()
1429			for i := 1; i < len(insert.Values[0].Values); i += 2 {
1430				// The non-array columns should contain non-NULL values.
1431				g := insert.Values[0].Values[i].GetKind()
1432				if _, ok := g.(*structpb.Value_NullValue); ok {
1433					t.Fatalf("type mismatch\nGot: %v\nWant: non-NULL value", g)
1434				}
1435				// The array columns should also be non-NULL.
1436				g, wList := insert.Values[0].Values[i+1].GetKind(), &structpb.Value_ListValue{}
1437				if _, ok := g.(*structpb.Value_ListValue); !ok {
1438					t.Fatalf("type mismatch\nGot: %v\nWant: %v", g, wList)
1439				}
1440				// The array should contain exactly 1 non-NULL value.
1441				if gLength, wLength := len(insert.Values[0].Values[i+1].GetListValue().Values), 1; gLength != wLength {
1442					t.Fatalf("list value length mismatch\nGot: %v\nWant: %v", gLength, wLength)
1443				}
1444				g = insert.Values[0].Values[i+1].GetListValue().Values[0].GetKind()
1445				if _, ok := g.(*structpb.Value_NullValue); ok {
1446					t.Fatalf("type mismatch\nGot: %v\nWant: non-NULL value", g)
1447				}
1448			}
1449		}
1450	}
1451}
1452
1453func TestReadWriteTransaction_ContextTimeoutDuringDuringCommit(t *testing.T) {
1454	t.Parallel()
1455	server, client, teardown := setupMockedTestServer(t)
1456	defer teardown()
1457	server.TestSpanner.PutExecutionTime(MethodCommitTransaction,
1458		SimulatedExecutionTime{
1459			MinimumExecutionTime: time.Minute,
1460		})
1461	ctx, cancel := context.WithTimeout(context.Background(), 200*time.Millisecond)
1462	defer cancel()
1463	_, err := client.ReadWriteTransaction(ctx, func(ctx context.Context, tx *ReadWriteTransaction) error {
1464		tx.BufferWrite([]*Mutation{Insert("FOO", []string{"ID", "NAME"}, []interface{}{int64(1), "bar"})})
1465		return nil
1466	})
1467	errContext, _ := context.WithTimeout(context.Background(), -time.Second)
1468	w := toSpannerErrorWithCommitInfo(errContext.Err(), true).(*Error)
1469	var se *Error
1470	if !errorAs(err, &se) {
1471		t.Fatalf("Error mismatch\nGot: %v\nWant: %v", err, w)
1472	}
1473	if se.GRPCStatus().Code() != w.GRPCStatus().Code() {
1474		t.Fatalf("Error status mismatch:\nGot: %v\nWant: %v", se.GRPCStatus(), w.GRPCStatus())
1475	}
1476	if se.Error() != w.Error() {
1477		t.Fatalf("Error message mismatch:\nGot %s\nWant: %s", se.Error(), w.Error())
1478	}
1479	var outcome *TransactionOutcomeUnknownError
1480	if !errorAs(err, &outcome) {
1481		t.Fatalf("Missing wrapped TransactionOutcomeUnknownError error")
1482	}
1483}
1484
1485func TestFailedCommit_NoRollback(t *testing.T) {
1486	t.Parallel()
1487	server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{
1488		SessionPoolConfig: SessionPoolConfig{
1489			MinOpened:     0,
1490			MaxOpened:     1,
1491			WriteSessions: 0,
1492		},
1493	})
1494	defer teardown()
1495	server.TestSpanner.PutExecutionTime(MethodCommitTransaction,
1496		SimulatedExecutionTime{
1497			Errors: []error{status.Errorf(codes.InvalidArgument, "Invalid mutations")},
1498		})
1499	_, err := client.Apply(context.Background(), []*Mutation{
1500		Insert("FOO", []string{"ID", "BAR"}, []interface{}{1, "value"}),
1501	})
1502	if got, want := status.Convert(err).Code(), codes.InvalidArgument; got != want {
1503		t.Fatalf("Error mismatch\nGot: %v\nWant: %v", got, want)
1504	}
1505	// The failed commit should not trigger a rollback after the commit.
1506	if _, err := shouldHaveReceived(server.TestSpanner, []interface{}{
1507		&sppb.CreateSessionRequest{},
1508		&sppb.BeginTransactionRequest{},
1509		&sppb.CommitRequest{},
1510	}); err != nil {
1511		t.Fatalf("Received RPCs mismatch: %v", err)
1512	}
1513}
1514
1515func TestFailedUpdate_ShouldRollback(t *testing.T) {
1516	t.Parallel()
1517	server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{
1518		SessionPoolConfig: SessionPoolConfig{
1519			MinOpened:     0,
1520			MaxOpened:     1,
1521			WriteSessions: 0,
1522		},
1523	})
1524	defer teardown()
1525	server.TestSpanner.PutExecutionTime(MethodExecuteSql,
1526		SimulatedExecutionTime{
1527			Errors: []error{status.Errorf(codes.InvalidArgument, "Invalid update")},
1528		})
1529	_, err := client.ReadWriteTransaction(context.Background(), func(ctx context.Context, tx *ReadWriteTransaction) error {
1530		_, err := tx.Update(ctx, NewStatement("UPDATE FOO SET BAR='value' WHERE ID=1"))
1531		return err
1532	})
1533	if got, want := status.Convert(err).Code(), codes.InvalidArgument; got != want {
1534		t.Fatalf("Error mismatch\nGot: %v\nWant: %v", got, want)
1535	}
1536	// The failed update should trigger a rollback.
1537	if _, err := shouldHaveReceived(server.TestSpanner, []interface{}{
1538		&sppb.CreateSessionRequest{},
1539		&sppb.BeginTransactionRequest{},
1540		&sppb.ExecuteSqlRequest{},
1541		&sppb.RollbackRequest{},
1542	}); err != nil {
1543		t.Fatalf("Received RPCs mismatch: %v", err)
1544	}
1545}
1546