1/*
2Copyright 2019 Google LLC
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8    http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17package spanner
18
19import (
20	"context"
21	"sync"
22	"testing"
23	"time"
24
25	vkit "cloud.google.com/go/spanner/apiv1"
26	. "cloud.google.com/go/spanner/internal/testutil"
27	"google.golang.org/grpc/codes"
28	"google.golang.org/grpc/status"
29)
30
31type testSessionCreateError struct {
32	err error
33	num int32
34}
35
36type testConsumer struct {
37	numExpected int32
38
39	mu       sync.Mutex
40	sessions []*session
41	errors   []*testSessionCreateError
42	numErr   int32
43
44	receivedAll chan struct{}
45}
46
47func (tc *testConsumer) sessionReady(s *session) {
48	tc.mu.Lock()
49	defer tc.mu.Unlock()
50	tc.sessions = append(tc.sessions, s)
51	tc.checkReceivedAll()
52}
53
54func (tc *testConsumer) sessionCreationFailed(err error, num int32) {
55	tc.mu.Lock()
56	defer tc.mu.Unlock()
57	tc.errors = append(tc.errors, &testSessionCreateError{
58		err: err,
59		num: num,
60	})
61	tc.numErr += num
62	tc.checkReceivedAll()
63}
64
65func (tc *testConsumer) checkReceivedAll() {
66	if int32(len(tc.sessions))+tc.numErr == tc.numExpected {
67		close(tc.receivedAll)
68	}
69}
70
71func newTestConsumer(numExpected int32) *testConsumer {
72	return &testConsumer{
73		numExpected: numExpected,
74		receivedAll: make(chan struct{}),
75	}
76}
77
78func TestCreateAndCloseSession(t *testing.T) {
79	t.Parallel()
80
81	server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{
82		SessionPoolConfig: SessionPoolConfig{
83			MinOpened: 0,
84			MaxOpened: 100,
85		},
86	})
87	defer teardown()
88
89	s, err := client.sc.createSession(context.Background())
90	if err != nil {
91		t.Fatalf("batch.next() return error mismatch\ngot: %v\nwant: nil", err)
92	}
93	if s == nil {
94		t.Fatalf("batch.next() return value mismatch\ngot: %v\nwant: any session", s)
95	}
96	if server.TestSpanner.TotalSessionsCreated() != 1 {
97		t.Fatalf("number of sessions created mismatch\ngot: %v\nwant: %v", server.TestSpanner.TotalSessionsCreated(), 1)
98	}
99	s.delete(context.Background())
100	if server.TestSpanner.TotalSessionsDeleted() != 1 {
101		t.Fatalf("number of sessions deleted mismatch\ngot: %v\nwant: %v", server.TestSpanner.TotalSessionsDeleted(), 1)
102	}
103}
104
105func TestBatchCreateAndCloseSession(t *testing.T) {
106	t.Parallel()
107
108	numSessions := int32(100)
109	server, opts, serverTeardown := NewMockedSpannerInMemTestServer(t)
110	defer serverTeardown()
111	for numChannels := 1; numChannels <= 32; numChannels *= 2 {
112		prevCreated := server.TestSpanner.TotalSessionsCreated()
113		prevDeleted := server.TestSpanner.TotalSessionsDeleted()
114		client, err := NewClientWithConfig(context.Background(), "projects/p/instances/i/databases/d", ClientConfig{
115			NumChannels: numChannels,
116			SessionPoolConfig: SessionPoolConfig{
117				MinOpened: 0,
118				MaxOpened: 400,
119			}}, opts...)
120		if err != nil {
121			t.Fatal(err)
122		}
123		consumer := newTestConsumer(numSessions)
124		client.sc.batchCreateSessions(numSessions, consumer)
125		<-consumer.receivedAll
126		if len(consumer.sessions) != int(numSessions) {
127			t.Fatalf("returned number of sessions mismatch\ngot: %v\nwant: %v", len(consumer.sessions), numSessions)
128		}
129		created := server.TestSpanner.TotalSessionsCreated() - prevCreated
130		if created != uint(numSessions) {
131			t.Fatalf("number of sessions created mismatch\ngot: %v\nwant: %v", created, numSessions)
132		}
133		// Check that all channels are used evenly.
134		channelCounts := make(map[*vkit.Client]int32)
135		for _, s := range consumer.sessions {
136			channelCounts[s.client]++
137		}
138		if len(channelCounts) != numChannels {
139			t.Fatalf("number of channels used mismatch\ngot: %v\nwant: %v", len(channelCounts), numChannels)
140		}
141		for _, c := range channelCounts {
142			if c < numSessions/int32(numChannels) || c > numSessions/int32(numChannels)+(numSessions%int32(numChannels)) {
143				t.Fatalf("channel used an unexpected number of times\ngot: %v\nwant between %v and %v", c, numSessions/int32(numChannels), numSessions/int32(numChannels)+1)
144			}
145		}
146		// Delete the sessions.
147		for _, s := range consumer.sessions {
148			s.delete(context.Background())
149		}
150		deleted := server.TestSpanner.TotalSessionsDeleted() - prevDeleted
151		if deleted != uint(numSessions) {
152			t.Fatalf("number of sessions deleted mismatch\ngot: %v\nwant %v", deleted, numSessions)
153		}
154		client.Close()
155	}
156}
157
158func TestBatchCreateSessionsWithExceptions(t *testing.T) {
159	t.Parallel()
160
161	numSessions := int32(100)
162	server, opts, serverTeardown := NewMockedSpannerInMemTestServer(t)
163	defer serverTeardown()
164
165	// Run the test with everything between 1 and numChannels errors.
166	for numErrors := int32(1); numErrors <= numChannels; numErrors++ {
167		// Make sure that the error is not always the first call.
168		for firstErrorAt := numErrors - 1; firstErrorAt < numChannels-numErrors+1; firstErrorAt++ {
169			client, err := NewClientWithConfig(context.Background(), "projects/p/instances/i/databases/d", ClientConfig{
170				NumChannels: numChannels,
171				SessionPoolConfig: SessionPoolConfig{
172					MinOpened: 0,
173					MaxOpened: 400,
174				}}, opts...)
175			if err != nil {
176				t.Fatal(err)
177			}
178			// Register the errors on the server.
179			errors := make([]error, numErrors+firstErrorAt)
180			for i := firstErrorAt; i < numErrors+firstErrorAt; i++ {
181				errors[i] = status.Errorf(codes.FailedPrecondition, "session creation failed")
182			}
183			server.TestSpanner.PutExecutionTime(MethodBatchCreateSession, SimulatedExecutionTime{
184				Errors: errors,
185			})
186			consumer := newTestConsumer(numSessions)
187			client.sc.batchCreateSessions(numSessions, consumer)
188			<-consumer.receivedAll
189
190			sessionsReturned := int32(len(consumer.sessions))
191			if int32(len(consumer.errors)) != numErrors {
192				t.Fatalf("Error count mismatch\nGot: %d\nWant: %d", len(consumer.errors), numErrors)
193			}
194			for _, e := range consumer.errors {
195				if g, w := status.Code(e.err), codes.FailedPrecondition; g != w {
196					t.Fatalf("error code mismatch\ngot: %v\nwant: %v", g, w)
197				}
198			}
199			maxExpectedSessions := numSessions - numErrors*(numSessions/numChannels)
200			minExpectedSessions := numSessions - numErrors*(numSessions/numChannels+1)
201			if sessionsReturned < minExpectedSessions || sessionsReturned > maxExpectedSessions {
202				t.Fatalf("session count mismatch\ngot: %v\nwant between %v and %v", sessionsReturned, minExpectedSessions, maxExpectedSessions)
203			}
204			client.Close()
205		}
206	}
207}
208
209func TestBatchCreateSessions_ServerReturnsLessThanRequestedSessions(t *testing.T) {
210	t.Parallel()
211
212	numChannels := 4
213	server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{
214		NumChannels: numChannels,
215		SessionPoolConfig: SessionPoolConfig{
216			MinOpened: 0,
217			MaxOpened: 100,
218		},
219	})
220	defer teardown()
221	// Ensure that the server will never return more than 10 sessions per batch
222	// create request.
223	server.TestSpanner.SetMaxSessionsReturnedByServerPerBatchRequest(10)
224	numSessions := int32(100)
225	// Request a batch of sessions that is larger than will be returned by the
226	// server in one request. The server will return at most 10 sessions per
227	// request. The sessionCreator will spread these requests over the 4
228	// channels that are available, i.e. do requests for 25 sessions in each
229	// request. The batch should still return 100 sessions.
230	consumer := newTestConsumer(numSessions)
231	client.sc.batchCreateSessions(numSessions, consumer)
232	<-consumer.receivedAll
233	if len(consumer.errors) > 0 {
234		t.Fatalf("Error count mismatch\nGot: %d\nWant: %d", len(consumer.errors), 0)
235	}
236	returnedSessionCount := int32(len(consumer.sessions))
237	if returnedSessionCount != numSessions {
238		t.Fatalf("Returned sessions mismatch\nGot: %v\nWant: %v", returnedSessionCount, numSessions)
239	}
240}
241
242func TestBatchCreateSessions_ServerExhausted(t *testing.T) {
243	t.Parallel()
244
245	numChannels := 4
246	server, client, teardown := setupMockedTestServerWithConfig(t, ClientConfig{
247		NumChannels: numChannels,
248		SessionPoolConfig: SessionPoolConfig{
249			MinOpened: 0,
250			MaxOpened: 100,
251		},
252	})
253	defer teardown()
254	numSessions := int32(100)
255	maxSessions := int32(50)
256	// Ensure that the server will never return more than 50 sessions in total.
257	server.TestSpanner.SetMaxSessionsReturnedByServerInTotal(maxSessions)
258	consumer := newTestConsumer(numSessions)
259	client.sc.batchCreateSessions(numSessions, consumer)
260	<-consumer.receivedAll
261	// Session creation should end with at least one RESOURCE_EXHAUSTED error.
262	if len(consumer.errors) == 0 {
263		t.Fatalf("Error count mismatch\nGot: %d\nWant: > %d", len(consumer.errors), 0)
264	}
265	for _, e := range consumer.errors {
266		if g, w := status.Code(e.err), codes.ResourceExhausted; g != w {
267			t.Fatalf("Error code mismath\nGot: %v\nWant: %v", g, w)
268		}
269	}
270	// The number of returned sessions should be equal to the max of the
271	// server.
272	returnedSessionCount := int32(len(consumer.sessions))
273	if returnedSessionCount != maxSessions {
274		t.Fatalf("Returned sessions mismatch\nGot: %v\nWant: %v", returnedSessionCount, maxSessions)
275	}
276	if consumer.numErr != (numSessions - maxSessions) {
277		t.Fatalf("Num errored sessions mismatch\nGot: %v\nWant: %v", consumer.numErr, numSessions-maxSessions)
278	}
279}
280
281func TestBatchCreateSessions_WithTimeout(t *testing.T) {
282	t.Parallel()
283
284	numSessions := int32(100)
285	server, opts, serverTeardown := NewMockedSpannerInMemTestServer(t)
286	defer serverTeardown()
287	server.TestSpanner.PutExecutionTime(MethodBatchCreateSession, SimulatedExecutionTime{
288		MinimumExecutionTime: time.Second,
289	})
290	client, err := NewClientWithConfig(context.Background(), "projects/p/instances/i/databases/d", ClientConfig{
291		SessionPoolConfig: SessionPoolConfig{
292			MinOpened: 0,
293			MaxOpened: 400,
294		}}, opts...)
295	if err != nil {
296		t.Fatal(err)
297	}
298
299	client.sc.batchTimeout = 10 * time.Millisecond
300	consumer := newTestConsumer(numSessions)
301	client.sc.batchCreateSessions(numSessions, consumer)
302	<-consumer.receivedAll
303	if len(consumer.sessions) > 0 {
304		t.Fatalf("Returned number of sessions mismatch\ngot: %v\nwant: %v", len(consumer.sessions), 0)
305	}
306	if len(consumer.errors) != numChannels {
307		t.Fatalf("Returned number of errors mismatch\ngot: %v\nwant: %v", len(consumer.errors), numChannels)
308	}
309	for _, e := range consumer.errors {
310		if g, w := status.Code(e.err), codes.DeadlineExceeded; g != w {
311			t.Fatalf("Error code mismatch\ngot: %v (%s)\nwant: %v", g, e.err, w)
312		}
313	}
314	client.Close()
315}
316