1/* SPDX-License-Identifier: MIT
2 *
3 * Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
4 */
5
6package ratelimiter
7
8import (
9	"net"
10	"testing"
11	"time"
12)
13
14type result struct {
15	allowed bool
16	text    string
17	wait    time.Duration
18}
19
20func TestRatelimiter(t *testing.T) {
21	var rate Ratelimiter
22	var expectedResults []result
23
24	nano := func(nano int64) time.Duration {
25		return time.Nanosecond * time.Duration(nano)
26	}
27
28	add := func(res result) {
29		expectedResults = append(
30			expectedResults,
31			res,
32		)
33	}
34
35	for i := 0; i < packetsBurstable; i++ {
36		add(result{
37			allowed: true,
38			text:    "initial burst",
39		})
40	}
41
42	add(result{
43		allowed: false,
44		text:    "after burst",
45	})
46
47	add(result{
48		allowed: true,
49		wait:    nano(time.Second.Nanoseconds() / packetsPerSecond),
50		text:    "filling tokens for single packet",
51	})
52
53	add(result{
54		allowed: false,
55		text:    "not having refilled enough",
56	})
57
58	add(result{
59		allowed: true,
60		wait:    2 * (nano(time.Second.Nanoseconds() / packetsPerSecond)),
61		text:    "filling tokens for two packet burst",
62	})
63
64	add(result{
65		allowed: true,
66		text:    "second packet in 2 packet burst",
67	})
68
69	add(result{
70		allowed: false,
71		text:    "packet following 2 packet burst",
72	})
73
74	ips := []net.IP{
75		net.ParseIP("127.0.0.1"),
76		net.ParseIP("192.168.1.1"),
77		net.ParseIP("172.167.2.3"),
78		net.ParseIP("97.231.252.215"),
79		net.ParseIP("248.97.91.167"),
80		net.ParseIP("188.208.233.47"),
81		net.ParseIP("104.2.183.179"),
82		net.ParseIP("72.129.46.120"),
83		net.ParseIP("2001:0db8:0a0b:12f0:0000:0000:0000:0001"),
84		net.ParseIP("f5c2:818f:c052:655a:9860:b136:6894:25f0"),
85		net.ParseIP("b2d7:15ab:48a7:b07c:a541:f144:a9fe:54fc"),
86		net.ParseIP("a47b:786e:1671:a22b:d6f9:4ab0:abc7:c918"),
87		net.ParseIP("ea1e:d155:7f7a:98fb:2bf5:9483:80f6:5445"),
88		net.ParseIP("3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4"),
89	}
90
91	now := time.Now()
92	rate.timeNow = func() time.Time {
93		return now
94	}
95	defer func() {
96		// Lock to avoid data race with cleanup goroutine from Init.
97		rate.mu.Lock()
98		defer rate.mu.Unlock()
99
100		rate.timeNow = time.Now
101	}()
102	timeSleep := func(d time.Duration) {
103		now = now.Add(d + 1)
104		rate.cleanup()
105	}
106
107	rate.Init()
108	defer rate.Close()
109
110	for i, res := range expectedResults {
111		timeSleep(res.wait)
112		for _, ip := range ips {
113			allowed := rate.Allow(ip)
114			if allowed != res.allowed {
115				t.Fatalf("%d: %s: rate.Allow(%q)=%v, want %v", i, res.text, ip, allowed, res.allowed)
116			}
117		}
118	}
119}
120