1package grpccorrelation
2
3import (
4	"context"
5	"testing"
6
7	"github.com/stretchr/testify/assert"
8	"github.com/stretchr/testify/require"
9	"gitlab.com/gitlab-org/labkit/correlation"
10	"google.golang.org/grpc"
11	"google.golang.org/grpc/metadata"
12)
13
14var (
15	_ grpc.ServerTransportStream = (*mockServerTransportStream)(nil)
16	_ grpc.ServerStream          = (*mockServerStream)(nil)
17)
18
19type tcType struct {
20	name               string
21	md                 metadata.MD
22	withoutPropagation bool
23
24	expectRandom       bool
25	expectedClientName string
26}
27
28func TestServerCorrelationInterceptors(t *testing.T) {
29	tests := []tcType{
30		{
31			name: "default",
32			md: metadata.Pairs(
33				metadataCorrelatorKey,
34				correlationID,
35				metadataClientNameKey,
36				clientName,
37			),
38			expectedClientName: clientName,
39		},
40		{
41			name: "id present but not trusted",
42			md: metadata.Pairs(
43				metadataCorrelatorKey,
44				correlationID,
45			),
46			withoutPropagation: true,
47			expectRandom:       true,
48		},
49		{
50			name: "id present, trusted but empty",
51			md: metadata.Pairs(
52				metadataCorrelatorKey,
53				"",
54			),
55			withoutPropagation: true,
56			expectRandom:       true,
57		},
58		{
59			name:               "id absent and not trusted",
60			md:                 metadata.Pairs(),
61			withoutPropagation: true,
62			expectRandom:       true,
63		},
64		{
65			name:         "id absent and trusted",
66			md:           metadata.Pairs(),
67			expectRandom: true,
68		},
69		{
70			name:         "no metadata",
71			md:           nil,
72			expectRandom: true,
73		},
74	}
75
76	t.Run("unary", func(t *testing.T) {
77		for _, tc := range tests {
78			t.Run(tc.name, testUnaryServerCorrelationInterceptor(tc, false))
79			t.Run(tc.name+" (reverse)", testUnaryServerCorrelationInterceptor(tc, true))
80		}
81	})
82	t.Run("streaming", func(t *testing.T) {
83		for _, tc := range tests {
84			t.Run(tc.name, testStreamingServerCorrelationInterceptor(tc, false))
85			t.Run(tc.name+" (reverse)", testStreamingServerCorrelationInterceptor(tc, true))
86		}
87	})
88}
89
90func testUnaryServerCorrelationInterceptor(tc tcType, reverseCorrelationID bool) func(*testing.T) {
91	return func(t *testing.T) {
92		t.Helper()
93
94		sts := &mockServerTransportStream{}
95		ctx := grpc.NewContextWithServerTransportStream(context.Background(), sts)
96		if tc.md != nil {
97			ctx = metadata.NewIncomingContext(ctx, tc.md)
98		}
99		interceptor := UnaryServerCorrelationInterceptor(constructServerOpts(tc, reverseCorrelationID)...)
100		_, err := interceptor(
101			ctx,
102			nil,
103			nil,
104			func(ctx context.Context, req interface{}) (interface{}, error) {
105				testServerCtx(ctx, t, tc, reverseCorrelationID, sts.header)
106				return nil, nil
107			},
108		)
109		require.NoError(t, err)
110	}
111}
112
113func testStreamingServerCorrelationInterceptor(tc tcType, reverseCorrelationID bool) func(*testing.T) {
114	return func(t *testing.T) {
115		t.Helper()
116
117		ctx := context.Background()
118		if tc.md != nil {
119			ctx = metadata.NewIncomingContext(ctx, tc.md)
120		}
121		ss := &mockServerStream{
122			ctx: ctx,
123		}
124		interceptor := StreamServerCorrelationInterceptor(constructServerOpts(tc, reverseCorrelationID)...)
125		err := interceptor(
126			nil,
127			ss,
128			nil,
129			func(srv interface{}, stream grpc.ServerStream) error {
130				testServerCtx(stream.Context(), t, tc, reverseCorrelationID, ss.header)
131				return nil
132			},
133		)
134		require.NoError(t, err)
135	}
136}
137
138func constructServerOpts(tc tcType, reverseCorrelationID bool) []ServerCorrelationInterceptorOption {
139	var opts []ServerCorrelationInterceptorOption
140	if tc.withoutPropagation {
141		opts = append(opts, WithoutPropagation())
142	}
143	if reverseCorrelationID {
144		opts = append(opts, WithReversePropagation())
145	}
146	return opts
147}
148
149func testServerCtx(ctx context.Context, t *testing.T, tc tcType, reverseCorrelationID bool, header metadata.MD) {
150	t.Helper()
151
152	actualID := correlation.ExtractFromContext(ctx)
153	if tc.expectRandom {
154		assert.NotEqual(t, correlationID, actualID)
155		assert.NotEmpty(t, actualID)
156	} else {
157		assert.Equal(t, correlationID, actualID)
158	}
159	vals := header.Get(metadataCorrelatorKey)
160	if reverseCorrelationID {
161		assert.Equal(t, []string{actualID}, vals)
162	} else {
163		assert.Empty(t, vals)
164	}
165	assert.Equal(t, tc.expectedClientName, correlation.ExtractClientNameFromContext(ctx))
166}
167
168type mockServerTransportStream struct {
169	header metadata.MD
170}
171
172func (s *mockServerTransportStream) Method() string {
173	panic("implement me")
174}
175
176func (s *mockServerTransportStream) SetHeader(md metadata.MD) error {
177	s.header = metadata.Join(s.header, md)
178	return nil
179}
180
181func (s *mockServerTransportStream) SendHeader(md metadata.MD) error {
182	panic("implement me")
183}
184
185func (s *mockServerTransportStream) SetTrailer(md metadata.MD) error {
186	panic("implement me")
187}
188
189type mockServerStream struct {
190	ctx    context.Context
191	header metadata.MD
192}
193
194func (s *mockServerStream) SetHeader(md metadata.MD) error {
195	s.header = metadata.Join(s.header, md)
196	return nil
197}
198
199func (s *mockServerStream) SendHeader(md metadata.MD) error {
200	panic("implement me")
201}
202
203func (s *mockServerStream) SetTrailer(md metadata.MD) {
204	panic("implement me")
205}
206
207func (s *mockServerStream) Context() context.Context {
208	return s.ctx
209}
210
211func (s *mockServerStream) SendMsg(m interface{}) error {
212	panic("implement me")
213}
214
215func (s *mockServerStream) RecvMsg(m interface{}) error {
216	panic("implement me")
217}
218