1// Copyright 2020 Google LLC 2// 3// Licensed under the Apache License, Version 2.0 (the "License"); 4// you may not use this file except in compliance with the License. 5// You may obtain a copy of the License at 6// 7// https://www.apache.org/licenses/LICENSE-2.0 8// 9// Unless required by applicable law or agreed to in writing, software 10// distributed under the License is distributed on an "AS IS" BASIS, 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12// See the License for the specific language governing permissions and 13 14package wire 15 16import ( 17 "context" 18 "io" 19 "reflect" 20 "sync" 21 "time" 22 23 "golang.org/x/xerrors" 24 "google.golang.org/grpc" 25 26 gax "github.com/googleapis/gax-go/v2" 27) 28 29// streamStatus is the status of a retryableStream. A stream starts off 30// uninitialized. While it is active, it can transition between reconnecting and 31// connected due to retryable errors. When a permanent error occurs, the stream 32// is terminated and cannot be reconnected. 33type streamStatus int 34 35const ( 36 streamUninitialized streamStatus = iota 37 streamReconnecting 38 streamConnected 39 streamTerminated 40) 41 42type initialResponseRequired bool 43 44// streamHandler provides hooks for different Pub/Sub Lite streaming APIs 45// (e.g. publish, subscribe, streaming cursor, etc.) to use retryableStream. 46// All Pub/Sub Lite streaming APIs implement a similar handshaking protocol, 47// where an initial request and response must be transmitted before other 48// requests can be sent over the stream. 49// 50// streamHandler methods must not be called while holding retryableStream.mu in 51// order to prevent the streamHandler calling back into the retryableStream and 52// deadlocking. 53// 54// If any streamHandler method implementations block, this will block the 55// retryableStream.connectStream goroutine processing the underlying stream. 56type streamHandler interface { 57 // newStream implementations must create the client stream with the given 58 // (cancellable) context. 59 newStream(context.Context) (grpc.ClientStream, error) 60 // initialRequest should return the initial request and whether an initial 61 // response is expected. 62 initialRequest() (interface{}, initialResponseRequired) 63 validateInitialResponse(interface{}) error 64 65 // onStreamStatusChange is used to notify stream handlers when the stream has 66 // changed state. A `streamReconnecting` status change is fired before 67 // attempting to connect a new stream. A `streamConnected` status change is 68 // fired when the stream is successfully connected. These are followed by 69 // onResponse() calls when responses are received from the server. These 70 // events are guaranteed to occur in this order. 71 // 72 // A final `streamTerminated` status change is fired when a permanent error 73 // occurs. retryableStream.Error() returns the error that caused the stream to 74 // terminate. 75 onStreamStatusChange(streamStatus) 76 // onResponse forwards a response received on the stream to the stream 77 // handler. 78 onResponse(interface{}) 79} 80 81// retryableStream is a wrapper around a bidirectional gRPC client stream to 82// handle automatic reconnection when the stream breaks. 83// 84// The connectStream() goroutine handles each stream connection. terminate() can 85// be called at any time, either by the client to force stream closure, or as a 86// result of an unretryable error. 87// 88// Safe to call capitalized methods from multiple goroutines. All other methods 89// are private implementation. 90type retryableStream struct { 91 // Immutable after creation. 92 ctx context.Context 93 handler streamHandler 94 responseType reflect.Type 95 timeout time.Duration 96 97 // Guards access to fields below. 98 mu sync.Mutex 99 100 // The current connected stream. 101 stream grpc.ClientStream 102 // Function to cancel the current stream (which may be reconnecting). 103 cancelStream context.CancelFunc 104 status streamStatus 105 finalErr error 106} 107 108// newRetryableStream creates a new retryable stream wrapper. `timeout` is the 109// maximum duration for reconnection. `responseType` is the type of the response 110// proto received on the stream. 111func newRetryableStream(ctx context.Context, handler streamHandler, timeout time.Duration, responseType reflect.Type) *retryableStream { 112 return &retryableStream{ 113 ctx: ctx, 114 handler: handler, 115 responseType: responseType, 116 timeout: timeout, 117 } 118} 119 120// Start establishes a stream connection. It is a no-op if the stream has 121// already started. 122func (rs *retryableStream) Start() { 123 rs.mu.Lock() 124 defer rs.mu.Unlock() 125 126 if rs.status == streamUninitialized { 127 go rs.connectStream() 128 } 129} 130 131// Stop gracefully closes the stream without error. 132func (rs *retryableStream) Stop() { 133 rs.terminate(nil) 134} 135 136// Send attempts to send the request to the underlying stream and returns true 137// if successfully sent. Returns false if an error occurred or a reconnection is 138// in progress. 139func (rs *retryableStream) Send(request interface{}) (sent bool) { 140 rs.mu.Lock() 141 defer rs.mu.Unlock() 142 143 if rs.stream != nil { 144 err := rs.stream.SendMsg(request) 145 // Note: if SendMsg returns an error, the stream is aborted. 146 switch { 147 case err == nil: 148 sent = true 149 case err == io.EOF: 150 // If SendMsg returns io.EOF, RecvMsg will return the status of the 151 // stream. Nothing to do here. 152 break 153 case isRetryableSendError(err): 154 go rs.connectStream() 155 default: 156 rs.unsafeTerminate(err) 157 } 158 } 159 return 160} 161 162// Status returns the current status of the retryable stream. 163func (rs *retryableStream) Status() streamStatus { 164 rs.mu.Lock() 165 defer rs.mu.Unlock() 166 return rs.status 167} 168 169// Error returns the error that caused the stream to terminate. Can be nil if it 170// was initiated by Stop(). 171func (rs *retryableStream) Error() error { 172 rs.mu.Lock() 173 defer rs.mu.Unlock() 174 return rs.finalErr 175} 176 177func (rs *retryableStream) currentStream() grpc.ClientStream { 178 rs.mu.Lock() 179 defer rs.mu.Unlock() 180 return rs.stream 181} 182 183// unsafeClearStream must be called with the retryableStream.mu locked. 184func (rs *retryableStream) unsafeClearStream() { 185 if rs.cancelStream != nil { 186 // If the stream did not already abort due to error, this will abort it. 187 rs.cancelStream() 188 rs.cancelStream = nil 189 } 190 if rs.stream != nil { 191 rs.stream = nil 192 } 193} 194 195func (rs *retryableStream) setCancel(cancel context.CancelFunc) { 196 rs.mu.Lock() 197 defer rs.mu.Unlock() 198 199 rs.unsafeClearStream() 200 rs.cancelStream = cancel 201} 202 203// connectStream attempts to establish a valid connection with the server. Due 204// to the potential high latency, initNewStream() should not be done while 205// holding retryableStream.mu. Hence we need to handle the stream being force 206// terminated during reconnection. 207// 208// Intended to be called in a goroutine. It ends once the client stream closes. 209func (rs *retryableStream) connectStream() { 210 canReconnect := func() bool { 211 rs.mu.Lock() 212 defer rs.mu.Unlock() 213 214 if rs.status == streamReconnecting { 215 // There can only be 1 goroutine reconnecting. 216 return false 217 } 218 if rs.status == streamTerminated { 219 return false 220 } 221 rs.status = streamReconnecting 222 rs.unsafeClearStream() 223 return true 224 } 225 if !canReconnect() { 226 return 227 } 228 rs.handler.onStreamStatusChange(streamReconnecting) 229 230 newStream, cancelFunc, err := rs.initNewStream() 231 if err != nil { 232 rs.terminate(err) 233 return 234 } 235 236 connected := func() bool { 237 rs.mu.Lock() 238 defer rs.mu.Unlock() 239 240 if rs.status == streamTerminated { 241 rs.unsafeClearStream() 242 return false 243 } 244 rs.status = streamConnected 245 rs.stream = newStream 246 rs.cancelStream = cancelFunc 247 return true 248 } 249 if !connected() { 250 return 251 } 252 253 rs.handler.onStreamStatusChange(streamConnected) 254 rs.listen(newStream) 255} 256 257func (rs *retryableStream) initNewStream() (newStream grpc.ClientStream, cancelFunc context.CancelFunc, err error) { 258 r := newStreamRetryer(rs.timeout) 259 for { 260 backoff, shouldRetry := func() (time.Duration, bool) { 261 var cctx context.Context 262 cctx, cancelFunc = context.WithCancel(rs.ctx) 263 // Store the cancel func to quickly cancel reconnecting if the stream is 264 // terminated. 265 rs.setCancel(cancelFunc) 266 267 newStream, err = rs.handler.newStream(cctx) 268 if err != nil { 269 return r.RetryRecv(err) 270 } 271 initReq, needsResponse := rs.handler.initialRequest() 272 if err = newStream.SendMsg(initReq); err != nil { 273 return r.RetrySend(err) 274 } 275 if needsResponse { 276 response := reflect.New(rs.responseType).Interface() 277 if err = newStream.RecvMsg(response); err != nil { 278 return r.RetryRecv(err) 279 } 280 if err = rs.handler.validateInitialResponse(response); err != nil { 281 // An unexpected initial response from the server is a permanent error. 282 cancelFunc() 283 return 0, false 284 } 285 } 286 287 // We have a valid connection and should break from the outer loop. 288 return 0, false 289 }() 290 291 if (shouldRetry || err != nil) && cancelFunc != nil { 292 // Ensure that streams aren't leaked. 293 cancelFunc() 294 cancelFunc = nil 295 newStream = nil 296 } 297 if !shouldRetry || rs.Status() == streamTerminated { 298 break 299 } 300 if r.ExceededDeadline() { 301 err = xerrors.Errorf("%v: %w", err, ErrBackendUnavailable) 302 break 303 } 304 if err = gax.Sleep(rs.ctx, backoff); err != nil { 305 break 306 } 307 } 308 return 309} 310 311// listen receives responses from the current stream. It initiates reconnection 312// upon retryable errors or terminates the stream upon permanent error. 313func (rs *retryableStream) listen(recvStream grpc.ClientStream) { 314 for { 315 response := reflect.New(rs.responseType).Interface() 316 err := recvStream.RecvMsg(response) 317 318 // If the current stream has changed while listening, any errors or messages 319 // received now are obsolete. Discard and end the goroutine. Assume the 320 // stream has been cancelled elsewhere. 321 if rs.currentStream() != recvStream { 322 break 323 } 324 if err != nil { 325 if isRetryableRecvError(err) { 326 go rs.connectStream() 327 } else { 328 rs.terminate(err) 329 } 330 break 331 } 332 rs.handler.onResponse(response) 333 } 334} 335 336// terminate forces the stream to terminate with the given error (can be nil) 337// Is a no-op if the stream has already terminated. 338func (rs *retryableStream) terminate(err error) { 339 rs.mu.Lock() 340 defer rs.mu.Unlock() 341 rs.unsafeTerminate(err) 342} 343 344func (rs *retryableStream) unsafeTerminate(err error) { 345 if rs.status == streamTerminated { 346 return 347 } 348 rs.status = streamTerminated 349 rs.finalErr = err 350 rs.unsafeClearStream() 351 352 // terminate can be called from within a streamHandler method with a lock 353 // held. So notify from a goroutine to prevent deadlock. 354 go rs.handler.onStreamStatusChange(streamTerminated) 355} 356