1// Copyright 2016 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package pubsub
16
17import (
18	"context"
19	"io"
20	"strings"
21	"sync"
22	"time"
23
24	vkit "cloud.google.com/go/pubsub/apiv1"
25	"cloud.google.com/go/pubsub/internal/distribution"
26	gax "github.com/googleapis/gax-go/v2"
27	pb "google.golang.org/genproto/googleapis/pubsub/v1"
28	"google.golang.org/grpc"
29	"google.golang.org/grpc/codes"
30	"google.golang.org/grpc/status"
31	"google.golang.org/protobuf/encoding/protowire"
32)
33
34// Between message receipt and ack (that is, the time spent processing a message) we want to extend the message
35// deadline by way of modack. However, we don't want to extend the deadline right as soon as the deadline expires;
36// instead, we'd want to extend the deadline a little bit of time ahead. gracePeriod is that amount of time ahead
37// of the actual deadline.
38const gracePeriod = 5 * time.Second
39
40type messageIterator struct {
41	ctx        context.Context
42	cancel     func() // the function that will cancel ctx; called in stop
43	po         *pullOptions
44	ps         *pullStream
45	subc       *vkit.SubscriberClient
46	subName    string
47	kaTick     <-chan time.Time // keep-alive (deadline extensions)
48	ackTicker  *time.Ticker     // message acks
49	nackTicker *time.Ticker     // message nacks (more frequent than acks)
50	pingTicker *time.Ticker     //  sends to the stream to keep it open
51	failed     chan struct{}    // closed on stream error
52	drained    chan struct{}    // closed when stopped && no more pending messages
53	wg         sync.WaitGroup
54
55	mu          sync.Mutex
56	ackTimeDist *distribution.D // dist uses seconds
57
58	// keepAliveDeadlines is a map of id to expiration time. This map is used in conjunction with
59	// subscription.ReceiveSettings.MaxExtension to record the maximum amount of time (the
60	// deadline, more specifically) we're willing to extend a message's ack deadline. As each
61	// message arrives, we'll record now+MaxExtension in this table; whenever we have a chance
62	// to update ack deadlines (via modack), we'll consult this table and only include IDs
63	// that are not beyond their deadline.
64	keepAliveDeadlines map[string]time.Time
65	pendingAcks        map[string]bool
66	pendingNacks       map[string]bool
67	pendingModAcks     map[string]bool // ack IDs whose ack deadline is to be modified
68	err                error           // error from stream failure
69}
70
71// newMessageIterator starts and returns a new messageIterator.
72// subName is the full name of the subscription to pull messages from.
73// Stop must be called on the messageIterator when it is no longer needed.
74// The iterator always uses the background context for acking messages and extending message deadlines.
75func newMessageIterator(subc *vkit.SubscriberClient, subName string, po *pullOptions) *messageIterator {
76	var ps *pullStream
77	if !po.synchronous {
78		maxMessages := po.maxOutstandingMessages
79		maxBytes := po.maxOutstandingBytes
80		if po.useLegacyFlowControl {
81			maxMessages = 0
82			maxBytes = 0
83		}
84		ps = newPullStream(context.Background(), subc.StreamingPull, subName, maxMessages, maxBytes, po.maxExtensionPeriod)
85	}
86	// The period will update each tick based on the distribution of acks. We'll start by arbitrarily sending
87	// the first keepAlive halfway towards the minimum ack deadline.
88	keepAlivePeriod := minAckDeadline / 2
89
90	// Ack promptly so users don't lose work if client crashes.
91	ackTicker := time.NewTicker(100 * time.Millisecond)
92	nackTicker := time.NewTicker(100 * time.Millisecond)
93	pingTicker := time.NewTicker(30 * time.Second)
94	cctx, cancel := context.WithCancel(context.Background())
95	it := &messageIterator{
96		ctx:                cctx,
97		cancel:             cancel,
98		ps:                 ps,
99		po:                 po,
100		subc:               subc,
101		subName:            subName,
102		kaTick:             time.After(keepAlivePeriod),
103		ackTicker:          ackTicker,
104		nackTicker:         nackTicker,
105		pingTicker:         pingTicker,
106		failed:             make(chan struct{}),
107		drained:            make(chan struct{}),
108		ackTimeDist:        distribution.New(int(maxAckDeadline/time.Second) + 1),
109		keepAliveDeadlines: map[string]time.Time{},
110		pendingAcks:        map[string]bool{},
111		pendingNacks:       map[string]bool{},
112		pendingModAcks:     map[string]bool{},
113	}
114	it.wg.Add(1)
115	go it.sender()
116	return it
117}
118
119// Subscription.receive will call stop on its messageIterator when finished with it.
120// Stop will block until Done has been called on all Messages that have been
121// returned by Next, or until the context with which the messageIterator was created
122// is cancelled or exceeds its deadline.
123func (it *messageIterator) stop() {
124	it.cancel()
125	it.mu.Lock()
126	it.checkDrained()
127	it.mu.Unlock()
128	it.wg.Wait()
129}
130
131// checkDrained closes the drained channel if the iterator has been stopped and all
132// pending messages have either been n/acked or expired.
133//
134// Called with the lock held.
135func (it *messageIterator) checkDrained() {
136	select {
137	case <-it.drained:
138		return
139	default:
140	}
141	select {
142	case <-it.ctx.Done():
143		if len(it.keepAliveDeadlines) == 0 {
144			close(it.drained)
145		}
146	default:
147	}
148}
149
150// Called when a message is acked/nacked.
151func (it *messageIterator) done(ackID string, ack bool, receiveTime time.Time) {
152	it.ackTimeDist.Record(int(time.Since(receiveTime) / time.Second))
153	it.mu.Lock()
154	defer it.mu.Unlock()
155	delete(it.keepAliveDeadlines, ackID)
156	if ack {
157		it.pendingAcks[ackID] = true
158	} else {
159		it.pendingNacks[ackID] = true
160	}
161	it.checkDrained()
162}
163
164// fail is called when a stream method returns a permanent error.
165// fail returns it.err. This may be err, or it may be the error
166// set by an earlier call to fail.
167func (it *messageIterator) fail(err error) error {
168	it.mu.Lock()
169	defer it.mu.Unlock()
170	if it.err == nil {
171		it.err = err
172		close(it.failed)
173	}
174	return it.err
175}
176
177// receive makes a call to the stream's Recv method, or the Pull RPC, and returns
178// its messages.
179// maxToPull is the maximum number of messages for the Pull RPC.
180func (it *messageIterator) receive(maxToPull int32) ([]*Message, error) {
181	it.mu.Lock()
182	ierr := it.err
183	it.mu.Unlock()
184	if ierr != nil {
185		return nil, ierr
186	}
187
188	// Stop retrieving messages if the iterator's Stop method was called.
189	select {
190	case <-it.ctx.Done():
191		it.wg.Wait()
192		return nil, io.EOF
193	default:
194	}
195
196	var rmsgs []*pb.ReceivedMessage
197	var err error
198	if it.po.synchronous {
199		rmsgs, err = it.pullMessages(maxToPull)
200	} else {
201		rmsgs, err = it.recvMessages()
202	}
203	// Any error here is fatal.
204	if err != nil {
205		return nil, it.fail(err)
206	}
207	recordStat(it.ctx, PullCount, int64(len(rmsgs)))
208	now := time.Now()
209	msgs, err := convertMessages(rmsgs, now, it.done)
210	if err != nil {
211		return nil, it.fail(err)
212	}
213	// We received some messages. Remember them so we can keep them alive. Also,
214	// do a receipt mod-ack when streaming.
215	maxExt := time.Now().Add(it.po.maxExtension)
216	ackIDs := map[string]bool{}
217	it.mu.Lock()
218	for _, m := range msgs {
219		ackID := msgAckID(m)
220		addRecv(m.ID, ackID, now)
221		it.keepAliveDeadlines[ackID] = maxExt
222		// Don't change the mod-ack if the message is going to be nacked. This is
223		// possible if there are retries.
224		if !it.pendingNacks[ackID] {
225			ackIDs[ackID] = true
226		}
227	}
228	deadline := it.ackDeadline()
229	it.mu.Unlock()
230	if len(ackIDs) > 0 {
231		if !it.sendModAck(ackIDs, deadline) {
232			return nil, it.err
233		}
234	}
235	return msgs, nil
236}
237
238// Get messages using the Pull RPC.
239// This may block indefinitely. It may also return zero messages, after some time waiting.
240func (it *messageIterator) pullMessages(maxToPull int32) ([]*pb.ReceivedMessage, error) {
241	// Use it.ctx as the RPC context, so that if the iterator is stopped, the call
242	// will return immediately.
243	res, err := it.subc.Pull(it.ctx, &pb.PullRequest{
244		Subscription: it.subName,
245		MaxMessages:  maxToPull,
246	}, gax.WithGRPCOptions(grpc.MaxCallRecvMsgSize(maxSendRecvBytes)))
247	switch {
248	case err == context.Canceled:
249		return nil, nil
250	case status.Code(err) == codes.Canceled:
251		return nil, nil
252	case err != nil:
253		return nil, err
254	default:
255		return res.ReceivedMessages, nil
256	}
257}
258
259func (it *messageIterator) recvMessages() ([]*pb.ReceivedMessage, error) {
260	res, err := it.ps.Recv()
261	if err != nil {
262		return nil, err
263	}
264	return res.ReceivedMessages, nil
265}
266
267// sender runs in a goroutine and handles all sends to the stream.
268func (it *messageIterator) sender() {
269	defer it.wg.Done()
270	defer it.ackTicker.Stop()
271	defer it.nackTicker.Stop()
272	defer it.pingTicker.Stop()
273	defer func() {
274		if it.ps != nil {
275			it.ps.CloseSend()
276		}
277	}()
278
279	done := false
280	for !done {
281		sendAcks := false
282		sendNacks := false
283		sendModAcks := false
284		sendPing := false
285
286		dl := it.ackDeadline()
287
288		select {
289		case <-it.failed:
290			// Stream failed: nothing to do, so stop immediately.
291			return
292
293		case <-it.drained:
294			// All outstanding messages have been marked done:
295			// nothing left to do except make the final calls.
296			it.mu.Lock()
297			sendAcks = (len(it.pendingAcks) > 0)
298			sendNacks = (len(it.pendingNacks) > 0)
299			// No point in sending modacks.
300			done = true
301
302		case <-it.kaTick:
303			it.mu.Lock()
304			it.handleKeepAlives()
305			sendModAcks = (len(it.pendingModAcks) > 0)
306
307			nextTick := dl - gracePeriod
308			if nextTick <= 0 {
309				// If the deadline is <= gracePeriod, let's tick again halfway to
310				// the deadline.
311				nextTick = dl / 2
312			}
313			it.kaTick = time.After(nextTick)
314
315		case <-it.nackTicker.C:
316			it.mu.Lock()
317			sendNacks = (len(it.pendingNacks) > 0)
318
319		case <-it.ackTicker.C:
320			it.mu.Lock()
321			sendAcks = (len(it.pendingAcks) > 0)
322
323		case <-it.pingTicker.C:
324			it.mu.Lock()
325			// Ping only if we are processing messages via streaming.
326			sendPing = !it.po.synchronous
327		}
328		// Lock is held here.
329		var acks, nacks, modAcks map[string]bool
330		if sendAcks {
331			acks = it.pendingAcks
332			it.pendingAcks = map[string]bool{}
333		}
334		if sendNacks {
335			nacks = it.pendingNacks
336			it.pendingNacks = map[string]bool{}
337		}
338		if sendModAcks {
339			modAcks = it.pendingModAcks
340			it.pendingModAcks = map[string]bool{}
341		}
342		it.mu.Unlock()
343		// Make Ack and ModAck RPCs.
344		if sendAcks {
345			if !it.sendAck(acks) {
346				return
347			}
348		}
349		if sendNacks {
350			// Nack indicated by modifying the deadline to zero.
351			if !it.sendModAck(nacks, 0) {
352				return
353			}
354		}
355		if sendModAcks {
356			if !it.sendModAck(modAcks, dl) {
357				return
358			}
359		}
360		if sendPing {
361			it.pingStream()
362		}
363	}
364}
365
366// handleKeepAlives modifies the pending request to include deadline extensions
367// for live messages. It also purges expired messages.
368//
369// Called with the lock held.
370func (it *messageIterator) handleKeepAlives() {
371	now := time.Now()
372	for id, expiry := range it.keepAliveDeadlines {
373		if expiry.Before(now) {
374			// This delete will not result in skipping any map items, as implied by
375			// the spec at https://golang.org/ref/spec#For_statements, "For
376			// statements with range clause", note 3, and stated explicitly at
377			// https://groups.google.com/forum/#!msg/golang-nuts/UciASUb03Js/pzSq5iVFAQAJ.
378			delete(it.keepAliveDeadlines, id)
379		} else {
380			// This will not conflict with a nack, because nacking removes the ID from keepAliveDeadlines.
381			it.pendingModAcks[id] = true
382		}
383	}
384	it.checkDrained()
385}
386
387func (it *messageIterator) sendAck(m map[string]bool) bool {
388	// Account for the Subscription field.
389	overhead := calcFieldSizeString(it.subName)
390	return it.sendAckIDRPC(m, maxPayload-overhead, func(ids []string) error {
391		recordStat(it.ctx, AckCount, int64(len(ids)))
392		addAcks(ids)
393		bo := gax.Backoff{
394			Initial:    100 * time.Millisecond,
395			Max:        time.Second,
396			Multiplier: 2,
397		}
398		cctx, cancel := context.WithTimeout(context.Background(), 3*time.Minute)
399		defer cancel()
400		for {
401			// Use context.Background() as the call's context, not it.ctx. We don't
402			// want to cancel this RPC when the iterator is stopped.
403			cctx2, cancel2 := context.WithTimeout(context.Background(), 60*time.Second)
404			defer cancel2()
405			err := it.subc.Acknowledge(cctx2, &pb.AcknowledgeRequest{
406				Subscription: it.subName,
407				AckIds:       ids,
408			})
409			// Retry DeadlineExceeded errors a few times before giving up and
410			// allowing the message to expire and be redelivered.
411			// The underlying library handles other retries, currently only
412			// codes.Unavailable.
413			switch status.Code(err) {
414			case codes.DeadlineExceeded:
415				// Use the outer context with timeout here. Errors from gax, including
416				// context deadline exceeded should be transparent, as unacked messages
417				// will be redelivered.
418				if err := gax.Sleep(cctx, bo.Pause()); err != nil {
419					return nil
420				}
421			default:
422				if err == nil {
423					return nil
424				}
425				// This addresses an error where `context deadline exceeded` errors
426				// not captured by the previous case causes fatal errors.
427				// See https://github.com/googleapis/google-cloud-go/issues/3060
428				if strings.Contains(err.Error(), "context deadline exceeded") {
429					// Context deadline exceeded errors here should be transparent
430					// to prevent the iterator from shutting down.
431					if err := gax.Sleep(cctx, bo.Pause()); err != nil {
432						return nil
433					}
434					continue
435				}
436				// Any other error is fatal.
437				return err
438			}
439		}
440	})
441}
442
443// The receipt mod-ack amount is derived from a percentile distribution based
444// on the time it takes to process messages. The percentile chosen is the 99%th
445// percentile in order to capture the highest amount of time necessary without
446// considering 1% outliers.
447func (it *messageIterator) sendModAck(m map[string]bool, deadline time.Duration) bool {
448	deadlineSec := int32(deadline / time.Second)
449	// Account for the Subscription and AckDeadlineSeconds fields.
450	overhead := calcFieldSizeString(it.subName) + calcFieldSizeInt(int(deadlineSec))
451	return it.sendAckIDRPC(m, maxPayload-overhead, func(ids []string) error {
452		if deadline == 0 {
453			recordStat(it.ctx, NackCount, int64(len(ids)))
454		} else {
455			recordStat(it.ctx, ModAckCount, int64(len(ids)))
456		}
457		addModAcks(ids, deadlineSec)
458		// Retry this RPC on Unavailable for a short amount of time, then give up
459		// without returning a fatal error. The utility of this RPC is by nature
460		// transient (since the deadline is relative to the current time) and it
461		// isn't crucial for correctness (since expired messages will just be
462		// resent).
463		cctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
464		defer cancel()
465		bo := gax.Backoff{
466			Initial:    100 * time.Millisecond,
467			Max:        time.Second,
468			Multiplier: 2,
469		}
470		for {
471			err := it.subc.ModifyAckDeadline(cctx, &pb.ModifyAckDeadlineRequest{
472				Subscription:       it.subName,
473				AckDeadlineSeconds: deadlineSec,
474				AckIds:             ids,
475			})
476			switch status.Code(err) {
477			case codes.Unavailable:
478				if err := gax.Sleep(cctx, bo.Pause()); err == nil {
479					continue
480				}
481				// Treat sleep timeout like RPC timeout.
482				fallthrough
483			case codes.DeadlineExceeded:
484				// Timeout. Not a fatal error, but note that it happened.
485				recordStat(it.ctx, ModAckTimeoutCount, 1)
486				return nil
487			default:
488				if err == nil {
489					return nil
490				}
491				// This addresses an error where `context deadline exceeded` errors
492				// not captured by the previous case causes fatal errors.
493				// See https://github.com/googleapis/google-cloud-go/issues/3060
494				if strings.Contains(err.Error(), "context deadline exceeded") {
495					recordStat(it.ctx, ModAckTimeoutCount, 1)
496					return nil
497				}
498				// Any other error is fatal.
499				return err
500			}
501		}
502	})
503}
504
505func (it *messageIterator) sendAckIDRPC(ackIDSet map[string]bool, maxSize int, call func([]string) error) bool {
506	ackIDs := make([]string, 0, len(ackIDSet))
507	for k := range ackIDSet {
508		ackIDs = append(ackIDs, k)
509	}
510	var toSend []string
511	for len(ackIDs) > 0 {
512		toSend, ackIDs = splitRequestIDs(ackIDs, maxSize)
513		if err := call(toSend); err != nil {
514			// The underlying client handles retries, so any error is fatal to the
515			// iterator.
516			it.fail(err)
517			return false
518		}
519	}
520	return true
521}
522
523// Send a message to the stream to keep it open. The stream will close if there's no
524// traffic on it for a while. By keeping it open, we delay the start of the
525// expiration timer on messages that are buffered by gRPC or elsewhere in the
526// network. This matters if it takes a long time to process messages relative to the
527// default ack deadline, and if the messages are small enough so that many can fit
528// into the buffer.
529func (it *messageIterator) pingStream() {
530	// Ignore error; if the stream is broken, this doesn't matter anyway.
531	_ = it.ps.Send(&pb.StreamingPullRequest{})
532}
533
534// calcFieldSizeString returns the number of bytes string fields
535// will take up in an encoded proto message.
536func calcFieldSizeString(fields ...string) int {
537	overhead := 0
538	for _, field := range fields {
539		overhead += 1 + len(field) + protowire.SizeVarint(uint64(len(field)))
540	}
541	return overhead
542}
543
544// calcFieldSizeInt returns the number of bytes int fields
545// will take up in an encoded proto message.
546func calcFieldSizeInt(fields ...int) int {
547	overhead := 0
548	for _, field := range fields {
549		overhead += 1 + protowire.SizeVarint(uint64(field))
550	}
551	return overhead
552}
553
554// splitRequestIDs takes a slice of ackIDs and returns two slices such that the first
555// ackID slice can be used in a request where the payload does not exceed maxSize.
556func splitRequestIDs(ids []string, maxSize int) (prefix, remainder []string) {
557	size := 0
558	i := 0
559	// TODO(hongalex): Use binary search to find split index, since ackIDs are
560	// fairly constant.
561	for size < maxSize && i < len(ids) {
562		size += calcFieldSizeString(ids[i])
563		i++
564	}
565	if size > maxSize {
566		i--
567	}
568	return ids[:i], ids[i:]
569}
570
571// The deadline to ack is derived from a percentile distribution based
572// on the time it takes to process messages. The percentile chosen is the 99%th
573// percentile - that is, processing times up to the 99%th longest processing
574// times should be safe. The highest 1% may expire. This number was chosen
575// as a way to cover most users' usecases without losing the value of
576// expiration.
577func (it *messageIterator) ackDeadline() time.Duration {
578	pt := time.Duration(it.ackTimeDist.Percentile(.99)) * time.Second
579
580	if it.po.maxExtensionPeriod > 0 && pt > it.po.maxExtensionPeriod {
581		return it.po.maxExtensionPeriod
582	}
583	if pt > maxAckDeadline {
584		return maxAckDeadline
585	}
586	if pt < minAckDeadline {
587		return minAckDeadline
588	}
589	return pt
590}
591