1// Copyright 2016 Michal Witkowski. All Rights Reserved.
2// See LICENSE for licensing terms.
3
4package grpc_prometheus
5
6import (
7	"net"
8	"testing"
9
10	"time"
11
12	"io"
13
14	pb_testproto "github.com/grpc-ecosystem/go-grpc-prometheus/examples/testproto"
15	"github.com/prometheus/client_golang/prometheus"
16	"github.com/stretchr/testify/assert"
17	"github.com/stretchr/testify/require"
18	"github.com/stretchr/testify/suite"
19	"golang.org/x/net/context"
20	"google.golang.org/grpc"
21	"google.golang.org/grpc/codes"
22	"google.golang.org/grpc/status"
23)
24
25var (
26	// client metrics must satisfy the Collector interface
27	_ prometheus.Collector = NewClientMetrics()
28)
29
30func TestClientInterceptorSuite(t *testing.T) {
31	suite.Run(t, &ClientInterceptorTestSuite{})
32}
33
34type ClientInterceptorTestSuite struct {
35	suite.Suite
36
37	serverListener net.Listener
38	server         *grpc.Server
39	clientConn     *grpc.ClientConn
40	testClient     pb_testproto.TestServiceClient
41	ctx            context.Context
42}
43
44func (s *ClientInterceptorTestSuite) SetupSuite() {
45	var err error
46
47	EnableClientHandlingTimeHistogram()
48
49	s.serverListener, err = net.Listen("tcp", "127.0.0.1:0")
50	require.NoError(s.T(), err, "must be able to allocate a port for serverListener")
51
52	// This is the point where we hook up the interceptor
53	s.server = grpc.NewServer()
54	pb_testproto.RegisterTestServiceServer(s.server, &testService{t: s.T()})
55
56	go func() {
57		s.server.Serve(s.serverListener)
58	}()
59
60	s.clientConn, err = grpc.Dial(
61		s.serverListener.Addr().String(),
62		grpc.WithInsecure(),
63		grpc.WithBlock(),
64		grpc.WithUnaryInterceptor(UnaryClientInterceptor),
65		grpc.WithStreamInterceptor(StreamClientInterceptor),
66		grpc.WithTimeout(2*time.Second))
67	require.NoError(s.T(), err, "must not error on client Dial")
68	s.testClient = pb_testproto.NewTestServiceClient(s.clientConn)
69}
70
71func (s *ClientInterceptorTestSuite) SetupTest() {
72	// Make all RPC calls last at most 2 sec, meaning all async issues or deadlock will not kill tests.
73	s.ctx, _ = context.WithTimeout(context.TODO(), 2*time.Second)
74}
75
76func (s *ClientInterceptorTestSuite) TearDownSuite() {
77	if s.serverListener != nil {
78		s.server.Stop()
79		s.T().Logf("stopped grpc.Server at: %v", s.serverListener.Addr().String())
80		s.serverListener.Close()
81
82	}
83	if s.clientConn != nil {
84		s.clientConn.Close()
85	}
86}
87
88func (s *ClientInterceptorTestSuite) TestUnaryIncrementsStarted() {
89	var before int
90	var after int
91
92	before = sumCountersForMetricAndLabels(s.T(), "grpc_client_started_total", "PingEmpty", "unary")
93	s.testClient.PingEmpty(s.ctx, &pb_testproto.Empty{})
94	after = sumCountersForMetricAndLabels(s.T(), "grpc_client_started_total", "PingEmpty", "unary")
95	assert.EqualValues(s.T(), before+1, after, "grpc_client_started_total should be incremented for PingEmpty")
96
97	before = sumCountersForMetricAndLabels(s.T(), "grpc_client_started_total", "PingError", "unary")
98	s.testClient.PingError(s.ctx, &pb_testproto.PingRequest{ErrorCodeReturned: uint32(codes.Unavailable)})
99	after = sumCountersForMetricAndLabels(s.T(), "grpc_client_started_total", "PingError", "unary")
100	assert.EqualValues(s.T(), before+1, after, "grpc_client_started_total should be incremented for PingError")
101}
102
103func (s *ClientInterceptorTestSuite) TestUnaryIncrementsHandled() {
104	var before int
105	var after int
106
107	before = sumCountersForMetricAndLabels(s.T(), "grpc_client_handled_total", "PingEmpty", "unary", "OK")
108	s.testClient.PingEmpty(s.ctx, &pb_testproto.Empty{}) // should return with code=OK
109	after = sumCountersForMetricAndLabels(s.T(), "grpc_client_handled_total", "PingEmpty", "unary", "OK")
110	assert.EqualValues(s.T(), before+1, after, "grpc_client_handled_count should be incremented for PingEmpty")
111
112	before = sumCountersForMetricAndLabels(s.T(), "grpc_client_handled_total", "PingError", "unary", "FailedPrecondition")
113	s.testClient.PingError(s.ctx, &pb_testproto.PingRequest{ErrorCodeReturned: uint32(codes.FailedPrecondition)}) // should return with code=FailedPrecondition
114	after = sumCountersForMetricAndLabels(s.T(), "grpc_client_handled_total", "PingError", "unary", "FailedPrecondition")
115	assert.EqualValues(s.T(), before+1, after, "grpc_client_handled_total should be incremented for PingError")
116}
117
118func (s *ClientInterceptorTestSuite) TestUnaryIncrementsHistograms() {
119	var before int
120	var after int
121
122	before = sumCountersForMetricAndLabels(s.T(), "grpc_client_handling_seconds_count", "PingEmpty", "unary")
123	s.testClient.PingEmpty(s.ctx, &pb_testproto.Empty{}) // should return with code=OK
124	after = sumCountersForMetricAndLabels(s.T(), "grpc_client_handling_seconds_count", "PingEmpty", "unary")
125	assert.EqualValues(s.T(), before+1, after, "grpc_client_handled_count should be incremented for PingEmpty")
126
127	before = sumCountersForMetricAndLabels(s.T(), "grpc_client_handling_seconds_count", "PingError", "unary")
128	s.testClient.PingError(s.ctx, &pb_testproto.PingRequest{ErrorCodeReturned: uint32(codes.FailedPrecondition)}) // should return with code=FailedPrecondition
129	after = sumCountersForMetricAndLabels(s.T(), "grpc_client_handling_seconds_count", "PingError", "unary")
130	assert.EqualValues(s.T(), before+1, after, "grpc_client_handling_seconds_count should be incremented for PingError")
131}
132
133func (s *ClientInterceptorTestSuite) TestStreamingIncrementsStarted() {
134	var before int
135	var after int
136
137	before = sumCountersForMetricAndLabels(s.T(), "grpc_client_started_total", "PingList", "server_stream")
138	s.testClient.PingList(s.ctx, &pb_testproto.PingRequest{})
139	after = sumCountersForMetricAndLabels(s.T(), "grpc_client_started_total", "PingList", "server_stream")
140	assert.EqualValues(s.T(), before+1, after, "grpc_client_started_total should be incremented for PingList")
141}
142
143func (s *ClientInterceptorTestSuite) TestStreamingIncrementsHistograms() {
144	var before int
145	var after int
146
147	before = sumCountersForMetricAndLabels(s.T(), "grpc_client_handling_seconds_count", "PingList", "server_stream")
148	ss, _ := s.testClient.PingList(s.ctx, &pb_testproto.PingRequest{}) // should return with code=OK
149	// Do a read, just for kicks.
150	for {
151		_, err := ss.Recv()
152		if err == io.EOF {
153			break
154		}
155		require.NoError(s.T(), err, "reading pingList shouldn't fail")
156	}
157	after = sumCountersForMetricAndLabels(s.T(), "grpc_client_handling_seconds_count", "PingList", "server_stream")
158	assert.EqualValues(s.T(), before+1, after, "grpc_client_handling_seconds_count should be incremented for PingList OK")
159
160	before = sumCountersForMetricAndLabels(s.T(), "grpc_client_handling_seconds_count", "PingList", "server_stream")
161	ss, err := s.testClient.PingList(s.ctx, &pb_testproto.PingRequest{ErrorCodeReturned: uint32(codes.FailedPrecondition)}) // should return with code=FailedPrecondition
162	require.NoError(s.T(), err, "PingList must not fail immediately")
163	// Do a read, just to progate errors.
164	_, err = ss.Recv()
165	st, _ := status.FromError(err)
166	require.Equal(s.T(), codes.FailedPrecondition, st.Code(), "Recv must return FailedPrecondition, otherwise the test is wrong")
167
168	after = sumCountersForMetricAndLabels(s.T(), "grpc_client_handling_seconds_count", "PingList", "server_stream")
169	assert.EqualValues(s.T(), before+1, after, "grpc_client_handling_seconds_count should be incremented for PingList FailedPrecondition")
170}
171
172func (s *ClientInterceptorTestSuite) TestStreamingIncrementsHandled() {
173	var before int
174	var after int
175
176	before = sumCountersForMetricAndLabels(s.T(), "grpc_client_handled_total", "PingList", "server_stream", "OK")
177	ss, _ := s.testClient.PingList(s.ctx, &pb_testproto.PingRequest{}) // should return with code=OK
178	// Do a read, just for kicks.
179	for {
180		_, err := ss.Recv()
181		if err == io.EOF {
182			break
183		}
184		require.NoError(s.T(), err, "reading pingList shouldn't fail")
185	}
186	after = sumCountersForMetricAndLabels(s.T(), "grpc_client_handled_total", "PingList", "server_stream", "OK")
187	assert.EqualValues(s.T(), before+1, after, "grpc_client_handled_total should be incremented for PingList OK")
188
189	before = sumCountersForMetricAndLabels(s.T(), "grpc_client_handled_total", "PingList", "server_stream", "FailedPrecondition")
190	ss, err := s.testClient.PingList(s.ctx, &pb_testproto.PingRequest{ErrorCodeReturned: uint32(codes.FailedPrecondition)}) // should return with code=FailedPrecondition
191	require.NoError(s.T(), err, "PingList must not fail immediately")
192	// Do a read, just to progate errors.
193	_, err = ss.Recv()
194	st, _ := status.FromError(err)
195	require.Equal(s.T(), codes.FailedPrecondition, st.Code(), "Recv must return FailedPrecondition, otherwise the test is wrong")
196
197	after = sumCountersForMetricAndLabels(s.T(), "grpc_client_handled_total", "PingList", "server_stream", "FailedPrecondition")
198	assert.EqualValues(s.T(), before+1, after, "grpc_client_handled_total should be incremented for PingList FailedPrecondition")
199}
200
201func (s *ClientInterceptorTestSuite) TestStreamingIncrementsMessageCounts() {
202	beforeRecv := sumCountersForMetricAndLabels(s.T(), "grpc_client_msg_received_total", "PingList", "server_stream")
203	beforeSent := sumCountersForMetricAndLabels(s.T(), "grpc_client_msg_sent_total", "PingList", "server_stream")
204	ss, _ := s.testClient.PingList(s.ctx, &pb_testproto.PingRequest{}) // should return with code=OK
205	// Do a read, just for kicks.
206	count := 0
207	for {
208		_, err := ss.Recv()
209		if err == io.EOF {
210			break
211		}
212		require.NoError(s.T(), err, "reading pingList shouldn't fail")
213		count++
214	}
215	require.EqualValues(s.T(), countListResponses, count, "Number of received msg on the wire must match")
216	afterSent := sumCountersForMetricAndLabels(s.T(), "grpc_client_msg_sent_total", "PingList", "server_stream")
217	afterRecv := sumCountersForMetricAndLabels(s.T(), "grpc_client_msg_received_total", "PingList", "server_stream")
218
219	assert.EqualValues(s.T(), beforeSent+1, afterSent, "grpc_client_msg_sent_total should be incremented 20 times for PingList")
220	assert.EqualValues(s.T(), beforeRecv+countListResponses, afterRecv, "grpc_client_msg_sent_total should be incremented ones for PingList ")
221}
222