1 // SPDX-License-Identifier: GPL-2.0
2 // Copyright (c) 2018 Facebook
3 
4 #include <stdlib.h>
5 #include <string.h>
6 #include <unistd.h>
7 
8 #include <arpa/inet.h>
9 #include <net/if.h>
10 #include <netinet/in.h>
11 #include <sys/socket.h>
12 #include <sys/types.h>
13 
14 
15 #include <bpf/bpf.h>
16 #include <bpf/libbpf.h>
17 
18 #include "bpf_rlimit.h"
19 #include "cgroup_helpers.h"
20 
21 #define CGROUP_PATH		"/skb_cgroup_test"
22 #define NUM_CGROUP_LEVELS	4
23 
24 /* RFC 4291, Section 2.7.1 */
25 #define LINKLOCAL_MULTICAST	"ff02::1"
26 
27 static int mk_dst_addr(const char *ip, const char *iface,
28 		       struct sockaddr_in6 *dst)
29 {
30 	memset(dst, 0, sizeof(*dst));
31 
32 	dst->sin6_family = AF_INET6;
33 	dst->sin6_port = htons(1025);
34 
35 	if (inet_pton(AF_INET6, ip, &dst->sin6_addr) != 1) {
36 		log_err("Invalid IPv6: %s", ip);
37 		return -1;
38 	}
39 
40 	dst->sin6_scope_id = if_nametoindex(iface);
41 	if (!dst->sin6_scope_id) {
42 		log_err("Failed to get index of iface: %s", iface);
43 		return -1;
44 	}
45 
46 	return 0;
47 }
48 
49 static int send_packet(const char *iface)
50 {
51 	struct sockaddr_in6 dst;
52 	char msg[] = "msg";
53 	int err = 0;
54 	int fd = -1;
55 
56 	if (mk_dst_addr(LINKLOCAL_MULTICAST, iface, &dst))
57 		goto err;
58 
59 	fd = socket(AF_INET6, SOCK_DGRAM, 0);
60 	if (fd == -1) {
61 		log_err("Failed to create UDP socket");
62 		goto err;
63 	}
64 
65 	if (sendto(fd, &msg, sizeof(msg), 0, (const struct sockaddr *)&dst,
66 		   sizeof(dst)) == -1) {
67 		log_err("Failed to send datagram");
68 		goto err;
69 	}
70 
71 	goto out;
72 err:
73 	err = -1;
74 out:
75 	if (fd >= 0)
76 		close(fd);
77 	return err;
78 }
79 
80 int get_map_fd_by_prog_id(int prog_id)
81 {
82 	struct bpf_prog_info info = {};
83 	__u32 info_len = sizeof(info);
84 	__u32 map_ids[1];
85 	int prog_fd = -1;
86 	int map_fd = -1;
87 
88 	prog_fd = bpf_prog_get_fd_by_id(prog_id);
89 	if (prog_fd < 0) {
90 		log_err("Failed to get fd by prog id %d", prog_id);
91 		goto err;
92 	}
93 
94 	info.nr_map_ids = 1;
95 	info.map_ids = (__u64) (unsigned long) map_ids;
96 
97 	if (bpf_obj_get_info_by_fd(prog_fd, &info, &info_len)) {
98 		log_err("Failed to get info by prog fd %d", prog_fd);
99 		goto err;
100 	}
101 
102 	if (!info.nr_map_ids) {
103 		log_err("No maps found for prog fd %d", prog_fd);
104 		goto err;
105 	}
106 
107 	map_fd = bpf_map_get_fd_by_id(map_ids[0]);
108 	if (map_fd < 0)
109 		log_err("Failed to get fd by map id %d", map_ids[0]);
110 err:
111 	if (prog_fd >= 0)
112 		close(prog_fd);
113 	return map_fd;
114 }
115 
116 int check_ancestor_cgroup_ids(int prog_id)
117 {
118 	__u64 actual_ids[NUM_CGROUP_LEVELS], expected_ids[NUM_CGROUP_LEVELS];
119 	__u32 level;
120 	int err = 0;
121 	int map_fd;
122 
123 	expected_ids[0] = get_cgroup_id("/..");	/* root cgroup */
124 	expected_ids[1] = get_cgroup_id("");
125 	expected_ids[2] = get_cgroup_id(CGROUP_PATH);
126 	expected_ids[3] = 0; /* non-existent cgroup */
127 
128 	map_fd = get_map_fd_by_prog_id(prog_id);
129 	if (map_fd < 0)
130 		goto err;
131 
132 	for (level = 0; level < NUM_CGROUP_LEVELS; ++level) {
133 		if (bpf_map_lookup_elem(map_fd, &level, &actual_ids[level])) {
134 			log_err("Failed to lookup key %d", level);
135 			goto err;
136 		}
137 		if (actual_ids[level] != expected_ids[level]) {
138 			log_err("%llx (actual) != %llx (expected), level: %u\n",
139 				actual_ids[level], expected_ids[level], level);
140 			goto err;
141 		}
142 	}
143 
144 	goto out;
145 err:
146 	err = -1;
147 out:
148 	if (map_fd >= 0)
149 		close(map_fd);
150 	return err;
151 }
152 
153 int main(int argc, char **argv)
154 {
155 	int cgfd = -1;
156 	int err = 0;
157 
158 	if (argc < 3) {
159 		fprintf(stderr, "Usage: %s iface prog_id\n", argv[0]);
160 		exit(EXIT_FAILURE);
161 	}
162 
163 	cgfd = cgroup_setup_and_join(CGROUP_PATH);
164 	if (cgfd < 0)
165 		goto err;
166 
167 	if (send_packet(argv[1]))
168 		goto err;
169 
170 	if (check_ancestor_cgroup_ids(atoi(argv[2])))
171 		goto err;
172 
173 	goto out;
174 err:
175 	err = -1;
176 out:
177 	close(cgfd);
178 	cleanup_cgroup_environment();
179 	printf("[%s]\n", err ? "FAIL" : "PASS");
180 	return err;
181 }
182