1/*
2Copyright 2019 Google LLC
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17package spanner
18
19import (
20	"context"
21	"fmt"
22	"sync"
23	"testing"
24	"time"
25
26	vkit "cloud.google.com/go/spanner/apiv1"
27	. "cloud.google.com/go/spanner/internal/testutil"
28	gax "github.com/googleapis/gax-go/v2"
29	"google.golang.org/grpc"
30	"google.golang.org/grpc/codes"
31	"google.golang.org/grpc/status"
32)
33
34type testSessionCreateError struct {
35	err error
36	num int32
37}
38
39type testConsumer struct {
40	numExpected int32
41
42	mu       sync.Mutex
43	sessions []*session
44	errors   []*testSessionCreateError
45	numErr   int32
46
47	receivedAll chan struct{}
48}
49
50func (tc *testConsumer) sessionReady(s *session) {
51	tc.mu.Lock()
52	defer tc.mu.Unlock()
53	tc.sessions = append(tc.sessions, s)
54	tc.checkReceivedAll()
55}
56
57func (tc *testConsumer) sessionCreationFailed(err error, num int32) {
58	tc.mu.Lock()
59	defer tc.mu.Unlock()
60	tc.errors = append(tc.errors, &testSessionCreateError{
61		err: err,
62		num: num,
63	})
64	tc.numErr += num
65	tc.checkReceivedAll()
66}
67
68func (tc *testConsumer) checkReceivedAll() {
69	if int32(len(tc.sessions))+tc.numErr == tc.numExpected {
70		close(tc.receivedAll)
71	}
72}
73
74func newTestConsumer(numExpected int32) *testConsumer {
75	return &testConsumer{
76		numExpected: numExpected,
77		receivedAll: make(chan struct{}),
78	}
79}
80
81func TestNextClient(t *testing.T) {
82	t.Parallel()
83
84	n := 4
85	_, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{
86		NumChannels: n,
87		SessionPoolConfig: SessionPoolConfig{
88			MinOpened: 0,
89			MaxOpened: 100,
90		},
91	})
92	defer teardown()
93	sc := client.idleSessions.sc
94	connections := make(map[*grpc.ClientConn]int)
95	for i := 0; i < n; i++ {
96		client, err := sc.nextClient()
97		if err != nil {
98			t.Fatalf("Error getting a gapic client from the session client\nGot: %v", err)
99		}
100		conn1 := client.Connection()
101		conn2 := client.Connection()
102		if conn1 != conn2 {
103			t.Fatalf("Client connection mismatch. Expected to get two equal connections.\nGot: %v and %v", conn1, conn2)
104		}
105		if index, ok := connections[conn1]; ok {
106			t.Fatalf("Same connection used multiple times for different clients.\nClient 1: %v\nClient 2: %v", index, i)
107		}
108		connections[conn1] = i
109	}
110	// Pass through all the clients once more. This time the exact same
111	// connections should be found.
112	for i := 0; i < n; i++ {
113		client, err := sc.nextClient()
114		if err != nil {
115			t.Fatalf("Error getting a gapic client from the session client\nGot: %v", err)
116		}
117		conn := client.Connection()
118		if _, ok := connections[conn]; !ok {
119			t.Fatalf("Connection not found for index %v", i)
120		}
121	}
122}
123
124func TestCreateAndCloseSession(t *testing.T) {
125	t.Parallel()
126
127	server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{
128		SessionPoolConfig: SessionPoolConfig{
129			MinOpened: 0,
130			MaxOpened: 100,
131		},
132	})
133	defer teardown()
134
135	s, err := client.sc.createSession(context.Background())
136	if err != nil {
137		t.Fatalf("batch.next() return error mismatch\ngot: %v\nwant: nil", err)
138	}
139	if s == nil {
140		t.Fatalf("batch.next() return value mismatch\ngot: %v\nwant: any session", s)
141	}
142	if server.TestSpanner.TotalSessionsCreated() != 1 {
143		t.Fatalf("number of sessions created mismatch\ngot: %v\nwant: %v", server.TestSpanner.TotalSessionsCreated(), 1)
144	}
145	s.delete(context.Background())
146	if server.TestSpanner.TotalSessionsDeleted() != 1 {
147		t.Fatalf("number of sessions deleted mismatch\ngot: %v\nwant: %v", server.TestSpanner.TotalSessionsDeleted(), 1)
148	}
149}
150
151func TestBatchCreateAndCloseSession(t *testing.T) {
152	t.Parallel()
153
154	numSessions := int32(100)
155	server, opts, serverTeardown := NewMockedSpannerInMemTestServer(t)
156	defer serverTeardown()
157	for numChannels := 1; numChannels <= 32; numChannels *= 2 {
158		prevCreated := server.TestSpanner.TotalSessionsCreated()
159		prevDeleted := server.TestSpanner.TotalSessionsDeleted()
160		client, err := NewClientWithConfig(context.Background(), "projects/p/instances/i/databases/d", ClientConfig{
161			NumChannels: numChannels,
162			SessionPoolConfig: SessionPoolConfig{
163				MinOpened: 0,
164				MaxOpened: 400,
165			}}, opts...)
166		if err != nil {
167			t.Fatal(err)
168		}
169		consumer := newTestConsumer(numSessions)
170		client.sc.batchCreateSessions(numSessions, true, consumer)
171		<-consumer.receivedAll
172		if len(consumer.sessions) != int(numSessions) {
173			t.Fatalf("returned number of sessions mismatch\ngot: %v\nwant: %v", len(consumer.sessions), numSessions)
174		}
175		created := server.TestSpanner.TotalSessionsCreated() - prevCreated
176		if created != uint(numSessions) {
177			t.Fatalf("number of sessions created mismatch\ngot: %v\nwant: %v", created, numSessions)
178		}
179		// Check that all channels are used evenly.
180		channelCounts := make(map[*vkit.Client]int32)
181		for _, s := range consumer.sessions {
182			channelCounts[s.client]++
183		}
184		if len(channelCounts) != numChannels {
185			t.Fatalf("number of channels used mismatch\ngot: %v\nwant: %v", len(channelCounts), numChannels)
186		}
187		for _, c := range channelCounts {
188			if c < numSessions/int32(numChannels) || c > numSessions/int32(numChannels)+(numSessions%int32(numChannels)) {
189				t.Fatalf("channel used an unexpected number of times\ngot: %v\nwant between %v and %v", c, numSessions/int32(numChannels), numSessions/int32(numChannels)+1)
190			}
191		}
192		// Delete the sessions.
193		for _, s := range consumer.sessions {
194			s.delete(context.Background())
195		}
196		deleted := server.TestSpanner.TotalSessionsDeleted() - prevDeleted
197		if deleted != uint(numSessions) {
198			t.Fatalf("number of sessions deleted mismatch\ngot: %v\nwant %v", deleted, numSessions)
199		}
200		client.Close()
201	}
202}
203
204func TestBatchCreateSessionsWithExceptions(t *testing.T) {
205	t.Parallel()
206
207	numSessions := int32(100)
208	server, opts, serverTeardown := NewMockedSpannerInMemTestServer(t)
209	defer serverTeardown()
210
211	// Run the test with everything between 1 and numChannels errors.
212	for numErrors := int32(1); numErrors <= numChannels; numErrors++ {
213		// Make sure that the error is not always the first call.
214		for firstErrorAt := numErrors - 1; firstErrorAt < numChannels-numErrors+1; firstErrorAt++ {
215			client, err := NewClientWithConfig(context.Background(), "projects/p/instances/i/databases/d", ClientConfig{
216				NumChannels: numChannels,
217				SessionPoolConfig: SessionPoolConfig{
218					MinOpened: 0,
219					MaxOpened: 400,
220				}}, opts...)
221			if err != nil {
222				t.Fatal(err)
223			}
224			// Register the errors on the server.
225			errors := make([]error, numErrors+firstErrorAt)
226			for i := firstErrorAt; i < numErrors+firstErrorAt; i++ {
227				errors[i] = status.Errorf(codes.FailedPrecondition, "session creation failed")
228			}
229			server.TestSpanner.PutExecutionTime(MethodBatchCreateSession, SimulatedExecutionTime{
230				Errors: errors,
231			})
232			consumer := newTestConsumer(numSessions)
233			client.sc.batchCreateSessions(numSessions, true, consumer)
234			<-consumer.receivedAll
235
236			sessionsReturned := int32(len(consumer.sessions))
237			if int32(len(consumer.errors)) != numErrors {
238				t.Fatalf("Error count mismatch\nGot: %d\nWant: %d", len(consumer.errors), numErrors)
239			}
240			for _, e := range consumer.errors {
241				if g, w := status.Code(e.err), codes.FailedPrecondition; g != w {
242					t.Fatalf("error code mismatch\ngot: %v\nwant: %v", g, w)
243				}
244			}
245			maxExpectedSessions := numSessions - numErrors*(numSessions/numChannels)
246			minExpectedSessions := numSessions - numErrors*(numSessions/numChannels+1)
247			if sessionsReturned < minExpectedSessions || sessionsReturned > maxExpectedSessions {
248				t.Fatalf("session count mismatch\ngot: %v\nwant between %v and %v", sessionsReturned, minExpectedSessions, maxExpectedSessions)
249			}
250			client.Close()
251		}
252	}
253}
254
255func TestBatchCreateSessions_ServerReturnsLessThanRequestedSessions(t *testing.T) {
256	t.Parallel()
257
258	numChannels := 4
259	server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{
260		NumChannels: numChannels,
261		SessionPoolConfig: SessionPoolConfig{
262			MinOpened: 0,
263			MaxOpened: 100,
264		},
265	})
266	defer teardown()
267	// Ensure that the server will never return more than 10 sessions per batch
268	// create request.
269	server.TestSpanner.SetMaxSessionsReturnedByServerPerBatchRequest(10)
270	numSessions := int32(100)
271	// Request a batch of sessions that is larger than will be returned by the
272	// server in one request. The server will return at most 10 sessions per
273	// request. The sessionCreator will spread these requests over the 4
274	// channels that are available, i.e. do requests for 25 sessions in each
275	// request. The batch should still return 100 sessions.
276	consumer := newTestConsumer(numSessions)
277	client.sc.batchCreateSessions(numSessions, true, consumer)
278	<-consumer.receivedAll
279	if len(consumer.errors) > 0 {
280		t.Fatalf("Error count mismatch\nGot: %d\nWant: %d", len(consumer.errors), 0)
281	}
282	returnedSessionCount := int32(len(consumer.sessions))
283	if returnedSessionCount != numSessions {
284		t.Fatalf("Returned sessions mismatch\nGot: %v\nWant: %v", returnedSessionCount, numSessions)
285	}
286}
287
288func TestBatchCreateSessions_ServerExhausted(t *testing.T) {
289	t.Parallel()
290
291	numChannels := 4
292	server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{
293		NumChannels: numChannels,
294		SessionPoolConfig: SessionPoolConfig{
295			MinOpened: 0,
296			MaxOpened: 100,
297		},
298	})
299	defer teardown()
300	numSessions := int32(100)
301	maxSessions := int32(50)
302	// Ensure that the server will never return more than 50 sessions in total.
303	server.TestSpanner.SetMaxSessionsReturnedByServerInTotal(maxSessions)
304	consumer := newTestConsumer(numSessions)
305	client.sc.batchCreateSessions(numSessions, true, consumer)
306	<-consumer.receivedAll
307	// Session creation should end with at least one RESOURCE_EXHAUSTED error.
308	if len(consumer.errors) == 0 {
309		t.Fatalf("Error count mismatch\nGot: %d\nWant: > %d", len(consumer.errors), 0)
310	}
311	for _, e := range consumer.errors {
312		if g, w := status.Code(e.err), codes.ResourceExhausted; g != w {
313			t.Fatalf("Error code mismath\nGot: %v\nWant: %v", g, w)
314		}
315	}
316	// The number of returned sessions should be equal to the max of the
317	// server.
318	returnedSessionCount := int32(len(consumer.sessions))
319	if returnedSessionCount != maxSessions {
320		t.Fatalf("Returned sessions mismatch\nGot: %v\nWant: %v", returnedSessionCount, maxSessions)
321	}
322	if consumer.numErr != (numSessions - maxSessions) {
323		t.Fatalf("Num errored sessions mismatch\nGot: %v\nWant: %v", consumer.numErr, numSessions-maxSessions)
324	}
325}
326
327func TestBatchCreateSessions_WithTimeout(t *testing.T) {
328	t.Parallel()
329
330	numSessions := int32(100)
331	server, opts, serverTeardown := NewMockedSpannerInMemTestServer(t)
332	defer serverTeardown()
333	server.TestSpanner.PutExecutionTime(MethodBatchCreateSession, SimulatedExecutionTime{
334		MinimumExecutionTime: time.Second,
335	})
336	client, err := NewClientWithConfig(context.Background(), "projects/p/instances/i/databases/d", ClientConfig{
337		SessionPoolConfig: SessionPoolConfig{
338			MinOpened: 0,
339			MaxOpened: 400,
340		}}, opts...)
341	if err != nil {
342		t.Fatal(err)
343	}
344
345	client.sc.batchTimeout = 10 * time.Millisecond
346	consumer := newTestConsumer(numSessions)
347	client.sc.batchCreateSessions(numSessions, true, consumer)
348	<-consumer.receivedAll
349	if len(consumer.sessions) > 0 {
350		t.Fatalf("Returned number of sessions mismatch\ngot: %v\nwant: %v", len(consumer.sessions), 0)
351	}
352	if len(consumer.errors) != numChannels {
353		t.Fatalf("Returned number of errors mismatch\ngot: %v\nwant: %v", len(consumer.errors), numChannels)
354	}
355	for _, e := range consumer.errors {
356		if g, w := status.Code(e.err), codes.DeadlineExceeded; g != w {
357			t.Fatalf("Error code mismatch\ngot: %v (%s)\nwant: %v", g, e.err, w)
358		}
359	}
360	client.Close()
361}
362
363func TestClientIDGenerator(t *testing.T) {
364	cidGen = newClientIDGenerator()
365	for _, tt := range []struct {
366		database string
367		clientID string
368	}{
369		{"db", "client-1"},
370		{"db-new", "client-1"},
371		{"db", "client-2"},
372	} {
373		if got, want := cidGen.nextID(tt.database), tt.clientID; got != want {
374			t.Fatalf("Generate wrong client ID: got %v, want %v", got, want)
375		}
376	}
377}
378
379func TestMergeCallOptions(t *testing.T) {
380	a := &vkit.CallOptions{
381		CreateSession: []gax.CallOption{
382			gax.WithRetry(func() gax.Retryer {
383				return gax.OnCodes([]codes.Code{
384					codes.Unavailable, codes.DeadlineExceeded,
385				}, gax.Backoff{
386					Initial:    100 * time.Millisecond,
387					Max:        16000 * time.Millisecond,
388					Multiplier: 1.0,
389				})
390			}),
391		},
392		GetSession: []gax.CallOption{
393			gax.WithRetry(func() gax.Retryer {
394				return gax.OnCodes([]codes.Code{
395					codes.Unavailable, codes.DeadlineExceeded,
396				}, gax.Backoff{
397					Initial:    250 * time.Millisecond,
398					Max:        32000 * time.Millisecond,
399					Multiplier: 1.30,
400				})
401			}),
402		},
403	}
404	b := &vkit.CallOptions{
405		CreateSession: []gax.CallOption{
406			gax.WithRetry(func() gax.Retryer {
407				return gax.OnCodes([]codes.Code{
408					codes.Unavailable,
409				}, gax.Backoff{
410					Initial:    250 * time.Millisecond,
411					Max:        32000 * time.Millisecond,
412					Multiplier: 1.30,
413				})
414			}),
415		},
416		BatchCreateSessions: []gax.CallOption{
417			gax.WithRetry(func() gax.Retryer {
418				return gax.OnCodes([]codes.Code{
419					codes.Unavailable,
420				}, gax.Backoff{
421					Initial:    250 * time.Millisecond,
422					Max:        32000 * time.Millisecond,
423					Multiplier: 1.30,
424				})
425			}),
426		}}
427
428	merged := mergeCallOptions(b, a)
429	cs := &gax.CallSettings{}
430	// We can't access the fields of Retryer so we have test the result by
431	// comparing strings.
432	merged.CreateSession[0].Resolve(cs)
433	if got, want := fmt.Sprintf("%v", cs.Retry()), "&{{250000000 32000000000 1.3 0} [14]}"; got != want {
434		t.Fatalf("merged CallOptions is incorrect: got %v, want %v", got, want)
435	}
436
437	merged.CreateSession[1].Resolve(cs)
438	if got, want := fmt.Sprintf("%v", cs.Retry()), "&{{100000000 16000000000 1 0} [14 4]}"; got != want {
439		t.Fatalf("merged CallOptions is incorrect: got %v, want %v", got, want)
440	}
441
442	merged.GetSession[0].Resolve(cs)
443	if got, want := fmt.Sprintf("%v", cs.Retry()), "&{{250000000 32000000000 1.3 0} [14 4]}"; got != want {
444		t.Fatalf("merged CallOptions is incorrect: got %v, want %v", got, want)
445	}
446
447	merged.BatchCreateSessions[0].Resolve(cs)
448	if got, want := fmt.Sprintf("%v", cs.Retry()), "&{{250000000 32000000000 1.3 0} [14]}"; got != want {
449		t.Fatalf("merged CallOptions is incorrect: got %v, want %v", got, want)
450	}
451}
452