1package common 2 3import ( 4 "errors" 5 "fmt" 6 "net" 7 "sort" 8 "sync" 9 "sync/atomic" 10 "time" 11 12 "golang.org/x/time/rate" 13 14 "github.com/drakkan/sftpgo/v2/util" 15) 16 17var ( 18 errNoBucket = errors.New("no bucket found") 19 errReserve = errors.New("unable to reserve token") 20 rateLimiterProtocolValues = []string{ProtocolSSH, ProtocolFTP, ProtocolWebDAV, ProtocolHTTP} 21) 22 23// RateLimiterType defines the supported rate limiters types 24type RateLimiterType int 25 26// Supported rate limiter types 27const ( 28 rateLimiterTypeGlobal RateLimiterType = iota + 1 29 rateLimiterTypeSource 30) 31 32// RateLimiterConfig defines the configuration for a rate limiter 33type RateLimiterConfig struct { 34 // Average defines the maximum rate allowed. 0 means disabled 35 Average int64 `json:"average" mapstructure:"average"` 36 // Period defines the period as milliseconds. Default: 1000 (1 second). 37 // The rate is actually defined by dividing average by period. 38 // So for a rate below 1 req/s, one needs to define a period larger than a second. 39 Period int64 `json:"period" mapstructure:"period"` 40 // Burst is the maximum number of requests allowed to go through in the 41 // same arbitrarily small period of time. Default: 1. 42 Burst int `json:"burst" mapstructure:"burst"` 43 // Type defines the rate limiter type: 44 // - rateLimiterTypeGlobal is a global rate limiter independent from the source 45 // - rateLimiterTypeSource is a per-source rate limiter 46 Type int `json:"type" mapstructure:"type"` 47 // Protocols defines the protocols for this rate limiter. 48 // Available protocols are: "SFTP", "FTP", "DAV". 49 // A rate limiter with no protocols defined is disabled 50 Protocols []string `json:"protocols" mapstructure:"protocols"` 51 // AllowList defines a list of IP addresses and IP ranges excluded from rate limiting 52 AllowList []string `json:"allow_list" mapstructure:"mapstructure"` 53 // If the rate limit is exceeded, the defender is enabled, and this is a per-source limiter, 54 // a new defender event will be generated 55 GenerateDefenderEvents bool `json:"generate_defender_events" mapstructure:"generate_defender_events"` 56 // The number of per-ip rate limiters kept in memory will vary between the 57 // soft and hard limit 58 EntriesSoftLimit int `json:"entries_soft_limit" mapstructure:"entries_soft_limit"` 59 EntriesHardLimit int `json:"entries_hard_limit" mapstructure:"entries_hard_limit"` 60} 61 62func (r *RateLimiterConfig) isEnabled() bool { 63 return r.Average > 0 && len(r.Protocols) > 0 64} 65 66func (r *RateLimiterConfig) validate() error { 67 if r.Burst < 1 { 68 return fmt.Errorf("invalid burst %v. It must be >= 1", r.Burst) 69 } 70 if r.Period < 100 { 71 return fmt.Errorf("invalid period %v. It must be >= 100", r.Period) 72 } 73 if r.Type != int(rateLimiterTypeGlobal) && r.Type != int(rateLimiterTypeSource) { 74 return fmt.Errorf("invalid type %v", r.Type) 75 } 76 if r.Type != int(rateLimiterTypeGlobal) { 77 if r.EntriesSoftLimit <= 0 { 78 return fmt.Errorf("invalid entries_soft_limit %v", r.EntriesSoftLimit) 79 } 80 if r.EntriesHardLimit <= r.EntriesSoftLimit { 81 return fmt.Errorf("invalid entries_hard_limit %v must be > %v", r.EntriesHardLimit, r.EntriesSoftLimit) 82 } 83 } 84 r.Protocols = util.RemoveDuplicates(r.Protocols) 85 for _, protocol := range r.Protocols { 86 if !util.IsStringInSlice(protocol, rateLimiterProtocolValues) { 87 return fmt.Errorf("invalid protocol %#v", protocol) 88 } 89 } 90 return nil 91} 92 93func (r *RateLimiterConfig) getLimiter() *rateLimiter { 94 limiter := &rateLimiter{ 95 burst: r.Burst, 96 globalBucket: nil, 97 generateDefenderEvents: r.GenerateDefenderEvents, 98 } 99 var maxDelay time.Duration 100 period := time.Duration(r.Period) * time.Millisecond 101 rtl := float64(r.Average*int64(time.Second)) / float64(period) 102 limiter.rate = rate.Limit(rtl) 103 if rtl < 1 { 104 maxDelay = period / 2 105 } else { 106 maxDelay = time.Second / (time.Duration(rtl) * 2) 107 } 108 if maxDelay > 10*time.Second { 109 maxDelay = 10 * time.Second 110 } 111 limiter.maxDelay = maxDelay 112 limiter.buckets = sourceBuckets{ 113 buckets: make(map[string]sourceRateLimiter), 114 hardLimit: r.EntriesHardLimit, 115 softLimit: r.EntriesSoftLimit, 116 } 117 if r.Type != int(rateLimiterTypeSource) { 118 limiter.globalBucket = rate.NewLimiter(limiter.rate, limiter.burst) 119 } 120 return limiter 121} 122 123// RateLimiter defines a rate limiter 124type rateLimiter struct { 125 rate rate.Limit 126 burst int 127 maxDelay time.Duration 128 globalBucket *rate.Limiter 129 buckets sourceBuckets 130 generateDefenderEvents bool 131 allowList []func(net.IP) bool 132} 133 134// Wait blocks until the limit allows one event to happen 135// or returns an error if the time to wait exceeds the max 136// allowed delay 137func (rl *rateLimiter) Wait(source string) (time.Duration, error) { 138 if len(rl.allowList) > 0 { 139 ip := net.ParseIP(source) 140 if ip != nil { 141 for idx := range rl.allowList { 142 if rl.allowList[idx](ip) { 143 return 0, nil 144 } 145 } 146 } 147 } 148 var res *rate.Reservation 149 if rl.globalBucket != nil { 150 res = rl.globalBucket.Reserve() 151 } else { 152 var err error 153 res, err = rl.buckets.reserve(source) 154 if err != nil { 155 rateLimiter := rate.NewLimiter(rl.rate, rl.burst) 156 res = rl.buckets.addAndReserve(rateLimiter, source) 157 } 158 } 159 if !res.OK() { 160 return 0, errReserve 161 } 162 delay := res.Delay() 163 if delay > rl.maxDelay { 164 res.Cancel() 165 if rl.generateDefenderEvents && rl.globalBucket == nil { 166 AddDefenderEvent(source, HostEventLimitExceeded) 167 } 168 return delay, fmt.Errorf("rate limit exceed, wait time to respect rate %v, max wait time allowed %v", delay, rl.maxDelay) 169 } 170 time.Sleep(delay) 171 return 0, nil 172} 173 174type sourceRateLimiter struct { 175 lastActivity int64 176 bucket *rate.Limiter 177} 178 179func (s *sourceRateLimiter) updateLastActivity() { 180 atomic.StoreInt64(&s.lastActivity, time.Now().UnixNano()) 181} 182 183func (s *sourceRateLimiter) getLastActivity() int64 { 184 return atomic.LoadInt64(&s.lastActivity) 185} 186 187type sourceBuckets struct { 188 sync.RWMutex 189 buckets map[string]sourceRateLimiter 190 hardLimit int 191 softLimit int 192} 193 194func (b *sourceBuckets) reserve(source string) (*rate.Reservation, error) { 195 b.RLock() 196 defer b.RUnlock() 197 198 if src, ok := b.buckets[source]; ok { 199 src.updateLastActivity() 200 return src.bucket.Reserve(), nil 201 } 202 203 return nil, errNoBucket 204} 205 206func (b *sourceBuckets) addAndReserve(r *rate.Limiter, source string) *rate.Reservation { 207 b.Lock() 208 defer b.Unlock() 209 210 b.cleanup() 211 212 src := sourceRateLimiter{ 213 bucket: r, 214 } 215 src.updateLastActivity() 216 b.buckets[source] = src 217 return src.bucket.Reserve() 218} 219 220func (b *sourceBuckets) cleanup() { 221 if len(b.buckets) >= b.hardLimit { 222 numToRemove := len(b.buckets) - b.softLimit 223 224 kvList := make(kvList, 0, len(b.buckets)) 225 226 for k, v := range b.buckets { 227 kvList = append(kvList, kv{ 228 Key: k, 229 Value: v.getLastActivity(), 230 }) 231 } 232 233 sort.Sort(kvList) 234 235 for idx, kv := range kvList { 236 if idx >= numToRemove { 237 break 238 } 239 240 delete(b.buckets, kv.Key) 241 } 242 } 243} 244