/* Copyright 2019 Google LLC Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ package spanner import ( "context" "sync" "testing" "time" vkit "cloud.google.com/go/spanner/apiv1" . "cloud.google.com/go/spanner/internal/testutil" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) type testSessionCreateError struct { err error num int32 } type testConsumer struct { numExpected int32 mu sync.Mutex sessions []*session errors []*testSessionCreateError numErr int32 receivedAll chan struct{} } func (tc *testConsumer) sessionReady(s *session) { tc.mu.Lock() defer tc.mu.Unlock() tc.sessions = append(tc.sessions, s) tc.checkReceivedAll() } func (tc *testConsumer) sessionCreationFailed(err error, num int32) { tc.mu.Lock() defer tc.mu.Unlock() tc.errors = append(tc.errors, &testSessionCreateError{ err: err, num: num, }) tc.numErr += num tc.checkReceivedAll() } func (tc *testConsumer) checkReceivedAll() { if int32(len(tc.sessions))+tc.numErr == tc.numExpected { close(tc.receivedAll) } } func newTestConsumer(numExpected int32) *testConsumer { return &testConsumer{ numExpected: numExpected, receivedAll: make(chan struct{}), } } func TestCreateAndCloseSession(t *testing.T) { t.Parallel() server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{ SessionPoolConfig: SessionPoolConfig{ MinOpened: 0, MaxOpened: 100, }, }) defer teardown() s, err := client.sc.createSession(context.Background()) if err != nil { t.Fatalf("batch.next() return error mismatch\ngot: %v\nwant: nil", err) } if s == nil { t.Fatalf("batch.next() return value mismatch\ngot: %v\nwant: any session", s) } if server.TestSpanner.TotalSessionsCreated() != 1 { t.Fatalf("number of sessions created mismatch\ngot: %v\nwant: %v", server.TestSpanner.TotalSessionsCreated(), 1) } s.delete(context.Background()) if server.TestSpanner.TotalSessionsDeleted() != 1 { t.Fatalf("number of sessions deleted mismatch\ngot: %v\nwant: %v", server.TestSpanner.TotalSessionsDeleted(), 1) } } func TestBatchCreateAndCloseSession(t *testing.T) { t.Parallel() numSessions := int32(100) server, opts, serverTeardown := NewMockedSpannerInMemTestServer(t) defer serverTeardown() for numChannels := 1; numChannels <= 32; numChannels *= 2 { prevCreated := server.TestSpanner.TotalSessionsCreated() prevDeleted := server.TestSpanner.TotalSessionsDeleted() client, err := NewClientWithConfig(context.Background(), "projects/p/instances/i/databases/d", ClientConfig{ NumChannels: numChannels, SessionPoolConfig: SessionPoolConfig{ MinOpened: 0, MaxOpened: 400, }}, opts...) if err != nil { t.Fatal(err) } consumer := newTestConsumer(numSessions) client.sc.batchCreateSessions(numSessions, consumer) <-consumer.receivedAll if len(consumer.sessions) != int(numSessions) { t.Fatalf("returned number of sessions mismatch\ngot: %v\nwant: %v", len(consumer.sessions), numSessions) } created := server.TestSpanner.TotalSessionsCreated() - prevCreated if created != uint(numSessions) { t.Fatalf("number of sessions created mismatch\ngot: %v\nwant: %v", created, numSessions) } // Check that all channels are used evenly. channelCounts := make(map[*vkit.Client]int32) for _, s := range consumer.sessions { channelCounts[s.client]++ } if len(channelCounts) != numChannels { t.Fatalf("number of channels used mismatch\ngot: %v\nwant: %v", len(channelCounts), numChannels) } for _, c := range channelCounts { if c < numSessions/int32(numChannels) || c > numSessions/int32(numChannels)+(numSessions%int32(numChannels)) { t.Fatalf("channel used an unexpected number of times\ngot: %v\nwant between %v and %v", c, numSessions/int32(numChannels), numSessions/int32(numChannels)+1) } } // Delete the sessions. for _, s := range consumer.sessions { s.delete(context.Background()) } deleted := server.TestSpanner.TotalSessionsDeleted() - prevDeleted if deleted != uint(numSessions) { t.Fatalf("number of sessions deleted mismatch\ngot: %v\nwant %v", deleted, numSessions) } client.Close() } } func TestBatchCreateSessionsWithExceptions(t *testing.T) { t.Parallel() numSessions := int32(100) server, opts, serverTeardown := NewMockedSpannerInMemTestServer(t) defer serverTeardown() // Run the test with everything between 1 and numChannels errors. for numErrors := int32(1); numErrors <= numChannels; numErrors++ { // Make sure that the error is not always the first call. for firstErrorAt := numErrors - 1; firstErrorAt < numChannels-numErrors+1; firstErrorAt++ { client, err := NewClientWithConfig(context.Background(), "projects/p/instances/i/databases/d", ClientConfig{ NumChannels: numChannels, SessionPoolConfig: SessionPoolConfig{ MinOpened: 0, MaxOpened: 400, }}, opts...) if err != nil { t.Fatal(err) } // Register the errors on the server. errors := make([]error, numErrors+firstErrorAt) for i := firstErrorAt; i < numErrors+firstErrorAt; i++ { errors[i] = spannerErrorf(codes.FailedPrecondition, "session creation failed") } server.TestSpanner.PutExecutionTime(MethodBatchCreateSession, SimulatedExecutionTime{ Errors: errors, }) consumer := newTestConsumer(numSessions) client.sc.batchCreateSessions(numSessions, consumer) <-consumer.receivedAll sessionsReturned := int32(len(consumer.sessions)) if int32(len(consumer.errors)) != numErrors { t.Fatalf("Error count mismatch\nGot: %d\nWant: %d", len(consumer.errors), numErrors) } for _, e := range consumer.errors { if g, w := status.Code(e.err), codes.FailedPrecondition; g != w { t.Fatalf("error code mismatch\ngot: %v\nwant: %v", g, w) } } maxExpectedSessions := numSessions - numErrors*(numSessions/numChannels) minExpectedSessions := numSessions - numErrors*(numSessions/numChannels+1) if sessionsReturned < minExpectedSessions || sessionsReturned > maxExpectedSessions { t.Fatalf("session count mismatch\ngot: %v\nwant between %v and %v", sessionsReturned, minExpectedSessions, maxExpectedSessions) } client.Close() } } } func TestBatchCreateSessions_ServerReturnsLessThanRequestedSessions(t *testing.T) { t.Parallel() numChannels := 4 server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{ NumChannels: numChannels, SessionPoolConfig: SessionPoolConfig{ MinOpened: 0, MaxOpened: 100, }, }) defer teardown() // Ensure that the server will never return more than 10 sessions per batch // create request. server.TestSpanner.SetMaxSessionsReturnedByServerPerBatchRequest(10) numSessions := int32(100) // Request a batch of sessions that is larger than will be returned by the // server in one request. The server will return at most 10 sessions per // request. The sessionCreator will spread these requests over the 4 // channels that are available, i.e. do requests for 25 sessions in each // request. The batch should still return 100 sessions. consumer := newTestConsumer(numSessions) client.sc.batchCreateSessions(numSessions, consumer) <-consumer.receivedAll if len(consumer.errors) > 0 { t.Fatalf("Error count mismatch\nGot: %d\nWant: %d", len(consumer.errors), 0) } returnedSessionCount := int32(len(consumer.sessions)) if returnedSessionCount != numSessions { t.Fatalf("Returned sessions mismatch\nGot: %v\nWant: %v", returnedSessionCount, numSessions) } } func TestBatchCreateSessions_ServerExhausted(t *testing.T) { t.Parallel() numChannels := 4 server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{ NumChannels: numChannels, SessionPoolConfig: SessionPoolConfig{ MinOpened: 0, MaxOpened: 100, }, }) defer teardown() numSessions := int32(100) maxSessions := int32(50) // Ensure that the server will never return more than 50 sessions in total. server.TestSpanner.SetMaxSessionsReturnedByServerInTotal(maxSessions) consumer := newTestConsumer(numSessions) client.sc.batchCreateSessions(numSessions, consumer) <-consumer.receivedAll // Session creation should end with at least one RESOURCE_EXHAUSTED error. if len(consumer.errors) == 0 { t.Fatalf("Error count mismatch\nGot: %d\nWant: > %d", len(consumer.errors), 0) } for _, e := range consumer.errors { if g, w := status.Code(e.err), codes.ResourceExhausted; g != w { t.Fatalf("Error code mismath\nGot: %v\nWant: %v", g, w) } } // The number of returned sessions should be equal to the max of the // server. returnedSessionCount := int32(len(consumer.sessions)) if returnedSessionCount != maxSessions { t.Fatalf("Returned sessions mismatch\nGot: %v\nWant: %v", returnedSessionCount, maxSessions) } if consumer.numErr != (numSessions - maxSessions) { t.Fatalf("Num errored sessions mismatch\nGot: %v\nWant: %v", consumer.numErr, numSessions-maxSessions) } } func TestBatchCreateSessions_WithTimeout(t *testing.T) { t.Parallel() numSessions := int32(100) server, opts, serverTeardown := NewMockedSpannerInMemTestServer(t) defer serverTeardown() server.TestSpanner.PutExecutionTime(MethodBatchCreateSession, SimulatedExecutionTime{ MinimumExecutionTime: time.Second, }) client, err := NewClientWithConfig(context.Background(), "projects/p/instances/i/databases/d", ClientConfig{ SessionPoolConfig: SessionPoolConfig{ MinOpened: 0, MaxOpened: 400, }}, opts...) if err != nil { t.Fatal(err) } client.sc.batchTimeout = 10 * time.Millisecond consumer := newTestConsumer(numSessions) client.sc.batchCreateSessions(numSessions, consumer) <-consumer.receivedAll if len(consumer.sessions) > 0 { t.Fatalf("Returned number of sessions mismatch\ngot: %v\nwant: %v", len(consumer.sessions), 0) } if len(consumer.errors) != numChannels { t.Fatalf("Returned number of errors mismatch\ngot: %v\nwant: %v", len(consumer.errors), numChannels) } for _, e := range consumer.errors { if g, w := status.Code(e.err), codes.DeadlineExceeded; g != w { t.Fatalf("Error code mismatch\ngot: %v\nwant: %v", g, w) } } client.Close() }