1/*
2 *
3 * Copyright 2016 gRPC authors.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 *     http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 */
18
19package stats_test
20
21import (
22	"context"
23	"fmt"
24	"io"
25	"net"
26	"reflect"
27	"sync"
28	"testing"
29	"time"
30
31	"github.com/golang/protobuf/proto"
32	"google.golang.org/grpc"
33	"google.golang.org/grpc/metadata"
34	"google.golang.org/grpc/stats"
35	testpb "google.golang.org/grpc/stats/grpc_testing"
36	"google.golang.org/grpc/status"
37)
38
39func init() {
40	grpc.EnableTracing = false
41}
42
43type connCtxKey struct{}
44type rpcCtxKey struct{}
45
46var (
47	// For headers:
48	testMetadata = metadata.MD{
49		"key1": []string{"value1"},
50		"key2": []string{"value2"},
51	}
52	// For trailers:
53	testTrailerMetadata = metadata.MD{
54		"tkey1": []string{"trailerValue1"},
55		"tkey2": []string{"trailerValue2"},
56	}
57	// The id for which the service handler should return error.
58	errorID int32 = 32202
59)
60
61type testServer struct {
62	testpb.UnimplementedTestServiceServer
63}
64
65func (s *testServer) UnaryCall(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
66	md, ok := metadata.FromIncomingContext(ctx)
67	if ok {
68		if err := grpc.SendHeader(ctx, md); err != nil {
69			return nil, status.Errorf(status.Code(err), "grpc.SendHeader(_, %v) = %v, want <nil>", md, err)
70		}
71		if err := grpc.SetTrailer(ctx, testTrailerMetadata); err != nil {
72			return nil, status.Errorf(status.Code(err), "grpc.SetTrailer(_, %v) = %v, want <nil>", testTrailerMetadata, err)
73		}
74	}
75
76	if in.Id == errorID {
77		return nil, fmt.Errorf("got error id: %v", in.Id)
78	}
79
80	return &testpb.SimpleResponse{Id: in.Id}, nil
81}
82
83func (s *testServer) FullDuplexCall(stream testpb.TestService_FullDuplexCallServer) error {
84	md, ok := metadata.FromIncomingContext(stream.Context())
85	if ok {
86		if err := stream.SendHeader(md); err != nil {
87			return status.Errorf(status.Code(err), "%v.SendHeader(%v) = %v, want %v", stream, md, err, nil)
88		}
89		stream.SetTrailer(testTrailerMetadata)
90	}
91	for {
92		in, err := stream.Recv()
93		if err == io.EOF {
94			// read done.
95			return nil
96		}
97		if err != nil {
98			return err
99		}
100
101		if in.Id == errorID {
102			return fmt.Errorf("got error id: %v", in.Id)
103		}
104
105		if err := stream.Send(&testpb.SimpleResponse{Id: in.Id}); err != nil {
106			return err
107		}
108	}
109}
110
111func (s *testServer) ClientStreamCall(stream testpb.TestService_ClientStreamCallServer) error {
112	md, ok := metadata.FromIncomingContext(stream.Context())
113	if ok {
114		if err := stream.SendHeader(md); err != nil {
115			return status.Errorf(status.Code(err), "%v.SendHeader(%v) = %v, want %v", stream, md, err, nil)
116		}
117		stream.SetTrailer(testTrailerMetadata)
118	}
119	for {
120		in, err := stream.Recv()
121		if err == io.EOF {
122			// read done.
123			return stream.SendAndClose(&testpb.SimpleResponse{Id: int32(0)})
124		}
125		if err != nil {
126			return err
127		}
128
129		if in.Id == errorID {
130			return fmt.Errorf("got error id: %v", in.Id)
131		}
132	}
133}
134
135func (s *testServer) ServerStreamCall(in *testpb.SimpleRequest, stream testpb.TestService_ServerStreamCallServer) error {
136	md, ok := metadata.FromIncomingContext(stream.Context())
137	if ok {
138		if err := stream.SendHeader(md); err != nil {
139			return status.Errorf(status.Code(err), "%v.SendHeader(%v) = %v, want %v", stream, md, err, nil)
140		}
141		stream.SetTrailer(testTrailerMetadata)
142	}
143
144	if in.Id == errorID {
145		return fmt.Errorf("got error id: %v", in.Id)
146	}
147
148	for i := 0; i < 5; i++ {
149		if err := stream.Send(&testpb.SimpleResponse{Id: in.Id}); err != nil {
150			return err
151		}
152	}
153	return nil
154}
155
156// test is an end-to-end test. It should be created with the newTest
157// func, modified as needed, and then started with its startServer method.
158// It should be cleaned up with the tearDown method.
159type test struct {
160	t                  *testing.T
161	compress           string
162	clientStatsHandler stats.Handler
163	serverStatsHandler stats.Handler
164
165	testServer testpb.TestServiceServer // nil means none
166	// srv and srvAddr are set once startServer is called.
167	srv     *grpc.Server
168	srvAddr string
169
170	cc *grpc.ClientConn // nil until requested via clientConn
171}
172
173func (te *test) tearDown() {
174	if te.cc != nil {
175		te.cc.Close()
176		te.cc = nil
177	}
178	te.srv.Stop()
179}
180
181type testConfig struct {
182	compress string
183}
184
185// newTest returns a new test using the provided testing.T and
186// environment.  It is returned with default values. Tests should
187// modify it before calling its startServer and clientConn methods.
188func newTest(t *testing.T, tc *testConfig, ch stats.Handler, sh stats.Handler) *test {
189	te := &test{
190		t:                  t,
191		compress:           tc.compress,
192		clientStatsHandler: ch,
193		serverStatsHandler: sh,
194	}
195	return te
196}
197
198// startServer starts a gRPC server listening. Callers should defer a
199// call to te.tearDown to clean up.
200func (te *test) startServer(ts testpb.TestServiceServer) {
201	te.testServer = ts
202	lis, err := net.Listen("tcp", "localhost:0")
203	if err != nil {
204		te.t.Fatalf("Failed to listen: %v", err)
205	}
206	var opts []grpc.ServerOption
207	if te.compress == "gzip" {
208		opts = append(opts,
209			grpc.RPCCompressor(grpc.NewGZIPCompressor()),
210			grpc.RPCDecompressor(grpc.NewGZIPDecompressor()),
211		)
212	}
213	if te.serverStatsHandler != nil {
214		opts = append(opts, grpc.StatsHandler(te.serverStatsHandler))
215	}
216	s := grpc.NewServer(opts...)
217	te.srv = s
218	if te.testServer != nil {
219		testpb.RegisterTestServiceServer(s, te.testServer)
220	}
221
222	go s.Serve(lis)
223	te.srvAddr = lis.Addr().String()
224}
225
226func (te *test) clientConn() *grpc.ClientConn {
227	if te.cc != nil {
228		return te.cc
229	}
230	opts := []grpc.DialOption{grpc.WithInsecure(), grpc.WithBlock()}
231	if te.compress == "gzip" {
232		opts = append(opts,
233			grpc.WithCompressor(grpc.NewGZIPCompressor()),
234			grpc.WithDecompressor(grpc.NewGZIPDecompressor()),
235		)
236	}
237	if te.clientStatsHandler != nil {
238		opts = append(opts, grpc.WithStatsHandler(te.clientStatsHandler))
239	}
240
241	var err error
242	te.cc, err = grpc.Dial(te.srvAddr, opts...)
243	if err != nil {
244		te.t.Fatalf("Dial(%q) = %v", te.srvAddr, err)
245	}
246	return te.cc
247}
248
249type rpcType int
250
251const (
252	unaryRPC rpcType = iota
253	clientStreamRPC
254	serverStreamRPC
255	fullDuplexStreamRPC
256)
257
258type rpcConfig struct {
259	count    int  // Number of requests and responses for streaming RPCs.
260	success  bool // Whether the RPC should succeed or return error.
261	failfast bool
262	callType rpcType // Type of RPC.
263}
264
265func (te *test) doUnaryCall(c *rpcConfig) (*testpb.SimpleRequest, *testpb.SimpleResponse, error) {
266	var (
267		resp *testpb.SimpleResponse
268		req  *testpb.SimpleRequest
269		err  error
270	)
271	tc := testpb.NewTestServiceClient(te.clientConn())
272	if c.success {
273		req = &testpb.SimpleRequest{Id: errorID + 1}
274	} else {
275		req = &testpb.SimpleRequest{Id: errorID}
276	}
277	ctx := metadata.NewOutgoingContext(context.Background(), testMetadata)
278
279	resp, err = tc.UnaryCall(ctx, req, grpc.WaitForReady(!c.failfast))
280	return req, resp, err
281}
282
283func (te *test) doFullDuplexCallRoundtrip(c *rpcConfig) ([]*testpb.SimpleRequest, []*testpb.SimpleResponse, error) {
284	var (
285		reqs  []*testpb.SimpleRequest
286		resps []*testpb.SimpleResponse
287		err   error
288	)
289	tc := testpb.NewTestServiceClient(te.clientConn())
290	stream, err := tc.FullDuplexCall(metadata.NewOutgoingContext(context.Background(), testMetadata), grpc.WaitForReady(!c.failfast))
291	if err != nil {
292		return reqs, resps, err
293	}
294	var startID int32
295	if !c.success {
296		startID = errorID
297	}
298	for i := 0; i < c.count; i++ {
299		req := &testpb.SimpleRequest{
300			Id: int32(i) + startID,
301		}
302		reqs = append(reqs, req)
303		if err = stream.Send(req); err != nil {
304			return reqs, resps, err
305		}
306		var resp *testpb.SimpleResponse
307		if resp, err = stream.Recv(); err != nil {
308			return reqs, resps, err
309		}
310		resps = append(resps, resp)
311	}
312	if err = stream.CloseSend(); err != nil && err != io.EOF {
313		return reqs, resps, err
314	}
315	if _, err = stream.Recv(); err != io.EOF {
316		return reqs, resps, err
317	}
318
319	return reqs, resps, nil
320}
321
322func (te *test) doClientStreamCall(c *rpcConfig) ([]*testpb.SimpleRequest, *testpb.SimpleResponse, error) {
323	var (
324		reqs []*testpb.SimpleRequest
325		resp *testpb.SimpleResponse
326		err  error
327	)
328	tc := testpb.NewTestServiceClient(te.clientConn())
329	stream, err := tc.ClientStreamCall(metadata.NewOutgoingContext(context.Background(), testMetadata), grpc.WaitForReady(!c.failfast))
330	if err != nil {
331		return reqs, resp, err
332	}
333	var startID int32
334	if !c.success {
335		startID = errorID
336	}
337	for i := 0; i < c.count; i++ {
338		req := &testpb.SimpleRequest{
339			Id: int32(i) + startID,
340		}
341		reqs = append(reqs, req)
342		if err = stream.Send(req); err != nil {
343			return reqs, resp, err
344		}
345	}
346	resp, err = stream.CloseAndRecv()
347	return reqs, resp, err
348}
349
350func (te *test) doServerStreamCall(c *rpcConfig) (*testpb.SimpleRequest, []*testpb.SimpleResponse, error) {
351	var (
352		req   *testpb.SimpleRequest
353		resps []*testpb.SimpleResponse
354		err   error
355	)
356
357	tc := testpb.NewTestServiceClient(te.clientConn())
358
359	var startID int32
360	if !c.success {
361		startID = errorID
362	}
363	req = &testpb.SimpleRequest{Id: startID}
364	stream, err := tc.ServerStreamCall(metadata.NewOutgoingContext(context.Background(), testMetadata), req, grpc.WaitForReady(!c.failfast))
365	if err != nil {
366		return req, resps, err
367	}
368	for {
369		var resp *testpb.SimpleResponse
370		resp, err := stream.Recv()
371		if err == io.EOF {
372			return req, resps, nil
373		} else if err != nil {
374			return req, resps, err
375		}
376		resps = append(resps, resp)
377	}
378}
379
380type expectedData struct {
381	method      string
382	serverAddr  string
383	compression string
384	reqIdx      int
385	requests    []*testpb.SimpleRequest
386	respIdx     int
387	responses   []*testpb.SimpleResponse
388	err         error
389	failfast    bool
390}
391
392type gotData struct {
393	ctx    context.Context
394	client bool
395	s      interface{} // This could be RPCStats or ConnStats.
396}
397
398const (
399	begin int = iota
400	end
401	inPayload
402	inHeader
403	inTrailer
404	outPayload
405	outHeader
406	// TODO: test outTrailer ?
407	connBegin
408	connEnd
409)
410
411func checkBegin(t *testing.T, d *gotData, e *expectedData) {
412	var (
413		ok bool
414		st *stats.Begin
415	)
416	if st, ok = d.s.(*stats.Begin); !ok {
417		t.Fatalf("got %T, want Begin", d.s)
418	}
419	if d.ctx == nil {
420		t.Fatalf("d.ctx = nil, want <non-nil>")
421	}
422	if st.BeginTime.IsZero() {
423		t.Fatalf("st.BeginTime = %v, want <non-zero>", st.BeginTime)
424	}
425	if d.client {
426		if st.FailFast != e.failfast {
427			t.Fatalf("st.FailFast = %v, want %v", st.FailFast, e.failfast)
428		}
429	}
430}
431
432func checkInHeader(t *testing.T, d *gotData, e *expectedData) {
433	var (
434		ok bool
435		st *stats.InHeader
436	)
437	if st, ok = d.s.(*stats.InHeader); !ok {
438		t.Fatalf("got %T, want InHeader", d.s)
439	}
440	if d.ctx == nil {
441		t.Fatalf("d.ctx = nil, want <non-nil>")
442	}
443	if !d.client {
444		if st.FullMethod != e.method {
445			t.Fatalf("st.FullMethod = %s, want %v", st.FullMethod, e.method)
446		}
447		if st.LocalAddr.String() != e.serverAddr {
448			t.Fatalf("st.LocalAddr = %v, want %v", st.LocalAddr, e.serverAddr)
449		}
450		if st.Compression != e.compression {
451			t.Fatalf("st.Compression = %v, want %v", st.Compression, e.compression)
452		}
453
454		if connInfo, ok := d.ctx.Value(connCtxKey{}).(*stats.ConnTagInfo); ok {
455			if connInfo.RemoteAddr != st.RemoteAddr {
456				t.Fatalf("connInfo.RemoteAddr = %v, want %v", connInfo.RemoteAddr, st.RemoteAddr)
457			}
458			if connInfo.LocalAddr != st.LocalAddr {
459				t.Fatalf("connInfo.LocalAddr = %v, want %v", connInfo.LocalAddr, st.LocalAddr)
460			}
461		} else {
462			t.Fatalf("got context %v, want one with connCtxKey", d.ctx)
463		}
464		if rpcInfo, ok := d.ctx.Value(rpcCtxKey{}).(*stats.RPCTagInfo); ok {
465			if rpcInfo.FullMethodName != st.FullMethod {
466				t.Fatalf("rpcInfo.FullMethod = %s, want %v", rpcInfo.FullMethodName, st.FullMethod)
467			}
468		} else {
469			t.Fatalf("got context %v, want one with rpcCtxKey", d.ctx)
470		}
471	}
472}
473
474func checkInPayload(t *testing.T, d *gotData, e *expectedData) {
475	var (
476		ok bool
477		st *stats.InPayload
478	)
479	if st, ok = d.s.(*stats.InPayload); !ok {
480		t.Fatalf("got %T, want InPayload", d.s)
481	}
482	if d.ctx == nil {
483		t.Fatalf("d.ctx = nil, want <non-nil>")
484	}
485	if d.client {
486		b, err := proto.Marshal(e.responses[e.respIdx])
487		if err != nil {
488			t.Fatalf("failed to marshal message: %v", err)
489		}
490		if reflect.TypeOf(st.Payload) != reflect.TypeOf(e.responses[e.respIdx]) {
491			t.Fatalf("st.Payload = %T, want %T", st.Payload, e.responses[e.respIdx])
492		}
493		e.respIdx++
494		if string(st.Data) != string(b) {
495			t.Fatalf("st.Data = %v, want %v", st.Data, b)
496		}
497		if st.Length != len(b) {
498			t.Fatalf("st.Lenght = %v, want %v", st.Length, len(b))
499		}
500	} else {
501		b, err := proto.Marshal(e.requests[e.reqIdx])
502		if err != nil {
503			t.Fatalf("failed to marshal message: %v", err)
504		}
505		if reflect.TypeOf(st.Payload) != reflect.TypeOf(e.requests[e.reqIdx]) {
506			t.Fatalf("st.Payload = %T, want %T", st.Payload, e.requests[e.reqIdx])
507		}
508		e.reqIdx++
509		if string(st.Data) != string(b) {
510			t.Fatalf("st.Data = %v, want %v", st.Data, b)
511		}
512		if st.Length != len(b) {
513			t.Fatalf("st.Lenght = %v, want %v", st.Length, len(b))
514		}
515	}
516	// Below are sanity checks that WireLength and RecvTime are populated.
517	// TODO: check values of WireLength and RecvTime.
518	if len(st.Data) > 0 && st.WireLength == 0 {
519		t.Fatalf("st.WireLength = %v with non-empty data, want <non-zero>",
520			st.WireLength)
521	}
522	if st.RecvTime.IsZero() {
523		t.Fatalf("st.ReceivedTime = %v, want <non-zero>", st.RecvTime)
524	}
525}
526
527func checkInTrailer(t *testing.T, d *gotData, e *expectedData) {
528	var (
529		ok bool
530	)
531	if _, ok = d.s.(*stats.InTrailer); !ok {
532		t.Fatalf("got %T, want InTrailer", d.s)
533	}
534	if d.ctx == nil {
535		t.Fatalf("d.ctx = nil, want <non-nil>")
536	}
537}
538
539func checkOutHeader(t *testing.T, d *gotData, e *expectedData) {
540	var (
541		ok bool
542		st *stats.OutHeader
543	)
544	if st, ok = d.s.(*stats.OutHeader); !ok {
545		t.Fatalf("got %T, want OutHeader", d.s)
546	}
547	if d.ctx == nil {
548		t.Fatalf("d.ctx = nil, want <non-nil>")
549	}
550	if d.client {
551		if st.FullMethod != e.method {
552			t.Fatalf("st.FullMethod = %s, want %v", st.FullMethod, e.method)
553		}
554		if st.RemoteAddr.String() != e.serverAddr {
555			t.Fatalf("st.RemoteAddr = %v, want %v", st.RemoteAddr, e.serverAddr)
556		}
557		if st.Compression != e.compression {
558			t.Fatalf("st.Compression = %v, want %v", st.Compression, e.compression)
559		}
560
561		if rpcInfo, ok := d.ctx.Value(rpcCtxKey{}).(*stats.RPCTagInfo); ok {
562			if rpcInfo.FullMethodName != st.FullMethod {
563				t.Fatalf("rpcInfo.FullMethod = %s, want %v", rpcInfo.FullMethodName, st.FullMethod)
564			}
565		} else {
566			t.Fatalf("got context %v, want one with rpcCtxKey", d.ctx)
567		}
568	}
569}
570
571func checkOutPayload(t *testing.T, d *gotData, e *expectedData) {
572	var (
573		ok bool
574		st *stats.OutPayload
575	)
576	if st, ok = d.s.(*stats.OutPayload); !ok {
577		t.Fatalf("got %T, want OutPayload", d.s)
578	}
579	if d.ctx == nil {
580		t.Fatalf("d.ctx = nil, want <non-nil>")
581	}
582	if d.client {
583		b, err := proto.Marshal(e.requests[e.reqIdx])
584		if err != nil {
585			t.Fatalf("failed to marshal message: %v", err)
586		}
587		if reflect.TypeOf(st.Payload) != reflect.TypeOf(e.requests[e.reqIdx]) {
588			t.Fatalf("st.Payload = %T, want %T", st.Payload, e.requests[e.reqIdx])
589		}
590		e.reqIdx++
591		if string(st.Data) != string(b) {
592			t.Fatalf("st.Data = %v, want %v", st.Data, b)
593		}
594		if st.Length != len(b) {
595			t.Fatalf("st.Lenght = %v, want %v", st.Length, len(b))
596		}
597	} else {
598		b, err := proto.Marshal(e.responses[e.respIdx])
599		if err != nil {
600			t.Fatalf("failed to marshal message: %v", err)
601		}
602		if reflect.TypeOf(st.Payload) != reflect.TypeOf(e.responses[e.respIdx]) {
603			t.Fatalf("st.Payload = %T, want %T", st.Payload, e.responses[e.respIdx])
604		}
605		e.respIdx++
606		if string(st.Data) != string(b) {
607			t.Fatalf("st.Data = %v, want %v", st.Data, b)
608		}
609		if st.Length != len(b) {
610			t.Fatalf("st.Lenght = %v, want %v", st.Length, len(b))
611		}
612	}
613	// Below are sanity checks that WireLength and SentTime are populated.
614	// TODO: check values of WireLength and SentTime.
615	if len(st.Data) > 0 && st.WireLength == 0 {
616		t.Fatalf("st.WireLength = %v with non-empty data, want <non-zero>",
617			st.WireLength)
618	}
619	if st.SentTime.IsZero() {
620		t.Fatalf("st.SentTime = %v, want <non-zero>", st.SentTime)
621	}
622}
623
624func checkOutTrailer(t *testing.T, d *gotData, e *expectedData) {
625	var (
626		ok bool
627		st *stats.OutTrailer
628	)
629	if st, ok = d.s.(*stats.OutTrailer); !ok {
630		t.Fatalf("got %T, want OutTrailer", d.s)
631	}
632	if d.ctx == nil {
633		t.Fatalf("d.ctx = nil, want <non-nil>")
634	}
635	if st.Client {
636		t.Fatalf("st IsClient = true, want false")
637	}
638}
639
640func checkEnd(t *testing.T, d *gotData, e *expectedData) {
641	var (
642		ok bool
643		st *stats.End
644	)
645	if st, ok = d.s.(*stats.End); !ok {
646		t.Fatalf("got %T, want End", d.s)
647	}
648	if d.ctx == nil {
649		t.Fatalf("d.ctx = nil, want <non-nil>")
650	}
651	if st.BeginTime.IsZero() {
652		t.Fatalf("st.BeginTime = %v, want <non-zero>", st.BeginTime)
653	}
654	if st.EndTime.IsZero() {
655		t.Fatalf("st.EndTime = %v, want <non-zero>", st.EndTime)
656	}
657
658	actual, ok := status.FromError(st.Error)
659	if !ok {
660		t.Fatalf("expected st.Error to be a statusError, got %v (type %T)", st.Error, st.Error)
661	}
662
663	expectedStatus, _ := status.FromError(e.err)
664	if actual.Code() != expectedStatus.Code() || actual.Message() != expectedStatus.Message() {
665		t.Fatalf("st.Error = %v, want %v", st.Error, e.err)
666	}
667
668	if st.Client {
669		if !reflect.DeepEqual(st.Trailer, testTrailerMetadata) {
670			t.Fatalf("st.Trailer = %v, want %v", st.Trailer, testTrailerMetadata)
671		}
672	} else {
673		if st.Trailer != nil {
674			t.Fatalf("st.Trailer = %v, want nil", st.Trailer)
675		}
676	}
677}
678
679func checkConnBegin(t *testing.T, d *gotData, e *expectedData) {
680	var (
681		ok bool
682		st *stats.ConnBegin
683	)
684	if st, ok = d.s.(*stats.ConnBegin); !ok {
685		t.Fatalf("got %T, want ConnBegin", d.s)
686	}
687	if d.ctx == nil {
688		t.Fatalf("d.ctx = nil, want <non-nil>")
689	}
690	st.IsClient() // TODO remove this.
691}
692
693func checkConnEnd(t *testing.T, d *gotData, e *expectedData) {
694	var (
695		ok bool
696		st *stats.ConnEnd
697	)
698	if st, ok = d.s.(*stats.ConnEnd); !ok {
699		t.Fatalf("got %T, want ConnEnd", d.s)
700	}
701	if d.ctx == nil {
702		t.Fatalf("d.ctx = nil, want <non-nil>")
703	}
704	st.IsClient() // TODO remove this.
705}
706
707type statshandler struct {
708	mu      sync.Mutex
709	gotRPC  []*gotData
710	gotConn []*gotData
711}
712
713func (h *statshandler) TagConn(ctx context.Context, info *stats.ConnTagInfo) context.Context {
714	return context.WithValue(ctx, connCtxKey{}, info)
715}
716
717func (h *statshandler) TagRPC(ctx context.Context, info *stats.RPCTagInfo) context.Context {
718	return context.WithValue(ctx, rpcCtxKey{}, info)
719}
720
721func (h *statshandler) HandleConn(ctx context.Context, s stats.ConnStats) {
722	h.mu.Lock()
723	defer h.mu.Unlock()
724	h.gotConn = append(h.gotConn, &gotData{ctx, s.IsClient(), s})
725}
726
727func (h *statshandler) HandleRPC(ctx context.Context, s stats.RPCStats) {
728	h.mu.Lock()
729	defer h.mu.Unlock()
730	h.gotRPC = append(h.gotRPC, &gotData{ctx, s.IsClient(), s})
731}
732
733func checkConnStats(t *testing.T, got []*gotData) {
734	if len(got) <= 0 || len(got)%2 != 0 {
735		for i, g := range got {
736			t.Errorf(" - %v, %T = %+v, ctx: %v", i, g.s, g.s, g.ctx)
737		}
738		t.Fatalf("got %v stats, want even positive number", len(got))
739	}
740	// The first conn stats must be a ConnBegin.
741	checkConnBegin(t, got[0], nil)
742	// The last conn stats must be a ConnEnd.
743	checkConnEnd(t, got[len(got)-1], nil)
744}
745
746func checkServerStats(t *testing.T, got []*gotData, expect *expectedData, checkFuncs []func(t *testing.T, d *gotData, e *expectedData)) {
747	if len(got) != len(checkFuncs) {
748		for i, g := range got {
749			t.Errorf(" - %v, %T", i, g.s)
750		}
751		t.Fatalf("got %v stats, want %v stats", len(got), len(checkFuncs))
752	}
753
754	var rpcctx context.Context
755	for i := 0; i < len(got); i++ {
756		if _, ok := got[i].s.(stats.RPCStats); ok {
757			if rpcctx != nil && got[i].ctx != rpcctx {
758				t.Fatalf("got different contexts with stats %T", got[i].s)
759			}
760			rpcctx = got[i].ctx
761		}
762	}
763
764	for i, f := range checkFuncs {
765		f(t, got[i], expect)
766	}
767}
768
769func testServerStats(t *testing.T, tc *testConfig, cc *rpcConfig, checkFuncs []func(t *testing.T, d *gotData, e *expectedData)) {
770	h := &statshandler{}
771	te := newTest(t, tc, nil, h)
772	te.startServer(&testServer{})
773	defer te.tearDown()
774
775	var (
776		reqs   []*testpb.SimpleRequest
777		resps  []*testpb.SimpleResponse
778		err    error
779		method string
780
781		req  *testpb.SimpleRequest
782		resp *testpb.SimpleResponse
783		e    error
784	)
785
786	switch cc.callType {
787	case unaryRPC:
788		method = "/grpc.testing.TestService/UnaryCall"
789		req, resp, e = te.doUnaryCall(cc)
790		reqs = []*testpb.SimpleRequest{req}
791		resps = []*testpb.SimpleResponse{resp}
792		err = e
793	case clientStreamRPC:
794		method = "/grpc.testing.TestService/ClientStreamCall"
795		reqs, resp, e = te.doClientStreamCall(cc)
796		resps = []*testpb.SimpleResponse{resp}
797		err = e
798	case serverStreamRPC:
799		method = "/grpc.testing.TestService/ServerStreamCall"
800		req, resps, e = te.doServerStreamCall(cc)
801		reqs = []*testpb.SimpleRequest{req}
802		err = e
803	case fullDuplexStreamRPC:
804		method = "/grpc.testing.TestService/FullDuplexCall"
805		reqs, resps, err = te.doFullDuplexCallRoundtrip(cc)
806	}
807	if cc.success != (err == nil) {
808		t.Fatalf("cc.success: %v, got error: %v", cc.success, err)
809	}
810	te.cc.Close()
811	te.srv.GracefulStop() // Wait for the server to stop.
812
813	for {
814		h.mu.Lock()
815		if len(h.gotRPC) >= len(checkFuncs) {
816			h.mu.Unlock()
817			break
818		}
819		h.mu.Unlock()
820		time.Sleep(10 * time.Millisecond)
821	}
822
823	for {
824		h.mu.Lock()
825		if _, ok := h.gotConn[len(h.gotConn)-1].s.(*stats.ConnEnd); ok {
826			h.mu.Unlock()
827			break
828		}
829		h.mu.Unlock()
830		time.Sleep(10 * time.Millisecond)
831	}
832
833	expect := &expectedData{
834		serverAddr:  te.srvAddr,
835		compression: tc.compress,
836		method:      method,
837		requests:    reqs,
838		responses:   resps,
839		err:         err,
840	}
841
842	h.mu.Lock()
843	checkConnStats(t, h.gotConn)
844	h.mu.Unlock()
845	checkServerStats(t, h.gotRPC, expect, checkFuncs)
846}
847
848func TestServerStatsUnaryRPC(t *testing.T) {
849	testServerStats(t, &testConfig{compress: ""}, &rpcConfig{success: true, callType: unaryRPC}, []func(t *testing.T, d *gotData, e *expectedData){
850		checkInHeader,
851		checkBegin,
852		checkInPayload,
853		checkOutHeader,
854		checkOutPayload,
855		checkOutTrailer,
856		checkEnd,
857	})
858}
859
860func TestServerStatsUnaryRPCError(t *testing.T) {
861	testServerStats(t, &testConfig{compress: ""}, &rpcConfig{success: false, callType: unaryRPC}, []func(t *testing.T, d *gotData, e *expectedData){
862		checkInHeader,
863		checkBegin,
864		checkInPayload,
865		checkOutHeader,
866		checkOutTrailer,
867		checkEnd,
868	})
869}
870
871func TestServerStatsClientStreamRPC(t *testing.T) {
872	count := 5
873	checkFuncs := []func(t *testing.T, d *gotData, e *expectedData){
874		checkInHeader,
875		checkBegin,
876		checkOutHeader,
877	}
878	ioPayFuncs := []func(t *testing.T, d *gotData, e *expectedData){
879		checkInPayload,
880	}
881	for i := 0; i < count; i++ {
882		checkFuncs = append(checkFuncs, ioPayFuncs...)
883	}
884	checkFuncs = append(checkFuncs,
885		checkOutPayload,
886		checkOutTrailer,
887		checkEnd,
888	)
889	testServerStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: true, callType: clientStreamRPC}, checkFuncs)
890}
891
892func TestServerStatsClientStreamRPCError(t *testing.T) {
893	count := 1
894	testServerStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: false, callType: clientStreamRPC}, []func(t *testing.T, d *gotData, e *expectedData){
895		checkInHeader,
896		checkBegin,
897		checkOutHeader,
898		checkInPayload,
899		checkOutTrailer,
900		checkEnd,
901	})
902}
903
904func TestServerStatsServerStreamRPC(t *testing.T) {
905	count := 5
906	checkFuncs := []func(t *testing.T, d *gotData, e *expectedData){
907		checkInHeader,
908		checkBegin,
909		checkInPayload,
910		checkOutHeader,
911	}
912	ioPayFuncs := []func(t *testing.T, d *gotData, e *expectedData){
913		checkOutPayload,
914	}
915	for i := 0; i < count; i++ {
916		checkFuncs = append(checkFuncs, ioPayFuncs...)
917	}
918	checkFuncs = append(checkFuncs,
919		checkOutTrailer,
920		checkEnd,
921	)
922	testServerStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: true, callType: serverStreamRPC}, checkFuncs)
923}
924
925func TestServerStatsServerStreamRPCError(t *testing.T) {
926	count := 5
927	testServerStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: false, callType: serverStreamRPC}, []func(t *testing.T, d *gotData, e *expectedData){
928		checkInHeader,
929		checkBegin,
930		checkInPayload,
931		checkOutHeader,
932		checkOutTrailer,
933		checkEnd,
934	})
935}
936
937func TestServerStatsFullDuplexRPC(t *testing.T) {
938	count := 5
939	checkFuncs := []func(t *testing.T, d *gotData, e *expectedData){
940		checkInHeader,
941		checkBegin,
942		checkOutHeader,
943	}
944	ioPayFuncs := []func(t *testing.T, d *gotData, e *expectedData){
945		checkInPayload,
946		checkOutPayload,
947	}
948	for i := 0; i < count; i++ {
949		checkFuncs = append(checkFuncs, ioPayFuncs...)
950	}
951	checkFuncs = append(checkFuncs,
952		checkOutTrailer,
953		checkEnd,
954	)
955	testServerStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: true, callType: fullDuplexStreamRPC}, checkFuncs)
956}
957
958func TestServerStatsFullDuplexRPCError(t *testing.T) {
959	count := 5
960	testServerStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: false, callType: fullDuplexStreamRPC}, []func(t *testing.T, d *gotData, e *expectedData){
961		checkInHeader,
962		checkBegin,
963		checkOutHeader,
964		checkInPayload,
965		checkOutTrailer,
966		checkEnd,
967	})
968}
969
970type checkFuncWithCount struct {
971	f func(t *testing.T, d *gotData, e *expectedData)
972	c int // expected count
973}
974
975func checkClientStats(t *testing.T, got []*gotData, expect *expectedData, checkFuncs map[int]*checkFuncWithCount) {
976	var expectLen int
977	for _, v := range checkFuncs {
978		expectLen += v.c
979	}
980	if len(got) != expectLen {
981		for i, g := range got {
982			t.Errorf(" - %v, %T", i, g.s)
983		}
984		t.Fatalf("got %v stats, want %v stats", len(got), expectLen)
985	}
986
987	var tagInfoInCtx *stats.RPCTagInfo
988	for i := 0; i < len(got); i++ {
989		if _, ok := got[i].s.(stats.RPCStats); ok {
990			tagInfoInCtxNew, _ := got[i].ctx.Value(rpcCtxKey{}).(*stats.RPCTagInfo)
991			if tagInfoInCtx != nil && tagInfoInCtx != tagInfoInCtxNew {
992				t.Fatalf("got context containing different tagInfo with stats %T", got[i].s)
993			}
994			tagInfoInCtx = tagInfoInCtxNew
995		}
996	}
997
998	for _, s := range got {
999		switch s.s.(type) {
1000		case *stats.Begin:
1001			if checkFuncs[begin].c <= 0 {
1002				t.Fatalf("unexpected stats: %T", s.s)
1003			}
1004			checkFuncs[begin].f(t, s, expect)
1005			checkFuncs[begin].c--
1006		case *stats.OutHeader:
1007			if checkFuncs[outHeader].c <= 0 {
1008				t.Fatalf("unexpected stats: %T", s.s)
1009			}
1010			checkFuncs[outHeader].f(t, s, expect)
1011			checkFuncs[outHeader].c--
1012		case *stats.OutPayload:
1013			if checkFuncs[outPayload].c <= 0 {
1014				t.Fatalf("unexpected stats: %T", s.s)
1015			}
1016			checkFuncs[outPayload].f(t, s, expect)
1017			checkFuncs[outPayload].c--
1018		case *stats.InHeader:
1019			if checkFuncs[inHeader].c <= 0 {
1020				t.Fatalf("unexpected stats: %T", s.s)
1021			}
1022			checkFuncs[inHeader].f(t, s, expect)
1023			checkFuncs[inHeader].c--
1024		case *stats.InPayload:
1025			if checkFuncs[inPayload].c <= 0 {
1026				t.Fatalf("unexpected stats: %T", s.s)
1027			}
1028			checkFuncs[inPayload].f(t, s, expect)
1029			checkFuncs[inPayload].c--
1030		case *stats.InTrailer:
1031			if checkFuncs[inTrailer].c <= 0 {
1032				t.Fatalf("unexpected stats: %T", s.s)
1033			}
1034			checkFuncs[inTrailer].f(t, s, expect)
1035			checkFuncs[inTrailer].c--
1036		case *stats.End:
1037			if checkFuncs[end].c <= 0 {
1038				t.Fatalf("unexpected stats: %T", s.s)
1039			}
1040			checkFuncs[end].f(t, s, expect)
1041			checkFuncs[end].c--
1042		case *stats.ConnBegin:
1043			if checkFuncs[connBegin].c <= 0 {
1044				t.Fatalf("unexpected stats: %T", s.s)
1045			}
1046			checkFuncs[connBegin].f(t, s, expect)
1047			checkFuncs[connBegin].c--
1048		case *stats.ConnEnd:
1049			if checkFuncs[connEnd].c <= 0 {
1050				t.Fatalf("unexpected stats: %T", s.s)
1051			}
1052			checkFuncs[connEnd].f(t, s, expect)
1053			checkFuncs[connEnd].c--
1054		default:
1055			t.Fatalf("unexpected stats: %T", s.s)
1056		}
1057	}
1058}
1059
1060func testClientStats(t *testing.T, tc *testConfig, cc *rpcConfig, checkFuncs map[int]*checkFuncWithCount) {
1061	h := &statshandler{}
1062	te := newTest(t, tc, h, nil)
1063	te.startServer(&testServer{})
1064	defer te.tearDown()
1065
1066	var (
1067		reqs   []*testpb.SimpleRequest
1068		resps  []*testpb.SimpleResponse
1069		method string
1070		err    error
1071
1072		req  *testpb.SimpleRequest
1073		resp *testpb.SimpleResponse
1074		e    error
1075	)
1076	switch cc.callType {
1077	case unaryRPC:
1078		method = "/grpc.testing.TestService/UnaryCall"
1079		req, resp, e = te.doUnaryCall(cc)
1080		reqs = []*testpb.SimpleRequest{req}
1081		resps = []*testpb.SimpleResponse{resp}
1082		err = e
1083	case clientStreamRPC:
1084		method = "/grpc.testing.TestService/ClientStreamCall"
1085		reqs, resp, e = te.doClientStreamCall(cc)
1086		resps = []*testpb.SimpleResponse{resp}
1087		err = e
1088	case serverStreamRPC:
1089		method = "/grpc.testing.TestService/ServerStreamCall"
1090		req, resps, e = te.doServerStreamCall(cc)
1091		reqs = []*testpb.SimpleRequest{req}
1092		err = e
1093	case fullDuplexStreamRPC:
1094		method = "/grpc.testing.TestService/FullDuplexCall"
1095		reqs, resps, err = te.doFullDuplexCallRoundtrip(cc)
1096	}
1097	if cc.success != (err == nil) {
1098		t.Fatalf("cc.success: %v, got error: %v", cc.success, err)
1099	}
1100	te.cc.Close()
1101	te.srv.GracefulStop() // Wait for the server to stop.
1102
1103	lenRPCStats := 0
1104	for _, v := range checkFuncs {
1105		lenRPCStats += v.c
1106	}
1107	for {
1108		h.mu.Lock()
1109		if len(h.gotRPC) >= lenRPCStats {
1110			h.mu.Unlock()
1111			break
1112		}
1113		h.mu.Unlock()
1114		time.Sleep(10 * time.Millisecond)
1115	}
1116
1117	for {
1118		h.mu.Lock()
1119		if _, ok := h.gotConn[len(h.gotConn)-1].s.(*stats.ConnEnd); ok {
1120			h.mu.Unlock()
1121			break
1122		}
1123		h.mu.Unlock()
1124		time.Sleep(10 * time.Millisecond)
1125	}
1126
1127	expect := &expectedData{
1128		serverAddr:  te.srvAddr,
1129		compression: tc.compress,
1130		method:      method,
1131		requests:    reqs,
1132		responses:   resps,
1133		failfast:    cc.failfast,
1134		err:         err,
1135	}
1136
1137	h.mu.Lock()
1138	checkConnStats(t, h.gotConn)
1139	h.mu.Unlock()
1140	checkClientStats(t, h.gotRPC, expect, checkFuncs)
1141}
1142
1143func TestClientStatsUnaryRPC(t *testing.T) {
1144	testClientStats(t, &testConfig{compress: ""}, &rpcConfig{success: true, failfast: false, callType: unaryRPC}, map[int]*checkFuncWithCount{
1145		begin:      {checkBegin, 1},
1146		outHeader:  {checkOutHeader, 1},
1147		outPayload: {checkOutPayload, 1},
1148		inHeader:   {checkInHeader, 1},
1149		inPayload:  {checkInPayload, 1},
1150		inTrailer:  {checkInTrailer, 1},
1151		end:        {checkEnd, 1},
1152	})
1153}
1154
1155func TestClientStatsUnaryRPCError(t *testing.T) {
1156	testClientStats(t, &testConfig{compress: ""}, &rpcConfig{success: false, failfast: false, callType: unaryRPC}, map[int]*checkFuncWithCount{
1157		begin:      {checkBegin, 1},
1158		outHeader:  {checkOutHeader, 1},
1159		outPayload: {checkOutPayload, 1},
1160		inHeader:   {checkInHeader, 1},
1161		inTrailer:  {checkInTrailer, 1},
1162		end:        {checkEnd, 1},
1163	})
1164}
1165
1166func TestClientStatsClientStreamRPC(t *testing.T) {
1167	count := 5
1168	testClientStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: true, failfast: false, callType: clientStreamRPC}, map[int]*checkFuncWithCount{
1169		begin:      {checkBegin, 1},
1170		outHeader:  {checkOutHeader, 1},
1171		inHeader:   {checkInHeader, 1},
1172		outPayload: {checkOutPayload, count},
1173		inTrailer:  {checkInTrailer, 1},
1174		inPayload:  {checkInPayload, 1},
1175		end:        {checkEnd, 1},
1176	})
1177}
1178
1179func TestClientStatsClientStreamRPCError(t *testing.T) {
1180	count := 1
1181	testClientStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: false, failfast: false, callType: clientStreamRPC}, map[int]*checkFuncWithCount{
1182		begin:      {checkBegin, 1},
1183		outHeader:  {checkOutHeader, 1},
1184		inHeader:   {checkInHeader, 1},
1185		outPayload: {checkOutPayload, 1},
1186		inTrailer:  {checkInTrailer, 1},
1187		end:        {checkEnd, 1},
1188	})
1189}
1190
1191func TestClientStatsServerStreamRPC(t *testing.T) {
1192	count := 5
1193	testClientStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: true, failfast: false, callType: serverStreamRPC}, map[int]*checkFuncWithCount{
1194		begin:      {checkBegin, 1},
1195		outHeader:  {checkOutHeader, 1},
1196		outPayload: {checkOutPayload, 1},
1197		inHeader:   {checkInHeader, 1},
1198		inPayload:  {checkInPayload, count},
1199		inTrailer:  {checkInTrailer, 1},
1200		end:        {checkEnd, 1},
1201	})
1202}
1203
1204func TestClientStatsServerStreamRPCError(t *testing.T) {
1205	count := 5
1206	testClientStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: false, failfast: false, callType: serverStreamRPC}, map[int]*checkFuncWithCount{
1207		begin:      {checkBegin, 1},
1208		outHeader:  {checkOutHeader, 1},
1209		outPayload: {checkOutPayload, 1},
1210		inHeader:   {checkInHeader, 1},
1211		inTrailer:  {checkInTrailer, 1},
1212		end:        {checkEnd, 1},
1213	})
1214}
1215
1216func TestClientStatsFullDuplexRPC(t *testing.T) {
1217	count := 5
1218	testClientStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: true, failfast: false, callType: fullDuplexStreamRPC}, map[int]*checkFuncWithCount{
1219		begin:      {checkBegin, 1},
1220		outHeader:  {checkOutHeader, 1},
1221		outPayload: {checkOutPayload, count},
1222		inHeader:   {checkInHeader, 1},
1223		inPayload:  {checkInPayload, count},
1224		inTrailer:  {checkInTrailer, 1},
1225		end:        {checkEnd, 1},
1226	})
1227}
1228
1229func TestClientStatsFullDuplexRPCError(t *testing.T) {
1230	count := 5
1231	testClientStats(t, &testConfig{compress: "gzip"}, &rpcConfig{count: count, success: false, failfast: false, callType: fullDuplexStreamRPC}, map[int]*checkFuncWithCount{
1232		begin:      {checkBegin, 1},
1233		outHeader:  {checkOutHeader, 1},
1234		outPayload: {checkOutPayload, 1},
1235		inHeader:   {checkInHeader, 1},
1236		inTrailer:  {checkInTrailer, 1},
1237		end:        {checkEnd, 1},
1238	})
1239}
1240
1241func TestTags(t *testing.T) {
1242	b := []byte{5, 2, 4, 3, 1}
1243	ctx := stats.SetTags(context.Background(), b)
1244	if tg := stats.OutgoingTags(ctx); !reflect.DeepEqual(tg, b) {
1245		t.Errorf("OutgoingTags(%v) = %v; want %v", ctx, tg, b)
1246	}
1247	if tg := stats.Tags(ctx); tg != nil {
1248		t.Errorf("Tags(%v) = %v; want nil", ctx, tg)
1249	}
1250
1251	ctx = stats.SetIncomingTags(context.Background(), b)
1252	if tg := stats.Tags(ctx); !reflect.DeepEqual(tg, b) {
1253		t.Errorf("Tags(%v) = %v; want %v", ctx, tg, b)
1254	}
1255	if tg := stats.OutgoingTags(ctx); tg != nil {
1256		t.Errorf("OutgoingTags(%v) = %v; want nil", ctx, tg)
1257	}
1258}
1259
1260func TestTrace(t *testing.T) {
1261	b := []byte{5, 2, 4, 3, 1}
1262	ctx := stats.SetTrace(context.Background(), b)
1263	if tr := stats.OutgoingTrace(ctx); !reflect.DeepEqual(tr, b) {
1264		t.Errorf("OutgoingTrace(%v) = %v; want %v", ctx, tr, b)
1265	}
1266	if tr := stats.Trace(ctx); tr != nil {
1267		t.Errorf("Trace(%v) = %v; want nil", ctx, tr)
1268	}
1269
1270	ctx = stats.SetIncomingTrace(context.Background(), b)
1271	if tr := stats.Trace(ctx); !reflect.DeepEqual(tr, b) {
1272		t.Errorf("Trace(%v) = %v; want %v", ctx, tr, b)
1273	}
1274	if tr := stats.OutgoingTrace(ctx); tr != nil {
1275		t.Errorf("OutgoingTrace(%v) = %v; want nil", ctx, tr)
1276	}
1277}
1278