1package ratelimiter 2 3import ( 4 "net/http" 5 "time" 6 7 "github.com/prometheus/client_golang/prometheus" 8 "golang.org/x/time/rate" 9 10 "gitlab.com/gitlab-org/gitlab-pages/internal/lru" 11 "gitlab.com/gitlab-org/gitlab-pages/internal/request" 12) 13 14const ( 15 // DefaultSourceIPLimitPerSecond is the limit per second that rate.Limiter 16 // needs to generate tokens every second. 17 // The default value is 20 requests per second. 18 DefaultSourceIPLimitPerSecond = 20.0 19 // DefaultSourceIPBurstSize is the maximum burst allowed per rate limiter. 20 // E.g. The first 100 requests within 1s will succeed, but the 101st will fail. 21 DefaultSourceIPBurstSize = 100 22 23 // based on an avg ~4,000 unique IPs per minute 24 // https://log.gprd.gitlab.net/app/lens#/edit/f7110d00-2013-11ec-8c8e-ed83b5469915?_g=h@e78830b 25 DefaultSourceIPCacheSize = 5000 26) 27 28// Option function to configure a RateLimiter 29type Option func(*RateLimiter) 30 31// KeyFunc returns unique identifier for the subject of rate limit(e.g. client IP or domain) 32type KeyFunc func(*http.Request) string 33 34// RateLimiter holds an LRU cache of elements to be rate limited. 35// It uses "golang.org/x/time/rate" as its Token Bucket rate limiter per source IP entry. 36// See example https://www.fatalerrors.org/a/design-and-implementation-of-time-rate-limiter-for-golang-standard-library.html 37// It also holds a now function that can be mocked in unit tests. 38type RateLimiter struct { 39 name string 40 now func() time.Time 41 limitPerSecond float64 42 burstSize int 43 blockedCount *prometheus.GaugeVec 44 cache *lru.Cache 45 key KeyFunc 46 47 cacheOptions []lru.Option 48} 49 50// New creates a new RateLimiter with default values that can be configured via Option functions 51func New(name string, opts ...Option) *RateLimiter { 52 rl := &RateLimiter{ 53 name: name, 54 now: time.Now, 55 limitPerSecond: DefaultSourceIPLimitPerSecond, 56 burstSize: DefaultSourceIPBurstSize, 57 key: request.GetRemoteAddrWithoutPort, 58 } 59 60 for _, opt := range opts { 61 opt(rl) 62 } 63 64 rl.cache = lru.New(name, rl.cacheOptions...) 65 66 return rl 67} 68 69// WithNow replaces the RateLimiter now function 70func WithNow(now func() time.Time) Option { 71 return func(rl *RateLimiter) { 72 rl.now = now 73 } 74} 75 76// WithLimitPerSecond allows configuring limit per second for RateLimiter 77func WithLimitPerSecond(limit float64) Option { 78 return func(rl *RateLimiter) { 79 rl.limitPerSecond = limit 80 } 81} 82 83// WithBurstSize configures burst per key for the RateLimiter 84func WithBurstSize(burst int) Option { 85 return func(rl *RateLimiter) { 86 rl.burstSize = burst 87 } 88} 89 90// WithBlockedCountMetric configures metric reporting how many requests were blocked 91func WithBlockedCountMetric(m *prometheus.GaugeVec) Option { 92 return func(rl *RateLimiter) { 93 rl.blockedCount = m 94 } 95} 96 97// WithCacheMaxSize configures cache size for ratelimiter 98func WithCacheMaxSize(size int64) Option { 99 return func(rl *RateLimiter) { 100 rl.cacheOptions = append(rl.cacheOptions, lru.WithMaxSize(size)) 101 } 102} 103 104// WithCachedEntriesMetric configures metric reporting how many keys are currently stored in 105// the rate-limiter cache 106func WithCachedEntriesMetric(m *prometheus.GaugeVec) Option { 107 return func(rl *RateLimiter) { 108 rl.cacheOptions = append(rl.cacheOptions, lru.WithCachedEntriesMetric(m)) 109 } 110} 111 112// WithCachedRequestsMetric configures metric for how many times we ask key cache 113func WithCachedRequestsMetric(m *prometheus.CounterVec) Option { 114 return func(rl *RateLimiter) { 115 rl.cacheOptions = append(rl.cacheOptions, lru.WithCachedRequestsMetric(m)) 116 } 117} 118 119func (rl *RateLimiter) limiter(key string) *rate.Limiter { 120 limiterI, _ := rl.cache.FindOrFetch(key, key, func() (interface{}, error) { 121 return rate.NewLimiter(rate.Limit(rl.limitPerSecond), rl.burstSize), nil 122 }) 123 124 return limiterI.(*rate.Limiter) 125} 126 127// RequestAllowed checks that the real remote IP address is allowed to perform an operation 128func (rl *RateLimiter) RequestAllowed(r *http.Request) bool { 129 rateLimitedKey := rl.key(r) 130 limiter := rl.limiter(rateLimitedKey) 131 132 // AllowN allows us to use the rl.now function, so we can test this more easily. 133 return limiter.AllowN(rl.now(), 1) 134} 135