1/* 2Copyright 2019 Google LLC 3 4Licensed under the Apache License, Version 2.0 (the "License"); 5you may not use this file except in compliance with the License. 6You may obtain a copy of the License at 7 8 http://www.apache.org/licenses/LICENSE-2.0 9 10Unless required by applicable law or agreed to in writing, software 11distributed under the License is distributed on an "AS IS" BASIS, 12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13See the License for the specific language governing permissions and 14limitations under the License. 15*/ 16 17package spanner 18 19import ( 20 "context" 21 "sync" 22 "testing" 23 "time" 24 25 vkit "cloud.google.com/go/spanner/apiv1" 26 . "cloud.google.com/go/spanner/internal/testutil" 27 "google.golang.org/grpc" 28 "google.golang.org/grpc/codes" 29 "google.golang.org/grpc/status" 30) 31 32type testSessionCreateError struct { 33 err error 34 num int32 35} 36 37type testConsumer struct { 38 numExpected int32 39 40 mu sync.Mutex 41 sessions []*session 42 errors []*testSessionCreateError 43 numErr int32 44 45 receivedAll chan struct{} 46} 47 48func (tc *testConsumer) sessionReady(s *session) { 49 tc.mu.Lock() 50 defer tc.mu.Unlock() 51 tc.sessions = append(tc.sessions, s) 52 tc.checkReceivedAll() 53} 54 55func (tc *testConsumer) sessionCreationFailed(err error, num int32) { 56 tc.mu.Lock() 57 defer tc.mu.Unlock() 58 tc.errors = append(tc.errors, &testSessionCreateError{ 59 err: err, 60 num: num, 61 }) 62 tc.numErr += num 63 tc.checkReceivedAll() 64} 65 66func (tc *testConsumer) checkReceivedAll() { 67 if int32(len(tc.sessions))+tc.numErr == tc.numExpected { 68 close(tc.receivedAll) 69 } 70} 71 72func newTestConsumer(numExpected int32) *testConsumer { 73 return &testConsumer{ 74 numExpected: numExpected, 75 receivedAll: make(chan struct{}), 76 } 77} 78 79func TestNextClient(t *testing.T) { 80 t.Parallel() 81 82 n := 4 83 _, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{ 84 NumChannels: n, 85 SessionPoolConfig: SessionPoolConfig{ 86 MinOpened: 0, 87 MaxOpened: 100, 88 }, 89 }) 90 defer teardown() 91 sc := client.idleSessions.sc 92 connections := make(map[*grpc.ClientConn]int) 93 for i := 0; i < n; i++ { 94 client, err := sc.nextClient() 95 if err != nil { 96 t.Fatalf("Error getting a gapic client from the session client\nGot: %v", err) 97 } 98 conn1 := client.Connection() 99 conn2 := client.Connection() 100 if conn1 != conn2 { 101 t.Fatalf("Client connection mismatch. Expected to get two equal connections.\nGot: %v and %v", conn1, conn2) 102 } 103 if index, ok := connections[conn1]; ok { 104 t.Fatalf("Same connection used multiple times for different clients.\nClient 1: %v\nClient 2: %v", index, i) 105 } 106 connections[conn1] = i 107 } 108 // Pass through all the clients once more. This time the exact same 109 // connections should be found. 110 for i := 0; i < n; i++ { 111 client, err := sc.nextClient() 112 if err != nil { 113 t.Fatalf("Error getting a gapic client from the session client\nGot: %v", err) 114 } 115 conn := client.Connection() 116 if _, ok := connections[conn]; !ok { 117 t.Fatalf("Connection not found for index %v", i) 118 } 119 } 120} 121 122func TestCreateAndCloseSession(t *testing.T) { 123 t.Parallel() 124 125 server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{ 126 SessionPoolConfig: SessionPoolConfig{ 127 MinOpened: 0, 128 MaxOpened: 100, 129 }, 130 }) 131 defer teardown() 132 133 s, err := client.sc.createSession(context.Background()) 134 if err != nil { 135 t.Fatalf("batch.next() return error mismatch\ngot: %v\nwant: nil", err) 136 } 137 if s == nil { 138 t.Fatalf("batch.next() return value mismatch\ngot: %v\nwant: any session", s) 139 } 140 if server.TestSpanner.TotalSessionsCreated() != 1 { 141 t.Fatalf("number of sessions created mismatch\ngot: %v\nwant: %v", server.TestSpanner.TotalSessionsCreated(), 1) 142 } 143 s.delete(context.Background()) 144 if server.TestSpanner.TotalSessionsDeleted() != 1 { 145 t.Fatalf("number of sessions deleted mismatch\ngot: %v\nwant: %v", server.TestSpanner.TotalSessionsDeleted(), 1) 146 } 147} 148 149func TestBatchCreateAndCloseSession(t *testing.T) { 150 t.Parallel() 151 152 numSessions := int32(100) 153 server, opts, serverTeardown := NewMockedSpannerInMemTestServer(t) 154 defer serverTeardown() 155 for numChannels := 1; numChannels <= 32; numChannels *= 2 { 156 prevCreated := server.TestSpanner.TotalSessionsCreated() 157 prevDeleted := server.TestSpanner.TotalSessionsDeleted() 158 client, err := NewClientWithConfig(context.Background(), "projects/p/instances/i/databases/d", ClientConfig{ 159 NumChannels: numChannels, 160 SessionPoolConfig: SessionPoolConfig{ 161 MinOpened: 0, 162 MaxOpened: 400, 163 }}, opts...) 164 if err != nil { 165 t.Fatal(err) 166 } 167 consumer := newTestConsumer(numSessions) 168 client.sc.batchCreateSessions(numSessions, true, consumer) 169 <-consumer.receivedAll 170 if len(consumer.sessions) != int(numSessions) { 171 t.Fatalf("returned number of sessions mismatch\ngot: %v\nwant: %v", len(consumer.sessions), numSessions) 172 } 173 created := server.TestSpanner.TotalSessionsCreated() - prevCreated 174 if created != uint(numSessions) { 175 t.Fatalf("number of sessions created mismatch\ngot: %v\nwant: %v", created, numSessions) 176 } 177 // Check that all channels are used evenly. 178 channelCounts := make(map[*vkit.Client]int32) 179 for _, s := range consumer.sessions { 180 channelCounts[s.client]++ 181 } 182 if len(channelCounts) != numChannels { 183 t.Fatalf("number of channels used mismatch\ngot: %v\nwant: %v", len(channelCounts), numChannels) 184 } 185 for _, c := range channelCounts { 186 if c < numSessions/int32(numChannels) || c > numSessions/int32(numChannels)+(numSessions%int32(numChannels)) { 187 t.Fatalf("channel used an unexpected number of times\ngot: %v\nwant between %v and %v", c, numSessions/int32(numChannels), numSessions/int32(numChannels)+1) 188 } 189 } 190 // Delete the sessions. 191 for _, s := range consumer.sessions { 192 s.delete(context.Background()) 193 } 194 deleted := server.TestSpanner.TotalSessionsDeleted() - prevDeleted 195 if deleted != uint(numSessions) { 196 t.Fatalf("number of sessions deleted mismatch\ngot: %v\nwant %v", deleted, numSessions) 197 } 198 client.Close() 199 } 200} 201 202func TestBatchCreateSessionsWithExceptions(t *testing.T) { 203 t.Parallel() 204 205 numSessions := int32(100) 206 server, opts, serverTeardown := NewMockedSpannerInMemTestServer(t) 207 defer serverTeardown() 208 209 // Run the test with everything between 1 and numChannels errors. 210 for numErrors := int32(1); numErrors <= numChannels; numErrors++ { 211 // Make sure that the error is not always the first call. 212 for firstErrorAt := numErrors - 1; firstErrorAt < numChannels-numErrors+1; firstErrorAt++ { 213 client, err := NewClientWithConfig(context.Background(), "projects/p/instances/i/databases/d", ClientConfig{ 214 NumChannels: numChannels, 215 SessionPoolConfig: SessionPoolConfig{ 216 MinOpened: 0, 217 MaxOpened: 400, 218 }}, opts...) 219 if err != nil { 220 t.Fatal(err) 221 } 222 // Register the errors on the server. 223 errors := make([]error, numErrors+firstErrorAt) 224 for i := firstErrorAt; i < numErrors+firstErrorAt; i++ { 225 errors[i] = status.Errorf(codes.FailedPrecondition, "session creation failed") 226 } 227 server.TestSpanner.PutExecutionTime(MethodBatchCreateSession, SimulatedExecutionTime{ 228 Errors: errors, 229 }) 230 consumer := newTestConsumer(numSessions) 231 client.sc.batchCreateSessions(numSessions, true, consumer) 232 <-consumer.receivedAll 233 234 sessionsReturned := int32(len(consumer.sessions)) 235 if int32(len(consumer.errors)) != numErrors { 236 t.Fatalf("Error count mismatch\nGot: %d\nWant: %d", len(consumer.errors), numErrors) 237 } 238 for _, e := range consumer.errors { 239 if g, w := status.Code(e.err), codes.FailedPrecondition; g != w { 240 t.Fatalf("error code mismatch\ngot: %v\nwant: %v", g, w) 241 } 242 } 243 maxExpectedSessions := numSessions - numErrors*(numSessions/numChannels) 244 minExpectedSessions := numSessions - numErrors*(numSessions/numChannels+1) 245 if sessionsReturned < minExpectedSessions || sessionsReturned > maxExpectedSessions { 246 t.Fatalf("session count mismatch\ngot: %v\nwant between %v and %v", sessionsReturned, minExpectedSessions, maxExpectedSessions) 247 } 248 client.Close() 249 } 250 } 251} 252 253func TestBatchCreateSessions_ServerReturnsLessThanRequestedSessions(t *testing.T) { 254 t.Parallel() 255 256 numChannels := 4 257 server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{ 258 NumChannels: numChannels, 259 SessionPoolConfig: SessionPoolConfig{ 260 MinOpened: 0, 261 MaxOpened: 100, 262 }, 263 }) 264 defer teardown() 265 // Ensure that the server will never return more than 10 sessions per batch 266 // create request. 267 server.TestSpanner.SetMaxSessionsReturnedByServerPerBatchRequest(10) 268 numSessions := int32(100) 269 // Request a batch of sessions that is larger than will be returned by the 270 // server in one request. The server will return at most 10 sessions per 271 // request. The sessionCreator will spread these requests over the 4 272 // channels that are available, i.e. do requests for 25 sessions in each 273 // request. The batch should still return 100 sessions. 274 consumer := newTestConsumer(numSessions) 275 client.sc.batchCreateSessions(numSessions, true, consumer) 276 <-consumer.receivedAll 277 if len(consumer.errors) > 0 { 278 t.Fatalf("Error count mismatch\nGot: %d\nWant: %d", len(consumer.errors), 0) 279 } 280 returnedSessionCount := int32(len(consumer.sessions)) 281 if returnedSessionCount != numSessions { 282 t.Fatalf("Returned sessions mismatch\nGot: %v\nWant: %v", returnedSessionCount, numSessions) 283 } 284} 285 286func TestBatchCreateSessions_ServerExhausted(t *testing.T) { 287 t.Parallel() 288 289 numChannels := 4 290 server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{ 291 NumChannels: numChannels, 292 SessionPoolConfig: SessionPoolConfig{ 293 MinOpened: 0, 294 MaxOpened: 100, 295 }, 296 }) 297 defer teardown() 298 numSessions := int32(100) 299 maxSessions := int32(50) 300 // Ensure that the server will never return more than 50 sessions in total. 301 server.TestSpanner.SetMaxSessionsReturnedByServerInTotal(maxSessions) 302 consumer := newTestConsumer(numSessions) 303 client.sc.batchCreateSessions(numSessions, true, consumer) 304 <-consumer.receivedAll 305 // Session creation should end with at least one RESOURCE_EXHAUSTED error. 306 if len(consumer.errors) == 0 { 307 t.Fatalf("Error count mismatch\nGot: %d\nWant: > %d", len(consumer.errors), 0) 308 } 309 for _, e := range consumer.errors { 310 if g, w := status.Code(e.err), codes.ResourceExhausted; g != w { 311 t.Fatalf("Error code mismath\nGot: %v\nWant: %v", g, w) 312 } 313 } 314 // The number of returned sessions should be equal to the max of the 315 // server. 316 returnedSessionCount := int32(len(consumer.sessions)) 317 if returnedSessionCount != maxSessions { 318 t.Fatalf("Returned sessions mismatch\nGot: %v\nWant: %v", returnedSessionCount, maxSessions) 319 } 320 if consumer.numErr != (numSessions - maxSessions) { 321 t.Fatalf("Num errored sessions mismatch\nGot: %v\nWant: %v", consumer.numErr, numSessions-maxSessions) 322 } 323} 324 325func TestBatchCreateSessions_WithTimeout(t *testing.T) { 326 t.Parallel() 327 328 numSessions := int32(100) 329 server, opts, serverTeardown := NewMockedSpannerInMemTestServer(t) 330 defer serverTeardown() 331 server.TestSpanner.PutExecutionTime(MethodBatchCreateSession, SimulatedExecutionTime{ 332 MinimumExecutionTime: time.Second, 333 }) 334 client, err := NewClientWithConfig(context.Background(), "projects/p/instances/i/databases/d", ClientConfig{ 335 SessionPoolConfig: SessionPoolConfig{ 336 MinOpened: 0, 337 MaxOpened: 400, 338 }}, opts...) 339 if err != nil { 340 t.Fatal(err) 341 } 342 343 client.sc.batchTimeout = 10 * time.Millisecond 344 consumer := newTestConsumer(numSessions) 345 client.sc.batchCreateSessions(numSessions, true, consumer) 346 <-consumer.receivedAll 347 if len(consumer.sessions) > 0 { 348 t.Fatalf("Returned number of sessions mismatch\ngot: %v\nwant: %v", len(consumer.sessions), 0) 349 } 350 if len(consumer.errors) != numChannels { 351 t.Fatalf("Returned number of errors mismatch\ngot: %v\nwant: %v", len(consumer.errors), numChannels) 352 } 353 for _, e := range consumer.errors { 354 if g, w := status.Code(e.err), codes.DeadlineExceeded; g != w { 355 t.Fatalf("Error code mismatch\ngot: %v (%s)\nwant: %v", g, e.err, w) 356 } 357 } 358 client.Close() 359} 360 361func TestClientIDGenerator(t *testing.T) { 362 cidGen = newClientIDGenerator() 363 for _, tt := range []struct { 364 database string 365 clientID string 366 }{ 367 {"db", "client-1"}, 368 {"db-new", "client-1"}, 369 {"db", "client-2"}, 370 } { 371 if got, want := cidGen.nextID(tt.database), tt.clientID; got != want { 372 t.Fatalf("Generate wrong client ID: got %v, want %v", got, want) 373 } 374 } 375} 376