1// Copyright (C) 2019 Storj Labs, Inc. 2// See LICENSE for copying information. 3 4package web 5 6import ( 7 "context" 8 "net" 9 "net/http" 10 "strings" 11 "sync" 12 "time" 13 14 "golang.org/x/time/rate" 15) 16 17// RateLimiterConfig configures a RateLimiter. 18type RateLimiterConfig struct { 19 Duration time.Duration `help:"the rate at which request are allowed" default:"5m"` 20 Burst int `help:"number of events before the limit kicks in" default:"5" testDefault:"3"` 21 NumLimits int `help:"number of clients whose rate limits we store" default:"1000" testDefault:"10"` 22} 23 24// RateLimiter imposes a rate limit per key. 25type RateLimiter struct { 26 config RateLimiterConfig 27 mu sync.Mutex 28 limits map[string]*userLimit 29 keyFunc func(*http.Request) (string, error) 30} 31 32// userLimit is the per-key limiter. 33type userLimit struct { 34 limiter *rate.Limiter 35 lastSeen time.Time 36} 37 38// NewIPRateLimiter constructs a RateLimiter that limits based on IP address. 39func NewIPRateLimiter(config RateLimiterConfig) *RateLimiter { 40 return NewRateLimiter(config, GetRequestIP) 41} 42 43// NewRateLimiter constructs a RateLimiter. 44func NewRateLimiter(config RateLimiterConfig, keyFunc func(*http.Request) (string, error)) *RateLimiter { 45 return &RateLimiter{ 46 config: config, 47 limits: make(map[string]*userLimit), 48 keyFunc: keyFunc, 49 } 50} 51 52// Run occasionally cleans old rate-limiting data, until context cancel. 53func (rl *RateLimiter) Run(ctx context.Context) { 54 cleanupTicker := time.NewTicker(rl.config.Duration) 55 defer cleanupTicker.Stop() 56 for { 57 select { 58 case <-ctx.Done(): 59 return 60 case <-cleanupTicker.C: 61 rl.cleanupLimiters() 62 } 63 } 64} 65 66// cleanupLimiters removes old rate limits to free memory. 67func (rl *RateLimiter) cleanupLimiters() { 68 rl.mu.Lock() 69 defer rl.mu.Unlock() 70 for k, v := range rl.limits { 71 if time.Since(v.lastSeen) > rl.config.Duration { 72 delete(rl.limits, k) 73 } 74 } 75} 76 77// Limit applies per-key rate limiting as an HTTP Handler. 78func (rl *RateLimiter) Limit(next http.Handler) http.Handler { 79 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 80 key, err := rl.keyFunc(r) 81 if err != nil { 82 http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) 83 return 84 } 85 limit := rl.getUserLimit(key) 86 if !limit.Allow() { 87 http.Error(w, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests) 88 return 89 } 90 next.ServeHTTP(w, r) 91 }) 92} 93 94// GetRequestIP gets the original IP address of the request by handling the request headers. 95func GetRequestIP(r *http.Request) (ip string, err error) { 96 realIP := r.Header.Get("X-REAL-IP") 97 if realIP != "" { 98 return realIP, nil 99 } 100 101 forwardedIPs := r.Header.Get("X-FORWARDED-FOR") 102 if forwardedIPs != "" { 103 ips := strings.Split(forwardedIPs, ", ") 104 if len(ips) > 0 { 105 return ips[0], nil 106 } 107 } 108 109 ip, _, err = net.SplitHostPort(r.RemoteAddr) 110 111 return ip, err 112} 113 114// getUserLimit returns a rate limiter for a key. 115func (rl *RateLimiter) getUserLimit(key string) *rate.Limiter { 116 rl.mu.Lock() 117 defer rl.mu.Unlock() 118 119 v, exists := rl.limits[key] 120 if !exists { 121 if len(rl.limits) >= rl.config.NumLimits { 122 // Tracking only N limits prevents an out-of-memory DOS attack 123 // Returning StatusTooManyRequests would be just as bad 124 // The least-bad option may be to remove the oldest key 125 oldestKey := "" 126 var oldestTime *time.Time 127 for key, v := range rl.limits { 128 // while we're looping, we'd prefer to just delete expired records 129 if time.Since(v.lastSeen) > rl.config.Duration { 130 delete(rl.limits, key) 131 } 132 // but we're prepared to delete the oldest non-expired 133 if oldestTime == nil || v.lastSeen.Before(*oldestTime) { 134 oldestTime = &v.lastSeen 135 oldestKey = key 136 } 137 } 138 // only delete the oldest non-expired if there's still an issue 139 if oldestKey != "" && len(rl.limits) >= rl.config.NumLimits { 140 delete(rl.limits, oldestKey) 141 } 142 } 143 limiter := rate.NewLimiter(rate.Limit(time.Second)/rate.Limit(rl.config.Duration), rl.config.Burst) 144 rl.limits[key] = &userLimit{limiter, time.Now()} 145 return limiter 146 } 147 v.lastSeen = time.Now() 148 return v.limiter 149} 150 151// Burst returns the number of events that happen before the rate limit. 152func (rl *RateLimiter) Burst() int { 153 return rl.config.Burst 154} 155 156// Duration returns the amount of time required between events. 157func (rl *RateLimiter) Duration() time.Duration { 158 return rl.config.Duration 159} 160