1 #include "misc-rstfilter.h"
2 #include "util-malloc.h"
3 #include "siphash24.h"
4 #include <time.h>
5
6 struct ResetFilter
7 {
8 unsigned long long seed;
9 size_t bucket_count;
10 size_t bucket_mask;
11 unsigned counter;
12 unsigned char *buckets;
13 };
14
15 static size_t
next_pow2(size_t n)16 next_pow2(size_t n)
17 {
18 size_t bit_count = 0;
19
20 /* Always have at least one bit */
21 if (n == 0)
22 return 1;
23
24 /* If already a power-of-two, then return that */
25 if ((n & (n - 1)) == 0)
26 return n;
27
28 /* Count the number of bits */
29 while (n != 0) {
30 n >>= 1;
31 bit_count += 1;
32 }
33
34 return (size_t)1 << (size_t)bit_count;
35 }
36
37 struct ResetFilter *
rstfilter_create(unsigned long long seed,size_t bucket_count)38 rstfilter_create(unsigned long long seed, size_t bucket_count)
39 {
40 struct ResetFilter *rf;
41
42 rf = CALLOC(1, sizeof(*rf));
43 rf->seed = seed;
44 rf->bucket_count = next_pow2(bucket_count);
45 rf->bucket_mask = rf->bucket_count - 1;
46 rf->buckets = CALLOC(rf->bucket_count/2, sizeof(*rf->buckets));
47
48 return rf;
49 }
50
51
52 void
rstfilter_destroy(struct ResetFilter * rf)53 rstfilter_destroy(struct ResetFilter *rf)
54 {
55 if (rf == NULL)
56 return;
57 free(rf->buckets);
58 free(rf);
59 }
60
61 int
rstfilter_is_filter(struct ResetFilter * rf,ipaddress src_ip,unsigned src_port,ipaddress dst_ip,unsigned dst_port)62 rstfilter_is_filter(struct ResetFilter *rf,
63 ipaddress src_ip, unsigned src_port,
64 ipaddress dst_ip, unsigned dst_port)
65 {
66 uint64_t hash;
67 uint64_t input[5];
68 uint64_t key[2];
69 size_t index;
70 unsigned char *p;
71 int result = 0;
72
73 /*
74 * Setup the input
75 */
76 switch (src_ip.version) {
77 case 4:
78 input[0] = src_ip.ipv4;
79 input[1] = src_port;
80 input[2] = dst_ip.ipv4;
81 input[3] = dst_port;
82 break;
83 case 6:
84 input[0] = src_ip.ipv6.hi;
85 input[1] = src_ip.ipv6.lo;
86 input[2] = dst_ip.ipv6.hi;
87 input[3] = dst_ip.ipv6.lo;
88 input[4] = src_port<<16 | dst_port;
89 break;
90 }
91 key[0] = rf->seed;
92 key[1] = rf->seed;
93
94 /*
95 * Grab the bucket
96 */
97 hash = siphash24(input, sizeof(input), key);
98 index = hash & rf->bucket_mask;
99
100 /*
101 * Find the result (1=filterout, 0=sendrst)
102 */
103 p = &rf->buckets[index/2];
104 if (index & 1) {
105 if ((*p & 0x0F) == 0x0F)
106 result = 1; /* filter out */
107 else
108 *p = (*p) + 0x01;
109 } else {
110 if ((*p & 0xF0) == 0xF0)
111 result = 1; /* filter out */
112 else
113 *p = (*p) + 0x10;
114 }
115
116 /*
117 * Empty a random bucket
118 */
119 input[0] = (unsigned)hash;
120 input[1] = rf->counter++;
121 hash = siphash24(input, sizeof(input), key);
122 index = hash & rf->bucket_mask;
123 p = &rf->buckets[index/2];
124 if (index & 1) {
125 if ((*p & 0x0F))
126 *p = (*p) - 0x01;
127 } else {
128 if ((*p & 0xF0))
129 *p = (*p) - 0x10;
130 }
131
132 return result;
133 }
134
135
136
137 int
rstfilter_selftest(void)138 rstfilter_selftest(void)
139 {
140 struct ResetFilter *rf;
141 size_t i;
142 unsigned count_filtered = 0;
143 unsigned count_passed = 0;
144
145 ipaddress src;
146 ipaddress dst;
147
148 src.version = 4;
149 src.ipv4 = 1;
150 dst.version = 4;
151 dst.ipv4 = 3;
152
153 rf = rstfilter_create(time(0), 64);
154
155 /* Verify the first 15 packets pass the filter */
156 for (i=0; i<15; i++) {
157 int x;
158
159 x = rstfilter_is_filter(rf, src, 2, dst, 4);
160 if (x) {
161 fprintf(stderr, "[-] rstfilter failed, line=%u\n", __LINE__);
162 return 1;
163 }
164 }
165
166 /* Now run 10000 more times */
167 for (i=0; i<1000; i++) {
168 int x;
169 x = rstfilter_is_filter(rf, src, 2, dst, 4);
170 count_filtered += x;
171 count_passed += !x;
172 }
173
174 /* SOME must have passed, due to us emptying random buckets */
175 if (count_passed == 0) {
176 fprintf(stderr, "[-] rstfilter failed, line=%u\n", __LINE__);
177 return 1;
178 }
179
180 /* However, while some pass, the vast majority should be filtered */
181 if (count_passed > count_filtered/10) {
182 fprintf(stderr, "[-] rstfilter failed, line=%u\n", __LINE__);
183 return 1;
184 }
185 //printf("filtered=%u passed=%u\n", count_filtered, count_passed);
186 return 0;
187 }
188
189
190
191
192
193
194
195