1// Copyright 2016 Michal Witkowski. All Rights Reserved.
2// See LICENSE for licensing terms.
3
4package grpc_prometheus
5
6import (
7	"bufio"
8	"io"
9	"net"
10	"net/http"
11	"net/http/httptest"
12	"strconv"
13	"strings"
14	"testing"
15	"time"
16
17	pb_testproto "github.com/grpc-ecosystem/go-grpc-prometheus/examples/testproto"
18	"github.com/prometheus/client_golang/prometheus"
19	"github.com/stretchr/testify/assert"
20	"github.com/stretchr/testify/require"
21	"github.com/stretchr/testify/suite"
22	"golang.org/x/net/context"
23	"google.golang.org/grpc"
24	"google.golang.org/grpc/codes"
25	"google.golang.org/grpc/status"
26)
27
28var (
29	// server metrics must satisfy the Collector interface
30	_ prometheus.Collector = NewServerMetrics()
31)
32
33const (
34	pingDefaultValue   = "I like kittens."
35	countListResponses = 20
36)
37
38func TestServerInterceptorSuite(t *testing.T) {
39	suite.Run(t, &ServerInterceptorTestSuite{})
40}
41
42type ServerInterceptorTestSuite struct {
43	suite.Suite
44
45	serverListener net.Listener
46	server         *grpc.Server
47	clientConn     *grpc.ClientConn
48	testClient     pb_testproto.TestServiceClient
49	ctx            context.Context
50}
51
52func (s *ServerInterceptorTestSuite) SetupSuite() {
53	var err error
54
55	EnableHandlingTimeHistogram()
56
57	s.serverListener, err = net.Listen("tcp", "127.0.0.1:0")
58	require.NoError(s.T(), err, "must be able to allocate a port for serverListener")
59
60	// This is the point where we hook up the interceptor
61	s.server = grpc.NewServer(
62		grpc.StreamInterceptor(StreamServerInterceptor),
63		grpc.UnaryInterceptor(UnaryServerInterceptor),
64	)
65	pb_testproto.RegisterTestServiceServer(s.server, &testService{t: s.T()})
66
67	go func() {
68		s.server.Serve(s.serverListener)
69	}()
70
71	s.clientConn, err = grpc.Dial(s.serverListener.Addr().String(), grpc.WithInsecure(), grpc.WithBlock(), grpc.WithTimeout(2*time.Second))
72	require.NoError(s.T(), err, "must not error on client Dial")
73	s.testClient = pb_testproto.NewTestServiceClient(s.clientConn)
74
75	// Important! Pre-register stuff here.
76	Register(s.server)
77}
78
79func (s *ServerInterceptorTestSuite) SetupTest() {
80	// Make all RPC calls last at most 2 sec, meaning all async issues or deadlock will not kill tests.
81	s.ctx, _ = context.WithTimeout(context.TODO(), 2*time.Second)
82}
83
84func (s *ServerInterceptorTestSuite) TearDownSuite() {
85	if s.serverListener != nil {
86		s.server.Stop()
87		s.T().Logf("stopped grpc.Server at: %v", s.serverListener.Addr().String())
88		s.serverListener.Close()
89
90	}
91	if s.clientConn != nil {
92		s.clientConn.Close()
93	}
94}
95
96func (s *ServerInterceptorTestSuite) TestRegisterPresetsStuff() {
97	for testID, testCase := range []struct {
98		metricName     string
99		existingLabels []string
100	}{
101		{"grpc_server_started_total", []string{"mwitkow.testproto.TestService", "PingEmpty", "unary"}},
102		{"grpc_server_started_total", []string{"mwitkow.testproto.TestService", "PingList", "server_stream"}},
103		{"grpc_server_msg_received_total", []string{"mwitkow.testproto.TestService", "PingList", "server_stream"}},
104		{"grpc_server_msg_sent_total", []string{"mwitkow.testproto.TestService", "PingEmpty", "unary"}},
105		{"grpc_server_handling_seconds_sum", []string{"mwitkow.testproto.TestService", "PingEmpty", "unary"}},
106		{"grpc_server_handling_seconds_count", []string{"mwitkow.testproto.TestService", "PingList", "server_stream"}},
107		{"grpc_server_handled_total", []string{"mwitkow.testproto.TestService", "PingList", "server_stream", "OutOfRange"}},
108		{"grpc_server_handled_total", []string{"mwitkow.testproto.TestService", "PingList", "server_stream", "Aborted"}},
109		{"grpc_server_handled_total", []string{"mwitkow.testproto.TestService", "PingEmpty", "unary", "FailedPrecondition"}},
110		{"grpc_server_handled_total", []string{"mwitkow.testproto.TestService", "PingEmpty", "unary", "ResourceExhausted"}},
111	} {
112		lineCount := len(fetchPrometheusLines(s.T(), testCase.metricName, testCase.existingLabels...))
113		assert.NotEqual(s.T(), 0, lineCount, "metrics must exist for test case %d", testID)
114	}
115}
116
117func (s *ServerInterceptorTestSuite) TestUnaryIncrementsStarted() {
118	var before int
119	var after int
120
121	before = sumCountersForMetricAndLabels(s.T(), "grpc_server_started_total", "PingEmpty", "unary")
122	s.testClient.PingEmpty(s.ctx, &pb_testproto.Empty{})
123	after = sumCountersForMetricAndLabels(s.T(), "grpc_server_started_total", "PingEmpty", "unary")
124	assert.EqualValues(s.T(), before+1, after, "grpc_server_started_total should be incremented for PingEmpty")
125
126	before = sumCountersForMetricAndLabels(s.T(), "grpc_server_started_total", "PingError", "unary")
127	s.testClient.PingError(s.ctx, &pb_testproto.PingRequest{ErrorCodeReturned: uint32(codes.Unavailable)})
128	after = sumCountersForMetricAndLabels(s.T(), "grpc_server_started_total", "PingError", "unary")
129	assert.EqualValues(s.T(), before+1, after, "grpc_server_started_total should be incremented for PingError")
130}
131
132func (s *ServerInterceptorTestSuite) TestUnaryIncrementsHandled() {
133	var before int
134	var after int
135
136	before = sumCountersForMetricAndLabels(s.T(), "grpc_server_handled_total", "PingEmpty", "unary", "OK")
137	s.testClient.PingEmpty(s.ctx, &pb_testproto.Empty{}) // should return with code=OK
138	after = sumCountersForMetricAndLabels(s.T(), "grpc_server_handled_total", "PingEmpty", "unary", "OK")
139	assert.EqualValues(s.T(), before+1, after, "grpc_server_handled_count should be incremented for PingEmpty")
140
141	before = sumCountersForMetricAndLabels(s.T(), "grpc_server_handled_total", "PingError", "unary", "FailedPrecondition")
142	s.testClient.PingError(s.ctx, &pb_testproto.PingRequest{ErrorCodeReturned: uint32(codes.FailedPrecondition)}) // should return with code=FailedPrecondition
143	after = sumCountersForMetricAndLabels(s.T(), "grpc_server_handled_total", "PingError", "unary", "FailedPrecondition")
144	assert.EqualValues(s.T(), before+1, after, "grpc_server_handled_total should be incremented for PingError")
145}
146
147func (s *ServerInterceptorTestSuite) TestUnaryIncrementsHistograms() {
148	var before int
149	var after int
150
151	before = sumCountersForMetricAndLabels(s.T(), "grpc_server_handling_seconds_count", "PingEmpty", "unary")
152	s.testClient.PingEmpty(s.ctx, &pb_testproto.Empty{}) // should return with code=OK
153	after = sumCountersForMetricAndLabels(s.T(), "grpc_server_handling_seconds_count", "PingEmpty", "unary")
154	assert.EqualValues(s.T(), before+1, after, "grpc_server_handled_count should be incremented for PingEmpty")
155
156	before = sumCountersForMetricAndLabels(s.T(), "grpc_server_handling_seconds_count", "PingError", "unary")
157	s.testClient.PingError(s.ctx, &pb_testproto.PingRequest{ErrorCodeReturned: uint32(codes.FailedPrecondition)}) // should return with code=FailedPrecondition
158	after = sumCountersForMetricAndLabels(s.T(), "grpc_server_handling_seconds_count", "PingError", "unary")
159	assert.EqualValues(s.T(), before+1, after, "grpc_server_handling_seconds_count should be incremented for PingError")
160}
161
162func (s *ServerInterceptorTestSuite) TestStreamingIncrementsStarted() {
163	var before int
164	var after int
165
166	before = sumCountersForMetricAndLabels(s.T(), "grpc_server_started_total", "PingList", "server_stream")
167	s.testClient.PingList(s.ctx, &pb_testproto.PingRequest{})
168	after = sumCountersForMetricAndLabels(s.T(), "grpc_server_started_total", "PingList", "server_stream")
169	assert.EqualValues(s.T(), before+1, after, "grpc_server_started_total should be incremented for PingList")
170}
171
172func (s *ServerInterceptorTestSuite) TestStreamingIncrementsHistograms() {
173	var before int
174	var after int
175
176	before = sumCountersForMetricAndLabels(s.T(), "grpc_server_handling_seconds_count", "PingList", "server_stream")
177	ss, _ := s.testClient.PingList(s.ctx, &pb_testproto.PingRequest{}) // should return with code=OK
178	// Do a read, just for kicks.
179	for {
180		_, err := ss.Recv()
181		if err == io.EOF {
182			break
183		}
184		require.NoError(s.T(), err, "reading pingList shouldn't fail")
185	}
186	after = sumCountersForMetricAndLabels(s.T(), "grpc_server_handling_seconds_count", "PingList", "server_stream")
187	assert.EqualValues(s.T(), before+1, after, "grpc_server_handling_seconds_count should be incremented for PingList OK")
188
189	before = sumCountersForMetricAndLabels(s.T(), "grpc_server_handling_seconds_count", "PingList", "server_stream")
190	_, err := s.testClient.PingList(s.ctx, &pb_testproto.PingRequest{ErrorCodeReturned: uint32(codes.FailedPrecondition)}) // should return with code=FailedPrecondition
191	require.NoError(s.T(), err, "PingList must not fail immediately")
192
193	after = sumCountersForMetricAndLabels(s.T(), "grpc_server_handling_seconds_count", "PingList", "server_stream")
194	assert.EqualValues(s.T(), before+1, after, "grpc_server_handling_seconds_count should be incremented for PingList FailedPrecondition")
195}
196
197func (s *ServerInterceptorTestSuite) TestStreamingIncrementsHandled() {
198	var before int
199	var after int
200
201	before = sumCountersForMetricAndLabels(s.T(), "grpc_server_handled_total", "PingList", "server_stream", "OK")
202	ss, _ := s.testClient.PingList(s.ctx, &pb_testproto.PingRequest{}) // should return with code=OK
203	// Do a read, just for kicks.
204	for {
205		_, err := ss.Recv()
206		if err == io.EOF {
207			break
208		}
209		require.NoError(s.T(), err, "reading pingList shouldn't fail")
210	}
211	after = sumCountersForMetricAndLabels(s.T(), "grpc_server_handled_total", "PingList", "server_stream", "OK")
212	assert.EqualValues(s.T(), before+1, after, "grpc_server_handled_total should be incremented for PingList OK")
213
214	before = sumCountersForMetricAndLabels(s.T(), "grpc_server_handled_total", "PingList", "server_stream", "FailedPrecondition")
215	_, err := s.testClient.PingList(s.ctx, &pb_testproto.PingRequest{ErrorCodeReturned: uint32(codes.FailedPrecondition)}) // should return with code=FailedPrecondition
216	require.NoError(s.T(), err, "PingList must not fail immediately")
217
218	after = sumCountersForMetricAndLabels(s.T(), "grpc_server_handled_total", "PingList", "server_stream", "FailedPrecondition")
219	assert.EqualValues(s.T(), before+1, after, "grpc_server_handled_total should be incremented for PingList FailedPrecondition")
220}
221
222func (s *ServerInterceptorTestSuite) TestStreamingIncrementsMessageCounts() {
223	beforeRecv := sumCountersForMetricAndLabels(s.T(), "grpc_server_msg_received_total", "PingList", "server_stream")
224	beforeSent := sumCountersForMetricAndLabels(s.T(), "grpc_server_msg_sent_total", "PingList", "server_stream")
225	ss, _ := s.testClient.PingList(s.ctx, &pb_testproto.PingRequest{}) // should return with code=OK
226	// Do a read, just for kicks.
227	count := 0
228	for {
229		_, err := ss.Recv()
230		if err == io.EOF {
231			break
232		}
233		require.NoError(s.T(), err, "reading pingList shouldn't fail")
234		count++
235	}
236	require.EqualValues(s.T(), countListResponses, count, "Number of received msg on the wire must match")
237	afterSent := sumCountersForMetricAndLabels(s.T(), "grpc_server_msg_sent_total", "PingList", "server_stream")
238	afterRecv := sumCountersForMetricAndLabels(s.T(), "grpc_server_msg_received_total", "PingList", "server_stream")
239
240	assert.EqualValues(s.T(), beforeSent+countListResponses, afterSent, "grpc_server_msg_sent_total should be incremented 20 times for PingList")
241	assert.EqualValues(s.T(), beforeRecv+1, afterRecv, "grpc_server_msg_sent_total should be incremented ones for PingList ")
242}
243
244func fetchPrometheusLines(t *testing.T, metricName string, matchingLabelValues ...string) []string {
245	resp := httptest.NewRecorder()
246	req, err := http.NewRequest("GET", "/", nil)
247	require.NoError(t, err, "failed creating request for Prometheus handler")
248	prometheus.Handler().ServeHTTP(resp, req)
249	reader := bufio.NewReader(resp.Body)
250	ret := []string{}
251	for {
252		line, err := reader.ReadString('\n')
253		if err == io.EOF {
254			break
255		} else {
256			require.NoError(t, err, "error reading stuff")
257		}
258		if !strings.HasPrefix(line, metricName) {
259			continue
260		}
261		matches := true
262		for _, labelValue := range matchingLabelValues {
263			if !strings.Contains(line, `"`+labelValue+`"`) {
264				matches = false
265			}
266		}
267		if matches {
268			ret = append(ret, line)
269		}
270
271	}
272	return ret
273}
274
275func sumCountersForMetricAndLabels(t *testing.T, metricName string, matchingLabelValues ...string) int {
276	count := 0
277	for _, line := range fetchPrometheusLines(t, metricName, matchingLabelValues...) {
278		valueString := line[strings.LastIndex(line, " ")+1 : len(line)-1]
279		valueFloat, err := strconv.ParseFloat(valueString, 32)
280		require.NoError(t, err, "failed parsing value for line: %v", line)
281		count += int(valueFloat)
282	}
283	return count
284}
285
286type testService struct {
287	t *testing.T
288}
289
290func (s *testService) PingEmpty(ctx context.Context, _ *pb_testproto.Empty) (*pb_testproto.PingResponse, error) {
291	return &pb_testproto.PingResponse{Value: pingDefaultValue, Counter: 42}, nil
292}
293
294func (s *testService) Ping(ctx context.Context, ping *pb_testproto.PingRequest) (*pb_testproto.PingResponse, error) {
295	// Send user trailers and headers.
296	return &pb_testproto.PingResponse{Value: ping.Value, Counter: 42}, nil
297}
298
299func (s *testService) PingError(ctx context.Context, ping *pb_testproto.PingRequest) (*pb_testproto.Empty, error) {
300	code := codes.Code(ping.ErrorCodeReturned)
301	return nil, status.Errorf(code, "Userspace error.")
302}
303
304func (s *testService) PingList(ping *pb_testproto.PingRequest, stream pb_testproto.TestService_PingListServer) error {
305	if ping.ErrorCodeReturned != 0 {
306		return status.Errorf(codes.Code(ping.ErrorCodeReturned), "foobar")
307	}
308	// Send user trailers and headers.
309	for i := 0; i < countListResponses; i++ {
310		stream.Send(&pb_testproto.PingResponse{Value: ping.Value, Counter: int32(i)})
311	}
312	return nil
313}
314