1// Copyright 2020 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13
14package wire
15
16import (
17	"context"
18	"errors"
19	"reflect"
20	"sync"
21	"time"
22
23	"github.com/google/uuid"
24	"google.golang.org/api/option"
25	"google.golang.org/grpc"
26
27	vkit "cloud.google.com/go/pubsublite/apiv1"
28	pb "google.golang.org/genproto/googleapis/cloud/pubsublite/v1"
29)
30
31var (
32	errServerNoMessages                = errors.New("pubsublite: server delivered no messages")
33	errInvalidInitialSubscribeResponse = errors.New("pubsublite: first response from server was not an initial response for subscribe")
34	errInvalidSubscribeResponse        = errors.New("pubsublite: received unexpected subscribe response from server")
35)
36
37// ReceivedMessage stores a received Pub/Sub message and AckConsumer for
38// acknowledging the message.
39type ReceivedMessage struct {
40	Msg       *pb.SequencedMessage
41	Ack       AckConsumer
42	Partition int
43}
44
45// MessageReceiverFunc receives a Pub/Sub message from a topic partition.
46type MessageReceiverFunc func(*ReceivedMessage)
47
48// messageDeliveryQueue delivers received messages to the client-provided
49// MessageReceiverFunc sequentially. It is only accessed by the subscribeStream.
50type messageDeliveryQueue struct {
51	bufferSize int
52	acks       *ackTracker
53	receiver   MessageReceiverFunc
54	messagesC  chan *ReceivedMessage
55	stopC      chan struct{}
56	active     sync.WaitGroup
57}
58
59func newMessageDeliveryQueue(acks *ackTracker, receiver MessageReceiverFunc, bufferSize int) *messageDeliveryQueue {
60	return &messageDeliveryQueue{
61		bufferSize: bufferSize,
62		acks:       acks,
63		receiver:   receiver,
64	}
65}
66
67// Start the message delivery, if not already started.
68func (mq *messageDeliveryQueue) Start() {
69	if mq.stopC != nil {
70		return
71	}
72
73	mq.stopC = make(chan struct{})
74	mq.messagesC = make(chan *ReceivedMessage, mq.bufferSize)
75	mq.active.Add(1)
76	go mq.deliverMessages(mq.messagesC, mq.stopC)
77}
78
79// Stop message delivery and discard undelivered messages.
80func (mq *messageDeliveryQueue) Stop() {
81	if mq.stopC == nil {
82		return
83	}
84
85	close(mq.stopC)
86	mq.stopC = nil
87	mq.messagesC = nil
88}
89
90// Wait until the message delivery goroutine has terminated.
91func (mq *messageDeliveryQueue) Wait() {
92	mq.active.Wait()
93}
94
95func (mq *messageDeliveryQueue) Add(msg *ReceivedMessage) {
96	if mq.messagesC != nil {
97		mq.messagesC <- msg
98	}
99}
100
101func (mq *messageDeliveryQueue) deliverMessages(messagesC chan *ReceivedMessage, stopC chan struct{}) {
102	// Notify the wait group that the goroutine has terminated upon exit.
103	defer mq.active.Done()
104
105	for {
106		// stopC has higher priority.
107		select {
108		case <-stopC:
109			return // Ends the goroutine.
110		default:
111		}
112
113		select {
114		case <-stopC:
115			return // Ends the goroutine.
116		case msg := <-messagesC:
117			// Register outstanding acks, which are primarily handled by the
118			// `committer`.
119			mq.acks.Push(msg.Ack.(*ackConsumer))
120			mq.receiver(msg)
121		}
122	}
123}
124
125// The frequency of sending batch flow control requests.
126const batchFlowControlPeriod = 100 * time.Millisecond
127
128// Handles subscriber reset actions that are external to the subscribeStream
129// (e.g. wait for the committer to flush commits).
130type subscriberResetHandler func() error
131
132// subscribeStream directly wraps the subscribe client stream. It passes
133// messages to the message receiver and manages flow control. Flow control
134// tokens are batched and sent to the stream via a periodic background task,
135// although it can be expedited if the user is rapidly acking messages.
136//
137// Client-initiated seek unsupported.
138type subscribeStream struct {
139	// Immutable after creation.
140	subClient    *vkit.SubscriberClient
141	settings     ReceiveSettings
142	subscription subscriptionPartition
143	handleReset  subscriberResetHandler
144	metadata     pubsubMetadata
145
146	// Fields below must be guarded with mu.
147	messageQueue           *messageDeliveryQueue
148	stream                 *retryableStream
149	offsetTracker          subscriberOffsetTracker
150	flowControl            flowControlBatcher
151	pollFlowControl        *periodicTask
152	enableBatchFlowControl bool
153
154	abstractService
155}
156
157func newSubscribeStream(ctx context.Context, subClient *vkit.SubscriberClient, settings ReceiveSettings,
158	receiver MessageReceiverFunc, subscription subscriptionPartition, acks *ackTracker,
159	handleReset subscriberResetHandler, disableTasks bool) *subscribeStream {
160
161	s := &subscribeStream{
162		subClient:    subClient,
163		settings:     settings,
164		subscription: subscription,
165		handleReset:  handleReset,
166		messageQueue: newMessageDeliveryQueue(acks, receiver, settings.MaxOutstandingMessages),
167		metadata:     newPubsubMetadata(),
168	}
169	s.stream = newRetryableStream(ctx, s, settings.Timeout, reflect.TypeOf(pb.SubscribeResponse{}))
170	s.metadata.AddSubscriptionRoutingMetadata(s.subscription)
171	s.metadata.AddClientInfo(settings.Framework)
172
173	backgroundTask := s.sendBatchFlowControl
174	if disableTasks {
175		backgroundTask = func() {}
176	}
177	s.pollFlowControl = newPeriodicTask(batchFlowControlPeriod, backgroundTask)
178	return s
179}
180
181// Start establishes a subscribe stream connection and initializes flow control
182// tokens from ReceiveSettings.
183func (s *subscribeStream) Start() {
184	s.mu.Lock()
185	defer s.mu.Unlock()
186
187	if s.unsafeUpdateStatus(serviceStarting, nil) {
188		s.stream.Start()
189		s.pollFlowControl.Start()
190		s.messageQueue.Start()
191
192		s.flowControl.Reset(flowControlTokens{
193			Bytes:    int64(s.settings.MaxOutstandingBytes),
194			Messages: int64(s.settings.MaxOutstandingMessages),
195		})
196	}
197}
198
199// Stop immediately terminates the subscribe stream.
200func (s *subscribeStream) Stop() {
201	s.mu.Lock()
202	defer s.mu.Unlock()
203	s.unsafeInitiateShutdown(serviceTerminating, nil)
204}
205
206func (s *subscribeStream) newStream(ctx context.Context) (grpc.ClientStream, error) {
207	return s.subClient.Subscribe(s.metadata.AddToContext(ctx))
208}
209
210func (s *subscribeStream) initialRequest() (interface{}, initialResponseRequired) {
211	s.mu.Lock()
212	defer s.mu.Unlock()
213	initReq := &pb.SubscribeRequest{
214		Request: &pb.SubscribeRequest_Initial{
215			Initial: &pb.InitialSubscribeRequest{
216				Subscription:    s.subscription.Path,
217				Partition:       int64(s.subscription.Partition),
218				InitialLocation: s.offsetTracker.RequestForRestart(),
219			},
220		},
221	}
222	return initReq, initialResponseRequired(true)
223}
224
225func (s *subscribeStream) validateInitialResponse(response interface{}) error {
226	subscribeResponse, _ := response.(*pb.SubscribeResponse)
227	if subscribeResponse.GetInitial() == nil {
228		return errInvalidInitialSubscribeResponse
229	}
230	return nil
231}
232
233func (s *subscribeStream) onStreamStatusChange(status streamStatus) {
234	s.mu.Lock()
235	defer s.mu.Unlock()
236
237	switch status {
238	case streamConnected:
239		s.unsafeUpdateStatus(serviceActive, nil)
240
241		// Reinitialize the flow control tokens when a new subscribe stream instance
242		// is connected.
243		s.unsafeSendFlowControl(s.flowControl.RequestForRestart())
244		s.enableBatchFlowControl = true
245		s.pollFlowControl.Start()
246
247	case streamReconnecting:
248		// Ensure no batch flow control tokens are sent until the RequestForRestart
249		// is sent above when a new subscribe stream is initialized.
250		s.enableBatchFlowControl = false
251		s.pollFlowControl.Stop()
252
253	case streamResetState:
254		// Handle out-of-band seek notifications from the server. Committer and
255		// subscriber state are reset.
256
257		s.messageQueue.Stop()
258
259		// Wait for all message receiver callbacks to finish and the committer to
260		// flush pending commits and reset its state. Release the mutex while
261		// waiting.
262		s.mu.Unlock()
263		s.messageQueue.Wait()
264		err := s.handleReset()
265		s.mu.Lock()
266
267		if err != nil {
268			s.unsafeInitiateShutdown(serviceTerminating, nil)
269			return
270		}
271		s.messageQueue.Start()
272		s.offsetTracker.Reset()
273		s.flowControl.Reset(flowControlTokens{
274			Bytes:    int64(s.settings.MaxOutstandingBytes),
275			Messages: int64(s.settings.MaxOutstandingMessages),
276		})
277
278	case streamTerminated:
279		s.unsafeInitiateShutdown(serviceTerminated, s.stream.Error())
280	}
281}
282
283func (s *subscribeStream) onResponse(response interface{}) {
284	s.mu.Lock()
285	defer s.mu.Unlock()
286
287	if s.status >= serviceTerminating {
288		return
289	}
290
291	var err error
292	subscribeResponse, _ := response.(*pb.SubscribeResponse)
293	switch {
294	case subscribeResponse.GetMessages() != nil:
295		err = s.unsafeOnMessageResponse(subscribeResponse.GetMessages())
296	default:
297		err = errInvalidSubscribeResponse
298	}
299	if err != nil {
300		s.unsafeInitiateShutdown(serviceTerminated, err)
301	}
302}
303
304func (s *subscribeStream) unsafeOnMessageResponse(response *pb.MessageResponse) error {
305	if len(response.Messages) == 0 {
306		return errServerNoMessages
307	}
308	if err := s.offsetTracker.OnMessages(response.Messages); err != nil {
309		return err
310	}
311	if err := s.flowControl.OnMessages(response.Messages); err != nil {
312		return err
313	}
314
315	for _, msg := range response.Messages {
316		ack := newAckConsumer(msg.GetCursor().GetOffset(), msg.GetSizeBytes(), s.onAck)
317		s.messageQueue.Add(&ReceivedMessage{Msg: msg, Ack: ack, Partition: s.subscription.Partition})
318	}
319	return nil
320}
321
322func (s *subscribeStream) onAck(ac *ackConsumer) {
323	s.mu.Lock()
324	defer s.mu.Unlock()
325
326	if s.status == serviceActive {
327		s.unsafeAllowFlow(flowControlTokens{Bytes: ac.MsgBytes, Messages: 1})
328	}
329}
330
331// sendBatchFlowControl is called by the periodic background task.
332func (s *subscribeStream) sendBatchFlowControl() {
333	s.mu.Lock()
334	defer s.mu.Unlock()
335
336	if s.enableBatchFlowControl {
337		s.unsafeSendFlowControl(s.flowControl.ReleasePendingRequest())
338	}
339}
340
341func (s *subscribeStream) unsafeAllowFlow(allow flowControlTokens) {
342	s.flowControl.OnClientFlow(allow)
343	if s.flowControl.ShouldExpediteBatchRequest() && s.enableBatchFlowControl {
344		s.unsafeSendFlowControl(s.flowControl.ReleasePendingRequest())
345	}
346}
347
348func (s *subscribeStream) unsafeSendFlowControl(req *pb.FlowControlRequest) {
349	if req == nil {
350		return
351	}
352
353	// Note: If Send() returns false, the stream will be reconnected and
354	// flowControlBatcher.RequestForRestart() will be sent when the stream
355	// reconnects. So its return value is ignored.
356	s.stream.Send(&pb.SubscribeRequest{
357		Request: &pb.SubscribeRequest_FlowControl{FlowControl: req},
358	})
359}
360
361func (s *subscribeStream) unsafeInitiateShutdown(targetStatus serviceStatus, err error) {
362	if !s.unsafeUpdateStatus(targetStatus, wrapError("subscriber", s.subscription.String(), err)) {
363		return
364	}
365
366	// No data to send. Immediately terminate the stream.
367	s.messageQueue.Stop()
368	s.pollFlowControl.Stop()
369	s.stream.Stop()
370}
371
372// singlePartitionSubscriber receives messages from a single topic partition.
373// It requires 2 child services:
374// - subscribeStream to receive messages from the subscribe stream.
375// - committer to commit cursor offsets to the streaming commit cursor stream.
376type singlePartitionSubscriber struct {
377	subscriber *subscribeStream
378	committer  *committer
379
380	compositeService
381}
382
383// Terminate shuts down the singlePartitionSubscriber without waiting for
384// outstanding acks. Alternatively, Stop() will wait for outstanding acks.
385func (s *singlePartitionSubscriber) Terminate() {
386	s.subscriber.Stop()
387	s.committer.Terminate()
388}
389
390type singlePartitionSubscriberFactory struct {
391	ctx              context.Context
392	subClient        *vkit.SubscriberClient
393	cursorClient     *vkit.CursorClient
394	settings         ReceiveSettings
395	subscriptionPath string
396	receiver         MessageReceiverFunc
397	disableTasks     bool
398}
399
400func (f *singlePartitionSubscriberFactory) New(partition int) *singlePartitionSubscriber {
401	subscription := subscriptionPartition{Path: f.subscriptionPath, Partition: partition}
402	acks := newAckTracker()
403	commit := newCommitter(f.ctx, f.cursorClient, f.settings, subscription, acks, f.disableTasks)
404	sub := newSubscribeStream(f.ctx, f.subClient, f.settings, f.receiver, subscription, acks, commit.BlockingReset, f.disableTasks)
405	ps := &singlePartitionSubscriber{
406		subscriber: sub,
407		committer:  commit,
408	}
409	ps.init()
410	ps.unsafeAddServices(sub, commit)
411	return ps
412}
413
414// multiPartitionSubscriber receives messages from a fixed set of topic
415// partitions.
416type multiPartitionSubscriber struct {
417	// Immutable after creation.
418	subscribers []*singlePartitionSubscriber
419
420	apiClientService
421}
422
423func newMultiPartitionSubscriber(allClients apiClients, subFactory *singlePartitionSubscriberFactory) *multiPartitionSubscriber {
424	ms := &multiPartitionSubscriber{
425		apiClientService: apiClientService{clients: allClients},
426	}
427	ms.init()
428
429	for _, partition := range subFactory.settings.Partitions {
430		subscriber := subFactory.New(partition)
431		ms.unsafeAddServices(subscriber)
432		ms.subscribers = append(ms.subscribers, subscriber)
433	}
434	return ms
435}
436
437// Terminate shuts down all singlePartitionSubscribers without waiting for
438// outstanding acks. Alternatively, Stop() will wait for outstanding acks.
439func (ms *multiPartitionSubscriber) Terminate() {
440	ms.mu.Lock()
441	defer ms.mu.Unlock()
442
443	for _, sub := range ms.subscribers {
444		sub.Terminate()
445	}
446}
447
448// assigningSubscriber uses the Pub/Sub Lite partition assignment service to
449// listen to its assigned partition numbers and dynamically add/remove
450// singlePartitionSubscribers.
451type assigningSubscriber struct {
452	// Immutable after creation.
453	subFactory *singlePartitionSubscriberFactory
454	assigner   *assigner
455
456	// Fields below must be guarded with mu.
457	// Subscribers keyed by partition number. Updated as assignments change.
458	subscribers map[int]*singlePartitionSubscriber
459
460	apiClientService
461}
462
463func newAssigningSubscriber(allClients apiClients, assignmentClient *vkit.PartitionAssignmentClient, genUUID generateUUIDFunc, subFactory *singlePartitionSubscriberFactory) (*assigningSubscriber, error) {
464	as := &assigningSubscriber{
465		apiClientService: apiClientService{clients: allClients},
466		subFactory:       subFactory,
467		subscribers:      make(map[int]*singlePartitionSubscriber),
468	}
469	as.init()
470
471	assigner, err := newAssigner(subFactory.ctx, assignmentClient, genUUID, subFactory.settings, subFactory.subscriptionPath, as.handleAssignment)
472	if err != nil {
473		return nil, err
474	}
475	as.assigner = assigner
476	as.unsafeAddServices(assigner)
477	return as, nil
478}
479
480func (as *assigningSubscriber) handleAssignment(partitions partitionSet) error {
481	removedSubscribers, err := as.doHandleAssignment(partitions)
482	if err != nil {
483		return err
484	}
485
486	// Wait for removed subscribers to completely stop (which waits for commit
487	// acknowledgments from the server) before acking the assignment. This avoids
488	// commits racing with the new assigned client.
489	for _, subscriber := range removedSubscribers {
490		subscriber.WaitStopped()
491	}
492	return nil
493}
494
495func (as *assigningSubscriber) doHandleAssignment(partitions partitionSet) ([]*singlePartitionSubscriber, error) {
496	as.mu.Lock()
497	defer as.mu.Unlock()
498
499	// Handle new partitions.
500	for _, partition := range partitions.Ints() {
501		if _, exists := as.subscribers[partition]; !exists {
502			subscriber := as.subFactory.New(partition)
503			if err := as.unsafeAddServices(subscriber); err != nil {
504				// Occurs when the assigningSubscriber is stopping/stopped.
505				return nil, err
506			}
507			as.subscribers[partition] = subscriber
508		}
509	}
510
511	// Handle removed partitions.
512	var removedSubscribers []*singlePartitionSubscriber
513	for partition, subscriber := range as.subscribers {
514		if !partitions.Contains(partition) {
515			// Ignore unacked messages from this point on to avoid conflicting with
516			// the commits of the new subscriber that will be assigned this partition.
517			subscriber.Terminate()
518			removedSubscribers = append(removedSubscribers, subscriber)
519
520			as.unsafeRemoveService(subscriber)
521			// Safe to delete map entry during range loop:
522			// https://golang.org/ref/spec#For_statements
523			delete(as.subscribers, partition)
524		}
525	}
526	return removedSubscribers, nil
527}
528
529// Terminate shuts down all singlePartitionSubscribers without waiting for
530// outstanding acks. Alternatively, Stop() will wait for outstanding acks.
531func (as *assigningSubscriber) Terminate() {
532	as.mu.Lock()
533	defer as.mu.Unlock()
534
535	for _, sub := range as.subscribers {
536		sub.Terminate()
537	}
538}
539
540// Subscriber is the client interface exported from this package for receiving
541// messages.
542type Subscriber interface {
543	Start()
544	WaitStarted() error
545	Stop()
546	WaitStopped() error
547	Terminate()
548}
549
550// NewSubscriber creates a new client for receiving messages.
551func NewSubscriber(ctx context.Context, settings ReceiveSettings, receiver MessageReceiverFunc, region, subscriptionPath string, opts ...option.ClientOption) (Subscriber, error) {
552	if err := ValidateRegion(region); err != nil {
553		return nil, err
554	}
555	if err := validateReceiveSettings(settings); err != nil {
556		return nil, err
557	}
558
559	var allClients apiClients
560	subClient, err := newSubscriberClient(ctx, region, opts...)
561	if err != nil {
562		return nil, err
563	}
564	allClients = append(allClients, subClient)
565
566	cursorClient, err := newCursorClient(ctx, region, opts...)
567	if err != nil {
568		allClients.Close()
569		return nil, err
570	}
571	allClients = append(allClients, cursorClient)
572
573	subFactory := &singlePartitionSubscriberFactory{
574		ctx:              ctx,
575		subClient:        subClient,
576		cursorClient:     cursorClient,
577		settings:         settings,
578		subscriptionPath: subscriptionPath,
579		receiver:         receiver,
580	}
581
582	if len(settings.Partitions) > 0 {
583		return newMultiPartitionSubscriber(allClients, subFactory), nil
584	}
585	partitionClient, err := newPartitionAssignmentClient(ctx, region, opts...)
586	if err != nil {
587		allClients.Close()
588		return nil, err
589	}
590	allClients = append(allClients, partitionClient)
591	return newAssigningSubscriber(allClients, partitionClient, uuid.NewRandom, subFactory)
592}
593