1// Copyright 2020 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13
14package wire
15
16import (
17	"context"
18	"testing"
19
20	"cloud.google.com/go/pubsublite/internal/test"
21	"google.golang.org/grpc/codes"
22	"google.golang.org/grpc/status"
23)
24
25// testCommitter wraps a committer for ease of testing.
26type testCommitter struct {
27	cmt *committer
28	serviceTestProxy
29}
30
31func newTestCommitter(t *testing.T, subscription subscriptionPartition, acks *ackTracker) *testCommitter {
32	ctx := context.Background()
33	cursorClient, err := newCursorClient(ctx, "ignored", testServer.ClientConn())
34	if err != nil {
35		t.Fatal(err)
36	}
37
38	tc := &testCommitter{
39		cmt: newCommitter(ctx, cursorClient, testReceiveSettings(), subscription, acks, true),
40	}
41	tc.initAndStart(t, tc.cmt, "Committer", cursorClient)
42	return tc
43}
44
45// SendBatchCommit invokes the periodic background batch commit. Note that the
46// periodic task is disabled in tests.
47func (tc *testCommitter) SendBatchCommit() {
48	tc.cmt.commitOffsetToStream()
49}
50
51func (tc *testCommitter) Terminate() {
52	tc.cmt.Terminate()
53}
54
55func TestCommitterStreamReconnect(t *testing.T) {
56	subscription := subscriptionPartition{"projects/123456/locations/us-central1-b/subscriptions/my-subs", 0}
57	ack1 := newAckConsumer(33, 0, nil)
58	ack2 := newAckConsumer(55, 0, nil)
59	acks := newAckTracker()
60	acks.Push(ack1)
61	acks.Push(ack2)
62
63	verifiers := test.NewVerifiers(t)
64
65	// Simulate a transient error that results in a reconnect.
66	stream1 := test.NewRPCVerifier(t)
67	stream1.Push(initCommitReq(subscription), initCommitResp(), nil)
68	barrier := stream1.PushWithBarrier(commitReq(34), nil, status.Error(codes.Unavailable, "server unavailable"))
69	verifiers.AddCommitStream(subscription.Path, subscription.Partition, stream1)
70
71	// When the stream reconnects, the latest commit offset should be sent to the
72	// server.
73	stream2 := test.NewRPCVerifier(t)
74	stream2.Push(initCommitReq(subscription), initCommitResp(), nil)
75	stream2.Push(commitReq(56), commitResp(1), nil)
76	verifiers.AddCommitStream(subscription.Path, subscription.Partition, stream2)
77
78	mockServer.OnTestStart(verifiers)
79	defer mockServer.OnTestEnd()
80
81	cmt := newTestCommitter(t, subscription, acks)
82	if gotErr := cmt.StartError(); gotErr != nil {
83		t.Errorf("Start() got err: (%v)", gotErr)
84	}
85
86	// Send 2 commits.
87	ack1.Ack()
88	cmt.SendBatchCommit()
89	ack2.Ack()
90	cmt.SendBatchCommit()
91
92	// Then send the retryable error, which results in reconnect.
93	barrier.Release()
94	cmt.StopVerifyNoError()
95}
96
97func TestCommitterStopFlushesCommits(t *testing.T) {
98	subscription := subscriptionPartition{"projects/123456/locations/us-central1-b/subscriptions/my-subs", 0}
99	ack1 := newAckConsumer(33, 0, nil)
100	ack2 := newAckConsumer(55, 0, nil)
101	acks := newAckTracker()
102	acks.Push(ack1)
103	acks.Push(ack2)
104
105	verifiers := test.NewVerifiers(t)
106	stream := test.NewRPCVerifier(t)
107	stream.Push(initCommitReq(subscription), initCommitResp(), nil)
108	stream.Push(commitReq(34), commitResp(1), nil)
109	stream.Push(commitReq(56), commitResp(1), nil)
110	verifiers.AddCommitStream(subscription.Path, subscription.Partition, stream)
111
112	mockServer.OnTestStart(verifiers)
113	defer mockServer.OnTestEnd()
114
115	cmt := newTestCommitter(t, subscription, acks)
116	if gotErr := cmt.StartError(); gotErr != nil {
117		t.Errorf("Start() got err: (%v)", gotErr)
118	}
119
120	ack1.Ack()
121	cmt.Stop() // Stop should flush the first offset
122	ack2.Ack() // Acks after Stop() are processed
123	cmt.SendBatchCommit()
124	// Committer terminates when all acks are processed.
125	if gotErr := cmt.FinalError(); gotErr != nil {
126		t.Errorf("Final err: (%v), want: <nil>", gotErr)
127	}
128}
129
130func TestCommitterTerminateDiscardsOutstandingAcks(t *testing.T) {
131	subscription := subscriptionPartition{"projects/123456/locations/us-central1-b/subscriptions/my-subs", 0}
132	ack1 := newAckConsumer(33, 0, nil)
133	ack2 := newAckConsumer(55, 0, nil)
134	acks := newAckTracker()
135	acks.Push(ack1)
136	acks.Push(ack2)
137
138	verifiers := test.NewVerifiers(t)
139	stream := test.NewRPCVerifier(t)
140	stream.Push(initCommitReq(subscription), initCommitResp(), nil)
141	stream.Push(commitReq(34), commitResp(1), nil)
142	verifiers.AddCommitStream(subscription.Path, subscription.Partition, stream)
143
144	mockServer.OnTestStart(verifiers)
145	defer mockServer.OnTestEnd()
146
147	cmt := newTestCommitter(t, subscription, acks)
148	if gotErr := cmt.StartError(); gotErr != nil {
149		t.Errorf("Start() got err: (%v)", gotErr)
150	}
151
152	ack1.Ack()
153	cmt.Terminate()       // Terminate should flush the first offset
154	ack2.Ack()            // Acks after Terminate() are discarded
155	cmt.SendBatchCommit() // Should do nothing (server does not expect second commit)
156	if gotErr := cmt.FinalError(); gotErr != nil {
157		t.Errorf("Final err: (%v), want: <nil>", gotErr)
158	}
159}
160
161func TestCommitterPermanentStreamError(t *testing.T) {
162	subscription := subscriptionPartition{"projects/123456/locations/us-central1-b/subscriptions/my-subs", 0}
163	acks := newAckTracker()
164	wantErr := status.Error(codes.FailedPrecondition, "failed")
165
166	verifiers := test.NewVerifiers(t)
167	stream := test.NewRPCVerifier(t)
168	stream.Push(initCommitReq(subscription), nil, wantErr)
169	verifiers.AddCommitStream(subscription.Path, subscription.Partition, stream)
170
171	mockServer.OnTestStart(verifiers)
172	defer mockServer.OnTestEnd()
173
174	cmt := newTestCommitter(t, subscription, acks)
175	if gotErr := cmt.StartError(); !test.ErrorEqual(gotErr, wantErr) {
176		t.Errorf("Start() got err: (%v), want: (%v)", gotErr, wantErr)
177	}
178}
179
180func TestCommitterInvalidInitialResponse(t *testing.T) {
181	subscription := subscriptionPartition{"projects/123456/locations/us-central1-b/subscriptions/my-subs", 0}
182	acks := newAckTracker()
183
184	verifiers := test.NewVerifiers(t)
185	stream := test.NewRPCVerifier(t)
186	stream.Push(initCommitReq(subscription), commitResp(1234), nil) // Invalid initial response
187	verifiers.AddCommitStream(subscription.Path, subscription.Partition, stream)
188
189	mockServer.OnTestStart(verifiers)
190	defer mockServer.OnTestEnd()
191
192	cmt := newTestCommitter(t, subscription, acks)
193
194	wantErr := errInvalidInitialCommitResponse
195	if gotErr := cmt.StartError(); !test.ErrorEqual(gotErr, wantErr) {
196		t.Errorf("Start() got err: (%v), want: (%v)", gotErr, wantErr)
197	}
198	if gotErr := cmt.FinalError(); !test.ErrorEqual(gotErr, wantErr) {
199		t.Errorf("Final err: (%v), want: (%v)", gotErr, wantErr)
200	}
201}
202
203func TestCommitterInvalidCommitResponse(t *testing.T) {
204	subscription := subscriptionPartition{"projects/123456/locations/us-central1-b/subscriptions/my-subs", 0}
205	ack := newAckConsumer(33, 0, nil)
206	acks := newAckTracker()
207	acks.Push(ack)
208
209	verifiers := test.NewVerifiers(t)
210	stream := test.NewRPCVerifier(t)
211	stream.Push(initCommitReq(subscription), initCommitResp(), nil)
212	stream.Push(commitReq(34), initCommitResp(), nil) // Invalid commit response
213	verifiers.AddCommitStream(subscription.Path, subscription.Partition, stream)
214
215	mockServer.OnTestStart(verifiers)
216	defer mockServer.OnTestEnd()
217
218	cmt := newTestCommitter(t, subscription, acks)
219	if gotErr := cmt.StartError(); gotErr != nil {
220		t.Errorf("Start() got err: (%v)", gotErr)
221	}
222
223	ack.Ack()
224	cmt.SendBatchCommit()
225
226	if gotErr, wantErr := cmt.FinalError(), errInvalidCommitResponse; !test.ErrorEqual(gotErr, wantErr) {
227		t.Errorf("Final err: (%v), want: (%v)", gotErr, wantErr)
228	}
229}
230
231func TestCommitterExcessConfirmedOffsets(t *testing.T) {
232	subscription := subscriptionPartition{"projects/123456/locations/us-central1-b/subscriptions/my-subs", 0}
233	ack := newAckConsumer(33, 0, nil)
234	acks := newAckTracker()
235	acks.Push(ack)
236
237	verifiers := test.NewVerifiers(t)
238	stream := test.NewRPCVerifier(t)
239	stream.Push(initCommitReq(subscription), initCommitResp(), nil)
240	stream.Push(commitReq(34), commitResp(2), nil) // More confirmed offsets than committed
241	verifiers.AddCommitStream(subscription.Path, subscription.Partition, stream)
242
243	mockServer.OnTestStart(verifiers)
244	defer mockServer.OnTestEnd()
245
246	cmt := newTestCommitter(t, subscription, acks)
247	if gotErr := cmt.StartError(); gotErr != nil {
248		t.Errorf("Start() got err: (%v)", gotErr)
249	}
250
251	ack.Ack()
252	cmt.SendBatchCommit()
253
254	wantMsg := "server acknowledged 2 cursor commits"
255	if gotErr := cmt.FinalError(); !test.ErrorHasMsg(gotErr, wantMsg) {
256		t.Errorf("Final err: (%v), want msg: (%v)", gotErr, wantMsg)
257	}
258}
259
260func TestCommitterZeroConfirmedOffsets(t *testing.T) {
261	subscription := subscriptionPartition{"projects/123456/locations/us-central1-b/subscriptions/my-subs", 0}
262	ack := newAckConsumer(33, 0, nil)
263	acks := newAckTracker()
264	acks.Push(ack)
265
266	verifiers := test.NewVerifiers(t)
267	stream := test.NewRPCVerifier(t)
268	stream.Push(initCommitReq(subscription), initCommitResp(), nil)
269	stream.Push(commitReq(34), commitResp(0), nil) // Zero confirmed offsets (invalid)
270	verifiers.AddCommitStream(subscription.Path, subscription.Partition, stream)
271
272	mockServer.OnTestStart(verifiers)
273	defer mockServer.OnTestEnd()
274
275	cmt := newTestCommitter(t, subscription, acks)
276	if gotErr := cmt.StartError(); gotErr != nil {
277		t.Errorf("Start() got err: (%v)", gotErr)
278	}
279
280	ack.Ack()
281	cmt.SendBatchCommit()
282
283	wantMsg := "server acknowledged an invalid commit count"
284	if gotErr := cmt.FinalError(); !test.ErrorHasMsg(gotErr, wantMsg) {
285		t.Errorf("Final err: (%v), want msg: (%v)", gotErr, wantMsg)
286	}
287}
288