1// Copyright 2016 Michal Witkowski. All Rights Reserved. 2// See LICENSE for licensing terms. 3 4package grpc_prometheus 5 6import ( 7 "net" 8 "testing" 9 10 "time" 11 12 "io" 13 14 pb_testproto "github.com/grpc-ecosystem/go-grpc-prometheus/examples/testproto" 15 "github.com/prometheus/client_golang/prometheus" 16 "github.com/stretchr/testify/assert" 17 "github.com/stretchr/testify/require" 18 "github.com/stretchr/testify/suite" 19 "golang.org/x/net/context" 20 "google.golang.org/grpc" 21 "google.golang.org/grpc/codes" 22 "google.golang.org/grpc/status" 23) 24 25var ( 26 // client metrics must satisfy the Collector interface 27 _ prometheus.Collector = NewClientMetrics() 28) 29 30func TestClientInterceptorSuite(t *testing.T) { 31 suite.Run(t, &ClientInterceptorTestSuite{}) 32} 33 34type ClientInterceptorTestSuite struct { 35 suite.Suite 36 37 serverListener net.Listener 38 server *grpc.Server 39 clientConn *grpc.ClientConn 40 testClient pb_testproto.TestServiceClient 41 ctx context.Context 42} 43 44func (s *ClientInterceptorTestSuite) SetupSuite() { 45 var err error 46 47 EnableClientHandlingTimeHistogram() 48 49 s.serverListener, err = net.Listen("tcp", "127.0.0.1:0") 50 require.NoError(s.T(), err, "must be able to allocate a port for serverListener") 51 52 // This is the point where we hook up the interceptor 53 s.server = grpc.NewServer() 54 pb_testproto.RegisterTestServiceServer(s.server, &testService{t: s.T()}) 55 56 go func() { 57 s.server.Serve(s.serverListener) 58 }() 59 60 s.clientConn, err = grpc.Dial( 61 s.serverListener.Addr().String(), 62 grpc.WithInsecure(), 63 grpc.WithBlock(), 64 grpc.WithUnaryInterceptor(UnaryClientInterceptor), 65 grpc.WithStreamInterceptor(StreamClientInterceptor), 66 grpc.WithTimeout(2*time.Second)) 67 require.NoError(s.T(), err, "must not error on client Dial") 68 s.testClient = pb_testproto.NewTestServiceClient(s.clientConn) 69} 70 71func (s *ClientInterceptorTestSuite) SetupTest() { 72 // Make all RPC calls last at most 2 sec, meaning all async issues or deadlock will not kill tests. 73 s.ctx, _ = context.WithTimeout(context.TODO(), 2*time.Second) 74} 75 76func (s *ClientInterceptorTestSuite) TearDownSuite() { 77 if s.serverListener != nil { 78 s.server.Stop() 79 s.T().Logf("stopped grpc.Server at: %v", s.serverListener.Addr().String()) 80 s.serverListener.Close() 81 82 } 83 if s.clientConn != nil { 84 s.clientConn.Close() 85 } 86} 87 88func (s *ClientInterceptorTestSuite) TestUnaryIncrementsStarted() { 89 var before int 90 var after int 91 92 before = sumCountersForMetricAndLabels(s.T(), "grpc_client_started_total", "PingEmpty", "unary") 93 s.testClient.PingEmpty(s.ctx, &pb_testproto.Empty{}) 94 after = sumCountersForMetricAndLabels(s.T(), "grpc_client_started_total", "PingEmpty", "unary") 95 assert.EqualValues(s.T(), before+1, after, "grpc_client_started_total should be incremented for PingEmpty") 96 97 before = sumCountersForMetricAndLabels(s.T(), "grpc_client_started_total", "PingError", "unary") 98 s.testClient.PingError(s.ctx, &pb_testproto.PingRequest{ErrorCodeReturned: uint32(codes.Unavailable)}) 99 after = sumCountersForMetricAndLabels(s.T(), "grpc_client_started_total", "PingError", "unary") 100 assert.EqualValues(s.T(), before+1, after, "grpc_client_started_total should be incremented for PingError") 101} 102 103func (s *ClientInterceptorTestSuite) TestUnaryIncrementsHandled() { 104 var before int 105 var after int 106 107 before = sumCountersForMetricAndLabels(s.T(), "grpc_client_handled_total", "PingEmpty", "unary", "OK") 108 s.testClient.PingEmpty(s.ctx, &pb_testproto.Empty{}) // should return with code=OK 109 after = sumCountersForMetricAndLabels(s.T(), "grpc_client_handled_total", "PingEmpty", "unary", "OK") 110 assert.EqualValues(s.T(), before+1, after, "grpc_client_handled_count should be incremented for PingEmpty") 111 112 before = sumCountersForMetricAndLabels(s.T(), "grpc_client_handled_total", "PingError", "unary", "FailedPrecondition") 113 s.testClient.PingError(s.ctx, &pb_testproto.PingRequest{ErrorCodeReturned: uint32(codes.FailedPrecondition)}) // should return with code=FailedPrecondition 114 after = sumCountersForMetricAndLabels(s.T(), "grpc_client_handled_total", "PingError", "unary", "FailedPrecondition") 115 assert.EqualValues(s.T(), before+1, after, "grpc_client_handled_total should be incremented for PingError") 116} 117 118func (s *ClientInterceptorTestSuite) TestUnaryIncrementsHistograms() { 119 var before int 120 var after int 121 122 before = sumCountersForMetricAndLabels(s.T(), "grpc_client_handling_seconds_count", "PingEmpty", "unary") 123 s.testClient.PingEmpty(s.ctx, &pb_testproto.Empty{}) // should return with code=OK 124 after = sumCountersForMetricAndLabels(s.T(), "grpc_client_handling_seconds_count", "PingEmpty", "unary") 125 assert.EqualValues(s.T(), before+1, after, "grpc_client_handled_count should be incremented for PingEmpty") 126 127 before = sumCountersForMetricAndLabels(s.T(), "grpc_client_handling_seconds_count", "PingError", "unary") 128 s.testClient.PingError(s.ctx, &pb_testproto.PingRequest{ErrorCodeReturned: uint32(codes.FailedPrecondition)}) // should return with code=FailedPrecondition 129 after = sumCountersForMetricAndLabels(s.T(), "grpc_client_handling_seconds_count", "PingError", "unary") 130 assert.EqualValues(s.T(), before+1, after, "grpc_client_handling_seconds_count should be incremented for PingError") 131} 132 133func (s *ClientInterceptorTestSuite) TestStreamingIncrementsStarted() { 134 var before int 135 var after int 136 137 before = sumCountersForMetricAndLabels(s.T(), "grpc_client_started_total", "PingList", "server_stream") 138 s.testClient.PingList(s.ctx, &pb_testproto.PingRequest{}) 139 after = sumCountersForMetricAndLabels(s.T(), "grpc_client_started_total", "PingList", "server_stream") 140 assert.EqualValues(s.T(), before+1, after, "grpc_client_started_total should be incremented for PingList") 141} 142 143func (s *ClientInterceptorTestSuite) TestStreamingIncrementsHistograms() { 144 var before int 145 var after int 146 147 before = sumCountersForMetricAndLabels(s.T(), "grpc_client_handling_seconds_count", "PingList", "server_stream") 148 ss, _ := s.testClient.PingList(s.ctx, &pb_testproto.PingRequest{}) // should return with code=OK 149 // Do a read, just for kicks. 150 for { 151 _, err := ss.Recv() 152 if err == io.EOF { 153 break 154 } 155 require.NoError(s.T(), err, "reading pingList shouldn't fail") 156 } 157 after = sumCountersForMetricAndLabels(s.T(), "grpc_client_handling_seconds_count", "PingList", "server_stream") 158 assert.EqualValues(s.T(), before+1, after, "grpc_client_handling_seconds_count should be incremented for PingList OK") 159 160 before = sumCountersForMetricAndLabels(s.T(), "grpc_client_handling_seconds_count", "PingList", "server_stream") 161 ss, err := s.testClient.PingList(s.ctx, &pb_testproto.PingRequest{ErrorCodeReturned: uint32(codes.FailedPrecondition)}) // should return with code=FailedPrecondition 162 require.NoError(s.T(), err, "PingList must not fail immediately") 163 // Do a read, just to progate errors. 164 _, err = ss.Recv() 165 st, _ := status.FromError(err) 166 require.Equal(s.T(), codes.FailedPrecondition, st.Code(), "Recv must return FailedPrecondition, otherwise the test is wrong") 167 168 after = sumCountersForMetricAndLabels(s.T(), "grpc_client_handling_seconds_count", "PingList", "server_stream") 169 assert.EqualValues(s.T(), before+1, after, "grpc_client_handling_seconds_count should be incremented for PingList FailedPrecondition") 170} 171 172func (s *ClientInterceptorTestSuite) TestStreamingIncrementsHandled() { 173 var before int 174 var after int 175 176 before = sumCountersForMetricAndLabels(s.T(), "grpc_client_handled_total", "PingList", "server_stream", "OK") 177 ss, _ := s.testClient.PingList(s.ctx, &pb_testproto.PingRequest{}) // should return with code=OK 178 // Do a read, just for kicks. 179 for { 180 _, err := ss.Recv() 181 if err == io.EOF { 182 break 183 } 184 require.NoError(s.T(), err, "reading pingList shouldn't fail") 185 } 186 after = sumCountersForMetricAndLabels(s.T(), "grpc_client_handled_total", "PingList", "server_stream", "OK") 187 assert.EqualValues(s.T(), before+1, after, "grpc_client_handled_total should be incremented for PingList OK") 188 189 before = sumCountersForMetricAndLabels(s.T(), "grpc_client_handled_total", "PingList", "server_stream", "FailedPrecondition") 190 ss, err := s.testClient.PingList(s.ctx, &pb_testproto.PingRequest{ErrorCodeReturned: uint32(codes.FailedPrecondition)}) // should return with code=FailedPrecondition 191 require.NoError(s.T(), err, "PingList must not fail immediately") 192 // Do a read, just to progate errors. 193 _, err = ss.Recv() 194 st, _ := status.FromError(err) 195 require.Equal(s.T(), codes.FailedPrecondition, st.Code(), "Recv must return FailedPrecondition, otherwise the test is wrong") 196 197 after = sumCountersForMetricAndLabels(s.T(), "grpc_client_handled_total", "PingList", "server_stream", "FailedPrecondition") 198 assert.EqualValues(s.T(), before+1, after, "grpc_client_handled_total should be incremented for PingList FailedPrecondition") 199} 200 201func (s *ClientInterceptorTestSuite) TestStreamingIncrementsMessageCounts() { 202 beforeRecv := sumCountersForMetricAndLabels(s.T(), "grpc_client_msg_received_total", "PingList", "server_stream") 203 beforeSent := sumCountersForMetricAndLabels(s.T(), "grpc_client_msg_sent_total", "PingList", "server_stream") 204 ss, _ := s.testClient.PingList(s.ctx, &pb_testproto.PingRequest{}) // should return with code=OK 205 // Do a read, just for kicks. 206 count := 0 207 for { 208 _, err := ss.Recv() 209 if err == io.EOF { 210 break 211 } 212 require.NoError(s.T(), err, "reading pingList shouldn't fail") 213 count++ 214 } 215 require.EqualValues(s.T(), countListResponses, count, "Number of received msg on the wire must match") 216 afterSent := sumCountersForMetricAndLabels(s.T(), "grpc_client_msg_sent_total", "PingList", "server_stream") 217 afterRecv := sumCountersForMetricAndLabels(s.T(), "grpc_client_msg_received_total", "PingList", "server_stream") 218 219 assert.EqualValues(s.T(), beforeSent+1, afterSent, "grpc_client_msg_sent_total should be incremented 20 times for PingList") 220 assert.EqualValues(s.T(), beforeRecv+countListResponses, afterRecv, "grpc_client_msg_sent_total should be incremented ones for PingList ") 221} 222