1/* 2 * 3 * Copyright 2014 gRPC authors. 4 * 5 * Licensed under the Apache License, Version 2.0 (the "License"); 6 * you may not use this file except in compliance with the License. 7 * You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 * 17 */ 18 19package grpc 20 21import ( 22 "context" 23 "fmt" 24 "io" 25 "math" 26 "net" 27 "strconv" 28 "strings" 29 "sync" 30 "testing" 31 "time" 32 33 "google.golang.org/grpc/codes" 34 "google.golang.org/grpc/internal/transport" 35 "google.golang.org/grpc/status" 36) 37 38var ( 39 expectedRequest = "ping" 40 expectedResponse = "pong" 41 weirdError = "format verbs: %v%s" 42 sizeLargeErr = 1024 * 1024 43 canceled = 0 44) 45 46type testCodec struct { 47} 48 49func (testCodec) Marshal(v interface{}) ([]byte, error) { 50 return []byte(*(v.(*string))), nil 51} 52 53func (testCodec) Unmarshal(data []byte, v interface{}) error { 54 *(v.(*string)) = string(data) 55 return nil 56} 57 58func (testCodec) String() string { 59 return "test" 60} 61 62type testStreamHandler struct { 63 port string 64 t transport.ServerTransport 65} 66 67func (h *testStreamHandler) handleStream(t *testing.T, s *transport.Stream) { 68 p := &parser{r: s} 69 for { 70 pf, req, err := p.recvMsg(math.MaxInt32) 71 if err == io.EOF { 72 break 73 } 74 if err != nil { 75 return 76 } 77 if pf != compressionNone { 78 t.Errorf("Received the mistaken message format %d, want %d", pf, compressionNone) 79 return 80 } 81 var v string 82 codec := testCodec{} 83 if err := codec.Unmarshal(req, &v); err != nil { 84 t.Errorf("Failed to unmarshal the received message: %v", err) 85 return 86 } 87 if v == "weird error" { 88 h.t.WriteStatus(s, status.New(codes.Internal, weirdError)) 89 return 90 } 91 if v == "canceled" { 92 canceled++ 93 h.t.WriteStatus(s, status.New(codes.Internal, "")) 94 return 95 } 96 if v == "port" { 97 h.t.WriteStatus(s, status.New(codes.Internal, h.port)) 98 return 99 } 100 101 if v != expectedRequest { 102 h.t.WriteStatus(s, status.New(codes.Internal, strings.Repeat("A", sizeLargeErr))) 103 return 104 } 105 } 106 // send a response back to end the stream. 107 data, err := encode(testCodec{}, &expectedResponse) 108 if err != nil { 109 t.Errorf("Failed to encode the response: %v", err) 110 return 111 } 112 hdr, payload := msgHeader(data, nil) 113 h.t.Write(s, hdr, payload, &transport.Options{}) 114 h.t.WriteStatus(s, status.New(codes.OK, "")) 115} 116 117type server struct { 118 lis net.Listener 119 port string 120 addr string 121 startedErr chan error // sent nil or an error after server starts 122 mu sync.Mutex 123 conns map[transport.ServerTransport]bool 124} 125 126type ctxKey string 127 128func newTestServer() *server { 129 return &server{startedErr: make(chan error, 1)} 130} 131 132// start starts server. Other goroutines should block on s.startedErr for further operations. 133func (s *server) start(t *testing.T, port int, maxStreams uint32) { 134 var err error 135 if port == 0 { 136 s.lis, err = net.Listen("tcp", "localhost:0") 137 } else { 138 s.lis, err = net.Listen("tcp", "localhost:"+strconv.Itoa(port)) 139 } 140 if err != nil { 141 s.startedErr <- fmt.Errorf("failed to listen: %v", err) 142 return 143 } 144 s.addr = s.lis.Addr().String() 145 _, p, err := net.SplitHostPort(s.addr) 146 if err != nil { 147 s.startedErr <- fmt.Errorf("failed to parse listener address: %v", err) 148 return 149 } 150 s.port = p 151 s.conns = make(map[transport.ServerTransport]bool) 152 s.startedErr <- nil 153 for { 154 conn, err := s.lis.Accept() 155 if err != nil { 156 return 157 } 158 config := &transport.ServerConfig{ 159 MaxStreams: maxStreams, 160 } 161 st, err := transport.NewServerTransport("http2", conn, config) 162 if err != nil { 163 continue 164 } 165 s.mu.Lock() 166 if s.conns == nil { 167 s.mu.Unlock() 168 st.Close() 169 return 170 } 171 s.conns[st] = true 172 s.mu.Unlock() 173 h := &testStreamHandler{ 174 port: s.port, 175 t: st, 176 } 177 go st.HandleStreams(func(s *transport.Stream) { 178 go h.handleStream(t, s) 179 }, func(ctx context.Context, method string) context.Context { 180 return ctx 181 }) 182 } 183} 184 185func (s *server) wait(t *testing.T, timeout time.Duration) { 186 select { 187 case err := <-s.startedErr: 188 if err != nil { 189 t.Fatal(err) 190 } 191 case <-time.After(timeout): 192 t.Fatalf("Timed out after %v waiting for server to be ready", timeout) 193 } 194} 195 196func (s *server) stop() { 197 s.lis.Close() 198 s.mu.Lock() 199 for c := range s.conns { 200 c.Close() 201 } 202 s.conns = nil 203 s.mu.Unlock() 204} 205 206func setUp(t *testing.T, port int, maxStreams uint32) (*server, *ClientConn) { 207 return setUpWithOptions(t, port, maxStreams) 208} 209 210func setUpWithOptions(t *testing.T, port int, maxStreams uint32, dopts ...DialOption) (*server, *ClientConn) { 211 server := newTestServer() 212 go server.start(t, port, maxStreams) 213 server.wait(t, 2*time.Second) 214 addr := "localhost:" + server.port 215 dopts = append(dopts, WithBlock(), WithInsecure(), WithCodec(testCodec{})) 216 cc, err := Dial(addr, dopts...) 217 if err != nil { 218 t.Fatalf("Failed to create ClientConn: %v", err) 219 } 220 return server, cc 221} 222 223func (s) TestUnaryClientInterceptor(t *testing.T) { 224 parentKey := ctxKey("parentKey") 225 226 interceptor := func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error { 227 if ctx.Value(parentKey) == nil { 228 t.Fatalf("interceptor should have %v in context", parentKey) 229 } 230 return invoker(ctx, method, req, reply, cc, opts...) 231 } 232 233 server, cc := setUpWithOptions(t, 0, math.MaxUint32, WithUnaryInterceptor(interceptor)) 234 defer func() { 235 cc.Close() 236 server.stop() 237 }() 238 239 var reply string 240 ctx := context.Background() 241 parentCtx := context.WithValue(ctx, ctxKey("parentKey"), 0) 242 if err := cc.Invoke(parentCtx, "/foo/bar", &expectedRequest, &reply); err != nil || reply != expectedResponse { 243 t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want <nil>", err) 244 } 245} 246 247func (s) TestChainUnaryClientInterceptor(t *testing.T) { 248 var ( 249 parentKey = ctxKey("parentKey") 250 firstIntKey = ctxKey("firstIntKey") 251 secondIntKey = ctxKey("secondIntKey") 252 ) 253 254 firstInt := func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error { 255 if ctx.Value(parentKey) == nil { 256 t.Fatalf("first interceptor should have %v in context", parentKey) 257 } 258 if ctx.Value(firstIntKey) != nil { 259 t.Fatalf("first interceptor should not have %v in context", firstIntKey) 260 } 261 if ctx.Value(secondIntKey) != nil { 262 t.Fatalf("first interceptor should not have %v in context", secondIntKey) 263 } 264 firstCtx := context.WithValue(ctx, firstIntKey, 1) 265 err := invoker(firstCtx, method, req, reply, cc, opts...) 266 *(reply.(*string)) += "1" 267 return err 268 } 269 270 secondInt := func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error { 271 if ctx.Value(parentKey) == nil { 272 t.Fatalf("second interceptor should have %v in context", parentKey) 273 } 274 if ctx.Value(firstIntKey) == nil { 275 t.Fatalf("second interceptor should have %v in context", firstIntKey) 276 } 277 if ctx.Value(secondIntKey) != nil { 278 t.Fatalf("second interceptor should not have %v in context", secondIntKey) 279 } 280 secondCtx := context.WithValue(ctx, secondIntKey, 2) 281 err := invoker(secondCtx, method, req, reply, cc, opts...) 282 *(reply.(*string)) += "2" 283 return err 284 } 285 286 lastInt := func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error { 287 if ctx.Value(parentKey) == nil { 288 t.Fatalf("last interceptor should have %v in context", parentKey) 289 } 290 if ctx.Value(firstIntKey) == nil { 291 t.Fatalf("last interceptor should have %v in context", firstIntKey) 292 } 293 if ctx.Value(secondIntKey) == nil { 294 t.Fatalf("last interceptor should have %v in context", secondIntKey) 295 } 296 err := invoker(ctx, method, req, reply, cc, opts...) 297 *(reply.(*string)) += "3" 298 return err 299 } 300 301 server, cc := setUpWithOptions(t, 0, math.MaxUint32, WithChainUnaryInterceptor(firstInt, secondInt, lastInt)) 302 defer func() { 303 cc.Close() 304 server.stop() 305 }() 306 307 var reply string 308 ctx := context.Background() 309 parentCtx := context.WithValue(ctx, ctxKey("parentKey"), 0) 310 if err := cc.Invoke(parentCtx, "/foo/bar", &expectedRequest, &reply); err != nil || reply != expectedResponse+"321" { 311 t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want <nil>", err) 312 } 313} 314 315func (s) TestChainOnBaseUnaryClientInterceptor(t *testing.T) { 316 var ( 317 parentKey = ctxKey("parentKey") 318 baseIntKey = ctxKey("baseIntKey") 319 ) 320 321 baseInt := func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error { 322 if ctx.Value(parentKey) == nil { 323 t.Fatalf("base interceptor should have %v in context", parentKey) 324 } 325 if ctx.Value(baseIntKey) != nil { 326 t.Fatalf("base interceptor should not have %v in context", baseIntKey) 327 } 328 baseCtx := context.WithValue(ctx, baseIntKey, 1) 329 return invoker(baseCtx, method, req, reply, cc, opts...) 330 } 331 332 chainInt := func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error { 333 if ctx.Value(parentKey) == nil { 334 t.Fatalf("chain interceptor should have %v in context", parentKey) 335 } 336 if ctx.Value(baseIntKey) == nil { 337 t.Fatalf("chain interceptor should have %v in context", baseIntKey) 338 } 339 return invoker(ctx, method, req, reply, cc, opts...) 340 } 341 342 server, cc := setUpWithOptions(t, 0, math.MaxUint32, WithUnaryInterceptor(baseInt), WithChainUnaryInterceptor(chainInt)) 343 defer func() { 344 cc.Close() 345 server.stop() 346 }() 347 348 var reply string 349 ctx := context.Background() 350 parentCtx := context.WithValue(ctx, ctxKey("parentKey"), 0) 351 if err := cc.Invoke(parentCtx, "/foo/bar", &expectedRequest, &reply); err != nil || reply != expectedResponse { 352 t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want <nil>", err) 353 } 354} 355 356func (s) TestChainStreamClientInterceptor(t *testing.T) { 357 var ( 358 parentKey = ctxKey("parentKey") 359 firstIntKey = ctxKey("firstIntKey") 360 secondIntKey = ctxKey("secondIntKey") 361 ) 362 363 firstInt := func(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, streamer Streamer, opts ...CallOption) (ClientStream, error) { 364 if ctx.Value(parentKey) == nil { 365 t.Fatalf("first interceptor should have %v in context", parentKey) 366 } 367 if ctx.Value(firstIntKey) != nil { 368 t.Fatalf("first interceptor should not have %v in context", firstIntKey) 369 } 370 if ctx.Value(secondIntKey) != nil { 371 t.Fatalf("first interceptor should not have %v in context", secondIntKey) 372 } 373 firstCtx := context.WithValue(ctx, firstIntKey, 1) 374 return streamer(firstCtx, desc, cc, method, opts...) 375 } 376 377 secondInt := func(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, streamer Streamer, opts ...CallOption) (ClientStream, error) { 378 if ctx.Value(parentKey) == nil { 379 t.Fatalf("second interceptor should have %v in context", parentKey) 380 } 381 if ctx.Value(firstIntKey) == nil { 382 t.Fatalf("second interceptor should have %v in context", firstIntKey) 383 } 384 if ctx.Value(secondIntKey) != nil { 385 t.Fatalf("second interceptor should not have %v in context", secondIntKey) 386 } 387 secondCtx := context.WithValue(ctx, secondIntKey, 2) 388 return streamer(secondCtx, desc, cc, method, opts...) 389 } 390 391 lastInt := func(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, streamer Streamer, opts ...CallOption) (ClientStream, error) { 392 if ctx.Value(parentKey) == nil { 393 t.Fatalf("last interceptor should have %v in context", parentKey) 394 } 395 if ctx.Value(firstIntKey) == nil { 396 t.Fatalf("last interceptor should have %v in context", firstIntKey) 397 } 398 if ctx.Value(secondIntKey) == nil { 399 t.Fatalf("last interceptor should have %v in context", secondIntKey) 400 } 401 return streamer(ctx, desc, cc, method, opts...) 402 } 403 404 server, cc := setUpWithOptions(t, 0, math.MaxUint32, WithChainStreamInterceptor(firstInt, secondInt, lastInt)) 405 defer func() { 406 cc.Close() 407 server.stop() 408 }() 409 410 ctx := context.Background() 411 parentCtx := context.WithValue(ctx, ctxKey("parentKey"), 0) 412 _, err := cc.NewStream(parentCtx, &StreamDesc{}, "/foo/bar") 413 if err != nil { 414 t.Fatalf("grpc.NewStream(_, _, _) = %v, want <nil>", err) 415 } 416} 417 418func (s) TestInvoke(t *testing.T) { 419 server, cc := setUp(t, 0, math.MaxUint32) 420 var reply string 421 if err := cc.Invoke(context.Background(), "/foo/bar", &expectedRequest, &reply); err != nil || reply != expectedResponse { 422 t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want <nil>", err) 423 } 424 cc.Close() 425 server.stop() 426} 427 428func (s) TestInvokeLargeErr(t *testing.T) { 429 server, cc := setUp(t, 0, math.MaxUint32) 430 var reply string 431 req := "hello" 432 err := cc.Invoke(context.Background(), "/foo/bar", &req, &reply) 433 if _, ok := status.FromError(err); !ok { 434 t.Fatalf("grpc.Invoke(_, _, _, _, _) receives non rpc error.") 435 } 436 if status.Code(err) != codes.Internal || len(errorDesc(err)) != sizeLargeErr { 437 t.Fatalf("grpc.Invoke(_, _, _, _, _) = %v, want an error of code %d and desc size %d", err, codes.Internal, sizeLargeErr) 438 } 439 cc.Close() 440 server.stop() 441} 442 443// TestInvokeErrorSpecialChars checks that error messages don't get mangled. 444func (s) TestInvokeErrorSpecialChars(t *testing.T) { 445 server, cc := setUp(t, 0, math.MaxUint32) 446 var reply string 447 req := "weird error" 448 err := cc.Invoke(context.Background(), "/foo/bar", &req, &reply) 449 if _, ok := status.FromError(err); !ok { 450 t.Fatalf("grpc.Invoke(_, _, _, _, _) receives non rpc error.") 451 } 452 if got, want := errorDesc(err), weirdError; got != want { 453 t.Fatalf("grpc.Invoke(_, _, _, _, _) error = %q, want %q", got, want) 454 } 455 cc.Close() 456 server.stop() 457} 458 459// TestInvokeCancel checks that an Invoke with a canceled context is not sent. 460func (s) TestInvokeCancel(t *testing.T) { 461 server, cc := setUp(t, 0, math.MaxUint32) 462 var reply string 463 req := "canceled" 464 for i := 0; i < 100; i++ { 465 ctx, cancel := context.WithCancel(context.Background()) 466 cancel() 467 cc.Invoke(ctx, "/foo/bar", &req, &reply) 468 } 469 if canceled != 0 { 470 t.Fatalf("received %d of 100 canceled requests", canceled) 471 } 472 cc.Close() 473 server.stop() 474} 475 476// TestInvokeCancelClosedNonFail checks that a canceled non-failfast RPC 477// on a closed client will terminate. 478func (s) TestInvokeCancelClosedNonFailFast(t *testing.T) { 479 server, cc := setUp(t, 0, math.MaxUint32) 480 var reply string 481 cc.Close() 482 req := "hello" 483 ctx, cancel := context.WithCancel(context.Background()) 484 cancel() 485 if err := cc.Invoke(ctx, "/foo/bar", &req, &reply, WaitForReady(true)); err == nil { 486 t.Fatalf("canceled invoke on closed connection should fail") 487 } 488 server.stop() 489} 490