1/* 2 * 3 * Copyright 2016 gRPC authors. 4 * 5 * Licensed under the Apache License, Version 2.0 (the "License"); 6 * you may not use this file except in compliance with the License. 7 * You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 * 17 */ 18 19package stats_test 20 21import ( 22 "context" 23 "fmt" 24 "io" 25 "net" 26 "reflect" 27 "sync" 28 "testing" 29 "time" 30 31 "github.com/golang/protobuf/proto" 32 "google.golang.org/grpc" 33 "google.golang.org/grpc/metadata" 34 "google.golang.org/grpc/stats" 35 testpb "google.golang.org/grpc/stats/grpc_testing" 36 "google.golang.org/grpc/status" 37) 38 39func init() { 40 grpc.EnableTracing = false 41} 42 43type connCtxKey struct{} 44type rpcCtxKey struct{} 45 46var ( 47 // For headers: 48 testMetadata = metadata.MD{ 49 "key1": []string{"value1"}, 50 "key2": []string{"value2"}, 51 } 52 // For trailers: 53 testTrailerMetadata = metadata.MD{ 54 "tkey1": []string{"trailerValue1"}, 55 "tkey2": []string{"trailerValue2"}, 56 } 57 // The id for which the service handler should return error. 58 errorID int32 = 32202 59) 60 61type testServer struct{} 62 63func (s *testServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { 64 md, ok := metadata.FromIncomingContext(ctx) 65 if ok { 66 if err := grpc.SendHeader(ctx, md); err != nil { 67 return nil, status.Errorf(status.Code(err), "grpc.SendHeader(_, %v) = %v, want <nil>", md, err) 68 } 69 if err := grpc.SetTrailer(ctx, testTrailerMetadata); err != nil { 70 return nil, status.Errorf(status.Code(err), "grpc.SetTrailer(_, %v) = %v, want <nil>", testTrailerMetadata, err) 71 } 72 } 73 74 if in.Id == errorID { 75 return nil, fmt.Errorf("got error id: %v", in.Id) 76 } 77 78 return &testpb.SimpleResponse{Id: in.Id}, nil 79} 80 81func (s *testServer) FullDuplexCall(stream testpb.TestService_FullDuplexCallServer) error { 82 md, ok := metadata.FromIncomingContext(stream.Context()) 83 if ok { 84 if err := stream.SendHeader(md); err != nil { 85 return status.Errorf(status.Code(err), "%v.SendHeader(%v) = %v, want %v", stream, md, err, nil) 86 } 87 stream.SetTrailer(testTrailerMetadata) 88 } 89 for { 90 in, err := stream.Recv() 91 if err == io.EOF { 92 // read done. 93 return nil 94 } 95 if err != nil { 96 return err 97 } 98 99 if in.Id == errorID { 100 return fmt.Errorf("got error id: %v", in.Id) 101 } 102 103 if err := stream.Send(&testpb.SimpleResponse{Id: in.Id}); err != nil { 104 return err 105 } 106 } 107} 108 109func (s *testServer) ClientStreamCall(stream testpb.TestService_ClientStreamCallServer) error { 110 md, ok := metadata.FromIncomingContext(stream.Context()) 111 if ok { 112 if err := stream.SendHeader(md); err != nil { 113 return status.Errorf(status.Code(err), "%v.SendHeader(%v) = %v, want %v", stream, md, err, nil) 114 } 115 stream.SetTrailer(testTrailerMetadata) 116 } 117 for { 118 in, err := stream.Recv() 119 if err == io.EOF { 120 // read done. 121 return stream.SendAndClose(&testpb.SimpleResponse{Id: int32(0)}) 122 } 123 if err != nil { 124 return err 125 } 126 127 if in.Id == errorID { 128 return fmt.Errorf("got error id: %v", in.Id) 129 } 130 } 131} 132 133func (s *testServer) ServerStreamCall(in *testpb.SimpleRequest, stream testpb.TestService_ServerStreamCallServer) error { 134 md, ok := metadata.FromIncomingContext(stream.Context()) 135 if ok { 136 if err := stream.SendHeader(md); err != nil { 137 return status.Errorf(status.Code(err), "%v.SendHeader(%v) = %v, want %v", stream, md, err, nil) 138 } 139 stream.SetTrailer(testTrailerMetadata) 140 } 141 142 if in.Id == errorID { 143 return fmt.Errorf("got error id: %v", in.Id) 144 } 145 146 for i := 0; i < 5; i++ { 147 if err := stream.Send(&testpb.SimpleResponse{Id: in.Id}); err != nil { 148 return err 149 } 150 } 151 return nil 152} 153 154// test is an end-to-end test. It should be created with the newTest 155// func, modified as needed, and then started with its startServer method. 156// It should be cleaned up with the tearDown method. 157type test struct { 158 t *testing.T 159 compress string 160 clientStatsHandler stats.Handler 161 serverStatsHandler stats.Handler 162 163 testServer testpb.TestServiceServer // nil means none 164 // srv and srvAddr are set once startServer is called. 165 srv *grpc.Server 166 srvAddr string 167 168 cc *grpc.ClientConn // nil until requested via clientConn 169} 170 171func (te *test) tearDown() { 172 if te.cc != nil { 173 te.cc.Close() 174 te.cc = nil 175 } 176 te.srv.Stop() 177} 178 179type testConfig struct { 180 compress string 181} 182 183// newTest returns a new test using the provided testing.T and 184// environment. It is returned with default values. Tests should 185// modify it before calling its startServer and clientConn methods. 186func newTest(t *testing.T, tc *testConfig, ch stats.Handler, sh stats.Handler) *test { 187 te := &test{ 188 t: t, 189 compress: tc.compress, 190 clientStatsHandler: ch, 191 serverStatsHandler: sh, 192 } 193 return te 194} 195 196// startServer starts a gRPC server listening. Callers should defer a 197// call to te.tearDown to clean up. 198func (te *test) startServer(ts testpb.TestServiceServer) { 199 te.testServer = ts 200 lis, err := net.Listen("tcp", "localhost:0") 201 if err != nil { 202 te.t.Fatalf("Failed to listen: %v", err) 203 } 204 var opts []grpc.ServerOption 205 if te.compress == "gzip" { 206 opts = append(opts, 207 grpc.RPCCompressor(grpc.NewGZIPCompressor()), 208 grpc.RPCDecompressor(grpc.NewGZIPDecompressor()), 209 ) 210 } 211 if te.serverStatsHandler != nil { 212 opts = append(opts, grpc.StatsHandler(te.serverStatsHandler)) 213 } 214 s := grpc.NewServer(opts...) 215 te.srv = s 216 if te.testServer != nil { 217 testpb.RegisterTestServiceServer(s, te.testServer) 218 } 219 220 go s.Serve(lis) 221 te.srvAddr = lis.Addr().String() 222} 223 224func (te *test) clientConn() *grpc.ClientConn { 225 if te.cc != nil { 226 return te.cc 227 } 228 opts := []grpc.DialOption{grpc.WithInsecure(), grpc.WithBlock()} 229 if te.compress == "gzip" { 230 opts = append(opts, 231 grpc.WithCompressor(grpc.NewGZIPCompressor()), 232 grpc.WithDecompressor(grpc.NewGZIPDecompressor()), 233 ) 234 } 235 if te.clientStatsHandler != nil { 236 opts = append(opts, grpc.WithStatsHandler(te.clientStatsHandler)) 237 } 238 239 var err error 240 te.cc, err = grpc.Dial(te.srvAddr, opts...) 241 if err != nil { 242 te.t.Fatalf("Dial(%q) = %v", te.srvAddr, err) 243 } 244 return te.cc 245} 246 247type rpcType int 248 249const ( 250 unaryRPC rpcType = iota 251 clientStreamRPC 252 serverStreamRPC 253 fullDuplexStreamRPC 254) 255 256type rpcConfig struct { 257 count int // Number of requests and responses for streaming RPCs. 258 success bool // Whether the RPC should succeed or return error. 259 failfast bool 260 callType rpcType // Type of RPC. 261} 262 263func (te *test) doUnaryCall(c *rpcConfig) (*testpb.SimpleRequest, *testpb.SimpleResponse, error) { 264 var ( 265 resp *testpb.SimpleResponse 266 req *testpb.SimpleRequest 267 err error 268 ) 269 tc := testpb.NewTestServiceClient(te.clientConn()) 270 if c.success { 271 req = &testpb.SimpleRequest{Id: errorID + 1} 272 } else { 273 req = &testpb.SimpleRequest{Id: errorID} 274 } 275 ctx := metadata.NewOutgoingContext(context.Background(), testMetadata) 276 277 resp, err = tc.UnaryCall(ctx, req, grpc.WaitForReady(!c.failfast)) 278 return req, resp, err 279} 280 281func (te *test) doFullDuplexCallRoundtrip(c *rpcConfig) ([]*testpb.SimpleRequest, []*testpb.SimpleResponse, error) { 282 var ( 283 reqs []*testpb.SimpleRequest 284 resps []*testpb.SimpleResponse 285 err error 286 ) 287 tc := testpb.NewTestServiceClient(te.clientConn()) 288 stream, err := tc.FullDuplexCall(metadata.NewOutgoingContext(context.Background(), testMetadata), grpc.WaitForReady(!c.failfast)) 289 if err != nil { 290 return reqs, resps, err 291 } 292 var startID int32 293 if !c.success { 294 startID = errorID 295 } 296 for i := 0; i < c.count; i++ { 297 req := &testpb.SimpleRequest{ 298 Id: int32(i) + startID, 299 } 300 reqs = append(reqs, req) 301 if err = stream.Send(req); err != nil { 302 return reqs, resps, err 303 } 304 var resp *testpb.SimpleResponse 305 if resp, err = stream.Recv(); err != nil { 306 return reqs, resps, err 307 } 308 resps = append(resps, resp) 309 } 310 if err = stream.CloseSend(); err != nil && err != io.EOF { 311 return reqs, resps, err 312 } 313 if _, err = stream.Recv(); err != io.EOF { 314 return reqs, resps, err 315 } 316 317 return reqs, resps, nil 318} 319 320func (te *test) doClientStreamCall(c *rpcConfig) ([]*testpb.SimpleRequest, *testpb.SimpleResponse, error) { 321 var ( 322 reqs []*testpb.SimpleRequest 323 resp *testpb.SimpleResponse 324 err error 325 ) 326 tc := testpb.NewTestServiceClient(te.clientConn()) 327 stream, err := tc.ClientStreamCall(metadata.NewOutgoingContext(context.Background(), testMetadata), grpc.WaitForReady(!c.failfast)) 328 if err != nil { 329 return reqs, resp, err 330 } 331 var startID int32 332 if !c.success { 333 startID = errorID 334 } 335 for i := 0; i < c.count; i++ { 336 req := &testpb.SimpleRequest{ 337 Id: int32(i) + startID, 338 } 339 reqs = append(reqs, req) 340 if err = stream.Send(req); err != nil { 341 return reqs, resp, err 342 } 343 } 344 resp, err = stream.CloseAndRecv() 345 return reqs, resp, err 346} 347 348func (te *test) doServerStreamCall(c *rpcConfig) (*testpb.SimpleRequest, []*testpb.SimpleResponse, error) { 349 var ( 350 req *testpb.SimpleRequest 351 resps []*testpb.SimpleResponse 352 err error 353 ) 354 355 tc := testpb.NewTestServiceClient(te.clientConn()) 356 357 var startID int32 358 if !c.success { 359 startID = errorID 360 } 361 req = &testpb.SimpleRequest{Id: startID} 362 stream, err := tc.ServerStreamCall(metadata.NewOutgoingContext(context.Background(), testMetadata), req, grpc.WaitForReady(!c.failfast)) 363 if err != nil { 364 return req, resps, err 365 } 366 for { 367 var resp *testpb.SimpleResponse 368 resp, err := stream.Recv() 369 if err == io.EOF { 370 return req, resps, nil 371 } else if err != nil { 372 return req, resps, err 373 } 374 resps = append(resps, resp) 375 } 376} 377 378type expectedData struct { 379 method string 380 serverAddr string 381 compression string 382 reqIdx int 383 requests []*testpb.SimpleRequest 384 respIdx int 385 responses []*testpb.SimpleResponse 386 err error 387 failfast bool 388} 389 390type gotData struct { 391 ctx context.Context 392 client bool 393 s interface{} // This could be RPCStats or ConnStats. 394} 395 396const ( 397 begin int = iota 398 end 399 inPayload 400 inHeader 401 inTrailer 402 outPayload 403 outHeader 404 // TODO: test outTrailer ? 405 connbegin 406 connend 407) 408 409func checkBegin(t *testing.T, d *gotData, e *expectedData) { 410 var ( 411 ok bool 412 st *stats.Begin 413 ) 414 if st, ok = d.s.(*stats.Begin); !ok { 415 t.Fatalf("got %T, want Begin", d.s) 416 } 417 if d.ctx == nil { 418 t.Fatalf("d.ctx = nil, want <non-nil>") 419 } 420 if st.BeginTime.IsZero() { 421 t.Fatalf("st.BeginTime = %v, want <non-zero>", st.BeginTime) 422 } 423 if d.client { 424 if st.FailFast != e.failfast { 425 t.Fatalf("st.FailFast = %v, want %v", st.FailFast, e.failfast) 426 } 427 } 428} 429 430func checkInHeader(t *testing.T, d *gotData, e *expectedData) { 431 var ( 432 ok bool 433 st *stats.InHeader 434 ) 435 if st, ok = d.s.(*stats.InHeader); !ok { 436 t.Fatalf("got %T, want InHeader", d.s) 437 } 438 if d.ctx == nil { 439 t.Fatalf("d.ctx = nil, want <non-nil>") 440 } 441 if !d.client { 442 if st.FullMethod != e.method { 443 t.Fatalf("st.FullMethod = %s, want %v", st.FullMethod, e.method) 444 } 445 if st.LocalAddr.String() != e.serverAddr { 446 t.Fatalf("st.LocalAddr = %v, want %v", st.LocalAddr, e.serverAddr) 447 } 448 if st.Compression != e.compression { 449 t.Fatalf("st.Compression = %v, want %v", st.Compression, e.compression) 450 } 451 452 if connInfo, ok := d.ctx.Value(connCtxKey{}).(*stats.ConnTagInfo); ok { 453 if connInfo.RemoteAddr != st.RemoteAddr { 454 t.Fatalf("connInfo.RemoteAddr = %v, want %v", connInfo.RemoteAddr, st.RemoteAddr) 455 } 456 if connInfo.LocalAddr != st.LocalAddr { 457 t.Fatalf("connInfo.LocalAddr = %v, want %v", connInfo.LocalAddr, st.LocalAddr) 458 } 459 } else { 460 t.Fatalf("got context %v, want one with connCtxKey", d.ctx) 461 } 462 if rpcInfo, ok := d.ctx.Value(rpcCtxKey{}).(*stats.RPCTagInfo); ok { 463 if rpcInfo.FullMethodName != st.FullMethod { 464 t.Fatalf("rpcInfo.FullMethod = %s, want %v", rpcInfo.FullMethodName, st.FullMethod) 465 } 466 } else { 467 t.Fatalf("got context %v, want one with rpcCtxKey", d.ctx) 468 } 469 } 470} 471 472func checkInPayload(t *testing.T, d *gotData, e *expectedData) { 473 var ( 474 ok bool 475 st *stats.InPayload 476 ) 477 if st, ok = d.s.(*stats.InPayload); !ok { 478 t.Fatalf("got %T, want InPayload", d.s) 479 } 480 if d.ctx == nil { 481 t.Fatalf("d.ctx = nil, want <non-nil>") 482 } 483 if d.client { 484 b, err := proto.Marshal(e.responses[e.respIdx]) 485 if err != nil { 486 t.Fatalf("failed to marshal message: %v", err) 487 } 488 if reflect.TypeOf(st.Payload) != reflect.TypeOf(e.responses[e.respIdx]) { 489 t.Fatalf("st.Payload = %T, want %T", st.Payload, e.responses[e.respIdx]) 490 } 491 e.respIdx++ 492 if string(st.Data) != string(b) { 493 t.Fatalf("st.Data = %v, want %v", st.Data, b) 494 } 495 if st.Length != len(b) { 496 t.Fatalf("st.Lenght = %v, want %v", st.Length, len(b)) 497 } 498 } else { 499 b, err := proto.Marshal(e.requests[e.reqIdx]) 500 if err != nil { 501 t.Fatalf("failed to marshal message: %v", err) 502 } 503 if reflect.TypeOf(st.Payload) != reflect.TypeOf(e.requests[e.reqIdx]) { 504 t.Fatalf("st.Payload = %T, want %T", st.Payload, e.requests[e.reqIdx]) 505 } 506 e.reqIdx++ 507 if string(st.Data) != string(b) { 508 t.Fatalf("st.Data = %v, want %v", st.Data, b) 509 } 510 if st.Length != len(b) { 511 t.Fatalf("st.Lenght = %v, want %v", st.Length, len(b)) 512 } 513 } 514 // TODO check WireLength and ReceivedTime. 515 if st.RecvTime.IsZero() { 516 t.Fatalf("st.ReceivedTime = %v, want <non-zero>", st.RecvTime) 517 } 518} 519 520func checkInTrailer(t *testing.T, d *gotData, e *expectedData) { 521 var ( 522 ok bool 523 ) 524 if _, ok = d.s.(*stats.InTrailer); !ok { 525 t.Fatalf("got %T, want InTrailer", d.s) 526 } 527 if d.ctx == nil { 528 t.Fatalf("d.ctx = nil, want <non-nil>") 529 } 530} 531 532func checkOutHeader(t *testing.T, d *gotData, e *expectedData) { 533 var ( 534 ok bool 535 st *stats.OutHeader 536 ) 537 if st, ok = d.s.(*stats.OutHeader); !ok { 538 t.Fatalf("got %T, want OutHeader", d.s) 539 } 540 if d.ctx == nil { 541 t.Fatalf("d.ctx = nil, want <non-nil>") 542 } 543 if d.client { 544 if st.FullMethod != e.method { 545 t.Fatalf("st.FullMethod = %s, want %v", st.FullMethod, e.method) 546 } 547 if st.RemoteAddr.String() != e.serverAddr { 548 t.Fatalf("st.RemoteAddr = %v, want %v", st.RemoteAddr, e.serverAddr) 549 } 550 if st.Compression != e.compression { 551 t.Fatalf("st.Compression = %v, want %v", st.Compression, e.compression) 552 } 553 554 if rpcInfo, ok := d.ctx.Value(rpcCtxKey{}).(*stats.RPCTagInfo); ok { 555 if rpcInfo.FullMethodName != st.FullMethod { 556 t.Fatalf("rpcInfo.FullMethod = %s, want %v", rpcInfo.FullMethodName, st.FullMethod) 557 } 558 } else { 559 t.Fatalf("got context %v, want one with rpcCtxKey", d.ctx) 560 } 561 } 562} 563 564func checkOutPayload(t *testing.T, d *gotData, e *expectedData) { 565 var ( 566 ok bool 567 st *stats.OutPayload 568 ) 569 if st, ok = d.s.(*stats.OutPayload); !ok { 570 t.Fatalf("got %T, want OutPayload", d.s) 571 } 572 if d.ctx == nil { 573 t.Fatalf("d.ctx = nil, want <non-nil>") 574 } 575 if d.client { 576 b, err := proto.Marshal(e.requests[e.reqIdx]) 577 if err != nil { 578 t.Fatalf("failed to marshal message: %v", err) 579 } 580 if reflect.TypeOf(st.Payload) != reflect.TypeOf(e.requests[e.reqIdx]) { 581 t.Fatalf("st.Payload = %T, want %T", st.Payload, e.requests[e.reqIdx]) 582 } 583 e.reqIdx++ 584 if string(st.Data) != string(b) { 585 t.Fatalf("st.Data = %v, want %v", st.Data, b) 586 } 587 if st.Length != len(b) { 588 t.Fatalf("st.Lenght = %v, want %v", st.Length, len(b)) 589 } 590 } else { 591 b, err := proto.Marshal(e.responses[e.respIdx]) 592 if err != nil { 593 t.Fatalf("failed to marshal message: %v", err) 594 } 595 if reflect.TypeOf(st.Payload) != reflect.TypeOf(e.responses[e.respIdx]) { 596 t.Fatalf("st.Payload = %T, want %T", st.Payload, e.responses[e.respIdx]) 597 } 598 e.respIdx++ 599 if string(st.Data) != string(b) { 600 t.Fatalf("st.Data = %v, want %v", st.Data, b) 601 } 602 if st.Length != len(b) { 603 t.Fatalf("st.Lenght = %v, want %v", st.Length, len(b)) 604 } 605 } 606 // TODO check WireLength and ReceivedTime. 607 if st.SentTime.IsZero() { 608 t.Fatalf("st.SentTime = %v, want <non-zero>", st.SentTime) 609 } 610} 611 612func checkOutTrailer(t *testing.T, d *gotData, e *expectedData) { 613 var ( 614 ok bool 615 st *stats.OutTrailer 616 ) 617 if st, ok = d.s.(*stats.OutTrailer); !ok { 618 t.Fatalf("got %T, want OutTrailer", d.s) 619 } 620 if d.ctx == nil { 621 t.Fatalf("d.ctx = nil, want <non-nil>") 622 } 623 if st.Client { 624 t.Fatalf("st IsClient = true, want false") 625 } 626} 627 628func checkEnd(t *testing.T, d *gotData, e *expectedData) { 629 var ( 630 ok bool 631 st *stats.End 632 ) 633 if st, ok = d.s.(*stats.End); !ok { 634 t.Fatalf("got %T, want End", d.s) 635 } 636 if d.ctx == nil { 637 t.Fatalf("d.ctx = nil, want <non-nil>") 638 } 639 if st.BeginTime.IsZero() { 640 t.Fatalf("st.BeginTime = %v, want <non-zero>", st.BeginTime) 641 } 642 if st.EndTime.IsZero() { 643 t.Fatalf("st.EndTime = %v, want <non-zero>", st.EndTime) 644 } 645 646 actual, ok := status.FromError(st.Error) 647 if !ok { 648 t.Fatalf("expected st.Error to be a statusError, got %v (type %T)", st.Error, st.Error) 649 } 650 651 expectedStatus, _ := status.FromError(e.err) 652 if actual.Code() != expectedStatus.Code() || actual.Message() != expectedStatus.Message() { 653 t.Fatalf("st.Error = %v, want %v", st.Error, e.err) 654 } 655 656 if st.Client { 657 if !reflect.DeepEqual(st.Trailer, testTrailerMetadata) { 658 t.Fatalf("st.Trailer = %v, want %v", st.Trailer, testTrailerMetadata) 659 } 660 } else { 661 if st.Trailer != nil { 662 t.Fatalf("st.Trailer = %v, want nil", st.Trailer) 663 } 664 } 665} 666 667func checkConnBegin(t *testing.T, d *gotData, e *expectedData) { 668 var ( 669 ok bool 670 st *stats.ConnBegin 671 ) 672 if st, ok = d.s.(*stats.ConnBegin); !ok { 673 t.Fatalf("got %T, want ConnBegin", d.s) 674 } 675 if d.ctx == nil { 676 t.Fatalf("d.ctx = nil, want <non-nil>") 677 } 678 st.IsClient() // TODO remove this. 679} 680 681func checkConnEnd(t *testing.T, d *gotData, e *expectedData) { 682 var ( 683 ok bool 684 st *stats.ConnEnd 685 ) 686 if st, ok = d.s.(*stats.ConnEnd); !ok { 687 t.Fatalf("got %T, want ConnEnd", d.s) 688 } 689 if d.ctx == nil { 690 t.Fatalf("d.ctx = nil, want <non-nil>") 691 } 692 st.IsClient() // TODO remove this. 693} 694 695type statshandler struct { 696 mu sync.Mutex 697 gotRPC []*gotData 698 gotConn []*gotData 699} 700 701func (h *statshandler) TagConn(ctx context.Context, info *stats.ConnTagInfo) context.Context { 702 return context.WithValue(ctx, connCtxKey{}, info) 703} 704 705func (h *statshandler) TagRPC(ctx context.Context, info *stats.RPCTagInfo) context.Context { 706 return context.WithValue(ctx, rpcCtxKey{}, info) 707} 708 709func (h *statshandler) HandleConn(ctx context.Context, s stats.ConnStats) { 710 h.mu.Lock() 711 defer h.mu.Unlock() 712 h.gotConn = append(h.gotConn, &gotData{ctx, s.IsClient(), s}) 713} 714 715func (h *statshandler) HandleRPC(ctx context.Context, s stats.RPCStats) { 716 h.mu.Lock() 717 defer h.mu.Unlock() 718 h.gotRPC = append(h.gotRPC, &gotData{ctx, s.IsClient(), s}) 719} 720 721func checkConnStats(t *testing.T, got []*gotData) { 722 if len(got) <= 0 || len(got)%2 != 0 { 723 for i, g := range got { 724 t.Errorf(" - %v, %T = %+v, ctx: %v", i, g.s, g.s, g.ctx) 725 } 726 t.Fatalf("got %v stats, want even positive number", len(got)) 727 } 728 // The first conn stats must be a ConnBegin. 729 checkConnBegin(t, got[0], nil) 730 // The last conn stats must be a ConnEnd. 731 checkConnEnd(t, got[len(got)-1], nil) 732} 733 734func checkServerStats(t *testing.T, got []*gotData, expect *expectedData, checkFuncs []func(t *testing.T, d *gotData, e *expectedData)) { 735 if len(got) != len(checkFuncs) { 736 for i, g := range got { 737 t.Errorf(" - %v, %T", i, g.s) 738 } 739 t.Fatalf("got %v stats, want %v stats", len(got), len(checkFuncs)) 740 } 741 742 var rpcctx context.Context 743 for i := 0; i < len(got); i++ { 744 if _, ok := got[i].s.(stats.RPCStats); ok { 745 if rpcctx != nil && got[i].ctx != rpcctx { 746 t.Fatalf("got different contexts with stats %T", got[i].s) 747 } 748 rpcctx = got[i].ctx 749 } 750 } 751 752 for i, f := range checkFuncs { 753 f(t, got[i], expect) 754 } 755} 756 757func testServerStats(t *testing.T, tc *testConfig, cc *rpcConfig, checkFuncs []func(t *testing.T, d *gotData, e *expectedData)) { 758 h := &statshandler{} 759 te := newTest(t, tc, nil, h) 760 te.startServer(&testServer{}) 761 defer te.tearDown() 762 763 var ( 764 reqs []*testpb.SimpleRequest 765 resps []*testpb.SimpleResponse 766 err error 767 method string 768 769 req *testpb.SimpleRequest 770 resp *testpb.SimpleResponse 771 e error 772 ) 773 774 switch cc.callType { 775 case unaryRPC: 776 method = "/grpc.testing.TestService/UnaryCall" 777 req, resp, e = te.doUnaryCall(cc) 778 reqs = []*testpb.SimpleRequest{req} 779 resps = []*testpb.SimpleResponse{resp} 780 err = e 781 case clientStreamRPC: 782 method = "/grpc.testing.TestService/ClientStreamCall" 783 reqs, resp, e = te.doClientStreamCall(cc) 784 resps = []*testpb.SimpleResponse{resp} 785 err = e 786 case serverStreamRPC: 787 method = "/grpc.testing.TestService/ServerStreamCall" 788 req, resps, e = te.doServerStreamCall(cc) 789 reqs = []*testpb.SimpleRequest{req} 790 err = e 791 case fullDuplexStreamRPC: 792 method = "/grpc.testing.TestService/FullDuplexCall" 793 reqs, resps, err = te.doFullDuplexCallRoundtrip(cc) 794 } 795 if cc.success != (err == nil) { 796 t.Fatalf("cc.success: %v, got error: %v", cc.success, err) 797 } 798 te.cc.Close() 799 te.srv.GracefulStop() // Wait for the server to stop. 800 801 for { 802 h.mu.Lock() 803 if len(h.gotRPC) >= len(checkFuncs) { 804 h.mu.Unlock() 805 break 806 } 807 h.mu.Unlock() 808 time.Sleep(10 * time.Millisecond) 809 } 810 811 for { 812 h.mu.Lock() 813 if _, ok := h.gotConn[len(h.gotConn)-1].s.(*stats.ConnEnd); ok { 814 h.mu.Unlock() 815 break 816 } 817 h.mu.Unlock() 818 time.Sleep(10 * time.Millisecond) 819 } 820 821 expect := &expectedData{ 822 serverAddr: te.srvAddr, 823 compression: tc.compress, 824 method: method, 825 requests: reqs, 826 responses: resps, 827 err: err, 828 } 829 830 h.mu.Lock() 831 checkConnStats(t, h.gotConn) 832 h.mu.Unlock() 833 checkServerStats(t, h.gotRPC, expect, checkFuncs) 834} 835 836func TestServerStatsUnaryRPC(t *testing.T) { 837 testServerStats(t, &testConfig{compress: ""}, &rpcConfig{success: true, callType: unaryRPC}, []func(t *testing.T, d *gotData, e *expectedData){ 838 checkInHeader, 839 checkBegin, 840 checkInPayload, 841 checkOutHeader, 842 checkOutPayload, 843 checkOutTrailer, 844 checkEnd, 845 }) 846} 847 848func TestServerStatsUnaryRPCError(t *testing.T) { 849 testServerStats(t, &testConfig{compress: ""}, &rpcConfig{success: false, callType: unaryRPC}, []func(t *testing.T, d *gotData, e *expectedData){ 850 checkInHeader, 851 checkBegin, 852 checkInPayload, 853 checkOutHeader, 854 checkOutTrailer, 855 checkEnd, 856 }) 857} 858 859func TestServerStatsClientStreamRPC(t *testing.T) { 860 count := 5 861 checkFuncs := []func(t *testing.T, d *gotData, e *expectedData){ 862 checkInHeader, 863 checkBegin, 864 checkOutHeader, 865 } 866 ioPayFuncs := []func(t *testing.T, d *gotData, e *expectedData){ 867 checkInPayload, 868 } 869 for i := 0; i < count; i++ { 870 checkFuncs = append(checkFuncs, ioPayFuncs...) 871 } 872 checkFuncs = append(checkFuncs, 873 checkOutPayload, 874 checkOutTrailer, 875 checkEnd, 876 ) 877 testServerStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: true, callType: clientStreamRPC}, checkFuncs) 878} 879 880func TestServerStatsClientStreamRPCError(t *testing.T) { 881 count := 1 882 testServerStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: false, callType: clientStreamRPC}, []func(t *testing.T, d *gotData, e *expectedData){ 883 checkInHeader, 884 checkBegin, 885 checkOutHeader, 886 checkInPayload, 887 checkOutTrailer, 888 checkEnd, 889 }) 890} 891 892func TestServerStatsServerStreamRPC(t *testing.T) { 893 count := 5 894 checkFuncs := []func(t *testing.T, d *gotData, e *expectedData){ 895 checkInHeader, 896 checkBegin, 897 checkInPayload, 898 checkOutHeader, 899 } 900 ioPayFuncs := []func(t *testing.T, d *gotData, e *expectedData){ 901 checkOutPayload, 902 } 903 for i := 0; i < count; i++ { 904 checkFuncs = append(checkFuncs, ioPayFuncs...) 905 } 906 checkFuncs = append(checkFuncs, 907 checkOutTrailer, 908 checkEnd, 909 ) 910 testServerStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: true, callType: serverStreamRPC}, checkFuncs) 911} 912 913func TestServerStatsServerStreamRPCError(t *testing.T) { 914 count := 5 915 testServerStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: false, callType: serverStreamRPC}, []func(t *testing.T, d *gotData, e *expectedData){ 916 checkInHeader, 917 checkBegin, 918 checkInPayload, 919 checkOutHeader, 920 checkOutTrailer, 921 checkEnd, 922 }) 923} 924 925func TestServerStatsFullDuplexRPC(t *testing.T) { 926 count := 5 927 checkFuncs := []func(t *testing.T, d *gotData, e *expectedData){ 928 checkInHeader, 929 checkBegin, 930 checkOutHeader, 931 } 932 ioPayFuncs := []func(t *testing.T, d *gotData, e *expectedData){ 933 checkInPayload, 934 checkOutPayload, 935 } 936 for i := 0; i < count; i++ { 937 checkFuncs = append(checkFuncs, ioPayFuncs...) 938 } 939 checkFuncs = append(checkFuncs, 940 checkOutTrailer, 941 checkEnd, 942 ) 943 testServerStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: true, callType: fullDuplexStreamRPC}, checkFuncs) 944} 945 946func TestServerStatsFullDuplexRPCError(t *testing.T) { 947 count := 5 948 testServerStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: false, callType: fullDuplexStreamRPC}, []func(t *testing.T, d *gotData, e *expectedData){ 949 checkInHeader, 950 checkBegin, 951 checkOutHeader, 952 checkInPayload, 953 checkOutTrailer, 954 checkEnd, 955 }) 956} 957 958type checkFuncWithCount struct { 959 f func(t *testing.T, d *gotData, e *expectedData) 960 c int // expected count 961} 962 963func checkClientStats(t *testing.T, got []*gotData, expect *expectedData, checkFuncs map[int]*checkFuncWithCount) { 964 var expectLen int 965 for _, v := range checkFuncs { 966 expectLen += v.c 967 } 968 if len(got) != expectLen { 969 for i, g := range got { 970 t.Errorf(" - %v, %T", i, g.s) 971 } 972 t.Fatalf("got %v stats, want %v stats", len(got), expectLen) 973 } 974 975 var tagInfoInCtx *stats.RPCTagInfo 976 for i := 0; i < len(got); i++ { 977 if _, ok := got[i].s.(stats.RPCStats); ok { 978 tagInfoInCtxNew, _ := got[i].ctx.Value(rpcCtxKey{}).(*stats.RPCTagInfo) 979 if tagInfoInCtx != nil && tagInfoInCtx != tagInfoInCtxNew { 980 t.Fatalf("got context containing different tagInfo with stats %T", got[i].s) 981 } 982 tagInfoInCtx = tagInfoInCtxNew 983 } 984 } 985 986 for _, s := range got { 987 switch s.s.(type) { 988 case *stats.Begin: 989 if checkFuncs[begin].c <= 0 { 990 t.Fatalf("unexpected stats: %T", s.s) 991 } 992 checkFuncs[begin].f(t, s, expect) 993 checkFuncs[begin].c-- 994 case *stats.OutHeader: 995 if checkFuncs[outHeader].c <= 0 { 996 t.Fatalf("unexpected stats: %T", s.s) 997 } 998 checkFuncs[outHeader].f(t, s, expect) 999 checkFuncs[outHeader].c-- 1000 case *stats.OutPayload: 1001 if checkFuncs[outPayload].c <= 0 { 1002 t.Fatalf("unexpected stats: %T", s.s) 1003 } 1004 checkFuncs[outPayload].f(t, s, expect) 1005 checkFuncs[outPayload].c-- 1006 case *stats.InHeader: 1007 if checkFuncs[inHeader].c <= 0 { 1008 t.Fatalf("unexpected stats: %T", s.s) 1009 } 1010 checkFuncs[inHeader].f(t, s, expect) 1011 checkFuncs[inHeader].c-- 1012 case *stats.InPayload: 1013 if checkFuncs[inPayload].c <= 0 { 1014 t.Fatalf("unexpected stats: %T", s.s) 1015 } 1016 checkFuncs[inPayload].f(t, s, expect) 1017 checkFuncs[inPayload].c-- 1018 case *stats.InTrailer: 1019 if checkFuncs[inTrailer].c <= 0 { 1020 t.Fatalf("unexpected stats: %T", s.s) 1021 } 1022 checkFuncs[inTrailer].f(t, s, expect) 1023 checkFuncs[inTrailer].c-- 1024 case *stats.End: 1025 if checkFuncs[end].c <= 0 { 1026 t.Fatalf("unexpected stats: %T", s.s) 1027 } 1028 checkFuncs[end].f(t, s, expect) 1029 checkFuncs[end].c-- 1030 case *stats.ConnBegin: 1031 if checkFuncs[connbegin].c <= 0 { 1032 t.Fatalf("unexpected stats: %T", s.s) 1033 } 1034 checkFuncs[connbegin].f(t, s, expect) 1035 checkFuncs[connbegin].c-- 1036 case *stats.ConnEnd: 1037 if checkFuncs[connend].c <= 0 { 1038 t.Fatalf("unexpected stats: %T", s.s) 1039 } 1040 checkFuncs[connend].f(t, s, expect) 1041 checkFuncs[connend].c-- 1042 default: 1043 t.Fatalf("unexpected stats: %T", s.s) 1044 } 1045 } 1046} 1047 1048func testClientStats(t *testing.T, tc *testConfig, cc *rpcConfig, checkFuncs map[int]*checkFuncWithCount) { 1049 h := &statshandler{} 1050 te := newTest(t, tc, h, nil) 1051 te.startServer(&testServer{}) 1052 defer te.tearDown() 1053 1054 var ( 1055 reqs []*testpb.SimpleRequest 1056 resps []*testpb.SimpleResponse 1057 method string 1058 err error 1059 1060 req *testpb.SimpleRequest 1061 resp *testpb.SimpleResponse 1062 e error 1063 ) 1064 switch cc.callType { 1065 case unaryRPC: 1066 method = "/grpc.testing.TestService/UnaryCall" 1067 req, resp, e = te.doUnaryCall(cc) 1068 reqs = []*testpb.SimpleRequest{req} 1069 resps = []*testpb.SimpleResponse{resp} 1070 err = e 1071 case clientStreamRPC: 1072 method = "/grpc.testing.TestService/ClientStreamCall" 1073 reqs, resp, e = te.doClientStreamCall(cc) 1074 resps = []*testpb.SimpleResponse{resp} 1075 err = e 1076 case serverStreamRPC: 1077 method = "/grpc.testing.TestService/ServerStreamCall" 1078 req, resps, e = te.doServerStreamCall(cc) 1079 reqs = []*testpb.SimpleRequest{req} 1080 err = e 1081 case fullDuplexStreamRPC: 1082 method = "/grpc.testing.TestService/FullDuplexCall" 1083 reqs, resps, err = te.doFullDuplexCallRoundtrip(cc) 1084 } 1085 if cc.success != (err == nil) { 1086 t.Fatalf("cc.success: %v, got error: %v", cc.success, err) 1087 } 1088 te.cc.Close() 1089 te.srv.GracefulStop() // Wait for the server to stop. 1090 1091 lenRPCStats := 0 1092 for _, v := range checkFuncs { 1093 lenRPCStats += v.c 1094 } 1095 for { 1096 h.mu.Lock() 1097 if len(h.gotRPC) >= lenRPCStats { 1098 h.mu.Unlock() 1099 break 1100 } 1101 h.mu.Unlock() 1102 time.Sleep(10 * time.Millisecond) 1103 } 1104 1105 for { 1106 h.mu.Lock() 1107 if _, ok := h.gotConn[len(h.gotConn)-1].s.(*stats.ConnEnd); ok { 1108 h.mu.Unlock() 1109 break 1110 } 1111 h.mu.Unlock() 1112 time.Sleep(10 * time.Millisecond) 1113 } 1114 1115 expect := &expectedData{ 1116 serverAddr: te.srvAddr, 1117 compression: tc.compress, 1118 method: method, 1119 requests: reqs, 1120 responses: resps, 1121 failfast: cc.failfast, 1122 err: err, 1123 } 1124 1125 h.mu.Lock() 1126 checkConnStats(t, h.gotConn) 1127 h.mu.Unlock() 1128 checkClientStats(t, h.gotRPC, expect, checkFuncs) 1129} 1130 1131func TestClientStatsUnaryRPC(t *testing.T) { 1132 testClientStats(t, &testConfig{compress: ""}, &rpcConfig{success: true, failfast: false, callType: unaryRPC}, map[int]*checkFuncWithCount{ 1133 begin: {checkBegin, 1}, 1134 outHeader: {checkOutHeader, 1}, 1135 outPayload: {checkOutPayload, 1}, 1136 inHeader: {checkInHeader, 1}, 1137 inPayload: {checkInPayload, 1}, 1138 inTrailer: {checkInTrailer, 1}, 1139 end: {checkEnd, 1}, 1140 }) 1141} 1142 1143func TestClientStatsUnaryRPCError(t *testing.T) { 1144 testClientStats(t, &testConfig{compress: ""}, &rpcConfig{success: false, failfast: false, callType: unaryRPC}, map[int]*checkFuncWithCount{ 1145 begin: {checkBegin, 1}, 1146 outHeader: {checkOutHeader, 1}, 1147 outPayload: {checkOutPayload, 1}, 1148 inHeader: {checkInHeader, 1}, 1149 inTrailer: {checkInTrailer, 1}, 1150 end: {checkEnd, 1}, 1151 }) 1152} 1153 1154func TestClientStatsClientStreamRPC(t *testing.T) { 1155 count := 5 1156 testClientStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: true, failfast: false, callType: clientStreamRPC}, map[int]*checkFuncWithCount{ 1157 begin: {checkBegin, 1}, 1158 outHeader: {checkOutHeader, 1}, 1159 inHeader: {checkInHeader, 1}, 1160 outPayload: {checkOutPayload, count}, 1161 inTrailer: {checkInTrailer, 1}, 1162 inPayload: {checkInPayload, 1}, 1163 end: {checkEnd, 1}, 1164 }) 1165} 1166 1167func TestClientStatsClientStreamRPCError(t *testing.T) { 1168 count := 1 1169 testClientStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: false, failfast: false, callType: clientStreamRPC}, map[int]*checkFuncWithCount{ 1170 begin: {checkBegin, 1}, 1171 outHeader: {checkOutHeader, 1}, 1172 inHeader: {checkInHeader, 1}, 1173 outPayload: {checkOutPayload, 1}, 1174 inTrailer: {checkInTrailer, 1}, 1175 end: {checkEnd, 1}, 1176 }) 1177} 1178 1179func TestClientStatsServerStreamRPC(t *testing.T) { 1180 count := 5 1181 testClientStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: true, failfast: false, callType: serverStreamRPC}, map[int]*checkFuncWithCount{ 1182 begin: {checkBegin, 1}, 1183 outHeader: {checkOutHeader, 1}, 1184 outPayload: {checkOutPayload, 1}, 1185 inHeader: {checkInHeader, 1}, 1186 inPayload: {checkInPayload, count}, 1187 inTrailer: {checkInTrailer, 1}, 1188 end: {checkEnd, 1}, 1189 }) 1190} 1191 1192func TestClientStatsServerStreamRPCError(t *testing.T) { 1193 count := 5 1194 testClientStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: false, failfast: false, callType: serverStreamRPC}, map[int]*checkFuncWithCount{ 1195 begin: {checkBegin, 1}, 1196 outHeader: {checkOutHeader, 1}, 1197 outPayload: {checkOutPayload, 1}, 1198 inHeader: {checkInHeader, 1}, 1199 inTrailer: {checkInTrailer, 1}, 1200 end: {checkEnd, 1}, 1201 }) 1202} 1203 1204func TestClientStatsFullDuplexRPC(t *testing.T) { 1205 count := 5 1206 testClientStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: true, failfast: false, callType: fullDuplexStreamRPC}, map[int]*checkFuncWithCount{ 1207 begin: {checkBegin, 1}, 1208 outHeader: {checkOutHeader, 1}, 1209 outPayload: {checkOutPayload, count}, 1210 inHeader: {checkInHeader, 1}, 1211 inPayload: {checkInPayload, count}, 1212 inTrailer: {checkInTrailer, 1}, 1213 end: {checkEnd, 1}, 1214 }) 1215} 1216 1217func TestClientStatsFullDuplexRPCError(t *testing.T) { 1218 count := 5 1219 testClientStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: false, failfast: false, callType: fullDuplexStreamRPC}, map[int]*checkFuncWithCount{ 1220 begin: {checkBegin, 1}, 1221 outHeader: {checkOutHeader, 1}, 1222 outPayload: {checkOutPayload, 1}, 1223 inHeader: {checkInHeader, 1}, 1224 inTrailer: {checkInTrailer, 1}, 1225 end: {checkEnd, 1}, 1226 }) 1227} 1228 1229func TestTags(t *testing.T) { 1230 b := []byte{5, 2, 4, 3, 1} 1231 ctx := stats.SetTags(context.Background(), b) 1232 if tg := stats.OutgoingTags(ctx); !reflect.DeepEqual(tg, b) { 1233 t.Errorf("OutgoingTags(%v) = %v; want %v", ctx, tg, b) 1234 } 1235 if tg := stats.Tags(ctx); tg != nil { 1236 t.Errorf("Tags(%v) = %v; want nil", ctx, tg) 1237 } 1238 1239 ctx = stats.SetIncomingTags(context.Background(), b) 1240 if tg := stats.Tags(ctx); !reflect.DeepEqual(tg, b) { 1241 t.Errorf("Tags(%v) = %v; want %v", ctx, tg, b) 1242 } 1243 if tg := stats.OutgoingTags(ctx); tg != nil { 1244 t.Errorf("OutgoingTags(%v) = %v; want nil", ctx, tg) 1245 } 1246} 1247 1248func TestTrace(t *testing.T) { 1249 b := []byte{5, 2, 4, 3, 1} 1250 ctx := stats.SetTrace(context.Background(), b) 1251 if tr := stats.OutgoingTrace(ctx); !reflect.DeepEqual(tr, b) { 1252 t.Errorf("OutgoingTrace(%v) = %v; want %v", ctx, tr, b) 1253 } 1254 if tr := stats.Trace(ctx); tr != nil { 1255 t.Errorf("Trace(%v) = %v; want nil", ctx, tr) 1256 } 1257 1258 ctx = stats.SetIncomingTrace(context.Background(), b) 1259 if tr := stats.Trace(ctx); !reflect.DeepEqual(tr, b) { 1260 t.Errorf("Trace(%v) = %v; want %v", ctx, tr, b) 1261 } 1262 if tr := stats.OutgoingTrace(ctx); tr != nil { 1263 t.Errorf("OutgoingTrace(%v) = %v; want nil", ctx, tr) 1264 } 1265} 1266