1/* 2 * 3 * Copyright 2019 gRPC authors. 4 * 5 * Licensed under the Apache License, Version 2.0 (the "License"); 6 * you may not use this file except in compliance with the License. 7 * You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 * 17 */ 18 19package test 20 21import ( 22 "context" 23 "testing" 24 "time" 25 26 "google.golang.org/grpc" 27 "google.golang.org/grpc/codes" 28 "google.golang.org/grpc/encoding/gzip" 29 "google.golang.org/grpc/metadata" 30 "google.golang.org/grpc/status" 31 testpb "google.golang.org/grpc/test/grpc_testing" 32) 33 34func (s) TestContextCanceled(t *testing.T) { 35 ss := &stubServer{ 36 fullDuplexCall: func(stream testpb.TestService_FullDuplexCallServer) error { 37 stream.SetTrailer(metadata.New(map[string]string{"a": "b"})) 38 return status.Error(codes.PermissionDenied, "perm denied") 39 }, 40 } 41 if err := ss.Start(nil); err != nil { 42 t.Fatalf("Error starting endpoint server: %v", err) 43 } 44 defer ss.Stop() 45 46 // Runs 10 rounds of tests with the given delay and returns counts of status codes. 47 // Fails in case of trailer/status code inconsistency. 48 const cntRetry uint = 10 49 runTest := func(delay time.Duration) (cntCanceled, cntPermDenied uint) { 50 for i := uint(0); i < cntRetry; i++ { 51 ctx, cancel := context.WithTimeout(context.Background(), delay) 52 defer cancel() 53 54 str, err := ss.client.FullDuplexCall(ctx) 55 if err != nil { 56 continue 57 } 58 59 _, err = str.Recv() 60 if err == nil { 61 t.Fatalf("non-nil error expected from Recv()") 62 } 63 64 _, trlOk := str.Trailer()["a"] 65 switch status.Code(err) { 66 case codes.PermissionDenied: 67 if !trlOk { 68 t.Fatalf(`status err: %v; wanted key "a" in trailer but didn't get it`, err) 69 } 70 cntPermDenied++ 71 case codes.DeadlineExceeded: 72 if trlOk { 73 t.Fatalf(`status err: %v; didn't want key "a" in trailer but got it`, err) 74 } 75 cntCanceled++ 76 default: 77 t.Fatalf(`unexpected status err: %v`, err) 78 } 79 } 80 return cntCanceled, cntPermDenied 81 } 82 83 // Tries to find the delay that causes canceled/perm denied race. 84 canceledOk, permDeniedOk := false, false 85 for lower, upper := time.Duration(0), 2*time.Millisecond; lower <= upper; { 86 delay := lower + (upper-lower)/2 87 cntCanceled, cntPermDenied := runTest(delay) 88 if cntPermDenied > 0 && cntCanceled > 0 { 89 // Delay that causes the race is found. 90 return 91 } 92 93 // Set OK flags. 94 if cntCanceled > 0 { 95 canceledOk = true 96 } 97 if cntPermDenied > 0 { 98 permDeniedOk = true 99 } 100 101 if cntPermDenied == 0 { 102 // No perm denied, increase the delay. 103 lower += (upper-lower)/10 + 1 104 } else { 105 // All perm denied, decrease the delay. 106 upper -= (upper-lower)/10 + 1 107 } 108 } 109 110 if !canceledOk || !permDeniedOk { 111 t.Fatalf(`couldn't find the delay that causes canceled/perm denied race.`) 112 } 113} 114 115// To make sure that canceling a stream with compression enabled won't result in 116// internal error, compressed flag set with identity or empty encoding. 117// 118// The root cause is a select race on stream headerChan and ctx. Stream gets 119// whether compression is enabled and the compression type from two separate 120// functions, both include select with context. If the `case non-ctx:` wins the 121// first one, but `case ctx.Done()` wins the second one, the compression info 122// will be inconsistent, and it causes internal error. 123func (s) TestCancelWhileRecvingWithCompression(t *testing.T) { 124 ss := &stubServer{ 125 fullDuplexCall: func(stream testpb.TestService_FullDuplexCallServer) error { 126 for { 127 if err := stream.Send(&testpb.StreamingOutputCallResponse{ 128 Payload: nil, 129 }); err != nil { 130 return err 131 } 132 } 133 }, 134 } 135 if err := ss.Start(nil); err != nil { 136 t.Fatalf("Error starting endpoint server: %v", err) 137 } 138 defer ss.Stop() 139 140 for i := 0; i < 10; i++ { 141 ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) 142 s, err := ss.client.FullDuplexCall(ctx, grpc.UseCompressor(gzip.Name)) 143 if err != nil { 144 t.Fatalf("failed to start bidi streaming RPC: %v", err) 145 } 146 // Cancel the stream while receiving to trigger the internal error. 147 time.AfterFunc(time.Millisecond, cancel) 148 for { 149 _, err := s.Recv() 150 if err != nil { 151 if status.Code(err) != codes.Canceled { 152 t.Fatalf("recv failed with %v, want Canceled", err) 153 } 154 break 155 } 156 } 157 } 158} 159