1// Copyright 2016 Michal Witkowski. All Rights Reserved.
2// See LICENSE for licensing terms.
3
4package grpc_middleware
5
6import (
7	"context"
8	"fmt"
9	"testing"
10
11	"github.com/stretchr/testify/require"
12	"google.golang.org/grpc"
13	"google.golang.org/grpc/metadata"
14)
15
16var (
17	someServiceName  = "SomeService.StreamMethod"
18	parentUnaryInfo  = &grpc.UnaryServerInfo{FullMethod: someServiceName}
19	parentStreamInfo = &grpc.StreamServerInfo{
20		FullMethod:     someServiceName,
21		IsServerStream: true,
22	}
23	someValue     = 1
24	parentContext = context.WithValue(context.TODO(), "parent", someValue)
25)
26
27func TestChainUnaryServer(t *testing.T) {
28	input := "input"
29	output := "output"
30
31	first := func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
32		requireContextValue(t, ctx, "parent", "first interceptor must know the parent context value")
33		require.Equal(t, parentUnaryInfo, info, "first interceptor must know the someUnaryServerInfo")
34		ctx = context.WithValue(ctx, "first", 1)
35		return handler(ctx, req)
36	}
37	second := func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
38		requireContextValue(t, ctx, "parent", "second interceptor must know the parent context value")
39		requireContextValue(t, ctx, "first", "second interceptor must know the first context value")
40		require.Equal(t, parentUnaryInfo, info, "second interceptor must know the someUnaryServerInfo")
41		ctx = context.WithValue(ctx, "second", 1)
42		return handler(ctx, req)
43	}
44	handler := func(ctx context.Context, req interface{}) (interface{}, error) {
45		require.EqualValues(t, input, req, "handler must get the input")
46		requireContextValue(t, ctx, "parent", "handler must know the parent context value")
47		requireContextValue(t, ctx, "first", "handler must know the first context value")
48		requireContextValue(t, ctx, "second", "handler must know the second context value")
49		return output, nil
50	}
51
52	chain := ChainUnaryServer(first, second)
53	out, _ := chain(parentContext, input, parentUnaryInfo, handler)
54	require.EqualValues(t, output, out, "chain must return handler's output")
55}
56
57func TestChainStreamServer(t *testing.T) {
58	someService := &struct{}{}
59	recvMessage := "received"
60	sentMessage := "sent"
61	outputError := fmt.Errorf("some error")
62
63	first := func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
64		requireContextValue(t, stream.Context(), "parent", "first interceptor must know the parent context value")
65		require.Equal(t, parentStreamInfo, info, "first interceptor must know the parentStreamInfo")
66		require.Equal(t, someService, srv, "first interceptor must know someService")
67		wrapped := WrapServerStream(stream)
68		wrapped.WrappedContext = context.WithValue(stream.Context(), "first", 1)
69		return handler(srv, wrapped)
70	}
71	second := func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
72		requireContextValue(t, stream.Context(), "parent", "second interceptor must know the parent context value")
73		requireContextValue(t, stream.Context(), "first", "second interceptor must know the first context value")
74		require.Equal(t, parentStreamInfo, info, "second interceptor must know the parentStreamInfo")
75		require.Equal(t, someService, srv, "second interceptor must know someService")
76		wrapped := WrapServerStream(stream)
77		wrapped.WrappedContext = context.WithValue(stream.Context(), "second", 1)
78		return handler(srv, wrapped)
79	}
80	handler := func(srv interface{}, stream grpc.ServerStream) error {
81		require.Equal(t, someService, srv, "handler must know someService")
82		requireContextValue(t, stream.Context(), "parent", "handler must know the parent context value")
83		requireContextValue(t, stream.Context(), "first", "handler must know the first context value")
84		requireContextValue(t, stream.Context(), "second", "handler must know the second context value")
85		require.NoError(t, stream.RecvMsg(recvMessage), "handler must have access to stream messages")
86		require.NoError(t, stream.SendMsg(sentMessage), "handler must be able to send stream messages")
87		return outputError
88	}
89	fakeStream := &fakeServerStream{ctx: parentContext, recvMessage: recvMessage}
90	chain := ChainStreamServer(first, second)
91	err := chain(someService, fakeStream, parentStreamInfo, handler)
92	require.Equal(t, outputError, err, "chain must return handler's error")
93	require.Equal(t, sentMessage, fakeStream.sentMessage, "handler's sent message must propagate to stream")
94}
95
96func TestChainUnaryClient(t *testing.T) {
97	ignoredMd := metadata.Pairs("foo", "bar")
98	parentOpts := []grpc.CallOption{grpc.Header(&ignoredMd)}
99	reqMessage := "request"
100	replyMessage := "reply"
101	outputError := fmt.Errorf("some error")
102
103	first := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
104		requireContextValue(t, ctx, "parent", "first must know the parent context value")
105		require.Equal(t, someServiceName, method, "first must know someService")
106		require.Len(t, opts, 1, "first should see parent CallOptions")
107		wrappedCtx := context.WithValue(ctx, "first", 1)
108		return invoker(wrappedCtx, method, req, reply, cc, opts...)
109	}
110	second := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
111		requireContextValue(t, ctx, "parent", "second must know the parent context value")
112		requireContextValue(t, ctx, "first", "second must know the first context value")
113		require.Equal(t, someServiceName, method, "second must know someService")
114		require.Len(t, opts, 1, "second should see parent CallOptions")
115		wrappedOpts := append(opts, grpc.WaitForReady(false))
116		wrappedCtx := context.WithValue(ctx, "second", 1)
117		return invoker(wrappedCtx, method, req, reply, cc, wrappedOpts...)
118	}
119	invoker := func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, opts ...grpc.CallOption) error {
120		require.Equal(t, someServiceName, method, "invoker must know someService")
121		requireContextValue(t, ctx, "parent", "invoker must know the parent context value")
122		requireContextValue(t, ctx, "first", "invoker must know the first context value")
123		requireContextValue(t, ctx, "second", "invoker must know the second context value")
124		require.Len(t, opts, 2, "invoker should see both CallOpts from second and parent")
125		return outputError
126	}
127	chain := ChainUnaryClient(first, second)
128	err := chain(parentContext, someServiceName, reqMessage, replyMessage, nil, invoker, parentOpts...)
129	require.Equal(t, outputError, err, "chain must return invokers's error")
130}
131
132func TestChainStreamClient(t *testing.T) {
133	ignoredMd := metadata.Pairs("foo", "bar")
134	parentOpts := []grpc.CallOption{grpc.Header(&ignoredMd)}
135	clientStream := &fakeClientStream{}
136	fakeStreamDesc := &grpc.StreamDesc{ClientStreams: true, ServerStreams: true, StreamName: someServiceName}
137
138	first := func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
139		requireContextValue(t, ctx, "parent", "first must know the parent context value")
140		require.Equal(t, someServiceName, method, "first must know someService")
141		require.Len(t, opts, 1, "first should see parent CallOptions")
142		wrappedCtx := context.WithValue(ctx, "first", 1)
143		return streamer(wrappedCtx, desc, cc, method, opts...)
144	}
145	second := func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
146		requireContextValue(t, ctx, "parent", "second must know the parent context value")
147		requireContextValue(t, ctx, "first", "second must know the first context value")
148		require.Equal(t, someServiceName, method, "second must know someService")
149		require.Len(t, opts, 1, "second should see parent CallOptions")
150		wrappedOpts := append(opts, grpc.WaitForReady(false))
151		wrappedCtx := context.WithValue(ctx, "second", 1)
152		return streamer(wrappedCtx, desc, cc, method, wrappedOpts...)
153	}
154	streamer := func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) {
155		require.Equal(t, someServiceName, method, "streamer must know someService")
156		require.Equal(t, fakeStreamDesc, desc, "streamer must see the right StreamDesc")
157
158		requireContextValue(t, ctx, "parent", "streamer must know the parent context value")
159		requireContextValue(t, ctx, "first", "streamer must know the first context value")
160		requireContextValue(t, ctx, "second", "streamer must know the second context value")
161		require.Len(t, opts, 2, "streamer should see both CallOpts from second and parent")
162		return clientStream, nil
163	}
164	chain := ChainStreamClient(first, second)
165	someStream, err := chain(parentContext, fakeStreamDesc, nil, someServiceName, streamer, parentOpts...)
166	require.NoError(t, err, "chain must not return an error")
167	require.Equal(t, clientStream, someStream, "chain must return invokers's clientstream")
168}
169
170func requireContextValue(t *testing.T, ctx context.Context, key string, msg ...interface{}) {
171	val := ctx.Value(key)
172	require.NotNil(t, val, msg...)
173	require.Equal(t, someValue, val, msg...)
174}
175