1/* 2 * 3 * Copyright 2016 gRPC authors. 4 * 5 * Licensed under the Apache License, Version 2.0 (the "License"); 6 * you may not use this file except in compliance with the License. 7 * You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 * 17 */ 18 19package stats_test 20 21import ( 22 "context" 23 "fmt" 24 "io" 25 "net" 26 "reflect" 27 "sync" 28 "testing" 29 "time" 30 31 "github.com/golang/protobuf/proto" 32 "google.golang.org/grpc" 33 "google.golang.org/grpc/metadata" 34 "google.golang.org/grpc/stats" 35 testpb "google.golang.org/grpc/stats/grpc_testing" 36 "google.golang.org/grpc/status" 37) 38 39func init() { 40 grpc.EnableTracing = false 41} 42 43type connCtxKey struct{} 44type rpcCtxKey struct{} 45 46var ( 47 // For headers: 48 testMetadata = metadata.MD{ 49 "key1": []string{"value1"}, 50 "key2": []string{"value2"}, 51 } 52 // For trailers: 53 testTrailerMetadata = metadata.MD{ 54 "tkey1": []string{"trailerValue1"}, 55 "tkey2": []string{"trailerValue2"}, 56 } 57 // The id for which the service handler should return error. 58 errorID int32 = 32202 59) 60 61type testServer struct { 62 testpb.UnimplementedTestServiceServer 63} 64 65func (s *testServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) { 66 md, ok := metadata.FromIncomingContext(ctx) 67 if ok { 68 if err := grpc.SendHeader(ctx, md); err != nil { 69 return nil, status.Errorf(status.Code(err), "grpc.SendHeader(_, %v) = %v, want <nil>", md, err) 70 } 71 if err := grpc.SetTrailer(ctx, testTrailerMetadata); err != nil { 72 return nil, status.Errorf(status.Code(err), "grpc.SetTrailer(_, %v) = %v, want <nil>", testTrailerMetadata, err) 73 } 74 } 75 76 if in.Id == errorID { 77 return nil, fmt.Errorf("got error id: %v", in.Id) 78 } 79 80 return &testpb.SimpleResponse{Id: in.Id}, nil 81} 82 83func (s *testServer) FullDuplexCall(stream testpb.TestService_FullDuplexCallServer) error { 84 md, ok := metadata.FromIncomingContext(stream.Context()) 85 if ok { 86 if err := stream.SendHeader(md); err != nil { 87 return status.Errorf(status.Code(err), "%v.SendHeader(%v) = %v, want %v", stream, md, err, nil) 88 } 89 stream.SetTrailer(testTrailerMetadata) 90 } 91 for { 92 in, err := stream.Recv() 93 if err == io.EOF { 94 // read done. 95 return nil 96 } 97 if err != nil { 98 return err 99 } 100 101 if in.Id == errorID { 102 return fmt.Errorf("got error id: %v", in.Id) 103 } 104 105 if err := stream.Send(&testpb.SimpleResponse{Id: in.Id}); err != nil { 106 return err 107 } 108 } 109} 110 111func (s *testServer) ClientStreamCall(stream testpb.TestService_ClientStreamCallServer) error { 112 md, ok := metadata.FromIncomingContext(stream.Context()) 113 if ok { 114 if err := stream.SendHeader(md); err != nil { 115 return status.Errorf(status.Code(err), "%v.SendHeader(%v) = %v, want %v", stream, md, err, nil) 116 } 117 stream.SetTrailer(testTrailerMetadata) 118 } 119 for { 120 in, err := stream.Recv() 121 if err == io.EOF { 122 // read done. 123 return stream.SendAndClose(&testpb.SimpleResponse{Id: int32(0)}) 124 } 125 if err != nil { 126 return err 127 } 128 129 if in.Id == errorID { 130 return fmt.Errorf("got error id: %v", in.Id) 131 } 132 } 133} 134 135func (s *testServer) ServerStreamCall(in *testpb.SimpleRequest, stream testpb.TestService_ServerStreamCallServer) error { 136 md, ok := metadata.FromIncomingContext(stream.Context()) 137 if ok { 138 if err := stream.SendHeader(md); err != nil { 139 return status.Errorf(status.Code(err), "%v.SendHeader(%v) = %v, want %v", stream, md, err, nil) 140 } 141 stream.SetTrailer(testTrailerMetadata) 142 } 143 144 if in.Id == errorID { 145 return fmt.Errorf("got error id: %v", in.Id) 146 } 147 148 for i := 0; i < 5; i++ { 149 if err := stream.Send(&testpb.SimpleResponse{Id: in.Id}); err != nil { 150 return err 151 } 152 } 153 return nil 154} 155 156// test is an end-to-end test. It should be created with the newTest 157// func, modified as needed, and then started with its startServer method. 158// It should be cleaned up with the tearDown method. 159type test struct { 160 t *testing.T 161 compress string 162 clientStatsHandler stats.Handler 163 serverStatsHandler stats.Handler 164 165 testServer testpb.TestServiceServer // nil means none 166 // srv and srvAddr are set once startServer is called. 167 srv *grpc.Server 168 srvAddr string 169 170 cc *grpc.ClientConn // nil until requested via clientConn 171} 172 173func (te *test) tearDown() { 174 if te.cc != nil { 175 te.cc.Close() 176 te.cc = nil 177 } 178 te.srv.Stop() 179} 180 181type testConfig struct { 182 compress string 183} 184 185// newTest returns a new test using the provided testing.T and 186// environment. It is returned with default values. Tests should 187// modify it before calling its startServer and clientConn methods. 188func newTest(t *testing.T, tc *testConfig, ch stats.Handler, sh stats.Handler) *test { 189 te := &test{ 190 t: t, 191 compress: tc.compress, 192 clientStatsHandler: ch, 193 serverStatsHandler: sh, 194 } 195 return te 196} 197 198// startServer starts a gRPC server listening. Callers should defer a 199// call to te.tearDown to clean up. 200func (te *test) startServer(ts testpb.TestServiceServer) { 201 te.testServer = ts 202 lis, err := net.Listen("tcp", "localhost:0") 203 if err != nil { 204 te.t.Fatalf("Failed to listen: %v", err) 205 } 206 var opts []grpc.ServerOption 207 if te.compress == "gzip" { 208 opts = append(opts, 209 grpc.RPCCompressor(grpc.NewGZIPCompressor()), 210 grpc.RPCDecompressor(grpc.NewGZIPDecompressor()), 211 ) 212 } 213 if te.serverStatsHandler != nil { 214 opts = append(opts, grpc.StatsHandler(te.serverStatsHandler)) 215 } 216 s := grpc.NewServer(opts...) 217 te.srv = s 218 if te.testServer != nil { 219 testpb.RegisterTestServiceServer(s, te.testServer) 220 } 221 222 go s.Serve(lis) 223 te.srvAddr = lis.Addr().String() 224} 225 226func (te *test) clientConn() *grpc.ClientConn { 227 if te.cc != nil { 228 return te.cc 229 } 230 opts := []grpc.DialOption{grpc.WithInsecure(), grpc.WithBlock()} 231 if te.compress == "gzip" { 232 opts = append(opts, 233 grpc.WithCompressor(grpc.NewGZIPCompressor()), 234 grpc.WithDecompressor(grpc.NewGZIPDecompressor()), 235 ) 236 } 237 if te.clientStatsHandler != nil { 238 opts = append(opts, grpc.WithStatsHandler(te.clientStatsHandler)) 239 } 240 241 var err error 242 te.cc, err = grpc.Dial(te.srvAddr, opts...) 243 if err != nil { 244 te.t.Fatalf("Dial(%q) = %v", te.srvAddr, err) 245 } 246 return te.cc 247} 248 249type rpcType int 250 251const ( 252 unaryRPC rpcType = iota 253 clientStreamRPC 254 serverStreamRPC 255 fullDuplexStreamRPC 256) 257 258type rpcConfig struct { 259 count int // Number of requests and responses for streaming RPCs. 260 success bool // Whether the RPC should succeed or return error. 261 failfast bool 262 callType rpcType // Type of RPC. 263} 264 265func (te *test) doUnaryCall(c *rpcConfig) (*testpb.SimpleRequest, *testpb.SimpleResponse, error) { 266 var ( 267 resp *testpb.SimpleResponse 268 req *testpb.SimpleRequest 269 err error 270 ) 271 tc := testpb.NewTestServiceClient(te.clientConn()) 272 if c.success { 273 req = &testpb.SimpleRequest{Id: errorID + 1} 274 } else { 275 req = &testpb.SimpleRequest{Id: errorID} 276 } 277 ctx := metadata.NewOutgoingContext(context.Background(), testMetadata) 278 279 resp, err = tc.UnaryCall(ctx, req, grpc.WaitForReady(!c.failfast)) 280 return req, resp, err 281} 282 283func (te *test) doFullDuplexCallRoundtrip(c *rpcConfig) ([]*testpb.SimpleRequest, []*testpb.SimpleResponse, error) { 284 var ( 285 reqs []*testpb.SimpleRequest 286 resps []*testpb.SimpleResponse 287 err error 288 ) 289 tc := testpb.NewTestServiceClient(te.clientConn()) 290 stream, err := tc.FullDuplexCall(metadata.NewOutgoingContext(context.Background(), testMetadata), grpc.WaitForReady(!c.failfast)) 291 if err != nil { 292 return reqs, resps, err 293 } 294 var startID int32 295 if !c.success { 296 startID = errorID 297 } 298 for i := 0; i < c.count; i++ { 299 req := &testpb.SimpleRequest{ 300 Id: int32(i) + startID, 301 } 302 reqs = append(reqs, req) 303 if err = stream.Send(req); err != nil { 304 return reqs, resps, err 305 } 306 var resp *testpb.SimpleResponse 307 if resp, err = stream.Recv(); err != nil { 308 return reqs, resps, err 309 } 310 resps = append(resps, resp) 311 } 312 if err = stream.CloseSend(); err != nil && err != io.EOF { 313 return reqs, resps, err 314 } 315 if _, err = stream.Recv(); err != io.EOF { 316 return reqs, resps, err 317 } 318 319 return reqs, resps, nil 320} 321 322func (te *test) doClientStreamCall(c *rpcConfig) ([]*testpb.SimpleRequest, *testpb.SimpleResponse, error) { 323 var ( 324 reqs []*testpb.SimpleRequest 325 resp *testpb.SimpleResponse 326 err error 327 ) 328 tc := testpb.NewTestServiceClient(te.clientConn()) 329 stream, err := tc.ClientStreamCall(metadata.NewOutgoingContext(context.Background(), testMetadata), grpc.WaitForReady(!c.failfast)) 330 if err != nil { 331 return reqs, resp, err 332 } 333 var startID int32 334 if !c.success { 335 startID = errorID 336 } 337 for i := 0; i < c.count; i++ { 338 req := &testpb.SimpleRequest{ 339 Id: int32(i) + startID, 340 } 341 reqs = append(reqs, req) 342 if err = stream.Send(req); err != nil { 343 return reqs, resp, err 344 } 345 } 346 resp, err = stream.CloseAndRecv() 347 return reqs, resp, err 348} 349 350func (te *test) doServerStreamCall(c *rpcConfig) (*testpb.SimpleRequest, []*testpb.SimpleResponse, error) { 351 var ( 352 req *testpb.SimpleRequest 353 resps []*testpb.SimpleResponse 354 err error 355 ) 356 357 tc := testpb.NewTestServiceClient(te.clientConn()) 358 359 var startID int32 360 if !c.success { 361 startID = errorID 362 } 363 req = &testpb.SimpleRequest{Id: startID} 364 stream, err := tc.ServerStreamCall(metadata.NewOutgoingContext(context.Background(), testMetadata), req, grpc.WaitForReady(!c.failfast)) 365 if err != nil { 366 return req, resps, err 367 } 368 for { 369 var resp *testpb.SimpleResponse 370 resp, err := stream.Recv() 371 if err == io.EOF { 372 return req, resps, nil 373 } else if err != nil { 374 return req, resps, err 375 } 376 resps = append(resps, resp) 377 } 378} 379 380type expectedData struct { 381 method string 382 serverAddr string 383 compression string 384 reqIdx int 385 requests []*testpb.SimpleRequest 386 respIdx int 387 responses []*testpb.SimpleResponse 388 err error 389 failfast bool 390} 391 392type gotData struct { 393 ctx context.Context 394 client bool 395 s interface{} // This could be RPCStats or ConnStats. 396} 397 398const ( 399 begin int = iota 400 end 401 inPayload 402 inHeader 403 inTrailer 404 outPayload 405 outHeader 406 // TODO: test outTrailer ? 407 connBegin 408 connEnd 409) 410 411func checkBegin(t *testing.T, d *gotData, e *expectedData) { 412 var ( 413 ok bool 414 st *stats.Begin 415 ) 416 if st, ok = d.s.(*stats.Begin); !ok { 417 t.Fatalf("got %T, want Begin", d.s) 418 } 419 if d.ctx == nil { 420 t.Fatalf("d.ctx = nil, want <non-nil>") 421 } 422 if st.BeginTime.IsZero() { 423 t.Fatalf("st.BeginTime = %v, want <non-zero>", st.BeginTime) 424 } 425 if d.client { 426 if st.FailFast != e.failfast { 427 t.Fatalf("st.FailFast = %v, want %v", st.FailFast, e.failfast) 428 } 429 } 430} 431 432func checkInHeader(t *testing.T, d *gotData, e *expectedData) { 433 var ( 434 ok bool 435 st *stats.InHeader 436 ) 437 if st, ok = d.s.(*stats.InHeader); !ok { 438 t.Fatalf("got %T, want InHeader", d.s) 439 } 440 if d.ctx == nil { 441 t.Fatalf("d.ctx = nil, want <non-nil>") 442 } 443 if !d.client { 444 if st.FullMethod != e.method { 445 t.Fatalf("st.FullMethod = %s, want %v", st.FullMethod, e.method) 446 } 447 if st.LocalAddr.String() != e.serverAddr { 448 t.Fatalf("st.LocalAddr = %v, want %v", st.LocalAddr, e.serverAddr) 449 } 450 if st.Compression != e.compression { 451 t.Fatalf("st.Compression = %v, want %v", st.Compression, e.compression) 452 } 453 454 if connInfo, ok := d.ctx.Value(connCtxKey{}).(*stats.ConnTagInfo); ok { 455 if connInfo.RemoteAddr != st.RemoteAddr { 456 t.Fatalf("connInfo.RemoteAddr = %v, want %v", connInfo.RemoteAddr, st.RemoteAddr) 457 } 458 if connInfo.LocalAddr != st.LocalAddr { 459 t.Fatalf("connInfo.LocalAddr = %v, want %v", connInfo.LocalAddr, st.LocalAddr) 460 } 461 } else { 462 t.Fatalf("got context %v, want one with connCtxKey", d.ctx) 463 } 464 if rpcInfo, ok := d.ctx.Value(rpcCtxKey{}).(*stats.RPCTagInfo); ok { 465 if rpcInfo.FullMethodName != st.FullMethod { 466 t.Fatalf("rpcInfo.FullMethod = %s, want %v", rpcInfo.FullMethodName, st.FullMethod) 467 } 468 } else { 469 t.Fatalf("got context %v, want one with rpcCtxKey", d.ctx) 470 } 471 } 472} 473 474func checkInPayload(t *testing.T, d *gotData, e *expectedData) { 475 var ( 476 ok bool 477 st *stats.InPayload 478 ) 479 if st, ok = d.s.(*stats.InPayload); !ok { 480 t.Fatalf("got %T, want InPayload", d.s) 481 } 482 if d.ctx == nil { 483 t.Fatalf("d.ctx = nil, want <non-nil>") 484 } 485 if d.client { 486 b, err := proto.Marshal(e.responses[e.respIdx]) 487 if err != nil { 488 t.Fatalf("failed to marshal message: %v", err) 489 } 490 if reflect.TypeOf(st.Payload) != reflect.TypeOf(e.responses[e.respIdx]) { 491 t.Fatalf("st.Payload = %T, want %T", st.Payload, e.responses[e.respIdx]) 492 } 493 e.respIdx++ 494 if string(st.Data) != string(b) { 495 t.Fatalf("st.Data = %v, want %v", st.Data, b) 496 } 497 if st.Length != len(b) { 498 t.Fatalf("st.Lenght = %v, want %v", st.Length, len(b)) 499 } 500 } else { 501 b, err := proto.Marshal(e.requests[e.reqIdx]) 502 if err != nil { 503 t.Fatalf("failed to marshal message: %v", err) 504 } 505 if reflect.TypeOf(st.Payload) != reflect.TypeOf(e.requests[e.reqIdx]) { 506 t.Fatalf("st.Payload = %T, want %T", st.Payload, e.requests[e.reqIdx]) 507 } 508 e.reqIdx++ 509 if string(st.Data) != string(b) { 510 t.Fatalf("st.Data = %v, want %v", st.Data, b) 511 } 512 if st.Length != len(b) { 513 t.Fatalf("st.Lenght = %v, want %v", st.Length, len(b)) 514 } 515 } 516 // Below are sanity checks that WireLength and RecvTime are populated. 517 // TODO: check values of WireLength and RecvTime. 518 if len(st.Data) > 0 && st.WireLength == 0 { 519 t.Fatalf("st.WireLength = %v with non-empty data, want <non-zero>", 520 st.WireLength) 521 } 522 if st.RecvTime.IsZero() { 523 t.Fatalf("st.ReceivedTime = %v, want <non-zero>", st.RecvTime) 524 } 525} 526 527func checkInTrailer(t *testing.T, d *gotData, e *expectedData) { 528 var ( 529 ok bool 530 ) 531 if _, ok = d.s.(*stats.InTrailer); !ok { 532 t.Fatalf("got %T, want InTrailer", d.s) 533 } 534 if d.ctx == nil { 535 t.Fatalf("d.ctx = nil, want <non-nil>") 536 } 537} 538 539func checkOutHeader(t *testing.T, d *gotData, e *expectedData) { 540 var ( 541 ok bool 542 st *stats.OutHeader 543 ) 544 if st, ok = d.s.(*stats.OutHeader); !ok { 545 t.Fatalf("got %T, want OutHeader", d.s) 546 } 547 if d.ctx == nil { 548 t.Fatalf("d.ctx = nil, want <non-nil>") 549 } 550 if d.client { 551 if st.FullMethod != e.method { 552 t.Fatalf("st.FullMethod = %s, want %v", st.FullMethod, e.method) 553 } 554 if st.RemoteAddr.String() != e.serverAddr { 555 t.Fatalf("st.RemoteAddr = %v, want %v", st.RemoteAddr, e.serverAddr) 556 } 557 if st.Compression != e.compression { 558 t.Fatalf("st.Compression = %v, want %v", st.Compression, e.compression) 559 } 560 561 if rpcInfo, ok := d.ctx.Value(rpcCtxKey{}).(*stats.RPCTagInfo); ok { 562 if rpcInfo.FullMethodName != st.FullMethod { 563 t.Fatalf("rpcInfo.FullMethod = %s, want %v", rpcInfo.FullMethodName, st.FullMethod) 564 } 565 } else { 566 t.Fatalf("got context %v, want one with rpcCtxKey", d.ctx) 567 } 568 } 569} 570 571func checkOutPayload(t *testing.T, d *gotData, e *expectedData) { 572 var ( 573 ok bool 574 st *stats.OutPayload 575 ) 576 if st, ok = d.s.(*stats.OutPayload); !ok { 577 t.Fatalf("got %T, want OutPayload", d.s) 578 } 579 if d.ctx == nil { 580 t.Fatalf("d.ctx = nil, want <non-nil>") 581 } 582 if d.client { 583 b, err := proto.Marshal(e.requests[e.reqIdx]) 584 if err != nil { 585 t.Fatalf("failed to marshal message: %v", err) 586 } 587 if reflect.TypeOf(st.Payload) != reflect.TypeOf(e.requests[e.reqIdx]) { 588 t.Fatalf("st.Payload = %T, want %T", st.Payload, e.requests[e.reqIdx]) 589 } 590 e.reqIdx++ 591 if string(st.Data) != string(b) { 592 t.Fatalf("st.Data = %v, want %v", st.Data, b) 593 } 594 if st.Length != len(b) { 595 t.Fatalf("st.Lenght = %v, want %v", st.Length, len(b)) 596 } 597 } else { 598 b, err := proto.Marshal(e.responses[e.respIdx]) 599 if err != nil { 600 t.Fatalf("failed to marshal message: %v", err) 601 } 602 if reflect.TypeOf(st.Payload) != reflect.TypeOf(e.responses[e.respIdx]) { 603 t.Fatalf("st.Payload = %T, want %T", st.Payload, e.responses[e.respIdx]) 604 } 605 e.respIdx++ 606 if string(st.Data) != string(b) { 607 t.Fatalf("st.Data = %v, want %v", st.Data, b) 608 } 609 if st.Length != len(b) { 610 t.Fatalf("st.Lenght = %v, want %v", st.Length, len(b)) 611 } 612 } 613 // Below are sanity checks that WireLength and SentTime are populated. 614 // TODO: check values of WireLength and SentTime. 615 if len(st.Data) > 0 && st.WireLength == 0 { 616 t.Fatalf("st.WireLength = %v with non-empty data, want <non-zero>", 617 st.WireLength) 618 } 619 if st.SentTime.IsZero() { 620 t.Fatalf("st.SentTime = %v, want <non-zero>", st.SentTime) 621 } 622} 623 624func checkOutTrailer(t *testing.T, d *gotData, e *expectedData) { 625 var ( 626 ok bool 627 st *stats.OutTrailer 628 ) 629 if st, ok = d.s.(*stats.OutTrailer); !ok { 630 t.Fatalf("got %T, want OutTrailer", d.s) 631 } 632 if d.ctx == nil { 633 t.Fatalf("d.ctx = nil, want <non-nil>") 634 } 635 if st.Client { 636 t.Fatalf("st IsClient = true, want false") 637 } 638} 639 640func checkEnd(t *testing.T, d *gotData, e *expectedData) { 641 var ( 642 ok bool 643 st *stats.End 644 ) 645 if st, ok = d.s.(*stats.End); !ok { 646 t.Fatalf("got %T, want End", d.s) 647 } 648 if d.ctx == nil { 649 t.Fatalf("d.ctx = nil, want <non-nil>") 650 } 651 if st.BeginTime.IsZero() { 652 t.Fatalf("st.BeginTime = %v, want <non-zero>", st.BeginTime) 653 } 654 if st.EndTime.IsZero() { 655 t.Fatalf("st.EndTime = %v, want <non-zero>", st.EndTime) 656 } 657 658 actual, ok := status.FromError(st.Error) 659 if !ok { 660 t.Fatalf("expected st.Error to be a statusError, got %v (type %T)", st.Error, st.Error) 661 } 662 663 expectedStatus, _ := status.FromError(e.err) 664 if actual.Code() != expectedStatus.Code() || actual.Message() != expectedStatus.Message() { 665 t.Fatalf("st.Error = %v, want %v", st.Error, e.err) 666 } 667 668 if st.Client { 669 if !reflect.DeepEqual(st.Trailer, testTrailerMetadata) { 670 t.Fatalf("st.Trailer = %v, want %v", st.Trailer, testTrailerMetadata) 671 } 672 } else { 673 if st.Trailer != nil { 674 t.Fatalf("st.Trailer = %v, want nil", st.Trailer) 675 } 676 } 677} 678 679func checkConnBegin(t *testing.T, d *gotData, e *expectedData) { 680 var ( 681 ok bool 682 st *stats.ConnBegin 683 ) 684 if st, ok = d.s.(*stats.ConnBegin); !ok { 685 t.Fatalf("got %T, want ConnBegin", d.s) 686 } 687 if d.ctx == nil { 688 t.Fatalf("d.ctx = nil, want <non-nil>") 689 } 690 st.IsClient() // TODO remove this. 691} 692 693func checkConnEnd(t *testing.T, d *gotData, e *expectedData) { 694 var ( 695 ok bool 696 st *stats.ConnEnd 697 ) 698 if st, ok = d.s.(*stats.ConnEnd); !ok { 699 t.Fatalf("got %T, want ConnEnd", d.s) 700 } 701 if d.ctx == nil { 702 t.Fatalf("d.ctx = nil, want <non-nil>") 703 } 704 st.IsClient() // TODO remove this. 705} 706 707type statshandler struct { 708 mu sync.Mutex 709 gotRPC []*gotData 710 gotConn []*gotData 711} 712 713func (h *statshandler) TagConn(ctx context.Context, info *stats.ConnTagInfo) context.Context { 714 return context.WithValue(ctx, connCtxKey{}, info) 715} 716 717func (h *statshandler) TagRPC(ctx context.Context, info *stats.RPCTagInfo) context.Context { 718 return context.WithValue(ctx, rpcCtxKey{}, info) 719} 720 721func (h *statshandler) HandleConn(ctx context.Context, s stats.ConnStats) { 722 h.mu.Lock() 723 defer h.mu.Unlock() 724 h.gotConn = append(h.gotConn, &gotData{ctx, s.IsClient(), s}) 725} 726 727func (h *statshandler) HandleRPC(ctx context.Context, s stats.RPCStats) { 728 h.mu.Lock() 729 defer h.mu.Unlock() 730 h.gotRPC = append(h.gotRPC, &gotData{ctx, s.IsClient(), s}) 731} 732 733func checkConnStats(t *testing.T, got []*gotData) { 734 if len(got) <= 0 || len(got)%2 != 0 { 735 for i, g := range got { 736 t.Errorf(" - %v, %T = %+v, ctx: %v", i, g.s, g.s, g.ctx) 737 } 738 t.Fatalf("got %v stats, want even positive number", len(got)) 739 } 740 // The first conn stats must be a ConnBegin. 741 checkConnBegin(t, got[0], nil) 742 // The last conn stats must be a ConnEnd. 743 checkConnEnd(t, got[len(got)-1], nil) 744} 745 746func checkServerStats(t *testing.T, got []*gotData, expect *expectedData, checkFuncs []func(t *testing.T, d *gotData, e *expectedData)) { 747 if len(got) != len(checkFuncs) { 748 for i, g := range got { 749 t.Errorf(" - %v, %T", i, g.s) 750 } 751 t.Fatalf("got %v stats, want %v stats", len(got), len(checkFuncs)) 752 } 753 754 var rpcctx context.Context 755 for i := 0; i < len(got); i++ { 756 if _, ok := got[i].s.(stats.RPCStats); ok { 757 if rpcctx != nil && got[i].ctx != rpcctx { 758 t.Fatalf("got different contexts with stats %T", got[i].s) 759 } 760 rpcctx = got[i].ctx 761 } 762 } 763 764 for i, f := range checkFuncs { 765 f(t, got[i], expect) 766 } 767} 768 769func testServerStats(t *testing.T, tc *testConfig, cc *rpcConfig, checkFuncs []func(t *testing.T, d *gotData, e *expectedData)) { 770 h := &statshandler{} 771 te := newTest(t, tc, nil, h) 772 te.startServer(&testServer{}) 773 defer te.tearDown() 774 775 var ( 776 reqs []*testpb.SimpleRequest 777 resps []*testpb.SimpleResponse 778 err error 779 method string 780 781 req *testpb.SimpleRequest 782 resp *testpb.SimpleResponse 783 e error 784 ) 785 786 switch cc.callType { 787 case unaryRPC: 788 method = "/grpc.testing.TestService/UnaryCall" 789 req, resp, e = te.doUnaryCall(cc) 790 reqs = []*testpb.SimpleRequest{req} 791 resps = []*testpb.SimpleResponse{resp} 792 err = e 793 case clientStreamRPC: 794 method = "/grpc.testing.TestService/ClientStreamCall" 795 reqs, resp, e = te.doClientStreamCall(cc) 796 resps = []*testpb.SimpleResponse{resp} 797 err = e 798 case serverStreamRPC: 799 method = "/grpc.testing.TestService/ServerStreamCall" 800 req, resps, e = te.doServerStreamCall(cc) 801 reqs = []*testpb.SimpleRequest{req} 802 err = e 803 case fullDuplexStreamRPC: 804 method = "/grpc.testing.TestService/FullDuplexCall" 805 reqs, resps, err = te.doFullDuplexCallRoundtrip(cc) 806 } 807 if cc.success != (err == nil) { 808 t.Fatalf("cc.success: %v, got error: %v", cc.success, err) 809 } 810 te.cc.Close() 811 te.srv.GracefulStop() // Wait for the server to stop. 812 813 for { 814 h.mu.Lock() 815 if len(h.gotRPC) >= len(checkFuncs) { 816 h.mu.Unlock() 817 break 818 } 819 h.mu.Unlock() 820 time.Sleep(10 * time.Millisecond) 821 } 822 823 for { 824 h.mu.Lock() 825 if _, ok := h.gotConn[len(h.gotConn)-1].s.(*stats.ConnEnd); ok { 826 h.mu.Unlock() 827 break 828 } 829 h.mu.Unlock() 830 time.Sleep(10 * time.Millisecond) 831 } 832 833 expect := &expectedData{ 834 serverAddr: te.srvAddr, 835 compression: tc.compress, 836 method: method, 837 requests: reqs, 838 responses: resps, 839 err: err, 840 } 841 842 h.mu.Lock() 843 checkConnStats(t, h.gotConn) 844 h.mu.Unlock() 845 checkServerStats(t, h.gotRPC, expect, checkFuncs) 846} 847 848func TestServerStatsUnaryRPC(t *testing.T) { 849 testServerStats(t, &testConfig{compress: ""}, &rpcConfig{success: true, callType: unaryRPC}, []func(t *testing.T, d *gotData, e *expectedData){ 850 checkInHeader, 851 checkBegin, 852 checkInPayload, 853 checkOutHeader, 854 checkOutPayload, 855 checkOutTrailer, 856 checkEnd, 857 }) 858} 859 860func TestServerStatsUnaryRPCError(t *testing.T) { 861 testServerStats(t, &testConfig{compress: ""}, &rpcConfig{success: false, callType: unaryRPC}, []func(t *testing.T, d *gotData, e *expectedData){ 862 checkInHeader, 863 checkBegin, 864 checkInPayload, 865 checkOutHeader, 866 checkOutTrailer, 867 checkEnd, 868 }) 869} 870 871func TestServerStatsClientStreamRPC(t *testing.T) { 872 count := 5 873 checkFuncs := []func(t *testing.T, d *gotData, e *expectedData){ 874 checkInHeader, 875 checkBegin, 876 checkOutHeader, 877 } 878 ioPayFuncs := []func(t *testing.T, d *gotData, e *expectedData){ 879 checkInPayload, 880 } 881 for i := 0; i < count; i++ { 882 checkFuncs = append(checkFuncs, ioPayFuncs...) 883 } 884 checkFuncs = append(checkFuncs, 885 checkOutPayload, 886 checkOutTrailer, 887 checkEnd, 888 ) 889 testServerStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: true, callType: clientStreamRPC}, checkFuncs) 890} 891 892func TestServerStatsClientStreamRPCError(t *testing.T) { 893 count := 1 894 testServerStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: false, callType: clientStreamRPC}, []func(t *testing.T, d *gotData, e *expectedData){ 895 checkInHeader, 896 checkBegin, 897 checkOutHeader, 898 checkInPayload, 899 checkOutTrailer, 900 checkEnd, 901 }) 902} 903 904func TestServerStatsServerStreamRPC(t *testing.T) { 905 count := 5 906 checkFuncs := []func(t *testing.T, d *gotData, e *expectedData){ 907 checkInHeader, 908 checkBegin, 909 checkInPayload, 910 checkOutHeader, 911 } 912 ioPayFuncs := []func(t *testing.T, d *gotData, e *expectedData){ 913 checkOutPayload, 914 } 915 for i := 0; i < count; i++ { 916 checkFuncs = append(checkFuncs, ioPayFuncs...) 917 } 918 checkFuncs = append(checkFuncs, 919 checkOutTrailer, 920 checkEnd, 921 ) 922 testServerStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: true, callType: serverStreamRPC}, checkFuncs) 923} 924 925func TestServerStatsServerStreamRPCError(t *testing.T) { 926 count := 5 927 testServerStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: false, callType: serverStreamRPC}, []func(t *testing.T, d *gotData, e *expectedData){ 928 checkInHeader, 929 checkBegin, 930 checkInPayload, 931 checkOutHeader, 932 checkOutTrailer, 933 checkEnd, 934 }) 935} 936 937func TestServerStatsFullDuplexRPC(t *testing.T) { 938 count := 5 939 checkFuncs := []func(t *testing.T, d *gotData, e *expectedData){ 940 checkInHeader, 941 checkBegin, 942 checkOutHeader, 943 } 944 ioPayFuncs := []func(t *testing.T, d *gotData, e *expectedData){ 945 checkInPayload, 946 checkOutPayload, 947 } 948 for i := 0; i < count; i++ { 949 checkFuncs = append(checkFuncs, ioPayFuncs...) 950 } 951 checkFuncs = append(checkFuncs, 952 checkOutTrailer, 953 checkEnd, 954 ) 955 testServerStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: true, callType: fullDuplexStreamRPC}, checkFuncs) 956} 957 958func TestServerStatsFullDuplexRPCError(t *testing.T) { 959 count := 5 960 testServerStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: false, callType: fullDuplexStreamRPC}, []func(t *testing.T, d *gotData, e *expectedData){ 961 checkInHeader, 962 checkBegin, 963 checkOutHeader, 964 checkInPayload, 965 checkOutTrailer, 966 checkEnd, 967 }) 968} 969 970type checkFuncWithCount struct { 971 f func(t *testing.T, d *gotData, e *expectedData) 972 c int // expected count 973} 974 975func checkClientStats(t *testing.T, got []*gotData, expect *expectedData, checkFuncs map[int]*checkFuncWithCount) { 976 var expectLen int 977 for _, v := range checkFuncs { 978 expectLen += v.c 979 } 980 if len(got) != expectLen { 981 for i, g := range got { 982 t.Errorf(" - %v, %T", i, g.s) 983 } 984 t.Fatalf("got %v stats, want %v stats", len(got), expectLen) 985 } 986 987 var tagInfoInCtx *stats.RPCTagInfo 988 for i := 0; i < len(got); i++ { 989 if _, ok := got[i].s.(stats.RPCStats); ok { 990 tagInfoInCtxNew, _ := got[i].ctx.Value(rpcCtxKey{}).(*stats.RPCTagInfo) 991 if tagInfoInCtx != nil && tagInfoInCtx != tagInfoInCtxNew { 992 t.Fatalf("got context containing different tagInfo with stats %T", got[i].s) 993 } 994 tagInfoInCtx = tagInfoInCtxNew 995 } 996 } 997 998 for _, s := range got { 999 switch s.s.(type) { 1000 case *stats.Begin: 1001 if checkFuncs[begin].c <= 0 { 1002 t.Fatalf("unexpected stats: %T", s.s) 1003 } 1004 checkFuncs[begin].f(t, s, expect) 1005 checkFuncs[begin].c-- 1006 case *stats.OutHeader: 1007 if checkFuncs[outHeader].c <= 0 { 1008 t.Fatalf("unexpected stats: %T", s.s) 1009 } 1010 checkFuncs[outHeader].f(t, s, expect) 1011 checkFuncs[outHeader].c-- 1012 case *stats.OutPayload: 1013 if checkFuncs[outPayload].c <= 0 { 1014 t.Fatalf("unexpected stats: %T", s.s) 1015 } 1016 checkFuncs[outPayload].f(t, s, expect) 1017 checkFuncs[outPayload].c-- 1018 case *stats.InHeader: 1019 if checkFuncs[inHeader].c <= 0 { 1020 t.Fatalf("unexpected stats: %T", s.s) 1021 } 1022 checkFuncs[inHeader].f(t, s, expect) 1023 checkFuncs[inHeader].c-- 1024 case *stats.InPayload: 1025 if checkFuncs[inPayload].c <= 0 { 1026 t.Fatalf("unexpected stats: %T", s.s) 1027 } 1028 checkFuncs[inPayload].f(t, s, expect) 1029 checkFuncs[inPayload].c-- 1030 case *stats.InTrailer: 1031 if checkFuncs[inTrailer].c <= 0 { 1032 t.Fatalf("unexpected stats: %T", s.s) 1033 } 1034 checkFuncs[inTrailer].f(t, s, expect) 1035 checkFuncs[inTrailer].c-- 1036 case *stats.End: 1037 if checkFuncs[end].c <= 0 { 1038 t.Fatalf("unexpected stats: %T", s.s) 1039 } 1040 checkFuncs[end].f(t, s, expect) 1041 checkFuncs[end].c-- 1042 case *stats.ConnBegin: 1043 if checkFuncs[connBegin].c <= 0 { 1044 t.Fatalf("unexpected stats: %T", s.s) 1045 } 1046 checkFuncs[connBegin].f(t, s, expect) 1047 checkFuncs[connBegin].c-- 1048 case *stats.ConnEnd: 1049 if checkFuncs[connEnd].c <= 0 { 1050 t.Fatalf("unexpected stats: %T", s.s) 1051 } 1052 checkFuncs[connEnd].f(t, s, expect) 1053 checkFuncs[connEnd].c-- 1054 default: 1055 t.Fatalf("unexpected stats: %T", s.s) 1056 } 1057 } 1058} 1059 1060func testClientStats(t *testing.T, tc *testConfig, cc *rpcConfig, checkFuncs map[int]*checkFuncWithCount) { 1061 h := &statshandler{} 1062 te := newTest(t, tc, h, nil) 1063 te.startServer(&testServer{}) 1064 defer te.tearDown() 1065 1066 var ( 1067 reqs []*testpb.SimpleRequest 1068 resps []*testpb.SimpleResponse 1069 method string 1070 err error 1071 1072 req *testpb.SimpleRequest 1073 resp *testpb.SimpleResponse 1074 e error 1075 ) 1076 switch cc.callType { 1077 case unaryRPC: 1078 method = "/grpc.testing.TestService/UnaryCall" 1079 req, resp, e = te.doUnaryCall(cc) 1080 reqs = []*testpb.SimpleRequest{req} 1081 resps = []*testpb.SimpleResponse{resp} 1082 err = e 1083 case clientStreamRPC: 1084 method = "/grpc.testing.TestService/ClientStreamCall" 1085 reqs, resp, e = te.doClientStreamCall(cc) 1086 resps = []*testpb.SimpleResponse{resp} 1087 err = e 1088 case serverStreamRPC: 1089 method = "/grpc.testing.TestService/ServerStreamCall" 1090 req, resps, e = te.doServerStreamCall(cc) 1091 reqs = []*testpb.SimpleRequest{req} 1092 err = e 1093 case fullDuplexStreamRPC: 1094 method = "/grpc.testing.TestService/FullDuplexCall" 1095 reqs, resps, err = te.doFullDuplexCallRoundtrip(cc) 1096 } 1097 if cc.success != (err == nil) { 1098 t.Fatalf("cc.success: %v, got error: %v", cc.success, err) 1099 } 1100 te.cc.Close() 1101 te.srv.GracefulStop() // Wait for the server to stop. 1102 1103 lenRPCStats := 0 1104 for _, v := range checkFuncs { 1105 lenRPCStats += v.c 1106 } 1107 for { 1108 h.mu.Lock() 1109 if len(h.gotRPC) >= lenRPCStats { 1110 h.mu.Unlock() 1111 break 1112 } 1113 h.mu.Unlock() 1114 time.Sleep(10 * time.Millisecond) 1115 } 1116 1117 for { 1118 h.mu.Lock() 1119 if _, ok := h.gotConn[len(h.gotConn)-1].s.(*stats.ConnEnd); ok { 1120 h.mu.Unlock() 1121 break 1122 } 1123 h.mu.Unlock() 1124 time.Sleep(10 * time.Millisecond) 1125 } 1126 1127 expect := &expectedData{ 1128 serverAddr: te.srvAddr, 1129 compression: tc.compress, 1130 method: method, 1131 requests: reqs, 1132 responses: resps, 1133 failfast: cc.failfast, 1134 err: err, 1135 } 1136 1137 h.mu.Lock() 1138 checkConnStats(t, h.gotConn) 1139 h.mu.Unlock() 1140 checkClientStats(t, h.gotRPC, expect, checkFuncs) 1141} 1142 1143func TestClientStatsUnaryRPC(t *testing.T) { 1144 testClientStats(t, &testConfig{compress: ""}, &rpcConfig{success: true, failfast: false, callType: unaryRPC}, map[int]*checkFuncWithCount{ 1145 begin: {checkBegin, 1}, 1146 outHeader: {checkOutHeader, 1}, 1147 outPayload: {checkOutPayload, 1}, 1148 inHeader: {checkInHeader, 1}, 1149 inPayload: {checkInPayload, 1}, 1150 inTrailer: {checkInTrailer, 1}, 1151 end: {checkEnd, 1}, 1152 }) 1153} 1154 1155func TestClientStatsUnaryRPCError(t *testing.T) { 1156 testClientStats(t, &testConfig{compress: ""}, &rpcConfig{success: false, failfast: false, callType: unaryRPC}, map[int]*checkFuncWithCount{ 1157 begin: {checkBegin, 1}, 1158 outHeader: {checkOutHeader, 1}, 1159 outPayload: {checkOutPayload, 1}, 1160 inHeader: {checkInHeader, 1}, 1161 inTrailer: {checkInTrailer, 1}, 1162 end: {checkEnd, 1}, 1163 }) 1164} 1165 1166func TestClientStatsClientStreamRPC(t *testing.T) { 1167 count := 5 1168 testClientStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: true, failfast: false, callType: clientStreamRPC}, map[int]*checkFuncWithCount{ 1169 begin: {checkBegin, 1}, 1170 outHeader: {checkOutHeader, 1}, 1171 inHeader: {checkInHeader, 1}, 1172 outPayload: {checkOutPayload, count}, 1173 inTrailer: {checkInTrailer, 1}, 1174 inPayload: {checkInPayload, 1}, 1175 end: {checkEnd, 1}, 1176 }) 1177} 1178 1179func TestClientStatsClientStreamRPCError(t *testing.T) { 1180 count := 1 1181 testClientStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: false, failfast: false, callType: clientStreamRPC}, map[int]*checkFuncWithCount{ 1182 begin: {checkBegin, 1}, 1183 outHeader: {checkOutHeader, 1}, 1184 inHeader: {checkInHeader, 1}, 1185 outPayload: {checkOutPayload, 1}, 1186 inTrailer: {checkInTrailer, 1}, 1187 end: {checkEnd, 1}, 1188 }) 1189} 1190 1191func TestClientStatsServerStreamRPC(t *testing.T) { 1192 count := 5 1193 testClientStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: true, failfast: false, callType: serverStreamRPC}, map[int]*checkFuncWithCount{ 1194 begin: {checkBegin, 1}, 1195 outHeader: {checkOutHeader, 1}, 1196 outPayload: {checkOutPayload, 1}, 1197 inHeader: {checkInHeader, 1}, 1198 inPayload: {checkInPayload, count}, 1199 inTrailer: {checkInTrailer, 1}, 1200 end: {checkEnd, 1}, 1201 }) 1202} 1203 1204func TestClientStatsServerStreamRPCError(t *testing.T) { 1205 count := 5 1206 testClientStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: false, failfast: false, callType: serverStreamRPC}, map[int]*checkFuncWithCount{ 1207 begin: {checkBegin, 1}, 1208 outHeader: {checkOutHeader, 1}, 1209 outPayload: {checkOutPayload, 1}, 1210 inHeader: {checkInHeader, 1}, 1211 inTrailer: {checkInTrailer, 1}, 1212 end: {checkEnd, 1}, 1213 }) 1214} 1215 1216func TestClientStatsFullDuplexRPC(t *testing.T) { 1217 count := 5 1218 testClientStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: true, failfast: false, callType: fullDuplexStreamRPC}, map[int]*checkFuncWithCount{ 1219 begin: {checkBegin, 1}, 1220 outHeader: {checkOutHeader, 1}, 1221 outPayload: {checkOutPayload, count}, 1222 inHeader: {checkInHeader, 1}, 1223 inPayload: {checkInPayload, count}, 1224 inTrailer: {checkInTrailer, 1}, 1225 end: {checkEnd, 1}, 1226 }) 1227} 1228 1229func TestClientStatsFullDuplexRPCError(t *testing.T) { 1230 count := 5 1231 testClientStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: false, failfast: false, callType: fullDuplexStreamRPC}, map[int]*checkFuncWithCount{ 1232 begin: {checkBegin, 1}, 1233 outHeader: {checkOutHeader, 1}, 1234 outPayload: {checkOutPayload, 1}, 1235 inHeader: {checkInHeader, 1}, 1236 inTrailer: {checkInTrailer, 1}, 1237 end: {checkEnd, 1}, 1238 }) 1239} 1240 1241func TestTags(t *testing.T) { 1242 b := []byte{5, 2, 4, 3, 1} 1243 ctx := stats.SetTags(context.Background(), b) 1244 if tg := stats.OutgoingTags(ctx); !reflect.DeepEqual(tg, b) { 1245 t.Errorf("OutgoingTags(%v) = %v; want %v", ctx, tg, b) 1246 } 1247 if tg := stats.Tags(ctx); tg != nil { 1248 t.Errorf("Tags(%v) = %v; want nil", ctx, tg) 1249 } 1250 1251 ctx = stats.SetIncomingTags(context.Background(), b) 1252 if tg := stats.Tags(ctx); !reflect.DeepEqual(tg, b) { 1253 t.Errorf("Tags(%v) = %v; want %v", ctx, tg, b) 1254 } 1255 if tg := stats.OutgoingTags(ctx); tg != nil { 1256 t.Errorf("OutgoingTags(%v) = %v; want nil", ctx, tg) 1257 } 1258} 1259 1260func TestTrace(t *testing.T) { 1261 b := []byte{5, 2, 4, 3, 1} 1262 ctx := stats.SetTrace(context.Background(), b) 1263 if tr := stats.OutgoingTrace(ctx); !reflect.DeepEqual(tr, b) { 1264 t.Errorf("OutgoingTrace(%v) = %v; want %v", ctx, tr, b) 1265 } 1266 if tr := stats.Trace(ctx); tr != nil { 1267 t.Errorf("Trace(%v) = %v; want nil", ctx, tr) 1268 } 1269 1270 ctx = stats.SetIncomingTrace(context.Background(), b) 1271 if tr := stats.Trace(ctx); !reflect.DeepEqual(tr, b) { 1272 t.Errorf("Trace(%v) = %v; want %v", ctx, tr, b) 1273 } 1274 if tr := stats.OutgoingTrace(ctx); tr != nil { 1275 t.Errorf("OutgoingTrace(%v) = %v; want nil", ctx, tr) 1276 } 1277} 1278