1// Copyright 2018 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//      http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package pubsub
16
17import (
18	"context"
19	"io"
20	"sync"
21	"time"
22
23	gax "github.com/googleapis/gax-go/v2"
24	pb "google.golang.org/genproto/googleapis/pubsub/v1"
25	"google.golang.org/grpc"
26)
27
28// A pullStream supports the methods of a StreamingPullClient, but re-opens
29// the stream on a retryable error.
30type pullStream struct {
31	ctx    context.Context
32	open   func() (pb.Subscriber_StreamingPullClient, error)
33	cancel context.CancelFunc
34
35	mu  sync.Mutex
36	spc *pb.Subscriber_StreamingPullClient
37	err error // permanent error
38}
39
40// for testing
41type streamingPullFunc func(context.Context, ...gax.CallOption) (pb.Subscriber_StreamingPullClient, error)
42
43func newPullStream(ctx context.Context, streamingPull streamingPullFunc, subName string, maxOutstandingMessages, maxOutstandingBytes int, maxDurationPerLeaseExtension time.Duration) *pullStream {
44	ctx = withSubscriptionKey(ctx, subName)
45	ctx, cancel := context.WithCancel(ctx)
46	return &pullStream{
47		ctx:    ctx,
48		cancel: cancel,
49		open: func() (pb.Subscriber_StreamingPullClient, error) {
50			spc, err := streamingPull(ctx, gax.WithGRPCOptions(grpc.MaxCallRecvMsgSize(maxSendRecvBytes)))
51			if err == nil {
52				recordStat(ctx, StreamRequestCount, 1)
53				streamAckDeadline := int32(maxDurationPerLeaseExtension / time.Second)
54				// By default, maxDurationPerLeaseExtension, aka MaxExtensionPeriod, is disabled,
55				// so in these cases, use a healthy default of 60 seconds.
56				if streamAckDeadline <= 0 {
57					streamAckDeadline = 60
58				}
59				err = spc.Send(&pb.StreamingPullRequest{
60					Subscription:             subName,
61					StreamAckDeadlineSeconds: streamAckDeadline,
62					MaxOutstandingMessages:   int64(maxOutstandingMessages),
63					MaxOutstandingBytes:      int64(maxOutstandingBytes),
64				})
65			}
66			if err != nil {
67				return nil, err
68			}
69			return spc, nil
70		},
71	}
72}
73
74// get returns either a valid *StreamingPullClient (SPC), or a permanent error.
75// If the argument is nil, this is the first call for an RPC, and the current
76// SPC will be returned (or a new one will be opened). Otherwise, this call is a
77// request to re-open the stream because of a retryable error, and the argument
78// is a pointer to the SPC that returned the error.
79func (s *pullStream) get(spc *pb.Subscriber_StreamingPullClient) (*pb.Subscriber_StreamingPullClient, error) {
80	s.mu.Lock()
81	defer s.mu.Unlock()
82	// A stored error is permanent.
83	if s.err != nil {
84		return nil, s.err
85	}
86	// If the context is done, so are we.
87	s.err = s.ctx.Err()
88	if s.err != nil {
89		return nil, s.err
90	}
91
92	// If the current and argument SPCs differ, return the current one. This subsumes two cases:
93	// 1. We have an SPC and the caller is getting the stream for the first time.
94	// 2. The caller wants to retry, but they have an older SPC; we've already retried.
95	if spc != s.spc {
96		return s.spc, nil
97	}
98	// Either this is the very first call on this stream (s.spc == nil), or we have a valid
99	// retry request. Either way, open a new stream.
100	// The lock is held here for a long time, but it doesn't matter because no callers could get
101	// anything done anyway.
102	s.spc = new(pb.Subscriber_StreamingPullClient)
103	*s.spc, s.err = s.openWithRetry() // Any error from openWithRetry is permanent.
104	return s.spc, s.err
105}
106
107func (s *pullStream) openWithRetry() (pb.Subscriber_StreamingPullClient, error) {
108	r := defaultRetryer{}
109	for {
110		recordStat(s.ctx, StreamOpenCount, 1)
111		spc, err := s.open()
112		bo, shouldRetry := r.Retry(err)
113		if err != nil && shouldRetry {
114			recordStat(s.ctx, StreamRetryCount, 1)
115			if err := gax.Sleep(s.ctx, bo); err != nil {
116				return nil, err
117			}
118			continue
119		}
120		return spc, err
121	}
122}
123
124func (s *pullStream) call(f func(pb.Subscriber_StreamingPullClient) error, opts ...gax.CallOption) error {
125	var settings gax.CallSettings
126	for _, opt := range opts {
127		opt.Resolve(&settings)
128	}
129	var r gax.Retryer = &defaultRetryer{}
130	if settings.Retry != nil {
131		r = settings.Retry()
132	}
133
134	var (
135		spc *pb.Subscriber_StreamingPullClient
136		err error
137	)
138	for {
139		spc, err = s.get(spc)
140		if err != nil {
141			return err
142		}
143		start := time.Now()
144		err = f(*spc)
145		if err != nil {
146			bo, shouldRetry := r.Retry(err)
147			if shouldRetry {
148				recordStat(s.ctx, StreamRetryCount, 1)
149				if time.Since(start) < 30*time.Second { // don't sleep if we've been blocked for a while
150					if err := gax.Sleep(s.ctx, bo); err != nil {
151						return err
152					}
153				}
154				continue
155			}
156			s.mu.Lock()
157			s.err = err
158			s.mu.Unlock()
159		}
160		return err
161	}
162}
163
164func (s *pullStream) Send(req *pb.StreamingPullRequest) error {
165	return s.call(func(spc pb.Subscriber_StreamingPullClient) error {
166		recordStat(s.ctx, AckCount, int64(len(req.AckIds)))
167		zeroes := 0
168		for _, mds := range req.ModifyDeadlineSeconds {
169			if mds == 0 {
170				zeroes++
171			}
172		}
173		recordStat(s.ctx, NackCount, int64(zeroes))
174		recordStat(s.ctx, ModAckCount, int64(len(req.ModifyDeadlineSeconds)-zeroes))
175		recordStat(s.ctx, StreamRequestCount, 1)
176		return spc.Send(req)
177	})
178}
179
180func (s *pullStream) Recv() (*pb.StreamingPullResponse, error) {
181	var res *pb.StreamingPullResponse
182	err := s.call(func(spc pb.Subscriber_StreamingPullClient) error {
183		var err error
184		recordStat(s.ctx, StreamResponseCount, 1)
185		res, err = spc.Recv()
186		return err
187	}, gax.WithRetry(func() gax.Retryer { return &streamingPullRetryer{defaultRetryer: &defaultRetryer{}} }))
188	return res, err
189}
190
191func (s *pullStream) CloseSend() error {
192	err := s.call(func(spc pb.Subscriber_StreamingPullClient) error {
193		return spc.CloseSend()
194	})
195	s.mu.Lock()
196	s.err = io.EOF // should not be retried
197	s.mu.Unlock()
198	return err
199}
200