1package middleware 2 3import ( 4 "context" 5 "net" 6 "testing" 7 "time" 8 9 "github.com/stretchr/testify/assert" 10 "github.com/stretchr/testify/require" 11 "gitlab.com/gitlab-org/gitaly/v14/internal/helper" 12 "gitlab.com/gitlab-org/gitaly/v14/internal/praefect/grpc-proxy/proxy" 13 "gitlab.com/gitlab-org/gitaly/v14/internal/praefect/mock" 14 "gitlab.com/gitlab-org/gitaly/v14/internal/praefect/nodes/tracker" 15 "gitlab.com/gitlab-org/gitaly/v14/internal/praefect/protoregistry" 16 "gitlab.com/gitlab-org/gitaly/v14/internal/testhelper" 17 "gitlab.com/gitlab-org/gitaly/v14/internal/testhelper/testcfg" 18 "google.golang.org/grpc" 19 "google.golang.org/protobuf/types/known/emptypb" 20) 21 22type simpleService struct { 23 mock.UnimplementedSimpleServiceServer 24} 25 26func (s *simpleService) RepoAccessorUnary(ctx context.Context, in *mock.RepoRequest) (*emptypb.Empty, error) { 27 if in.GetRepo() == nil { 28 return nil, helper.ErrInternalf("error") 29 } 30 31 return &emptypb.Empty{}, nil 32} 33 34func (s *simpleService) RepoMutatorUnary(ctx context.Context, in *mock.RepoRequest) (*emptypb.Empty, error) { 35 if in.GetRepo() == nil { 36 return nil, helper.ErrInternalf("error") 37 } 38 39 return &emptypb.Empty{}, nil 40} 41 42func TestStreamInterceptor(t *testing.T) { 43 ctx, cancel := testhelper.Context() 44 defer cancel() 45 46 window := 1 * time.Second 47 threshold := 5 48 errTracker, err := tracker.NewErrors(ctx, window, uint32(threshold), uint32(threshold)) 49 require.NoError(t, err) 50 nodeName := "node-1" 51 52 internalSrv := grpc.NewServer() 53 54 internalServerSocketPath := testhelper.GetTemporaryGitalySocketFileName(t) 55 lis, err := net.Listen("unix", internalServerSocketPath) 56 require.NoError(t, err) 57 58 registry, err := protoregistry.NewFromPaths("praefect/mock/mock.proto") 59 require.NoError(t, err) 60 61 mock.RegisterSimpleServiceServer(internalSrv, &simpleService{}) 62 63 go internalSrv.Serve(lis) 64 defer internalSrv.Stop() 65 66 srvOptions := []grpc.ServerOption{ 67 grpc.ForceServerCodec(proxy.NewCodec()), 68 grpc.UnknownServiceHandler(proxy.TransparentHandler(func(ctx context.Context, 69 fullMethodName string, 70 peeker proxy.StreamPeeker, 71 ) (*proxy.StreamParameters, error) { 72 cc, err := grpc.Dial("unix://"+internalServerSocketPath, 73 grpc.WithDefaultCallOptions(grpc.ForceCodec(proxy.NewCodec())), 74 grpc.WithInsecure(), 75 grpc.WithStreamInterceptor(StreamErrorHandler(registry, errTracker, nodeName)), 76 ) 77 require.NoError(t, err) 78 f, err := peeker.Peek() 79 require.NoError(t, err) 80 return proxy.NewStreamParameters(proxy.Destination{Conn: cc, Ctx: ctx, Msg: f}, nil, func() error { return nil }, nil), nil 81 })), 82 } 83 84 praefectSocket := testhelper.GetTemporaryGitalySocketFileName(t) 85 praefectLis, err := net.Listen("unix", praefectSocket) 86 require.NoError(t, err) 87 88 praefectSrv := grpc.NewServer(srvOptions...) 89 defer praefectSrv.Stop() 90 go praefectSrv.Serve(praefectLis) 91 92 praefectCC, err := grpc.Dial("unix://"+praefectSocket, grpc.WithInsecure()) 93 require.NoError(t, err) 94 95 simpleClient := mock.NewSimpleServiceClient(praefectCC) 96 97 _, repo, _ := testcfg.BuildWithRepo(t) 98 99 for i := 0; i < threshold; i++ { 100 _, err = simpleClient.RepoAccessorUnary(ctx, &mock.RepoRequest{ 101 Repo: repo, 102 }) 103 require.NoError(t, err) 104 _, err = simpleClient.RepoMutatorUnary(ctx, &mock.RepoRequest{ 105 Repo: repo, 106 }) 107 require.NoError(t, err) 108 } 109 110 assert.False(t, errTracker.WriteThresholdReached(nodeName)) 111 assert.False(t, errTracker.ReadThresholdReached(nodeName)) 112 113 for i := 0; i < threshold; i++ { 114 _, err = simpleClient.RepoAccessorUnary(ctx, &mock.RepoRequest{ 115 Repo: nil, 116 }) 117 require.Error(t, err) 118 _, err = simpleClient.RepoMutatorUnary(ctx, &mock.RepoRequest{ 119 Repo: nil, 120 }) 121 require.Error(t, err) 122 } 123 124 assert.True(t, errTracker.WriteThresholdReached(nodeName)) 125 assert.True(t, errTracker.ReadThresholdReached(nodeName)) 126 127 time.Sleep(window) 128 129 for i := 0; i < threshold; i++ { 130 _, err = simpleClient.RepoAccessorUnary(ctx, &mock.RepoRequest{ 131 Repo: repo, 132 }) 133 require.NoError(t, err) 134 _, err = simpleClient.RepoMutatorUnary(ctx, &mock.RepoRequest{ 135 Repo: repo, 136 }) 137 require.NoError(t, err) 138 } 139 140 assert.False(t, errTracker.WriteThresholdReached(nodeName)) 141 assert.False(t, errTracker.ReadThresholdReached(nodeName)) 142} 143