1/*
2 *
3 * Copyright 2016 gRPC authors.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 *     http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 */
18
19package stats_test
20
21import (
22	"context"
23	"fmt"
24	"io"
25	"net"
26	"reflect"
27	"sync"
28	"testing"
29	"time"
30
31	"github.com/golang/protobuf/proto"
32	"google.golang.org/grpc"
33	"google.golang.org/grpc/metadata"
34	"google.golang.org/grpc/stats"
35	testpb "google.golang.org/grpc/stats/grpc_testing"
36	"google.golang.org/grpc/status"
37)
38
39func init() {
40	grpc.EnableTracing = false
41}
42
43type connCtxKey struct{}
44type rpcCtxKey struct{}
45
46var (
47	// For headers:
48	testMetadata = metadata.MD{
49		"key1": []string{"value1"},
50		"key2": []string{"value2"},
51	}
52	// For trailers:
53	testTrailerMetadata = metadata.MD{
54		"tkey1": []string{"trailerValue1"},
55		"tkey2": []string{"trailerValue2"},
56	}
57	// The id for which the service handler should return error.
58	errorID int32 = 32202
59)
60
61type testServer struct{}
62
63func (s *testServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
64	md, ok := metadata.FromIncomingContext(ctx)
65	if ok {
66		if err := grpc.SendHeader(ctx, md); err != nil {
67			return nil, status.Errorf(status.Code(err), "grpc.SendHeader(_, %v) = %v, want <nil>", md, err)
68		}
69		if err := grpc.SetTrailer(ctx, testTrailerMetadata); err != nil {
70			return nil, status.Errorf(status.Code(err), "grpc.SetTrailer(_, %v) = %v, want <nil>", testTrailerMetadata, err)
71		}
72	}
73
74	if in.Id == errorID {
75		return nil, fmt.Errorf("got error id: %v", in.Id)
76	}
77
78	return &testpb.SimpleResponse{Id: in.Id}, nil
79}
80
81func (s *testServer) FullDuplexCall(stream testpb.TestService_FullDuplexCallServer) error {
82	md, ok := metadata.FromIncomingContext(stream.Context())
83	if ok {
84		if err := stream.SendHeader(md); err != nil {
85			return status.Errorf(status.Code(err), "%v.SendHeader(%v) = %v, want %v", stream, md, err, nil)
86		}
87		stream.SetTrailer(testTrailerMetadata)
88	}
89	for {
90		in, err := stream.Recv()
91		if err == io.EOF {
92			// read done.
93			return nil
94		}
95		if err != nil {
96			return err
97		}
98
99		if in.Id == errorID {
100			return fmt.Errorf("got error id: %v", in.Id)
101		}
102
103		if err := stream.Send(&testpb.SimpleResponse{Id: in.Id}); err != nil {
104			return err
105		}
106	}
107}
108
109func (s *testServer) ClientStreamCall(stream testpb.TestService_ClientStreamCallServer) error {
110	md, ok := metadata.FromIncomingContext(stream.Context())
111	if ok {
112		if err := stream.SendHeader(md); err != nil {
113			return status.Errorf(status.Code(err), "%v.SendHeader(%v) = %v, want %v", stream, md, err, nil)
114		}
115		stream.SetTrailer(testTrailerMetadata)
116	}
117	for {
118		in, err := stream.Recv()
119		if err == io.EOF {
120			// read done.
121			return stream.SendAndClose(&testpb.SimpleResponse{Id: int32(0)})
122		}
123		if err != nil {
124			return err
125		}
126
127		if in.Id == errorID {
128			return fmt.Errorf("got error id: %v", in.Id)
129		}
130	}
131}
132
133func (s *testServer) ServerStreamCall(in *testpb.SimpleRequest, stream testpb.TestService_ServerStreamCallServer) error {
134	md, ok := metadata.FromIncomingContext(stream.Context())
135	if ok {
136		if err := stream.SendHeader(md); err != nil {
137			return status.Errorf(status.Code(err), "%v.SendHeader(%v) = %v, want %v", stream, md, err, nil)
138		}
139		stream.SetTrailer(testTrailerMetadata)
140	}
141
142	if in.Id == errorID {
143		return fmt.Errorf("got error id: %v", in.Id)
144	}
145
146	for i := 0; i < 5; i++ {
147		if err := stream.Send(&testpb.SimpleResponse{Id: in.Id}); err != nil {
148			return err
149		}
150	}
151	return nil
152}
153
154// test is an end-to-end test. It should be created with the newTest
155// func, modified as needed, and then started with its startServer method.
156// It should be cleaned up with the tearDown method.
157type test struct {
158	t                  *testing.T
159	compress           string
160	clientStatsHandler stats.Handler
161	serverStatsHandler stats.Handler
162
163	testServer testpb.TestServiceServer // nil means none
164	// srv and srvAddr are set once startServer is called.
165	srv     *grpc.Server
166	srvAddr string
167
168	cc *grpc.ClientConn // nil until requested via clientConn
169}
170
171func (te *test) tearDown() {
172	if te.cc != nil {
173		te.cc.Close()
174		te.cc = nil
175	}
176	te.srv.Stop()
177}
178
179type testConfig struct {
180	compress string
181}
182
183// newTest returns a new test using the provided testing.T and
184// environment.  It is returned with default values. Tests should
185// modify it before calling its startServer and clientConn methods.
186func newTest(t *testing.T, tc *testConfig, ch stats.Handler, sh stats.Handler) *test {
187	te := &test{
188		t:                  t,
189		compress:           tc.compress,
190		clientStatsHandler: ch,
191		serverStatsHandler: sh,
192	}
193	return te
194}
195
196// startServer starts a gRPC server listening. Callers should defer a
197// call to te.tearDown to clean up.
198func (te *test) startServer(ts testpb.TestServiceServer) {
199	te.testServer = ts
200	lis, err := net.Listen("tcp", "localhost:0")
201	if err != nil {
202		te.t.Fatalf("Failed to listen: %v", err)
203	}
204	var opts []grpc.ServerOption
205	if te.compress == "gzip" {
206		opts = append(opts,
207			grpc.RPCCompressor(grpc.NewGZIPCompressor()),
208			grpc.RPCDecompressor(grpc.NewGZIPDecompressor()),
209		)
210	}
211	if te.serverStatsHandler != nil {
212		opts = append(opts, grpc.StatsHandler(te.serverStatsHandler))
213	}
214	s := grpc.NewServer(opts...)
215	te.srv = s
216	if te.testServer != nil {
217		testpb.RegisterTestServiceServer(s, te.testServer)
218	}
219
220	go s.Serve(lis)
221	te.srvAddr = lis.Addr().String()
222}
223
224func (te *test) clientConn() *grpc.ClientConn {
225	if te.cc != nil {
226		return te.cc
227	}
228	opts := []grpc.DialOption{grpc.WithInsecure(), grpc.WithBlock()}
229	if te.compress == "gzip" {
230		opts = append(opts,
231			grpc.WithCompressor(grpc.NewGZIPCompressor()),
232			grpc.WithDecompressor(grpc.NewGZIPDecompressor()),
233		)
234	}
235	if te.clientStatsHandler != nil {
236		opts = append(opts, grpc.WithStatsHandler(te.clientStatsHandler))
237	}
238
239	var err error
240	te.cc, err = grpc.Dial(te.srvAddr, opts...)
241	if err != nil {
242		te.t.Fatalf("Dial(%q) = %v", te.srvAddr, err)
243	}
244	return te.cc
245}
246
247type rpcType int
248
249const (
250	unaryRPC rpcType = iota
251	clientStreamRPC
252	serverStreamRPC
253	fullDuplexStreamRPC
254)
255
256type rpcConfig struct {
257	count    int  // Number of requests and responses for streaming RPCs.
258	success  bool // Whether the RPC should succeed or return error.
259	failfast bool
260	callType rpcType // Type of RPC.
261}
262
263func (te *test) doUnaryCall(c *rpcConfig) (*testpb.SimpleRequest, *testpb.SimpleResponse, error) {
264	var (
265		resp *testpb.SimpleResponse
266		req  *testpb.SimpleRequest
267		err  error
268	)
269	tc := testpb.NewTestServiceClient(te.clientConn())
270	if c.success {
271		req = &testpb.SimpleRequest{Id: errorID + 1}
272	} else {
273		req = &testpb.SimpleRequest{Id: errorID}
274	}
275	ctx := metadata.NewOutgoingContext(context.Background(), testMetadata)
276
277	resp, err = tc.UnaryCall(ctx, req, grpc.WaitForReady(!c.failfast))
278	return req, resp, err
279}
280
281func (te *test) doFullDuplexCallRoundtrip(c *rpcConfig) ([]*testpb.SimpleRequest, []*testpb.SimpleResponse, error) {
282	var (
283		reqs  []*testpb.SimpleRequest
284		resps []*testpb.SimpleResponse
285		err   error
286	)
287	tc := testpb.NewTestServiceClient(te.clientConn())
288	stream, err := tc.FullDuplexCall(metadata.NewOutgoingContext(context.Background(), testMetadata), grpc.WaitForReady(!c.failfast))
289	if err != nil {
290		return reqs, resps, err
291	}
292	var startID int32
293	if !c.success {
294		startID = errorID
295	}
296	for i := 0; i < c.count; i++ {
297		req := &testpb.SimpleRequest{
298			Id: int32(i) + startID,
299		}
300		reqs = append(reqs, req)
301		if err = stream.Send(req); err != nil {
302			return reqs, resps, err
303		}
304		var resp *testpb.SimpleResponse
305		if resp, err = stream.Recv(); err != nil {
306			return reqs, resps, err
307		}
308		resps = append(resps, resp)
309	}
310	if err = stream.CloseSend(); err != nil && err != io.EOF {
311		return reqs, resps, err
312	}
313	if _, err = stream.Recv(); err != io.EOF {
314		return reqs, resps, err
315	}
316
317	return reqs, resps, nil
318}
319
320func (te *test) doClientStreamCall(c *rpcConfig) ([]*testpb.SimpleRequest, *testpb.SimpleResponse, error) {
321	var (
322		reqs []*testpb.SimpleRequest
323		resp *testpb.SimpleResponse
324		err  error
325	)
326	tc := testpb.NewTestServiceClient(te.clientConn())
327	stream, err := tc.ClientStreamCall(metadata.NewOutgoingContext(context.Background(), testMetadata), grpc.WaitForReady(!c.failfast))
328	if err != nil {
329		return reqs, resp, err
330	}
331	var startID int32
332	if !c.success {
333		startID = errorID
334	}
335	for i := 0; i < c.count; i++ {
336		req := &testpb.SimpleRequest{
337			Id: int32(i) + startID,
338		}
339		reqs = append(reqs, req)
340		if err = stream.Send(req); err != nil {
341			return reqs, resp, err
342		}
343	}
344	resp, err = stream.CloseAndRecv()
345	return reqs, resp, err
346}
347
348func (te *test) doServerStreamCall(c *rpcConfig) (*testpb.SimpleRequest, []*testpb.SimpleResponse, error) {
349	var (
350		req   *testpb.SimpleRequest
351		resps []*testpb.SimpleResponse
352		err   error
353	)
354
355	tc := testpb.NewTestServiceClient(te.clientConn())
356
357	var startID int32
358	if !c.success {
359		startID = errorID
360	}
361	req = &testpb.SimpleRequest{Id: startID}
362	stream, err := tc.ServerStreamCall(metadata.NewOutgoingContext(context.Background(), testMetadata), req, grpc.WaitForReady(!c.failfast))
363	if err != nil {
364		return req, resps, err
365	}
366	for {
367		var resp *testpb.SimpleResponse
368		resp, err := stream.Recv()
369		if err == io.EOF {
370			return req, resps, nil
371		} else if err != nil {
372			return req, resps, err
373		}
374		resps = append(resps, resp)
375	}
376}
377
378type expectedData struct {
379	method      string
380	serverAddr  string
381	compression string
382	reqIdx      int
383	requests    []*testpb.SimpleRequest
384	respIdx     int
385	responses   []*testpb.SimpleResponse
386	err         error
387	failfast    bool
388}
389
390type gotData struct {
391	ctx    context.Context
392	client bool
393	s      interface{} // This could be RPCStats or ConnStats.
394}
395
396const (
397	begin int = iota
398	end
399	inPayload
400	inHeader
401	inTrailer
402	outPayload
403	outHeader
404	// TODO: test outTrailer ?
405	connbegin
406	connend
407)
408
409func checkBegin(t *testing.T, d *gotData, e *expectedData) {
410	var (
411		ok bool
412		st *stats.Begin
413	)
414	if st, ok = d.s.(*stats.Begin); !ok {
415		t.Fatalf("got %T, want Begin", d.s)
416	}
417	if d.ctx == nil {
418		t.Fatalf("d.ctx = nil, want <non-nil>")
419	}
420	if st.BeginTime.IsZero() {
421		t.Fatalf("st.BeginTime = %v, want <non-zero>", st.BeginTime)
422	}
423	if d.client {
424		if st.FailFast != e.failfast {
425			t.Fatalf("st.FailFast = %v, want %v", st.FailFast, e.failfast)
426		}
427	}
428}
429
430func checkInHeader(t *testing.T, d *gotData, e *expectedData) {
431	var (
432		ok bool
433		st *stats.InHeader
434	)
435	if st, ok = d.s.(*stats.InHeader); !ok {
436		t.Fatalf("got %T, want InHeader", d.s)
437	}
438	if d.ctx == nil {
439		t.Fatalf("d.ctx = nil, want <non-nil>")
440	}
441	if !d.client {
442		if st.FullMethod != e.method {
443			t.Fatalf("st.FullMethod = %s, want %v", st.FullMethod, e.method)
444		}
445		if st.LocalAddr.String() != e.serverAddr {
446			t.Fatalf("st.LocalAddr = %v, want %v", st.LocalAddr, e.serverAddr)
447		}
448		if st.Compression != e.compression {
449			t.Fatalf("st.Compression = %v, want %v", st.Compression, e.compression)
450		}
451
452		if connInfo, ok := d.ctx.Value(connCtxKey{}).(*stats.ConnTagInfo); ok {
453			if connInfo.RemoteAddr != st.RemoteAddr {
454				t.Fatalf("connInfo.RemoteAddr = %v, want %v", connInfo.RemoteAddr, st.RemoteAddr)
455			}
456			if connInfo.LocalAddr != st.LocalAddr {
457				t.Fatalf("connInfo.LocalAddr = %v, want %v", connInfo.LocalAddr, st.LocalAddr)
458			}
459		} else {
460			t.Fatalf("got context %v, want one with connCtxKey", d.ctx)
461		}
462		if rpcInfo, ok := d.ctx.Value(rpcCtxKey{}).(*stats.RPCTagInfo); ok {
463			if rpcInfo.FullMethodName != st.FullMethod {
464				t.Fatalf("rpcInfo.FullMethod = %s, want %v", rpcInfo.FullMethodName, st.FullMethod)
465			}
466		} else {
467			t.Fatalf("got context %v, want one with rpcCtxKey", d.ctx)
468		}
469	}
470}
471
472func checkInPayload(t *testing.T, d *gotData, e *expectedData) {
473	var (
474		ok bool
475		st *stats.InPayload
476	)
477	if st, ok = d.s.(*stats.InPayload); !ok {
478		t.Fatalf("got %T, want InPayload", d.s)
479	}
480	if d.ctx == nil {
481		t.Fatalf("d.ctx = nil, want <non-nil>")
482	}
483	if d.client {
484		b, err := proto.Marshal(e.responses[e.respIdx])
485		if err != nil {
486			t.Fatalf("failed to marshal message: %v", err)
487		}
488		if reflect.TypeOf(st.Payload) != reflect.TypeOf(e.responses[e.respIdx]) {
489			t.Fatalf("st.Payload = %T, want %T", st.Payload, e.responses[e.respIdx])
490		}
491		e.respIdx++
492		if string(st.Data) != string(b) {
493			t.Fatalf("st.Data = %v, want %v", st.Data, b)
494		}
495		if st.Length != len(b) {
496			t.Fatalf("st.Lenght = %v, want %v", st.Length, len(b))
497		}
498	} else {
499		b, err := proto.Marshal(e.requests[e.reqIdx])
500		if err != nil {
501			t.Fatalf("failed to marshal message: %v", err)
502		}
503		if reflect.TypeOf(st.Payload) != reflect.TypeOf(e.requests[e.reqIdx]) {
504			t.Fatalf("st.Payload = %T, want %T", st.Payload, e.requests[e.reqIdx])
505		}
506		e.reqIdx++
507		if string(st.Data) != string(b) {
508			t.Fatalf("st.Data = %v, want %v", st.Data, b)
509		}
510		if st.Length != len(b) {
511			t.Fatalf("st.Lenght = %v, want %v", st.Length, len(b))
512		}
513	}
514	// TODO check WireLength and ReceivedTime.
515	if st.RecvTime.IsZero() {
516		t.Fatalf("st.ReceivedTime = %v, want <non-zero>", st.RecvTime)
517	}
518}
519
520func checkInTrailer(t *testing.T, d *gotData, e *expectedData) {
521	var (
522		ok bool
523	)
524	if _, ok = d.s.(*stats.InTrailer); !ok {
525		t.Fatalf("got %T, want InTrailer", d.s)
526	}
527	if d.ctx == nil {
528		t.Fatalf("d.ctx = nil, want <non-nil>")
529	}
530}
531
532func checkOutHeader(t *testing.T, d *gotData, e *expectedData) {
533	var (
534		ok bool
535		st *stats.OutHeader
536	)
537	if st, ok = d.s.(*stats.OutHeader); !ok {
538		t.Fatalf("got %T, want OutHeader", d.s)
539	}
540	if d.ctx == nil {
541		t.Fatalf("d.ctx = nil, want <non-nil>")
542	}
543	if d.client {
544		if st.FullMethod != e.method {
545			t.Fatalf("st.FullMethod = %s, want %v", st.FullMethod, e.method)
546		}
547		if st.RemoteAddr.String() != e.serverAddr {
548			t.Fatalf("st.RemoteAddr = %v, want %v", st.RemoteAddr, e.serverAddr)
549		}
550		if st.Compression != e.compression {
551			t.Fatalf("st.Compression = %v, want %v", st.Compression, e.compression)
552		}
553
554		if rpcInfo, ok := d.ctx.Value(rpcCtxKey{}).(*stats.RPCTagInfo); ok {
555			if rpcInfo.FullMethodName != st.FullMethod {
556				t.Fatalf("rpcInfo.FullMethod = %s, want %v", rpcInfo.FullMethodName, st.FullMethod)
557			}
558		} else {
559			t.Fatalf("got context %v, want one with rpcCtxKey", d.ctx)
560		}
561	}
562}
563
564func checkOutPayload(t *testing.T, d *gotData, e *expectedData) {
565	var (
566		ok bool
567		st *stats.OutPayload
568	)
569	if st, ok = d.s.(*stats.OutPayload); !ok {
570		t.Fatalf("got %T, want OutPayload", d.s)
571	}
572	if d.ctx == nil {
573		t.Fatalf("d.ctx = nil, want <non-nil>")
574	}
575	if d.client {
576		b, err := proto.Marshal(e.requests[e.reqIdx])
577		if err != nil {
578			t.Fatalf("failed to marshal message: %v", err)
579		}
580		if reflect.TypeOf(st.Payload) != reflect.TypeOf(e.requests[e.reqIdx]) {
581			t.Fatalf("st.Payload = %T, want %T", st.Payload, e.requests[e.reqIdx])
582		}
583		e.reqIdx++
584		if string(st.Data) != string(b) {
585			t.Fatalf("st.Data = %v, want %v", st.Data, b)
586		}
587		if st.Length != len(b) {
588			t.Fatalf("st.Lenght = %v, want %v", st.Length, len(b))
589		}
590	} else {
591		b, err := proto.Marshal(e.responses[e.respIdx])
592		if err != nil {
593			t.Fatalf("failed to marshal message: %v", err)
594		}
595		if reflect.TypeOf(st.Payload) != reflect.TypeOf(e.responses[e.respIdx]) {
596			t.Fatalf("st.Payload = %T, want %T", st.Payload, e.responses[e.respIdx])
597		}
598		e.respIdx++
599		if string(st.Data) != string(b) {
600			t.Fatalf("st.Data = %v, want %v", st.Data, b)
601		}
602		if st.Length != len(b) {
603			t.Fatalf("st.Lenght = %v, want %v", st.Length, len(b))
604		}
605	}
606	// TODO check WireLength and ReceivedTime.
607	if st.SentTime.IsZero() {
608		t.Fatalf("st.SentTime = %v, want <non-zero>", st.SentTime)
609	}
610}
611
612func checkOutTrailer(t *testing.T, d *gotData, e *expectedData) {
613	var (
614		ok bool
615		st *stats.OutTrailer
616	)
617	if st, ok = d.s.(*stats.OutTrailer); !ok {
618		t.Fatalf("got %T, want OutTrailer", d.s)
619	}
620	if d.ctx == nil {
621		t.Fatalf("d.ctx = nil, want <non-nil>")
622	}
623	if st.Client {
624		t.Fatalf("st IsClient = true, want false")
625	}
626}
627
628func checkEnd(t *testing.T, d *gotData, e *expectedData) {
629	var (
630		ok bool
631		st *stats.End
632	)
633	if st, ok = d.s.(*stats.End); !ok {
634		t.Fatalf("got %T, want End", d.s)
635	}
636	if d.ctx == nil {
637		t.Fatalf("d.ctx = nil, want <non-nil>")
638	}
639	if st.BeginTime.IsZero() {
640		t.Fatalf("st.BeginTime = %v, want <non-zero>", st.BeginTime)
641	}
642	if st.EndTime.IsZero() {
643		t.Fatalf("st.EndTime = %v, want <non-zero>", st.EndTime)
644	}
645
646	actual, ok := status.FromError(st.Error)
647	if !ok {
648		t.Fatalf("expected st.Error to be a statusError, got %v (type %T)", st.Error, st.Error)
649	}
650
651	expectedStatus, _ := status.FromError(e.err)
652	if actual.Code() != expectedStatus.Code() || actual.Message() != expectedStatus.Message() {
653		t.Fatalf("st.Error = %v, want %v", st.Error, e.err)
654	}
655
656	if st.Client {
657		if !reflect.DeepEqual(st.Trailer, testTrailerMetadata) {
658			t.Fatalf("st.Trailer = %v, want %v", st.Trailer, testTrailerMetadata)
659		}
660	} else {
661		if st.Trailer != nil {
662			t.Fatalf("st.Trailer = %v, want nil", st.Trailer)
663		}
664	}
665}
666
667func checkConnBegin(t *testing.T, d *gotData, e *expectedData) {
668	var (
669		ok bool
670		st *stats.ConnBegin
671	)
672	if st, ok = d.s.(*stats.ConnBegin); !ok {
673		t.Fatalf("got %T, want ConnBegin", d.s)
674	}
675	if d.ctx == nil {
676		t.Fatalf("d.ctx = nil, want <non-nil>")
677	}
678	st.IsClient() // TODO remove this.
679}
680
681func checkConnEnd(t *testing.T, d *gotData, e *expectedData) {
682	var (
683		ok bool
684		st *stats.ConnEnd
685	)
686	if st, ok = d.s.(*stats.ConnEnd); !ok {
687		t.Fatalf("got %T, want ConnEnd", d.s)
688	}
689	if d.ctx == nil {
690		t.Fatalf("d.ctx = nil, want <non-nil>")
691	}
692	st.IsClient() // TODO remove this.
693}
694
695type statshandler struct {
696	mu      sync.Mutex
697	gotRPC  []*gotData
698	gotConn []*gotData
699}
700
701func (h *statshandler) TagConn(ctx context.Context, info *stats.ConnTagInfo) context.Context {
702	return context.WithValue(ctx, connCtxKey{}, info)
703}
704
705func (h *statshandler) TagRPC(ctx context.Context, info *stats.RPCTagInfo) context.Context {
706	return context.WithValue(ctx, rpcCtxKey{}, info)
707}
708
709func (h *statshandler) HandleConn(ctx context.Context, s stats.ConnStats) {
710	h.mu.Lock()
711	defer h.mu.Unlock()
712	h.gotConn = append(h.gotConn, &gotData{ctx, s.IsClient(), s})
713}
714
715func (h *statshandler) HandleRPC(ctx context.Context, s stats.RPCStats) {
716	h.mu.Lock()
717	defer h.mu.Unlock()
718	h.gotRPC = append(h.gotRPC, &gotData{ctx, s.IsClient(), s})
719}
720
721func checkConnStats(t *testing.T, got []*gotData) {
722	if len(got) <= 0 || len(got)%2 != 0 {
723		for i, g := range got {
724			t.Errorf(" - %v, %T = %+v, ctx: %v", i, g.s, g.s, g.ctx)
725		}
726		t.Fatalf("got %v stats, want even positive number", len(got))
727	}
728	// The first conn stats must be a ConnBegin.
729	checkConnBegin(t, got[0], nil)
730	// The last conn stats must be a ConnEnd.
731	checkConnEnd(t, got[len(got)-1], nil)
732}
733
734func checkServerStats(t *testing.T, got []*gotData, expect *expectedData, checkFuncs []func(t *testing.T, d *gotData, e *expectedData)) {
735	if len(got) != len(checkFuncs) {
736		for i, g := range got {
737			t.Errorf(" - %v, %T", i, g.s)
738		}
739		t.Fatalf("got %v stats, want %v stats", len(got), len(checkFuncs))
740	}
741
742	var rpcctx context.Context
743	for i := 0; i < len(got); i++ {
744		if _, ok := got[i].s.(stats.RPCStats); ok {
745			if rpcctx != nil && got[i].ctx != rpcctx {
746				t.Fatalf("got different contexts with stats %T", got[i].s)
747			}
748			rpcctx = got[i].ctx
749		}
750	}
751
752	for i, f := range checkFuncs {
753		f(t, got[i], expect)
754	}
755}
756
757func testServerStats(t *testing.T, tc *testConfig, cc *rpcConfig, checkFuncs []func(t *testing.T, d *gotData, e *expectedData)) {
758	h := &statshandler{}
759	te := newTest(t, tc, nil, h)
760	te.startServer(&testServer{})
761	defer te.tearDown()
762
763	var (
764		reqs   []*testpb.SimpleRequest
765		resps  []*testpb.SimpleResponse
766		err    error
767		method string
768
769		req  *testpb.SimpleRequest
770		resp *testpb.SimpleResponse
771		e    error
772	)
773
774	switch cc.callType {
775	case unaryRPC:
776		method = "/grpc.testing.TestService/UnaryCall"
777		req, resp, e = te.doUnaryCall(cc)
778		reqs = []*testpb.SimpleRequest{req}
779		resps = []*testpb.SimpleResponse{resp}
780		err = e
781	case clientStreamRPC:
782		method = "/grpc.testing.TestService/ClientStreamCall"
783		reqs, resp, e = te.doClientStreamCall(cc)
784		resps = []*testpb.SimpleResponse{resp}
785		err = e
786	case serverStreamRPC:
787		method = "/grpc.testing.TestService/ServerStreamCall"
788		req, resps, e = te.doServerStreamCall(cc)
789		reqs = []*testpb.SimpleRequest{req}
790		err = e
791	case fullDuplexStreamRPC:
792		method = "/grpc.testing.TestService/FullDuplexCall"
793		reqs, resps, err = te.doFullDuplexCallRoundtrip(cc)
794	}
795	if cc.success != (err == nil) {
796		t.Fatalf("cc.success: %v, got error: %v", cc.success, err)
797	}
798	te.cc.Close()
799	te.srv.GracefulStop() // Wait for the server to stop.
800
801	for {
802		h.mu.Lock()
803		if len(h.gotRPC) >= len(checkFuncs) {
804			h.mu.Unlock()
805			break
806		}
807		h.mu.Unlock()
808		time.Sleep(10 * time.Millisecond)
809	}
810
811	for {
812		h.mu.Lock()
813		if _, ok := h.gotConn[len(h.gotConn)-1].s.(*stats.ConnEnd); ok {
814			h.mu.Unlock()
815			break
816		}
817		h.mu.Unlock()
818		time.Sleep(10 * time.Millisecond)
819	}
820
821	expect := &expectedData{
822		serverAddr:  te.srvAddr,
823		compression: tc.compress,
824		method:      method,
825		requests:    reqs,
826		responses:   resps,
827		err:         err,
828	}
829
830	h.mu.Lock()
831	checkConnStats(t, h.gotConn)
832	h.mu.Unlock()
833	checkServerStats(t, h.gotRPC, expect, checkFuncs)
834}
835
836func TestServerStatsUnaryRPC(t *testing.T) {
837	testServerStats(t, &testConfig{compress: ""}, &rpcConfig{success: true, callType: unaryRPC}, []func(t *testing.T, d *gotData, e *expectedData){
838		checkInHeader,
839		checkBegin,
840		checkInPayload,
841		checkOutHeader,
842		checkOutPayload,
843		checkOutTrailer,
844		checkEnd,
845	})
846}
847
848func TestServerStatsUnaryRPCError(t *testing.T) {
849	testServerStats(t, &testConfig{compress: ""}, &rpcConfig{success: false, callType: unaryRPC}, []func(t *testing.T, d *gotData, e *expectedData){
850		checkInHeader,
851		checkBegin,
852		checkInPayload,
853		checkOutHeader,
854		checkOutTrailer,
855		checkEnd,
856	})
857}
858
859func TestServerStatsClientStreamRPC(t *testing.T) {
860	count := 5
861	checkFuncs := []func(t *testing.T, d *gotData, e *expectedData){
862		checkInHeader,
863		checkBegin,
864		checkOutHeader,
865	}
866	ioPayFuncs := []func(t *testing.T, d *gotData, e *expectedData){
867		checkInPayload,
868	}
869	for i := 0; i < count; i++ {
870		checkFuncs = append(checkFuncs, ioPayFuncs...)
871	}
872	checkFuncs = append(checkFuncs,
873		checkOutPayload,
874		checkOutTrailer,
875		checkEnd,
876	)
877	testServerStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: true, callType: clientStreamRPC}, checkFuncs)
878}
879
880func TestServerStatsClientStreamRPCError(t *testing.T) {
881	count := 1
882	testServerStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: false, callType: clientStreamRPC}, []func(t *testing.T, d *gotData, e *expectedData){
883		checkInHeader,
884		checkBegin,
885		checkOutHeader,
886		checkInPayload,
887		checkOutTrailer,
888		checkEnd,
889	})
890}
891
892func TestServerStatsServerStreamRPC(t *testing.T) {
893	count := 5
894	checkFuncs := []func(t *testing.T, d *gotData, e *expectedData){
895		checkInHeader,
896		checkBegin,
897		checkInPayload,
898		checkOutHeader,
899	}
900	ioPayFuncs := []func(t *testing.T, d *gotData, e *expectedData){
901		checkOutPayload,
902	}
903	for i := 0; i < count; i++ {
904		checkFuncs = append(checkFuncs, ioPayFuncs...)
905	}
906	checkFuncs = append(checkFuncs,
907		checkOutTrailer,
908		checkEnd,
909	)
910	testServerStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: true, callType: serverStreamRPC}, checkFuncs)
911}
912
913func TestServerStatsServerStreamRPCError(t *testing.T) {
914	count := 5
915	testServerStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: false, callType: serverStreamRPC}, []func(t *testing.T, d *gotData, e *expectedData){
916		checkInHeader,
917		checkBegin,
918		checkInPayload,
919		checkOutHeader,
920		checkOutTrailer,
921		checkEnd,
922	})
923}
924
925func TestServerStatsFullDuplexRPC(t *testing.T) {
926	count := 5
927	checkFuncs := []func(t *testing.T, d *gotData, e *expectedData){
928		checkInHeader,
929		checkBegin,
930		checkOutHeader,
931	}
932	ioPayFuncs := []func(t *testing.T, d *gotData, e *expectedData){
933		checkInPayload,
934		checkOutPayload,
935	}
936	for i := 0; i < count; i++ {
937		checkFuncs = append(checkFuncs, ioPayFuncs...)
938	}
939	checkFuncs = append(checkFuncs,
940		checkOutTrailer,
941		checkEnd,
942	)
943	testServerStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: true, callType: fullDuplexStreamRPC}, checkFuncs)
944}
945
946func TestServerStatsFullDuplexRPCError(t *testing.T) {
947	count := 5
948	testServerStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: false, callType: fullDuplexStreamRPC}, []func(t *testing.T, d *gotData, e *expectedData){
949		checkInHeader,
950		checkBegin,
951		checkOutHeader,
952		checkInPayload,
953		checkOutTrailer,
954		checkEnd,
955	})
956}
957
958type checkFuncWithCount struct {
959	f func(t *testing.T, d *gotData, e *expectedData)
960	c int // expected count
961}
962
963func checkClientStats(t *testing.T, got []*gotData, expect *expectedData, checkFuncs map[int]*checkFuncWithCount) {
964	var expectLen int
965	for _, v := range checkFuncs {
966		expectLen += v.c
967	}
968	if len(got) != expectLen {
969		for i, g := range got {
970			t.Errorf(" - %v, %T", i, g.s)
971		}
972		t.Fatalf("got %v stats, want %v stats", len(got), expectLen)
973	}
974
975	var tagInfoInCtx *stats.RPCTagInfo
976	for i := 0; i < len(got); i++ {
977		if _, ok := got[i].s.(stats.RPCStats); ok {
978			tagInfoInCtxNew, _ := got[i].ctx.Value(rpcCtxKey{}).(*stats.RPCTagInfo)
979			if tagInfoInCtx != nil && tagInfoInCtx != tagInfoInCtxNew {
980				t.Fatalf("got context containing different tagInfo with stats %T", got[i].s)
981			}
982			tagInfoInCtx = tagInfoInCtxNew
983		}
984	}
985
986	for _, s := range got {
987		switch s.s.(type) {
988		case *stats.Begin:
989			if checkFuncs[begin].c <= 0 {
990				t.Fatalf("unexpected stats: %T", s.s)
991			}
992			checkFuncs[begin].f(t, s, expect)
993			checkFuncs[begin].c--
994		case *stats.OutHeader:
995			if checkFuncs[outHeader].c <= 0 {
996				t.Fatalf("unexpected stats: %T", s.s)
997			}
998			checkFuncs[outHeader].f(t, s, expect)
999			checkFuncs[outHeader].c--
1000		case *stats.OutPayload:
1001			if checkFuncs[outPayload].c <= 0 {
1002				t.Fatalf("unexpected stats: %T", s.s)
1003			}
1004			checkFuncs[outPayload].f(t, s, expect)
1005			checkFuncs[outPayload].c--
1006		case *stats.InHeader:
1007			if checkFuncs[inHeader].c <= 0 {
1008				t.Fatalf("unexpected stats: %T", s.s)
1009			}
1010			checkFuncs[inHeader].f(t, s, expect)
1011			checkFuncs[inHeader].c--
1012		case *stats.InPayload:
1013			if checkFuncs[inPayload].c <= 0 {
1014				t.Fatalf("unexpected stats: %T", s.s)
1015			}
1016			checkFuncs[inPayload].f(t, s, expect)
1017			checkFuncs[inPayload].c--
1018		case *stats.InTrailer:
1019			if checkFuncs[inTrailer].c <= 0 {
1020				t.Fatalf("unexpected stats: %T", s.s)
1021			}
1022			checkFuncs[inTrailer].f(t, s, expect)
1023			checkFuncs[inTrailer].c--
1024		case *stats.End:
1025			if checkFuncs[end].c <= 0 {
1026				t.Fatalf("unexpected stats: %T", s.s)
1027			}
1028			checkFuncs[end].f(t, s, expect)
1029			checkFuncs[end].c--
1030		case *stats.ConnBegin:
1031			if checkFuncs[connbegin].c <= 0 {
1032				t.Fatalf("unexpected stats: %T", s.s)
1033			}
1034			checkFuncs[connbegin].f(t, s, expect)
1035			checkFuncs[connbegin].c--
1036		case *stats.ConnEnd:
1037			if checkFuncs[connend].c <= 0 {
1038				t.Fatalf("unexpected stats: %T", s.s)
1039			}
1040			checkFuncs[connend].f(t, s, expect)
1041			checkFuncs[connend].c--
1042		default:
1043			t.Fatalf("unexpected stats: %T", s.s)
1044		}
1045	}
1046}
1047
1048func testClientStats(t *testing.T, tc *testConfig, cc *rpcConfig, checkFuncs map[int]*checkFuncWithCount) {
1049	h := &statshandler{}
1050	te := newTest(t, tc, h, nil)
1051	te.startServer(&testServer{})
1052	defer te.tearDown()
1053
1054	var (
1055		reqs   []*testpb.SimpleRequest
1056		resps  []*testpb.SimpleResponse
1057		method string
1058		err    error
1059
1060		req  *testpb.SimpleRequest
1061		resp *testpb.SimpleResponse
1062		e    error
1063	)
1064	switch cc.callType {
1065	case unaryRPC:
1066		method = "/grpc.testing.TestService/UnaryCall"
1067		req, resp, e = te.doUnaryCall(cc)
1068		reqs = []*testpb.SimpleRequest{req}
1069		resps = []*testpb.SimpleResponse{resp}
1070		err = e
1071	case clientStreamRPC:
1072		method = "/grpc.testing.TestService/ClientStreamCall"
1073		reqs, resp, e = te.doClientStreamCall(cc)
1074		resps = []*testpb.SimpleResponse{resp}
1075		err = e
1076	case serverStreamRPC:
1077		method = "/grpc.testing.TestService/ServerStreamCall"
1078		req, resps, e = te.doServerStreamCall(cc)
1079		reqs = []*testpb.SimpleRequest{req}
1080		err = e
1081	case fullDuplexStreamRPC:
1082		method = "/grpc.testing.TestService/FullDuplexCall"
1083		reqs, resps, err = te.doFullDuplexCallRoundtrip(cc)
1084	}
1085	if cc.success != (err == nil) {
1086		t.Fatalf("cc.success: %v, got error: %v", cc.success, err)
1087	}
1088	te.cc.Close()
1089	te.srv.GracefulStop() // Wait for the server to stop.
1090
1091	lenRPCStats := 0
1092	for _, v := range checkFuncs {
1093		lenRPCStats += v.c
1094	}
1095	for {
1096		h.mu.Lock()
1097		if len(h.gotRPC) >= lenRPCStats {
1098			h.mu.Unlock()
1099			break
1100		}
1101		h.mu.Unlock()
1102		time.Sleep(10 * time.Millisecond)
1103	}
1104
1105	for {
1106		h.mu.Lock()
1107		if _, ok := h.gotConn[len(h.gotConn)-1].s.(*stats.ConnEnd); ok {
1108			h.mu.Unlock()
1109			break
1110		}
1111		h.mu.Unlock()
1112		time.Sleep(10 * time.Millisecond)
1113	}
1114
1115	expect := &expectedData{
1116		serverAddr:  te.srvAddr,
1117		compression: tc.compress,
1118		method:      method,
1119		requests:    reqs,
1120		responses:   resps,
1121		failfast:    cc.failfast,
1122		err:         err,
1123	}
1124
1125	h.mu.Lock()
1126	checkConnStats(t, h.gotConn)
1127	h.mu.Unlock()
1128	checkClientStats(t, h.gotRPC, expect, checkFuncs)
1129}
1130
1131func TestClientStatsUnaryRPC(t *testing.T) {
1132	testClientStats(t, &testConfig{compress: ""}, &rpcConfig{success: true, failfast: false, callType: unaryRPC}, map[int]*checkFuncWithCount{
1133		begin:      {checkBegin, 1},
1134		outHeader:  {checkOutHeader, 1},
1135		outPayload: {checkOutPayload, 1},
1136		inHeader:   {checkInHeader, 1},
1137		inPayload:  {checkInPayload, 1},
1138		inTrailer:  {checkInTrailer, 1},
1139		end:        {checkEnd, 1},
1140	})
1141}
1142
1143func TestClientStatsUnaryRPCError(t *testing.T) {
1144	testClientStats(t, &testConfig{compress: ""}, &rpcConfig{success: false, failfast: false, callType: unaryRPC}, map[int]*checkFuncWithCount{
1145		begin:      {checkBegin, 1},
1146		outHeader:  {checkOutHeader, 1},
1147		outPayload: {checkOutPayload, 1},
1148		inHeader:   {checkInHeader, 1},
1149		inTrailer:  {checkInTrailer, 1},
1150		end:        {checkEnd, 1},
1151	})
1152}
1153
1154func TestClientStatsClientStreamRPC(t *testing.T) {
1155	count := 5
1156	testClientStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: true, failfast: false, callType: clientStreamRPC}, map[int]*checkFuncWithCount{
1157		begin:      {checkBegin, 1},
1158		outHeader:  {checkOutHeader, 1},
1159		inHeader:   {checkInHeader, 1},
1160		outPayload: {checkOutPayload, count},
1161		inTrailer:  {checkInTrailer, 1},
1162		inPayload:  {checkInPayload, 1},
1163		end:        {checkEnd, 1},
1164	})
1165}
1166
1167func TestClientStatsClientStreamRPCError(t *testing.T) {
1168	count := 1
1169	testClientStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: false, failfast: false, callType: clientStreamRPC}, map[int]*checkFuncWithCount{
1170		begin:      {checkBegin, 1},
1171		outHeader:  {checkOutHeader, 1},
1172		inHeader:   {checkInHeader, 1},
1173		outPayload: {checkOutPayload, 1},
1174		inTrailer:  {checkInTrailer, 1},
1175		end:        {checkEnd, 1},
1176	})
1177}
1178
1179func TestClientStatsServerStreamRPC(t *testing.T) {
1180	count := 5
1181	testClientStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: true, failfast: false, callType: serverStreamRPC}, map[int]*checkFuncWithCount{
1182		begin:      {checkBegin, 1},
1183		outHeader:  {checkOutHeader, 1},
1184		outPayload: {checkOutPayload, 1},
1185		inHeader:   {checkInHeader, 1},
1186		inPayload:  {checkInPayload, count},
1187		inTrailer:  {checkInTrailer, 1},
1188		end:        {checkEnd, 1},
1189	})
1190}
1191
1192func TestClientStatsServerStreamRPCError(t *testing.T) {
1193	count := 5
1194	testClientStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: false, failfast: false, callType: serverStreamRPC}, map[int]*checkFuncWithCount{
1195		begin:      {checkBegin, 1},
1196		outHeader:  {checkOutHeader, 1},
1197		outPayload: {checkOutPayload, 1},
1198		inHeader:   {checkInHeader, 1},
1199		inTrailer:  {checkInTrailer, 1},
1200		end:        {checkEnd, 1},
1201	})
1202}
1203
1204func TestClientStatsFullDuplexRPC(t *testing.T) {
1205	count := 5
1206	testClientStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: true, failfast: false, callType: fullDuplexStreamRPC}, map[int]*checkFuncWithCount{
1207		begin:      {checkBegin, 1},
1208		outHeader:  {checkOutHeader, 1},
1209		outPayload: {checkOutPayload, count},
1210		inHeader:   {checkInHeader, 1},
1211		inPayload:  {checkInPayload, count},
1212		inTrailer:  {checkInTrailer, 1},
1213		end:        {checkEnd, 1},
1214	})
1215}
1216
1217func TestClientStatsFullDuplexRPCError(t *testing.T) {
1218	count := 5
1219	testClientStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: false, failfast: false, callType: fullDuplexStreamRPC}, map[int]*checkFuncWithCount{
1220		begin:      {checkBegin, 1},
1221		outHeader:  {checkOutHeader, 1},
1222		outPayload: {checkOutPayload, 1},
1223		inHeader:   {checkInHeader, 1},
1224		inTrailer:  {checkInTrailer, 1},
1225		end:        {checkEnd, 1},
1226	})
1227}
1228
1229func TestTags(t *testing.T) {
1230	b := []byte{5, 2, 4, 3, 1}
1231	ctx := stats.SetTags(context.Background(), b)
1232	if tg := stats.OutgoingTags(ctx); !reflect.DeepEqual(tg, b) {
1233		t.Errorf("OutgoingTags(%v) = %v; want %v", ctx, tg, b)
1234	}
1235	if tg := stats.Tags(ctx); tg != nil {
1236		t.Errorf("Tags(%v) = %v; want nil", ctx, tg)
1237	}
1238
1239	ctx = stats.SetIncomingTags(context.Background(), b)
1240	if tg := stats.Tags(ctx); !reflect.DeepEqual(tg, b) {
1241		t.Errorf("Tags(%v) = %v; want %v", ctx, tg, b)
1242	}
1243	if tg := stats.OutgoingTags(ctx); tg != nil {
1244		t.Errorf("OutgoingTags(%v) = %v; want nil", ctx, tg)
1245	}
1246}
1247
1248func TestTrace(t *testing.T) {
1249	b := []byte{5, 2, 4, 3, 1}
1250	ctx := stats.SetTrace(context.Background(), b)
1251	if tr := stats.OutgoingTrace(ctx); !reflect.DeepEqual(tr, b) {
1252		t.Errorf("OutgoingTrace(%v) = %v; want %v", ctx, tr, b)
1253	}
1254	if tr := stats.Trace(ctx); tr != nil {
1255		t.Errorf("Trace(%v) = %v; want nil", ctx, tr)
1256	}
1257
1258	ctx = stats.SetIncomingTrace(context.Background(), b)
1259	if tr := stats.Trace(ctx); !reflect.DeepEqual(tr, b) {
1260		t.Errorf("Trace(%v) = %v; want %v", ctx, tr, b)
1261	}
1262	if tr := stats.OutgoingTrace(ctx); tr != nil {
1263		t.Errorf("OutgoingTrace(%v) = %v; want nil", ctx, tr)
1264	}
1265}
1266