1// Package memorystore defines an in-memory storage system for limiting. 2package memorystore 3 4import ( 5 "sync" 6 "sync/atomic" 7 "time" 8 "unsafe" 9 10 "github.com/sethvargo/go-limiter" 11 "github.com/sethvargo/go-limiter/internal/fasttime" 12) 13 14var _ limiter.Store = (*store)(nil) 15 16type store struct { 17 tokens uint64 18 interval time.Duration 19 rate float64 20 21 sweepInterval time.Duration 22 sweepMinTTL uint64 23 24 data map[string]*bucket 25 dataLock sync.RWMutex 26 27 stopped uint32 28 stopCh chan struct{} 29} 30 31// Config is used as input to New. It defines the behavior of the storage 32// system. 33type Config struct { 34 // Tokens is the number of tokens to allow per interval. The default value is 35 // 1. 36 Tokens uint64 37 38 // Interval is the time interval upon which to enforce rate limiting. The 39 // default value is 1 second. 40 Interval time.Duration 41 42 // SweepInterval is the rate at which to run the garabage collection on stale 43 // entries. Setting this to a low value will optimize memory consumption, but 44 // will likely reduce performance and increase lock contention. Setting this 45 // to a high value will maximum throughput, but will increase the memory 46 // footprint. This can be tuned in combination with SweepMinTTL to control how 47 // long stale entires are kept. The default value is 6 hours. 48 SweepInterval time.Duration 49 50 // SweepMinTTL is the minimum amount of time a session must be inactive before 51 // clearing it from the entries. There's no validation, but this should be at 52 // least as high as your rate limit, or else the data store will purge records 53 // before they limit is applied. The default value is 12 hours. 54 SweepMinTTL time.Duration 55 56 // InitialAlloc is the size to use for the in-memory map. Go will 57 // automatically expand the buffer, but choosing higher number can trade 58 // memory consumption for performance as it limits the number of times the map 59 // needs to expand. The default value is 4096. 60 InitialAlloc int 61} 62 63// New creates an in-memory rate limiter that uses a bucketing model to limit 64// the number of permitted events over an interval. It's optimized for runtime 65// and memory efficiency. 66func New(c *Config) (limiter.Store, error) { 67 if c == nil { 68 c = new(Config) 69 } 70 71 tokens := uint64(1) 72 if c.Tokens > 0 { 73 tokens = c.Tokens 74 } 75 76 interval := 1 * time.Second 77 if c.Interval > 0 { 78 interval = c.Interval 79 } 80 81 sweepInterval := 6 * time.Hour 82 if c.SweepInterval > 0 { 83 sweepInterval = c.SweepInterval 84 } 85 86 sweepMinTTL := 12 * time.Hour 87 if c.SweepMinTTL > 0 { 88 sweepMinTTL = c.SweepMinTTL 89 } 90 91 initialAlloc := 4096 92 if c.InitialAlloc > 0 { 93 initialAlloc = c.InitialAlloc 94 } 95 96 s := &store{ 97 tokens: tokens, 98 interval: interval, 99 rate: float64(interval) / float64(tokens), 100 101 sweepInterval: sweepInterval, 102 sweepMinTTL: uint64(sweepMinTTL), 103 104 data: make(map[string]*bucket, initialAlloc), 105 stopCh: make(chan struct{}), 106 } 107 go s.purge() 108 return s, nil 109} 110 111// Take attempts to remove a token from the named key. If the take is 112// successful, it returns true, otherwise false. It also returns the configured 113// limit, remaining tokens, and reset time. 114func (s *store) Take(key string) (uint64, uint64, uint64, bool) { 115 // If the store is stopped, all requests are rejected. 116 if atomic.LoadUint32(&s.stopped) == 1 { 117 return 0, 0, 0, false 118 } 119 120 // Acquire a read lock first - this allows other to concurrently check limits 121 // without taking a full lock. 122 s.dataLock.RLock() 123 if b, ok := s.data[key]; ok { 124 s.dataLock.RUnlock() 125 return b.take() 126 } 127 s.dataLock.RUnlock() 128 129 // Unfortunately we did not find the key in the map. Take out a full lock. We 130 // have to check if the key exists again, because it's possible another 131 // goroutine created it between our shared lock and exclusive lock. 132 s.dataLock.Lock() 133 if b, ok := s.data[key]; ok { 134 s.dataLock.Unlock() 135 return b.take() 136 } 137 138 // This is the first time we've seen this entry (or it's been garbage 139 // collected), so create the bucket and take an initial request. 140 b := newBucket(s.tokens, s.interval, s.rate) 141 142 // Add it to the map and take. 143 s.data[key] = b 144 s.dataLock.Unlock() 145 return b.take() 146} 147 148// Close stops the memory limiter and cleans up any outstanding sessions. You 149// should absolutely always call Close() as it releases the memory consumed by 150// the map AND releases the tickers. 151func (s *store) Close() error { 152 if !atomic.CompareAndSwapUint32(&s.stopped, 0, 1) { 153 return nil 154 } 155 156 // Close the channel to prevent future purging. 157 close(s.stopCh) 158 159 // Delete all the things. 160 s.dataLock.Lock() 161 for k := range s.data { 162 delete(s.data, k) 163 } 164 s.dataLock.Unlock() 165 return nil 166} 167 168// purge continually iterates over the map and purges old values on the provided 169// sweep interval. Earlier designs used a go-function-per-item expiration, but 170// it actually generated *more* lock contention under normal use. The most 171// performant option with real-world data was a global garbage collection on a 172// fixed interval. 173func (s *store) purge() { 174 ticker := time.NewTicker(s.sweepInterval) 175 defer ticker.Stop() 176 177 for { 178 select { 179 case <-s.stopCh: 180 return 181 case <-ticker.C: 182 } 183 184 s.dataLock.Lock() 185 now := fasttime.Now() 186 for k, b := range s.data { 187 lastTick := (*bucketState)(atomic.LoadPointer(&b.bucketState)).lastTick 188 lastTime := b.startTime + (lastTick * uint64(b.interval)) 189 190 if now-lastTime > s.sweepMinTTL { 191 delete(s.data, k) 192 } 193 } 194 s.dataLock.Unlock() 195 } 196} 197 198// bucket is an internal wrapper around a taker. 199type bucket struct { 200 // startTime is the number of nanoseconds from unix epoch when this bucket was 201 // initially created. 202 startTime uint64 203 204 // maxTokens is the maximum number of tokens permitted on the bucket at any 205 // time. The number of available tokens will never exceed this value. 206 maxTokens uint64 207 208 // interval is the time at which ticking should occur. 209 interval time.Duration 210 211 // bucketState is the mutable internal state of the event. It includes the 212 // current number of available tokens and the last time the clock ticked. It 213 // should always be loaded with atomic as it is not concurrent safe. 214 bucketState unsafe.Pointer 215 216 // fillRate is the number of tokens to add per nanosecond. It is calculated 217 // based on the provided maxTokens and interval. 218 fillRate float64 219} 220 221// bucketState represents the internal bucket state. 222type bucketState struct { 223 // availableTokens is the current point-in-time number of tokens remaining. 224 // This value changes frequently and must be guarded by an atomic read/write. 225 availableTokens uint64 226 227 // lastTick is the last clock tick, used to re-calculate the number of tokens 228 // on the bucket. 229 lastTick uint64 230} 231 232// newBucket creates a new bucket from the given tokens and interval. 233func newBucket(tokens uint64, interval time.Duration, rate float64) *bucket { 234 b := &bucket{ 235 startTime: fasttime.Now(), 236 maxTokens: tokens, 237 interval: interval, 238 fillRate: rate, 239 240 bucketState: unsafe.Pointer(&bucketState{ 241 availableTokens: tokens, 242 }), 243 } 244 return b 245} 246 247// take attempts to remove a token from the bucket. If there are no tokens 248// available and the clock has ticked forward, it recalculates the number of 249// tokens and retries. It returns the limit, remaining tokens, time until 250// refresh, and whether the take was successful. 251func (b *bucket) take() (uint64, uint64, uint64, bool) { 252 // Capture the current request time, current tick, and amount of time until 253 // the bucket resets. 254 now := fasttime.Now() 255 currTick := tick(b.startTime, now, b.interval) 256 next := b.startTime + ((currTick + 1) * uint64(b.interval)) 257 258 for { 259 curr := atomic.LoadPointer(&b.bucketState) 260 currState := (*bucketState)(curr) 261 lastTick := currState.lastTick 262 tokens := currState.availableTokens 263 264 if lastTick < currTick { 265 tokens = availableTokens(currState.lastTick, currTick, b.maxTokens, b.fillRate) 266 lastTick = currTick 267 268 if !atomic.CompareAndSwapPointer(&b.bucketState, curr, unsafe.Pointer(&bucketState{ 269 availableTokens: tokens, 270 lastTick: lastTick, 271 })) { 272 // Someone else modified the value 273 continue 274 } 275 } 276 277 if tokens > 0 { 278 tokens-- 279 if !atomic.CompareAndSwapPointer(&b.bucketState, curr, unsafe.Pointer(&bucketState{ 280 availableTokens: tokens, 281 lastTick: lastTick, 282 })) { 283 // There were tokens left, but someone took them :( 284 continue 285 } 286 287 return b.maxTokens, tokens, next, true 288 } 289 290 // Returning the TTL until next tick. 291 return b.maxTokens, 0, next, false 292 } 293} 294 295// availableTokens returns the number of available tokens, up to max, between 296// the two ticks. 297func availableTokens(last, curr, max uint64, fillRate float64) uint64 { 298 delta := curr - last 299 300 available := uint64(float64(delta) * fillRate) 301 if available > max { 302 available = max 303 } 304 305 return available 306} 307 308// tick is the total number of times the current interval has occurred between 309// when the time started (start) and the current time (curr). For example, if 310// the start time was 12:30pm and it's currently 1:00pm, and the interval was 5 311// minutes, tick would return 6 because 1:00pm is the 6th 5-minute tick. Note 312// that tick would return 5 at 12:59pm, because it hasn't reached the 6th tick 313// yet. 314func tick(start, curr uint64, interval time.Duration) uint64 { 315 return (curr - start) / uint64(interval) 316} 317