1// Copyright 2020 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13
14package wire
15
16import (
17	"context"
18	"sort"
19	"sync"
20	"testing"
21	"time"
22
23	"cloud.google.com/go/internal/testutil"
24	"cloud.google.com/go/pubsublite/internal/test"
25	"github.com/google/go-cmp/cmp/cmpopts"
26	"google.golang.org/grpc/codes"
27	"google.golang.org/grpc/status"
28	"google.golang.org/protobuf/proto"
29
30	pb "google.golang.org/genproto/googleapis/cloud/pubsublite/v1"
31)
32
33func testSubscriberSettings() ReceiveSettings {
34	settings := testReceiveSettings()
35	settings.MaxOutstandingMessages = 10
36	settings.MaxOutstandingBytes = 1000
37	return settings
38}
39
40// initFlowControlReq returns the first expected flow control request when
41// testSubscriberSettings are used.
42func initFlowControlReq() *pb.SubscribeRequest {
43	return flowControlSubReq(flowControlTokens{Bytes: 1000, Messages: 10})
44}
45
46func partitionMsgs(partition int, msgs ...*pb.SequencedMessage) []*ReceivedMessage {
47	var received []*ReceivedMessage
48	for _, msg := range msgs {
49		received = append(received, &ReceivedMessage{Msg: msg, Partition: partition})
50	}
51	return received
52}
53
54func join(args ...[]*ReceivedMessage) []*ReceivedMessage {
55	var received []*ReceivedMessage
56	for _, msgs := range args {
57		received = append(received, msgs...)
58	}
59	return received
60}
61
62type testMessageReceiver struct {
63	t        *testing.T
64	received chan *ReceivedMessage
65}
66
67func newTestMessageReceiver(t *testing.T) *testMessageReceiver {
68	return &testMessageReceiver{
69		t:        t,
70		received: make(chan *ReceivedMessage, 5),
71	}
72}
73
74func (tr *testMessageReceiver) onMessage(msg *ReceivedMessage) {
75	tr.received <- msg
76}
77
78func (tr *testMessageReceiver) ValidateMsg(want *pb.SequencedMessage) AckConsumer {
79	select {
80	case <-time.After(serviceTestWaitTimeout):
81		tr.t.Errorf("Message (%v) not received within %v", want, serviceTestWaitTimeout)
82		return nil
83	case got := <-tr.received:
84		if !proto.Equal(got.Msg, want) {
85			tr.t.Errorf("Received message: got (%v), want (%v)", got.Msg, want)
86		}
87		return got.Ack
88	}
89}
90
91type ByMsgOffset []*ReceivedMessage
92
93func (m ByMsgOffset) Len() int      { return len(m) }
94func (m ByMsgOffset) Swap(i, j int) { m[i], m[j] = m[j], m[i] }
95func (m ByMsgOffset) Less(i, j int) bool {
96	return m[i].Msg.GetCursor().GetOffset() < m[j].Msg.GetCursor().GetOffset()
97}
98
99func (tr *testMessageReceiver) ValidateMsgs(want []*ReceivedMessage) {
100	var got []*ReceivedMessage
101	for count := 0; count < len(want); count++ {
102		select {
103		case <-time.After(serviceTestWaitTimeout):
104			tr.t.Errorf("Received messages count: got %d, want %d", count, len(want))
105		case received := <-tr.received:
106			received.Ack.Ack()
107			got = append(got, received)
108		}
109	}
110
111	sort.Sort(ByMsgOffset(want))
112	sort.Sort(ByMsgOffset(got))
113	if !testutil.Equal(got, want, cmpopts.IgnoreFields(ReceivedMessage{}, "Ack")) {
114		tr.t.Errorf("Received messages: got: %v\nwant: %v", got, want)
115	}
116}
117
118func (tr *testMessageReceiver) VerifyNoMsgs() {
119	select {
120	case got := <-tr.received:
121		tr.t.Errorf("Got unexpected message: %v", got.Msg)
122	case <-time.After(20 * time.Millisecond):
123		// Wait to ensure no messages received.
124	}
125}
126
127// testBlockingMessageReceiver can be used to simulate a client message receiver
128// func that is blocking due to slow message processing.
129type testBlockingMessageReceiver struct {
130	blockReceive chan struct{}
131
132	testMessageReceiver
133}
134
135func newTestBlockingMessageReceiver(t *testing.T) *testBlockingMessageReceiver {
136	return &testBlockingMessageReceiver{
137		testMessageReceiver: testMessageReceiver{
138			t:        t,
139			received: make(chan *ReceivedMessage, 5),
140		},
141		blockReceive: make(chan struct{}),
142	}
143}
144
145// onMessage is the message receiver func and blocks until there is a call to
146// Return().
147func (tr *testBlockingMessageReceiver) onMessage(msg *ReceivedMessage) {
148	tr.testMessageReceiver.onMessage(msg)
149	<-tr.blockReceive
150}
151
152// Return signals onMessage to return.
153func (tr *testBlockingMessageReceiver) Return() {
154	var void struct{}
155	tr.blockReceive <- void
156}
157
158func TestMessageDeliveryQueueStartStop(t *testing.T) {
159	acks := newAckTracker()
160	receiver := newTestMessageReceiver(t)
161	messageQueue := newMessageDeliveryQueue(acks, receiver.onMessage, 10)
162
163	t.Run("Add before start", func(t *testing.T) {
164		msg1 := seqMsgWithOffset(1)
165		ack1 := newAckConsumer(1, 0, nil)
166		messageQueue.Add(&ReceivedMessage{Msg: msg1, Ack: ack1})
167
168		receiver.VerifyNoMsgs()
169	})
170
171	t.Run("Add after start", func(t *testing.T) {
172		msg2 := seqMsgWithOffset(2)
173		ack2 := newAckConsumer(2, 0, nil)
174		msg3 := seqMsgWithOffset(3)
175		ack3 := newAckConsumer(3, 0, nil)
176
177		messageQueue.Start()
178		messageQueue.Start() // Check duplicate starts
179		messageQueue.Add(&ReceivedMessage{Msg: msg2, Ack: ack2})
180		messageQueue.Add(&ReceivedMessage{Msg: msg3, Ack: ack3})
181
182		receiver.ValidateMsg(msg2)
183		receiver.ValidateMsg(msg3)
184	})
185
186	t.Run("Add after stop", func(t *testing.T) {
187		msg4 := seqMsgWithOffset(4)
188		ack4 := newAckConsumer(4, 0, nil)
189
190		messageQueue.Stop()
191		messageQueue.Stop() // Check duplicate stop
192		messageQueue.Add(&ReceivedMessage{Msg: msg4, Ack: ack4})
193		messageQueue.Wait()
194
195		receiver.VerifyNoMsgs()
196	})
197
198	t.Run("Restart", func(t *testing.T) {
199		msg5 := seqMsgWithOffset(5)
200		ack5 := newAckConsumer(5, 0, nil)
201
202		messageQueue.Start()
203		messageQueue.Add(&ReceivedMessage{Msg: msg5, Ack: ack5})
204
205		receiver.ValidateMsg(msg5)
206	})
207
208	t.Run("Stop", func(t *testing.T) {
209		messageQueue.Stop()
210		messageQueue.Wait()
211
212		receiver.VerifyNoMsgs()
213	})
214}
215
216func TestMessageDeliveryQueueDiscardMessages(t *testing.T) {
217	acks := newAckTracker()
218	blockingReceiver := newTestBlockingMessageReceiver(t)
219	messageQueue := newMessageDeliveryQueue(acks, blockingReceiver.onMessage, 10)
220
221	msg1 := seqMsgWithOffset(1)
222	ack1 := newAckConsumer(1, 0, nil)
223	msg2 := seqMsgWithOffset(2)
224	ack2 := newAckConsumer(2, 0, nil)
225
226	messageQueue.Start()
227	messageQueue.Add(&ReceivedMessage{Msg: msg1, Ack: ack1})
228	messageQueue.Add(&ReceivedMessage{Msg: msg2, Ack: ack2})
229
230	// The blocking receiver suspends after receiving msg1.
231	blockingReceiver.ValidateMsg(msg1)
232	// Stopping the message queue should discard undelivered msg2.
233	messageQueue.Stop()
234
235	// Unsuspend the blocking receiver and verify msg2 is not received.
236	blockingReceiver.Return()
237	messageQueue.Wait()
238	blockingReceiver.VerifyNoMsgs()
239	if got, want := acks.outstandingAcks.Len(), 1; got != want {
240		t.Errorf("ackTracker.outstandingAcks.Len() got %v, want %v", got, want)
241	}
242}
243
244// testSubscribeStream wraps a subscribeStream for ease of testing.
245type testSubscribeStream struct {
246	Receiver *testMessageReceiver
247	t        *testing.T
248	acks     *ackTracker
249	sub      *subscribeStream
250	mu       sync.Mutex
251	resetErr error
252	serviceTestProxy
253}
254
255func newTestSubscribeStream(t *testing.T, subscription subscriptionPartition, settings ReceiveSettings) *testSubscribeStream {
256	ctx := context.Background()
257	subClient, err := newSubscriberClient(ctx, "ignored", testServer.ClientConn())
258	if err != nil {
259		t.Fatal(err)
260	}
261
262	ts := &testSubscribeStream{
263		Receiver: newTestMessageReceiver(t),
264		t:        t,
265		acks:     newAckTracker(),
266	}
267	ts.sub = newSubscribeStream(ctx, subClient, settings, ts.Receiver.onMessage, subscription, ts.acks, ts.handleReset, true)
268	ts.initAndStart(t, ts.sub, "Subscriber", subClient)
269	return ts
270}
271
272// SendBatchFlowControl invokes the periodic background batch flow control. Note
273// that the periodic task is disabled in tests.
274func (ts *testSubscribeStream) SendBatchFlowControl() {
275	ts.sub.sendBatchFlowControl()
276}
277
278func (ts *testSubscribeStream) PendingFlowControlRequest() *pb.FlowControlRequest {
279	ts.sub.mu.Lock()
280	defer ts.sub.mu.Unlock()
281	return ts.sub.flowControl.pendingTokens.ToFlowControlRequest()
282}
283
284func (ts *testSubscribeStream) SetResetErr(err error) {
285	ts.mu.Lock()
286	defer ts.mu.Unlock()
287	ts.resetErr = err
288}
289
290func (ts *testSubscribeStream) handleReset() error {
291	ts.mu.Lock()
292	defer ts.mu.Unlock()
293	return ts.resetErr
294}
295
296func TestSubscribeStreamReconnect(t *testing.T) {
297	subscription := subscriptionPartition{"projects/123456/locations/us-central1-b/subscriptions/my-sub", 0}
298	msg1 := seqMsgWithOffsetAndSize(67, 200)
299	msg2 := seqMsgWithOffsetAndSize(68, 100)
300	permanentErr := status.Error(codes.FailedPrecondition, "permanent failure")
301
302	verifiers := test.NewVerifiers(t)
303
304	stream1 := test.NewRPCVerifier(t)
305	stream1.Push(initSubReqCommit(subscription), initSubResp(), nil)
306	stream1.Push(initFlowControlReq(), msgSubResp(msg1), nil)
307	stream1.Push(nil, nil, status.Error(codes.Unavailable, "server unavailable"))
308	verifiers.AddSubscribeStream(subscription.Path, subscription.Partition, stream1)
309
310	// When reconnected, the subscribeStream should set initial cursor to msg2 and
311	// have subtracted flow control tokens.
312	stream2 := test.NewRPCVerifier(t)
313	stream2.Push(initSubReqCursor(subscription, 68), initSubResp(), nil)
314	stream2.Push(flowControlSubReq(flowControlTokens{Bytes: 800, Messages: 9}), msgSubResp(msg2), nil)
315	// Subscriber should terminate on permanent error.
316	stream2.Push(nil, nil, permanentErr)
317	verifiers.AddSubscribeStream(subscription.Path, subscription.Partition, stream2)
318
319	mockServer.OnTestStart(verifiers)
320	defer mockServer.OnTestEnd()
321
322	sub := newTestSubscribeStream(t, subscription, testSubscriberSettings())
323	if gotErr := sub.StartError(); gotErr != nil {
324		t.Errorf("Start() got err: (%v)", gotErr)
325	}
326	sub.Receiver.ValidateMsg(msg1)
327	sub.Receiver.ValidateMsg(msg2)
328	if gotErr := sub.FinalError(); !test.ErrorEqual(gotErr, permanentErr) {
329		t.Errorf("Final err: (%v), want: (%v)", gotErr, permanentErr)
330	}
331}
332
333func TestSubscribeStreamFlowControlBatching(t *testing.T) {
334	subscription := subscriptionPartition{"projects/123456/locations/us-central1-b/subscriptions/my-sub", 0}
335	msg1 := seqMsgWithOffsetAndSize(67, 200)
336	msg2 := seqMsgWithOffsetAndSize(68, 100)
337	serverErr := status.Error(codes.InvalidArgument, "verifies flow control received")
338
339	verifiers := test.NewVerifiers(t)
340	stream := test.NewRPCVerifier(t)
341	stream.Push(initSubReqCommit(subscription), initSubResp(), nil)
342	stream.Push(initFlowControlReq(), msgSubResp(msg1, msg2), nil)
343	// Batch flow control request expected.
344	stream.Push(flowControlSubReq(flowControlTokens{Bytes: 300, Messages: 2}), nil, serverErr)
345	verifiers.AddSubscribeStream(subscription.Path, subscription.Partition, stream)
346
347	mockServer.OnTestStart(verifiers)
348	defer mockServer.OnTestEnd()
349
350	sub := newTestSubscribeStream(t, subscription, testSubscriberSettings())
351	if gotErr := sub.StartError(); gotErr != nil {
352		t.Errorf("Start() got err: (%v)", gotErr)
353	}
354	sub.Receiver.ValidateMsg(msg1)
355	sub.Receiver.ValidateMsg(msg2)
356	sub.sub.onAck(&ackConsumer{MsgBytes: msg1.SizeBytes})
357	sub.sub.onAck(&ackConsumer{MsgBytes: msg2.SizeBytes})
358	sub.sub.sendBatchFlowControl()
359	if gotErr := sub.FinalError(); !test.ErrorEqual(gotErr, serverErr) {
360		t.Errorf("Final err: (%v), want: (%v)", gotErr, serverErr)
361	}
362}
363
364func TestSubscribeStreamExpediteFlowControl(t *testing.T) {
365	subscription := subscriptionPartition{"projects/123456/locations/us-central1-b/subscriptions/my-sub", 0}
366	msg1 := seqMsgWithOffsetAndSize(67, 250)
367	// MaxOutstandingBytes = 1000, so msg2 pushes the pending flow control bytes
368	// over the expediteBatchRequestRatio=50% threshold in flowControlBatcher.
369	msg2 := seqMsgWithOffsetAndSize(68, 251)
370	serverErr := status.Error(codes.InvalidArgument, "verifies flow control received")
371
372	verifiers := test.NewVerifiers(t)
373	stream := test.NewRPCVerifier(t)
374	stream.Push(initSubReqCommit(subscription), initSubResp(), nil)
375	stream.Push(initFlowControlReq(), msgSubResp(msg1, msg2), nil)
376	// Batch flow control request expected.
377	stream.Push(flowControlSubReq(flowControlTokens{Bytes: 501, Messages: 2}), nil, serverErr)
378	verifiers.AddSubscribeStream(subscription.Path, subscription.Partition, stream)
379
380	mockServer.OnTestStart(verifiers)
381	defer mockServer.OnTestEnd()
382
383	sub := newTestSubscribeStream(t, subscription, testSubscriberSettings())
384	if gotErr := sub.StartError(); gotErr != nil {
385		t.Errorf("Start() got err: (%v)", gotErr)
386	}
387	sub.Receiver.ValidateMsg(msg1)
388	sub.Receiver.ValidateMsg(msg2)
389	sub.sub.onAck(&ackConsumer{MsgBytes: msg1.SizeBytes})
390	sub.sub.onAck(&ackConsumer{MsgBytes: msg2.SizeBytes})
391	// Note: the ack for msg2 automatically triggers sending the flow control.
392	if gotErr := sub.FinalError(); !test.ErrorEqual(gotErr, serverErr) {
393		t.Errorf("Final err: (%v), want: (%v)", gotErr, serverErr)
394	}
395}
396
397func TestSubscribeStreamDisableBatchFlowControl(t *testing.T) {
398	subscription := subscriptionPartition{"projects/123456/locations/us-central1-b/subscriptions/my-sub", 0}
399	// MaxOutstandingBytes = 1000, so this pushes the pending flow control bytes
400	// over the expediteBatchRequestRatio=50% threshold in flowControlBatcher.
401	msg := seqMsgWithOffsetAndSize(67, 800)
402	retryableErr := status.Error(codes.Unavailable, "unavailable")
403	serverErr := status.Error(codes.InvalidArgument, "verifies flow control received")
404
405	verifiers := test.NewVerifiers(t)
406
407	stream1 := test.NewRPCVerifier(t)
408	stream1.Push(initSubReqCommit(subscription), initSubResp(), nil)
409	stream1.Push(initFlowControlReq(), msgSubResp(msg), nil)
410	// Break the stream immediately after sending the message.
411	stream1.Push(nil, nil, retryableErr)
412	verifiers.AddSubscribeStream(subscription.Path, subscription.Partition, stream1)
413
414	stream2 := test.NewRPCVerifier(t)
415	// The barrier is used to pause in the middle of stream reconnection.
416	barrier := stream2.PushWithBarrier(initSubReqCursor(subscription, 68), initSubResp(), nil)
417	// Full flow control tokens should be sent after stream has connected.
418	stream2.Push(initFlowControlReq(), nil, serverErr)
419	verifiers.AddSubscribeStream(subscription.Path, subscription.Partition, stream2)
420
421	mockServer.OnTestStart(verifiers)
422	defer mockServer.OnTestEnd()
423
424	sub := newTestSubscribeStream(t, subscription, testSubscriberSettings())
425	if gotErr := sub.StartError(); gotErr != nil {
426		t.Errorf("Start() got err: (%v)", gotErr)
427	}
428
429	sub.Receiver.ValidateMsg(msg)
430	barrier.ReleaseAfter(func() {
431		// While the stream is not connected, the pending flow control request
432		// should not be released and sent to the stream.
433		sub.sub.onAck(&ackConsumer{MsgBytes: msg.SizeBytes})
434		if sub.PendingFlowControlRequest() == nil {
435			t.Errorf("Pending flow control request should not be cleared")
436		}
437	})
438
439	if gotErr := sub.FinalError(); !test.ErrorEqual(gotErr, serverErr) {
440		t.Errorf("Final err: (%v), want: (%v)", gotErr, serverErr)
441	}
442}
443
444func TestSubscribeStreamInvalidInitialResponse(t *testing.T) {
445	subscription := subscriptionPartition{"projects/123456/locations/us-central1-b/subscriptions/my-sub", 0}
446
447	verifiers := test.NewVerifiers(t)
448	stream := test.NewRPCVerifier(t)
449	stream.Push(initSubReqCommit(subscription), seekResp(0), nil) // Seek instead of init response
450	verifiers.AddSubscribeStream(subscription.Path, subscription.Partition, stream)
451
452	mockServer.OnTestStart(verifiers)
453	defer mockServer.OnTestEnd()
454
455	sub := newTestSubscribeStream(t, subscription, testSubscriberSettings())
456	if gotErr, wantErr := sub.StartError(), errInvalidInitialSubscribeResponse; !test.ErrorEqual(gotErr, wantErr) {
457		t.Errorf("Start got err: (%v), want: (%v)", gotErr, wantErr)
458	}
459}
460
461func TestSubscribeStreamDuplicateInitialResponse(t *testing.T) {
462	subscription := subscriptionPartition{"projects/123456/locations/us-central1-b/subscriptions/my-sub", 0}
463
464	verifiers := test.NewVerifiers(t)
465	stream := test.NewRPCVerifier(t)
466	stream.Push(initSubReqCommit(subscription), initSubResp(), nil)
467	stream.Push(initFlowControlReq(), initSubResp(), nil) // Second initial response
468	verifiers.AddSubscribeStream(subscription.Path, subscription.Partition, stream)
469
470	mockServer.OnTestStart(verifiers)
471	defer mockServer.OnTestEnd()
472
473	sub := newTestSubscribeStream(t, subscription, testSubscriberSettings())
474	if gotErr, wantErr := sub.FinalError(), errInvalidSubscribeResponse; !test.ErrorEqual(gotErr, wantErr) {
475		t.Errorf("Final err: (%v), want: (%v)", gotErr, wantErr)
476	}
477}
478
479func TestSubscribeStreamSpuriousSeekResponse(t *testing.T) {
480	subscription := subscriptionPartition{"projects/123456/locations/us-central1-b/subscriptions/my-sub", 0}
481
482	verifiers := test.NewVerifiers(t)
483	stream := test.NewRPCVerifier(t)
484	stream.Push(initSubReqCommit(subscription), initSubResp(), nil)
485	stream.Push(initFlowControlReq(), seekResp(1), nil) // Seek response with no seek request
486	verifiers.AddSubscribeStream(subscription.Path, subscription.Partition, stream)
487
488	mockServer.OnTestStart(verifiers)
489	defer mockServer.OnTestEnd()
490
491	sub := newTestSubscribeStream(t, subscription, testSubscriberSettings())
492	if gotErr, wantErr := sub.FinalError(), errInvalidSubscribeResponse; !test.ErrorEqual(gotErr, wantErr) {
493		t.Errorf("Final err: (%v), want: (%v)", gotErr, wantErr)
494	}
495}
496
497func TestSubscribeStreamNoMessages(t *testing.T) {
498	subscription := subscriptionPartition{"projects/123456/locations/us-central1-b/subscriptions/my-sub", 0}
499
500	verifiers := test.NewVerifiers(t)
501	stream := test.NewRPCVerifier(t)
502	stream.Push(initSubReqCommit(subscription), initSubResp(), nil)
503	stream.Push(initFlowControlReq(), msgSubResp(), nil) // No messages in response
504	verifiers.AddSubscribeStream(subscription.Path, subscription.Partition, stream)
505
506	mockServer.OnTestStart(verifiers)
507	defer mockServer.OnTestEnd()
508
509	sub := newTestSubscribeStream(t, subscription, testSubscriberSettings())
510	if gotErr, wantErr := sub.FinalError(), errServerNoMessages; !test.ErrorEqual(gotErr, wantErr) {
511		t.Errorf("Final err: (%v), want: (%v)", gotErr, wantErr)
512	}
513}
514
515func TestSubscribeStreamMessagesOutOfOrder(t *testing.T) {
516	subscription := subscriptionPartition{"projects/123456/locations/us-central1-b/subscriptions/my-sub", 0}
517	msg1 := seqMsgWithOffsetAndSize(56, 100)
518	msg2 := seqMsgWithOffsetAndSize(55, 100) // Offset before msg1
519
520	verifiers := test.NewVerifiers(t)
521	stream := test.NewRPCVerifier(t)
522	stream.Push(initSubReqCommit(subscription), initSubResp(), nil)
523	stream.Push(initFlowControlReq(), msgSubResp(msg1), nil)
524	stream.Push(nil, msgSubResp(msg2), nil)
525	verifiers.AddSubscribeStream(subscription.Path, subscription.Partition, stream)
526
527	mockServer.OnTestStart(verifiers)
528	defer mockServer.OnTestEnd()
529
530	sub := newTestSubscribeStream(t, subscription, testSubscriberSettings())
531	sub.Receiver.ValidateMsg(msg1)
532	if gotErr, msg := sub.FinalError(), "start offset = 55, expected >= 57"; !test.ErrorHasMsg(gotErr, msg) {
533		t.Errorf("Final err: (%v), want msg: %q", gotErr, msg)
534	}
535}
536
537func TestSubscribeStreamFlowControlOverflow(t *testing.T) {
538	subscription := subscriptionPartition{"projects/123456/locations/us-central1-b/subscriptions/my-sub", 0}
539	msg1 := seqMsgWithOffsetAndSize(56, 900)
540	msg2 := seqMsgWithOffsetAndSize(57, 101) // Overflows ReceiveSettings.MaxOutstandingBytes = 1000
541
542	verifiers := test.NewVerifiers(t)
543	stream := test.NewRPCVerifier(t)
544	stream.Push(initSubReqCommit(subscription), initSubResp(), nil)
545	stream.Push(initFlowControlReq(), msgSubResp(msg1), nil)
546	stream.Push(nil, msgSubResp(msg2), nil)
547	verifiers.AddSubscribeStream(subscription.Path, subscription.Partition, stream)
548
549	mockServer.OnTestStart(verifiers)
550	defer mockServer.OnTestEnd()
551
552	sub := newTestSubscribeStream(t, subscription, testSubscriberSettings())
553	sub.Receiver.ValidateMsg(msg1)
554	if gotErr, wantErr := sub.FinalError(), errTokenCounterBytesNegative; !test.ErrorEqual(gotErr, wantErr) {
555		t.Errorf("Final err: (%v), want: (%v)", gotErr, wantErr)
556	}
557}
558
559func TestSubscribeStreamHandleResetError(t *testing.T) {
560	subscription := subscriptionPartition{"projects/123456/locations/us-central1-b/subscriptions/my-sub", 0}
561	msg := seqMsgWithOffsetAndSize(67, 100)
562
563	verifiers := test.NewVerifiers(t)
564	stream := test.NewRPCVerifier(t)
565	stream.Push(initSubReqCommit(subscription), initSubResp(), nil)
566	stream.Push(initFlowControlReq(), msgSubResp(msg), nil)
567	barrier := stream.PushWithBarrier(nil, nil, makeStreamResetSignal())
568	verifiers.AddSubscribeStream(subscription.Path, subscription.Partition, stream)
569	// No reconnect expected because the reset handler func will fail.
570
571	mockServer.OnTestStart(verifiers)
572	defer mockServer.OnTestEnd()
573
574	sub := newTestSubscribeStream(t, subscription, testSubscriberSettings())
575	sub.SetResetErr(status.Error(codes.FailedPrecondition, "reset handler failed"))
576	if gotErr := sub.StartError(); gotErr != nil {
577		t.Errorf("Start() got err: (%v)", gotErr)
578	}
579	sub.Receiver.ValidateMsg(msg)
580	barrier.Release()
581	if gotErr := sub.FinalError(); gotErr != nil {
582		t.Errorf("Final err: (%v), want: <nil>", gotErr)
583	}
584}
585
586type testSinglePartitionSubscriber singlePartitionSubscriber
587
588func (t *testSinglePartitionSubscriber) WaitStopped() error {
589	err := t.compositeService.WaitStopped()
590	// Close connections.
591	t.committer.cursorClient.Close()
592	t.subscriber.subClient.Close()
593	return err
594}
595
596func newTestSinglePartitionSubscriber(t *testing.T, receiverFunc MessageReceiverFunc, subscription subscriptionPartition) *testSinglePartitionSubscriber {
597	ctx := context.Background()
598	subClient, err := newSubscriberClient(ctx, "ignored", testServer.ClientConn())
599	if err != nil {
600		t.Fatal(err)
601	}
602	cursorClient, err := newCursorClient(ctx, "ignored", testServer.ClientConn())
603	if err != nil {
604		t.Fatal(err)
605	}
606
607	f := &singlePartitionSubscriberFactory{
608		ctx:              ctx,
609		subClient:        subClient,
610		cursorClient:     cursorClient,
611		settings:         testSubscriberSettings(),
612		subscriptionPath: subscription.Path,
613		receiver:         receiverFunc,
614		disableTasks:     true, // Background tasks disabled to control event order
615	}
616	sub := f.New(subscription.Partition)
617	sub.Start()
618	return (*testSinglePartitionSubscriber)(sub)
619}
620
621func TestSinglePartitionSubscriberStartStop(t *testing.T) {
622	subscription := subscriptionPartition{"projects/123456/locations/us-central1-b/subscriptions/my-sub", 0}
623	receiver := newTestMessageReceiver(t)
624
625	verifiers := test.NewVerifiers(t)
626
627	// Verifies the behavior of the subscribeStream and committer when they are
628	// stopped before any messages are received.
629	subStream := test.NewRPCVerifier(t)
630	subStream.Push(initSubReqCommit(subscription), initSubResp(), nil)
631	barrier := subStream.PushWithBarrier(initFlowControlReq(), nil, nil)
632	verifiers.AddSubscribeStream(subscription.Path, subscription.Partition, subStream)
633
634	cmtStream := test.NewRPCVerifier(t)
635	cmtStream.Push(initCommitReq(subscription), initCommitResp(), nil)
636	verifiers.AddCommitStream(subscription.Path, subscription.Partition, cmtStream)
637
638	mockServer.OnTestStart(verifiers)
639	defer mockServer.OnTestEnd()
640
641	sub := newTestSinglePartitionSubscriber(t, receiver.onMessage, subscription)
642	if gotErr := sub.WaitStarted(); gotErr != nil {
643		t.Errorf("Start() got err: (%v)", gotErr)
644	}
645	barrier.Release() // To ensure the test is deterministic (i.e. flow control req always received)
646	sub.Stop()
647	if gotErr := sub.WaitStopped(); gotErr != nil {
648		t.Errorf("Stop() got err: (%v)", gotErr)
649	}
650}
651
652func TestSinglePartitionSubscriberSimpleMsgAck(t *testing.T) {
653	subscription := subscriptionPartition{"projects/123456/locations/us-central1-b/subscriptions/my-sub", 0}
654	receiver := newTestMessageReceiver(t)
655	msg1 := seqMsgWithOffsetAndSize(22, 100)
656	msg2 := seqMsgWithOffsetAndSize(23, 200)
657
658	verifiers := test.NewVerifiers(t)
659
660	subStream := test.NewRPCVerifier(t)
661	subStream.Push(initSubReqCommit(subscription), initSubResp(), nil)
662	subStream.Push(initFlowControlReq(), msgSubResp(msg1, msg2), nil)
663	verifiers.AddSubscribeStream(subscription.Path, subscription.Partition, subStream)
664
665	cmtStream := test.NewRPCVerifier(t)
666	cmtStream.Push(initCommitReq(subscription), initCommitResp(), nil)
667	cmtStream.Push(commitReq(24), commitResp(1), nil)
668	verifiers.AddCommitStream(subscription.Path, subscription.Partition, cmtStream)
669
670	mockServer.OnTestStart(verifiers)
671	defer mockServer.OnTestEnd()
672
673	sub := newTestSinglePartitionSubscriber(t, receiver.onMessage, subscription)
674	if gotErr := sub.WaitStarted(); gotErr != nil {
675		t.Errorf("Start() got err: (%v)", gotErr)
676	}
677	receiver.ValidateMsg(msg1).Ack()
678	receiver.ValidateMsg(msg2).Ack()
679	sub.Stop()
680	if gotErr := sub.WaitStopped(); gotErr != nil {
681		t.Errorf("Stop() got err: (%v)", gotErr)
682	}
683}
684
685func TestSinglePartitionSubscriberMessageQueue(t *testing.T) {
686	subscription := subscriptionPartition{"projects/123456/locations/us-central1-b/subscriptions/my-sub", 0}
687	receiver := newTestBlockingMessageReceiver(t)
688	msg1 := seqMsgWithOffsetAndSize(1, 100)
689	msg2 := seqMsgWithOffsetAndSize(2, 100)
690	msg3 := seqMsgWithOffsetAndSize(3, 100)
691	retryableErr := status.Error(codes.Unavailable, "should retry")
692
693	verifiers := test.NewVerifiers(t)
694
695	subStream1 := test.NewRPCVerifier(t)
696	subStream1.Push(initSubReqCommit(subscription), initSubResp(), nil)
697	subStream1.Push(initFlowControlReq(), msgSubResp(msg1), nil)
698	subStream1.Push(nil, msgSubResp(msg2), nil)
699	subStream1.Push(nil, nil, retryableErr)
700	verifiers.AddSubscribeStream(subscription.Path, subscription.Partition, subStream1)
701
702	// When reconnected, the subscribeStream should set initial cursor to msg3 and
703	// have subtracted flow control tokens for msg1 and msg2.
704	subStream2 := test.NewRPCVerifier(t)
705	subStream2.Push(initSubReqCursor(subscription, 3), initSubResp(), nil)
706	subStream2.Push(flowControlSubReq(flowControlTokens{Bytes: 800, Messages: 8}), msgSubResp(msg3), nil)
707	verifiers.AddSubscribeStream(subscription.Path, subscription.Partition, subStream2)
708
709	cmtStream := test.NewRPCVerifier(t)
710	cmtStream.Push(initCommitReq(subscription), initCommitResp(), nil)
711	cmtStream.Push(commitReq(4), commitResp(1), nil)
712	verifiers.AddCommitStream(subscription.Path, subscription.Partition, cmtStream)
713
714	mockServer.OnTestStart(verifiers)
715	defer mockServer.OnTestEnd()
716
717	sub := newTestSinglePartitionSubscriber(t, receiver.onMessage, subscription)
718	if gotErr := sub.WaitStarted(); gotErr != nil {
719		t.Errorf("Start() got err: (%v)", gotErr)
720	}
721
722	// Verifies that messageDeliveryQueue delivers messages sequentially and waits
723	// for the client message receiver func to return before delivering the next
724	// message.
725	var acks []AckConsumer
726	for _, msg := range []*pb.SequencedMessage{msg1, msg2, msg3} {
727		ack := receiver.ValidateMsg(msg)
728		acks = append(acks, ack)
729		receiver.VerifyNoMsgs()
730		receiver.Return()
731	}
732
733	// Ack all messages so that the committer terminates.
734	for _, ack := range acks {
735		ack.Ack()
736	}
737
738	sub.Stop()
739	if gotErr := sub.WaitStopped(); gotErr != nil {
740		t.Errorf("Stop() got err: (%v)", gotErr)
741	}
742}
743
744func TestSinglePartitionSubscriberStopDuringReceive(t *testing.T) {
745	subscription := subscriptionPartition{"projects/123456/locations/us-central1-b/subscriptions/my-sub", 0}
746	receiver := newTestBlockingMessageReceiver(t)
747	msg1 := seqMsgWithOffsetAndSize(1, 100)
748	msg2 := seqMsgWithOffsetAndSize(2, 100)
749
750	verifiers := test.NewVerifiers(t)
751
752	subStream := test.NewRPCVerifier(t)
753	subStream.Push(initSubReqCommit(subscription), initSubResp(), nil)
754	subStream.Push(initFlowControlReq(), msgSubResp(msg1, msg2), nil)
755	verifiers.AddSubscribeStream(subscription.Path, subscription.Partition, subStream)
756
757	cmtStream := test.NewRPCVerifier(t)
758	cmtStream.Push(initCommitReq(subscription), initCommitResp(), nil)
759	cmtStream.Push(commitReq(2), commitResp(1), nil)
760	verifiers.AddCommitStream(subscription.Path, subscription.Partition, cmtStream)
761
762	mockServer.OnTestStart(verifiers)
763	defer mockServer.OnTestEnd()
764
765	sub := newTestSinglePartitionSubscriber(t, receiver.onMessage, subscription)
766	if gotErr := sub.WaitStarted(); gotErr != nil {
767		t.Errorf("Start() got err: (%v)", gotErr)
768	}
769
770	receiver.ValidateMsg(msg1).Ack()
771
772	// Stop the subscriber before returning from the message receiver func.
773	sub.Stop()
774	receiver.Return()
775
776	if gotErr := sub.WaitStopped(); gotErr != nil {
777		t.Errorf("Stop() got err: (%v)", gotErr)
778	}
779	receiver.VerifyNoMsgs() // msg2 should not be received
780}
781
782func TestSinglePartitionSubscriberAdminSeekWhileConnected(t *testing.T) {
783	subscription := subscriptionPartition{"projects/123456/locations/us-central1-b/subscriptions/my-sub", 0}
784	receiver := newTestMessageReceiver(t)
785	msg1 := seqMsgWithOffsetAndSize(1, 100)
786	msg2 := seqMsgWithOffsetAndSize(2, 100)
787	msg3 := seqMsgWithOffsetAndSize(3, 100)
788
789	verifiers := test.NewVerifiers(t)
790
791	subStream1 := test.NewRPCVerifier(t)
792	subStream1.Push(initSubReqCommit(subscription), initSubResp(), nil)
793	subStream1.Push(initFlowControlReq(), msgSubResp(msg1, msg2, msg3), nil)
794	// Server disconnects the stream with the RESET signal.
795	barrier := subStream1.PushWithBarrier(nil, nil, makeStreamResetSignal())
796	verifiers.AddSubscribeStream(subscription.Path, subscription.Partition, subStream1)
797
798	subStream2 := test.NewRPCVerifier(t)
799	// Reconnected stream reads from commit cursor.
800	subStream2.Push(initSubReqCommit(subscription), initSubResp(), nil)
801	// Ensure that the subscriber resets state and can handle seeking back to
802	// msg1.
803	subStream2.Push(initFlowControlReq(), msgSubResp(msg1), nil)
804	verifiers.AddSubscribeStream(subscription.Path, subscription.Partition, subStream2)
805
806	cmtStream := test.NewRPCVerifier(t)
807	cmtStream.Push(initCommitReq(subscription), initCommitResp(), nil)
808	cmtStream.Push(commitReq(4), commitResp(1), nil)
809	cmtStream.Push(commitReq(2), commitResp(1), nil)
810	verifiers.AddCommitStream(subscription.Path, subscription.Partition, cmtStream)
811
812	mockServer.OnTestStart(verifiers)
813	defer mockServer.OnTestEnd()
814
815	sub := newTestSinglePartitionSubscriber(t, receiver.onMessage, subscription)
816	if gotErr := sub.WaitStarted(); gotErr != nil {
817		t.Errorf("Start() got err: (%v)", gotErr)
818	}
819
820	receiver.ValidateMsg(msg1).Ack()
821	receiver.ValidateMsg(msg2).Ack()
822	receiver.ValidateMsg(msg3).Ack()
823	barrier.Release()
824	receiver.ValidateMsg(msg1).Ack()
825
826	sub.Stop()
827	if gotErr := sub.WaitStopped(); gotErr != nil {
828		t.Errorf("Stop() got err: (%v)", gotErr)
829	}
830}
831
832func TestSinglePartitionSubscriberAdminSeekWhileReconnecting(t *testing.T) {
833	subscription := subscriptionPartition{"projects/123456/locations/us-central1-b/subscriptions/my-sub", 0}
834	receiver := newTestMessageReceiver(t)
835	msg1 := seqMsgWithOffsetAndSize(1, 100)
836	msg2 := seqMsgWithOffsetAndSize(2, 100)
837	msg3 := seqMsgWithOffsetAndSize(3, 100)
838
839	verifiers := test.NewVerifiers(t)
840
841	subStream1 := test.NewRPCVerifier(t)
842	subStream1.Push(initSubReqCommit(subscription), initSubResp(), nil)
843	subStream1.Push(initFlowControlReq(), msgSubResp(msg1, msg2, msg3), nil)
844	// Normal stream breakage.
845	barrier := subStream1.PushWithBarrier(nil, nil, status.Error(codes.DeadlineExceeded, ""))
846	verifiers.AddSubscribeStream(subscription.Path, subscription.Partition, subStream1)
847
848	subStream2 := test.NewRPCVerifier(t)
849	// The server sends the RESET signal during stream initialization.
850	subStream2.Push(initSubReqCursor(subscription, 4), nil, makeStreamResetSignal())
851	verifiers.AddSubscribeStream(subscription.Path, subscription.Partition, subStream2)
852
853	subStream3 := test.NewRPCVerifier(t)
854	// Reconnected stream reads from commit cursor.
855	subStream3.Push(initSubReqCommit(subscription), initSubResp(), nil)
856	// Ensure that the subscriber resets state and can handle seeking back to
857	// msg1.
858	subStream3.Push(initFlowControlReq(), msgSubResp(msg1), nil)
859	verifiers.AddSubscribeStream(subscription.Path, subscription.Partition, subStream3)
860
861	cmtStream := test.NewRPCVerifier(t)
862	cmtStream.Push(initCommitReq(subscription), initCommitResp(), nil)
863	cmtStream.Push(commitReq(3), commitResp(1), nil)
864	cmtStream.Push(commitReq(2), commitResp(1), nil)
865	verifiers.AddCommitStream(subscription.Path, subscription.Partition, cmtStream)
866
867	mockServer.OnTestStart(verifiers)
868	defer mockServer.OnTestEnd()
869
870	sub := newTestSinglePartitionSubscriber(t, receiver.onMessage, subscription)
871	if gotErr := sub.WaitStarted(); gotErr != nil {
872		t.Errorf("Start() got err: (%v)", gotErr)
873	}
874
875	receiver.ValidateMsg(msg1).Ack()
876	receiver.ValidateMsg(msg2).Ack()
877	ack := receiver.ValidateMsg(msg3) // Unacked message discarded
878	barrier.Release()
879	receiver.ValidateMsg(msg1).Ack()
880	ack.Ack() // Should be ignored
881
882	sub.Stop()
883	if gotErr := sub.WaitStopped(); gotErr != nil {
884		t.Errorf("Stop() got err: (%v)", gotErr)
885	}
886}
887
888func TestSinglePartitionSubscriberStopDuringAdminSeek(t *testing.T) {
889	subscription := subscriptionPartition{"projects/123456/locations/us-central1-b/subscriptions/my-sub", 0}
890	receiver := newTestMessageReceiver(t)
891	msg1 := seqMsgWithOffsetAndSize(1, 100)
892	msg2 := seqMsgWithOffsetAndSize(2, 100)
893
894	verifiers := test.NewVerifiers(t)
895
896	subStream := test.NewRPCVerifier(t)
897	subStream.Push(initSubReqCommit(subscription), initSubResp(), nil)
898	subStream.Push(initFlowControlReq(), msgSubResp(msg1, msg2), nil)
899	// Server disconnects the stream with the RESET signal.
900	subBarrier := subStream.PushWithBarrier(nil, nil, makeStreamResetSignal())
901	verifiers.AddSubscribeStream(subscription.Path, subscription.Partition, subStream)
902
903	cmtStream := test.NewRPCVerifier(t)
904	cmtStream.Push(initCommitReq(subscription), initCommitResp(), nil)
905	cmtBarrier := cmtStream.PushWithBarrier(commitReq(3), commitResp(1), nil)
906	verifiers.AddCommitStream(subscription.Path, subscription.Partition, cmtStream)
907
908	mockServer.OnTestStart(verifiers)
909	defer mockServer.OnTestEnd()
910
911	sub := newTestSinglePartitionSubscriber(t, receiver.onMessage, subscription)
912	if gotErr := sub.WaitStarted(); gotErr != nil {
913		t.Errorf("Start() got err: (%v)", gotErr)
914	}
915
916	receiver.ValidateMsg(msg1).Ack()
917	receiver.ValidateMsg(msg2).Ack()
918	subBarrier.Release()
919
920	// Ensure that the user is able to call Stop while a reset is in progress.
921	// Verifies that the subscribeStream is not holding mutexes while waiting and
922	// that the subscribe stream is not reconnected.
923	cmtBarrier.ReleaseAfter(func() {
924		sub.Stop()
925	})
926
927	if gotErr := sub.WaitStopped(); gotErr != nil {
928		t.Errorf("Stop() got err: (%v)", gotErr)
929	}
930}
931
932func newTestMultiPartitionSubscriber(t *testing.T, receiverFunc MessageReceiverFunc, subscriptionPath string, partitions []int) *multiPartitionSubscriber {
933	ctx := context.Background()
934	subClient, err := newSubscriberClient(ctx, "ignored", testServer.ClientConn())
935	if err != nil {
936		t.Fatal(err)
937	}
938	cursorClient, err := newCursorClient(ctx, "ignored", testServer.ClientConn())
939	if err != nil {
940		t.Fatal(err)
941	}
942	allClients := apiClients{subClient, cursorClient}
943
944	f := &singlePartitionSubscriberFactory{
945		ctx:              ctx,
946		subClient:        subClient,
947		cursorClient:     cursorClient,
948		settings:         testSubscriberSettings(),
949		subscriptionPath: subscriptionPath,
950		receiver:         receiverFunc,
951		disableTasks:     true, // Background tasks disabled to control event order
952	}
953	f.settings.Partitions = partitions
954	sub := newMultiPartitionSubscriber(allClients, f)
955	sub.Start()
956	return sub
957}
958
959func TestMultiPartitionSubscriberMultipleMessages(t *testing.T) {
960	const subscription = "projects/123456/locations/us-central1-b/subscriptions/my-sub"
961	receiver := newTestMessageReceiver(t)
962	msg1 := seqMsgWithOffsetAndSize(22, 100)
963	msg2 := seqMsgWithOffsetAndSize(23, 200)
964	msg3 := seqMsgWithOffsetAndSize(44, 100)
965	msg4 := seqMsgWithOffsetAndSize(45, 200)
966
967	verifiers := test.NewVerifiers(t)
968
969	// Partition 1
970	subStream1 := test.NewRPCVerifier(t)
971	subStream1.Push(initSubReqCommit(subscriptionPartition{Path: subscription, Partition: 1}), initSubResp(), nil)
972	subStream1.Push(initFlowControlReq(), msgSubResp(msg1), nil)
973	subStream1.Push(nil, msgSubResp(msg2), nil)
974	verifiers.AddSubscribeStream(subscription, 1, subStream1)
975
976	cmtStream1 := test.NewRPCVerifier(t)
977	cmtStream1.Push(initCommitReq(subscriptionPartition{Path: subscription, Partition: 1}), initCommitResp(), nil)
978	cmtStream1.Push(commitReq(24), commitResp(1), nil)
979	verifiers.AddCommitStream(subscription, 1, cmtStream1)
980
981	// Partition 2
982	subStream2 := test.NewRPCVerifier(t)
983	subStream2.Push(initSubReqCommit(subscriptionPartition{Path: subscription, Partition: 2}), initSubResp(), nil)
984	subStream2.Push(initFlowControlReq(), msgSubResp(msg3), nil)
985	subStream2.Push(nil, msgSubResp(msg4), nil)
986	verifiers.AddSubscribeStream(subscription, 2, subStream2)
987
988	cmtStream2 := test.NewRPCVerifier(t)
989	cmtStream2.Push(initCommitReq(subscriptionPartition{Path: subscription, Partition: 2}), initCommitResp(), nil)
990	cmtStream2.Push(commitReq(46), commitResp(1), nil)
991	verifiers.AddCommitStream(subscription, 2, cmtStream2)
992
993	mockServer.OnTestStart(verifiers)
994	defer mockServer.OnTestEnd()
995
996	sub := newTestMultiPartitionSubscriber(t, receiver.onMessage, subscription, []int{1, 2})
997	if gotErr := sub.WaitStarted(); gotErr != nil {
998		t.Errorf("Start() got err: (%v)", gotErr)
999	}
1000	receiver.ValidateMsgs(join(partitionMsgs(1, msg1, msg2), partitionMsgs(2, msg3, msg4)))
1001	sub.Stop()
1002	if gotErr := sub.WaitStopped(); gotErr != nil {
1003		t.Errorf("Stop() got err: (%v)", gotErr)
1004	}
1005}
1006
1007func TestMultiPartitionSubscriberPermanentError(t *testing.T) {
1008	const subscription = "projects/123456/locations/us-central1-b/subscriptions/my-sub"
1009	receiver := newTestMessageReceiver(t)
1010	msg1 := seqMsgWithOffsetAndSize(22, 100)
1011	msg2 := seqMsgWithOffsetAndSize(23, 200)
1012	msg3 := seqMsgWithOffsetAndSize(44, 100)
1013	serverErr := status.Error(codes.FailedPrecondition, "failed")
1014
1015	verifiers := test.NewVerifiers(t)
1016
1017	// Partition 1
1018	subStream1 := test.NewRPCVerifier(t)
1019	subStream1.Push(initSubReqCommit(subscriptionPartition{Path: subscription, Partition: 1}), initSubResp(), nil)
1020	subStream1.Push(initFlowControlReq(), msgSubResp(msg1), nil)
1021	msg2Barrier := subStream1.PushWithBarrier(nil, msgSubResp(msg2), nil)
1022	verifiers.AddSubscribeStream(subscription, 1, subStream1)
1023
1024	cmtStream1 := test.NewRPCVerifier(t)
1025	cmtStream1.Push(initCommitReq(subscriptionPartition{Path: subscription, Partition: 1}), initCommitResp(), nil)
1026	cmtStream1.Push(commitReq(23), commitResp(1), nil)
1027	verifiers.AddCommitStream(subscription, 1, cmtStream1)
1028
1029	// Partition 2
1030	subStream2 := test.NewRPCVerifier(t)
1031	subStream2.Push(initSubReqCommit(subscriptionPartition{Path: subscription, Partition: 2}), initSubResp(), nil)
1032	subStream2.Push(initFlowControlReq(), msgSubResp(msg3), nil)
1033	errorBarrier := subStream2.PushWithBarrier(nil, nil, serverErr)
1034	verifiers.AddSubscribeStream(subscription, 2, subStream2)
1035
1036	cmtStream2 := test.NewRPCVerifier(t)
1037	cmtStream2.Push(initCommitReq(subscriptionPartition{Path: subscription, Partition: 2}), initCommitResp(), nil)
1038	cmtStream2.Push(commitReq(45), commitResp(1), nil)
1039	verifiers.AddCommitStream(subscription, 2, cmtStream2)
1040
1041	mockServer.OnTestStart(verifiers)
1042	defer mockServer.OnTestEnd()
1043
1044	sub := newTestMultiPartitionSubscriber(t, receiver.onMessage, subscription, []int{1, 2})
1045	if gotErr := sub.WaitStarted(); gotErr != nil {
1046		t.Errorf("Start() got err: (%v)", gotErr)
1047	}
1048	receiver.ValidateMsgs(join(partitionMsgs(1, msg1), partitionMsgs(2, msg3)))
1049	errorBarrier.Release() // Release server error now to ensure test is deterministic
1050	if gotErr := sub.WaitStopped(); !test.ErrorEqual(gotErr, serverErr) {
1051		t.Errorf("Final error got: (%v), want: (%v)", gotErr, serverErr)
1052	}
1053
1054	// Verify msg2 never received as subscriber has terminated.
1055	msg2Barrier.Release()
1056	receiver.VerifyNoMsgs()
1057}
1058
1059func (as *assigningSubscriber) Partitions() []int {
1060	as.mu.Lock()
1061	defer as.mu.Unlock()
1062
1063	var partitions []int
1064	for p := range as.subscribers {
1065		partitions = append(partitions, p)
1066	}
1067	sort.Ints(partitions)
1068	return partitions
1069}
1070
1071func (as *assigningSubscriber) Subscribers() []*singlePartitionSubscriber {
1072	as.mu.Lock()
1073	defer as.mu.Unlock()
1074
1075	var subscribers []*singlePartitionSubscriber
1076	for _, s := range as.subscribers {
1077		subscribers = append(subscribers, s)
1078	}
1079	return subscribers
1080}
1081
1082func (as *assigningSubscriber) FlushCommits() {
1083	as.mu.Lock()
1084	defer as.mu.Unlock()
1085
1086	for _, sub := range as.subscribers {
1087		sub.committer.commitOffsetToStream()
1088	}
1089}
1090
1091func newTestAssigningSubscriber(t *testing.T, receiverFunc MessageReceiverFunc, subscriptionPath string) *assigningSubscriber {
1092	ctx := context.Background()
1093	subClient, err := newSubscriberClient(ctx, "ignored", testServer.ClientConn())
1094	if err != nil {
1095		t.Fatal(err)
1096	}
1097	cursorClient, err := newCursorClient(ctx, "ignored", testServer.ClientConn())
1098	if err != nil {
1099		t.Fatal(err)
1100	}
1101	assignmentClient, err := newPartitionAssignmentClient(ctx, "ignored", testServer.ClientConn())
1102	if err != nil {
1103		t.Fatal(err)
1104	}
1105	allClients := apiClients{subClient, cursorClient, assignmentClient}
1106
1107	f := &singlePartitionSubscriberFactory{
1108		ctx:              ctx,
1109		subClient:        subClient,
1110		cursorClient:     cursorClient,
1111		settings:         testSubscriberSettings(),
1112		subscriptionPath: subscriptionPath,
1113		receiver:         receiverFunc,
1114		disableTasks:     true, // Background tasks disabled to control event order
1115	}
1116	sub, err := newAssigningSubscriber(allClients, assignmentClient, fakeGenerateUUID, f)
1117	if err != nil {
1118		t.Fatal(err)
1119	}
1120	sub.Start()
1121	return sub
1122}
1123
1124func TestAssigningSubscriberAddRemovePartitions(t *testing.T) {
1125	const subscription = "projects/123456/locations/us-central1-b/subscriptions/my-sub"
1126	receiver := newTestMessageReceiver(t)
1127	msg1 := seqMsgWithOffsetAndSize(33, 100)
1128	msg2 := seqMsgWithOffsetAndSize(34, 200)
1129	msg3 := seqMsgWithOffsetAndSize(66, 100)
1130	msg4 := seqMsgWithOffsetAndSize(67, 100)
1131	msg5 := seqMsgWithOffsetAndSize(88, 100)
1132
1133	verifiers := test.NewVerifiers(t)
1134
1135	// Assignment stream
1136	asnStream := test.NewRPCVerifier(t)
1137	asnStream.Push(initAssignmentReq(subscription, fakeUUID[:]), assignmentResp([]int64{3, 6}), nil)
1138	assignmentBarrier := asnStream.PushWithBarrier(assignmentAckReq(), assignmentResp([]int64{3, 8}), nil)
1139	asnStream.Push(assignmentAckReq(), nil, nil)
1140	verifiers.AddAssignmentStream(subscription, asnStream)
1141
1142	// Partition 3
1143	subStream3 := test.NewRPCVerifier(t)
1144	subStream3.Push(initSubReqCommit(subscriptionPartition{Path: subscription, Partition: 3}), initSubResp(), nil)
1145	subStream3.Push(initFlowControlReq(), msgSubResp(msg1), nil)
1146	msg2Barrier := subStream3.PushWithBarrier(nil, msgSubResp(msg2), nil)
1147	verifiers.AddSubscribeStream(subscription, 3, subStream3)
1148
1149	cmtStream3 := test.NewRPCVerifier(t)
1150	cmtStream3.Push(initCommitReq(subscriptionPartition{Path: subscription, Partition: 3}), initCommitResp(), nil)
1151	cmtStream3.Push(commitReq(34), commitResp(1), nil)
1152	cmtStream3.Push(commitReq(35), commitResp(1), nil)
1153	verifiers.AddCommitStream(subscription, 3, cmtStream3)
1154
1155	// Partition 6
1156	subStream6 := test.NewRPCVerifier(t)
1157	subStream6.Push(initSubReqCommit(subscriptionPartition{Path: subscription, Partition: 6}), initSubResp(), nil)
1158	subStream6.Push(initFlowControlReq(), msgSubResp(msg3), nil)
1159	// msg4 should not be received.
1160	msg4Barrier := subStream6.PushWithBarrier(nil, msgSubResp(msg4), nil)
1161	verifiers.AddSubscribeStream(subscription, 6, subStream6)
1162
1163	cmtStream6 := test.NewRPCVerifier(t)
1164	cmtStream6.Push(initCommitReq(subscriptionPartition{Path: subscription, Partition: 6}), initCommitResp(), nil)
1165	cmtStream6.Push(commitReq(67), commitResp(1), nil)
1166	verifiers.AddCommitStream(subscription, 6, cmtStream6)
1167
1168	// Partition 8
1169	subStream8 := test.NewRPCVerifier(t)
1170	subStream8.Push(initSubReqCommit(subscriptionPartition{Path: subscription, Partition: 8}), initSubResp(), nil)
1171	subStream8.Push(initFlowControlReq(), msgSubResp(msg5), nil)
1172	verifiers.AddSubscribeStream(subscription, 8, subStream8)
1173
1174	cmtStream8 := test.NewRPCVerifier(t)
1175	cmtStream8.Push(initCommitReq(subscriptionPartition{Path: subscription, Partition: 8}), initCommitResp(), nil)
1176	cmtStream8.Push(commitReq(89), commitResp(1), nil)
1177	verifiers.AddCommitStream(subscription, 8, cmtStream8)
1178
1179	mockServer.OnTestStart(verifiers)
1180	defer mockServer.OnTestEnd()
1181
1182	sub := newTestAssigningSubscriber(t, receiver.onMessage, subscription)
1183	if gotErr := sub.WaitStarted(); gotErr != nil {
1184		t.Errorf("Start() got err: (%v)", gotErr)
1185	}
1186
1187	// Partition assignments are initially {3, 6}.
1188	receiver.ValidateMsgs(join(partitionMsgs(3, msg1), partitionMsgs(6, msg3)))
1189	if got, want := sub.Partitions(), []int{3, 6}; !testutil.Equal(got, want) {
1190		t.Errorf("subscriber partitions: got %d, want %d", got, want)
1191	}
1192
1193	// Partition assignments will now be {3, 8}.
1194	assignmentBarrier.Release()
1195	receiver.ValidateMsgs(partitionMsgs(8, msg5))
1196	if got, want := sub.Partitions(), []int{3, 8}; !testutil.Equal(got, want) {
1197		t.Errorf("subscriber partitions: got %d, want %d", got, want)
1198	}
1199
1200	// msg2 is from partition 3 and should be received. msg4 is from partition 6
1201	// (removed) and should be discarded.
1202	sub.FlushCommits()
1203	msg2Barrier.Release()
1204	msg4Barrier.Release()
1205	receiver.ValidateMsgs(partitionMsgs(3, msg2))
1206
1207	// Stop should flush all commit cursors.
1208	sub.Stop()
1209	if gotErr := sub.WaitStopped(); gotErr != nil {
1210		t.Errorf("Stop() got err: (%v)", gotErr)
1211	}
1212}
1213
1214func TestAssigningSubscriberPermanentError(t *testing.T) {
1215	const subscription = "projects/123456/locations/us-central1-b/subscriptions/my-sub"
1216	receiver := newTestMessageReceiver(t)
1217	msg1 := seqMsgWithOffsetAndSize(11, 100)
1218	msg2 := seqMsgWithOffsetAndSize(22, 200)
1219	serverErr := status.Error(codes.FailedPrecondition, "failed")
1220
1221	verifiers := test.NewVerifiers(t)
1222
1223	// Assignment stream
1224	asnStream := test.NewRPCVerifier(t)
1225	asnStream.Push(initAssignmentReq(subscription, fakeUUID[:]), assignmentResp([]int64{1, 2}), nil)
1226	errBarrier := asnStream.PushWithBarrier(assignmentAckReq(), nil, serverErr)
1227	verifiers.AddAssignmentStream(subscription, asnStream)
1228
1229	// Partition 1
1230	subStream1 := test.NewRPCVerifier(t)
1231	subStream1.Push(initSubReqCommit(subscriptionPartition{Path: subscription, Partition: 1}), initSubResp(), nil)
1232	subStream1.Push(initFlowControlReq(), msgSubResp(msg1), nil)
1233	verifiers.AddSubscribeStream(subscription, 1, subStream1)
1234
1235	cmtStream1 := test.NewRPCVerifier(t)
1236	cmtStream1.Push(initCommitReq(subscriptionPartition{Path: subscription, Partition: 1}), initCommitResp(), nil)
1237	cmtStream1.Push(commitReq(12), commitResp(1), nil)
1238	verifiers.AddCommitStream(subscription, 1, cmtStream1)
1239
1240	// Partition 2
1241	subStream2 := test.NewRPCVerifier(t)
1242	subStream2.Push(initSubReqCommit(subscriptionPartition{Path: subscription, Partition: 2}), initSubResp(), nil)
1243	subStream2.Push(initFlowControlReq(), msgSubResp(msg2), nil)
1244	verifiers.AddSubscribeStream(subscription, 2, subStream2)
1245
1246	cmtStream2 := test.NewRPCVerifier(t)
1247	cmtStream2.Push(initCommitReq(subscriptionPartition{Path: subscription, Partition: 2}), initCommitResp(), nil)
1248	cmtStream2.Push(commitReq(23), commitResp(1), nil)
1249	verifiers.AddCommitStream(subscription, 2, cmtStream2)
1250
1251	mockServer.OnTestStart(verifiers)
1252	defer mockServer.OnTestEnd()
1253
1254	sub := newTestAssigningSubscriber(t, receiver.onMessage, subscription)
1255	if gotErr := sub.WaitStarted(); gotErr != nil {
1256		t.Errorf("Start() got err: (%v)", gotErr)
1257	}
1258	receiver.ValidateMsgs(join(partitionMsgs(1, msg1), partitionMsgs(2, msg2)))
1259
1260	// Permanent assignment stream error should terminate subscriber. Commits are
1261	// still flushed.
1262	errBarrier.Release()
1263	if gotErr := sub.WaitStopped(); !test.ErrorEqual(gotErr, serverErr) {
1264		t.Errorf("Final error got: (%v), want: (%v)", gotErr, serverErr)
1265	}
1266}
1267
1268func TestAssigningSubscriberIgnoreOutstandingAcks(t *testing.T) {
1269	const subscription = "projects/123456/locations/us-central1-b/subscriptions/my-sub"
1270	receiver := newTestMessageReceiver(t)
1271	msg1 := seqMsgWithOffsetAndSize(11, 100)
1272	msg2 := seqMsgWithOffsetAndSize(22, 200)
1273
1274	verifiers := test.NewVerifiers(t)
1275
1276	// Assignment stream
1277	asnStream := test.NewRPCVerifier(t)
1278	asnStream.Push(initAssignmentReq(subscription, fakeUUID[:]), assignmentResp([]int64{1}), nil)
1279	assignmentBarrier1 := asnStream.PushWithBarrier(assignmentAckReq(), assignmentResp([]int64{}), nil)
1280	assignmentBarrier2 := asnStream.PushWithBarrier(assignmentAckReq(), nil, nil)
1281	verifiers.AddAssignmentStream(subscription, asnStream)
1282
1283	// Partition 1
1284	subStream := test.NewRPCVerifier(t)
1285	subStream.Push(initSubReqCommit(subscriptionPartition{Path: subscription, Partition: 1}), initSubResp(), nil)
1286	subStream.Push(initFlowControlReq(), msgSubResp(msg1, msg2), nil)
1287	verifiers.AddSubscribeStream(subscription, 1, subStream)
1288
1289	cmtStream := test.NewRPCVerifier(t)
1290	cmtStream.Push(initCommitReq(subscriptionPartition{Path: subscription, Partition: 1}), initCommitResp(), nil)
1291	cmtStream.Push(commitReq(12), commitResp(1), nil)
1292	verifiers.AddCommitStream(subscription, 1, cmtStream)
1293
1294	mockServer.OnTestStart(verifiers)
1295	defer mockServer.OnTestEnd()
1296
1297	sub := newTestAssigningSubscriber(t, receiver.onMessage, subscription)
1298	if gotErr := sub.WaitStarted(); gotErr != nil {
1299		t.Errorf("Start() got err: (%v)", gotErr)
1300	}
1301
1302	// Partition assignments are initially {1}.
1303	receiver.ValidateMsg(msg1).Ack()
1304	ack2 := receiver.ValidateMsg(msg2)
1305	subscribers := sub.Subscribers()
1306
1307	// Partition assignments will now be {}.
1308	assignmentBarrier1.Release()
1309	assignmentBarrier2.ReleaseAfter(func() {
1310		// Verify that the assignment is acked after the subscriber has terminated.
1311		if got, want := len(subscribers), 1; got != want {
1312			t.Errorf("singlePartitionSubcriber count: got %d, want %d", got, want)
1313			return
1314		}
1315		if got, want := subscribers[0].Status(), serviceTerminated; got != want {
1316			t.Errorf("singlePartitionSubcriber status: got %v, want %v", got, want)
1317		}
1318	})
1319
1320	// Partition 1 has already been unassigned, so this ack is discarded.
1321	ack2.Ack()
1322
1323	sub.Stop()
1324	if gotErr := sub.WaitStopped(); gotErr != nil {
1325		t.Errorf("Stop() got err: (%v)", gotErr)
1326	}
1327}
1328
1329func TestNewSubscriberValidatesSettings(t *testing.T) {
1330	const subscription = "projects/123456/locations/us-central1-b/subscriptions/my-sub"
1331	const region = "us-central1"
1332	receiver := newTestMessageReceiver(t)
1333
1334	settings := DefaultReceiveSettings
1335	settings.MaxOutstandingMessages = 0
1336	if _, err := NewSubscriber(context.Background(), settings, receiver.onMessage, region, subscription); err == nil {
1337		t.Error("NewSubscriber() did not return error")
1338	}
1339}
1340