1// Copyright (C) 2019 Storj Labs, Inc.
2// See LICENSE for copying information.
3
4package web
5
6import (
7	"context"
8	"net"
9	"net/http"
10	"strings"
11	"sync"
12	"time"
13
14	"golang.org/x/time/rate"
15)
16
17// RateLimiterConfig configures a RateLimiter.
18type RateLimiterConfig struct {
19	Duration  time.Duration `help:"the rate at which request are allowed" default:"5m"`
20	Burst     int           `help:"number of events before the limit kicks in" default:"5" testDefault:"3"`
21	NumLimits int           `help:"number of clients whose rate limits we store" default:"1000" testDefault:"10"`
22}
23
24// RateLimiter imposes a rate limit per key.
25type RateLimiter struct {
26	config  RateLimiterConfig
27	mu      sync.Mutex
28	limits  map[string]*userLimit
29	keyFunc func(*http.Request) (string, error)
30}
31
32// userLimit is the per-key limiter.
33type userLimit struct {
34	limiter  *rate.Limiter
35	lastSeen time.Time
36}
37
38// NewIPRateLimiter constructs a RateLimiter that limits based on IP address.
39func NewIPRateLimiter(config RateLimiterConfig) *RateLimiter {
40	return NewRateLimiter(config, GetRequestIP)
41}
42
43// NewRateLimiter constructs a RateLimiter.
44func NewRateLimiter(config RateLimiterConfig, keyFunc func(*http.Request) (string, error)) *RateLimiter {
45	return &RateLimiter{
46		config:  config,
47		limits:  make(map[string]*userLimit),
48		keyFunc: keyFunc,
49	}
50}
51
52// Run occasionally cleans old rate-limiting data, until context cancel.
53func (rl *RateLimiter) Run(ctx context.Context) {
54	cleanupTicker := time.NewTicker(rl.config.Duration)
55	defer cleanupTicker.Stop()
56	for {
57		select {
58		case <-ctx.Done():
59			return
60		case <-cleanupTicker.C:
61			rl.cleanupLimiters()
62		}
63	}
64}
65
66// cleanupLimiters removes old rate limits to free memory.
67func (rl *RateLimiter) cleanupLimiters() {
68	rl.mu.Lock()
69	defer rl.mu.Unlock()
70	for k, v := range rl.limits {
71		if time.Since(v.lastSeen) > rl.config.Duration {
72			delete(rl.limits, k)
73		}
74	}
75}
76
77// Limit applies per-key rate limiting as an HTTP Handler.
78func (rl *RateLimiter) Limit(next http.Handler) http.Handler {
79	return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
80		key, err := rl.keyFunc(r)
81		if err != nil {
82			http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
83			return
84		}
85		limit := rl.getUserLimit(key)
86		if !limit.Allow() {
87			http.Error(w, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests)
88			return
89		}
90		next.ServeHTTP(w, r)
91	})
92}
93
94// GetRequestIP gets the original IP address of the request by handling the request headers.
95func GetRequestIP(r *http.Request) (ip string, err error) {
96	realIP := r.Header.Get("X-REAL-IP")
97	if realIP != "" {
98		return realIP, nil
99	}
100
101	forwardedIPs := r.Header.Get("X-FORWARDED-FOR")
102	if forwardedIPs != "" {
103		ips := strings.Split(forwardedIPs, ", ")
104		if len(ips) > 0 {
105			return ips[0], nil
106		}
107	}
108
109	ip, _, err = net.SplitHostPort(r.RemoteAddr)
110
111	return ip, err
112}
113
114// getUserLimit returns a rate limiter for a key.
115func (rl *RateLimiter) getUserLimit(key string) *rate.Limiter {
116	rl.mu.Lock()
117	defer rl.mu.Unlock()
118
119	v, exists := rl.limits[key]
120	if !exists {
121		if len(rl.limits) >= rl.config.NumLimits {
122			// Tracking only N limits prevents an out-of-memory DOS attack
123			// Returning StatusTooManyRequests would be just as bad
124			// The least-bad option may be to remove the oldest key
125			oldestKey := ""
126			var oldestTime *time.Time
127			for key, v := range rl.limits {
128				// while we're looping, we'd prefer to just delete expired records
129				if time.Since(v.lastSeen) > rl.config.Duration {
130					delete(rl.limits, key)
131				}
132				// but we're prepared to delete the oldest non-expired
133				if oldestTime == nil || v.lastSeen.Before(*oldestTime) {
134					oldestTime = &v.lastSeen
135					oldestKey = key
136				}
137			}
138			// only delete the oldest non-expired if there's still an issue
139			if oldestKey != "" && len(rl.limits) >= rl.config.NumLimits {
140				delete(rl.limits, oldestKey)
141			}
142		}
143		limiter := rate.NewLimiter(rate.Limit(time.Second)/rate.Limit(rl.config.Duration), rl.config.Burst)
144		rl.limits[key] = &userLimit{limiter, time.Now()}
145		return limiter
146	}
147	v.lastSeen = time.Now()
148	return v.limiter
149}
150
151// Burst returns the number of events that happen before the rate limit.
152func (rl *RateLimiter) Burst() int {
153	return rl.config.Burst
154}
155
156// Duration returns the amount of time required between events.
157func (rl *RateLimiter) Duration() time.Duration {
158	return rl.config.Duration
159}
160