1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright Amazon.com Inc. or its affiliates. */
3 
4 #include "vmlinux.h"
5 
6 #include <bpf/bpf_helpers.h>
7 #include <bpf/bpf_endian.h>
8 #include "bpf_tracing_net.h"
9 #include "bpf_kfuncs.h"
10 #include "test_siphash.h"
11 #include "test_tcp_custom_syncookie.h"
12 
13 #define MAX_PACKET_OFF 0xffff
14 
15 /* Hash is calculated for each client and split into ISN and TS.
16  *
17  *       MSB                                   LSB
18  * ISN:  | 31 ... 8 | 7 6 |   5 |    4 | 3 2 1 0 |
19  *       |   Hash_1 | MSS | ECN | SACK |  WScale |
20  *
21  * TS:   | 31 ... 8 |          7 ... 0           |
22  *       |   Random |           Hash_2           |
23  */
24 #define COOKIE_BITS	8
25 #define COOKIE_MASK	(((__u32)1 << COOKIE_BITS) - 1)
26 
27 enum {
28 	/* 0xf is invalid thus means that SYN did not have WScale. */
29 	BPF_SYNCOOKIE_WSCALE_MASK	= (1 << 4) - 1,
30 	BPF_SYNCOOKIE_SACK		= (1 << 4),
31 	BPF_SYNCOOKIE_ECN		= (1 << 5),
32 };
33 
34 #define MSS_LOCAL_IPV4	65495
35 #define MSS_LOCAL_IPV6	65476
36 
37 const __u16 msstab4[] = {
38 	536,
39 	1300,
40 	1460,
41 	MSS_LOCAL_IPV4,
42 };
43 
44 const __u16 msstab6[] = {
45 	1280 - 60, /* IPV6_MIN_MTU - 60 */
46 	1480 - 60,
47 	9000 - 60,
48 	MSS_LOCAL_IPV6,
49 };
50 
51 static siphash_key_t test_key_siphash = {
52 	{ 0x0706050403020100ULL, 0x0f0e0d0c0b0a0908ULL }
53 };
54 
55 struct tcp_syncookie {
56 	struct __sk_buff *skb;
57 	void *data;
58 	void *data_end;
59 	struct ethhdr *eth;
60 	struct iphdr *ipv4;
61 	struct ipv6hdr *ipv6;
62 	struct tcphdr *tcp;
63 	__be32 *ptr32;
64 	struct bpf_tcp_req_attrs attrs;
65 	u32 off;
66 	u32 cookie;
67 	u64 first;
68 };
69 
70 bool handled_syn, handled_ack;
71 
72 static int tcp_load_headers(struct tcp_syncookie *ctx)
73 {
74 	ctx->data = (void *)(long)ctx->skb->data;
75 	ctx->data_end = (void *)(long)ctx->skb->data_end;
76 	ctx->eth = (struct ethhdr *)(long)ctx->skb->data;
77 
78 	if (ctx->eth + 1 > ctx->data_end)
79 		goto err;
80 
81 	switch (bpf_ntohs(ctx->eth->h_proto)) {
82 	case ETH_P_IP:
83 		ctx->ipv4 = (struct iphdr *)(ctx->eth + 1);
84 
85 		if (ctx->ipv4 + 1 > ctx->data_end)
86 			goto err;
87 
88 		if (ctx->ipv4->ihl != sizeof(*ctx->ipv4) / 4)
89 			goto err;
90 
91 		if (ctx->ipv4->version != 4)
92 			goto err;
93 
94 		if (ctx->ipv4->protocol != IPPROTO_TCP)
95 			goto err;
96 
97 		ctx->tcp = (struct tcphdr *)(ctx->ipv4 + 1);
98 		break;
99 	case ETH_P_IPV6:
100 		ctx->ipv6 = (struct ipv6hdr *)(ctx->eth + 1);
101 
102 		if (ctx->ipv6 + 1 > ctx->data_end)
103 			goto err;
104 
105 		if (ctx->ipv6->version != 6)
106 			goto err;
107 
108 		if (ctx->ipv6->nexthdr != NEXTHDR_TCP)
109 			goto err;
110 
111 		ctx->tcp = (struct tcphdr *)(ctx->ipv6 + 1);
112 		break;
113 	default:
114 		goto err;
115 	}
116 
117 	if (ctx->tcp + 1 > ctx->data_end)
118 		goto err;
119 
120 	return 0;
121 err:
122 	return -1;
123 }
124 
125 static int tcp_reload_headers(struct tcp_syncookie *ctx)
126 {
127 	/* Without volatile,
128 	 * R3 32-bit pointer arithmetic prohibited
129 	 */
130 	volatile u64 data_len = ctx->skb->data_end - ctx->skb->data;
131 
132 	if (ctx->tcp->doff < sizeof(*ctx->tcp) / 4)
133 		goto err;
134 
135 	/* Needed to calculate csum and parse TCP options. */
136 	if (bpf_skb_change_tail(ctx->skb, data_len + 60 - ctx->tcp->doff * 4, 0))
137 		goto err;
138 
139 	ctx->data = (void *)(long)ctx->skb->data;
140 	ctx->data_end = (void *)(long)ctx->skb->data_end;
141 	ctx->eth = (struct ethhdr *)(long)ctx->skb->data;
142 	if (ctx->ipv4) {
143 		ctx->ipv4 = (struct iphdr *)(ctx->eth + 1);
144 		ctx->ipv6 = NULL;
145 		ctx->tcp = (struct tcphdr *)(ctx->ipv4 + 1);
146 	} else {
147 		ctx->ipv4 = NULL;
148 		ctx->ipv6 = (struct ipv6hdr *)(ctx->eth + 1);
149 		ctx->tcp = (struct tcphdr *)(ctx->ipv6 + 1);
150 	}
151 
152 	if ((void *)ctx->tcp + 60 > ctx->data_end)
153 		goto err;
154 
155 	return 0;
156 err:
157 	return -1;
158 }
159 
160 static __sum16 tcp_v4_csum(struct tcp_syncookie *ctx, __wsum csum)
161 {
162 	return csum_tcpudp_magic(ctx->ipv4->saddr, ctx->ipv4->daddr,
163 				 ctx->tcp->doff * 4, IPPROTO_TCP, csum);
164 }
165 
166 static __sum16 tcp_v6_csum(struct tcp_syncookie *ctx, __wsum csum)
167 {
168 	return csum_ipv6_magic(&ctx->ipv6->saddr, &ctx->ipv6->daddr,
169 			       ctx->tcp->doff * 4, IPPROTO_TCP, csum);
170 }
171 
172 static int tcp_validate_header(struct tcp_syncookie *ctx)
173 {
174 	s64 csum;
175 
176 	if (tcp_reload_headers(ctx))
177 		goto err;
178 
179 	csum = bpf_csum_diff(0, 0, (void *)ctx->tcp, ctx->tcp->doff * 4, 0);
180 	if (csum < 0)
181 		goto err;
182 
183 	if (ctx->ipv4) {
184 		/* check tcp_v4_csum(csum) is 0 if not on lo. */
185 
186 		csum = bpf_csum_diff(0, 0, (void *)ctx->ipv4, ctx->ipv4->ihl * 4, 0);
187 		if (csum < 0)
188 			goto err;
189 
190 		if (csum_fold(csum) != 0)
191 			goto err;
192 	} else if (ctx->ipv6) {
193 		/* check tcp_v6_csum(csum) is 0 if not on lo. */
194 	}
195 
196 	return 0;
197 err:
198 	return -1;
199 }
200 
201 static __always_inline void *next(struct tcp_syncookie *ctx, __u32 sz)
202 {
203 	__u64 off = ctx->off;
204 	__u8 *data;
205 
206 	/* Verifier forbids access to packet when offset exceeds MAX_PACKET_OFF */
207 	if (off > MAX_PACKET_OFF - sz)
208 		return NULL;
209 
210 	data = ctx->data + off;
211 	barrier_var(data);
212 	if (data + sz >= ctx->data_end)
213 		return NULL;
214 
215 	ctx->off += sz;
216 	return data;
217 }
218 
219 static int tcp_parse_option(__u32 index, struct tcp_syncookie *ctx)
220 {
221 	__u8 *opcode, *opsize, *wscale;
222 	__u32 *tsval, *tsecr;
223 	__u16 *mss;
224 	__u32 off;
225 
226 	off = ctx->off;
227 	opcode = next(ctx, 1);
228 	if (!opcode)
229 		goto stop;
230 
231 	if (*opcode == TCPOPT_EOL)
232 		goto stop;
233 
234 	if (*opcode == TCPOPT_NOP)
235 		goto next;
236 
237 	opsize = next(ctx, 1);
238 	if (!opsize)
239 		goto stop;
240 
241 	if (*opsize < 2)
242 		goto stop;
243 
244 	switch (*opcode) {
245 	case TCPOPT_MSS:
246 		mss = next(ctx, 2);
247 		if (*opsize == TCPOLEN_MSS && ctx->tcp->syn && mss)
248 			ctx->attrs.mss = get_unaligned_be16(mss);
249 		break;
250 	case TCPOPT_WINDOW:
251 		wscale = next(ctx, 1);
252 		if (*opsize == TCPOLEN_WINDOW && ctx->tcp->syn && wscale) {
253 			ctx->attrs.wscale_ok = 1;
254 			ctx->attrs.snd_wscale = *wscale;
255 		}
256 		break;
257 	case TCPOPT_TIMESTAMP:
258 		tsval = next(ctx, 4);
259 		tsecr = next(ctx, 4);
260 		if (*opsize == TCPOLEN_TIMESTAMP && tsval && tsecr) {
261 			ctx->attrs.rcv_tsval = get_unaligned_be32(tsval);
262 			ctx->attrs.rcv_tsecr = get_unaligned_be32(tsecr);
263 
264 			if (ctx->tcp->syn && ctx->attrs.rcv_tsecr)
265 				ctx->attrs.tstamp_ok = 0;
266 			else
267 				ctx->attrs.tstamp_ok = 1;
268 		}
269 		break;
270 	case TCPOPT_SACK_PERM:
271 		if (*opsize == TCPOLEN_SACK_PERM && ctx->tcp->syn)
272 			ctx->attrs.sack_ok = 1;
273 		break;
274 	}
275 
276 	ctx->off = off + *opsize;
277 next:
278 	return 0;
279 stop:
280 	return 1;
281 }
282 
283 static void tcp_parse_options(struct tcp_syncookie *ctx)
284 {
285 	ctx->off = (__u8 *)(ctx->tcp + 1) - (__u8 *)ctx->data,
286 
287 	bpf_loop(40, tcp_parse_option, ctx, 0);
288 }
289 
290 static int tcp_validate_sysctl(struct tcp_syncookie *ctx)
291 {
292 	if ((ctx->ipv4 && ctx->attrs.mss != MSS_LOCAL_IPV4) ||
293 	    (ctx->ipv6 && ctx->attrs.mss != MSS_LOCAL_IPV6))
294 		goto err;
295 
296 	if (!ctx->attrs.wscale_ok || ctx->attrs.snd_wscale != 7)
297 		goto err;
298 
299 	if (!ctx->attrs.tstamp_ok)
300 		goto err;
301 
302 	if (!ctx->attrs.sack_ok)
303 		goto err;
304 
305 	if (!ctx->tcp->ece || !ctx->tcp->cwr)
306 		goto err;
307 
308 	return 0;
309 err:
310 	return -1;
311 }
312 
313 static void tcp_prepare_cookie(struct tcp_syncookie *ctx)
314 {
315 	u32 seq = bpf_ntohl(ctx->tcp->seq);
316 	u64 first = 0, second;
317 	int mssind = 0;
318 	u32 hash;
319 
320 	if (ctx->ipv4) {
321 		for (mssind = ARRAY_SIZE(msstab4) - 1; mssind; mssind--)
322 			if (ctx->attrs.mss >= msstab4[mssind])
323 				break;
324 
325 		ctx->attrs.mss = msstab4[mssind];
326 
327 		first = (u64)ctx->ipv4->saddr << 32 | ctx->ipv4->daddr;
328 	} else if (ctx->ipv6) {
329 		for (mssind = ARRAY_SIZE(msstab6) - 1; mssind; mssind--)
330 			if (ctx->attrs.mss >= msstab6[mssind])
331 				break;
332 
333 		ctx->attrs.mss = msstab6[mssind];
334 
335 		first = (u64)ctx->ipv6->saddr.in6_u.u6_addr8[0] << 32 |
336 			ctx->ipv6->daddr.in6_u.u6_addr32[0];
337 	}
338 
339 	second = (u64)seq << 32 | ctx->tcp->source << 16 | ctx->tcp->dest;
340 	hash = siphash_2u64(first, second, &test_key_siphash);
341 
342 	if (ctx->attrs.tstamp_ok) {
343 		ctx->attrs.rcv_tsecr = bpf_get_prandom_u32();
344 		ctx->attrs.rcv_tsecr &= ~COOKIE_MASK;
345 		ctx->attrs.rcv_tsecr |= hash & COOKIE_MASK;
346 	}
347 
348 	hash &= ~COOKIE_MASK;
349 	hash |= mssind << 6;
350 
351 	if (ctx->attrs.wscale_ok)
352 		hash |= ctx->attrs.snd_wscale & BPF_SYNCOOKIE_WSCALE_MASK;
353 
354 	if (ctx->attrs.sack_ok)
355 		hash |= BPF_SYNCOOKIE_SACK;
356 
357 	if (ctx->attrs.tstamp_ok && ctx->tcp->ece && ctx->tcp->cwr)
358 		hash |= BPF_SYNCOOKIE_ECN;
359 
360 	ctx->cookie = hash;
361 }
362 
363 static void tcp_write_options(struct tcp_syncookie *ctx)
364 {
365 	ctx->ptr32 = (__be32 *)(ctx->tcp + 1);
366 
367 	*ctx->ptr32++ = bpf_htonl(TCPOPT_MSS << 24 | TCPOLEN_MSS << 16 |
368 				  ctx->attrs.mss);
369 
370 	if (ctx->attrs.wscale_ok)
371 		*ctx->ptr32++ = bpf_htonl(TCPOPT_NOP << 24 |
372 					  TCPOPT_WINDOW << 16 |
373 					  TCPOLEN_WINDOW << 8 |
374 					  ctx->attrs.snd_wscale);
375 
376 	if (ctx->attrs.tstamp_ok) {
377 		if (ctx->attrs.sack_ok)
378 			*ctx->ptr32++ = bpf_htonl(TCPOPT_SACK_PERM << 24 |
379 						  TCPOLEN_SACK_PERM << 16 |
380 						  TCPOPT_TIMESTAMP << 8 |
381 						  TCPOLEN_TIMESTAMP);
382 		else
383 			*ctx->ptr32++ = bpf_htonl(TCPOPT_NOP << 24 |
384 						  TCPOPT_NOP << 16 |
385 						  TCPOPT_TIMESTAMP << 8 |
386 						  TCPOLEN_TIMESTAMP);
387 
388 		*ctx->ptr32++ = bpf_htonl(ctx->attrs.rcv_tsecr);
389 		*ctx->ptr32++ = bpf_htonl(ctx->attrs.rcv_tsval);
390 	} else if (ctx->attrs.sack_ok) {
391 		*ctx->ptr32++ = bpf_htonl(TCPOPT_NOP << 24 |
392 					  TCPOPT_NOP << 16 |
393 					  TCPOPT_SACK_PERM << 8 |
394 					  TCPOLEN_SACK_PERM);
395 	}
396 }
397 
398 static int tcp_handle_syn(struct tcp_syncookie *ctx)
399 {
400 	s64 csum;
401 
402 	if (tcp_validate_header(ctx))
403 		goto err;
404 
405 	tcp_parse_options(ctx);
406 
407 	if (tcp_validate_sysctl(ctx))
408 		goto err;
409 
410 	tcp_prepare_cookie(ctx);
411 	tcp_write_options(ctx);
412 
413 	swap(ctx->tcp->source, ctx->tcp->dest);
414 	ctx->tcp->check = 0;
415 	ctx->tcp->ack_seq = bpf_htonl(bpf_ntohl(ctx->tcp->seq) + 1);
416 	ctx->tcp->seq = bpf_htonl(ctx->cookie);
417 	ctx->tcp->doff = ((long)ctx->ptr32 - (long)ctx->tcp) >> 2;
418 	ctx->tcp->ack = 1;
419 	if (!ctx->attrs.tstamp_ok || !ctx->tcp->ece || !ctx->tcp->cwr)
420 		ctx->tcp->ece = 0;
421 	ctx->tcp->cwr = 0;
422 
423 	csum = bpf_csum_diff(0, 0, (void *)ctx->tcp, ctx->tcp->doff * 4, 0);
424 	if (csum < 0)
425 		goto err;
426 
427 	if (ctx->ipv4) {
428 		swap(ctx->ipv4->saddr, ctx->ipv4->daddr);
429 		ctx->tcp->check = tcp_v4_csum(ctx, csum);
430 
431 		ctx->ipv4->check = 0;
432 		ctx->ipv4->tos = 0;
433 		ctx->ipv4->tot_len = bpf_htons((long)ctx->ptr32 - (long)ctx->ipv4);
434 		ctx->ipv4->id = 0;
435 		ctx->ipv4->ttl = 64;
436 
437 		csum = bpf_csum_diff(0, 0, (void *)ctx->ipv4, sizeof(*ctx->ipv4), 0);
438 		if (csum < 0)
439 			goto err;
440 
441 		ctx->ipv4->check = csum_fold(csum);
442 	} else if (ctx->ipv6) {
443 		swap(ctx->ipv6->saddr, ctx->ipv6->daddr);
444 		ctx->tcp->check = tcp_v6_csum(ctx, csum);
445 
446 		*(__be32 *)ctx->ipv6 = bpf_htonl(0x60000000);
447 		ctx->ipv6->payload_len = bpf_htons((long)ctx->ptr32 - (long)ctx->tcp);
448 		ctx->ipv6->hop_limit = 64;
449 	}
450 
451 	swap_array(ctx->eth->h_source, ctx->eth->h_dest);
452 
453 	if (bpf_skb_change_tail(ctx->skb, (long)ctx->ptr32 - (long)ctx->eth, 0))
454 		goto err;
455 
456 	return bpf_redirect(ctx->skb->ifindex, 0);
457 err:
458 	return TC_ACT_SHOT;
459 }
460 
461 static int tcp_validate_cookie(struct tcp_syncookie *ctx)
462 {
463 	u32 cookie = bpf_ntohl(ctx->tcp->ack_seq) - 1;
464 	u32 seq = bpf_ntohl(ctx->tcp->seq) - 1;
465 	u64 first = 0, second;
466 	int mssind;
467 	u32 hash;
468 
469 	if (ctx->ipv4)
470 		first = (u64)ctx->ipv4->saddr << 32 | ctx->ipv4->daddr;
471 	else if (ctx->ipv6)
472 		first = (u64)ctx->ipv6->saddr.in6_u.u6_addr8[0] << 32 |
473 			ctx->ipv6->daddr.in6_u.u6_addr32[0];
474 
475 	second = (u64)seq << 32 | ctx->tcp->source << 16 | ctx->tcp->dest;
476 	hash = siphash_2u64(first, second, &test_key_siphash);
477 
478 	if (ctx->attrs.tstamp_ok)
479 		hash -= ctx->attrs.rcv_tsecr & COOKIE_MASK;
480 	else
481 		hash &= ~COOKIE_MASK;
482 
483 	hash -= cookie & ~COOKIE_MASK;
484 	if (hash)
485 		goto err;
486 
487 	mssind = (cookie & (3 << 6)) >> 6;
488 	if (ctx->ipv4) {
489 		if (mssind > ARRAY_SIZE(msstab4))
490 			goto err;
491 
492 		ctx->attrs.mss = msstab4[mssind];
493 	} else {
494 		if (mssind > ARRAY_SIZE(msstab6))
495 			goto err;
496 
497 		ctx->attrs.mss = msstab6[mssind];
498 	}
499 
500 	ctx->attrs.snd_wscale = cookie & BPF_SYNCOOKIE_WSCALE_MASK;
501 	ctx->attrs.rcv_wscale = ctx->attrs.snd_wscale;
502 	ctx->attrs.wscale_ok = ctx->attrs.snd_wscale == BPF_SYNCOOKIE_WSCALE_MASK;
503 	ctx->attrs.sack_ok = cookie & BPF_SYNCOOKIE_SACK;
504 	ctx->attrs.ecn_ok = cookie & BPF_SYNCOOKIE_ECN;
505 
506 	return 0;
507 err:
508 	return -1;
509 }
510 
511 static int tcp_handle_ack(struct tcp_syncookie *ctx)
512 {
513 	struct bpf_sock_tuple tuple;
514 	struct bpf_sock *skc;
515 	int ret = TC_ACT_OK;
516 	struct sock *sk;
517 	u32 tuple_size;
518 
519 	if (ctx->ipv4) {
520 		tuple.ipv4.saddr = ctx->ipv4->saddr;
521 		tuple.ipv4.daddr = ctx->ipv4->daddr;
522 		tuple.ipv4.sport = ctx->tcp->source;
523 		tuple.ipv4.dport = ctx->tcp->dest;
524 		tuple_size = sizeof(tuple.ipv4);
525 	} else if (ctx->ipv6) {
526 		__builtin_memcpy(tuple.ipv6.saddr, &ctx->ipv6->saddr, sizeof(tuple.ipv6.saddr));
527 		__builtin_memcpy(tuple.ipv6.daddr, &ctx->ipv6->daddr, sizeof(tuple.ipv6.daddr));
528 		tuple.ipv6.sport = ctx->tcp->source;
529 		tuple.ipv6.dport = ctx->tcp->dest;
530 		tuple_size = sizeof(tuple.ipv6);
531 	} else {
532 		goto out;
533 	}
534 
535 	skc = bpf_skc_lookup_tcp(ctx->skb, &tuple, tuple_size, -1, 0);
536 	if (!skc)
537 		goto out;
538 
539 	if (skc->state != TCP_LISTEN)
540 		goto release;
541 
542 	sk = (struct sock *)bpf_skc_to_tcp_sock(skc);
543 	if (!sk)
544 		goto err;
545 
546 	if (tcp_validate_header(ctx))
547 		goto err;
548 
549 	tcp_parse_options(ctx);
550 
551 	if (tcp_validate_cookie(ctx))
552 		goto err;
553 
554 	ret = bpf_sk_assign_tcp_reqsk(ctx->skb, sk, &ctx->attrs, sizeof(ctx->attrs));
555 	if (ret < 0)
556 		goto err;
557 
558 release:
559 	bpf_sk_release(skc);
560 out:
561 	return ret;
562 
563 err:
564 	ret = TC_ACT_SHOT;
565 	goto release;
566 }
567 
568 SEC("tc")
569 int tcp_custom_syncookie(struct __sk_buff *skb)
570 {
571 	struct tcp_syncookie ctx = {
572 		.skb = skb,
573 	};
574 
575 	if (tcp_load_headers(&ctx))
576 		return TC_ACT_OK;
577 
578 	if (ctx.tcp->rst)
579 		return TC_ACT_OK;
580 
581 	if (ctx.tcp->syn) {
582 		if (ctx.tcp->ack)
583 			return TC_ACT_OK;
584 
585 		handled_syn = true;
586 
587 		return tcp_handle_syn(&ctx);
588 	}
589 
590 	handled_ack = true;
591 
592 	return tcp_handle_ack(&ctx);
593 }
594 
595 char _license[] SEC("license") = "GPL";
596