1package limithandler
2
3import (
4	"context"
5	"strconv"
6	"sync"
7	"testing"
8	"time"
9
10	"github.com/stretchr/testify/assert"
11	"github.com/stretchr/testify/require"
12)
13
14type counter struct {
15	sync.Mutex
16	max      int
17	current  int
18	queued   int
19	dequeued int
20	enter    int
21	exit     int
22}
23
24func (c *counter) up() {
25	c.Lock()
26	defer c.Unlock()
27
28	c.current = c.current + 1
29	if c.current > c.max {
30		c.max = c.current
31	}
32}
33
34func (c *counter) down() {
35	c.Lock()
36	defer c.Unlock()
37
38	c.current = c.current - 1
39}
40
41func (c *counter) currentVal() int {
42	c.Lock()
43	defer c.Unlock()
44	return c.current
45}
46
47func (c *counter) Queued(ctx context.Context) {
48	c.Lock()
49	defer c.Unlock()
50	c.queued++
51}
52
53func (c *counter) Dequeued(ctx context.Context) {
54	c.Lock()
55	defer c.Unlock()
56	c.dequeued++
57}
58
59func (c *counter) Enter(ctx context.Context, acquireTime time.Duration) {
60	c.Lock()
61	defer c.Unlock()
62	c.enter++
63}
64
65func (c *counter) Exit(ctx context.Context) {
66	c.Lock()
67	defer c.Unlock()
68	c.exit++
69}
70
71func TestLimiter(t *testing.T) {
72	tests := []struct {
73		name             string
74		concurrency      int
75		maxConcurrency   int
76		iterations       int
77		buckets          int
78		wantMonitorCalls bool
79	}{
80		{
81			name:             "single",
82			concurrency:      1,
83			maxConcurrency:   1,
84			iterations:       1,
85			buckets:          1,
86			wantMonitorCalls: true,
87		},
88		{
89			name:             "two-at-a-time",
90			concurrency:      100,
91			maxConcurrency:   2,
92			iterations:       10,
93			buckets:          1,
94			wantMonitorCalls: true,
95		},
96		{
97			name:             "two-by-two",
98			concurrency:      100,
99			maxConcurrency:   2,
100			iterations:       4,
101			buckets:          2,
102			wantMonitorCalls: true,
103		},
104		{
105			name:             "no-limit",
106			concurrency:      10,
107			maxConcurrency:   0,
108			iterations:       200,
109			buckets:          1,
110			wantMonitorCalls: false,
111		},
112		{
113			name:           "wide-spread",
114			concurrency:    1000,
115			maxConcurrency: 2,
116			// We use a long delay here to prevent flakiness in CI. If the delay is
117			// too short, the first goroutines to enter the critical section will be
118			// gone before we hit the intended maximum concurrency.
119			iterations:       40,
120			buckets:          50,
121			wantMonitorCalls: true,
122		},
123	}
124	for _, tt := range tests {
125		t.Run(tt.name, func(t *testing.T) {
126			expectedGaugeMax := tt.maxConcurrency * tt.buckets
127			if tt.maxConcurrency <= 0 {
128				expectedGaugeMax = tt.concurrency
129			}
130
131			gauge := &counter{}
132
133			limiter := NewLimiter(tt.maxConcurrency, gauge)
134			wg := sync.WaitGroup{}
135			wg.Add(tt.concurrency)
136
137			full := sync.NewCond(&sync.Mutex{})
138
139			// primePump waits for the gauge to reach the minimum
140			// expected max concurrency so that the limiter is
141			// "warmed" up before proceeding with the test
142			primePump := func() {
143				full.L.Lock()
144				defer full.L.Unlock()
145
146				gauge.up()
147
148				if gauge.max >= expectedGaugeMax {
149					full.Broadcast()
150					return
151				}
152
153				full.Wait() // wait until full is broadcast
154			}
155
156			// We know of an edge case that can lead to the rate limiter
157			// occasionally letting one or two extra goroutines run
158			// concurrently.
159			for c := 0; c < tt.concurrency; c++ {
160				go func(counter int) {
161					for i := 0; i < tt.iterations; i++ {
162						lockKey := strconv.Itoa((i ^ counter) % tt.buckets)
163
164						_, err := limiter.Limit(context.Background(), lockKey, func() (interface{}, error) {
165							primePump()
166
167							current := gauge.currentVal()
168							require.True(t, current <= expectedGaugeMax, "Expected the number of concurrent operations (%v) to not exceed the maximum concurrency (%v)", current, expectedGaugeMax)
169
170							require.True(t, limiter.countSemaphores() <= tt.buckets, "Expected the number of semaphores (%v) to be lte number of buckets (%v)", limiter.countSemaphores(), tt.buckets)
171
172							gauge.down()
173							return nil, nil
174						})
175						require.NoError(t, err)
176					}
177
178					wg.Done()
179				}(c)
180			}
181
182			wg.Wait()
183
184			assert.Equal(t, expectedGaugeMax, gauge.max, "Expected maximum concurrency")
185			assert.Equal(t, 0, gauge.current)
186			assert.Equal(t, 0, limiter.countSemaphores())
187
188			var wantMonitorCallCount int
189			if tt.wantMonitorCalls {
190				wantMonitorCallCount = tt.concurrency * tt.iterations
191			} else {
192				wantMonitorCallCount = 0
193			}
194
195			assert.Equal(t, wantMonitorCallCount, gauge.enter)
196			assert.Equal(t, wantMonitorCallCount, gauge.exit)
197			assert.Equal(t, wantMonitorCallCount, gauge.queued)
198			assert.Equal(t, wantMonitorCallCount, gauge.dequeued)
199		})
200	}
201}
202