1// Copyright 2020 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13
14package wire
15
16import (
17	"context"
18	"io"
19	"reflect"
20	"sync"
21	"time"
22
23	"golang.org/x/xerrors"
24	"google.golang.org/grpc"
25
26	gax "github.com/googleapis/gax-go/v2"
27)
28
29// streamStatus is the status of a retryableStream. A stream starts off
30// uninitialized. While it is active, it can transition between reconnecting and
31// connected due to retryable errors. When a permanent error occurs, the stream
32// is terminated and cannot be reconnected.
33type streamStatus int
34
35const (
36	streamUninitialized streamStatus = iota
37	streamReconnecting
38	streamConnected
39	streamTerminated
40)
41
42type initialResponseRequired bool
43
44// streamHandler provides hooks for different Pub/Sub Lite streaming APIs
45// (e.g. publish, subscribe, streaming cursor, etc.) to use retryableStream.
46// All Pub/Sub Lite streaming APIs implement a similar handshaking protocol,
47// where an initial request and response must be transmitted before other
48// requests can be sent over the stream.
49//
50// streamHandler methods must not be called while holding retryableStream.mu in
51// order to prevent the streamHandler calling back into the retryableStream and
52// deadlocking.
53//
54// If any streamHandler method implementations block, this will block the
55// retryableStream.connectStream goroutine processing the underlying stream.
56type streamHandler interface {
57	// newStream implementations must create the client stream with the given
58	// (cancellable) context.
59	newStream(context.Context) (grpc.ClientStream, error)
60	// initialRequest should return the initial request and whether an initial
61	// response is expected.
62	initialRequest() (interface{}, initialResponseRequired)
63	validateInitialResponse(interface{}) error
64
65	// onStreamStatusChange is used to notify stream handlers when the stream has
66	// changed state. A `streamReconnecting` status change is fired before
67	// attempting to connect a new stream. A `streamConnected` status change is
68	// fired when the stream is successfully connected. These are followed by
69	// onResponse() calls when responses are received from the server. These
70	// events are guaranteed to occur in this order.
71	//
72	// A final `streamTerminated` status change is fired when a permanent error
73	// occurs. retryableStream.Error() returns the error that caused the stream to
74	// terminate.
75	onStreamStatusChange(streamStatus)
76	// onResponse forwards a response received on the stream to the stream
77	// handler.
78	onResponse(interface{})
79}
80
81// retryableStream is a wrapper around a bidirectional gRPC client stream to
82// handle automatic reconnection when the stream breaks.
83//
84// The connectStream() goroutine handles each stream connection. terminate() can
85// be called at any time, either by the client to force stream closure, or as a
86// result of an unretryable error.
87//
88// Safe to call capitalized methods from multiple goroutines. All other methods
89// are private implementation.
90type retryableStream struct {
91	// Immutable after creation.
92	ctx          context.Context
93	handler      streamHandler
94	responseType reflect.Type
95	timeout      time.Duration
96
97	// Guards access to fields below.
98	mu sync.Mutex
99
100	// The current connected stream.
101	stream grpc.ClientStream
102	// Function to cancel the current stream (which may be reconnecting).
103	cancelStream context.CancelFunc
104	status       streamStatus
105	finalErr     error
106}
107
108// newRetryableStream creates a new retryable stream wrapper. `timeout` is the
109// maximum duration for reconnection. `responseType` is the type of the response
110// proto received on the stream.
111func newRetryableStream(ctx context.Context, handler streamHandler, timeout time.Duration, responseType reflect.Type) *retryableStream {
112	return &retryableStream{
113		ctx:          ctx,
114		handler:      handler,
115		responseType: responseType,
116		timeout:      timeout,
117	}
118}
119
120// Start establishes a stream connection. It is a no-op if the stream has
121// already started.
122func (rs *retryableStream) Start() {
123	rs.mu.Lock()
124	defer rs.mu.Unlock()
125
126	if rs.status == streamUninitialized {
127		go rs.connectStream()
128	}
129}
130
131// Stop gracefully closes the stream without error.
132func (rs *retryableStream) Stop() {
133	rs.terminate(nil)
134}
135
136// Send attempts to send the request to the underlying stream and returns true
137// if successfully sent. Returns false if an error occurred or a reconnection is
138// in progress.
139func (rs *retryableStream) Send(request interface{}) (sent bool) {
140	rs.mu.Lock()
141	defer rs.mu.Unlock()
142
143	if rs.stream != nil {
144		err := rs.stream.SendMsg(request)
145		// Note: if SendMsg returns an error, the stream is aborted.
146		switch {
147		case err == nil:
148			sent = true
149		case err == io.EOF:
150			// If SendMsg returns io.EOF, RecvMsg will return the status of the
151			// stream. Nothing to do here.
152			break
153		case isRetryableSendError(err):
154			go rs.connectStream()
155		default:
156			rs.unsafeTerminate(err)
157		}
158	}
159	return
160}
161
162// Status returns the current status of the retryable stream.
163func (rs *retryableStream) Status() streamStatus {
164	rs.mu.Lock()
165	defer rs.mu.Unlock()
166	return rs.status
167}
168
169// Error returns the error that caused the stream to terminate. Can be nil if it
170// was initiated by Stop().
171func (rs *retryableStream) Error() error {
172	rs.mu.Lock()
173	defer rs.mu.Unlock()
174	return rs.finalErr
175}
176
177func (rs *retryableStream) currentStream() grpc.ClientStream {
178	rs.mu.Lock()
179	defer rs.mu.Unlock()
180	return rs.stream
181}
182
183// unsafeClearStream must be called with the retryableStream.mu locked.
184func (rs *retryableStream) unsafeClearStream() {
185	if rs.cancelStream != nil {
186		// If the stream did not already abort due to error, this will abort it.
187		rs.cancelStream()
188		rs.cancelStream = nil
189	}
190	if rs.stream != nil {
191		rs.stream = nil
192	}
193}
194
195func (rs *retryableStream) setCancel(cancel context.CancelFunc) {
196	rs.mu.Lock()
197	defer rs.mu.Unlock()
198
199	rs.unsafeClearStream()
200	rs.cancelStream = cancel
201}
202
203// connectStream attempts to establish a valid connection with the server. Due
204// to the potential high latency, initNewStream() should not be done while
205// holding retryableStream.mu. Hence we need to handle the stream being force
206// terminated during reconnection.
207//
208// Intended to be called in a goroutine. It ends once the client stream closes.
209func (rs *retryableStream) connectStream() {
210	canReconnect := func() bool {
211		rs.mu.Lock()
212		defer rs.mu.Unlock()
213
214		if rs.status == streamReconnecting {
215			// There can only be 1 goroutine reconnecting.
216			return false
217		}
218		if rs.status == streamTerminated {
219			return false
220		}
221		rs.status = streamReconnecting
222		rs.unsafeClearStream()
223		return true
224	}
225	if !canReconnect() {
226		return
227	}
228	rs.handler.onStreamStatusChange(streamReconnecting)
229
230	newStream, cancelFunc, err := rs.initNewStream()
231	if err != nil {
232		rs.terminate(err)
233		return
234	}
235
236	connected := func() bool {
237		rs.mu.Lock()
238		defer rs.mu.Unlock()
239
240		if rs.status == streamTerminated {
241			rs.unsafeClearStream()
242			return false
243		}
244		rs.status = streamConnected
245		rs.stream = newStream
246		rs.cancelStream = cancelFunc
247		return true
248	}
249	if !connected() {
250		return
251	}
252
253	rs.handler.onStreamStatusChange(streamConnected)
254	rs.listen(newStream)
255}
256
257func (rs *retryableStream) initNewStream() (newStream grpc.ClientStream, cancelFunc context.CancelFunc, err error) {
258	r := newStreamRetryer(rs.timeout)
259	for {
260		backoff, shouldRetry := func() (time.Duration, bool) {
261			var cctx context.Context
262			cctx, cancelFunc = context.WithCancel(rs.ctx)
263			// Store the cancel func to quickly cancel reconnecting if the stream is
264			// terminated.
265			rs.setCancel(cancelFunc)
266
267			newStream, err = rs.handler.newStream(cctx)
268			if err != nil {
269				return r.RetryRecv(err)
270			}
271			initReq, needsResponse := rs.handler.initialRequest()
272			if err = newStream.SendMsg(initReq); err != nil {
273				return r.RetrySend(err)
274			}
275			if needsResponse {
276				response := reflect.New(rs.responseType).Interface()
277				if err = newStream.RecvMsg(response); err != nil {
278					return r.RetryRecv(err)
279				}
280				if err = rs.handler.validateInitialResponse(response); err != nil {
281					// An unexpected initial response from the server is a permanent error.
282					cancelFunc()
283					return 0, false
284				}
285			}
286
287			// We have a valid connection and should break from the outer loop.
288			return 0, false
289		}()
290
291		if (shouldRetry || err != nil) && cancelFunc != nil {
292			// Ensure that streams aren't leaked.
293			cancelFunc()
294			cancelFunc = nil
295			newStream = nil
296		}
297		if !shouldRetry || rs.Status() == streamTerminated {
298			break
299		}
300		if r.ExceededDeadline() {
301			err = xerrors.Errorf("%v: %w", err, ErrBackendUnavailable)
302			break
303		}
304		if err = gax.Sleep(rs.ctx, backoff); err != nil {
305			break
306		}
307	}
308	return
309}
310
311// listen receives responses from the current stream. It initiates reconnection
312// upon retryable errors or terminates the stream upon permanent error.
313func (rs *retryableStream) listen(recvStream grpc.ClientStream) {
314	for {
315		response := reflect.New(rs.responseType).Interface()
316		err := recvStream.RecvMsg(response)
317
318		// If the current stream has changed while listening, any errors or messages
319		// received now are obsolete. Discard and end the goroutine. Assume the
320		// stream has been cancelled elsewhere.
321		if rs.currentStream() != recvStream {
322			break
323		}
324		if err != nil {
325			if isRetryableRecvError(err) {
326				go rs.connectStream()
327			} else {
328				rs.terminate(err)
329			}
330			break
331		}
332		rs.handler.onResponse(response)
333	}
334}
335
336// terminate forces the stream to terminate with the given error (can be nil)
337// Is a no-op if the stream has already terminated.
338func (rs *retryableStream) terminate(err error) {
339	rs.mu.Lock()
340	defer rs.mu.Unlock()
341	rs.unsafeTerminate(err)
342}
343
344func (rs *retryableStream) unsafeTerminate(err error) {
345	if rs.status == streamTerminated {
346		return
347	}
348	rs.status = streamTerminated
349	rs.finalErr = err
350	rs.unsafeClearStream()
351
352	// terminate can be called from within a streamHandler method with a lock
353	// held. So notify from a goroutine to prevent deadlock.
354	go rs.handler.onStreamStatusChange(streamTerminated)
355}
356