1/*
2 *
3 * Copyright 2019 gRPC authors.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 *     http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 */
18
19package test
20
21import (
22	"context"
23	"testing"
24	"time"
25
26	"google.golang.org/grpc"
27	"google.golang.org/grpc/codes"
28	"google.golang.org/grpc/encoding/gzip"
29	"google.golang.org/grpc/metadata"
30	"google.golang.org/grpc/status"
31	testpb "google.golang.org/grpc/test/grpc_testing"
32)
33
34func (s) TestContextCanceled(t *testing.T) {
35	ss := &stubServer{
36		fullDuplexCall: func(stream testpb.TestService_FullDuplexCallServer) error {
37			stream.SetTrailer(metadata.New(map[string]string{"a": "b"}))
38			return status.Error(codes.PermissionDenied, "perm denied")
39		},
40	}
41	if err := ss.Start(nil); err != nil {
42		t.Fatalf("Error starting endpoint server: %v", err)
43	}
44	defer ss.Stop()
45
46	// Runs 10 rounds of tests with the given delay and returns counts of status codes.
47	// Fails in case of trailer/status code inconsistency.
48	const cntRetry uint = 10
49	runTest := func(delay time.Duration) (cntCanceled, cntPermDenied uint) {
50		for i := uint(0); i < cntRetry; i++ {
51			ctx, cancel := context.WithTimeout(context.Background(), delay)
52			defer cancel()
53
54			str, err := ss.client.FullDuplexCall(ctx)
55			if err != nil {
56				continue
57			}
58
59			_, err = str.Recv()
60			if err == nil {
61				t.Fatalf("non-nil error expected from Recv()")
62			}
63
64			_, trlOk := str.Trailer()["a"]
65			switch status.Code(err) {
66			case codes.PermissionDenied:
67				if !trlOk {
68					t.Fatalf(`status err: %v; wanted key "a" in trailer but didn't get it`, err)
69				}
70				cntPermDenied++
71			case codes.DeadlineExceeded:
72				if trlOk {
73					t.Fatalf(`status err: %v; didn't want key "a" in trailer but got it`, err)
74				}
75				cntCanceled++
76			default:
77				t.Fatalf(`unexpected status err: %v`, err)
78			}
79		}
80		return cntCanceled, cntPermDenied
81	}
82
83	// Tries to find the delay that causes canceled/perm denied race.
84	canceledOk, permDeniedOk := false, false
85	for lower, upper := time.Duration(0), 2*time.Millisecond; lower <= upper; {
86		delay := lower + (upper-lower)/2
87		cntCanceled, cntPermDenied := runTest(delay)
88		if cntPermDenied > 0 && cntCanceled > 0 {
89			// Delay that causes the race is found.
90			return
91		}
92
93		// Set OK flags.
94		if cntCanceled > 0 {
95			canceledOk = true
96		}
97		if cntPermDenied > 0 {
98			permDeniedOk = true
99		}
100
101		if cntPermDenied == 0 {
102			// No perm denied, increase the delay.
103			lower += (upper-lower)/10 + 1
104		} else {
105			// All perm denied, decrease the delay.
106			upper -= (upper-lower)/10 + 1
107		}
108	}
109
110	if !canceledOk || !permDeniedOk {
111		t.Fatalf(`couldn't find the delay that causes canceled/perm denied race.`)
112	}
113}
114
115// To make sure that canceling a stream with compression enabled won't result in
116// internal error, compressed flag set with identity or empty encoding.
117//
118// The root cause is a select race on stream headerChan and ctx. Stream gets
119// whether compression is enabled and the compression type from two separate
120// functions, both include select with context. If the `case non-ctx:` wins the
121// first one, but `case ctx.Done()` wins the second one, the compression info
122// will be inconsistent, and it causes internal error.
123func (s) TestCancelWhileRecvingWithCompression(t *testing.T) {
124	ss := &stubServer{
125		fullDuplexCall: func(stream testpb.TestService_FullDuplexCallServer) error {
126			for {
127				if err := stream.Send(&testpb.StreamingOutputCallResponse{
128					Payload: nil,
129				}); err != nil {
130					return err
131				}
132			}
133		},
134	}
135	if err := ss.Start(nil); err != nil {
136		t.Fatalf("Error starting endpoint server: %v", err)
137	}
138	defer ss.Stop()
139
140	for i := 0; i < 10; i++ {
141		ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
142		s, err := ss.client.FullDuplexCall(ctx, grpc.UseCompressor(gzip.Name))
143		if err != nil {
144			t.Fatalf("failed to start bidi streaming RPC: %v", err)
145		}
146		// Cancel the stream while receiving to trigger the internal error.
147		time.AfterFunc(time.Millisecond, cancel)
148		for {
149			_, err := s.Recv()
150			if err != nil {
151				if status.Code(err) != codes.Canceled {
152					t.Fatalf("recv failed with %v, want Canceled", err)
153				}
154				break
155			}
156		}
157	}
158}
159