1package limithandler_test 2 3import ( 4 "context" 5 "net" 6 "sync" 7 "testing" 8 "time" 9 10 "github.com/stretchr/testify/assert" 11 "github.com/stretchr/testify/require" 12 "gitlab.com/gitlab-org/gitaly/v14/internal/middleware/limithandler" 13 pb "gitlab.com/gitlab-org/gitaly/v14/internal/middleware/limithandler/testdata" 14 "gitlab.com/gitlab-org/gitaly/v14/internal/testhelper" 15 "google.golang.org/grpc" 16) 17 18func TestMain(m *testing.M) { 19 testhelper.Run(m) 20} 21 22func fixedLockKey(ctx context.Context) string { 23 return "fixed-id" 24} 25 26func TestUnaryLimitHandler(t *testing.T) { 27 s := &server{blockCh: make(chan struct{})} 28 29 limithandler.SetMaxRepoConcurrency(map[string]int{"/test.limithandler.Test/Unary": 2}) 30 lh := limithandler.New(fixedLockKey) 31 interceptor := lh.UnaryInterceptor() 32 srv, serverSocketPath := runServer(t, s, grpc.UnaryInterceptor(interceptor)) 33 defer srv.Stop() 34 35 client, conn := newClient(t, serverSocketPath) 36 defer conn.Close() 37 38 ctx, cancel := testhelper.Context() 39 defer cancel() 40 41 wg := &sync.WaitGroup{} 42 for i := 0; i < 10; i++ { 43 wg.Add(1) 44 go func() { 45 defer wg.Done() 46 47 resp, err := client.Unary(ctx, &pb.UnaryRequest{}) 48 if !assert.NoError(t, err) { 49 return 50 } 51 if !assert.NotNil(t, resp) { 52 return 53 } 54 assert.True(t, resp.Ok) 55 }() 56 } 57 58 time.Sleep(100 * time.Millisecond) 59 60 require.Equal(t, 2, s.getRequestCount()) 61 62 close(s.blockCh) 63 wg.Wait() 64} 65 66func TestStreamLimitHandler(t *testing.T) { 67 testCases := []struct { 68 desc string 69 fullname string 70 f func(*testing.T, context.Context, pb.TestClient) 71 maxConcurrency int 72 expectedRequestCount int 73 }{ 74 { 75 desc: "Single request, multiple responses", 76 fullname: "/test.limithandler.Test/StreamOutput", 77 f: func(t *testing.T, ctx context.Context, client pb.TestClient) { 78 stream, err := client.StreamOutput(ctx, &pb.StreamOutputRequest{}) 79 require.NoError(t, err) 80 require.NotNil(t, stream) 81 82 r, err := stream.Recv() 83 require.NoError(t, err) 84 require.NotNil(t, r) 85 require.True(t, r.Ok) 86 }, 87 maxConcurrency: 3, 88 expectedRequestCount: 3, 89 }, 90 { 91 desc: "Multiple requests, single response", 92 fullname: "/test.limithandler.Test/StreamInput", 93 f: func(t *testing.T, ctx context.Context, client pb.TestClient) { 94 stream, err := client.StreamInput(ctx) 95 require.NoError(t, err) 96 require.NotNil(t, stream) 97 98 require.NoError(t, stream.Send(&pb.StreamInputRequest{})) 99 r, err := stream.CloseAndRecv() 100 require.NoError(t, err) 101 require.NotNil(t, r) 102 require.True(t, r.Ok) 103 }, 104 maxConcurrency: 3, 105 expectedRequestCount: 3, 106 }, 107 { 108 desc: "Multiple requests, multiple responses", 109 fullname: "/test.limithandler.Test/Bidirectional", 110 f: func(t *testing.T, ctx context.Context, client pb.TestClient) { 111 stream, err := client.Bidirectional(ctx) 112 require.NoError(t, err) 113 require.NotNil(t, stream) 114 115 require.NoError(t, stream.Send(&pb.BidirectionalRequest{})) 116 require.NoError(t, stream.CloseSend()) 117 118 r, err := stream.Recv() 119 require.NoError(t, err) 120 require.NotNil(t, r) 121 require.True(t, r.Ok) 122 }, 123 maxConcurrency: 3, 124 expectedRequestCount: 3, 125 }, 126 { 127 // Make sure that _streams_ are limited but that _requests_ on each 128 // allowed stream are not limited. 129 desc: "Multiple requests with same id, multiple responses", 130 fullname: "/test.limithandler.Test/Bidirectional", 131 f: func(t *testing.T, ctx context.Context, client pb.TestClient) { 132 stream, err := client.Bidirectional(ctx) 133 require.NoError(t, err) 134 require.NotNil(t, stream) 135 136 // Since the concurrency id is fixed all requests have the same 137 // id, but subsequent requests in a stream, even with the same 138 // id, should bypass the concurrency limiter 139 for i := 0; i < 10; i++ { 140 require.NoError(t, stream.Send(&pb.BidirectionalRequest{})) 141 } 142 require.NoError(t, stream.CloseSend()) 143 144 r, err := stream.Recv() 145 require.NoError(t, err) 146 require.NotNil(t, r) 147 require.True(t, r.Ok) 148 }, 149 maxConcurrency: 3, 150 // 3 (concurrent streams allowed) * 10 (requests per stream) 151 expectedRequestCount: 30, 152 }, 153 { 154 desc: "With a max concurrency of 0", 155 fullname: "/test.limithandler.Test/StreamOutput", 156 f: func(t *testing.T, ctx context.Context, client pb.TestClient) { 157 stream, err := client.StreamOutput(ctx, &pb.StreamOutputRequest{}) 158 require.NoError(t, err) 159 require.NotNil(t, stream) 160 161 r, err := stream.Recv() 162 require.NoError(t, err) 163 require.NotNil(t, r) 164 require.True(t, r.Ok) 165 }, 166 maxConcurrency: 0, 167 expectedRequestCount: 10, // Allow all 168 }, 169 } 170 171 for _, tc := range testCases { 172 t.Run(tc.desc, func(t *testing.T) { 173 s := &server{blockCh: make(chan struct{})} 174 175 limithandler.SetMaxRepoConcurrency(map[string]int{ 176 tc.fullname: tc.maxConcurrency, 177 }) 178 179 lh := limithandler.New(fixedLockKey) 180 interceptor := lh.StreamInterceptor() 181 srv, serverSocketPath := runServer(t, s, grpc.StreamInterceptor(interceptor)) 182 defer srv.Stop() 183 184 client, conn := newClient(t, serverSocketPath) 185 defer conn.Close() 186 187 ctx, cancel := testhelper.Context() 188 defer cancel() 189 190 wg := &sync.WaitGroup{} 191 for i := 0; i < 10; i++ { 192 wg.Add(1) 193 go func() { 194 defer wg.Done() 195 tc.f(t, ctx, client) 196 }() 197 } 198 199 time.Sleep(100 * time.Millisecond) 200 201 require.Equal(t, tc.expectedRequestCount, s.getRequestCount()) 202 203 close(s.blockCh) 204 wg.Wait() 205 }) 206 } 207} 208 209func runServer(t *testing.T, s *server, opt ...grpc.ServerOption) (*grpc.Server, string) { 210 serverSocketPath := testhelper.GetTemporaryGitalySocketFileName(t) 211 grpcServer := grpc.NewServer(opt...) 212 pb.RegisterTestServer(grpcServer, s) 213 214 lis, err := net.Listen("unix", serverSocketPath) 215 require.NoError(t, err) 216 217 go grpcServer.Serve(lis) 218 219 return grpcServer, "unix://" + serverSocketPath 220} 221 222func newClient(t *testing.T, serverSocketPath string) (pb.TestClient, *grpc.ClientConn) { 223 connOpts := []grpc.DialOption{ 224 grpc.WithInsecure(), 225 } 226 conn, err := grpc.Dial(serverSocketPath, connOpts...) 227 if err != nil { 228 t.Fatal(err) 229 } 230 231 return pb.NewTestClient(conn), conn 232} 233