1package metadatahandler 2 3import ( 4 "context" 5 "fmt" 6 "testing" 7 "time" 8 9 grpcmwtags "github.com/grpc-ecosystem/go-grpc-middleware/tags" 10 "github.com/stretchr/testify/require" 11 "gitlab.com/gitlab-org/gitaly/v14/internal/testhelper" 12 "gitlab.com/gitlab-org/labkit/correlation" 13 "google.golang.org/grpc/metadata" 14) 15 16const ( 17 correlationID = "CORRELATION_ID" 18 clientName = "CLIENT_NAME" 19) 20 21func TestAddMetadataTags(t *testing.T) { 22 baseContext, cancel := testhelper.Context() 23 defer cancel() 24 25 testCases := []struct { 26 desc string 27 metadata metadata.MD 28 deadline bool 29 expectedMetatags metadataTags 30 }{ 31 { 32 desc: "empty metadata", 33 metadata: metadata.Pairs(), 34 deadline: false, 35 expectedMetatags: metadataTags{ 36 clientName: unknownValue, 37 callSite: unknownValue, 38 authVersion: unknownValue, 39 deadlineType: "none", 40 }, 41 }, 42 { 43 desc: "context containing metadata", 44 metadata: metadata.Pairs("call_site", "testsite"), 45 deadline: false, 46 expectedMetatags: metadataTags{ 47 clientName: unknownValue, 48 callSite: "testsite", 49 authVersion: unknownValue, 50 deadlineType: "none", 51 }, 52 }, 53 { 54 desc: "context containing metadata and a deadline", 55 metadata: metadata.Pairs("call_site", "testsite"), 56 deadline: true, 57 expectedMetatags: metadataTags{ 58 clientName: unknownValue, 59 callSite: "testsite", 60 authVersion: unknownValue, 61 deadlineType: unknownValue, 62 }, 63 }, 64 { 65 desc: "context containing metadata and a deadline type", 66 metadata: metadata.Pairs("deadline_type", "regular"), 67 deadline: true, 68 expectedMetatags: metadataTags{ 69 clientName: unknownValue, 70 callSite: unknownValue, 71 authVersion: unknownValue, 72 deadlineType: "regular", 73 }, 74 }, 75 { 76 desc: "a context without deadline but with deadline type", 77 metadata: metadata.Pairs("deadline_type", "regular"), 78 deadline: false, 79 expectedMetatags: metadataTags{ 80 clientName: unknownValue, 81 callSite: unknownValue, 82 authVersion: unknownValue, 83 deadlineType: "none", 84 }, 85 }, 86 { 87 desc: "with a context containing metadata", 88 metadata: metadata.Pairs("deadline_type", "regular", "client_name", "rails"), 89 deadline: true, 90 expectedMetatags: metadataTags{ 91 clientName: "rails", 92 callSite: unknownValue, 93 authVersion: unknownValue, 94 deadlineType: "regular", 95 }, 96 }, 97 } 98 99 for _, testCase := range testCases { 100 t.Run(testCase.desc, func(t *testing.T) { 101 ctx := metadata.NewIncomingContext(baseContext, testCase.metadata) 102 if testCase.deadline { 103 ctx, cancel = context.WithDeadline(ctx, time.Now().Add(50*time.Millisecond)) 104 defer cancel() 105 } 106 require.Equal(t, testCase.expectedMetatags, addMetadataTags(ctx, "unary")) 107 }) 108 } 109} 110 111func verifyHandler(ctx context.Context, req interface{}) (interface{}, error) { 112 require, ok := req.(*require.Assertions) 113 if !ok { 114 return nil, fmt.Errorf("unexpected type conversion failure") 115 } 116 metaTags := addMetadataTags(ctx, "unary") 117 require.Equal(clientName, metaTags.clientName) 118 119 tags := grpcmwtags.Extract(ctx) 120 require.True(tags.Has(CorrelationIDKey)) 121 require.True(tags.Has(ClientNameKey)) 122 values := tags.Values() 123 require.Equal(correlationID, values[CorrelationIDKey]) 124 require.Equal(clientName, values[ClientNameKey]) 125 126 return nil, nil 127} 128 129func TestGRPCTags(t *testing.T) { 130 require := require.New(t) 131 132 ctx := metadata.NewIncomingContext( 133 correlation.ContextWithCorrelation( 134 correlation.ContextWithClientName( 135 context.Background(), 136 clientName, 137 ), 138 correlationID, 139 ), 140 metadata.Pairs(), 141 ) 142 143 interceptor := grpcmwtags.UnaryServerInterceptor() 144 145 _, err := interceptor(ctx, require, nil, verifyHandler) 146 require.NoError(err) 147} 148 149func Test_extractServiceName(t *testing.T) { 150 tests := []struct { 151 name string 152 fullMethodName string 153 want string 154 }{ 155 { 156 name: "blank", 157 fullMethodName: "", 158 want: unknownValue, 159 }, { 160 name: "normal", 161 fullMethodName: "/gitaly.OperationService/method", 162 want: "gitaly.OperationService", 163 }, { 164 name: "malformed", 165 fullMethodName: "//method", 166 want: "", 167 }, 168 } 169 for _, tt := range tests { 170 t.Run(tt.name, func(t *testing.T) { 171 if got := extractServiceName(tt.fullMethodName); got != tt.want { 172 t.Errorf("extractServiceName() = %v, want %v", got, tt.want) 173 } 174 }) 175 } 176} 177