1/* 2 * 3 * Copyright 2020 gRPC authors. 4 * 5 * Licensed under the Apache License, Version 2.0 (the "License"); 6 * you may not use this file except in compliance with the License. 7 * You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 * 17 */ 18 19package test 20 21import ( 22 "context" 23 "io" 24 "testing" 25 26 "google.golang.org/grpc" 27 "google.golang.org/grpc/codes" 28 "google.golang.org/grpc/status" 29 testpb "google.golang.org/grpc/test/grpc_testing" 30) 31 32type ctxKey string 33 34func (s) TestChainUnaryServerInterceptor(t *testing.T) { 35 var ( 36 firstIntKey = ctxKey("firstIntKey") 37 secondIntKey = ctxKey("secondIntKey") 38 ) 39 40 firstInt := func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { 41 if ctx.Value(firstIntKey) != nil { 42 return nil, status.Errorf(codes.Internal, "first interceptor should not have %v in context", firstIntKey) 43 } 44 if ctx.Value(secondIntKey) != nil { 45 return nil, status.Errorf(codes.Internal, "first interceptor should not have %v in context", secondIntKey) 46 } 47 48 firstCtx := context.WithValue(ctx, firstIntKey, 0) 49 resp, err := handler(firstCtx, req) 50 if err != nil { 51 return nil, status.Errorf(codes.Internal, "failed to handle request at firstInt") 52 } 53 54 simpleResp, ok := resp.(*testpb.SimpleResponse) 55 if !ok { 56 return nil, status.Errorf(codes.Internal, "failed to get *testpb.SimpleResponse at firstInt") 57 } 58 return &testpb.SimpleResponse{ 59 Payload: &testpb.Payload{ 60 Type: simpleResp.GetPayload().GetType(), 61 Body: append(simpleResp.GetPayload().GetBody(), '1'), 62 }, 63 }, nil 64 } 65 66 secondInt := func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { 67 if ctx.Value(firstIntKey) == nil { 68 return nil, status.Errorf(codes.Internal, "second interceptor should have %v in context", firstIntKey) 69 } 70 if ctx.Value(secondIntKey) != nil { 71 return nil, status.Errorf(codes.Internal, "second interceptor should not have %v in context", secondIntKey) 72 } 73 74 secondCtx := context.WithValue(ctx, secondIntKey, 1) 75 resp, err := handler(secondCtx, req) 76 if err != nil { 77 return nil, status.Errorf(codes.Internal, "failed to handle request at secondInt") 78 } 79 80 simpleResp, ok := resp.(*testpb.SimpleResponse) 81 if !ok { 82 return nil, status.Errorf(codes.Internal, "failed to get *testpb.SimpleResponse at secondInt") 83 } 84 return &testpb.SimpleResponse{ 85 Payload: &testpb.Payload{ 86 Type: simpleResp.GetPayload().GetType(), 87 Body: append(simpleResp.GetPayload().GetBody(), '2'), 88 }, 89 }, nil 90 } 91 92 lastInt := func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { 93 if ctx.Value(firstIntKey) == nil { 94 return nil, status.Errorf(codes.Internal, "last interceptor should have %v in context", firstIntKey) 95 } 96 if ctx.Value(secondIntKey) == nil { 97 return nil, status.Errorf(codes.Internal, "last interceptor should not have %v in context", secondIntKey) 98 } 99 100 resp, err := handler(ctx, req) 101 if err != nil { 102 return nil, status.Errorf(codes.Internal, "failed to handle request at lastInt at lastInt") 103 } 104 105 simpleResp, ok := resp.(*testpb.SimpleResponse) 106 if !ok { 107 return nil, status.Errorf(codes.Internal, "failed to get *testpb.SimpleResponse at lastInt") 108 } 109 return &testpb.SimpleResponse{ 110 Payload: &testpb.Payload{ 111 Type: simpleResp.GetPayload().GetType(), 112 Body: append(simpleResp.GetPayload().GetBody(), '3'), 113 }, 114 }, nil 115 } 116 117 sopts := []grpc.ServerOption{ 118 grpc.ChainUnaryInterceptor(firstInt, secondInt, lastInt), 119 } 120 121 ss := &stubServer{ 122 unaryCall: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { 123 payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, 0) 124 if err != nil { 125 return nil, status.Errorf(codes.Aborted, "failed to make payload: %v", err) 126 } 127 128 return &testpb.SimpleResponse{ 129 Payload: payload, 130 }, nil 131 }, 132 } 133 if err := ss.Start(sopts); err != nil { 134 t.Fatalf("Error starting endpoint server: %v", err) 135 } 136 defer ss.Stop() 137 138 resp, err := ss.client.UnaryCall(context.Background(), &testpb.SimpleRequest{}) 139 if s, ok := status.FromError(err); !ok || s.Code() != codes.OK { 140 t.Fatalf("ss.client.UnaryCall(context.Background(), _) = %v, %v; want nil, <status with Code()=OK>", resp, err) 141 } 142 143 respBytes := resp.Payload.GetBody() 144 if string(respBytes) != "321" { 145 t.Fatalf("invalid response: want=%s, but got=%s", "321", resp) 146 } 147} 148 149func (s) TestChainOnBaseUnaryServerInterceptor(t *testing.T) { 150 baseIntKey := ctxKey("baseIntKey") 151 152 baseInt := func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { 153 if ctx.Value(baseIntKey) != nil { 154 return nil, status.Errorf(codes.Internal, "base interceptor should not have %v in context", baseIntKey) 155 } 156 157 baseCtx := context.WithValue(ctx, baseIntKey, 1) 158 return handler(baseCtx, req) 159 } 160 161 chainInt := func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) { 162 if ctx.Value(baseIntKey) == nil { 163 return nil, status.Errorf(codes.Internal, "chain interceptor should have %v in context", baseIntKey) 164 } 165 166 return handler(ctx, req) 167 } 168 169 sopts := []grpc.ServerOption{ 170 grpc.UnaryInterceptor(baseInt), 171 grpc.ChainUnaryInterceptor(chainInt), 172 } 173 174 ss := &stubServer{ 175 emptyCall: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) { 176 return &testpb.Empty{}, nil 177 }, 178 } 179 if err := ss.Start(sopts); err != nil { 180 t.Fatalf("Error starting endpoint server: %v", err) 181 } 182 defer ss.Stop() 183 184 resp, err := ss.client.EmptyCall(context.Background(), &testpb.Empty{}) 185 if s, ok := status.FromError(err); !ok || s.Code() != codes.OK { 186 t.Fatalf("ss.client.EmptyCall(context.Background(), _) = %v, %v; want nil, <status with Code()=OK>", resp, err) 187 } 188} 189 190func (s) TestChainStreamServerInterceptor(t *testing.T) { 191 callCounts := make([]int, 4) 192 193 firstInt := func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { 194 if callCounts[0] != 0 { 195 return status.Errorf(codes.Internal, "callCounts[0] should be 0, but got=%d", callCounts[0]) 196 } 197 if callCounts[1] != 0 { 198 return status.Errorf(codes.Internal, "callCounts[1] should be 0, but got=%d", callCounts[1]) 199 } 200 if callCounts[2] != 0 { 201 return status.Errorf(codes.Internal, "callCounts[2] should be 0, but got=%d", callCounts[2]) 202 } 203 if callCounts[3] != 0 { 204 return status.Errorf(codes.Internal, "callCounts[3] should be 0, but got=%d", callCounts[3]) 205 } 206 callCounts[0]++ 207 return handler(srv, stream) 208 } 209 210 secondInt := func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { 211 if callCounts[0] != 1 { 212 return status.Errorf(codes.Internal, "callCounts[0] should be 1, but got=%d", callCounts[0]) 213 } 214 if callCounts[1] != 0 { 215 return status.Errorf(codes.Internal, "callCounts[1] should be 0, but got=%d", callCounts[1]) 216 } 217 if callCounts[2] != 0 { 218 return status.Errorf(codes.Internal, "callCounts[2] should be 0, but got=%d", callCounts[2]) 219 } 220 if callCounts[3] != 0 { 221 return status.Errorf(codes.Internal, "callCounts[3] should be 0, but got=%d", callCounts[3]) 222 } 223 callCounts[1]++ 224 return handler(srv, stream) 225 } 226 227 lastInt := func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { 228 if callCounts[0] != 1 { 229 return status.Errorf(codes.Internal, "callCounts[0] should be 1, but got=%d", callCounts[0]) 230 } 231 if callCounts[1] != 1 { 232 return status.Errorf(codes.Internal, "callCounts[1] should be 1, but got=%d", callCounts[1]) 233 } 234 if callCounts[2] != 0 { 235 return status.Errorf(codes.Internal, "callCounts[2] should be 0, but got=%d", callCounts[2]) 236 } 237 if callCounts[3] != 0 { 238 return status.Errorf(codes.Internal, "callCounts[3] should be 0, but got=%d", callCounts[3]) 239 } 240 callCounts[2]++ 241 return handler(srv, stream) 242 } 243 244 sopts := []grpc.ServerOption{ 245 grpc.ChainStreamInterceptor(firstInt, secondInt, lastInt), 246 } 247 248 ss := &stubServer{ 249 fullDuplexCall: func(stream testpb.TestService_FullDuplexCallServer) error { 250 if callCounts[0] != 1 { 251 return status.Errorf(codes.Internal, "callCounts[0] should be 1, but got=%d", callCounts[0]) 252 } 253 if callCounts[1] != 1 { 254 return status.Errorf(codes.Internal, "callCounts[1] should be 1, but got=%d", callCounts[1]) 255 } 256 if callCounts[2] != 1 { 257 return status.Errorf(codes.Internal, "callCounts[2] should be 0, but got=%d", callCounts[2]) 258 } 259 if callCounts[3] != 0 { 260 return status.Errorf(codes.Internal, "callCounts[3] should be 0, but got=%d", callCounts[3]) 261 } 262 callCounts[3]++ 263 return nil 264 }, 265 } 266 if err := ss.Start(sopts); err != nil { 267 t.Fatalf("Error starting endpoint server: %v", err) 268 } 269 defer ss.Stop() 270 271 stream, err := ss.client.FullDuplexCall(context.Background()) 272 if err != nil { 273 t.Fatalf("failed to FullDuplexCall: %v", err) 274 } 275 276 _, err = stream.Recv() 277 if err != io.EOF { 278 t.Fatalf("failed to recv from stream: %v", err) 279 } 280 281 if callCounts[3] != 1 { 282 t.Fatalf("callCounts[3] should be 1, but got=%d", callCounts[3]) 283 } 284} 285