1/*
2 *
3 * Copyright 2018 gRPC authors.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 *     http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 */
18
19package test
20
21import (
22	"context"
23	"fmt"
24	"io"
25	"os"
26	"reflect"
27	"strconv"
28	"strings"
29	"testing"
30	"time"
31
32	"github.com/golang/protobuf/proto"
33	"google.golang.org/grpc"
34	"google.golang.org/grpc/codes"
35	"google.golang.org/grpc/internal/envconfig"
36	"google.golang.org/grpc/metadata"
37	"google.golang.org/grpc/status"
38	testpb "google.golang.org/grpc/test/grpc_testing"
39)
40
41func enableRetry() func() {
42	old := envconfig.Retry
43	envconfig.Retry = true
44	return func() { envconfig.Retry = old }
45}
46
47func (s) TestRetryUnary(t *testing.T) {
48	defer enableRetry()()
49	i := -1
50	ss := &stubServer{
51		emptyCall: func(context.Context, *testpb.Empty) (*testpb.Empty, error) {
52			i++
53			switch i {
54			case 0, 2, 5:
55				return &testpb.Empty{}, nil
56			case 6, 8, 11:
57				return nil, status.New(codes.Internal, "non-retryable error").Err()
58			}
59			return nil, status.New(codes.AlreadyExists, "retryable error").Err()
60		},
61	}
62	if err := ss.Start([]grpc.ServerOption{}); err != nil {
63		t.Fatalf("Error starting endpoint server: %v", err)
64	}
65	defer ss.Stop()
66	ss.newServiceConfig(`{
67    "methodConfig": [{
68      "name": [{"service": "grpc.testing.TestService"}],
69      "waitForReady": true,
70      "retryPolicy": {
71        "MaxAttempts": 4,
72        "InitialBackoff": ".01s",
73        "MaxBackoff": ".01s",
74        "BackoffMultiplier": 1.0,
75        "RetryableStatusCodes": [ "ALREADY_EXISTS" ]
76      }
77    }]}`)
78	ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
79	for {
80		if ctx.Err() != nil {
81			t.Fatalf("Timed out waiting for service config update")
82		}
83		if ss.cc.GetMethodConfig("/grpc.testing.TestService/EmptyCall").WaitForReady != nil {
84			break
85		}
86		time.Sleep(time.Millisecond)
87	}
88	cancel()
89
90	testCases := []struct {
91		code  codes.Code
92		count int
93	}{
94		{codes.OK, 0},
95		{codes.OK, 2},
96		{codes.OK, 5},
97		{codes.Internal, 6},
98		{codes.Internal, 8},
99		{codes.Internal, 11},
100		{codes.AlreadyExists, 15},
101	}
102	for _, tc := range testCases {
103		ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
104		_, err := ss.client.EmptyCall(ctx, &testpb.Empty{})
105		cancel()
106		if status.Code(err) != tc.code {
107			t.Fatalf("EmptyCall(_, _) = _, %v; want _, <Code() = %v>", err, tc.code)
108		}
109		if i != tc.count {
110			t.Fatalf("i = %v; want %v", i, tc.count)
111		}
112	}
113}
114
115func (s) TestRetryDisabledByDefault(t *testing.T) {
116	if strings.EqualFold(os.Getenv("GRPC_GO_RETRY"), "on") {
117		return
118	}
119	i := -1
120	ss := &stubServer{
121		emptyCall: func(context.Context, *testpb.Empty) (*testpb.Empty, error) {
122			i++
123			switch i {
124			case 0:
125				return nil, status.New(codes.AlreadyExists, "retryable error").Err()
126			}
127			return &testpb.Empty{}, nil
128		},
129	}
130	if err := ss.Start([]grpc.ServerOption{}); err != nil {
131		t.Fatalf("Error starting endpoint server: %v", err)
132	}
133	defer ss.Stop()
134	ss.newServiceConfig(`{
135    "methodConfig": [{
136      "name": [{"service": "grpc.testing.TestService"}],
137      "waitForReady": true,
138      "retryPolicy": {
139        "MaxAttempts": 4,
140        "InitialBackoff": ".01s",
141        "MaxBackoff": ".01s",
142        "BackoffMultiplier": 1.0,
143        "RetryableStatusCodes": [ "ALREADY_EXISTS" ]
144      }
145    }]}`)
146	ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
147	for {
148		if ctx.Err() != nil {
149			t.Fatalf("Timed out waiting for service config update")
150		}
151		if ss.cc.GetMethodConfig("/grpc.testing.TestService/EmptyCall").WaitForReady != nil {
152			break
153		}
154		time.Sleep(time.Millisecond)
155	}
156	cancel()
157
158	testCases := []struct {
159		code  codes.Code
160		count int
161	}{
162		{codes.AlreadyExists, 0},
163	}
164	for _, tc := range testCases {
165		ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
166		_, err := ss.client.EmptyCall(ctx, &testpb.Empty{})
167		cancel()
168		if status.Code(err) != tc.code {
169			t.Fatalf("EmptyCall(_, _) = _, %v; want _, <Code() = %v>", err, tc.code)
170		}
171		if i != tc.count {
172			t.Fatalf("i = %v; want %v", i, tc.count)
173		}
174	}
175}
176
177func (s) TestRetryThrottling(t *testing.T) {
178	defer enableRetry()()
179	i := -1
180	ss := &stubServer{
181		emptyCall: func(context.Context, *testpb.Empty) (*testpb.Empty, error) {
182			i++
183			switch i {
184			case 0, 3, 6, 10, 11, 12, 13, 14, 16, 18:
185				return &testpb.Empty{}, nil
186			}
187			return nil, status.New(codes.Unavailable, "retryable error").Err()
188		},
189	}
190	if err := ss.Start([]grpc.ServerOption{}); err != nil {
191		t.Fatalf("Error starting endpoint server: %v", err)
192	}
193	defer ss.Stop()
194	ss.newServiceConfig(`{
195    "methodConfig": [{
196      "name": [{"service": "grpc.testing.TestService"}],
197      "waitForReady": true,
198      "retryPolicy": {
199        "MaxAttempts": 4,
200        "InitialBackoff": ".01s",
201        "MaxBackoff": ".01s",
202        "BackoffMultiplier": 1.0,
203        "RetryableStatusCodes": [ "UNAVAILABLE" ]
204      }
205    }],
206    "retryThrottling": {
207      "maxTokens": 10,
208      "tokenRatio": 0.5
209    }
210  }`)
211	ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
212	for {
213		if ctx.Err() != nil {
214			t.Fatalf("Timed out waiting for service config update")
215		}
216		if ss.cc.GetMethodConfig("/grpc.testing.TestService/EmptyCall").WaitForReady != nil {
217			break
218		}
219		time.Sleep(time.Millisecond)
220	}
221	cancel()
222
223	testCases := []struct {
224		code  codes.Code
225		count int
226	}{
227		{codes.OK, 0},           // tokens = 10
228		{codes.OK, 3},           // tokens = 8.5 (10 - 2 failures + 0.5 success)
229		{codes.OK, 6},           // tokens = 6
230		{codes.Unavailable, 8},  // tokens = 5 -- first attempt is retried; second aborted.
231		{codes.Unavailable, 9},  // tokens = 4
232		{codes.OK, 10},          // tokens = 4.5
233		{codes.OK, 11},          // tokens = 5
234		{codes.OK, 12},          // tokens = 5.5
235		{codes.OK, 13},          // tokens = 6
236		{codes.OK, 14},          // tokens = 6.5
237		{codes.OK, 16},          // tokens = 5.5
238		{codes.Unavailable, 17}, // tokens = 4.5
239	}
240	for _, tc := range testCases {
241		ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
242		_, err := ss.client.EmptyCall(ctx, &testpb.Empty{})
243		cancel()
244		if status.Code(err) != tc.code {
245			t.Errorf("EmptyCall(_, _) = _, %v; want _, <Code() = %v>", err, tc.code)
246		}
247		if i != tc.count {
248			t.Errorf("i = %v; want %v", i, tc.count)
249		}
250	}
251}
252
253func (s) TestRetryStreaming(t *testing.T) {
254	defer enableRetry()()
255	req := func(b byte) *testpb.StreamingOutputCallRequest {
256		return &testpb.StreamingOutputCallRequest{Payload: &testpb.Payload{Body: []byte{b}}}
257	}
258	res := func(b byte) *testpb.StreamingOutputCallResponse {
259		return &testpb.StreamingOutputCallResponse{Payload: &testpb.Payload{Body: []byte{b}}}
260	}
261
262	largePayload, _ := newPayload(testpb.PayloadType_COMPRESSABLE, 500)
263
264	type serverOp func(stream testpb.TestService_FullDuplexCallServer) error
265	type clientOp func(stream testpb.TestService_FullDuplexCallClient) error
266
267	// Server Operations
268	sAttempts := func(n int) serverOp {
269		return func(stream testpb.TestService_FullDuplexCallServer) error {
270			const key = "grpc-previous-rpc-attempts"
271			md, ok := metadata.FromIncomingContext(stream.Context())
272			if !ok {
273				return status.Errorf(codes.Internal, "server: no header metadata received")
274			}
275			if got := md[key]; len(got) != 1 || got[0] != strconv.Itoa(n) {
276				return status.Errorf(codes.Internal, "server: metadata = %v; want <contains %q: %q>", md, key, n)
277			}
278			return nil
279		}
280	}
281	sReq := func(b byte) serverOp {
282		return func(stream testpb.TestService_FullDuplexCallServer) error {
283			want := req(b)
284			if got, err := stream.Recv(); err != nil || !proto.Equal(got, want) {
285				return status.Errorf(codes.Internal, "server: Recv() = %v, %v; want %v, <nil>", got, err, want)
286			}
287			return nil
288		}
289	}
290	sReqPayload := func(p *testpb.Payload) serverOp {
291		return func(stream testpb.TestService_FullDuplexCallServer) error {
292			want := &testpb.StreamingOutputCallRequest{Payload: p}
293			if got, err := stream.Recv(); err != nil || !proto.Equal(got, want) {
294				return status.Errorf(codes.Internal, "server: Recv() = %v, %v; want %v, <nil>", got, err, want)
295			}
296			return nil
297		}
298	}
299	sRes := func(b byte) serverOp {
300		return func(stream testpb.TestService_FullDuplexCallServer) error {
301			msg := res(b)
302			if err := stream.Send(msg); err != nil {
303				return status.Errorf(codes.Internal, "server: Send(%v) = %v; want <nil>", msg, err)
304			}
305			return nil
306		}
307	}
308	sErr := func(c codes.Code) serverOp {
309		return func(stream testpb.TestService_FullDuplexCallServer) error {
310			return status.New(c, "").Err()
311		}
312	}
313	sCloseSend := func() serverOp {
314		return func(stream testpb.TestService_FullDuplexCallServer) error {
315			if msg, err := stream.Recv(); msg != nil || err != io.EOF {
316				return status.Errorf(codes.Internal, "server: Recv() = %v, %v; want <nil>, io.EOF", msg, err)
317			}
318			return nil
319		}
320	}
321	sPushback := func(s string) serverOp {
322		return func(stream testpb.TestService_FullDuplexCallServer) error {
323			stream.SetTrailer(metadata.MD{"grpc-retry-pushback-ms": []string{s}})
324			return nil
325		}
326	}
327
328	// Client Operations
329	cReq := func(b byte) clientOp {
330		return func(stream testpb.TestService_FullDuplexCallClient) error {
331			msg := req(b)
332			if err := stream.Send(msg); err != nil {
333				return fmt.Errorf("client: Send(%v) = %v; want <nil>", msg, err)
334			}
335			return nil
336		}
337	}
338	cReqPayload := func(p *testpb.Payload) clientOp {
339		return func(stream testpb.TestService_FullDuplexCallClient) error {
340			msg := &testpb.StreamingOutputCallRequest{Payload: p}
341			if err := stream.Send(msg); err != nil {
342				return fmt.Errorf("client: Send(%v) = %v; want <nil>", msg, err)
343			}
344			return nil
345		}
346	}
347	cRes := func(b byte) clientOp {
348		return func(stream testpb.TestService_FullDuplexCallClient) error {
349			want := res(b)
350			if got, err := stream.Recv(); err != nil || !proto.Equal(got, want) {
351				return fmt.Errorf("client: Recv() = %v, %v; want %v, <nil>", got, err, want)
352			}
353			return nil
354		}
355	}
356	cErr := func(c codes.Code) clientOp {
357		return func(stream testpb.TestService_FullDuplexCallClient) error {
358			want := status.New(c, "").Err()
359			if c == codes.OK {
360				want = io.EOF
361			}
362			res, err := stream.Recv()
363			if res != nil ||
364				((err == nil) != (want == nil)) ||
365				(want != nil && !reflect.DeepEqual(err, want)) {
366				return fmt.Errorf("client: Recv() = %v, %v; want <nil>, %v", res, err, want)
367			}
368			return nil
369		}
370	}
371	cCloseSend := func() clientOp {
372		return func(stream testpb.TestService_FullDuplexCallClient) error {
373			if err := stream.CloseSend(); err != nil {
374				return fmt.Errorf("client: CloseSend() = %v; want <nil>", err)
375			}
376			return nil
377		}
378	}
379	var curTime time.Time
380	cGetTime := func() clientOp {
381		return func(_ testpb.TestService_FullDuplexCallClient) error {
382			curTime = time.Now()
383			return nil
384		}
385	}
386	cCheckElapsed := func(d time.Duration) clientOp {
387		return func(_ testpb.TestService_FullDuplexCallClient) error {
388			if elapsed := time.Since(curTime); elapsed < d {
389				return fmt.Errorf("elapsed time: %v; want >= %v", elapsed, d)
390			}
391			return nil
392		}
393	}
394	cHdr := func() clientOp {
395		return func(stream testpb.TestService_FullDuplexCallClient) error {
396			_, err := stream.Header()
397			return err
398		}
399	}
400	cCtx := func() clientOp {
401		return func(stream testpb.TestService_FullDuplexCallClient) error {
402			stream.Context()
403			return nil
404		}
405	}
406
407	testCases := []struct {
408		desc      string
409		serverOps []serverOp
410		clientOps []clientOp
411	}{{
412		desc:      "Non-retryable error code",
413		serverOps: []serverOp{sReq(1), sErr(codes.Internal)},
414		clientOps: []clientOp{cReq(1), cErr(codes.Internal)},
415	}, {
416		desc:      "One retry necessary",
417		serverOps: []serverOp{sReq(1), sErr(codes.Unavailable), sReq(1), sAttempts(1), sRes(1)},
418		clientOps: []clientOp{cReq(1), cRes(1), cErr(codes.OK)},
419	}, {
420		desc: "Exceed max attempts (4); check attempts header on server",
421		serverOps: []serverOp{
422			sReq(1), sErr(codes.Unavailable),
423			sReq(1), sAttempts(1), sErr(codes.Unavailable),
424			sAttempts(2), sReq(1), sErr(codes.Unavailable),
425			sAttempts(3), sReq(1), sErr(codes.Unavailable),
426		},
427		clientOps: []clientOp{cReq(1), cErr(codes.Unavailable)},
428	}, {
429		desc: "Multiple requests",
430		serverOps: []serverOp{
431			sReq(1), sReq(2), sErr(codes.Unavailable),
432			sReq(1), sReq(2), sRes(5),
433		},
434		clientOps: []clientOp{cReq(1), cReq(2), cRes(5), cErr(codes.OK)},
435	}, {
436		desc: "Multiple successive requests",
437		serverOps: []serverOp{
438			sReq(1), sErr(codes.Unavailable),
439			sReq(1), sReq(2), sErr(codes.Unavailable),
440			sReq(1), sReq(2), sReq(3), sRes(5),
441		},
442		clientOps: []clientOp{cReq(1), cReq(2), cReq(3), cRes(5), cErr(codes.OK)},
443	}, {
444		desc: "No retry after receiving",
445		serverOps: []serverOp{
446			sReq(1), sErr(codes.Unavailable),
447			sReq(1), sRes(3), sErr(codes.Unavailable),
448		},
449		clientOps: []clientOp{cReq(1), cRes(3), cErr(codes.Unavailable)},
450	}, {
451		desc:      "No retry after header",
452		serverOps: []serverOp{sReq(1), sErr(codes.Unavailable)},
453		clientOps: []clientOp{cReq(1), cHdr(), cErr(codes.Unavailable)},
454	}, {
455		desc:      "No retry after context",
456		serverOps: []serverOp{sReq(1), sErr(codes.Unavailable)},
457		clientOps: []clientOp{cReq(1), cCtx(), cErr(codes.Unavailable)},
458	}, {
459		desc: "Replaying close send",
460		serverOps: []serverOp{
461			sReq(1), sReq(2), sCloseSend(), sErr(codes.Unavailable),
462			sReq(1), sReq(2), sCloseSend(), sRes(1), sRes(3), sRes(5),
463		},
464		clientOps: []clientOp{cReq(1), cReq(2), cCloseSend(), cRes(1), cRes(3), cRes(5), cErr(codes.OK)},
465	}, {
466		desc:      "Negative server pushback - no retry",
467		serverOps: []serverOp{sReq(1), sPushback("-1"), sErr(codes.Unavailable)},
468		clientOps: []clientOp{cReq(1), cErr(codes.Unavailable)},
469	}, {
470		desc:      "Non-numeric server pushback - no retry",
471		serverOps: []serverOp{sReq(1), sPushback("xxx"), sErr(codes.Unavailable)},
472		clientOps: []clientOp{cReq(1), cErr(codes.Unavailable)},
473	}, {
474		desc:      "Multiple server pushback values - no retry",
475		serverOps: []serverOp{sReq(1), sPushback("100"), sPushback("10"), sErr(codes.Unavailable)},
476		clientOps: []clientOp{cReq(1), cErr(codes.Unavailable)},
477	}, {
478		desc:      "1s server pushback - delayed retry",
479		serverOps: []serverOp{sReq(1), sPushback("1000"), sErr(codes.Unavailable), sReq(1), sRes(2)},
480		clientOps: []clientOp{cGetTime(), cReq(1), cRes(2), cCheckElapsed(time.Second), cErr(codes.OK)},
481	}, {
482		desc:      "Overflowing buffer - no retry",
483		serverOps: []serverOp{sReqPayload(largePayload), sErr(codes.Unavailable)},
484		clientOps: []clientOp{cReqPayload(largePayload), cErr(codes.Unavailable)},
485	}}
486
487	var serverOpIter int
488	var serverOps []serverOp
489	ss := &stubServer{
490		fullDuplexCall: func(stream testpb.TestService_FullDuplexCallServer) error {
491			for serverOpIter < len(serverOps) {
492				op := serverOps[serverOpIter]
493				serverOpIter++
494				if err := op(stream); err != nil {
495					return err
496				}
497			}
498			return nil
499		},
500	}
501	if err := ss.Start([]grpc.ServerOption{}, grpc.WithDefaultCallOptions(grpc.MaxRetryRPCBufferSize(200))); err != nil {
502		t.Fatalf("Error starting endpoint server: %v", err)
503	}
504	defer ss.Stop()
505	ss.newServiceConfig(`{
506    "methodConfig": [{
507      "name": [{"service": "grpc.testing.TestService"}],
508      "waitForReady": true,
509      "retryPolicy": {
510          "MaxAttempts": 4,
511          "InitialBackoff": ".01s",
512          "MaxBackoff": ".01s",
513          "BackoffMultiplier": 1.0,
514          "RetryableStatusCodes": [ "UNAVAILABLE" ]
515      }
516    }]}`)
517	ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
518	for {
519		if ctx.Err() != nil {
520			t.Fatalf("Timed out waiting for service config update")
521		}
522		if ss.cc.GetMethodConfig("/grpc.testing.TestService/FullDuplexCall").WaitForReady != nil {
523			break
524		}
525		time.Sleep(time.Millisecond)
526	}
527	cancel()
528
529	for _, tc := range testCases {
530		func() {
531			serverOpIter = 0
532			serverOps = tc.serverOps
533
534			ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
535			defer cancel()
536			stream, err := ss.client.FullDuplexCall(ctx)
537			if err != nil {
538				t.Fatalf("%v: Error while creating stream: %v", tc.desc, err)
539			}
540			for _, op := range tc.clientOps {
541				if err := op(stream); err != nil {
542					t.Errorf("%v: %v", tc.desc, err)
543					break
544				}
545			}
546			if serverOpIter != len(serverOps) {
547				t.Errorf("%v: serverOpIter = %v; want %v", tc.desc, serverOpIter, len(serverOps))
548			}
549		}()
550	}
551}
552