1/*
2Copyright 2019 Google LLC
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17package spanner
18
19import (
20	"context"
21	"sync"
22	"testing"
23	"time"
24
25	vkit "cloud.google.com/go/spanner/apiv1"
26	. "cloud.google.com/go/spanner/internal/testutil"
27	"google.golang.org/grpc"
28	"google.golang.org/grpc/codes"
29	"google.golang.org/grpc/status"
30)
31
32type testSessionCreateError struct {
33	err error
34	num int32
35}
36
37type testConsumer struct {
38	numExpected int32
39
40	mu       sync.Mutex
41	sessions []*session
42	errors   []*testSessionCreateError
43	numErr   int32
44
45	receivedAll chan struct{}
46}
47
48func (tc *testConsumer) sessionReady(s *session) {
49	tc.mu.Lock()
50	defer tc.mu.Unlock()
51	tc.sessions = append(tc.sessions, s)
52	tc.checkReceivedAll()
53}
54
55func (tc *testConsumer) sessionCreationFailed(err error, num int32) {
56	tc.mu.Lock()
57	defer tc.mu.Unlock()
58	tc.errors = append(tc.errors, &testSessionCreateError{
59		err: err,
60		num: num,
61	})
62	tc.numErr += num
63	tc.checkReceivedAll()
64}
65
66func (tc *testConsumer) checkReceivedAll() {
67	if int32(len(tc.sessions))+tc.numErr == tc.numExpected {
68		close(tc.receivedAll)
69	}
70}
71
72func newTestConsumer(numExpected int32) *testConsumer {
73	return &testConsumer{
74		numExpected: numExpected,
75		receivedAll: make(chan struct{}),
76	}
77}
78
79func TestNextClient(t *testing.T) {
80	t.Parallel()
81
82	n := 4
83	_, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{
84		NumChannels: n,
85		SessionPoolConfig: SessionPoolConfig{
86			MinOpened: 0,
87			MaxOpened: 100,
88		},
89	})
90	defer teardown()
91	sc := client.idleSessions.sc
92	connections := make(map[*grpc.ClientConn]int)
93	for i := 0; i < n; i++ {
94		client, err := sc.nextClient()
95		if err != nil {
96			t.Fatalf("Error getting a gapic client from the session client\nGot: %v", err)
97		}
98		conn1 := client.Connection()
99		conn2 := client.Connection()
100		if conn1 != conn2 {
101			t.Fatalf("Client connection mismatch. Expected to get two equal connections.\nGot: %v and %v", conn1, conn2)
102		}
103		if index, ok := connections[conn1]; ok {
104			t.Fatalf("Same connection used multiple times for different clients.\nClient 1: %v\nClient 2: %v", index, i)
105		}
106		connections[conn1] = i
107	}
108	// Pass through all the clients once more. This time the exact same
109	// connections should be found.
110	for i := 0; i < n; i++ {
111		client, err := sc.nextClient()
112		if err != nil {
113			t.Fatalf("Error getting a gapic client from the session client\nGot: %v", err)
114		}
115		conn := client.Connection()
116		if _, ok := connections[conn]; !ok {
117			t.Fatalf("Connection not found for index %v", i)
118		}
119	}
120}
121
122func TestCreateAndCloseSession(t *testing.T) {
123	t.Parallel()
124
125	server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{
126		SessionPoolConfig: SessionPoolConfig{
127			MinOpened: 0,
128			MaxOpened: 100,
129		},
130	})
131	defer teardown()
132
133	s, err := client.sc.createSession(context.Background())
134	if err != nil {
135		t.Fatalf("batch.next() return error mismatch\ngot: %v\nwant: nil", err)
136	}
137	if s == nil {
138		t.Fatalf("batch.next() return value mismatch\ngot: %v\nwant: any session", s)
139	}
140	if server.TestSpanner.TotalSessionsCreated() != 1 {
141		t.Fatalf("number of sessions created mismatch\ngot: %v\nwant: %v", server.TestSpanner.TotalSessionsCreated(), 1)
142	}
143	s.delete(context.Background())
144	if server.TestSpanner.TotalSessionsDeleted() != 1 {
145		t.Fatalf("number of sessions deleted mismatch\ngot: %v\nwant: %v", server.TestSpanner.TotalSessionsDeleted(), 1)
146	}
147}
148
149func TestBatchCreateAndCloseSession(t *testing.T) {
150	t.Parallel()
151
152	numSessions := int32(100)
153	server, opts, serverTeardown := NewMockedSpannerInMemTestServer(t)
154	defer serverTeardown()
155	for numChannels := 1; numChannels <= 32; numChannels *= 2 {
156		prevCreated := server.TestSpanner.TotalSessionsCreated()
157		prevDeleted := server.TestSpanner.TotalSessionsDeleted()
158		client, err := NewClientWithConfig(context.Background(), "projects/p/instances/i/databases/d", ClientConfig{
159			NumChannels: numChannels,
160			SessionPoolConfig: SessionPoolConfig{
161				MinOpened: 0,
162				MaxOpened: 400,
163			}}, opts...)
164		if err != nil {
165			t.Fatal(err)
166		}
167		consumer := newTestConsumer(numSessions)
168		client.sc.batchCreateSessions(numSessions, consumer)
169		<-consumer.receivedAll
170		if len(consumer.sessions) != int(numSessions) {
171			t.Fatalf("returned number of sessions mismatch\ngot: %v\nwant: %v", len(consumer.sessions), numSessions)
172		}
173		created := server.TestSpanner.TotalSessionsCreated() - prevCreated
174		if created != uint(numSessions) {
175			t.Fatalf("number of sessions created mismatch\ngot: %v\nwant: %v", created, numSessions)
176		}
177		// Check that all channels are used evenly.
178		channelCounts := make(map[*vkit.Client]int32)
179		for _, s := range consumer.sessions {
180			channelCounts[s.client]++
181		}
182		if len(channelCounts) != numChannels {
183			t.Fatalf("number of channels used mismatch\ngot: %v\nwant: %v", len(channelCounts), numChannels)
184		}
185		for _, c := range channelCounts {
186			if c < numSessions/int32(numChannels) || c > numSessions/int32(numChannels)+(numSessions%int32(numChannels)) {
187				t.Fatalf("channel used an unexpected number of times\ngot: %v\nwant between %v and %v", c, numSessions/int32(numChannels), numSessions/int32(numChannels)+1)
188			}
189		}
190		// Delete the sessions.
191		for _, s := range consumer.sessions {
192			s.delete(context.Background())
193		}
194		deleted := server.TestSpanner.TotalSessionsDeleted() - prevDeleted
195		if deleted != uint(numSessions) {
196			t.Fatalf("number of sessions deleted mismatch\ngot: %v\nwant %v", deleted, numSessions)
197		}
198		client.Close()
199	}
200}
201
202func TestBatchCreateSessionsWithExceptions(t *testing.T) {
203	t.Parallel()
204
205	numSessions := int32(100)
206	server, opts, serverTeardown := NewMockedSpannerInMemTestServer(t)
207	defer serverTeardown()
208
209	// Run the test with everything between 1 and numChannels errors.
210	for numErrors := int32(1); numErrors <= numChannels; numErrors++ {
211		// Make sure that the error is not always the first call.
212		for firstErrorAt := numErrors - 1; firstErrorAt < numChannels-numErrors+1; firstErrorAt++ {
213			client, err := NewClientWithConfig(context.Background(), "projects/p/instances/i/databases/d", ClientConfig{
214				NumChannels: numChannels,
215				SessionPoolConfig: SessionPoolConfig{
216					MinOpened: 0,
217					MaxOpened: 400,
218				}}, opts...)
219			if err != nil {
220				t.Fatal(err)
221			}
222			// Register the errors on the server.
223			errors := make([]error, numErrors+firstErrorAt)
224			for i := firstErrorAt; i < numErrors+firstErrorAt; i++ {
225				errors[i] = status.Errorf(codes.FailedPrecondition, "session creation failed")
226			}
227			server.TestSpanner.PutExecutionTime(MethodBatchCreateSession, SimulatedExecutionTime{
228				Errors: errors,
229			})
230			consumer := newTestConsumer(numSessions)
231			client.sc.batchCreateSessions(numSessions, consumer)
232			<-consumer.receivedAll
233
234			sessionsReturned := int32(len(consumer.sessions))
235			if int32(len(consumer.errors)) != numErrors {
236				t.Fatalf("Error count mismatch\nGot: %d\nWant: %d", len(consumer.errors), numErrors)
237			}
238			for _, e := range consumer.errors {
239				if g, w := status.Code(e.err), codes.FailedPrecondition; g != w {
240					t.Fatalf("error code mismatch\ngot: %v\nwant: %v", g, w)
241				}
242			}
243			maxExpectedSessions := numSessions - numErrors*(numSessions/numChannels)
244			minExpectedSessions := numSessions - numErrors*(numSessions/numChannels+1)
245			if sessionsReturned < minExpectedSessions || sessionsReturned > maxExpectedSessions {
246				t.Fatalf("session count mismatch\ngot: %v\nwant between %v and %v", sessionsReturned, minExpectedSessions, maxExpectedSessions)
247			}
248			client.Close()
249		}
250	}
251}
252
253func TestBatchCreateSessions_ServerReturnsLessThanRequestedSessions(t *testing.T) {
254	t.Parallel()
255
256	numChannels := 4
257	server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{
258		NumChannels: numChannels,
259		SessionPoolConfig: SessionPoolConfig{
260			MinOpened: 0,
261			MaxOpened: 100,
262		},
263	})
264	defer teardown()
265	// Ensure that the server will never return more than 10 sessions per batch
266	// create request.
267	server.TestSpanner.SetMaxSessionsReturnedByServerPerBatchRequest(10)
268	numSessions := int32(100)
269	// Request a batch of sessions that is larger than will be returned by the
270	// server in one request. The server will return at most 10 sessions per
271	// request. The sessionCreator will spread these requests over the 4
272	// channels that are available, i.e. do requests for 25 sessions in each
273	// request. The batch should still return 100 sessions.
274	consumer := newTestConsumer(numSessions)
275	client.sc.batchCreateSessions(numSessions, consumer)
276	<-consumer.receivedAll
277	if len(consumer.errors) > 0 {
278		t.Fatalf("Error count mismatch\nGot: %d\nWant: %d", len(consumer.errors), 0)
279	}
280	returnedSessionCount := int32(len(consumer.sessions))
281	if returnedSessionCount != numSessions {
282		t.Fatalf("Returned sessions mismatch\nGot: %v\nWant: %v", returnedSessionCount, numSessions)
283	}
284}
285
286func TestBatchCreateSessions_ServerExhausted(t *testing.T) {
287	t.Parallel()
288
289	numChannels := 4
290	server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{
291		NumChannels: numChannels,
292		SessionPoolConfig: SessionPoolConfig{
293			MinOpened: 0,
294			MaxOpened: 100,
295		},
296	})
297	defer teardown()
298	numSessions := int32(100)
299	maxSessions := int32(50)
300	// Ensure that the server will never return more than 50 sessions in total.
301	server.TestSpanner.SetMaxSessionsReturnedByServerInTotal(maxSessions)
302	consumer := newTestConsumer(numSessions)
303	client.sc.batchCreateSessions(numSessions, consumer)
304	<-consumer.receivedAll
305	// Session creation should end with at least one RESOURCE_EXHAUSTED error.
306	if len(consumer.errors) == 0 {
307		t.Fatalf("Error count mismatch\nGot: %d\nWant: > %d", len(consumer.errors), 0)
308	}
309	for _, e := range consumer.errors {
310		if g, w := status.Code(e.err), codes.ResourceExhausted; g != w {
311			t.Fatalf("Error code mismath\nGot: %v\nWant: %v", g, w)
312		}
313	}
314	// The number of returned sessions should be equal to the max of the
315	// server.
316	returnedSessionCount := int32(len(consumer.sessions))
317	if returnedSessionCount != maxSessions {
318		t.Fatalf("Returned sessions mismatch\nGot: %v\nWant: %v", returnedSessionCount, maxSessions)
319	}
320	if consumer.numErr != (numSessions - maxSessions) {
321		t.Fatalf("Num errored sessions mismatch\nGot: %v\nWant: %v", consumer.numErr, numSessions-maxSessions)
322	}
323}
324
325func TestBatchCreateSessions_WithTimeout(t *testing.T) {
326	t.Parallel()
327
328	numSessions := int32(100)
329	server, opts, serverTeardown := NewMockedSpannerInMemTestServer(t)
330	defer serverTeardown()
331	server.TestSpanner.PutExecutionTime(MethodBatchCreateSession, SimulatedExecutionTime{
332		MinimumExecutionTime: time.Second,
333	})
334	client, err := NewClientWithConfig(context.Background(), "projects/p/instances/i/databases/d", ClientConfig{
335		SessionPoolConfig: SessionPoolConfig{
336			MinOpened: 0,
337			MaxOpened: 400,
338		}}, opts...)
339	if err != nil {
340		t.Fatal(err)
341	}
342
343	client.sc.batchTimeout = 10 * time.Millisecond
344	consumer := newTestConsumer(numSessions)
345	client.sc.batchCreateSessions(numSessions, consumer)
346	<-consumer.receivedAll
347	if len(consumer.sessions) > 0 {
348		t.Fatalf("Returned number of sessions mismatch\ngot: %v\nwant: %v", len(consumer.sessions), 0)
349	}
350	if len(consumer.errors) != numChannels {
351		t.Fatalf("Returned number of errors mismatch\ngot: %v\nwant: %v", len(consumer.errors), numChannels)
352	}
353	for _, e := range consumer.errors {
354		if g, w := status.Code(e.err), codes.DeadlineExceeded; g != w {
355			t.Fatalf("Error code mismatch\ngot: %v (%s)\nwant: %v", g, e.err, w)
356		}
357	}
358	client.Close()
359}
360
361func TestClientIDGenerator(t *testing.T) {
362	cidGen = newClientIDGenerator()
363	for _, tt := range []struct {
364		database string
365		clientID string
366	}{
367		{"db", "client-1"},
368		{"db-new", "client-1"},
369		{"db", "client-2"},
370	} {
371		if got, want := cidGen.nextID(tt.database), tt.clientID; got != want {
372			t.Fatalf("Generate wrong client ID: got %v, want %v", got, want)
373		}
374	}
375}
376