1/*
2 *
3 * Copyright 2020 gRPC authors.
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 *     http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 */
18
19package test
20
21import (
22	"context"
23	"io"
24	"testing"
25
26	"google.golang.org/grpc"
27	"google.golang.org/grpc/codes"
28	"google.golang.org/grpc/status"
29	testpb "google.golang.org/grpc/test/grpc_testing"
30)
31
32type ctxKey string
33
34func (s) TestChainUnaryServerInterceptor(t *testing.T) {
35	var (
36		firstIntKey  = ctxKey("firstIntKey")
37		secondIntKey = ctxKey("secondIntKey")
38	)
39
40	firstInt := func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
41		if ctx.Value(firstIntKey) != nil {
42			return nil, status.Errorf(codes.Internal, "first interceptor should not have %v in context", firstIntKey)
43		}
44		if ctx.Value(secondIntKey) != nil {
45			return nil, status.Errorf(codes.Internal, "first interceptor should not have %v in context", secondIntKey)
46		}
47
48		firstCtx := context.WithValue(ctx, firstIntKey, 0)
49		resp, err := handler(firstCtx, req)
50		if err != nil {
51			return nil, status.Errorf(codes.Internal, "failed to handle request at firstInt")
52		}
53
54		simpleResp, ok := resp.(*testpb.SimpleResponse)
55		if !ok {
56			return nil, status.Errorf(codes.Internal, "failed to get *testpb.SimpleResponse at firstInt")
57		}
58		return &testpb.SimpleResponse{
59			Payload: &testpb.Payload{
60				Type: simpleResp.GetPayload().GetType(),
61				Body: append(simpleResp.GetPayload().GetBody(), '1'),
62			},
63		}, nil
64	}
65
66	secondInt := func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
67		if ctx.Value(firstIntKey) == nil {
68			return nil, status.Errorf(codes.Internal, "second interceptor should have %v in context", firstIntKey)
69		}
70		if ctx.Value(secondIntKey) != nil {
71			return nil, status.Errorf(codes.Internal, "second interceptor should not have %v in context", secondIntKey)
72		}
73
74		secondCtx := context.WithValue(ctx, secondIntKey, 1)
75		resp, err := handler(secondCtx, req)
76		if err != nil {
77			return nil, status.Errorf(codes.Internal, "failed to handle request at secondInt")
78		}
79
80		simpleResp, ok := resp.(*testpb.SimpleResponse)
81		if !ok {
82			return nil, status.Errorf(codes.Internal, "failed to get *testpb.SimpleResponse at secondInt")
83		}
84		return &testpb.SimpleResponse{
85			Payload: &testpb.Payload{
86				Type: simpleResp.GetPayload().GetType(),
87				Body: append(simpleResp.GetPayload().GetBody(), '2'),
88			},
89		}, nil
90	}
91
92	lastInt := func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
93		if ctx.Value(firstIntKey) == nil {
94			return nil, status.Errorf(codes.Internal, "last interceptor should have %v in context", firstIntKey)
95		}
96		if ctx.Value(secondIntKey) == nil {
97			return nil, status.Errorf(codes.Internal, "last interceptor should not have %v in context", secondIntKey)
98		}
99
100		resp, err := handler(ctx, req)
101		if err != nil {
102			return nil, status.Errorf(codes.Internal, "failed to handle request at lastInt at lastInt")
103		}
104
105		simpleResp, ok := resp.(*testpb.SimpleResponse)
106		if !ok {
107			return nil, status.Errorf(codes.Internal, "failed to get *testpb.SimpleResponse at lastInt")
108		}
109		return &testpb.SimpleResponse{
110			Payload: &testpb.Payload{
111				Type: simpleResp.GetPayload().GetType(),
112				Body: append(simpleResp.GetPayload().GetBody(), '3'),
113			},
114		}, nil
115	}
116
117	sopts := []grpc.ServerOption{
118		grpc.ChainUnaryInterceptor(firstInt, secondInt, lastInt),
119	}
120
121	ss := &stubServer{
122		unaryCall: func(ctx context.Context, in *testpb.SimpleRequest) (*testpb.SimpleResponse, error) {
123			payload, err := newPayload(testpb.PayloadType_COMPRESSABLE, 0)
124			if err != nil {
125				return nil, status.Errorf(codes.Aborted, "failed to make payload: %v", err)
126			}
127
128			return &testpb.SimpleResponse{
129				Payload: payload,
130			}, nil
131		},
132	}
133	if err := ss.Start(sopts); err != nil {
134		t.Fatalf("Error starting endpoint server: %v", err)
135	}
136	defer ss.Stop()
137
138	resp, err := ss.client.UnaryCall(context.Background(), &testpb.SimpleRequest{})
139	if s, ok := status.FromError(err); !ok || s.Code() != codes.OK {
140		t.Fatalf("ss.client.UnaryCall(context.Background(), _) = %v, %v; want nil, <status with Code()=OK>", resp, err)
141	}
142
143	respBytes := resp.Payload.GetBody()
144	if string(respBytes) != "321" {
145		t.Fatalf("invalid response: want=%s, but got=%s", "321", resp)
146	}
147}
148
149func (s) TestChainOnBaseUnaryServerInterceptor(t *testing.T) {
150	baseIntKey := ctxKey("baseIntKey")
151
152	baseInt := func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
153		if ctx.Value(baseIntKey) != nil {
154			return nil, status.Errorf(codes.Internal, "base interceptor should not have %v in context", baseIntKey)
155		}
156
157		baseCtx := context.WithValue(ctx, baseIntKey, 1)
158		return handler(baseCtx, req)
159	}
160
161	chainInt := func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
162		if ctx.Value(baseIntKey) == nil {
163			return nil, status.Errorf(codes.Internal, "chain interceptor should have %v in context", baseIntKey)
164		}
165
166		return handler(ctx, req)
167	}
168
169	sopts := []grpc.ServerOption{
170		grpc.UnaryInterceptor(baseInt),
171		grpc.ChainUnaryInterceptor(chainInt),
172	}
173
174	ss := &stubServer{
175		emptyCall: func(ctx context.Context, in *testpb.Empty) (*testpb.Empty, error) {
176			return &testpb.Empty{}, nil
177		},
178	}
179	if err := ss.Start(sopts); err != nil {
180		t.Fatalf("Error starting endpoint server: %v", err)
181	}
182	defer ss.Stop()
183
184	resp, err := ss.client.EmptyCall(context.Background(), &testpb.Empty{})
185	if s, ok := status.FromError(err); !ok || s.Code() != codes.OK {
186		t.Fatalf("ss.client.EmptyCall(context.Background(), _) = %v, %v; want nil, <status with Code()=OK>", resp, err)
187	}
188}
189
190func (s) TestChainStreamServerInterceptor(t *testing.T) {
191	callCounts := make([]int, 4)
192
193	firstInt := func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
194		if callCounts[0] != 0 {
195			return status.Errorf(codes.Internal, "callCounts[0] should be 0, but got=%d", callCounts[0])
196		}
197		if callCounts[1] != 0 {
198			return status.Errorf(codes.Internal, "callCounts[1] should be 0, but got=%d", callCounts[1])
199		}
200		if callCounts[2] != 0 {
201			return status.Errorf(codes.Internal, "callCounts[2] should be 0, but got=%d", callCounts[2])
202		}
203		if callCounts[3] != 0 {
204			return status.Errorf(codes.Internal, "callCounts[3] should be 0, but got=%d", callCounts[3])
205		}
206		callCounts[0]++
207		return handler(srv, stream)
208	}
209
210	secondInt := func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
211		if callCounts[0] != 1 {
212			return status.Errorf(codes.Internal, "callCounts[0] should be 1, but got=%d", callCounts[0])
213		}
214		if callCounts[1] != 0 {
215			return status.Errorf(codes.Internal, "callCounts[1] should be 0, but got=%d", callCounts[1])
216		}
217		if callCounts[2] != 0 {
218			return status.Errorf(codes.Internal, "callCounts[2] should be 0, but got=%d", callCounts[2])
219		}
220		if callCounts[3] != 0 {
221			return status.Errorf(codes.Internal, "callCounts[3] should be 0, but got=%d", callCounts[3])
222		}
223		callCounts[1]++
224		return handler(srv, stream)
225	}
226
227	lastInt := func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
228		if callCounts[0] != 1 {
229			return status.Errorf(codes.Internal, "callCounts[0] should be 1, but got=%d", callCounts[0])
230		}
231		if callCounts[1] != 1 {
232			return status.Errorf(codes.Internal, "callCounts[1] should be 1, but got=%d", callCounts[1])
233		}
234		if callCounts[2] != 0 {
235			return status.Errorf(codes.Internal, "callCounts[2] should be 0, but got=%d", callCounts[2])
236		}
237		if callCounts[3] != 0 {
238			return status.Errorf(codes.Internal, "callCounts[3] should be 0, but got=%d", callCounts[3])
239		}
240		callCounts[2]++
241		return handler(srv, stream)
242	}
243
244	sopts := []grpc.ServerOption{
245		grpc.ChainStreamInterceptor(firstInt, secondInt, lastInt),
246	}
247
248	ss := &stubServer{
249		fullDuplexCall: func(stream testpb.TestService_FullDuplexCallServer) error {
250			if callCounts[0] != 1 {
251				return status.Errorf(codes.Internal, "callCounts[0] should be 1, but got=%d", callCounts[0])
252			}
253			if callCounts[1] != 1 {
254				return status.Errorf(codes.Internal, "callCounts[1] should be 1, but got=%d", callCounts[1])
255			}
256			if callCounts[2] != 1 {
257				return status.Errorf(codes.Internal, "callCounts[2] should be 0, but got=%d", callCounts[2])
258			}
259			if callCounts[3] != 0 {
260				return status.Errorf(codes.Internal, "callCounts[3] should be 0, but got=%d", callCounts[3])
261			}
262			callCounts[3]++
263			return nil
264		},
265	}
266	if err := ss.Start(sopts); err != nil {
267		t.Fatalf("Error starting endpoint server: %v", err)
268	}
269	defer ss.Stop()
270
271	stream, err := ss.client.FullDuplexCall(context.Background())
272	if err != nil {
273		t.Fatalf("failed to FullDuplexCall: %v", err)
274	}
275
276	_, err = stream.Recv()
277	if err != io.EOF {
278		t.Fatalf("failed to recv from stream: %v", err)
279	}
280
281	if callCounts[3] != 1 {
282		t.Fatalf("callCounts[3] should be 1, but got=%d", callCounts[3])
283	}
284}
285