1package ratelimit
2
3import (
4	"context"
5	"errors"
6	"testing"
7
8	"google.golang.org/grpc"
9
10	"github.com/stretchr/testify/assert"
11)
12
13const errMsgFake = "fake error"
14
15type mockPassLimiter struct{}
16
17func (*mockPassLimiter) Limit() bool {
18	return false
19}
20
21func TestUnaryServerInterceptor_RateLimitPass(t *testing.T) {
22	interceptor := UnaryServerInterceptor(&mockPassLimiter{})
23	handler := func(ctx context.Context, req interface{}) (interface{}, error) {
24		return nil, errors.New(errMsgFake)
25	}
26	info := &grpc.UnaryServerInfo{
27		FullMethod: "FakeMethod",
28	}
29	req, err := interceptor(nil, nil, info, handler)
30	assert.Nil(t, req)
31	assert.EqualError(t, err, errMsgFake)
32}
33
34type mockFailLimiter struct{}
35
36func (*mockFailLimiter) Limit() bool {
37	return true
38}
39
40func TestUnaryServerInterceptor_RateLimitFail(t *testing.T) {
41	interceptor := UnaryServerInterceptor(&mockFailLimiter{})
42	handler := func(ctx context.Context, req interface{}) (interface{}, error) {
43		return nil, errors.New(errMsgFake)
44	}
45	info := &grpc.UnaryServerInfo{
46		FullMethod: "FakeMethod",
47	}
48	req, err := interceptor(nil, nil, info, handler)
49	assert.Nil(t, req)
50	assert.EqualError(t, err, "rpc error: code = ResourceExhausted desc = FakeMethod is rejected by grpc_ratelimit middleware, please retry later.")
51}
52
53func TestStreamServerInterceptor_RateLimitPass(t *testing.T) {
54	interceptor := StreamServerInterceptor(&mockPassLimiter{})
55	handler := func(srv interface{}, stream grpc.ServerStream) error {
56		return errors.New(errMsgFake)
57	}
58	info := &grpc.StreamServerInfo{
59		FullMethod: "FakeMethod",
60	}
61	err := interceptor(nil, nil, info, handler)
62	assert.EqualError(t, err, errMsgFake)
63}
64
65func TestStreamServerInterceptor_RateLimitFail(t *testing.T) {
66	interceptor := StreamServerInterceptor(&mockFailLimiter{})
67	handler := func(srv interface{}, stream grpc.ServerStream) error {
68		return errors.New(errMsgFake)
69	}
70	info := &grpc.StreamServerInfo{
71		FullMethod: "FakeMethod",
72	}
73	err := interceptor(nil, nil, info, handler)
74	assert.EqualError(t, err, "rpc error: code = ResourceExhausted desc = FakeMethod is rejected by grpc_ratelimit middleware, please retry later.")
75}
76