1 // SPDX-License-Identifier: GPL-2.0
2 // Copyright (c) 2020 Cloudflare
3 
4 #include <errno.h>
5 #include <stdbool.h>
6 #include <linux/bpf.h>
7 
8 #include <bpf/bpf_helpers.h>
9 
10 struct {
11 	__uint(type, BPF_MAP_TYPE_SOCKMAP);
12 	__uint(max_entries, 2);
13 	__type(key, __u32);
14 	__type(value, __u64);
15 } sock_map SEC(".maps");
16 
17 struct {
18 	__uint(type, BPF_MAP_TYPE_SOCKHASH);
19 	__uint(max_entries, 2);
20 	__type(key, __u32);
21 	__type(value, __u64);
22 } sock_hash SEC(".maps");
23 
24 struct {
25 	__uint(type, BPF_MAP_TYPE_ARRAY);
26 	__uint(max_entries, 2);
27 	__type(key, int);
28 	__type(value, unsigned int);
29 } verdict_map SEC(".maps");
30 
31 bool test_sockmap = false; /* toggled by user-space */
32 bool test_ingress = false; /* toggled by user-space */
33 
34 SEC("sk_skb/stream_parser")
35 int prog_stream_parser(struct __sk_buff *skb)
36 {
37 	return skb->len;
38 }
39 
40 SEC("sk_skb/stream_verdict")
41 int prog_stream_verdict(struct __sk_buff *skb)
42 {
43 	unsigned int *count;
44 	__u32 zero = 0;
45 	int verdict;
46 
47 	if (test_sockmap)
48 		verdict = bpf_sk_redirect_map(skb, &sock_map, zero, 0);
49 	else
50 		verdict = bpf_sk_redirect_hash(skb, &sock_hash, &zero, 0);
51 
52 	count = bpf_map_lookup_elem(&verdict_map, &verdict);
53 	if (count)
54 		(*count)++;
55 
56 	return verdict;
57 }
58 
59 SEC("sk_skb")
60 int prog_skb_verdict(struct __sk_buff *skb)
61 {
62 	unsigned int *count;
63 	__u32 zero = 0;
64 	int verdict;
65 
66 	if (test_sockmap)
67 		verdict = bpf_sk_redirect_map(skb, &sock_map, zero,
68 					      test_ingress ? BPF_F_INGRESS : 0);
69 	else
70 		verdict = bpf_sk_redirect_hash(skb, &sock_hash, &zero,
71 					       test_ingress ? BPF_F_INGRESS : 0);
72 
73 	count = bpf_map_lookup_elem(&verdict_map, &verdict);
74 	if (count)
75 		(*count)++;
76 
77 	return verdict;
78 }
79 
80 SEC("sk_msg")
81 int prog_msg_verdict(struct sk_msg_md *msg)
82 {
83 	unsigned int *count;
84 	__u32 zero = 0;
85 	int verdict;
86 
87 	if (test_sockmap)
88 		verdict = bpf_msg_redirect_map(msg, &sock_map, zero, 0);
89 	else
90 		verdict = bpf_msg_redirect_hash(msg, &sock_hash, &zero, 0);
91 
92 	count = bpf_map_lookup_elem(&verdict_map, &verdict);
93 	if (count)
94 		(*count)++;
95 
96 	return verdict;
97 }
98 
99 SEC("sk_reuseport")
100 int prog_reuseport(struct sk_reuseport_md *reuse)
101 {
102 	unsigned int *count;
103 	int err, verdict;
104 	__u32 zero = 0;
105 
106 	if (test_sockmap)
107 		err = bpf_sk_select_reuseport(reuse, &sock_map, &zero, 0);
108 	else
109 		err = bpf_sk_select_reuseport(reuse, &sock_hash, &zero, 0);
110 	verdict = err ? SK_DROP : SK_PASS;
111 
112 	count = bpf_map_lookup_elem(&verdict_map, &verdict);
113 	if (count)
114 		(*count)++;
115 
116 	return verdict;
117 }
118 
119 char _license[] SEC("license") = "GPL";
120