1// Copyright 2020 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     https://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13
14package wire
15
16import (
17	"context"
18	"errors"
19	"reflect"
20	"testing"
21	"time"
22
23	"cloud.google.com/go/internal/testutil"
24	"cloud.google.com/go/pubsublite/internal/test"
25	"golang.org/x/xerrors"
26	"google.golang.org/grpc"
27	"google.golang.org/grpc/codes"
28	"google.golang.org/grpc/status"
29
30	vkit "cloud.google.com/go/pubsublite/apiv1"
31	pb "google.golang.org/genproto/googleapis/cloud/pubsublite/v1"
32)
33
34const defaultStreamTimeout = 30 * time.Second
35
36var errInvalidInitialResponse = errors.New("invalid initial response")
37
38// testStreamHandler is a simplified publisher service that owns a
39// retryableStream.
40type testStreamHandler struct {
41	Topic      topicPartition
42	InitialReq *pb.PublishRequest
43	Stream     *retryableStream
44
45	t         *testing.T
46	statuses  chan streamStatus
47	responses chan interface{}
48	pubClient *vkit.PublisherClient
49}
50
51func newTestStreamHandler(t *testing.T, timeout time.Duration) *testStreamHandler {
52	ctx := context.Background()
53	pubClient, err := newPublisherClient(ctx, "ignored", testServer.ClientConn())
54	if err != nil {
55		t.Fatal(err)
56	}
57
58	topic := topicPartition{Path: "path/to/topic", Partition: 1}
59	sh := &testStreamHandler{
60		Topic:      topic,
61		InitialReq: initPubReq(topic),
62		t:          t,
63		statuses:   make(chan streamStatus, 3),
64		responses:  make(chan interface{}, 1),
65		pubClient:  pubClient,
66	}
67	sh.Stream = newRetryableStream(ctx, sh, timeout, reflect.TypeOf(pb.PublishResponse{}))
68	return sh
69}
70
71func (sh *testStreamHandler) NextStatus() streamStatus {
72	select {
73	case status := <-sh.statuses:
74		return status
75	case <-time.After(defaultStreamTimeout):
76		sh.t.Errorf("Stream did not change state within %v", defaultStreamTimeout)
77		return streamUninitialized
78	}
79}
80
81func (sh *testStreamHandler) NextResponse() interface{} {
82	select {
83	case response := <-sh.responses:
84		return response
85	case <-time.After(defaultStreamTimeout):
86		sh.t.Errorf("Stream did not receive response within %v", defaultStreamTimeout)
87		return nil
88	}
89}
90
91func (sh *testStreamHandler) newStream(ctx context.Context) (grpc.ClientStream, error) {
92	return sh.pubClient.Publish(ctx)
93}
94
95func (sh *testStreamHandler) validateInitialResponse(response interface{}) error {
96	pubResponse, _ := response.(*pb.PublishResponse)
97	if pubResponse.GetInitialResponse() == nil {
98		return errInvalidInitialResponse
99	}
100	return nil
101}
102
103func (sh *testStreamHandler) initialRequest() (interface{}, initialResponseRequired) {
104	return sh.InitialReq, initialResponseRequired(true)
105}
106
107func (sh *testStreamHandler) onStreamStatusChange(status streamStatus) {
108	sh.statuses <- status
109
110	// Close connections.
111	if status == streamTerminated {
112		sh.pubClient.Close()
113	}
114}
115
116func (sh *testStreamHandler) onResponse(response interface{}) {
117	sh.responses <- response
118}
119
120func TestRetryableStreamStartOnce(t *testing.T) {
121	pub := newTestStreamHandler(t, defaultStreamTimeout)
122
123	verifiers := test.NewVerifiers(t)
124	stream := test.NewRPCVerifier(t)
125	stream.Push(pub.InitialReq, initPubResp(), nil)
126	verifiers.AddPublishStream(pub.Topic.Path, pub.Topic.Partition, stream)
127
128	mockServer.OnTestStart(verifiers)
129	defer mockServer.OnTestEnd()
130
131	// Ensure that new streams are not opened if the publisher is started twice
132	// (note: only 1 stream verifier was added to the mock server above).
133	pub.Stream.Start()
134	pub.Stream.Start()
135	pub.Stream.Start()
136	if got, want := pub.NextStatus(), streamReconnecting; got != want {
137		t.Errorf("Stream status change: got %d, want %d", got, want)
138	}
139	if got, want := pub.NextStatus(), streamConnected; got != want {
140		t.Errorf("Stream status change: got %d, want %d", got, want)
141	}
142
143	pub.Stream.Stop()
144	if got, want := pub.NextStatus(), streamTerminated; got != want {
145		t.Errorf("Stream status change: got %d, want %d", got, want)
146	}
147	if gotErr := pub.Stream.Error(); gotErr != nil {
148		t.Errorf("Stream final err: got (%v), want <nil>", gotErr)
149	}
150}
151
152func TestRetryableStreamStopWhileConnecting(t *testing.T) {
153	pub := newTestStreamHandler(t, defaultStreamTimeout)
154
155	verifiers := test.NewVerifiers(t)
156	stream := test.NewRPCVerifier(t)
157	barrier := stream.PushWithBarrier(pub.InitialReq, initPubResp(), nil)
158	verifiers.AddPublishStream(pub.Topic.Path, pub.Topic.Partition, stream)
159
160	mockServer.OnTestStart(verifiers)
161	defer mockServer.OnTestEnd()
162
163	pub.Stream.Start()
164	if got, want := pub.NextStatus(), streamReconnecting; got != want {
165		t.Errorf("Stream status change: got %d, want %d", got, want)
166	}
167
168	barrier.Release()
169	pub.Stream.Stop()
170
171	// The stream should transition to terminated and the client stream should be
172	// discarded.
173	if got, want := pub.NextStatus(), streamTerminated; got != want {
174		t.Errorf("Stream status change: got %d, want %d", got, want)
175	}
176	if pub.Stream.currentStream() != nil {
177		t.Error("Client stream should be nil")
178	}
179	if gotErr := pub.Stream.Error(); gotErr != nil {
180		t.Errorf("Stream final err: got (%v), want <nil>", gotErr)
181	}
182}
183
184func TestRetryableStreamStopAbortsRetries(t *testing.T) {
185	pub := newTestStreamHandler(t, defaultStreamTimeout)
186
187	verifiers := test.NewVerifiers(t)
188	stream := test.NewRPCVerifier(t)
189	// Aborted is a retryable error, but the stream should not be retried because
190	// the publisher is stopped.
191	barrier := stream.PushWithBarrier(pub.InitialReq, nil, status.Error(codes.Aborted, "abort retry"))
192	verifiers.AddPublishStream(pub.Topic.Path, pub.Topic.Partition, stream)
193
194	mockServer.OnTestStart(verifiers)
195	defer mockServer.OnTestEnd()
196
197	pub.Stream.Start()
198	if got, want := pub.NextStatus(), streamReconnecting; got != want {
199		t.Errorf("Stream status change: got %d, want %d", got, want)
200	}
201
202	barrier.Release()
203	pub.Stream.Stop()
204
205	// The stream should transition to terminated and the client stream should be
206	// discarded.
207	if got, want := pub.NextStatus(), streamTerminated; got != want {
208		t.Errorf("Stream status change: got %d, want %d", got, want)
209	}
210	if pub.Stream.currentStream() != nil {
211		t.Error("Client stream should be nil")
212	}
213	if gotErr := pub.Stream.Error(); gotErr != nil {
214		t.Errorf("Stream final err: got (%v), want <nil>", gotErr)
215	}
216}
217
218func TestRetryableStreamConnectRetries(t *testing.T) {
219	pub := newTestStreamHandler(t, defaultStreamTimeout)
220
221	verifiers := test.NewVerifiers(t)
222
223	// First 2 errors are retryable.
224	stream1 := test.NewRPCVerifier(t)
225	stream1.Push(pub.InitialReq, nil, status.Error(codes.Unavailable, "server unavailable"))
226	verifiers.AddPublishStream(pub.Topic.Path, pub.Topic.Partition, stream1)
227
228	stream2 := test.NewRPCVerifier(t)
229	stream2.Push(pub.InitialReq, nil, status.Error(codes.Internal, "internal"))
230	verifiers.AddPublishStream(pub.Topic.Path, pub.Topic.Partition, stream2)
231
232	// Third stream should succeed.
233	stream3 := test.NewRPCVerifier(t)
234	stream3.Push(pub.InitialReq, initPubResp(), nil)
235	verifiers.AddPublishStream(pub.Topic.Path, pub.Topic.Partition, stream3)
236
237	mockServer.OnTestStart(verifiers)
238	defer mockServer.OnTestEnd()
239
240	pub.Stream.Start()
241	if got, want := pub.NextStatus(), streamReconnecting; got != want {
242		t.Errorf("Stream status change: got %d, want %d", got, want)
243	}
244	if got, want := pub.NextStatus(), streamConnected; got != want {
245		t.Errorf("Stream status change: got %d, want %d", got, want)
246	}
247
248	pub.Stream.Stop()
249	if got, want := pub.NextStatus(), streamTerminated; got != want {
250		t.Errorf("Stream status change: got %d, want %d", got, want)
251	}
252}
253
254func TestRetryableStreamConnectPermanentFailure(t *testing.T) {
255	pub := newTestStreamHandler(t, defaultStreamTimeout)
256	permanentErr := status.Error(codes.PermissionDenied, "denied")
257
258	verifiers := test.NewVerifiers(t)
259	// The stream connection results in a non-retryable error, so the publisher
260	// cannot start.
261	stream := test.NewRPCVerifier(t)
262	stream.Push(pub.InitialReq, nil, permanentErr)
263	verifiers.AddPublishStream(pub.Topic.Path, pub.Topic.Partition, stream)
264
265	mockServer.OnTestStart(verifiers)
266	defer mockServer.OnTestEnd()
267
268	pub.Stream.Start()
269	if got, want := pub.NextStatus(), streamReconnecting; got != want {
270		t.Errorf("Stream status change: got %d, want %d", got, want)
271	}
272	if got, want := pub.NextStatus(), streamTerminated; got != want {
273		t.Errorf("Stream status change: got %d, want %d", got, want)
274	}
275	if pub.Stream.currentStream() != nil {
276		t.Error("Client stream should be nil")
277	}
278	if gotErr := pub.Stream.Error(); !test.ErrorEqual(gotErr, permanentErr) {
279		t.Errorf("Stream final err: got (%v), want (%v)", gotErr, permanentErr)
280	}
281}
282
283func TestRetryableStreamConnectTimeout(t *testing.T) {
284	// Set a very low timeout to ensure no retries.
285	timeout := time.Millisecond
286	pub := newTestStreamHandler(t, timeout)
287	wantErr := status.Error(codes.DeadlineExceeded, "timeout")
288
289	verifiers := test.NewVerifiers(t)
290	stream := test.NewRPCVerifier(t)
291	barrier := stream.PushWithBarrier(pub.InitialReq, nil, wantErr)
292	verifiers.AddPublishStream(pub.Topic.Path, pub.Topic.Partition, stream)
293
294	mockServer.OnTestStart(verifiers)
295	defer mockServer.OnTestEnd()
296
297	pub.Stream.Start()
298	if got, want := pub.NextStatus(), streamReconnecting; got != want {
299		t.Errorf("Stream status change: got %d, want %d", got, want)
300	}
301
302	// Send the initial server response well after the timeout setting.
303	time.Sleep(10 * timeout)
304	barrier.Release()
305
306	if got, want := pub.NextStatus(), streamTerminated; got != want {
307		t.Errorf("Stream status change: got %d, want %d", got, want)
308	}
309	if pub.Stream.currentStream() != nil {
310		t.Error("Client stream should be nil")
311	}
312	if gotErr := pub.Stream.Error(); !xerrors.Is(gotErr, ErrBackendUnavailable) {
313		t.Errorf("Stream final err: got (%v), want (%v)", gotErr, ErrBackendUnavailable)
314	}
315}
316
317func TestRetryableStreamSendReceive(t *testing.T) {
318	pub := newTestStreamHandler(t, defaultStreamTimeout)
319	req := msgPubReq(&pb.PubSubMessage{Data: []byte("msg")})
320	wantResp := msgPubResp(5)
321
322	verifiers := test.NewVerifiers(t)
323	stream := test.NewRPCVerifier(t)
324	barrier := stream.PushWithBarrier(pub.InitialReq, initPubResp(), nil)
325	stream.Push(req, wantResp, nil)
326	verifiers.AddPublishStream(pub.Topic.Path, pub.Topic.Partition, stream)
327
328	mockServer.OnTestStart(verifiers)
329	defer mockServer.OnTestEnd()
330
331	pub.Stream.Start()
332	if got, want := pub.NextStatus(), streamReconnecting; got != want {
333		t.Errorf("Stream status change: got %d, want %d", got, want)
334	}
335
336	// While the stream is reconnecting, requests are discarded.
337	if got, want := pub.Stream.Send(req), false; got != want {
338		t.Errorf("Stream send: got %v, want %v", got, want)
339	}
340
341	barrier.Release()
342	if got, want := pub.NextStatus(), streamConnected; got != want {
343		t.Errorf("Stream status change: got %d, want %d", got, want)
344	}
345
346	if got, want := pub.Stream.Send(req), true; got != want {
347		t.Errorf("Stream send: got %v, want %v", got, want)
348	}
349	if gotResp := pub.NextResponse(); !testutil.Equal(gotResp, wantResp) {
350		t.Errorf("Stream response: got %v, want %v", gotResp, wantResp)
351	}
352
353	pub.Stream.Stop()
354	if got, want := pub.NextStatus(), streamTerminated; got != want {
355		t.Errorf("Stream status change: got %d, want %d", got, want)
356	}
357	if gotErr := pub.Stream.Error(); gotErr != nil {
358		t.Errorf("Stream final err: got (%v), want <nil>", gotErr)
359	}
360}
361
362func TestRetryableStreamConnectReceivesResetSignal(t *testing.T) {
363	pub := newTestStreamHandler(t, defaultStreamTimeout)
364
365	verifiers := test.NewVerifiers(t)
366
367	stream1 := test.NewRPCVerifier(t)
368	// Reset signal received during stream initialization.
369	stream1.Push(pub.InitialReq, nil, makeStreamResetSignal())
370	verifiers.AddPublishStream(pub.Topic.Path, pub.Topic.Partition, stream1)
371
372	stream2 := test.NewRPCVerifier(t)
373	stream2.Push(pub.InitialReq, initPubResp(), nil)
374	verifiers.AddPublishStream(pub.Topic.Path, pub.Topic.Partition, stream2)
375
376	mockServer.OnTestStart(verifiers)
377	defer mockServer.OnTestEnd()
378
379	pub.Stream.Start()
380	if got, want := pub.NextStatus(), streamReconnecting; got != want {
381		t.Errorf("Stream status change: got %d, want %d", got, want)
382	}
383	if got, want := pub.NextStatus(), streamResetState; got != want {
384		t.Errorf("Stream status change: got %d, want %d", got, want)
385	}
386	if got, want := pub.NextStatus(), streamConnected; got != want {
387		t.Errorf("Stream status change: got %d, want %d", got, want)
388	}
389
390	pub.Stream.Stop()
391	if got, want := pub.NextStatus(), streamTerminated; got != want {
392		t.Errorf("Stream status change: got %d, want %d", got, want)
393	}
394	if gotErr := pub.Stream.Error(); gotErr != nil {
395		t.Errorf("Stream final err: got (%v), want <nil>", gotErr)
396	}
397}
398
399func TestRetryableStreamDisconnectedWithResetSignal(t *testing.T) {
400	pub := newTestStreamHandler(t, defaultStreamTimeout)
401
402	verifiers := test.NewVerifiers(t)
403
404	stream1 := test.NewRPCVerifier(t)
405	stream1.Push(pub.InitialReq, initPubResp(), nil)
406	// Reset signal received after stream is connected.
407	stream1.Push(nil, nil, makeStreamResetSignal())
408	verifiers.AddPublishStream(pub.Topic.Path, pub.Topic.Partition, stream1)
409
410	stream2 := test.NewRPCVerifier(t)
411	stream2.Push(pub.InitialReq, initPubResp(), nil)
412	verifiers.AddPublishStream(pub.Topic.Path, pub.Topic.Partition, stream2)
413
414	mockServer.OnTestStart(verifiers)
415	defer mockServer.OnTestEnd()
416
417	pub.Stream.Start()
418	if got, want := pub.NextStatus(), streamReconnecting; got != want {
419		t.Errorf("Stream status change: got %d, want %d", got, want)
420	}
421	if got, want := pub.NextStatus(), streamConnected; got != want {
422		t.Errorf("Stream status change: got %d, want %d", got, want)
423	}
424	if got, want := pub.NextStatus(), streamReconnecting; got != want {
425		t.Errorf("Stream status change: got %d, want %d", got, want)
426	}
427	if got, want := pub.NextStatus(), streamResetState; got != want {
428		t.Errorf("Stream status change: got %d, want %d", got, want)
429	}
430	if got, want := pub.NextStatus(), streamConnected; got != want {
431		t.Errorf("Stream status change: got %d, want %d", got, want)
432	}
433
434	pub.Stream.Stop()
435	if got, want := pub.NextStatus(), streamTerminated; got != want {
436		t.Errorf("Stream status change: got %d, want %d", got, want)
437	}
438	if gotErr := pub.Stream.Error(); gotErr != nil {
439		t.Errorf("Stream final err: got (%v), want <nil>", gotErr)
440	}
441}
442