1/* 2Copyright 2019 Google LLC 3 4Licensed under the Apache License, Version 2.0 (the "License"); 5you may not use this file except in compliance with the License. 6You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10Unless required by applicable law or agreed to in writing, software 11distributed under the License is distributed on an "AS IS" BASIS, 12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13See the License for the specific language governing permissions and 14limitations under the License. 15*/ 16 17package spanner 18 19import ( 20 "context" 21 "sync" 22 "testing" 23 "time" 24 25 vkit "cloud.google.com/go/spanner/apiv1" 26 . "cloud.google.com/go/spanner/internal/testutil" 27 "google.golang.org/grpc/codes" 28 "google.golang.org/grpc/status" 29) 30 31type testSessionCreateError struct { 32 err error 33 num int32 34} 35 36type testConsumer struct { 37 numExpected int32 38 39 mu sync.Mutex 40 sessions []*session 41 errors []*testSessionCreateError 42 numErr int32 43 44 receivedAll chan struct{} 45} 46 47func (tc *testConsumer) sessionReady(s *session) { 48 tc.mu.Lock() 49 defer tc.mu.Unlock() 50 tc.sessions = append(tc.sessions, s) 51 tc.checkReceivedAll() 52} 53 54func (tc *testConsumer) sessionCreationFailed(err error, num int32) { 55 tc.mu.Lock() 56 defer tc.mu.Unlock() 57 tc.errors = append(tc.errors, &testSessionCreateError{ 58 err: err, 59 num: num, 60 }) 61 tc.numErr += num 62 tc.checkReceivedAll() 63} 64 65func (tc *testConsumer) checkReceivedAll() { 66 if int32(len(tc.sessions))+tc.numErr == tc.numExpected { 67 close(tc.receivedAll) 68 } 69} 70 71func newTestConsumer(numExpected int32) *testConsumer { 72 return &testConsumer{ 73 numExpected: numExpected, 74 receivedAll: make(chan struct{}), 75 } 76} 77 78func TestCreateAndCloseSession(t *testing.T) { 79 t.Parallel() 80 81 server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{ 82 SessionPoolConfig: SessionPoolConfig{ 83 MinOpened: 0, 84 MaxOpened: 100, 85 }, 86 }) 87 defer teardown() 88 89 s, err := client.sc.createSession(context.Background()) 90 if err != nil { 91 t.Fatalf("batch.next() return error mismatch\ngot: %v\nwant: nil", err) 92 } 93 if s == nil { 94 t.Fatalf("batch.next() return value mismatch\ngot: %v\nwant: any session", s) 95 } 96 if server.TestSpanner.TotalSessionsCreated() != 1 { 97 t.Fatalf("number of sessions created mismatch\ngot: %v\nwant: %v", server.TestSpanner.TotalSessionsCreated(), 1) 98 } 99 s.delete(context.Background()) 100 if server.TestSpanner.TotalSessionsDeleted() != 1 { 101 t.Fatalf("number of sessions deleted mismatch\ngot: %v\nwant: %v", server.TestSpanner.TotalSessionsDeleted(), 1) 102 } 103} 104 105func TestBatchCreateAndCloseSession(t *testing.T) { 106 t.Parallel() 107 108 numSessions := int32(100) 109 server, opts, serverTeardown := NewMockedSpannerInMemTestServer(t) 110 defer serverTeardown() 111 for numChannels := 1; numChannels <= 32; numChannels *= 2 { 112 prevCreated := server.TestSpanner.TotalSessionsCreated() 113 prevDeleted := server.TestSpanner.TotalSessionsDeleted() 114 client, err := NewClientWithConfig(context.Background(), "projects/p/instances/i/databases/d", ClientConfig{ 115 NumChannels: numChannels, 116 SessionPoolConfig: SessionPoolConfig{ 117 MinOpened: 0, 118 MaxOpened: 400, 119 }}, opts...) 120 if err != nil { 121 t.Fatal(err) 122 } 123 consumer := newTestConsumer(numSessions) 124 client.sc.batchCreateSessions(numSessions, consumer) 125 <-consumer.receivedAll 126 if len(consumer.sessions) != int(numSessions) { 127 t.Fatalf("returned number of sessions mismatch\ngot: %v\nwant: %v", len(consumer.sessions), numSessions) 128 } 129 created := server.TestSpanner.TotalSessionsCreated() - prevCreated 130 if created != uint(numSessions) { 131 t.Fatalf("number of sessions created mismatch\ngot: %v\nwant: %v", created, numSessions) 132 } 133 // Check that all channels are used evenly. 134 channelCounts := make(map[*vkit.Client]int32) 135 for _, s := range consumer.sessions { 136 channelCounts[s.client]++ 137 } 138 if len(channelCounts) != numChannels { 139 t.Fatalf("number of channels used mismatch\ngot: %v\nwant: %v", len(channelCounts), numChannels) 140 } 141 for _, c := range channelCounts { 142 if c < numSessions/int32(numChannels) || c > numSessions/int32(numChannels)+(numSessions%int32(numChannels)) { 143 t.Fatalf("channel used an unexpected number of times\ngot: %v\nwant between %v and %v", c, numSessions/int32(numChannels), numSessions/int32(numChannels)+1) 144 } 145 } 146 // Delete the sessions. 147 for _, s := range consumer.sessions { 148 s.delete(context.Background()) 149 } 150 deleted := server.TestSpanner.TotalSessionsDeleted() - prevDeleted 151 if deleted != uint(numSessions) { 152 t.Fatalf("number of sessions deleted mismatch\ngot: %v\nwant %v", deleted, numSessions) 153 } 154 client.Close() 155 } 156} 157 158func TestBatchCreateSessionsWithExceptions(t *testing.T) { 159 t.Parallel() 160 161 numSessions := int32(100) 162 server, opts, serverTeardown := NewMockedSpannerInMemTestServer(t) 163 defer serverTeardown() 164 165 // Run the test with everything between 1 and numChannels errors. 166 for numErrors := int32(1); numErrors <= numChannels; numErrors++ { 167 // Make sure that the error is not always the first call. 168 for firstErrorAt := numErrors - 1; firstErrorAt < numChannels-numErrors+1; firstErrorAt++ { 169 client, err := NewClientWithConfig(context.Background(), "projects/p/instances/i/databases/d", ClientConfig{ 170 NumChannels: numChannels, 171 SessionPoolConfig: SessionPoolConfig{ 172 MinOpened: 0, 173 MaxOpened: 400, 174 }}, opts...) 175 if err != nil { 176 t.Fatal(err) 177 } 178 // Register the errors on the server. 179 errors := make([]error, numErrors+firstErrorAt) 180 for i := firstErrorAt; i < numErrors+firstErrorAt; i++ { 181 errors[i] = status.Errorf(codes.FailedPrecondition, "session creation failed") 182 } 183 server.TestSpanner.PutExecutionTime(MethodBatchCreateSession, SimulatedExecutionTime{ 184 Errors: errors, 185 }) 186 consumer := newTestConsumer(numSessions) 187 client.sc.batchCreateSessions(numSessions, consumer) 188 <-consumer.receivedAll 189 190 sessionsReturned := int32(len(consumer.sessions)) 191 if int32(len(consumer.errors)) != numErrors { 192 t.Fatalf("Error count mismatch\nGot: %d\nWant: %d", len(consumer.errors), numErrors) 193 } 194 for _, e := range consumer.errors { 195 if g, w := status.Code(e.err), codes.FailedPrecondition; g != w { 196 t.Fatalf("error code mismatch\ngot: %v\nwant: %v", g, w) 197 } 198 } 199 maxExpectedSessions := numSessions - numErrors*(numSessions/numChannels) 200 minExpectedSessions := numSessions - numErrors*(numSessions/numChannels+1) 201 if sessionsReturned < minExpectedSessions || sessionsReturned > maxExpectedSessions { 202 t.Fatalf("session count mismatch\ngot: %v\nwant between %v and %v", sessionsReturned, minExpectedSessions, maxExpectedSessions) 203 } 204 client.Close() 205 } 206 } 207} 208 209func TestBatchCreateSessions_ServerReturnsLessThanRequestedSessions(t *testing.T) { 210 t.Parallel() 211 212 numChannels := 4 213 server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{ 214 NumChannels: numChannels, 215 SessionPoolConfig: SessionPoolConfig{ 216 MinOpened: 0, 217 MaxOpened: 100, 218 }, 219 }) 220 defer teardown() 221 // Ensure that the server will never return more than 10 sessions per batch 222 // create request. 223 server.TestSpanner.SetMaxSessionsReturnedByServerPerBatchRequest(10) 224 numSessions := int32(100) 225 // Request a batch of sessions that is larger than will be returned by the 226 // server in one request. The server will return at most 10 sessions per 227 // request. The sessionCreator will spread these requests over the 4 228 // channels that are available, i.e. do requests for 25 sessions in each 229 // request. The batch should still return 100 sessions. 230 consumer := newTestConsumer(numSessions) 231 client.sc.batchCreateSessions(numSessions, consumer) 232 <-consumer.receivedAll 233 if len(consumer.errors) > 0 { 234 t.Fatalf("Error count mismatch\nGot: %d\nWant: %d", len(consumer.errors), 0) 235 } 236 returnedSessionCount := int32(len(consumer.sessions)) 237 if returnedSessionCount != numSessions { 238 t.Fatalf("Returned sessions mismatch\nGot: %v\nWant: %v", returnedSessionCount, numSessions) 239 } 240} 241 242func TestBatchCreateSessions_ServerExhausted(t *testing.T) { 243 t.Parallel() 244 245 numChannels := 4 246 server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{ 247 NumChannels: numChannels, 248 SessionPoolConfig: SessionPoolConfig{ 249 MinOpened: 0, 250 MaxOpened: 100, 251 }, 252 }) 253 defer teardown() 254 numSessions := int32(100) 255 maxSessions := int32(50) 256 // Ensure that the server will never return more than 50 sessions in total. 257 server.TestSpanner.SetMaxSessionsReturnedByServerInTotal(maxSessions) 258 consumer := newTestConsumer(numSessions) 259 client.sc.batchCreateSessions(numSessions, consumer) 260 <-consumer.receivedAll 261 // Session creation should end with at least one RESOURCE_EXHAUSTED error. 262 if len(consumer.errors) == 0 { 263 t.Fatalf("Error count mismatch\nGot: %d\nWant: > %d", len(consumer.errors), 0) 264 } 265 for _, e := range consumer.errors { 266 if g, w := status.Code(e.err), codes.ResourceExhausted; g != w { 267 t.Fatalf("Error code mismath\nGot: %v\nWant: %v", g, w) 268 } 269 } 270 // The number of returned sessions should be equal to the max of the 271 // server. 272 returnedSessionCount := int32(len(consumer.sessions)) 273 if returnedSessionCount != maxSessions { 274 t.Fatalf("Returned sessions mismatch\nGot: %v\nWant: %v", returnedSessionCount, maxSessions) 275 } 276 if consumer.numErr != (numSessions - maxSessions) { 277 t.Fatalf("Num errored sessions mismatch\nGot: %v\nWant: %v", consumer.numErr, numSessions-maxSessions) 278 } 279} 280 281func TestBatchCreateSessions_WithTimeout(t *testing.T) { 282 t.Parallel() 283 284 numSessions := int32(100) 285 server, opts, serverTeardown := NewMockedSpannerInMemTestServer(t) 286 defer serverTeardown() 287 server.TestSpanner.PutExecutionTime(MethodBatchCreateSession, SimulatedExecutionTime{ 288 MinimumExecutionTime: time.Second, 289 }) 290 client, err := NewClientWithConfig(context.Background(), "projects/p/instances/i/databases/d", ClientConfig{ 291 SessionPoolConfig: SessionPoolConfig{ 292 MinOpened: 0, 293 MaxOpened: 400, 294 }}, opts...) 295 if err != nil { 296 t.Fatal(err) 297 } 298 299 client.sc.batchTimeout = 10 * time.Millisecond 300 consumer := newTestConsumer(numSessions) 301 client.sc.batchCreateSessions(numSessions, consumer) 302 <-consumer.receivedAll 303 if len(consumer.sessions) > 0 { 304 t.Fatalf("Returned number of sessions mismatch\ngot: %v\nwant: %v", len(consumer.sessions), 0) 305 } 306 if len(consumer.errors) != numChannels { 307 t.Fatalf("Returned number of errors mismatch\ngot: %v\nwant: %v", len(consumer.errors), numChannels) 308 } 309 for _, e := range consumer.errors { 310 if g, w := status.Code(e.err), codes.DeadlineExceeded; g != w { 311 t.Fatalf("Error code mismatch\ngot: %v (%s)\nwant: %v", g, e.err, w) 312 } 313 } 314 client.Close() 315} 316