1package metadatahandler
2
3import (
4	"context"
5	"fmt"
6	"testing"
7	"time"
8
9	grpcmwtags "github.com/grpc-ecosystem/go-grpc-middleware/tags"
10	"github.com/stretchr/testify/require"
11	"gitlab.com/gitlab-org/gitaly/v14/internal/testhelper"
12	"gitlab.com/gitlab-org/labkit/correlation"
13	"google.golang.org/grpc/metadata"
14)
15
16const (
17	correlationID = "CORRELATION_ID"
18	clientName    = "CLIENT_NAME"
19)
20
21func TestAddMetadataTags(t *testing.T) {
22	baseContext, cancel := testhelper.Context()
23	defer cancel()
24
25	testCases := []struct {
26		desc             string
27		metadata         metadata.MD
28		deadline         bool
29		expectedMetatags metadataTags
30	}{
31		{
32			desc:     "empty metadata",
33			metadata: metadata.Pairs(),
34			deadline: false,
35			expectedMetatags: metadataTags{
36				clientName:   unknownValue,
37				callSite:     unknownValue,
38				authVersion:  unknownValue,
39				deadlineType: "none",
40			},
41		},
42		{
43			desc:     "context containing metadata",
44			metadata: metadata.Pairs("call_site", "testsite"),
45			deadline: false,
46			expectedMetatags: metadataTags{
47				clientName:   unknownValue,
48				callSite:     "testsite",
49				authVersion:  unknownValue,
50				deadlineType: "none",
51			},
52		},
53		{
54			desc:     "context containing metadata and a deadline",
55			metadata: metadata.Pairs("call_site", "testsite"),
56			deadline: true,
57			expectedMetatags: metadataTags{
58				clientName:   unknownValue,
59				callSite:     "testsite",
60				authVersion:  unknownValue,
61				deadlineType: unknownValue,
62			},
63		},
64		{
65			desc:     "context containing metadata and a deadline type",
66			metadata: metadata.Pairs("deadline_type", "regular"),
67			deadline: true,
68			expectedMetatags: metadataTags{
69				clientName:   unknownValue,
70				callSite:     unknownValue,
71				authVersion:  unknownValue,
72				deadlineType: "regular",
73			},
74		},
75		{
76			desc:     "a context without deadline but with deadline type",
77			metadata: metadata.Pairs("deadline_type", "regular"),
78			deadline: false,
79			expectedMetatags: metadataTags{
80				clientName:   unknownValue,
81				callSite:     unknownValue,
82				authVersion:  unknownValue,
83				deadlineType: "none",
84			},
85		},
86		{
87			desc:     "with a context containing metadata",
88			metadata: metadata.Pairs("deadline_type", "regular", "client_name", "rails"),
89			deadline: true,
90			expectedMetatags: metadataTags{
91				clientName:   "rails",
92				callSite:     unknownValue,
93				authVersion:  unknownValue,
94				deadlineType: "regular",
95			},
96		},
97	}
98
99	for _, testCase := range testCases {
100		t.Run(testCase.desc, func(t *testing.T) {
101			ctx := metadata.NewIncomingContext(baseContext, testCase.metadata)
102			if testCase.deadline {
103				ctx, cancel = context.WithDeadline(ctx, time.Now().Add(50*time.Millisecond))
104				defer cancel()
105			}
106			require.Equal(t, testCase.expectedMetatags, addMetadataTags(ctx, "unary"))
107		})
108	}
109}
110
111func verifyHandler(ctx context.Context, req interface{}) (interface{}, error) {
112	require, ok := req.(*require.Assertions)
113	if !ok {
114		return nil, fmt.Errorf("unexpected type conversion failure")
115	}
116	metaTags := addMetadataTags(ctx, "unary")
117	require.Equal(clientName, metaTags.clientName)
118
119	tags := grpcmwtags.Extract(ctx)
120	require.True(tags.Has(CorrelationIDKey))
121	require.True(tags.Has(ClientNameKey))
122	values := tags.Values()
123	require.Equal(correlationID, values[CorrelationIDKey])
124	require.Equal(clientName, values[ClientNameKey])
125
126	return nil, nil
127}
128
129func TestGRPCTags(t *testing.T) {
130	require := require.New(t)
131
132	ctx := metadata.NewIncomingContext(
133		correlation.ContextWithCorrelation(
134			correlation.ContextWithClientName(
135				context.Background(),
136				clientName,
137			),
138			correlationID,
139		),
140		metadata.Pairs(),
141	)
142
143	interceptor := grpcmwtags.UnaryServerInterceptor()
144
145	_, err := interceptor(ctx, require, nil, verifyHandler)
146	require.NoError(err)
147}
148
149func Test_extractServiceName(t *testing.T) {
150	tests := []struct {
151		name           string
152		fullMethodName string
153		want           string
154	}{
155		{
156			name:           "blank",
157			fullMethodName: "",
158			want:           unknownValue,
159		}, {
160			name:           "normal",
161			fullMethodName: "/gitaly.OperationService/method",
162			want:           "gitaly.OperationService",
163		}, {
164			name:           "malformed",
165			fullMethodName: "//method",
166			want:           "",
167		},
168	}
169	for _, tt := range tests {
170		t.Run(tt.name, func(t *testing.T) {
171			if got := extractServiceName(tt.fullMethodName); got != tt.want {
172				t.Errorf("extractServiceName() = %v, want %v", got, tt.want)
173			}
174		})
175	}
176}
177