1package limithandler_test
2
3import (
4	"context"
5	"net"
6	"sync"
7	"testing"
8	"time"
9
10	"github.com/stretchr/testify/assert"
11	"github.com/stretchr/testify/require"
12	"gitlab.com/gitlab-org/gitaly/v14/internal/middleware/limithandler"
13	pb "gitlab.com/gitlab-org/gitaly/v14/internal/middleware/limithandler/testdata"
14	"gitlab.com/gitlab-org/gitaly/v14/internal/testhelper"
15	"google.golang.org/grpc"
16)
17
18func TestMain(m *testing.M) {
19	testhelper.Run(m)
20}
21
22func fixedLockKey(ctx context.Context) string {
23	return "fixed-id"
24}
25
26func TestUnaryLimitHandler(t *testing.T) {
27	s := &server{blockCh: make(chan struct{})}
28
29	limithandler.SetMaxRepoConcurrency(map[string]int{"/test.limithandler.Test/Unary": 2})
30	lh := limithandler.New(fixedLockKey)
31	interceptor := lh.UnaryInterceptor()
32	srv, serverSocketPath := runServer(t, s, grpc.UnaryInterceptor(interceptor))
33	defer srv.Stop()
34
35	client, conn := newClient(t, serverSocketPath)
36	defer conn.Close()
37
38	ctx, cancel := testhelper.Context()
39	defer cancel()
40
41	wg := &sync.WaitGroup{}
42	for i := 0; i < 10; i++ {
43		wg.Add(1)
44		go func() {
45			defer wg.Done()
46
47			resp, err := client.Unary(ctx, &pb.UnaryRequest{})
48			if !assert.NoError(t, err) {
49				return
50			}
51			if !assert.NotNil(t, resp) {
52				return
53			}
54			assert.True(t, resp.Ok)
55		}()
56	}
57
58	time.Sleep(100 * time.Millisecond)
59
60	require.Equal(t, 2, s.getRequestCount())
61
62	close(s.blockCh)
63	wg.Wait()
64}
65
66func TestStreamLimitHandler(t *testing.T) {
67	testCases := []struct {
68		desc                 string
69		fullname             string
70		f                    func(*testing.T, context.Context, pb.TestClient)
71		maxConcurrency       int
72		expectedRequestCount int
73	}{
74		{
75			desc:     "Single request, multiple responses",
76			fullname: "/test.limithandler.Test/StreamOutput",
77			f: func(t *testing.T, ctx context.Context, client pb.TestClient) {
78				stream, err := client.StreamOutput(ctx, &pb.StreamOutputRequest{})
79				require.NoError(t, err)
80				require.NotNil(t, stream)
81
82				r, err := stream.Recv()
83				require.NoError(t, err)
84				require.NotNil(t, r)
85				require.True(t, r.Ok)
86			},
87			maxConcurrency:       3,
88			expectedRequestCount: 3,
89		},
90		{
91			desc:     "Multiple requests, single response",
92			fullname: "/test.limithandler.Test/StreamInput",
93			f: func(t *testing.T, ctx context.Context, client pb.TestClient) {
94				stream, err := client.StreamInput(ctx)
95				require.NoError(t, err)
96				require.NotNil(t, stream)
97
98				require.NoError(t, stream.Send(&pb.StreamInputRequest{}))
99				r, err := stream.CloseAndRecv()
100				require.NoError(t, err)
101				require.NotNil(t, r)
102				require.True(t, r.Ok)
103			},
104			maxConcurrency:       3,
105			expectedRequestCount: 3,
106		},
107		{
108			desc:     "Multiple requests, multiple responses",
109			fullname: "/test.limithandler.Test/Bidirectional",
110			f: func(t *testing.T, ctx context.Context, client pb.TestClient) {
111				stream, err := client.Bidirectional(ctx)
112				require.NoError(t, err)
113				require.NotNil(t, stream)
114
115				require.NoError(t, stream.Send(&pb.BidirectionalRequest{}))
116				require.NoError(t, stream.CloseSend())
117
118				r, err := stream.Recv()
119				require.NoError(t, err)
120				require.NotNil(t, r)
121				require.True(t, r.Ok)
122			},
123			maxConcurrency:       3,
124			expectedRequestCount: 3,
125		},
126		{
127			// Make sure that _streams_ are limited but that _requests_ on each
128			// allowed stream are not limited.
129			desc:     "Multiple requests with same id, multiple responses",
130			fullname: "/test.limithandler.Test/Bidirectional",
131			f: func(t *testing.T, ctx context.Context, client pb.TestClient) {
132				stream, err := client.Bidirectional(ctx)
133				require.NoError(t, err)
134				require.NotNil(t, stream)
135
136				// Since the concurrency id is fixed all requests have the same
137				// id, but subsequent requests in a stream, even with the same
138				// id, should bypass the concurrency limiter
139				for i := 0; i < 10; i++ {
140					require.NoError(t, stream.Send(&pb.BidirectionalRequest{}))
141				}
142				require.NoError(t, stream.CloseSend())
143
144				r, err := stream.Recv()
145				require.NoError(t, err)
146				require.NotNil(t, r)
147				require.True(t, r.Ok)
148			},
149			maxConcurrency: 3,
150			// 3 (concurrent streams allowed) * 10 (requests per stream)
151			expectedRequestCount: 30,
152		},
153		{
154			desc:     "With a max concurrency of 0",
155			fullname: "/test.limithandler.Test/StreamOutput",
156			f: func(t *testing.T, ctx context.Context, client pb.TestClient) {
157				stream, err := client.StreamOutput(ctx, &pb.StreamOutputRequest{})
158				require.NoError(t, err)
159				require.NotNil(t, stream)
160
161				r, err := stream.Recv()
162				require.NoError(t, err)
163				require.NotNil(t, r)
164				require.True(t, r.Ok)
165			},
166			maxConcurrency:       0,
167			expectedRequestCount: 10, // Allow all
168		},
169	}
170
171	for _, tc := range testCases {
172		t.Run(tc.desc, func(t *testing.T) {
173			s := &server{blockCh: make(chan struct{})}
174
175			limithandler.SetMaxRepoConcurrency(map[string]int{
176				tc.fullname: tc.maxConcurrency,
177			})
178
179			lh := limithandler.New(fixedLockKey)
180			interceptor := lh.StreamInterceptor()
181			srv, serverSocketPath := runServer(t, s, grpc.StreamInterceptor(interceptor))
182			defer srv.Stop()
183
184			client, conn := newClient(t, serverSocketPath)
185			defer conn.Close()
186
187			ctx, cancel := testhelper.Context()
188			defer cancel()
189
190			wg := &sync.WaitGroup{}
191			for i := 0; i < 10; i++ {
192				wg.Add(1)
193				go func() {
194					defer wg.Done()
195					tc.f(t, ctx, client)
196				}()
197			}
198
199			time.Sleep(100 * time.Millisecond)
200
201			require.Equal(t, tc.expectedRequestCount, s.getRequestCount())
202
203			close(s.blockCh)
204			wg.Wait()
205		})
206	}
207}
208
209func runServer(t *testing.T, s *server, opt ...grpc.ServerOption) (*grpc.Server, string) {
210	serverSocketPath := testhelper.GetTemporaryGitalySocketFileName(t)
211	grpcServer := grpc.NewServer(opt...)
212	pb.RegisterTestServer(grpcServer, s)
213
214	lis, err := net.Listen("unix", serverSocketPath)
215	require.NoError(t, err)
216
217	go grpcServer.Serve(lis)
218
219	return grpcServer, "unix://" + serverSocketPath
220}
221
222func newClient(t *testing.T, serverSocketPath string) (pb.TestClient, *grpc.ClientConn) {
223	connOpts := []grpc.DialOption{
224		grpc.WithInsecure(),
225	}
226	conn, err := grpc.Dial(serverSocketPath, connOpts...)
227	if err != nil {
228		t.Fatal(err)
229	}
230
231	return pb.NewTestClient(conn), conn
232}
233