1// Copyright 2020 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13
14package wire
15
16import (
17	"context"
18	"io"
19	"reflect"
20	"sync"
21	"time"
22
23	"google.golang.org/grpc"
24
25	gax "github.com/googleapis/gax-go/v2"
26)
27
28// streamStatus is the status of a retryableStream. A stream starts off
29// uninitialized. While it is active, it can transition between reconnecting and
30// connected due to retryable errors. When a permanent error occurs, the stream
31// is terminated and cannot be reconnected.
32type streamStatus int
33
34const (
35	streamUninitialized streamStatus = 0
36	streamReconnecting  streamStatus = 1
37	streamConnected     streamStatus = 2
38	streamTerminated    streamStatus = 3
39)
40
41// streamHandler provides hooks for different Pub/Sub Lite streaming APIs
42// (e.g. publish, subscribe, streaming cursor, etc.) to use retryableStream.
43// All Pub/Sub Lite streaming APIs implement a similar handshaking protocol,
44// where an initial request and response must be transmitted before other
45// requests can be sent over the stream.
46//
47// streamHandler methods must not be called while holding retryableStream.mu in
48// order to prevent the streamHandler calling back into the retryableStream and
49// deadlocking.
50type streamHandler interface {
51	// newStream implementations must create the client stream with the given
52	// (cancellable) context.
53	newStream(context.Context) (grpc.ClientStream, error)
54	initialRequest() interface{}
55	validateInitialResponse(interface{}) error
56
57	// onStreamStatusChange is used to notify stream handlers when the stream has
58	// changed state. In particular, the `streamTerminated` state must be handled.
59	// retryableStream.Error() returns the error that caused the stream to
60	// terminate. Stream handlers should perform any necessary reset of state upon
61	// `streamConnected`.
62	onStreamStatusChange(streamStatus)
63	// onResponse forwards a response received on the stream to the stream
64	// handler.
65	onResponse(interface{})
66}
67
68// retryableStream is a wrapper around a bidirectional gRPC client stream to
69// handle automatic reconnection when the stream breaks.
70//
71// A retryableStream cycles between the following goroutines:
72//   Start() --> reconnect() <--> listen()
73// terminate() can be called at any time, either by the client to force stream
74// closure, or as a result of an unretryable error.
75//
76// Safe to call capitalized methods from multiple goroutines. All other methods
77// are private implementation.
78type retryableStream struct {
79	// Immutable after creation.
80	ctx          context.Context
81	handler      streamHandler
82	responseType reflect.Type
83	timeout      time.Duration
84
85	// Guards access to fields below.
86	mu sync.Mutex
87
88	// The current connected stream.
89	stream grpc.ClientStream
90	// Function to cancel the current stream (which may be reconnecting).
91	cancelStream context.CancelFunc
92	status       streamStatus
93	finalErr     error
94}
95
96// newRetryableStream creates a new retryable stream wrapper. `timeout` is the
97// maximum duration for reconnection. `responseType` is the type of the response
98// proto received on the stream.
99func newRetryableStream(ctx context.Context, handler streamHandler, timeout time.Duration, responseType reflect.Type) *retryableStream {
100	return &retryableStream{
101		ctx:          ctx,
102		handler:      handler,
103		responseType: responseType,
104		timeout:      timeout,
105	}
106}
107
108// Start establishes a stream connection. It is a no-op if the stream has
109// already started.
110func (rs *retryableStream) Start() {
111	rs.mu.Lock()
112	defer rs.mu.Unlock()
113
114	if rs.status != streamUninitialized {
115		return
116	}
117	go rs.reconnect()
118}
119
120// Stop gracefully closes the stream without error.
121func (rs *retryableStream) Stop() {
122	rs.terminate(nil)
123}
124
125// Send attempts to send the request to the underlying stream and returns true
126// if successfully sent. Returns false if an error occurred or a reconnection is
127// in progress.
128func (rs *retryableStream) Send(request interface{}) (sent bool) {
129	rs.mu.Lock()
130
131	if rs.stream != nil {
132		err := rs.stream.SendMsg(request)
133		// Note: if SendMsg returns an error, the stream is aborted.
134		switch {
135		case err == nil:
136			sent = true
137		case err == io.EOF:
138			// If SendMsg returns io.EOF, RecvMsg will return the status of the
139			// stream. Nothing to do here.
140			break
141		case isRetryableSendError(err):
142			go rs.reconnect()
143		default:
144			rs.mu.Unlock() // terminate acquires the mutex.
145			rs.terminate(err)
146			return
147		}
148	}
149
150	rs.mu.Unlock()
151	return
152}
153
154// Status returns the current status of the retryable stream.
155func (rs *retryableStream) Status() streamStatus {
156	rs.mu.Lock()
157	defer rs.mu.Unlock()
158	return rs.status
159}
160
161// Error returns the error that caused the stream to terminate. Can be nil if it
162// was initiated by Stop().
163func (rs *retryableStream) Error() error {
164	rs.mu.Lock()
165	defer rs.mu.Unlock()
166	return rs.finalErr
167}
168
169func (rs *retryableStream) currentStream() grpc.ClientStream {
170	rs.mu.Lock()
171	defer rs.mu.Unlock()
172	return rs.stream
173}
174
175// unsafeClearStream must be called with the retryableStream.mu locked.
176func (rs *retryableStream) unsafeClearStream() {
177	if rs.cancelStream != nil {
178		// If the stream did not already abort due to error, this will abort it.
179		rs.cancelStream()
180		rs.cancelStream = nil
181	}
182	if rs.stream != nil {
183		rs.stream = nil
184	}
185}
186
187func (rs *retryableStream) setCancel(cancel context.CancelFunc) {
188	rs.mu.Lock()
189	defer rs.mu.Unlock()
190	rs.cancelStream = cancel
191}
192
193// reconnect attempts to establish a valid connection with the server. Due to
194// the potential high latency, initNewStream() should not be done while holding
195// retryableStream.mu. Hence we need to handle the stream being force terminated
196// during reconnection.
197//
198// Intended to be called in a goroutine. It ends once the connection has been
199// established or the stream terminated.
200func (rs *retryableStream) reconnect() {
201	canReconnect := func() bool {
202		rs.mu.Lock()
203		defer rs.mu.Unlock()
204
205		if rs.status == streamReconnecting {
206			// There can only be 1 goroutine reconnecting.
207			return false
208		}
209		if rs.status == streamTerminated {
210			return false
211		}
212		rs.status = streamReconnecting
213		rs.unsafeClearStream()
214		return true
215	}
216	if !canReconnect() {
217		return
218	}
219	rs.handler.onStreamStatusChange(streamReconnecting)
220
221	newStream, cancelFunc, err := rs.initNewStream()
222	if err != nil {
223		rs.terminate(err)
224		return
225	}
226
227	connected := func() bool {
228		rs.mu.Lock()
229		defer rs.mu.Unlock()
230
231		if rs.status == streamTerminated {
232			rs.unsafeClearStream()
233			return false
234		}
235		rs.status = streamConnected
236		rs.stream = newStream
237		rs.cancelStream = cancelFunc
238		go rs.listen(newStream)
239		return true
240	}
241	if !connected() {
242		return
243	}
244	rs.handler.onStreamStatusChange(streamConnected)
245}
246
247func (rs *retryableStream) initNewStream() (newStream grpc.ClientStream, cancelFunc context.CancelFunc, err error) {
248	r := newStreamRetryer(rs.timeout)
249	for {
250		backoff, shouldRetry := func() (time.Duration, bool) {
251			defer func() {
252				if err != nil && cancelFunc != nil {
253					cancelFunc()
254					cancelFunc = nil
255					newStream = nil
256				}
257			}()
258
259			var cctx context.Context
260			cctx, cancelFunc = context.WithCancel(rs.ctx)
261			// Store the cancel func to quickly cancel reconnecting if the stream is
262			// terminated.
263			rs.setCancel(cancelFunc)
264
265			newStream, err = rs.handler.newStream(cctx)
266			if err != nil {
267				return r.RetryRecv(err)
268			}
269			if err = newStream.SendMsg(rs.handler.initialRequest()); err != nil {
270				return r.RetrySend(err)
271			}
272			response := reflect.New(rs.responseType).Interface()
273			if err = newStream.RecvMsg(response); err != nil {
274				return r.RetryRecv(err)
275			}
276			if err = rs.handler.validateInitialResponse(response); err != nil {
277				// An unexpected initial response from the server is a permanent error.
278				return 0, false
279			}
280
281			// We have a valid connection and should break from the outer loop.
282			return 0, false
283		}()
284
285		if !shouldRetry {
286			break
287		}
288		if err = gax.Sleep(rs.ctx, backoff); err != nil {
289			break
290		}
291	}
292	return
293}
294
295// listen receives responses from the current stream. It initiates reconnection
296// upon retryable errors or terminates the stream upon permanent error.
297//
298// Intended to be called in a goroutine. It ends when recvStream has closed.
299func (rs *retryableStream) listen(recvStream grpc.ClientStream) {
300	for {
301		response := reflect.New(rs.responseType).Interface()
302		err := recvStream.RecvMsg(response)
303
304		// If the current stream has changed while listening, any errors or messages
305		// received now are obsolete. Discard and end the goroutine. Assume the
306		// stream has been cancelled elsewhere.
307		if rs.currentStream() != recvStream {
308			break
309		}
310		if err != nil {
311			if isRetryableRecvError(err) {
312				go rs.reconnect()
313			} else {
314				rs.terminate(err)
315			}
316			break
317		}
318		rs.handler.onResponse(response)
319	}
320}
321
322// terminate forces the stream to terminate with the given error (can be nil)
323// Is a no-op if the stream has already terminated.
324func (rs *retryableStream) terminate(err error) {
325	rs.mu.Lock()
326	defer rs.mu.Unlock()
327
328	if rs.status == streamTerminated {
329		return
330	}
331	rs.status = streamTerminated
332	rs.finalErr = err
333	rs.unsafeClearStream()
334
335	// terminate can be called from within a streamHandler method with a lock
336	// held. So notify from a goroutine to prevent deadlock.
337	go rs.handler.onStreamStatusChange(streamTerminated)
338}
339