1// Copyright 2017 David Ackroyd. All Rights Reserved.
2// See LICENSE for licensing terms.
3
4package grpc_recovery_test
5
6import (
7	"context"
8	"testing"
9
10	grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
11	grpc_recovery "github.com/grpc-ecosystem/go-grpc-middleware/recovery"
12	grpc_testing "github.com/grpc-ecosystem/go-grpc-middleware/testing"
13	pb_testproto "github.com/grpc-ecosystem/go-grpc-middleware/testing/testproto"
14	"github.com/stretchr/testify/assert"
15	"github.com/stretchr/testify/require"
16	"github.com/stretchr/testify/suite"
17	"google.golang.org/grpc"
18	"google.golang.org/grpc/codes"
19	"google.golang.org/grpc/status"
20)
21
22var (
23	goodPing     = &pb_testproto.PingRequest{Value: "something", SleepTimeMs: 9999}
24	panicPing    = &pb_testproto.PingRequest{Value: "panic", SleepTimeMs: 9999}
25	nilPanicPing = &pb_testproto.PingRequest{Value: "nilpanic", SleepTimeMs: 9999}
26)
27
28type recoveryAssertService struct {
29	pb_testproto.TestServiceServer
30}
31
32func (s *recoveryAssertService) Ping(ctx context.Context, ping *pb_testproto.PingRequest) (*pb_testproto.PingResponse, error) {
33	if ping.Value == "panic" {
34		panic("very bad thing happened")
35	}
36	if ping.Value == "nilpanic" {
37		panic(nil)
38	}
39	return s.TestServiceServer.Ping(ctx, ping)
40}
41
42func (s *recoveryAssertService) PingList(ping *pb_testproto.PingRequest, stream pb_testproto.TestService_PingListServer) error {
43	if ping.Value == "panic" {
44		panic("very bad thing happened")
45	}
46	if ping.Value == "nilpanic" {
47		panic(nil)
48	}
49	return s.TestServiceServer.PingList(ping, stream)
50}
51
52func TestRecoverySuite(t *testing.T) {
53	s := &RecoverySuite{
54		InterceptorTestSuite: &grpc_testing.InterceptorTestSuite{
55			TestService: &recoveryAssertService{TestServiceServer: &grpc_testing.TestPingService{T: t}},
56			ServerOpts: []grpc.ServerOption{
57				grpc_middleware.WithStreamServerChain(
58					grpc_recovery.StreamServerInterceptor()),
59				grpc_middleware.WithUnaryServerChain(
60					grpc_recovery.UnaryServerInterceptor()),
61			},
62		},
63	}
64	suite.Run(t, s)
65}
66
67type RecoverySuite struct {
68	*grpc_testing.InterceptorTestSuite
69}
70
71func (s *RecoverySuite) TestUnary_SuccessfulRequest() {
72	_, err := s.Client.Ping(s.SimpleCtx(), goodPing)
73	require.NoError(s.T(), err, "no error must occur")
74}
75
76func (s *RecoverySuite) TestUnary_PanickingRequest() {
77	_, err := s.Client.Ping(s.SimpleCtx(), panicPing)
78	require.Error(s.T(), err, "there must be an error")
79	assert.Equal(s.T(), codes.Internal, status.Code(err), "must error with internal")
80	assert.Equal(s.T(), "very bad thing happened", status.Convert(err).Message(), "must error with message")
81}
82
83func (s *RecoverySuite) TestUnary_NilPanickingRequest() {
84	_, err := s.Client.Ping(s.SimpleCtx(), nilPanicPing)
85	require.Error(s.T(), err, "there must be an error")
86	assert.Equal(s.T(), codes.Internal, status.Code(err), "must error with internal")
87	assert.Equal(s.T(), "<nil>", status.Convert(err).Message(), "must error with <nil>")
88}
89
90func (s *RecoverySuite) TestStream_SuccessfulReceive() {
91	stream, err := s.Client.PingList(s.SimpleCtx(), goodPing)
92	require.NoError(s.T(), err, "should not fail on establishing the stream")
93	pong, err := stream.Recv()
94	require.NoError(s.T(), err, "no error must occur")
95	require.NotNil(s.T(), pong, "pong must not be nil")
96}
97
98func (s *RecoverySuite) TestStream_PanickingReceive() {
99	stream, err := s.Client.PingList(s.SimpleCtx(), panicPing)
100	require.NoError(s.T(), err, "should not fail on establishing the stream")
101	_, err = stream.Recv()
102	require.Error(s.T(), err, "there must be an error")
103	assert.Equal(s.T(), codes.Internal, status.Code(err), "must error with internal")
104	assert.Equal(s.T(), "very bad thing happened", status.Convert(err).Message(), "must error with message")
105}
106
107func (s *RecoverySuite) TestStream_NilPanickingReceive() {
108	stream, err := s.Client.PingList(s.SimpleCtx(), nilPanicPing)
109	require.NoError(s.T(), err, "should not fail on establishing the stream")
110	_, err = stream.Recv()
111	require.Error(s.T(), err, "there must be an error")
112	assert.Equal(s.T(), codes.Internal, status.Code(err), "must error with internal")
113	assert.Equal(s.T(), "<nil>", status.Convert(err).Message(), "must error with <nil>")
114}
115
116func TestRecoveryOverrideSuite(t *testing.T) {
117	opts := []grpc_recovery.Option{
118		grpc_recovery.WithRecoveryHandler(func(p interface{}) (err error) {
119			return status.Errorf(codes.Unknown, "panic triggered: %v", p)
120		}),
121	}
122	s := &RecoveryOverrideSuite{
123		InterceptorTestSuite: &grpc_testing.InterceptorTestSuite{
124			TestService: &recoveryAssertService{TestServiceServer: &grpc_testing.TestPingService{T: t}},
125			ServerOpts: []grpc.ServerOption{
126				grpc_middleware.WithStreamServerChain(
127					grpc_recovery.StreamServerInterceptor(opts...)),
128				grpc_middleware.WithUnaryServerChain(
129					grpc_recovery.UnaryServerInterceptor(opts...)),
130			},
131		},
132	}
133	suite.Run(t, s)
134}
135
136type RecoveryOverrideSuite struct {
137	*grpc_testing.InterceptorTestSuite
138}
139
140func (s *RecoveryOverrideSuite) TestUnary_SuccessfulRequest() {
141	_, err := s.Client.Ping(s.SimpleCtx(), goodPing)
142	require.NoError(s.T(), err, "no error must occur")
143}
144
145func (s *RecoveryOverrideSuite) TestUnary_PanickingRequest() {
146	_, err := s.Client.Ping(s.SimpleCtx(), panicPing)
147	require.Error(s.T(), err, "there must be an error")
148	assert.Equal(s.T(), codes.Unknown, status.Code(err), "must error with unknown")
149	assert.Equal(s.T(), "panic triggered: very bad thing happened", status.Convert(err).Message(), "must error with message")
150}
151
152func (s *RecoveryOverrideSuite) TestStream_SuccessfulReceive() {
153	stream, err := s.Client.PingList(s.SimpleCtx(), goodPing)
154	require.NoError(s.T(), err, "should not fail on establishing the stream")
155	pong, err := stream.Recv()
156	require.NoError(s.T(), err, "no error must occur")
157	require.NotNil(s.T(), pong, "pong must not be nil")
158}
159
160func (s *RecoveryOverrideSuite) TestStream_PanickingReceive() {
161	stream, err := s.Client.PingList(s.SimpleCtx(), panicPing)
162	require.NoError(s.T(), err, "should not fail on establishing the stream")
163	_, err = stream.Recv()
164	require.Error(s.T(), err, "there must be an error")
165	assert.Equal(s.T(), codes.Unknown, status.Code(err), "must error with unknown")
166	assert.Equal(s.T(), "panic triggered: very bad thing happened", status.Convert(err).Message(), "must error with message")
167}
168