1package middleware
2
3import (
4	"context"
5	"net"
6	"testing"
7	"time"
8
9	"github.com/stretchr/testify/assert"
10	"github.com/stretchr/testify/require"
11	"gitlab.com/gitlab-org/gitaly/v14/internal/helper"
12	"gitlab.com/gitlab-org/gitaly/v14/internal/praefect/grpc-proxy/proxy"
13	"gitlab.com/gitlab-org/gitaly/v14/internal/praefect/mock"
14	"gitlab.com/gitlab-org/gitaly/v14/internal/praefect/nodes/tracker"
15	"gitlab.com/gitlab-org/gitaly/v14/internal/praefect/protoregistry"
16	"gitlab.com/gitlab-org/gitaly/v14/internal/testhelper"
17	"gitlab.com/gitlab-org/gitaly/v14/internal/testhelper/testcfg"
18	"google.golang.org/grpc"
19	"google.golang.org/protobuf/types/known/emptypb"
20)
21
22type simpleService struct {
23	mock.UnimplementedSimpleServiceServer
24}
25
26func (s *simpleService) RepoAccessorUnary(ctx context.Context, in *mock.RepoRequest) (*emptypb.Empty, error) {
27	if in.GetRepo() == nil {
28		return nil, helper.ErrInternalf("error")
29	}
30
31	return &emptypb.Empty{}, nil
32}
33
34func (s *simpleService) RepoMutatorUnary(ctx context.Context, in *mock.RepoRequest) (*emptypb.Empty, error) {
35	if in.GetRepo() == nil {
36		return nil, helper.ErrInternalf("error")
37	}
38
39	return &emptypb.Empty{}, nil
40}
41
42func TestStreamInterceptor(t *testing.T) {
43	ctx, cancel := testhelper.Context()
44	defer cancel()
45
46	window := 1 * time.Second
47	threshold := 5
48	errTracker, err := tracker.NewErrors(ctx, window, uint32(threshold), uint32(threshold))
49	require.NoError(t, err)
50	nodeName := "node-1"
51
52	internalSrv := grpc.NewServer()
53
54	internalServerSocketPath := testhelper.GetTemporaryGitalySocketFileName(t)
55	lis, err := net.Listen("unix", internalServerSocketPath)
56	require.NoError(t, err)
57
58	registry, err := protoregistry.NewFromPaths("praefect/mock/mock.proto")
59	require.NoError(t, err)
60
61	mock.RegisterSimpleServiceServer(internalSrv, &simpleService{})
62
63	go internalSrv.Serve(lis)
64	defer internalSrv.Stop()
65
66	srvOptions := []grpc.ServerOption{
67		grpc.ForceServerCodec(proxy.NewCodec()),
68		grpc.UnknownServiceHandler(proxy.TransparentHandler(func(ctx context.Context,
69			fullMethodName string,
70			peeker proxy.StreamPeeker,
71		) (*proxy.StreamParameters, error) {
72			cc, err := grpc.Dial("unix://"+internalServerSocketPath,
73				grpc.WithDefaultCallOptions(grpc.ForceCodec(proxy.NewCodec())),
74				grpc.WithInsecure(),
75				grpc.WithStreamInterceptor(StreamErrorHandler(registry, errTracker, nodeName)),
76			)
77			require.NoError(t, err)
78			f, err := peeker.Peek()
79			require.NoError(t, err)
80			return proxy.NewStreamParameters(proxy.Destination{Conn: cc, Ctx: ctx, Msg: f}, nil, func() error { return nil }, nil), nil
81		})),
82	}
83
84	praefectSocket := testhelper.GetTemporaryGitalySocketFileName(t)
85	praefectLis, err := net.Listen("unix", praefectSocket)
86	require.NoError(t, err)
87
88	praefectSrv := grpc.NewServer(srvOptions...)
89	defer praefectSrv.Stop()
90	go praefectSrv.Serve(praefectLis)
91
92	praefectCC, err := grpc.Dial("unix://"+praefectSocket, grpc.WithInsecure())
93	require.NoError(t, err)
94
95	simpleClient := mock.NewSimpleServiceClient(praefectCC)
96
97	_, repo, _ := testcfg.BuildWithRepo(t)
98
99	for i := 0; i < threshold; i++ {
100		_, err = simpleClient.RepoAccessorUnary(ctx, &mock.RepoRequest{
101			Repo: repo,
102		})
103		require.NoError(t, err)
104		_, err = simpleClient.RepoMutatorUnary(ctx, &mock.RepoRequest{
105			Repo: repo,
106		})
107		require.NoError(t, err)
108	}
109
110	assert.False(t, errTracker.WriteThresholdReached(nodeName))
111	assert.False(t, errTracker.ReadThresholdReached(nodeName))
112
113	for i := 0; i < threshold; i++ {
114		_, err = simpleClient.RepoAccessorUnary(ctx, &mock.RepoRequest{
115			Repo: nil,
116		})
117		require.Error(t, err)
118		_, err = simpleClient.RepoMutatorUnary(ctx, &mock.RepoRequest{
119			Repo: nil,
120		})
121		require.Error(t, err)
122	}
123
124	assert.True(t, errTracker.WriteThresholdReached(nodeName))
125	assert.True(t, errTracker.ReadThresholdReached(nodeName))
126
127	time.Sleep(window)
128
129	for i := 0; i < threshold; i++ {
130		_, err = simpleClient.RepoAccessorUnary(ctx, &mock.RepoRequest{
131			Repo: repo,
132		})
133		require.NoError(t, err)
134		_, err = simpleClient.RepoMutatorUnary(ctx, &mock.RepoRequest{
135			Repo: repo,
136		})
137		require.NoError(t, err)
138	}
139
140	assert.False(t, errTracker.WriteThresholdReached(nodeName))
141	assert.False(t, errTracker.ReadThresholdReached(nodeName))
142}
143