1// Copyright 2020 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13
14package wire
15
16import (
17	"context"
18	"testing"
19
20	"cloud.google.com/go/pubsublite/internal/test"
21	"google.golang.org/grpc/codes"
22	"google.golang.org/grpc/status"
23)
24
25// testCommitter wraps a committer for ease of testing.
26type testCommitter struct {
27	cmt *committer
28	serviceTestProxy
29}
30
31func newTestCommitter(t *testing.T, subscription subscriptionPartition, acks *ackTracker) *testCommitter {
32	ctx := context.Background()
33	cursorClient, err := newCursorClient(ctx, "ignored", testServer.ClientConn())
34	if err != nil {
35		t.Fatal(err)
36	}
37
38	tc := &testCommitter{
39		cmt: newCommitter(ctx, cursorClient, testReceiveSettings(), subscription, acks, true),
40	}
41	tc.initAndStart(t, tc.cmt, "Committer", cursorClient)
42	return tc
43}
44
45// SendBatchCommit invokes the periodic background batch commit. Note that the
46// periodic task is disabled in tests.
47func (tc *testCommitter) SendBatchCommit() {
48	tc.cmt.commitOffsetToStream()
49}
50
51func (tc *testCommitter) Terminate() {
52	tc.cmt.Terminate()
53}
54
55func (tc *testCommitter) BlockingReset() error {
56	return tc.cmt.BlockingReset()
57}
58
59func TestCommitterStreamReconnect(t *testing.T) {
60	subscription := subscriptionPartition{"projects/123456/locations/us-central1-b/subscriptions/my-subs", 0}
61	ack1 := newAckConsumer(33, 0, nil)
62	ack2 := newAckConsumer(55, 0, nil)
63	acks := newAckTracker()
64	acks.Push(ack1)
65	acks.Push(ack2)
66
67	verifiers := test.NewVerifiers(t)
68
69	// Simulate a transient error that results in a reconnect.
70	stream1 := test.NewRPCVerifier(t)
71	stream1.Push(initCommitReq(subscription), initCommitResp(), nil)
72	barrier := stream1.PushWithBarrier(commitReq(34), nil, status.Error(codes.Unavailable, "server unavailable"))
73	verifiers.AddCommitStream(subscription.Path, subscription.Partition, stream1)
74
75	// When the stream reconnects, the latest commit offset should be sent to the
76	// server.
77	stream2 := test.NewRPCVerifier(t)
78	stream2.Push(initCommitReq(subscription), initCommitResp(), nil)
79	stream2.Push(commitReq(56), commitResp(1), nil)
80	verifiers.AddCommitStream(subscription.Path, subscription.Partition, stream2)
81
82	mockServer.OnTestStart(verifiers)
83	defer mockServer.OnTestEnd()
84
85	cmt := newTestCommitter(t, subscription, acks)
86	if gotErr := cmt.StartError(); gotErr != nil {
87		t.Errorf("Start() got err: (%v)", gotErr)
88	}
89
90	// Send 2 commits.
91	ack1.Ack()
92	cmt.SendBatchCommit()
93	ack2.Ack()
94	cmt.SendBatchCommit()
95
96	// Then send the retryable error, which results in reconnect.
97	barrier.Release()
98	cmt.StopVerifyNoError()
99}
100
101func TestCommitterStopFlushesCommits(t *testing.T) {
102	subscription := subscriptionPartition{"projects/123456/locations/us-central1-b/subscriptions/my-subs", 0}
103	ack1 := newAckConsumer(33, 0, nil)
104	ack2 := newAckConsumer(55, 0, nil)
105	acks := newAckTracker()
106	acks.Push(ack1)
107	acks.Push(ack2)
108
109	verifiers := test.NewVerifiers(t)
110	stream := test.NewRPCVerifier(t)
111	stream.Push(initCommitReq(subscription), initCommitResp(), nil)
112	stream.Push(commitReq(34), commitResp(1), nil)
113	stream.Push(commitReq(56), commitResp(1), nil)
114	verifiers.AddCommitStream(subscription.Path, subscription.Partition, stream)
115
116	mockServer.OnTestStart(verifiers)
117	defer mockServer.OnTestEnd()
118
119	cmt := newTestCommitter(t, subscription, acks)
120	if gotErr := cmt.StartError(); gotErr != nil {
121		t.Errorf("Start() got err: (%v)", gotErr)
122	}
123
124	ack1.Ack()
125	cmt.Stop() // Stop should flush the first offset
126	ack2.Ack() // Acks after Stop() are processed
127	cmt.SendBatchCommit()
128	// Committer terminates when all acks are processed.
129	if gotErr := cmt.FinalError(); gotErr != nil {
130		t.Errorf("Final err: (%v), want: <nil>", gotErr)
131	}
132}
133
134func TestCommitterTerminateDiscardsOutstandingAcks(t *testing.T) {
135	subscription := subscriptionPartition{"projects/123456/locations/us-central1-b/subscriptions/my-subs", 0}
136	ack1 := newAckConsumer(33, 0, nil)
137	ack2 := newAckConsumer(55, 0, nil)
138	acks := newAckTracker()
139	acks.Push(ack1)
140	acks.Push(ack2)
141
142	verifiers := test.NewVerifiers(t)
143	stream := test.NewRPCVerifier(t)
144	stream.Push(initCommitReq(subscription), initCommitResp(), nil)
145	stream.Push(commitReq(34), commitResp(1), nil)
146	verifiers.AddCommitStream(subscription.Path, subscription.Partition, stream)
147
148	mockServer.OnTestStart(verifiers)
149	defer mockServer.OnTestEnd()
150
151	cmt := newTestCommitter(t, subscription, acks)
152	if gotErr := cmt.StartError(); gotErr != nil {
153		t.Errorf("Start() got err: (%v)", gotErr)
154	}
155
156	ack1.Ack()
157	cmt.Terminate()       // Terminate should flush the first offset
158	ack2.Ack()            // Acks after Terminate() are discarded
159	cmt.SendBatchCommit() // Should do nothing (server does not expect second commit)
160	if gotErr := cmt.FinalError(); gotErr != nil {
161		t.Errorf("Final err: (%v), want: <nil>", gotErr)
162	}
163}
164
165func TestCommitterStopThenTerminateDiscardsOutstandingAcks(t *testing.T) {
166	subscription := subscriptionPartition{"projects/123456/locations/us-central1-b/subscriptions/my-subs", 0}
167	ack := newAckConsumer(33, 0, nil)
168	acks := newAckTracker()
169	acks.Push(ack)
170
171	verifiers := test.NewVerifiers(t)
172	stream := test.NewRPCVerifier(t)
173	stream.Push(initCommitReq(subscription), initCommitResp(), nil)
174	// No commits expected.
175	verifiers.AddCommitStream(subscription.Path, subscription.Partition, stream)
176
177	mockServer.OnTestStart(verifiers)
178	defer mockServer.OnTestEnd()
179
180	cmt := newTestCommitter(t, subscription, acks)
181	if gotErr := cmt.StartError(); gotErr != nil {
182		t.Errorf("Start() got err: (%v)", gotErr)
183	}
184
185	cmt.Stop()      // Stop waits for outstanding acks
186	cmt.Terminate() // Terminate should discard all outstanding acks
187	if gotErr := cmt.FinalError(); gotErr != nil {
188		t.Errorf("Final err: (%v), want: <nil>", gotErr)
189	}
190}
191
192func TestCommitterPermanentStreamError(t *testing.T) {
193	subscription := subscriptionPartition{"projects/123456/locations/us-central1-b/subscriptions/my-subs", 0}
194	acks := newAckTracker()
195	wantErr := status.Error(codes.FailedPrecondition, "failed")
196
197	verifiers := test.NewVerifiers(t)
198	stream := test.NewRPCVerifier(t)
199	stream.Push(initCommitReq(subscription), nil, wantErr)
200	verifiers.AddCommitStream(subscription.Path, subscription.Partition, stream)
201
202	mockServer.OnTestStart(verifiers)
203	defer mockServer.OnTestEnd()
204
205	cmt := newTestCommitter(t, subscription, acks)
206	if gotErr := cmt.StartError(); !test.ErrorEqual(gotErr, wantErr) {
207		t.Errorf("Start() got err: (%v), want: (%v)", gotErr, wantErr)
208	}
209}
210
211func TestCommitterInvalidInitialResponse(t *testing.T) {
212	subscription := subscriptionPartition{"projects/123456/locations/us-central1-b/subscriptions/my-subs", 0}
213	acks := newAckTracker()
214
215	verifiers := test.NewVerifiers(t)
216	stream := test.NewRPCVerifier(t)
217	stream.Push(initCommitReq(subscription), commitResp(1234), nil) // Invalid initial response
218	verifiers.AddCommitStream(subscription.Path, subscription.Partition, stream)
219
220	mockServer.OnTestStart(verifiers)
221	defer mockServer.OnTestEnd()
222
223	cmt := newTestCommitter(t, subscription, acks)
224
225	wantErr := errInvalidInitialCommitResponse
226	if gotErr := cmt.StartError(); !test.ErrorEqual(gotErr, wantErr) {
227		t.Errorf("Start() got err: (%v), want: (%v)", gotErr, wantErr)
228	}
229	if gotErr := cmt.FinalError(); !test.ErrorEqual(gotErr, wantErr) {
230		t.Errorf("Final err: (%v), want: (%v)", gotErr, wantErr)
231	}
232}
233
234func TestCommitterInvalidCommitResponse(t *testing.T) {
235	subscription := subscriptionPartition{"projects/123456/locations/us-central1-b/subscriptions/my-subs", 0}
236	ack := newAckConsumer(33, 0, nil)
237	acks := newAckTracker()
238	acks.Push(ack)
239
240	verifiers := test.NewVerifiers(t)
241	stream := test.NewRPCVerifier(t)
242	stream.Push(initCommitReq(subscription), initCommitResp(), nil)
243	stream.Push(commitReq(34), initCommitResp(), nil) // Invalid commit response
244	verifiers.AddCommitStream(subscription.Path, subscription.Partition, stream)
245
246	mockServer.OnTestStart(verifiers)
247	defer mockServer.OnTestEnd()
248
249	cmt := newTestCommitter(t, subscription, acks)
250	if gotErr := cmt.StartError(); gotErr != nil {
251		t.Errorf("Start() got err: (%v)", gotErr)
252	}
253
254	ack.Ack()
255	cmt.SendBatchCommit()
256
257	if gotErr, wantErr := cmt.FinalError(), errInvalidCommitResponse; !test.ErrorEqual(gotErr, wantErr) {
258		t.Errorf("Final err: (%v), want: (%v)", gotErr, wantErr)
259	}
260}
261
262func TestCommitterExcessConfirmedOffsets(t *testing.T) {
263	subscription := subscriptionPartition{"projects/123456/locations/us-central1-b/subscriptions/my-subs", 0}
264	ack := newAckConsumer(33, 0, nil)
265	acks := newAckTracker()
266	acks.Push(ack)
267
268	verifiers := test.NewVerifiers(t)
269	stream := test.NewRPCVerifier(t)
270	stream.Push(initCommitReq(subscription), initCommitResp(), nil)
271	stream.Push(commitReq(34), commitResp(2), nil) // More confirmed offsets than committed
272	verifiers.AddCommitStream(subscription.Path, subscription.Partition, stream)
273
274	mockServer.OnTestStart(verifiers)
275	defer mockServer.OnTestEnd()
276
277	cmt := newTestCommitter(t, subscription, acks)
278	if gotErr := cmt.StartError(); gotErr != nil {
279		t.Errorf("Start() got err: (%v)", gotErr)
280	}
281
282	ack.Ack()
283	cmt.SendBatchCommit()
284
285	wantMsg := "server acknowledged 2 cursor commits"
286	if gotErr := cmt.FinalError(); !test.ErrorHasMsg(gotErr, wantMsg) {
287		t.Errorf("Final err: (%v), want msg: (%v)", gotErr, wantMsg)
288	}
289}
290
291func TestCommitterZeroConfirmedOffsets(t *testing.T) {
292	subscription := subscriptionPartition{"projects/123456/locations/us-central1-b/subscriptions/my-subs", 0}
293	ack := newAckConsumer(33, 0, nil)
294	acks := newAckTracker()
295	acks.Push(ack)
296
297	verifiers := test.NewVerifiers(t)
298	stream := test.NewRPCVerifier(t)
299	stream.Push(initCommitReq(subscription), initCommitResp(), nil)
300	stream.Push(commitReq(34), commitResp(0), nil) // Zero confirmed offsets (invalid)
301	verifiers.AddCommitStream(subscription.Path, subscription.Partition, stream)
302
303	mockServer.OnTestStart(verifiers)
304	defer mockServer.OnTestEnd()
305
306	cmt := newTestCommitter(t, subscription, acks)
307	if gotErr := cmt.StartError(); gotErr != nil {
308		t.Errorf("Start() got err: (%v)", gotErr)
309	}
310
311	ack.Ack()
312	cmt.SendBatchCommit()
313
314	wantMsg := "server acknowledged an invalid commit count"
315	if gotErr := cmt.FinalError(); !test.ErrorHasMsg(gotErr, wantMsg) {
316		t.Errorf("Final err: (%v), want msg: (%v)", gotErr, wantMsg)
317	}
318}
319
320func TestCommitterBlockingResetNormalCompletion(t *testing.T) {
321	subscription := subscriptionPartition{"projects/123456/locations/us-central1-b/subscriptions/my-subs", 0}
322	ack1 := newAckConsumer(33, 0, nil)
323	ack2 := newAckConsumer(55, 0, nil)
324	acks := newAckTracker()
325	acks.Push(ack1)
326	acks.Push(ack2)
327
328	verifiers := test.NewVerifiers(t)
329	stream := test.NewRPCVerifier(t)
330	stream.Push(initCommitReq(subscription), initCommitResp(), nil)
331	stream.Push(commitReq(34), commitResp(1), nil)
332	barrier := stream.PushWithBarrier(commitReq(56), commitResp(1), nil)
333	verifiers.AddCommitStream(subscription.Path, subscription.Partition, stream)
334
335	mockServer.OnTestStart(verifiers)
336	defer mockServer.OnTestEnd()
337
338	cmt := newTestCommitter(t, subscription, acks)
339	if gotErr := cmt.StartError(); gotErr != nil {
340		t.Errorf("Start() got err: (%v)", gotErr)
341	}
342
343	complete := test.NewCondition("blocking reset complete")
344	go func() {
345		if err := cmt.BlockingReset(); err != nil {
346			t.Errorf("BlockingReset() got err: (%v), want: <nil>", err)
347		}
348		cmt.BlockingReset()
349		complete.SetDone()
350	}()
351	complete.VerifyNotDone(t)
352
353	ack1.Ack()
354	cmt.SendBatchCommit()
355	complete.VerifyNotDone(t)
356	ack2.Ack()
357	cmt.SendBatchCommit()
358
359	// Until the final commit response is received, committer.BlockingReset should
360	// not return.
361	barrier.ReleaseAfter(func() {
362		complete.VerifyNotDone(t)
363	})
364	complete.WaitUntilDone(t, serviceTestWaitTimeout)
365
366	// Ack tracker should be reset.
367	if got, want := acks.CommitOffset(), nilCursorOffset; got != want {
368		t.Errorf("ackTracker.CommitOffset() got %v, want %v", got, want)
369	}
370	if got, want := acks.Empty(), true; got != want {
371		t.Errorf("ackTracker.Empty() got %v, want %v", got, want)
372	}
373
374	// Calling committer.BlockingReset again should immediately return.
375	if err := cmt.BlockingReset(); err != nil {
376		t.Errorf("BlockingReset() got err: (%v), want: <nil>", err)
377	}
378
379	cmt.StopVerifyNoError()
380}
381
382func TestCommitterBlockingResetCommitterStopped(t *testing.T) {
383	subscription := subscriptionPartition{"projects/123456/locations/us-central1-b/subscriptions/my-subs", 0}
384	ack1 := newAckConsumer(33, 0, nil)
385	ack2 := newAckConsumer(55, 0, nil)
386	acks := newAckTracker()
387	acks.Push(ack1)
388	acks.Push(ack2)
389
390	verifiers := test.NewVerifiers(t)
391	stream := test.NewRPCVerifier(t)
392	stream.Push(initCommitReq(subscription), initCommitResp(), nil)
393	stream.Push(commitReq(34), commitResp(1), nil)
394	verifiers.AddCommitStream(subscription.Path, subscription.Partition, stream)
395
396	mockServer.OnTestStart(verifiers)
397	defer mockServer.OnTestEnd()
398
399	cmt := newTestCommitter(t, subscription, acks)
400	if gotErr := cmt.StartError(); gotErr != nil {
401		t.Errorf("Start() got err: (%v)", gotErr)
402	}
403
404	complete := test.NewCondition("blocking reset complete")
405	go func() {
406		if got, want := cmt.BlockingReset(), ErrServiceStopped; !test.ErrorEqual(got, want) {
407			t.Errorf("BlockingReset() got: (%v), want: (%v)", got, want)
408		}
409		complete.SetDone()
410	}()
411	complete.VerifyNotDone(t)
412
413	ack1.Ack()
414	cmt.SendBatchCommit()
415	complete.VerifyNotDone(t)
416
417	// committer.BlockingReset should return when the committer is stopped.
418	cmt.Stop()
419	complete.WaitUntilDone(t, serviceTestWaitTimeout)
420
421	// Ack tracker should not be reset.
422	if got, want := acks.Empty(), false; got != want {
423		t.Errorf("ackTracker.Empty() got %v, want %v", got, want)
424	}
425
426	cmt.Terminate()
427	if gotErr := cmt.FinalError(); gotErr != nil {
428		t.Errorf("Final err: (%v), want: <nil>", gotErr)
429	}
430}
431
432func TestCommitterBlockingResetFatalError(t *testing.T) {
433	subscription := subscriptionPartition{"projects/123456/locations/us-central1-b/subscriptions/my-subs", 0}
434	ack1 := newAckConsumer(33, 0, nil)
435	ack2 := newAckConsumer(55, 0, nil)
436	acks := newAckTracker()
437	acks.Push(ack1)
438	acks.Push(ack2)
439	serverErr := status.Error(codes.FailedPrecondition, "failed")
440
441	verifiers := test.NewVerifiers(t)
442	stream := test.NewRPCVerifier(t)
443	stream.Push(initCommitReq(subscription), initCommitResp(), nil)
444	stream.Push(commitReq(34), nil, serverErr)
445	verifiers.AddCommitStream(subscription.Path, subscription.Partition, stream)
446
447	mockServer.OnTestStart(verifiers)
448	defer mockServer.OnTestEnd()
449
450	cmt := newTestCommitter(t, subscription, acks)
451	if gotErr := cmt.StartError(); gotErr != nil {
452		t.Errorf("Start() got err: (%v)", gotErr)
453	}
454
455	complete := test.NewCondition("blocking reset complete")
456	go func() {
457		if got, want := cmt.BlockingReset(), ErrServiceStopped; !test.ErrorEqual(got, want) {
458			t.Errorf("BlockingReset() got: (%v), want: (%v)", got, want)
459		}
460		complete.SetDone()
461	}()
462	complete.VerifyNotDone(t)
463
464	ack1.Ack()
465	cmt.SendBatchCommit()
466
467	// committer.BlockingReset should return when the committer terminates due to
468	// fatal server error.
469	complete.WaitUntilDone(t, serviceTestWaitTimeout)
470
471	if gotErr := cmt.FinalError(); !test.ErrorEqual(gotErr, serverErr) {
472		t.Errorf("Final err: (%v), want: (%v)", gotErr, serverErr)
473	}
474}
475