xref: /qemu/tools/ebpf/rss.bpf.c (revision 4a1babe5)
1 /*
2  * eBPF RSS program
3  *
4  * Developed by Daynix Computing LTD (http://www.daynix.com)
5  *
6  * Authors:
7  *  Andrew Melnychenko <andrew@daynix.com>
8  *  Yuri Benditovich <yuri.benditovich@daynix.com>
9  *
10  * This work is licensed under the terms of the GNU GPL, version 2.  See
11  * the COPYING file in the top-level directory.
12  *
13  * Prepare:
14  * Requires llvm, clang, bpftool, linux kernel tree
15  *
16  * Build rss.bpf.skeleton.h:
17  * make -f Makefile.ebpf clean all
18  */
19 
20 #include <stddef.h>
21 #include <stdbool.h>
22 #include <linux/bpf.h>
23 
24 #include <linux/in.h>
25 #include <linux/if_ether.h>
26 #include <linux/ip.h>
27 #include <linux/ipv6.h>
28 
29 #include <linux/udp.h>
30 #include <linux/tcp.h>
31 
32 #include <bpf/bpf_helpers.h>
33 #include <bpf/bpf_endian.h>
34 #include <linux/virtio_net.h>
35 
36 #define INDIRECTION_TABLE_SIZE 128
37 #define HASH_CALCULATION_BUFFER_SIZE 36
38 
39 struct rss_config_t {
40     __u8 redirect;
41     __u8 populate_hash;
42     __u32 hash_types;
43     __u16 indirections_len;
44     __u16 default_queue;
45 } __attribute__((packed));
46 
47 struct toeplitz_key_data_t {
48     __u32 leftmost_32_bits;
49     __u8 next_byte[HASH_CALCULATION_BUFFER_SIZE];
50 };
51 
52 struct packet_hash_info_t {
53     __u8 is_ipv4;
54     __u8 is_ipv6;
55     __u8 is_udp;
56     __u8 is_tcp;
57     __u8 is_ipv6_ext_src;
58     __u8 is_ipv6_ext_dst;
59     __u8 is_fragmented;
60 
61     __u16 src_port;
62     __u16 dst_port;
63 
64     union {
65         struct {
66             __be32 in_src;
67             __be32 in_dst;
68         };
69 
70         struct {
71             struct in6_addr in6_src;
72             struct in6_addr in6_dst;
73             struct in6_addr in6_ext_src;
74             struct in6_addr in6_ext_dst;
75         };
76     };
77 };
78 
79 struct {
80     __uint(type, BPF_MAP_TYPE_ARRAY);
81     __uint(key_size, sizeof(__u32));
82     __uint(value_size, sizeof(struct rss_config_t));
83     __uint(max_entries, 1);
84     __uint(map_flags, BPF_F_MMAPABLE);
85 } tap_rss_map_configurations SEC(".maps");
86 
87 struct {
88     __uint(type, BPF_MAP_TYPE_ARRAY);
89     __uint(key_size, sizeof(__u32));
90     __uint(value_size, sizeof(struct toeplitz_key_data_t));
91     __uint(max_entries, 1);
92     __uint(map_flags, BPF_F_MMAPABLE);
93 } tap_rss_map_toeplitz_key SEC(".maps");
94 
95 struct {
96     __uint(type, BPF_MAP_TYPE_ARRAY);
97     __uint(key_size, sizeof(__u32));
98     __uint(value_size, sizeof(__u16));
99     __uint(max_entries, INDIRECTION_TABLE_SIZE);
100     __uint(map_flags, BPF_F_MMAPABLE);
101 } tap_rss_map_indirection_table SEC(".maps");
102 
103 static inline void net_rx_rss_add_chunk(__u8 *rss_input, size_t *bytes_written,
104                                         const void *ptr, size_t size) {
105     __builtin_memcpy(&rss_input[*bytes_written], ptr, size);
106     *bytes_written += size;
107 }
108 
109 static inline
110 void net_toeplitz_add(__u32 *result,
111                       __u8 *input,
112                       __u32 len
113         , struct toeplitz_key_data_t *key) {
114 
115     __u32 accumulator = *result;
116     __u32 leftmost_32_bits = key->leftmost_32_bits;
117     __u32 byte;
118 
119     for (byte = 0; byte < HASH_CALCULATION_BUFFER_SIZE; byte++) {
120         __u8 input_byte = input[byte];
121         __u8 key_byte = key->next_byte[byte];
122         __u8 bit;
123 
124         for (bit = 0; bit < 8; bit++) {
125             if (input_byte & (1 << 7)) {
126                 accumulator ^= leftmost_32_bits;
127             }
128 
129             leftmost_32_bits =
130                     (leftmost_32_bits << 1) | ((key_byte & (1 << 7)) >> 7);
131 
132             input_byte <<= 1;
133             key_byte <<= 1;
134         }
135     }
136 
137     *result = accumulator;
138 }
139 
140 
141 static inline int ip6_extension_header_type(__u8 hdr_type)
142 {
143     switch (hdr_type) {
144     case IPPROTO_HOPOPTS:
145     case IPPROTO_ROUTING:
146     case IPPROTO_FRAGMENT:
147     case IPPROTO_ICMPV6:
148     case IPPROTO_NONE:
149     case IPPROTO_DSTOPTS:
150     case IPPROTO_MH:
151         return 1;
152     default:
153         return 0;
154     }
155 }
156 /*
157  * According to
158  * https://www.iana.org/assignments/ipv6-parameters/ipv6-parameters.xhtml
159  * we expect that there are would be no more than 11 extensions in IPv6 header,
160  * also there is 27 TLV options for Destination and Hop-by-hop extensions.
161  * Need to choose reasonable amount of maximum extensions/options we may
162  * check to find ext src/dst.
163  */
164 #define IP6_EXTENSIONS_COUNT 11
165 #define IP6_OPTIONS_COUNT 30
166 
167 static inline int parse_ipv6_ext(struct __sk_buff *skb,
168         struct packet_hash_info_t *info,
169         __u8 *l4_protocol, size_t *l4_offset)
170 {
171     int err = 0;
172 
173     if (!ip6_extension_header_type(*l4_protocol)) {
174         return 0;
175     }
176 
177     struct ipv6_opt_hdr ext_hdr = {};
178 
179     for (unsigned int i = 0; i < IP6_EXTENSIONS_COUNT; ++i) {
180 
181         err = bpf_skb_load_bytes_relative(skb, *l4_offset, &ext_hdr,
182                                     sizeof(ext_hdr), BPF_HDR_START_NET);
183         if (err) {
184             goto error;
185         }
186 
187         if (*l4_protocol == IPPROTO_ROUTING) {
188             struct ipv6_rt_hdr ext_rt = {};
189 
190             err = bpf_skb_load_bytes_relative(skb, *l4_offset, &ext_rt,
191                                         sizeof(ext_rt), BPF_HDR_START_NET);
192             if (err) {
193                 goto error;
194             }
195 
196             if ((ext_rt.type == IPV6_SRCRT_TYPE_2) &&
197                     (ext_rt.hdrlen == sizeof(struct in6_addr) / 8) &&
198                     (ext_rt.segments_left == 1)) {
199 
200                 err = bpf_skb_load_bytes_relative(skb,
201                     *l4_offset + offsetof(struct rt2_hdr, addr),
202                     &info->in6_ext_dst, sizeof(info->in6_ext_dst),
203                     BPF_HDR_START_NET);
204                 if (err) {
205                     goto error;
206                 }
207 
208                 info->is_ipv6_ext_dst = 1;
209             }
210 
211         } else if (*l4_protocol == IPPROTO_DSTOPTS) {
212             struct ipv6_opt_t {
213                 __u8 type;
214                 __u8 length;
215             } __attribute__((packed)) opt = {};
216 
217             size_t opt_offset = sizeof(ext_hdr);
218 
219             for (unsigned int j = 0; j < IP6_OPTIONS_COUNT; ++j) {
220                 err = bpf_skb_load_bytes_relative(skb, *l4_offset + opt_offset,
221                                         &opt, sizeof(opt), BPF_HDR_START_NET);
222                 if (err) {
223                     goto error;
224                 }
225 
226                 if (opt.type == IPV6_TLV_HAO) {
227                     err = bpf_skb_load_bytes_relative(skb,
228                         *l4_offset + opt_offset
229                         + offsetof(struct ipv6_destopt_hao, addr),
230                         &info->in6_ext_src, sizeof(info->in6_ext_src),
231                         BPF_HDR_START_NET);
232                     if (err) {
233                         goto error;
234                     }
235 
236                     info->is_ipv6_ext_src = 1;
237                     break;
238                 }
239 
240                 opt_offset += (opt.type == IPV6_TLV_PAD1) ?
241                               1 : opt.length + sizeof(opt);
242 
243                 if (opt_offset + 1 >= ext_hdr.hdrlen * 8) {
244                     break;
245                 }
246             }
247         } else if (*l4_protocol == IPPROTO_FRAGMENT) {
248             info->is_fragmented = true;
249         }
250 
251         *l4_protocol = ext_hdr.nexthdr;
252         *l4_offset += (ext_hdr.hdrlen + 1) * 8;
253 
254         if (!ip6_extension_header_type(ext_hdr.nexthdr)) {
255             return 0;
256         }
257     }
258 
259     return 0;
260 error:
261     return err;
262 }
263 
264 static __be16 parse_eth_type(struct __sk_buff *skb)
265 {
266     unsigned int offset = 12;
267     __be16 ret = 0;
268     int err = 0;
269 
270     err = bpf_skb_load_bytes_relative(skb, offset, &ret, sizeof(ret),
271                                 BPF_HDR_START_MAC);
272     if (err) {
273         return 0;
274     }
275 
276     switch (bpf_ntohs(ret)) {
277     case ETH_P_8021AD:
278         offset += 4;
279     case ETH_P_8021Q:
280         offset += 4;
281         err = bpf_skb_load_bytes_relative(skb, offset, &ret, sizeof(ret),
282                                     BPF_HDR_START_MAC);
283     default:
284         break;
285     }
286 
287     if (err) {
288         return 0;
289     }
290 
291     return ret;
292 }
293 
294 static inline int parse_packet(struct __sk_buff *skb,
295         struct packet_hash_info_t *info)
296 {
297     int err = 0;
298 
299     if (!info || !skb) {
300         return -1;
301     }
302 
303     size_t l4_offset = 0;
304     __u8 l4_protocol = 0;
305     __u16 l3_protocol = bpf_ntohs(parse_eth_type(skb));
306     if (l3_protocol == 0) {
307         err = -1;
308         goto error;
309     }
310 
311     if (l3_protocol == ETH_P_IP) {
312         info->is_ipv4 = 1;
313 
314         struct iphdr ip = {};
315         err = bpf_skb_load_bytes_relative(skb, 0, &ip, sizeof(ip),
316                                     BPF_HDR_START_NET);
317         if (err) {
318             goto error;
319         }
320 
321         info->in_src = ip.saddr;
322         info->in_dst = ip.daddr;
323         info->is_fragmented = !!(bpf_ntohs(ip.frag_off) & (0x2000 | 0x1fff));
324 
325         l4_protocol = ip.protocol;
326         l4_offset = ip.ihl * 4;
327     } else if (l3_protocol == ETH_P_IPV6) {
328         info->is_ipv6 = 1;
329 
330         struct ipv6hdr ip6 = {};
331         err = bpf_skb_load_bytes_relative(skb, 0, &ip6, sizeof(ip6),
332                                     BPF_HDR_START_NET);
333         if (err) {
334             goto error;
335         }
336 
337         info->in6_src = ip6.saddr;
338         info->in6_dst = ip6.daddr;
339 
340         l4_protocol = ip6.nexthdr;
341         l4_offset = sizeof(ip6);
342 
343         err = parse_ipv6_ext(skb, info, &l4_protocol, &l4_offset);
344         if (err) {
345             goto error;
346         }
347     }
348 
349     if (l4_protocol != 0 && !info->is_fragmented) {
350         if (l4_protocol == IPPROTO_TCP) {
351             info->is_tcp = 1;
352 
353             struct tcphdr tcp = {};
354             err = bpf_skb_load_bytes_relative(skb, l4_offset, &tcp, sizeof(tcp),
355                                         BPF_HDR_START_NET);
356             if (err) {
357                 goto error;
358             }
359 
360             info->src_port = tcp.source;
361             info->dst_port = tcp.dest;
362         } else if (l4_protocol == IPPROTO_UDP) { /* TODO: add udplite? */
363             info->is_udp = 1;
364 
365             struct udphdr udp = {};
366             err = bpf_skb_load_bytes_relative(skb, l4_offset, &udp, sizeof(udp),
367                                         BPF_HDR_START_NET);
368             if (err) {
369                 goto error;
370             }
371 
372             info->src_port = udp.source;
373             info->dst_port = udp.dest;
374         }
375     }
376 
377     return 0;
378 
379 error:
380     return err;
381 }
382 
383 static inline __u32 calculate_rss_hash(struct __sk_buff *skb,
384         struct rss_config_t *config, struct toeplitz_key_data_t *toe)
385 {
386     __u8 rss_input[HASH_CALCULATION_BUFFER_SIZE] = {};
387     size_t bytes_written = 0;
388     __u32 result = 0;
389     int err = 0;
390     struct packet_hash_info_t packet_info = {};
391 
392     err = parse_packet(skb, &packet_info);
393     if (err) {
394         return 0;
395     }
396 
397     if (packet_info.is_ipv4) {
398         if (packet_info.is_tcp &&
399             config->hash_types & VIRTIO_NET_RSS_HASH_TYPE_TCPv4) {
400 
401             net_rx_rss_add_chunk(rss_input, &bytes_written,
402                                  &packet_info.in_src,
403                                  sizeof(packet_info.in_src));
404             net_rx_rss_add_chunk(rss_input, &bytes_written,
405                                  &packet_info.in_dst,
406                                  sizeof(packet_info.in_dst));
407             net_rx_rss_add_chunk(rss_input, &bytes_written,
408                                  &packet_info.src_port,
409                                  sizeof(packet_info.src_port));
410             net_rx_rss_add_chunk(rss_input, &bytes_written,
411                                  &packet_info.dst_port,
412                                  sizeof(packet_info.dst_port));
413         } else if (packet_info.is_udp &&
414                    config->hash_types & VIRTIO_NET_RSS_HASH_TYPE_UDPv4) {
415 
416             net_rx_rss_add_chunk(rss_input, &bytes_written,
417                                  &packet_info.in_src,
418                                  sizeof(packet_info.in_src));
419             net_rx_rss_add_chunk(rss_input, &bytes_written,
420                                  &packet_info.in_dst,
421                                  sizeof(packet_info.in_dst));
422             net_rx_rss_add_chunk(rss_input, &bytes_written,
423                                  &packet_info.src_port,
424                                  sizeof(packet_info.src_port));
425             net_rx_rss_add_chunk(rss_input, &bytes_written,
426                                  &packet_info.dst_port,
427                                  sizeof(packet_info.dst_port));
428         } else if (config->hash_types & VIRTIO_NET_RSS_HASH_TYPE_IPv4) {
429             net_rx_rss_add_chunk(rss_input, &bytes_written,
430                                  &packet_info.in_src,
431                                  sizeof(packet_info.in_src));
432             net_rx_rss_add_chunk(rss_input, &bytes_written,
433                                  &packet_info.in_dst,
434                                  sizeof(packet_info.in_dst));
435         }
436     } else if (packet_info.is_ipv6) {
437         if (packet_info.is_tcp &&
438             config->hash_types & VIRTIO_NET_RSS_HASH_TYPE_TCPv6) {
439 
440             if (packet_info.is_ipv6_ext_src &&
441                 config->hash_types & VIRTIO_NET_RSS_HASH_TYPE_TCP_EX) {
442 
443                 net_rx_rss_add_chunk(rss_input, &bytes_written,
444                                      &packet_info.in6_ext_src,
445                                      sizeof(packet_info.in6_ext_src));
446             } else {
447                 net_rx_rss_add_chunk(rss_input, &bytes_written,
448                                      &packet_info.in6_src,
449                                      sizeof(packet_info.in6_src));
450             }
451             if (packet_info.is_ipv6_ext_dst &&
452                 config->hash_types & VIRTIO_NET_RSS_HASH_TYPE_TCP_EX) {
453 
454                 net_rx_rss_add_chunk(rss_input, &bytes_written,
455                                      &packet_info.in6_ext_dst,
456                                      sizeof(packet_info.in6_ext_dst));
457             } else {
458                 net_rx_rss_add_chunk(rss_input, &bytes_written,
459                                      &packet_info.in6_dst,
460                                      sizeof(packet_info.in6_dst));
461             }
462             net_rx_rss_add_chunk(rss_input, &bytes_written,
463                                  &packet_info.src_port,
464                                  sizeof(packet_info.src_port));
465             net_rx_rss_add_chunk(rss_input, &bytes_written,
466                                  &packet_info.dst_port,
467                                  sizeof(packet_info.dst_port));
468         } else if (packet_info.is_udp &&
469                    config->hash_types & VIRTIO_NET_RSS_HASH_TYPE_UDPv6) {
470 
471             if (packet_info.is_ipv6_ext_src &&
472                config->hash_types & VIRTIO_NET_RSS_HASH_TYPE_UDP_EX) {
473 
474                 net_rx_rss_add_chunk(rss_input, &bytes_written,
475                                      &packet_info.in6_ext_src,
476                                      sizeof(packet_info.in6_ext_src));
477             } else {
478                 net_rx_rss_add_chunk(rss_input, &bytes_written,
479                                      &packet_info.in6_src,
480                                      sizeof(packet_info.in6_src));
481             }
482             if (packet_info.is_ipv6_ext_dst &&
483                config->hash_types & VIRTIO_NET_RSS_HASH_TYPE_UDP_EX) {
484 
485                 net_rx_rss_add_chunk(rss_input, &bytes_written,
486                                      &packet_info.in6_ext_dst,
487                                      sizeof(packet_info.in6_ext_dst));
488             } else {
489                 net_rx_rss_add_chunk(rss_input, &bytes_written,
490                                      &packet_info.in6_dst,
491                                      sizeof(packet_info.in6_dst));
492             }
493 
494             net_rx_rss_add_chunk(rss_input, &bytes_written,
495                                  &packet_info.src_port,
496                                  sizeof(packet_info.src_port));
497             net_rx_rss_add_chunk(rss_input, &bytes_written,
498                                  &packet_info.dst_port,
499                                  sizeof(packet_info.dst_port));
500 
501         } else if (config->hash_types & VIRTIO_NET_RSS_HASH_TYPE_IPv6) {
502             if (packet_info.is_ipv6_ext_src &&
503                config->hash_types & VIRTIO_NET_RSS_HASH_TYPE_IP_EX) {
504 
505                 net_rx_rss_add_chunk(rss_input, &bytes_written,
506                                      &packet_info.in6_ext_src,
507                                      sizeof(packet_info.in6_ext_src));
508             } else {
509                 net_rx_rss_add_chunk(rss_input, &bytes_written,
510                                      &packet_info.in6_src,
511                                      sizeof(packet_info.in6_src));
512             }
513             if (packet_info.is_ipv6_ext_dst &&
514                 config->hash_types & VIRTIO_NET_RSS_HASH_TYPE_IP_EX) {
515 
516                 net_rx_rss_add_chunk(rss_input, &bytes_written,
517                                      &packet_info.in6_ext_dst,
518                                      sizeof(packet_info.in6_ext_dst));
519             } else {
520                 net_rx_rss_add_chunk(rss_input, &bytes_written,
521                                      &packet_info.in6_dst,
522                                      sizeof(packet_info.in6_dst));
523             }
524         }
525     }
526 
527     if (bytes_written) {
528         net_toeplitz_add(&result, rss_input, bytes_written, toe);
529     }
530 
531     return result;
532 }
533 
534 SEC("socket")
535 int tun_rss_steering_prog(struct __sk_buff *skb)
536 {
537 
538     struct rss_config_t *config;
539     struct toeplitz_key_data_t *toe;
540 
541     __u32 key = 0;
542     __u32 hash = 0;
543 
544     config = bpf_map_lookup_elem(&tap_rss_map_configurations, &key);
545     toe = bpf_map_lookup_elem(&tap_rss_map_toeplitz_key, &key);
546 
547     if (config && toe) {
548         if (!config->redirect) {
549             return config->default_queue;
550         }
551 
552         hash = calculate_rss_hash(skb, config, toe);
553         if (hash) {
554             __u32 table_idx = hash % config->indirections_len;
555             __u16 *queue = 0;
556 
557             queue = bpf_map_lookup_elem(&tap_rss_map_indirection_table,
558                                         &table_idx);
559 
560             if (queue) {
561                 return *queue;
562             }
563         }
564 
565         return config->default_queue;
566     }
567 
568     return -1;
569 }
570 
571 char _license[] SEC("license") = "GPL v2";
572