1/*
2 *
3 * Copyright 2020 gRPC authors.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 *     http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 */
18
19package test
20
21import (
22	"context"
23	"io"
24	"testing"
25
26	"google.golang.org/grpc"
27	"google.golang.org/grpc/codes"
28	"google.golang.org/grpc/internal/stubserver"
29	"google.golang.org/grpc/status"
30	testpb "google.golang.org/grpc/test/grpc_testing"
31)
32
33type ctxKey string
34
35func (s) TestChainUnaryServerInterceptor(t *testing.T) {
36	var (
37		firstIntKey  = ctxKey("firstIntKey")
38		secondIntKey = ctxKey("secondIntKey")
39	)
40
41	firstInt := func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
42		if ctx.Value(firstIntKey) != nil {
43			return nil, status.Errorf(codes.Internal, "first interceptor should not have %v in context", firstIntKey)
44		}
45		if ctx.Value(secondIntKey) != nil {
46			return nil, status.Errorf(codes.Internal, "first interceptor should not have %v in context", secondIntKey)
47		}
48
49		firstCtx := context.WithValue(ctx, firstIntKey, 0)
50		resp, err := handler(firstCtx, req)
51		if err != nil {
52			return nil, status.Errorf(codes.Internal, "failed to handle request at firstInt")
53		}
54
55		simpleResp, ok := resp.(*testpb.SimpleResponse)
56		if !ok {
57			return nil, status.Errorf(codes.Internal, "failed to get *testpb.SimpleResponse at firstInt")
58		}
59		return &testpb.SimpleResponse{
60			Payload: &testpb.Payload{
61				Type: simpleResp.GetPayload().GetType(),
62				Body: append(simpleResp.GetPayload().GetBody(), '1'),
63			},
64		}, nil
65	}
66
67	secondInt := func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
68		if ctx.Value(firstIntKey) == nil {
69			return nil, status.Errorf(codes.Internal, "second interceptor should have %v in context", firstIntKey)
70		}
71		if ctx.Value(secondIntKey) != nil {
72			return nil, status.Errorf(codes.Internal, "second interceptor should not have %v in context", secondIntKey)
73		}
74
75		secondCtx := context.WithValue(ctx, secondIntKey, 1)
76		resp, err := handler(secondCtx, req)
77		if err != nil {
78			return nil, status.Errorf(codes.Internal, "failed to handle request at secondInt")
79		}
80
81		simpleResp, ok := resp.(*testpb.SimpleResponse)
82		if !ok {
83			return nil, status.Errorf(codes.Internal, "failed to get *testpb.SimpleResponse at secondInt")
84		}
85		return &testpb.SimpleResponse{
86			Payload: &testpb.Payload{
87				Type: simpleResp.GetPayload().GetType(),
88				Body: append(simpleResp.GetPayload().GetBody(), '2'),
89			},
90		}, nil
91	}
92
93	lastInt := func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
94		if ctx.Value(firstIntKey) == nil {
95			return nil, status.Errorf(codes.Internal, "last interceptor should have %v in context", firstIntKey)
96		}
97		if ctx.Value(secondIntKey) == nil {
98			return nil, status.Errorf(codes.Internal, "last interceptor should not have %v in context", secondIntKey)
99		}
100
101		resp, err := handler(ctx, req)
102		if err != nil {
103			return nil, status.Errorf(codes.Internal, "failed to handle request at lastInt at lastInt")
104		}
105
106		simpleResp, ok := resp.(*testpb.SimpleResponse)
107		if !ok {
108			return nil, status.Errorf(codes.Internal, "failed to get *testpb.SimpleResponse at lastInt")
109		}
110		return &testpb.SimpleResponse{
111			Payload: &testpb.Payload{
112				Type: simpleResp.GetPayload().GetType(),
113				Body: append(simpleResp.GetPayload().GetBody(), '3'),
114			},
115		}, nil
116	}
117
118	sopts := []grpc.ServerOption{
119		grpc.ChainUnaryInterceptor(firstInt, secondInt, lastInt),
120	}
121
122	ss := &stubserver.StubServer{
123		UnaryCallF: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
124			payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, 0)
125			if err != nil {
126				return nil, status.Errorf(codes.Aborted, "failed to make payload: %v", err)
127			}
128
129			return &testpb.SimpleResponse{
130				Payload: payload,
131			}, nil
132		},
133	}
134	if err := ss.Start(sopts); err != nil {
135		t.Fatalf("Error starting endpoint server: %v", err)
136	}
137	defer ss.Stop()
138
139	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
140	defer cancel()
141	resp, err := ss.Client.UnaryCall(ctx, &testpb.SimpleRequest{})
142	if s, ok := status.FromError(err); !ok || s.Code() != codes.OK {
143		t.Fatalf("ss.Client.UnaryCall(ctx, _) = %v, %v; want nil, <status with Code()=OK>", resp, err)
144	}
145
146	respBytes := resp.Payload.GetBody()
147	if string(respBytes) != "321" {
148		t.Fatalf("invalid response: want=%s, but got=%s", "321", resp)
149	}
150}
151
152func (s) TestChainOnBaseUnaryServerInterceptor(t *testing.T) {
153	baseIntKey := ctxKey("baseIntKey")
154
155	baseInt := func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
156		if ctx.Value(baseIntKey) != nil {
157			return nil, status.Errorf(codes.Internal, "base interceptor should not have %v in context", baseIntKey)
158		}
159
160		baseCtx := context.WithValue(ctx, baseIntKey, 1)
161		return handler(baseCtx, req)
162	}
163
164	chainInt := func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
165		if ctx.Value(baseIntKey) == nil {
166			return nil, status.Errorf(codes.Internal, "chain interceptor should have %v in context", baseIntKey)
167		}
168
169		return handler(ctx, req)
170	}
171
172	sopts := []grpc.ServerOption{
173		grpc.UnaryInterceptor(baseInt),
174		grpc.ChainUnaryInterceptor(chainInt),
175	}
176
177	ss := &stubserver.StubServer{
178		EmptyCallF: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
179			return &testpb.Empty{}, nil
180		},
181	}
182	if err := ss.Start(sopts); err != nil {
183		t.Fatalf("Error starting endpoint server: %v", err)
184	}
185	defer ss.Stop()
186
187	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
188	defer cancel()
189	resp, err := ss.Client.EmptyCall(ctx, &testpb.Empty{})
190	if s, ok := status.FromError(err); !ok || s.Code() != codes.OK {
191		t.Fatalf("ss.Client.EmptyCall(ctx, _) = %v, %v; want nil, <status with Code()=OK>", resp, err)
192	}
193}
194
195func (s) TestChainStreamServerInterceptor(t *testing.T) {
196	callCounts := make([]int, 4)
197
198	firstInt := func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
199		if callCounts[0] != 0 {
200			return status.Errorf(codes.Internal, "callCounts[0] should be 0, but got=%d", callCounts[0])
201		}
202		if callCounts[1] != 0 {
203			return status.Errorf(codes.Internal, "callCounts[1] should be 0, but got=%d", callCounts[1])
204		}
205		if callCounts[2] != 0 {
206			return status.Errorf(codes.Internal, "callCounts[2] should be 0, but got=%d", callCounts[2])
207		}
208		if callCounts[3] != 0 {
209			return status.Errorf(codes.Internal, "callCounts[3] should be 0, but got=%d", callCounts[3])
210		}
211		callCounts[0]++
212		return handler(srv, stream)
213	}
214
215	secondInt := func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
216		if callCounts[0] != 1 {
217			return status.Errorf(codes.Internal, "callCounts[0] should be 1, but got=%d", callCounts[0])
218		}
219		if callCounts[1] != 0 {
220			return status.Errorf(codes.Internal, "callCounts[1] should be 0, but got=%d", callCounts[1])
221		}
222		if callCounts[2] != 0 {
223			return status.Errorf(codes.Internal, "callCounts[2] should be 0, but got=%d", callCounts[2])
224		}
225		if callCounts[3] != 0 {
226			return status.Errorf(codes.Internal, "callCounts[3] should be 0, but got=%d", callCounts[3])
227		}
228		callCounts[1]++
229		return handler(srv, stream)
230	}
231
232	lastInt := func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
233		if callCounts[0] != 1 {
234			return status.Errorf(codes.Internal, "callCounts[0] should be 1, but got=%d", callCounts[0])
235		}
236		if callCounts[1] != 1 {
237			return status.Errorf(codes.Internal, "callCounts[1] should be 1, but got=%d", callCounts[1])
238		}
239		if callCounts[2] != 0 {
240			return status.Errorf(codes.Internal, "callCounts[2] should be 0, but got=%d", callCounts[2])
241		}
242		if callCounts[3] != 0 {
243			return status.Errorf(codes.Internal, "callCounts[3] should be 0, but got=%d", callCounts[3])
244		}
245		callCounts[2]++
246		return handler(srv, stream)
247	}
248
249	sopts := []grpc.ServerOption{
250		grpc.ChainStreamInterceptor(firstInt, secondInt, lastInt),
251	}
252
253	ss := &stubserver.StubServer{
254		FullDuplexCallF: func(stream testpb.TestService_FullDuplexCallServer) error {
255			if callCounts[0] != 1 {
256				return status.Errorf(codes.Internal, "callCounts[0] should be 1, but got=%d", callCounts[0])
257			}
258			if callCounts[1] != 1 {
259				return status.Errorf(codes.Internal, "callCounts[1] should be 1, but got=%d", callCounts[1])
260			}
261			if callCounts[2] != 1 {
262				return status.Errorf(codes.Internal, "callCounts[2] should be 0, but got=%d", callCounts[2])
263			}
264			if callCounts[3] != 0 {
265				return status.Errorf(codes.Internal, "callCounts[3] should be 0, but got=%d", callCounts[3])
266			}
267			callCounts[3]++
268			return nil
269		},
270	}
271	if err := ss.Start(sopts); err != nil {
272		t.Fatalf("Error starting endpoint server: %v", err)
273	}
274	defer ss.Stop()
275
276	ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
277	defer cancel()
278	stream, err := ss.Client.FullDuplexCall(ctx)
279	if err != nil {
280		t.Fatalf("failed to FullDuplexCall: %v", err)
281	}
282
283	_, err = stream.Recv()
284	if err != io.EOF {
285		t.Fatalf("failed to recv from stream: %v", err)
286	}
287
288	if callCounts[3] != 1 {
289		t.Fatalf("callCounts[3] should be 1, but got=%d", callCounts[3])
290	}
291}
292