1// Copyright 2020 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13
14package wire
15
16import (
17	"context"
18	"io"
19	"reflect"
20	"sync"
21	"time"
22
23	"golang.org/x/xerrors"
24	"google.golang.org/grpc"
25	"google.golang.org/grpc/codes"
26	"google.golang.org/grpc/status"
27
28	gax "github.com/googleapis/gax-go/v2"
29)
30
31// streamStatus is the status of a retryableStream. A stream starts off
32// uninitialized. While it is active, it can transition between reconnecting and
33// connected due to retryable errors. When a permanent error occurs, the stream
34// is terminated and cannot be reconnected.
35type streamStatus int
36
37const (
38	streamUninitialized streamStatus = iota
39	streamReconnecting
40	streamResetState
41	streamConnected
42	streamTerminated
43)
44
45// Abort a stream initialization attempt after this duration to mitigate delays.
46const defaultInitTimeout = 2 * time.Minute
47
48var errStreamInitTimeout = status.Error(codes.DeadlineExceeded, "pubsublite: stream initialization timed out")
49
50type initialResponseRequired bool
51type notifyReset bool
52
53// streamHandler provides hooks for different Pub/Sub Lite streaming APIs
54// (e.g. publish, subscribe, streaming cursor, etc.) to use retryableStream.
55// All Pub/Sub Lite streaming APIs implement a similar handshaking protocol,
56// where an initial request and response must be transmitted before other
57// requests can be sent over the stream.
58//
59// streamHandler methods must not be called while holding retryableStream.mu in
60// order to prevent the streamHandler calling back into the retryableStream and
61// deadlocking.
62//
63// If any streamHandler method implementations block, this will block the
64// retryableStream.connectStream goroutine processing the underlying stream.
65type streamHandler interface {
66	// newStream implementations must create the client stream with the given
67	// (cancellable) context.
68	newStream(context.Context) (grpc.ClientStream, error)
69	// initialRequest should return the initial request and whether an initial
70	// response is expected.
71	initialRequest() (interface{}, initialResponseRequired)
72	validateInitialResponse(interface{}) error
73
74	// onStreamStatusChange is used to notify stream handlers when the stream has
75	// changed state.
76	// - A `streamReconnecting` status change is fired before attempting to
77	//   connect a new stream.
78	// - A `streamResetState` status change may be fired if the stream should
79	//   reset its state (due to receipt of the RESET signal from the server).
80	// - A `streamConnected` status change is fired when the stream is
81	//   successfully connected.
82	// These are followed by onResponse() calls when responses are received from
83	// the server. These events are guaranteed to occur in this order.
84	//
85	// A final `streamTerminated` status change is fired when a permanent error
86	// occurs. retryableStream.Error() returns the error that caused the stream to
87	// terminate.
88	onStreamStatusChange(streamStatus)
89	// onResponse forwards a response received on the stream to the stream
90	// handler.
91	onResponse(interface{})
92}
93
94// retryableStream is a wrapper around a bidirectional gRPC client stream to
95// handle automatic reconnection when the stream breaks.
96//
97// The connectStream() goroutine handles each stream connection. terminate() can
98// be called at any time, either by the client to force stream closure, or as a
99// result of an unretryable error.
100//
101// Safe to call capitalized methods from multiple goroutines. All other methods
102// are private implementation.
103type retryableStream struct {
104	// Immutable after creation.
105	ctx            context.Context
106	handler        streamHandler
107	responseType   reflect.Type
108	connectTimeout time.Duration
109	initTimeout    time.Duration
110
111	// Guards access to fields below.
112	mu sync.Mutex
113
114	// The current connected stream.
115	stream grpc.ClientStream
116	// Function to cancel the current stream (which may be reconnecting).
117	cancelStream context.CancelFunc
118	status       streamStatus
119	finalErr     error
120}
121
122// newRetryableStream creates a new retryable stream wrapper. `timeout` is the
123// maximum duration for reconnection. `responseType` is the type of the response
124// proto received on the stream.
125func newRetryableStream(ctx context.Context, handler streamHandler, timeout time.Duration, responseType reflect.Type) *retryableStream {
126	initTimeout := defaultInitTimeout
127	if timeout < defaultInitTimeout {
128		initTimeout = timeout
129	}
130	return &retryableStream{
131		ctx:            ctx,
132		handler:        handler,
133		responseType:   responseType,
134		connectTimeout: timeout,
135		initTimeout:    initTimeout,
136	}
137}
138
139// Start establishes a stream connection. It is a no-op if the stream has
140// already started.
141func (rs *retryableStream) Start() {
142	rs.mu.Lock()
143	defer rs.mu.Unlock()
144
145	if rs.status == streamUninitialized {
146		go rs.connectStream(notifyReset(false))
147	}
148}
149
150// Stop gracefully closes the stream without error.
151func (rs *retryableStream) Stop() {
152	rs.terminate(nil)
153}
154
155// Send attempts to send the request to the underlying stream and returns true
156// if successfully sent. Returns false if an error occurred or a reconnection is
157// in progress.
158func (rs *retryableStream) Send(request interface{}) (sent bool) {
159	rs.mu.Lock()
160	defer rs.mu.Unlock()
161
162	if rs.stream != nil {
163		err := rs.stream.SendMsg(request)
164		// Note: if SendMsg returns an error, the stream is aborted.
165		switch {
166		case err == nil:
167			sent = true
168		case err == io.EOF:
169			// If SendMsg returns io.EOF, RecvMsg will return the status of the
170			// stream. Nothing to do here.
171			break
172		case isRetryableSendError(err):
173			go rs.connectStream(notifyReset(false))
174		default:
175			rs.unsafeTerminate(err)
176		}
177	}
178	return
179}
180
181// Status returns the current status of the retryable stream.
182func (rs *retryableStream) Status() streamStatus {
183	rs.mu.Lock()
184	defer rs.mu.Unlock()
185	return rs.status
186}
187
188// Error returns the error that caused the stream to terminate. Can be nil if it
189// was initiated by Stop().
190func (rs *retryableStream) Error() error {
191	rs.mu.Lock()
192	defer rs.mu.Unlock()
193	return rs.finalErr
194}
195
196func (rs *retryableStream) currentStream() grpc.ClientStream {
197	rs.mu.Lock()
198	defer rs.mu.Unlock()
199	return rs.stream
200}
201
202// unsafeClearStream must be called with the retryableStream.mu locked.
203func (rs *retryableStream) unsafeClearStream() {
204	if rs.cancelStream != nil {
205		// If the stream did not already abort due to error, this will abort it.
206		rs.cancelStream()
207		rs.cancelStream = nil
208	}
209	if rs.stream != nil {
210		rs.stream = nil
211	}
212}
213
214func (rs *retryableStream) newStreamContext() (ctx context.Context, cancel context.CancelFunc) {
215	rs.mu.Lock()
216	defer rs.mu.Unlock()
217
218	rs.unsafeClearStream()
219	ctx, cancel = context.WithCancel(rs.ctx)
220	rs.cancelStream = cancel
221	return
222}
223
224// connectStream attempts to establish a valid connection with the server. Due
225// to the potential high latency, initNewStream() should not be done while
226// holding retryableStream.mu. Hence we need to handle the stream being force
227// terminated during reconnection.
228//
229// Intended to be called in a goroutine. It ends once the client stream closes.
230func (rs *retryableStream) connectStream(notifyReset notifyReset) {
231	canReconnect := func() bool {
232		rs.mu.Lock()
233		defer rs.mu.Unlock()
234
235		if rs.status == streamReconnecting {
236			// There can only be 1 goroutine reconnecting.
237			return false
238		}
239		if rs.status == streamTerminated {
240			return false
241		}
242		rs.status = streamReconnecting
243		rs.unsafeClearStream()
244		return true
245	}
246	if !canReconnect() {
247		return
248	}
249
250	rs.handler.onStreamStatusChange(streamReconnecting)
251	if notifyReset {
252		rs.handler.onStreamStatusChange(streamResetState)
253	}
254	// Check whether handler terminated stream before reconnecting.
255	if rs.Status() == streamTerminated {
256		return
257	}
258
259	newStream, err := rs.initNewStream()
260	if err != nil {
261		rs.terminate(err)
262		return
263	}
264
265	connected := func() bool {
266		rs.mu.Lock()
267		defer rs.mu.Unlock()
268
269		if rs.status == streamTerminated {
270			rs.unsafeClearStream()
271			return false
272		}
273		rs.status = streamConnected
274		rs.stream = newStream
275		return true
276	}
277	if !connected() {
278		return
279	}
280
281	rs.handler.onStreamStatusChange(streamConnected)
282	rs.listen(newStream)
283}
284
285func (rs *retryableStream) newInitTimer(cancelFunc func()) *requestTimer {
286	return newRequestTimer(rs.initTimeout, cancelFunc, errStreamInitTimeout)
287}
288
289func (rs *retryableStream) initNewStream() (newStream grpc.ClientStream, err error) {
290	var cancelFunc context.CancelFunc
291	r := newStreamRetryer(rs.connectTimeout)
292	for {
293		backoff, shouldRetry := func() (time.Duration, bool) {
294			var cctx context.Context
295			cctx, cancelFunc = rs.newStreamContext()
296			// Bound the duration of the stream initialization attempt.
297			it := rs.newInitTimer(cancelFunc)
298			defer it.Stop()
299
300			newStream, err = rs.handler.newStream(cctx)
301			if err = it.ResolveError(err); err != nil {
302				return r.RetryRecv(err)
303			}
304			initReq, needsResponse := rs.handler.initialRequest()
305			if err = it.ResolveError(newStream.SendMsg(initReq)); err != nil {
306				return r.RetrySend(err)
307			}
308			if needsResponse {
309				response := reflect.New(rs.responseType).Interface()
310				if err = it.ResolveError(newStream.RecvMsg(response)); err != nil {
311					if isStreamResetSignal(err) {
312						rs.handler.onStreamStatusChange(streamResetState)
313					}
314					return r.RetryRecv(err)
315				}
316				if err = rs.handler.validateInitialResponse(response); err != nil {
317					// An unexpected initial response from the server is a permanent error.
318					cancelFunc()
319					return 0, false
320				}
321			}
322
323			// If the init timer fired due to a race, the stream would be unusable.
324			it.Stop()
325			if err = it.ResolveError(nil); err != nil {
326				return r.RetryRecv(err)
327			}
328
329			// We have a valid connection and should break from the outer loop.
330			return 0, false
331		}()
332
333		if (shouldRetry || err != nil) && cancelFunc != nil {
334			// Ensure that streams aren't leaked.
335			cancelFunc()
336			cancelFunc = nil
337			newStream = nil
338		}
339		if !shouldRetry || rs.Status() == streamTerminated {
340			break
341		}
342		if r.ExceededDeadline() {
343			err = xerrors.Errorf("%v: %w", err, ErrBackendUnavailable)
344			break
345		}
346		if err = gax.Sleep(rs.ctx, backoff); err != nil {
347			break
348		}
349	}
350	return
351}
352
353// listen receives responses from the current stream. It initiates reconnection
354// upon retryable errors or terminates the stream upon permanent error.
355func (rs *retryableStream) listen(recvStream grpc.ClientStream) {
356	for {
357		response := reflect.New(rs.responseType).Interface()
358		err := recvStream.RecvMsg(response)
359
360		// If the current stream has changed while listening, any errors or messages
361		// received now are obsolete. Discard and end the goroutine. Assume the
362		// stream has been cancelled elsewhere.
363		if rs.currentStream() != recvStream {
364			break
365		}
366		if err != nil {
367			if isRetryableRecvError(err) {
368				go rs.connectStream(notifyReset(isStreamResetSignal(err)))
369			} else {
370				rs.terminate(err)
371			}
372			break
373		}
374		rs.handler.onResponse(response)
375	}
376}
377
378// terminate forces the stream to terminate with the given error (can be nil)
379// Is a no-op if the stream has already terminated.
380func (rs *retryableStream) terminate(err error) {
381	rs.mu.Lock()
382	defer rs.mu.Unlock()
383	rs.unsafeTerminate(err)
384}
385
386func (rs *retryableStream) unsafeTerminate(err error) {
387	if rs.status == streamTerminated {
388		return
389	}
390	rs.status = streamTerminated
391	rs.finalErr = err
392	rs.unsafeClearStream()
393
394	// terminate can be called from within a streamHandler method with a lock
395	// held. So notify from a goroutine to prevent deadlock.
396	go rs.handler.onStreamStatusChange(streamTerminated)
397}
398