1// Copyright 2020 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13
14package wire
15
16import (
17	"context"
18	"io"
19	"reflect"
20	"sync"
21	"time"
22
23	"golang.org/x/xerrors"
24	"google.golang.org/grpc"
25
26	gax "github.com/googleapis/gax-go/v2"
27)
28
29// streamStatus is the status of a retryableStream. A stream starts off
30// uninitialized. While it is active, it can transition between reconnecting and
31// connected due to retryable errors. When a permanent error occurs, the stream
32// is terminated and cannot be reconnected.
33type streamStatus int
34
35const (
36	streamUninitialized streamStatus = iota
37	streamReconnecting
38	streamResetState
39	streamConnected
40	streamTerminated
41)
42
43type initialResponseRequired bool
44type notifyReset bool
45
46// streamHandler provides hooks for different Pub/Sub Lite streaming APIs
47// (e.g. publish, subscribe, streaming cursor, etc.) to use retryableStream.
48// All Pub/Sub Lite streaming APIs implement a similar handshaking protocol,
49// where an initial request and response must be transmitted before other
50// requests can be sent over the stream.
51//
52// streamHandler methods must not be called while holding retryableStream.mu in
53// order to prevent the streamHandler calling back into the retryableStream and
54// deadlocking.
55//
56// If any streamHandler method implementations block, this will block the
57// retryableStream.connectStream goroutine processing the underlying stream.
58type streamHandler interface {
59	// newStream implementations must create the client stream with the given
60	// (cancellable) context.
61	newStream(context.Context) (grpc.ClientStream, error)
62	// initialRequest should return the initial request and whether an initial
63	// response is expected.
64	initialRequest() (interface{}, initialResponseRequired)
65	validateInitialResponse(interface{}) error
66
67	// onStreamStatusChange is used to notify stream handlers when the stream has
68	// changed state.
69	// - A `streamReconnecting` status change is fired before attempting to
70	//   connect a new stream.
71	// - A `streamResetState` status change may be fired if the stream should
72	//   reset its state (due to receipt of the RESET signal from the server).
73	// - A `streamConnected` status change is fired when the stream is
74	//   successfully connected.
75	// These are followed by onResponse() calls when responses are received from
76	// the server. These events are guaranteed to occur in this order.
77	//
78	// A final `streamTerminated` status change is fired when a permanent error
79	// occurs. retryableStream.Error() returns the error that caused the stream to
80	// terminate.
81	onStreamStatusChange(streamStatus)
82	// onResponse forwards a response received on the stream to the stream
83	// handler.
84	onResponse(interface{})
85}
86
87// retryableStream is a wrapper around a bidirectional gRPC client stream to
88// handle automatic reconnection when the stream breaks.
89//
90// The connectStream() goroutine handles each stream connection. terminate() can
91// be called at any time, either by the client to force stream closure, or as a
92// result of an unretryable error.
93//
94// Safe to call capitalized methods from multiple goroutines. All other methods
95// are private implementation.
96type retryableStream struct {
97	// Immutable after creation.
98	ctx          context.Context
99	handler      streamHandler
100	responseType reflect.Type
101	timeout      time.Duration
102
103	// Guards access to fields below.
104	mu sync.Mutex
105
106	// The current connected stream.
107	stream grpc.ClientStream
108	// Function to cancel the current stream (which may be reconnecting).
109	cancelStream context.CancelFunc
110	status       streamStatus
111	finalErr     error
112}
113
114// newRetryableStream creates a new retryable stream wrapper. `timeout` is the
115// maximum duration for reconnection. `responseType` is the type of the response
116// proto received on the stream.
117func newRetryableStream(ctx context.Context, handler streamHandler, timeout time.Duration, responseType reflect.Type) *retryableStream {
118	return &retryableStream{
119		ctx:          ctx,
120		handler:      handler,
121		responseType: responseType,
122		timeout:      timeout,
123	}
124}
125
126// Start establishes a stream connection. It is a no-op if the stream has
127// already started.
128func (rs *retryableStream) Start() {
129	rs.mu.Lock()
130	defer rs.mu.Unlock()
131
132	if rs.status == streamUninitialized {
133		go rs.connectStream(notifyReset(false))
134	}
135}
136
137// Stop gracefully closes the stream without error.
138func (rs *retryableStream) Stop() {
139	rs.terminate(nil)
140}
141
142// Send attempts to send the request to the underlying stream and returns true
143// if successfully sent. Returns false if an error occurred or a reconnection is
144// in progress.
145func (rs *retryableStream) Send(request interface{}) (sent bool) {
146	rs.mu.Lock()
147	defer rs.mu.Unlock()
148
149	if rs.stream != nil {
150		err := rs.stream.SendMsg(request)
151		// Note: if SendMsg returns an error, the stream is aborted.
152		switch {
153		case err == nil:
154			sent = true
155		case err == io.EOF:
156			// If SendMsg returns io.EOF, RecvMsg will return the status of the
157			// stream. Nothing to do here.
158			break
159		case isRetryableSendError(err):
160			go rs.connectStream(notifyReset(false))
161		default:
162			rs.unsafeTerminate(err)
163		}
164	}
165	return
166}
167
168// Status returns the current status of the retryable stream.
169func (rs *retryableStream) Status() streamStatus {
170	rs.mu.Lock()
171	defer rs.mu.Unlock()
172	return rs.status
173}
174
175// Error returns the error that caused the stream to terminate. Can be nil if it
176// was initiated by Stop().
177func (rs *retryableStream) Error() error {
178	rs.mu.Lock()
179	defer rs.mu.Unlock()
180	return rs.finalErr
181}
182
183func (rs *retryableStream) currentStream() grpc.ClientStream {
184	rs.mu.Lock()
185	defer rs.mu.Unlock()
186	return rs.stream
187}
188
189// unsafeClearStream must be called with the retryableStream.mu locked.
190func (rs *retryableStream) unsafeClearStream() {
191	if rs.cancelStream != nil {
192		// If the stream did not already abort due to error, this will abort it.
193		rs.cancelStream()
194		rs.cancelStream = nil
195	}
196	if rs.stream != nil {
197		rs.stream = nil
198	}
199}
200
201func (rs *retryableStream) setCancel(cancel context.CancelFunc) {
202	rs.mu.Lock()
203	defer rs.mu.Unlock()
204
205	rs.unsafeClearStream()
206	rs.cancelStream = cancel
207}
208
209// connectStream attempts to establish a valid connection with the server. Due
210// to the potential high latency, initNewStream() should not be done while
211// holding retryableStream.mu. Hence we need to handle the stream being force
212// terminated during reconnection.
213//
214// Intended to be called in a goroutine. It ends once the client stream closes.
215func (rs *retryableStream) connectStream(notifyReset notifyReset) {
216	canReconnect := func() bool {
217		rs.mu.Lock()
218		defer rs.mu.Unlock()
219
220		if rs.status == streamReconnecting {
221			// There can only be 1 goroutine reconnecting.
222			return false
223		}
224		if rs.status == streamTerminated {
225			return false
226		}
227		rs.status = streamReconnecting
228		rs.unsafeClearStream()
229		return true
230	}
231	if !canReconnect() {
232		return
233	}
234
235	rs.handler.onStreamStatusChange(streamReconnecting)
236	if notifyReset {
237		rs.handler.onStreamStatusChange(streamResetState)
238	}
239	// Check whether handler terminated stream before reconnecting.
240	if rs.Status() == streamTerminated {
241		return
242	}
243
244	newStream, cancelFunc, err := rs.initNewStream()
245	if err != nil {
246		rs.terminate(err)
247		return
248	}
249
250	connected := func() bool {
251		rs.mu.Lock()
252		defer rs.mu.Unlock()
253
254		if rs.status == streamTerminated {
255			rs.unsafeClearStream()
256			return false
257		}
258		rs.status = streamConnected
259		rs.stream = newStream
260		rs.cancelStream = cancelFunc
261		return true
262	}
263	if !connected() {
264		return
265	}
266
267	rs.handler.onStreamStatusChange(streamConnected)
268	rs.listen(newStream)
269}
270
271func (rs *retryableStream) initNewStream() (newStream grpc.ClientStream, cancelFunc context.CancelFunc, err error) {
272	r := newStreamRetryer(rs.timeout)
273	for {
274		backoff, shouldRetry := func() (time.Duration, bool) {
275			var cctx context.Context
276			cctx, cancelFunc = context.WithCancel(rs.ctx)
277			// Store the cancel func to quickly cancel reconnecting if the stream is
278			// terminated.
279			rs.setCancel(cancelFunc)
280
281			newStream, err = rs.handler.newStream(cctx)
282			if err != nil {
283				return r.RetryRecv(err)
284			}
285			initReq, needsResponse := rs.handler.initialRequest()
286			if err = newStream.SendMsg(initReq); err != nil {
287				return r.RetrySend(err)
288			}
289			if needsResponse {
290				response := reflect.New(rs.responseType).Interface()
291				if err = newStream.RecvMsg(response); err != nil {
292					if isStreamResetSignal(err) {
293						rs.handler.onStreamStatusChange(streamResetState)
294					}
295					return r.RetryRecv(err)
296				}
297				if err = rs.handler.validateInitialResponse(response); err != nil {
298					// An unexpected initial response from the server is a permanent error.
299					cancelFunc()
300					return 0, false
301				}
302			}
303
304			// We have a valid connection and should break from the outer loop.
305			return 0, false
306		}()
307
308		if (shouldRetry || err != nil) && cancelFunc != nil {
309			// Ensure that streams aren't leaked.
310			cancelFunc()
311			cancelFunc = nil
312			newStream = nil
313		}
314		if !shouldRetry || rs.Status() == streamTerminated {
315			break
316		}
317		if r.ExceededDeadline() {
318			err = xerrors.Errorf("%v: %w", err, ErrBackendUnavailable)
319			break
320		}
321		if err = gax.Sleep(rs.ctx, backoff); err != nil {
322			break
323		}
324	}
325	return
326}
327
328// listen receives responses from the current stream. It initiates reconnection
329// upon retryable errors or terminates the stream upon permanent error.
330func (rs *retryableStream) listen(recvStream grpc.ClientStream) {
331	for {
332		response := reflect.New(rs.responseType).Interface()
333		err := recvStream.RecvMsg(response)
334
335		// If the current stream has changed while listening, any errors or messages
336		// received now are obsolete. Discard and end the goroutine. Assume the
337		// stream has been cancelled elsewhere.
338		if rs.currentStream() != recvStream {
339			break
340		}
341		if err != nil {
342			if isRetryableRecvError(err) {
343				go rs.connectStream(notifyReset(isStreamResetSignal(err)))
344			} else {
345				rs.terminate(err)
346			}
347			break
348		}
349		rs.handler.onResponse(response)
350	}
351}
352
353// terminate forces the stream to terminate with the given error (can be nil)
354// Is a no-op if the stream has already terminated.
355func (rs *retryableStream) terminate(err error) {
356	rs.mu.Lock()
357	defer rs.mu.Unlock()
358	rs.unsafeTerminate(err)
359}
360
361func (rs *retryableStream) unsafeTerminate(err error) {
362	if rs.status == streamTerminated {
363		return
364	}
365	rs.status = streamTerminated
366	rs.finalErr = err
367	rs.unsafeClearStream()
368
369	// terminate can be called from within a streamHandler method with a lock
370	// held. So notify from a goroutine to prevent deadlock.
371	go rs.handler.onStreamStatusChange(streamTerminated)
372}
373