1package limithandler 2 3import ( 4 "context" 5 "strconv" 6 "sync" 7 "testing" 8 "time" 9 10 "github.com/stretchr/testify/assert" 11 "github.com/stretchr/testify/require" 12) 13 14type counter struct { 15 sync.Mutex 16 max int 17 current int 18 queued int 19 dequeued int 20 enter int 21 exit int 22} 23 24func (c *counter) up() { 25 c.Lock() 26 defer c.Unlock() 27 28 c.current = c.current + 1 29 if c.current > c.max { 30 c.max = c.current 31 } 32} 33 34func (c *counter) down() { 35 c.Lock() 36 defer c.Unlock() 37 38 c.current = c.current - 1 39} 40 41func (c *counter) currentVal() int { 42 c.Lock() 43 defer c.Unlock() 44 return c.current 45} 46 47func (c *counter) Queued(ctx context.Context) { 48 c.Lock() 49 defer c.Unlock() 50 c.queued++ 51} 52 53func (c *counter) Dequeued(ctx context.Context) { 54 c.Lock() 55 defer c.Unlock() 56 c.dequeued++ 57} 58 59func (c *counter) Enter(ctx context.Context, acquireTime time.Duration) { 60 c.Lock() 61 defer c.Unlock() 62 c.enter++ 63} 64 65func (c *counter) Exit(ctx context.Context) { 66 c.Lock() 67 defer c.Unlock() 68 c.exit++ 69} 70 71func TestLimiter(t *testing.T) { 72 tests := []struct { 73 name string 74 concurrency int 75 maxConcurrency int 76 iterations int 77 buckets int 78 wantMonitorCalls bool 79 }{ 80 { 81 name: "single", 82 concurrency: 1, 83 maxConcurrency: 1, 84 iterations: 1, 85 buckets: 1, 86 wantMonitorCalls: true, 87 }, 88 { 89 name: "two-at-a-time", 90 concurrency: 100, 91 maxConcurrency: 2, 92 iterations: 10, 93 buckets: 1, 94 wantMonitorCalls: true, 95 }, 96 { 97 name: "two-by-two", 98 concurrency: 100, 99 maxConcurrency: 2, 100 iterations: 4, 101 buckets: 2, 102 wantMonitorCalls: true, 103 }, 104 { 105 name: "no-limit", 106 concurrency: 10, 107 maxConcurrency: 0, 108 iterations: 200, 109 buckets: 1, 110 wantMonitorCalls: false, 111 }, 112 { 113 name: "wide-spread", 114 concurrency: 1000, 115 maxConcurrency: 2, 116 // We use a long delay here to prevent flakiness in CI. If the delay is 117 // too short, the first goroutines to enter the critical section will be 118 // gone before we hit the intended maximum concurrency. 119 iterations: 40, 120 buckets: 50, 121 wantMonitorCalls: true, 122 }, 123 } 124 for _, tt := range tests { 125 t.Run(tt.name, func(t *testing.T) { 126 expectedGaugeMax := tt.maxConcurrency * tt.buckets 127 if tt.maxConcurrency <= 0 { 128 expectedGaugeMax = tt.concurrency 129 } 130 131 gauge := &counter{} 132 133 limiter := NewLimiter(tt.maxConcurrency, gauge) 134 wg := sync.WaitGroup{} 135 wg.Add(tt.concurrency) 136 137 full := sync.NewCond(&sync.Mutex{}) 138 139 // primePump waits for the gauge to reach the minimum 140 // expected max concurrency so that the limiter is 141 // "warmed" up before proceeding with the test 142 primePump := func() { 143 full.L.Lock() 144 defer full.L.Unlock() 145 146 gauge.up() 147 148 if gauge.max >= expectedGaugeMax { 149 full.Broadcast() 150 return 151 } 152 153 full.Wait() // wait until full is broadcast 154 } 155 156 // We know of an edge case that can lead to the rate limiter 157 // occasionally letting one or two extra goroutines run 158 // concurrently. 159 for c := 0; c < tt.concurrency; c++ { 160 go func(counter int) { 161 for i := 0; i < tt.iterations; i++ { 162 lockKey := strconv.Itoa((i ^ counter) % tt.buckets) 163 164 _, err := limiter.Limit(context.Background(), lockKey, func() (interface{}, error) { 165 primePump() 166 167 current := gauge.currentVal() 168 require.True(t, current <= expectedGaugeMax, "Expected the number of concurrent operations (%v) to not exceed the maximum concurrency (%v)", current, expectedGaugeMax) 169 170 require.True(t, limiter.countSemaphores() <= tt.buckets, "Expected the number of semaphores (%v) to be lte number of buckets (%v)", limiter.countSemaphores(), tt.buckets) 171 172 gauge.down() 173 return nil, nil 174 }) 175 require.NoError(t, err) 176 } 177 178 wg.Done() 179 }(c) 180 } 181 182 wg.Wait() 183 184 assert.Equal(t, expectedGaugeMax, gauge.max, "Expected maximum concurrency") 185 assert.Equal(t, 0, gauge.current) 186 assert.Equal(t, 0, limiter.countSemaphores()) 187 188 var wantMonitorCallCount int 189 if tt.wantMonitorCalls { 190 wantMonitorCallCount = tt.concurrency * tt.iterations 191 } else { 192 wantMonitorCallCount = 0 193 } 194 195 assert.Equal(t, wantMonitorCallCount, gauge.enter) 196 assert.Equal(t, wantMonitorCallCount, gauge.exit) 197 assert.Equal(t, wantMonitorCallCount, gauge.queued) 198 assert.Equal(t, wantMonitorCallCount, gauge.dequeued) 199 }) 200 } 201} 202