1// Package memorystore defines an in-memory storage system for limiting.
2package memorystore
3
4import (
5	"sync"
6	"sync/atomic"
7	"time"
8	"unsafe"
9
10	"github.com/sethvargo/go-limiter"
11	"github.com/sethvargo/go-limiter/internal/fasttime"
12)
13
14var _ limiter.Store = (*store)(nil)
15
16type store struct {
17	tokens   uint64
18	interval time.Duration
19	rate     float64
20
21	sweepInterval time.Duration
22	sweepMinTTL   uint64
23
24	data     map[string]*bucket
25	dataLock sync.RWMutex
26
27	stopped uint32
28	stopCh  chan struct{}
29}
30
31// Config is used as input to New. It defines the behavior of the storage
32// system.
33type Config struct {
34	// Tokens is the number of tokens to allow per interval. The default value is
35	// 1.
36	Tokens uint64
37
38	// Interval is the time interval upon which to enforce rate limiting. The
39	// default value is 1 second.
40	Interval time.Duration
41
42	// SweepInterval is the rate at which to run the garabage collection on stale
43	// entries. Setting this to a low value will optimize memory consumption, but
44	// will likely reduce performance and increase lock contention. Setting this
45	// to a high value will maximum throughput, but will increase the memory
46	// footprint. This can be tuned in combination with SweepMinTTL to control how
47	// long stale entires are kept. The default value is 6 hours.
48	SweepInterval time.Duration
49
50	// SweepMinTTL is the minimum amount of time a session must be inactive before
51	// clearing it from the entries. There's no validation, but this should be at
52	// least as high as your rate limit, or else the data store will purge records
53	// before they limit is applied. The default value is 12 hours.
54	SweepMinTTL time.Duration
55
56	// InitialAlloc is the size to use for the in-memory map. Go will
57	// automatically expand the buffer, but choosing higher number can trade
58	// memory consumption for performance as it limits the number of times the map
59	// needs to expand. The default value is 4096.
60	InitialAlloc int
61}
62
63// New creates an in-memory rate limiter that uses a bucketing model to limit
64// the number of permitted events over an interval. It's optimized for runtime
65// and memory efficiency.
66func New(c *Config) (limiter.Store, error) {
67	if c == nil {
68		c = new(Config)
69	}
70
71	tokens := uint64(1)
72	if c.Tokens > 0 {
73		tokens = c.Tokens
74	}
75
76	interval := 1 * time.Second
77	if c.Interval > 0 {
78		interval = c.Interval
79	}
80
81	sweepInterval := 6 * time.Hour
82	if c.SweepInterval > 0 {
83		sweepInterval = c.SweepInterval
84	}
85
86	sweepMinTTL := 12 * time.Hour
87	if c.SweepMinTTL > 0 {
88		sweepMinTTL = c.SweepMinTTL
89	}
90
91	initialAlloc := 4096
92	if c.InitialAlloc > 0 {
93		initialAlloc = c.InitialAlloc
94	}
95
96	s := &store{
97		tokens:   tokens,
98		interval: interval,
99		rate:     float64(interval) / float64(tokens),
100
101		sweepInterval: sweepInterval,
102		sweepMinTTL:   uint64(sweepMinTTL),
103
104		data:   make(map[string]*bucket, initialAlloc),
105		stopCh: make(chan struct{}),
106	}
107	go s.purge()
108	return s, nil
109}
110
111// Take attempts to remove a token from the named key. If the take is
112// successful, it returns true, otherwise false. It also returns the configured
113// limit, remaining tokens, and reset time.
114func (s *store) Take(key string) (uint64, uint64, uint64, bool) {
115	// If the store is stopped, all requests are rejected.
116	if atomic.LoadUint32(&s.stopped) == 1 {
117		return 0, 0, 0, false
118	}
119
120	// Acquire a read lock first - this allows other to concurrently check limits
121	// without taking a full lock.
122	s.dataLock.RLock()
123	if b, ok := s.data[key]; ok {
124		s.dataLock.RUnlock()
125		return b.take()
126	}
127	s.dataLock.RUnlock()
128
129	// Unfortunately we did not find the key in the map. Take out a full lock. We
130	// have to check if the key exists again, because it's possible another
131	// goroutine created it between our shared lock and exclusive lock.
132	s.dataLock.Lock()
133	if b, ok := s.data[key]; ok {
134		s.dataLock.Unlock()
135		return b.take()
136	}
137
138	// This is the first time we've seen this entry (or it's been garbage
139	// collected), so create the bucket and take an initial request.
140	b := newBucket(s.tokens, s.interval, s.rate)
141
142	// Add it to the map and take.
143	s.data[key] = b
144	s.dataLock.Unlock()
145	return b.take()
146}
147
148// Close stops the memory limiter and cleans up any outstanding sessions. You
149// should absolutely always call Close() as it releases the memory consumed by
150// the map AND releases the tickers.
151func (s *store) Close() error {
152	if !atomic.CompareAndSwapUint32(&s.stopped, 0, 1) {
153		return nil
154	}
155
156	// Close the channel to prevent future purging.
157	close(s.stopCh)
158
159	// Delete all the things.
160	s.dataLock.Lock()
161	for k := range s.data {
162		delete(s.data, k)
163	}
164	s.dataLock.Unlock()
165	return nil
166}
167
168// purge continually iterates over the map and purges old values on the provided
169// sweep interval. Earlier designs used a go-function-per-item expiration, but
170// it actually generated *more* lock contention under normal use. The most
171// performant option with real-world data was a global garbage collection on a
172// fixed interval.
173func (s *store) purge() {
174	ticker := time.NewTicker(s.sweepInterval)
175	defer ticker.Stop()
176
177	for {
178		select {
179		case <-s.stopCh:
180			return
181		case <-ticker.C:
182		}
183
184		s.dataLock.Lock()
185		now := fasttime.Now()
186		for k, b := range s.data {
187			lastTick := (*bucketState)(atomic.LoadPointer(&b.bucketState)).lastTick
188			lastTime := b.startTime + (lastTick * uint64(b.interval))
189
190			if now-lastTime > s.sweepMinTTL {
191				delete(s.data, k)
192			}
193		}
194		s.dataLock.Unlock()
195	}
196}
197
198// bucket is an internal wrapper around a taker.
199type bucket struct {
200	// startTime is the number of nanoseconds from unix epoch when this bucket was
201	// initially created.
202	startTime uint64
203
204	// maxTokens is the maximum number of tokens permitted on the bucket at any
205	// time. The number of available tokens will never exceed this value.
206	maxTokens uint64
207
208	// interval is the time at which ticking should occur.
209	interval time.Duration
210
211	// bucketState is the mutable internal state of the event. It includes the
212	// current number of available tokens and the last time the clock ticked. It
213	// should always be loaded with atomic as it is not concurrent safe.
214	bucketState unsafe.Pointer
215
216	// fillRate is the number of tokens to add per nanosecond. It is calculated
217	// based on the provided maxTokens and interval.
218	fillRate float64
219}
220
221// bucketState represents the internal bucket state.
222type bucketState struct {
223	// availableTokens is the current point-in-time number of tokens remaining.
224	// This value changes frequently and must be guarded by an atomic read/write.
225	availableTokens uint64
226
227	// lastTick is the last clock tick, used to re-calculate the number of tokens
228	// on the bucket.
229	lastTick uint64
230}
231
232// newBucket creates a new bucket from the given tokens and interval.
233func newBucket(tokens uint64, interval time.Duration, rate float64) *bucket {
234	b := &bucket{
235		startTime: fasttime.Now(),
236		maxTokens: tokens,
237		interval:  interval,
238		fillRate:  rate,
239
240		bucketState: unsafe.Pointer(&bucketState{
241			availableTokens: tokens,
242		}),
243	}
244	return b
245}
246
247// take attempts to remove a token from the bucket. If there are no tokens
248// available and the clock has ticked forward, it recalculates the number of
249// tokens and retries. It returns the limit, remaining tokens, time until
250// refresh, and whether the take was successful.
251func (b *bucket) take() (uint64, uint64, uint64, bool) {
252	// Capture the current request time, current tick, and amount of time until
253	// the bucket resets.
254	now := fasttime.Now()
255	currTick := tick(b.startTime, now, b.interval)
256	next := b.startTime + ((currTick + 1) * uint64(b.interval))
257
258	for {
259		curr := atomic.LoadPointer(&b.bucketState)
260		currState := (*bucketState)(curr)
261		lastTick := currState.lastTick
262		tokens := currState.availableTokens
263
264		if lastTick < currTick {
265			tokens = availableTokens(currState.lastTick, currTick, b.maxTokens, b.fillRate)
266			lastTick = currTick
267
268			if !atomic.CompareAndSwapPointer(&b.bucketState, curr, unsafe.Pointer(&bucketState{
269				availableTokens: tokens,
270				lastTick:        lastTick,
271			})) {
272				// Someone else modified the value
273				continue
274			}
275		}
276
277		if tokens > 0 {
278			tokens--
279			if !atomic.CompareAndSwapPointer(&b.bucketState, curr, unsafe.Pointer(&bucketState{
280				availableTokens: tokens,
281				lastTick:        lastTick,
282			})) {
283				// There were tokens left, but someone took them :(
284				continue
285			}
286
287			return b.maxTokens, tokens, next, true
288		}
289
290		// Returning the TTL until next tick.
291		return b.maxTokens, 0, next, false
292	}
293}
294
295// availableTokens returns the number of available tokens, up to max, between
296// the two ticks.
297func availableTokens(last, curr, max uint64, fillRate float64) uint64 {
298	delta := curr - last
299
300	available := uint64(float64(delta) * fillRate)
301	if available > max {
302		available = max
303	}
304
305	return available
306}
307
308// tick is the total number of times the current interval has occurred between
309// when the time started (start) and the current time (curr). For example, if
310// the start time was 12:30pm and it's currently 1:00pm, and the interval was 5
311// minutes, tick would return 6 because 1:00pm is the 6th 5-minute tick. Note
312// that tick would return 5 at 12:59pm, because it hasn't reached the 6th tick
313// yet.
314func tick(start, curr uint64, interval time.Duration) uint64 {
315	return (curr - start) / uint64(interval)
316}
317