1package common
2
3import (
4	"errors"
5	"fmt"
6	"net"
7	"sort"
8	"sync"
9	"sync/atomic"
10	"time"
11
12	"golang.org/x/time/rate"
13
14	"github.com/drakkan/sftpgo/v2/util"
15)
16
17var (
18	errNoBucket               = errors.New("no bucket found")
19	errReserve                = errors.New("unable to reserve token")
20	rateLimiterProtocolValues = []string{ProtocolSSH, ProtocolFTP, ProtocolWebDAV, ProtocolHTTP}
21)
22
23// RateLimiterType defines the supported rate limiters types
24type RateLimiterType int
25
26// Supported rate limiter types
27const (
28	rateLimiterTypeGlobal RateLimiterType = iota + 1
29	rateLimiterTypeSource
30)
31
32// RateLimiterConfig defines the configuration for a rate limiter
33type RateLimiterConfig struct {
34	// Average defines the maximum rate allowed. 0 means disabled
35	Average int64 `json:"average" mapstructure:"average"`
36	// Period defines the period as milliseconds. Default: 1000 (1 second).
37	// The rate is actually defined by dividing average by period.
38	// So for a rate below 1 req/s, one needs to define a period larger than a second.
39	Period int64 `json:"period" mapstructure:"period"`
40	// Burst is the maximum number of requests allowed to go through in the
41	// same arbitrarily small period of time. Default: 1.
42	Burst int `json:"burst" mapstructure:"burst"`
43	// Type defines the rate limiter type:
44	// - rateLimiterTypeGlobal is a global rate limiter independent from the source
45	// - rateLimiterTypeSource is a per-source rate limiter
46	Type int `json:"type" mapstructure:"type"`
47	// Protocols defines the protocols for this rate limiter.
48	// Available protocols are: "SFTP", "FTP", "DAV".
49	// A rate limiter with no protocols defined is disabled
50	Protocols []string `json:"protocols" mapstructure:"protocols"`
51	// AllowList defines a list of IP addresses and IP ranges excluded from rate limiting
52	AllowList []string `json:"allow_list" mapstructure:"mapstructure"`
53	// If the rate limit is exceeded, the defender is enabled, and this is a per-source limiter,
54	// a new defender event will be generated
55	GenerateDefenderEvents bool `json:"generate_defender_events" mapstructure:"generate_defender_events"`
56	// The number of per-ip rate limiters kept in memory will vary between the
57	// soft and hard limit
58	EntriesSoftLimit int `json:"entries_soft_limit" mapstructure:"entries_soft_limit"`
59	EntriesHardLimit int `json:"entries_hard_limit" mapstructure:"entries_hard_limit"`
60}
61
62func (r *RateLimiterConfig) isEnabled() bool {
63	return r.Average > 0 && len(r.Protocols) > 0
64}
65
66func (r *RateLimiterConfig) validate() error {
67	if r.Burst < 1 {
68		return fmt.Errorf("invalid burst %v. It must be >= 1", r.Burst)
69	}
70	if r.Period < 100 {
71		return fmt.Errorf("invalid period %v. It must be >= 100", r.Period)
72	}
73	if r.Type != int(rateLimiterTypeGlobal) && r.Type != int(rateLimiterTypeSource) {
74		return fmt.Errorf("invalid type %v", r.Type)
75	}
76	if r.Type != int(rateLimiterTypeGlobal) {
77		if r.EntriesSoftLimit <= 0 {
78			return fmt.Errorf("invalid entries_soft_limit %v", r.EntriesSoftLimit)
79		}
80		if r.EntriesHardLimit <= r.EntriesSoftLimit {
81			return fmt.Errorf("invalid entries_hard_limit %v must be > %v", r.EntriesHardLimit, r.EntriesSoftLimit)
82		}
83	}
84	r.Protocols = util.RemoveDuplicates(r.Protocols)
85	for _, protocol := range r.Protocols {
86		if !util.IsStringInSlice(protocol, rateLimiterProtocolValues) {
87			return fmt.Errorf("invalid protocol %#v", protocol)
88		}
89	}
90	return nil
91}
92
93func (r *RateLimiterConfig) getLimiter() *rateLimiter {
94	limiter := &rateLimiter{
95		burst:                  r.Burst,
96		globalBucket:           nil,
97		generateDefenderEvents: r.GenerateDefenderEvents,
98	}
99	var maxDelay time.Duration
100	period := time.Duration(r.Period) * time.Millisecond
101	rtl := float64(r.Average*int64(time.Second)) / float64(period)
102	limiter.rate = rate.Limit(rtl)
103	if rtl < 1 {
104		maxDelay = period / 2
105	} else {
106		maxDelay = time.Second / (time.Duration(rtl) * 2)
107	}
108	if maxDelay > 10*time.Second {
109		maxDelay = 10 * time.Second
110	}
111	limiter.maxDelay = maxDelay
112	limiter.buckets = sourceBuckets{
113		buckets:   make(map[string]sourceRateLimiter),
114		hardLimit: r.EntriesHardLimit,
115		softLimit: r.EntriesSoftLimit,
116	}
117	if r.Type != int(rateLimiterTypeSource) {
118		limiter.globalBucket = rate.NewLimiter(limiter.rate, limiter.burst)
119	}
120	return limiter
121}
122
123// RateLimiter defines a rate limiter
124type rateLimiter struct {
125	rate                   rate.Limit
126	burst                  int
127	maxDelay               time.Duration
128	globalBucket           *rate.Limiter
129	buckets                sourceBuckets
130	generateDefenderEvents bool
131	allowList              []func(net.IP) bool
132}
133
134// Wait blocks until the limit allows one event to happen
135// or returns an error if the time to wait exceeds the max
136// allowed delay
137func (rl *rateLimiter) Wait(source string) (time.Duration, error) {
138	if len(rl.allowList) > 0 {
139		ip := net.ParseIP(source)
140		if ip != nil {
141			for idx := range rl.allowList {
142				if rl.allowList[idx](ip) {
143					return 0, nil
144				}
145			}
146		}
147	}
148	var res *rate.Reservation
149	if rl.globalBucket != nil {
150		res = rl.globalBucket.Reserve()
151	} else {
152		var err error
153		res, err = rl.buckets.reserve(source)
154		if err != nil {
155			rateLimiter := rate.NewLimiter(rl.rate, rl.burst)
156			res = rl.buckets.addAndReserve(rateLimiter, source)
157		}
158	}
159	if !res.OK() {
160		return 0, errReserve
161	}
162	delay := res.Delay()
163	if delay > rl.maxDelay {
164		res.Cancel()
165		if rl.generateDefenderEvents && rl.globalBucket == nil {
166			AddDefenderEvent(source, HostEventLimitExceeded)
167		}
168		return delay, fmt.Errorf("rate limit exceed, wait time to respect rate %v, max wait time allowed %v", delay, rl.maxDelay)
169	}
170	time.Sleep(delay)
171	return 0, nil
172}
173
174type sourceRateLimiter struct {
175	lastActivity int64
176	bucket       *rate.Limiter
177}
178
179func (s *sourceRateLimiter) updateLastActivity() {
180	atomic.StoreInt64(&s.lastActivity, time.Now().UnixNano())
181}
182
183func (s *sourceRateLimiter) getLastActivity() int64 {
184	return atomic.LoadInt64(&s.lastActivity)
185}
186
187type sourceBuckets struct {
188	sync.RWMutex
189	buckets   map[string]sourceRateLimiter
190	hardLimit int
191	softLimit int
192}
193
194func (b *sourceBuckets) reserve(source string) (*rate.Reservation, error) {
195	b.RLock()
196	defer b.RUnlock()
197
198	if src, ok := b.buckets[source]; ok {
199		src.updateLastActivity()
200		return src.bucket.Reserve(), nil
201	}
202
203	return nil, errNoBucket
204}
205
206func (b *sourceBuckets) addAndReserve(r *rate.Limiter, source string) *rate.Reservation {
207	b.Lock()
208	defer b.Unlock()
209
210	b.cleanup()
211
212	src := sourceRateLimiter{
213		bucket: r,
214	}
215	src.updateLastActivity()
216	b.buckets[source] = src
217	return src.bucket.Reserve()
218}
219
220func (b *sourceBuckets) cleanup() {
221	if len(b.buckets) >= b.hardLimit {
222		numToRemove := len(b.buckets) - b.softLimit
223
224		kvList := make(kvList, 0, len(b.buckets))
225
226		for k, v := range b.buckets {
227			kvList = append(kvList, kv{
228				Key:   k,
229				Value: v.getLastActivity(),
230			})
231		}
232
233		sort.Sort(kvList)
234
235		for idx, kv := range kvList {
236			if idx >= numToRemove {
237				break
238			}
239
240			delete(b.buckets, kv.Key)
241		}
242	}
243}
244