1package ratelimiter
2
3import (
4	"net/http"
5	"time"
6
7	"github.com/prometheus/client_golang/prometheus"
8	"golang.org/x/time/rate"
9
10	"gitlab.com/gitlab-org/gitlab-pages/internal/lru"
11	"gitlab.com/gitlab-org/gitlab-pages/internal/request"
12)
13
14const (
15	// DefaultSourceIPLimitPerSecond is the limit per second that rate.Limiter
16	// needs to generate tokens every second.
17	// The default value is 20 requests per second.
18	DefaultSourceIPLimitPerSecond = 20.0
19	// DefaultSourceIPBurstSize is the maximum burst allowed per rate limiter.
20	// E.g. The first 100 requests within 1s will succeed, but the 101st will fail.
21	DefaultSourceIPBurstSize = 100
22
23	// based on an avg ~4,000 unique IPs per minute
24	// https://log.gprd.gitlab.net/app/lens#/edit/f7110d00-2013-11ec-8c8e-ed83b5469915?_g=h@e78830b
25	DefaultSourceIPCacheSize = 5000
26)
27
28// Option function to configure a RateLimiter
29type Option func(*RateLimiter)
30
31// KeyFunc returns unique identifier for the subject of rate limit(e.g. client IP or domain)
32type KeyFunc func(*http.Request) string
33
34// RateLimiter holds an LRU cache of elements to be rate limited.
35// It uses "golang.org/x/time/rate" as its Token Bucket rate limiter per source IP entry.
36// See example https://www.fatalerrors.org/a/design-and-implementation-of-time-rate-limiter-for-golang-standard-library.html
37// It also holds a now function that can be mocked in unit tests.
38type RateLimiter struct {
39	name           string
40	now            func() time.Time
41	limitPerSecond float64
42	burstSize      int
43	blockedCount   *prometheus.GaugeVec
44	cache          *lru.Cache
45	key            KeyFunc
46
47	cacheOptions []lru.Option
48}
49
50// New creates a new RateLimiter with default values that can be configured via Option functions
51func New(name string, opts ...Option) *RateLimiter {
52	rl := &RateLimiter{
53		name:           name,
54		now:            time.Now,
55		limitPerSecond: DefaultSourceIPLimitPerSecond,
56		burstSize:      DefaultSourceIPBurstSize,
57		key:            request.GetRemoteAddrWithoutPort,
58	}
59
60	for _, opt := range opts {
61		opt(rl)
62	}
63
64	rl.cache = lru.New(name, rl.cacheOptions...)
65
66	return rl
67}
68
69// WithNow replaces the RateLimiter now function
70func WithNow(now func() time.Time) Option {
71	return func(rl *RateLimiter) {
72		rl.now = now
73	}
74}
75
76// WithLimitPerSecond allows configuring limit per second for RateLimiter
77func WithLimitPerSecond(limit float64) Option {
78	return func(rl *RateLimiter) {
79		rl.limitPerSecond = limit
80	}
81}
82
83// WithBurstSize configures burst per key for the RateLimiter
84func WithBurstSize(burst int) Option {
85	return func(rl *RateLimiter) {
86		rl.burstSize = burst
87	}
88}
89
90// WithBlockedCountMetric configures metric reporting how many requests were blocked
91func WithBlockedCountMetric(m *prometheus.GaugeVec) Option {
92	return func(rl *RateLimiter) {
93		rl.blockedCount = m
94	}
95}
96
97// WithCacheMaxSize configures cache size for ratelimiter
98func WithCacheMaxSize(size int64) Option {
99	return func(rl *RateLimiter) {
100		rl.cacheOptions = append(rl.cacheOptions, lru.WithMaxSize(size))
101	}
102}
103
104// WithCachedEntriesMetric configures metric reporting how many keys are currently stored in
105// the rate-limiter cache
106func WithCachedEntriesMetric(m *prometheus.GaugeVec) Option {
107	return func(rl *RateLimiter) {
108		rl.cacheOptions = append(rl.cacheOptions, lru.WithCachedEntriesMetric(m))
109	}
110}
111
112// WithCachedRequestsMetric configures metric for how many times we ask key cache
113func WithCachedRequestsMetric(m *prometheus.CounterVec) Option {
114	return func(rl *RateLimiter) {
115		rl.cacheOptions = append(rl.cacheOptions, lru.WithCachedRequestsMetric(m))
116	}
117}
118
119func (rl *RateLimiter) limiter(key string) *rate.Limiter {
120	limiterI, _ := rl.cache.FindOrFetch(key, key, func() (interface{}, error) {
121		return rate.NewLimiter(rate.Limit(rl.limitPerSecond), rl.burstSize), nil
122	})
123
124	return limiterI.(*rate.Limiter)
125}
126
127// RequestAllowed checks that the real remote IP address is allowed to perform an operation
128func (rl *RateLimiter) RequestAllowed(r *http.Request) bool {
129	rateLimitedKey := rl.key(r)
130	limiter := rl.limiter(rateLimitedKey)
131
132	// AllowN allows us to use the rl.now function, so we can test this more easily.
133	return limiter.AllowN(rl.now(), 1)
134}
135