1// Copyright 2020 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13
14package wire
15
16import (
17	"context"
18	"testing"
19
20	"cloud.google.com/go/pubsublite/internal/test"
21	"google.golang.org/grpc/codes"
22	"google.golang.org/grpc/status"
23)
24
25// testCommitter wraps a committer for ease of testing.
26type testCommitter struct {
27	cmt *committer
28	serviceTestProxy
29}
30
31func newTestCommitter(t *testing.T, subscription subscriptionPartition, acks *ackTracker) *testCommitter {
32	ctx := context.Background()
33	cursorClient, err := newCursorClient(ctx, "ignored", testServer.ClientConn())
34	if err != nil {
35		t.Fatal(err)
36	}
37
38	tc := &testCommitter{
39		cmt: newCommitter(ctx, cursorClient, testReceiveSettings(), subscription, acks, true),
40	}
41	tc.initAndStart(t, tc.cmt, "Committer", cursorClient)
42	return tc
43}
44
45// SendBatchCommit invokes the periodic background batch commit. Note that the
46// periodic task is disabled in tests.
47func (tc *testCommitter) SendBatchCommit() {
48	tc.cmt.commitOffsetToStream()
49}
50
51func (tc *testCommitter) Terminate() {
52	tc.cmt.Terminate()
53}
54
55func (tc *testCommitter) BlockingReset() error {
56	return tc.cmt.BlockingReset()
57}
58
59func TestCommitterStreamReconnect(t *testing.T) {
60	subscription := subscriptionPartition{"projects/123456/locations/us-central1-b/subscriptions/my-subs", 0}
61	ack1 := newAckConsumer(33, 0, nil)
62	ack2 := newAckConsumer(55, 0, nil)
63	acks := newAckTracker()
64	acks.Push(ack1)
65	acks.Push(ack2)
66
67	verifiers := test.NewVerifiers(t)
68
69	// Simulate a transient error that results in a reconnect.
70	stream1 := test.NewRPCVerifier(t)
71	stream1.Push(initCommitReq(subscription), initCommitResp(), nil)
72	barrier := stream1.PushWithBarrier(commitReq(34), nil, status.Error(codes.Unavailable, "server unavailable"))
73	verifiers.AddCommitStream(subscription.Path, subscription.Partition, stream1)
74
75	// When the stream reconnects, the latest commit offset should be sent to the
76	// server.
77	stream2 := test.NewRPCVerifier(t)
78	stream2.Push(initCommitReq(subscription), initCommitResp(), nil)
79	stream2.Push(commitReq(56), commitResp(1), nil)
80	verifiers.AddCommitStream(subscription.Path, subscription.Partition, stream2)
81
82	mockServer.OnTestStart(verifiers)
83	defer mockServer.OnTestEnd()
84
85	cmt := newTestCommitter(t, subscription, acks)
86	if gotErr := cmt.StartError(); gotErr != nil {
87		t.Errorf("Start() got err: (%v)", gotErr)
88	}
89
90	// Send 2 commits.
91	ack1.Ack()
92	cmt.SendBatchCommit()
93	ack2.Ack()
94	cmt.SendBatchCommit()
95
96	// Then send the retryable error, which results in reconnect.
97	barrier.Release()
98	cmt.StopVerifyNoError()
99}
100
101func TestCommitterStopFlushesCommits(t *testing.T) {
102	subscription := subscriptionPartition{"projects/123456/locations/us-central1-b/subscriptions/my-subs", 0}
103	ack1 := newAckConsumer(33, 0, nil)
104	ack2 := newAckConsumer(55, 0, nil)
105	acks := newAckTracker()
106	acks.Push(ack1)
107	acks.Push(ack2)
108
109	verifiers := test.NewVerifiers(t)
110	stream := test.NewRPCVerifier(t)
111	stream.Push(initCommitReq(subscription), initCommitResp(), nil)
112	stream.Push(commitReq(34), commitResp(1), nil)
113	stream.Push(commitReq(56), commitResp(1), nil)
114	verifiers.AddCommitStream(subscription.Path, subscription.Partition, stream)
115
116	mockServer.OnTestStart(verifiers)
117	defer mockServer.OnTestEnd()
118
119	cmt := newTestCommitter(t, subscription, acks)
120	if gotErr := cmt.StartError(); gotErr != nil {
121		t.Errorf("Start() got err: (%v)", gotErr)
122	}
123
124	ack1.Ack()
125	cmt.Stop() // Stop should flush the first offset
126	ack2.Ack() // Acks after Stop() are processed
127	cmt.SendBatchCommit()
128	// Committer terminates when all acks are processed.
129	if gotErr := cmt.FinalError(); gotErr != nil {
130		t.Errorf("Final err: (%v), want: <nil>", gotErr)
131	}
132}
133
134func TestCommitterTerminateDiscardsOutstandingAcks(t *testing.T) {
135	subscription := subscriptionPartition{"projects/123456/locations/us-central1-b/subscriptions/my-subs", 0}
136	ack1 := newAckConsumer(33, 0, nil)
137	ack2 := newAckConsumer(55, 0, nil)
138	acks := newAckTracker()
139	acks.Push(ack1)
140	acks.Push(ack2)
141
142	verifiers := test.NewVerifiers(t)
143	stream := test.NewRPCVerifier(t)
144	stream.Push(initCommitReq(subscription), initCommitResp(), nil)
145	stream.Push(commitReq(34), commitResp(1), nil)
146	verifiers.AddCommitStream(subscription.Path, subscription.Partition, stream)
147
148	mockServer.OnTestStart(verifiers)
149	defer mockServer.OnTestEnd()
150
151	cmt := newTestCommitter(t, subscription, acks)
152	if gotErr := cmt.StartError(); gotErr != nil {
153		t.Errorf("Start() got err: (%v)", gotErr)
154	}
155
156	ack1.Ack()
157	cmt.Terminate()       // Terminate should flush the first offset
158	ack2.Ack()            // Acks after Terminate() are discarded
159	cmt.SendBatchCommit() // Should do nothing (server does not expect second commit)
160	if gotErr := cmt.FinalError(); gotErr != nil {
161		t.Errorf("Final err: (%v), want: <nil>", gotErr)
162	}
163}
164
165func TestCommitterStopThenTerminateDiscardsOutstandingAcks(t *testing.T) {
166	subscription := subscriptionPartition{"projects/123456/locations/us-central1-b/subscriptions/my-subs", 0}
167	ack := newAckConsumer(33, 0, nil)
168	acks := newAckTracker()
169	acks.Push(ack)
170
171	verifiers := test.NewVerifiers(t)
172	stream := test.NewRPCVerifier(t)
173	stream.Push(initCommitReq(subscription), initCommitResp(), nil)
174	// No commits expected.
175	verifiers.AddCommitStream(subscription.Path, subscription.Partition, stream)
176
177	mockServer.OnTestStart(verifiers)
178	defer mockServer.OnTestEnd()
179
180	cmt := newTestCommitter(t, subscription, acks)
181	if gotErr := cmt.StartError(); gotErr != nil {
182		t.Errorf("Start() got err: (%v)", gotErr)
183	}
184
185	cmt.Stop()      // Stop waits for outstanding acks
186	cmt.Terminate() // Terminate should discard all outstanding acks
187	if gotErr := cmt.FinalError(); gotErr != nil {
188		t.Errorf("Final err: (%v), want: <nil>", gotErr)
189	}
190}
191
192func TestCommitterPermanentStreamError(t *testing.T) {
193	subscription := subscriptionPartition{"projects/123456/locations/us-central1-b/subscriptions/my-subs", 0}
194	acks := newAckTracker()
195	wantErr := status.Error(codes.FailedPrecondition, "failed")
196
197	verifiers := test.NewVerifiers(t)
198	stream := test.NewRPCVerifier(t)
199	stream.Push(initCommitReq(subscription), nil, wantErr)
200	verifiers.AddCommitStream(subscription.Path, subscription.Partition, stream)
201
202	mockServer.OnTestStart(verifiers)
203	defer mockServer.OnTestEnd()
204
205	cmt := newTestCommitter(t, subscription, acks)
206	if gotErr := cmt.StartError(); !test.ErrorEqual(gotErr, wantErr) {
207		t.Errorf("Start() got err: (%v), want: (%v)", gotErr, wantErr)
208	}
209}
210
211func TestCommitterInvalidInitialResponse(t *testing.T) {
212	subscription := subscriptionPartition{"projects/123456/locations/us-central1-b/subscriptions/my-subs", 0}
213	acks := newAckTracker()
214
215	verifiers := test.NewVerifiers(t)
216	stream := test.NewRPCVerifier(t)
217	stream.Push(initCommitReq(subscription), commitResp(1234), nil) // Invalid initial response
218	verifiers.AddCommitStream(subscription.Path, subscription.Partition, stream)
219
220	mockServer.OnTestStart(verifiers)
221	defer mockServer.OnTestEnd()
222
223	cmt := newTestCommitter(t, subscription, acks)
224
225	wantErr := errInvalidInitialCommitResponse
226	if gotErr := cmt.StartError(); !test.ErrorEqual(gotErr, wantErr) {
227		t.Errorf("Start() got err: (%v), want: (%v)", gotErr, wantErr)
228	}
229	if gotErr := cmt.FinalError(); !test.ErrorEqual(gotErr, wantErr) {
230		t.Errorf("Final err: (%v), want: (%v)", gotErr, wantErr)
231	}
232}
233
234func TestCommitterInvalidCommitResponse(t *testing.T) {
235	subscription := subscriptionPartition{"projects/123456/locations/us-central1-b/subscriptions/my-subs", 0}
236	ack := newAckConsumer(33, 0, nil)
237	acks := newAckTracker()
238	acks.Push(ack)
239
240	verifiers := test.NewVerifiers(t)
241	stream := test.NewRPCVerifier(t)
242	stream.Push(initCommitReq(subscription), initCommitResp(), nil)
243	stream.Push(commitReq(34), initCommitResp(), nil) // Invalid commit response
244	verifiers.AddCommitStream(subscription.Path, subscription.Partition, stream)
245
246	mockServer.OnTestStart(verifiers)
247	defer mockServer.OnTestEnd()
248
249	cmt := newTestCommitter(t, subscription, acks)
250	if gotErr := cmt.StartError(); gotErr != nil {
251		t.Errorf("Start() got err: (%v)", gotErr)
252	}
253
254	ack.Ack()
255	cmt.SendBatchCommit()
256
257	if gotErr, wantErr := cmt.FinalError(), errInvalidCommitResponse; !test.ErrorEqual(gotErr, wantErr) {
258		t.Errorf("Final err: (%v), want: (%v)", gotErr, wantErr)
259	}
260}
261
262func TestCommitterExcessConfirmedOffsets(t *testing.T) {
263	subscription := subscriptionPartition{"projects/123456/locations/us-central1-b/subscriptions/my-subs", 0}
264	ack := newAckConsumer(33, 0, nil)
265	acks := newAckTracker()
266	acks.Push(ack)
267
268	verifiers := test.NewVerifiers(t)
269	stream := test.NewRPCVerifier(t)
270	stream.Push(initCommitReq(subscription), initCommitResp(), nil)
271	stream.Push(commitReq(34), commitResp(2), nil) // More confirmed offsets than committed
272	verifiers.AddCommitStream(subscription.Path, subscription.Partition, stream)
273
274	mockServer.OnTestStart(verifiers)
275	defer mockServer.OnTestEnd()
276
277	cmt := newTestCommitter(t, subscription, acks)
278	if gotErr := cmt.StartError(); gotErr != nil {
279		t.Errorf("Start() got err: (%v)", gotErr)
280	}
281
282	ack.Ack()
283	cmt.SendBatchCommit()
284
285	wantMsg := "server acknowledged 2 cursor commits"
286	if gotErr := cmt.FinalError(); !test.ErrorHasMsg(gotErr, wantMsg) {
287		t.Errorf("Final err: (%v), want msg: (%v)", gotErr, wantMsg)
288	}
289}
290
291func TestCommitterZeroConfirmedOffsets(t *testing.T) {
292	subscription := subscriptionPartition{"projects/123456/locations/us-central1-b/subscriptions/my-subs", 0}
293	ack := newAckConsumer(33, 0, nil)
294	acks := newAckTracker()
295	acks.Push(ack)
296
297	verifiers := test.NewVerifiers(t)
298	stream := test.NewRPCVerifier(t)
299	stream.Push(initCommitReq(subscription), initCommitResp(), nil)
300	stream.Push(commitReq(34), commitResp(0), nil) // Zero confirmed offsets (invalid)
301	verifiers.AddCommitStream(subscription.Path, subscription.Partition, stream)
302
303	mockServer.OnTestStart(verifiers)
304	defer mockServer.OnTestEnd()
305
306	cmt := newTestCommitter(t, subscription, acks)
307	if gotErr := cmt.StartError(); gotErr != nil {
308		t.Errorf("Start() got err: (%v)", gotErr)
309	}
310
311	ack.Ack()
312	cmt.SendBatchCommit()
313
314	wantMsg := "server acknowledged an invalid commit count"
315	if gotErr := cmt.FinalError(); !test.ErrorHasMsg(gotErr, wantMsg) {
316		t.Errorf("Final err: (%v), want msg: (%v)", gotErr, wantMsg)
317	}
318}
319
320func TestCommitterBlockingResetNormalCompletion(t *testing.T) {
321	subscription := subscriptionPartition{"projects/123456/locations/us-central1-b/subscriptions/my-subs", 0}
322	ack1 := newAckConsumer(33, 0, nil)
323	ack2 := newAckConsumer(55, 0, nil)
324	acks := newAckTracker()
325	acks.Push(ack1)
326	acks.Push(ack2)
327
328	verifiers := test.NewVerifiers(t)
329	stream := test.NewRPCVerifier(t)
330	stream.Push(initCommitReq(subscription), initCommitResp(), nil)
331	barrier := stream.PushWithBarrier(commitReq(34), commitResp(1), nil)
332	verifiers.AddCommitStream(subscription.Path, subscription.Partition, stream)
333
334	mockServer.OnTestStart(verifiers)
335	defer mockServer.OnTestEnd()
336
337	cmt := newTestCommitter(t, subscription, acks)
338	if gotErr := cmt.StartError(); gotErr != nil {
339		t.Errorf("Start() got err: (%v)", gotErr)
340	}
341
342	ack1.Ack()
343
344	complete := test.NewCondition("blocking reset complete")
345	go func() {
346		if err := cmt.BlockingReset(); err != nil {
347			t.Errorf("BlockingReset() got err: (%v), want: <nil>", err)
348		}
349		cmt.BlockingReset()
350		complete.SetDone()
351	}()
352	complete.VerifyNotDone(t)
353
354	// Until the commit response is received, committer.BlockingReset should not
355	// return.
356	barrier.ReleaseAfter(func() {
357		complete.VerifyNotDone(t)
358	})
359	complete.WaitUntilDone(t, serviceTestWaitTimeout)
360
361	// Ack tracker should be reset.
362	if got, want := acks.CommitOffset(), nilCursorOffset; got != want {
363		t.Errorf("ackTracker.CommitOffset() got %v, want %v", got, want)
364	}
365	if got, want := acks.Empty(), true; got != want {
366		t.Errorf("ackTracker.Empty() got %v, want %v", got, want)
367	}
368
369	// This ack should have been discarded.
370	ack2.Ack()
371
372	// Calling committer.BlockingReset again should immediately return.
373	if err := cmt.BlockingReset(); err != nil {
374		t.Errorf("BlockingReset() got err: (%v), want: <nil>", err)
375	}
376
377	cmt.StopVerifyNoError()
378}
379
380func TestCommitterBlockingResetCommitterStopped(t *testing.T) {
381	subscription := subscriptionPartition{"projects/123456/locations/us-central1-b/subscriptions/my-subs", 0}
382	ack1 := newAckConsumer(33, 0, nil)
383	ack2 := newAckConsumer(55, 0, nil)
384	acks := newAckTracker()
385	acks.Push(ack1)
386	acks.Push(ack2)
387
388	verifiers := test.NewVerifiers(t)
389	stream := test.NewRPCVerifier(t)
390	stream.Push(initCommitReq(subscription), initCommitResp(), nil)
391	barrier := stream.PushWithBarrier(commitReq(34), commitResp(1), nil)
392	verifiers.AddCommitStream(subscription.Path, subscription.Partition, stream)
393
394	mockServer.OnTestStart(verifiers)
395	defer mockServer.OnTestEnd()
396
397	cmt := newTestCommitter(t, subscription, acks)
398	if gotErr := cmt.StartError(); gotErr != nil {
399		t.Errorf("Start() got err: (%v)", gotErr)
400	}
401
402	ack1.Ack()
403
404	complete := test.NewCondition("blocking reset complete")
405	go func() {
406		if got, want := cmt.BlockingReset(), ErrServiceStopped; !test.ErrorEqual(got, want) {
407			t.Errorf("BlockingReset() got: (%v), want: (%v)", got, want)
408		}
409		complete.SetDone()
410	}()
411	complete.VerifyNotDone(t)
412
413	// committer.BlockingReset should return when the committer is stopped.
414	barrier.ReleaseAfter(func() {
415		cmt.Stop()
416		complete.WaitUntilDone(t, serviceTestWaitTimeout)
417	})
418
419	cmt.Terminate()
420	if gotErr := cmt.FinalError(); gotErr != nil {
421		t.Errorf("Final err: (%v), want: <nil>", gotErr)
422	}
423}
424
425func TestCommitterBlockingResetFatalError(t *testing.T) {
426	subscription := subscriptionPartition{"projects/123456/locations/us-central1-b/subscriptions/my-subs", 0}
427	ack1 := newAckConsumer(33, 0, nil)
428	ack2 := newAckConsumer(55, 0, nil)
429	acks := newAckTracker()
430	acks.Push(ack1)
431	acks.Push(ack2)
432	serverErr := status.Error(codes.FailedPrecondition, "failed")
433
434	verifiers := test.NewVerifiers(t)
435	stream := test.NewRPCVerifier(t)
436	stream.Push(initCommitReq(subscription), initCommitResp(), nil)
437	stream.Push(commitReq(34), nil, serverErr)
438	verifiers.AddCommitStream(subscription.Path, subscription.Partition, stream)
439
440	mockServer.OnTestStart(verifiers)
441	defer mockServer.OnTestEnd()
442
443	cmt := newTestCommitter(t, subscription, acks)
444	if gotErr := cmt.StartError(); gotErr != nil {
445		t.Errorf("Start() got err: (%v)", gotErr)
446	}
447
448	ack1.Ack()
449
450	complete := test.NewCondition("blocking reset complete")
451	go func() {
452		if got, want := cmt.BlockingReset(), ErrServiceStopped; !test.ErrorEqual(got, want) {
453			t.Errorf("BlockingReset() got: (%v), want: (%v)", got, want)
454		}
455		complete.SetDone()
456	}()
457
458	// committer.BlockingReset should return when the committer terminates due to
459	// fatal server error.
460	complete.WaitUntilDone(t, serviceTestWaitTimeout)
461
462	if gotErr := cmt.FinalError(); !test.ErrorEqual(gotErr, serverErr) {
463		t.Errorf("Final err: (%v), want: (%v)", gotErr, serverErr)
464	}
465}
466