1 // SPDX-License-Identifier: GPL-2.0
2 #include <sys/un.h>
3 
4 #include "test_progs.h"
5 
6 #include "connect_unix_prog.skel.h"
7 #include "sendmsg_unix_prog.skel.h"
8 #include "recvmsg_unix_prog.skel.h"
9 #include "getsockname_unix_prog.skel.h"
10 #include "getpeername_unix_prog.skel.h"
11 #include "network_helpers.h"
12 
13 #define SERVUN_ADDRESS         "bpf_cgroup_unix_test"
14 #define SERVUN_REWRITE_ADDRESS "bpf_cgroup_unix_test_rewrite"
15 #define SRCUN_ADDRESS	       "bpf_cgroup_unix_test_src"
16 
17 enum sock_addr_test_type {
18 	SOCK_ADDR_TEST_BIND,
19 	SOCK_ADDR_TEST_CONNECT,
20 	SOCK_ADDR_TEST_SENDMSG,
21 	SOCK_ADDR_TEST_RECVMSG,
22 	SOCK_ADDR_TEST_GETSOCKNAME,
23 	SOCK_ADDR_TEST_GETPEERNAME,
24 };
25 
26 typedef void *(*load_fn)(int cgroup_fd);
27 typedef void (*destroy_fn)(void *skel);
28 
29 struct sock_addr_test {
30 	enum sock_addr_test_type type;
31 	const char *name;
32 	/* BPF prog properties */
33 	load_fn loadfn;
34 	destroy_fn destroyfn;
35 	/* Socket properties */
36 	int socket_family;
37 	int socket_type;
38 	/* IP:port pairs for BPF prog to override */
39 	const char *requested_addr;
40 	unsigned short requested_port;
41 	const char *expected_addr;
42 	unsigned short expected_port;
43 	const char *expected_src_addr;
44 };
45 
46 static void *connect_unix_prog_load(int cgroup_fd)
47 {
48 	struct connect_unix_prog *skel;
49 
50 	skel = connect_unix_prog__open_and_load();
51 	if (!ASSERT_OK_PTR(skel, "skel_open"))
52 		goto cleanup;
53 
54 	skel->links.connect_unix_prog = bpf_program__attach_cgroup(
55 		skel->progs.connect_unix_prog, cgroup_fd);
56 	if (!ASSERT_OK_PTR(skel->links.connect_unix_prog, "prog_attach"))
57 		goto cleanup;
58 
59 	return skel;
60 cleanup:
61 	connect_unix_prog__destroy(skel);
62 	return NULL;
63 }
64 
65 static void connect_unix_prog_destroy(void *skel)
66 {
67 	connect_unix_prog__destroy(skel);
68 }
69 
70 static void *sendmsg_unix_prog_load(int cgroup_fd)
71 {
72 	struct sendmsg_unix_prog *skel;
73 
74 	skel = sendmsg_unix_prog__open_and_load();
75 	if (!ASSERT_OK_PTR(skel, "skel_open"))
76 		goto cleanup;
77 
78 	skel->links.sendmsg_unix_prog = bpf_program__attach_cgroup(
79 		skel->progs.sendmsg_unix_prog, cgroup_fd);
80 	if (!ASSERT_OK_PTR(skel->links.sendmsg_unix_prog, "prog_attach"))
81 		goto cleanup;
82 
83 	return skel;
84 cleanup:
85 	sendmsg_unix_prog__destroy(skel);
86 	return NULL;
87 }
88 
89 static void sendmsg_unix_prog_destroy(void *skel)
90 {
91 	sendmsg_unix_prog__destroy(skel);
92 }
93 
94 static void *recvmsg_unix_prog_load(int cgroup_fd)
95 {
96 	struct recvmsg_unix_prog *skel;
97 
98 	skel = recvmsg_unix_prog__open_and_load();
99 	if (!ASSERT_OK_PTR(skel, "skel_open"))
100 		goto cleanup;
101 
102 	skel->links.recvmsg_unix_prog = bpf_program__attach_cgroup(
103 		skel->progs.recvmsg_unix_prog, cgroup_fd);
104 	if (!ASSERT_OK_PTR(skel->links.recvmsg_unix_prog, "prog_attach"))
105 		goto cleanup;
106 
107 	return skel;
108 cleanup:
109 	recvmsg_unix_prog__destroy(skel);
110 	return NULL;
111 }
112 
113 static void recvmsg_unix_prog_destroy(void *skel)
114 {
115 	recvmsg_unix_prog__destroy(skel);
116 }
117 
118 static void *getsockname_unix_prog_load(int cgroup_fd)
119 {
120 	struct getsockname_unix_prog *skel;
121 
122 	skel = getsockname_unix_prog__open_and_load();
123 	if (!ASSERT_OK_PTR(skel, "skel_open"))
124 		goto cleanup;
125 
126 	skel->links.getsockname_unix_prog = bpf_program__attach_cgroup(
127 		skel->progs.getsockname_unix_prog, cgroup_fd);
128 	if (!ASSERT_OK_PTR(skel->links.getsockname_unix_prog, "prog_attach"))
129 		goto cleanup;
130 
131 	return skel;
132 cleanup:
133 	getsockname_unix_prog__destroy(skel);
134 	return NULL;
135 }
136 
137 static void getsockname_unix_prog_destroy(void *skel)
138 {
139 	getsockname_unix_prog__destroy(skel);
140 }
141 
142 static void *getpeername_unix_prog_load(int cgroup_fd)
143 {
144 	struct getpeername_unix_prog *skel;
145 
146 	skel = getpeername_unix_prog__open_and_load();
147 	if (!ASSERT_OK_PTR(skel, "skel_open"))
148 		goto cleanup;
149 
150 	skel->links.getpeername_unix_prog = bpf_program__attach_cgroup(
151 		skel->progs.getpeername_unix_prog, cgroup_fd);
152 	if (!ASSERT_OK_PTR(skel->links.getpeername_unix_prog, "prog_attach"))
153 		goto cleanup;
154 
155 	return skel;
156 cleanup:
157 	getpeername_unix_prog__destroy(skel);
158 	return NULL;
159 }
160 
161 static void getpeername_unix_prog_destroy(void *skel)
162 {
163 	getpeername_unix_prog__destroy(skel);
164 }
165 
166 static struct sock_addr_test tests[] = {
167 	{
168 		SOCK_ADDR_TEST_CONNECT,
169 		"connect_unix",
170 		connect_unix_prog_load,
171 		connect_unix_prog_destroy,
172 		AF_UNIX,
173 		SOCK_STREAM,
174 		SERVUN_ADDRESS,
175 		0,
176 		SERVUN_REWRITE_ADDRESS,
177 		0,
178 		NULL,
179 	},
180 	{
181 		SOCK_ADDR_TEST_SENDMSG,
182 		"sendmsg_unix",
183 		sendmsg_unix_prog_load,
184 		sendmsg_unix_prog_destroy,
185 		AF_UNIX,
186 		SOCK_DGRAM,
187 		SERVUN_ADDRESS,
188 		0,
189 		SERVUN_REWRITE_ADDRESS,
190 		0,
191 		NULL,
192 	},
193 	{
194 		SOCK_ADDR_TEST_RECVMSG,
195 		"recvmsg_unix-dgram",
196 		recvmsg_unix_prog_load,
197 		recvmsg_unix_prog_destroy,
198 		AF_UNIX,
199 		SOCK_DGRAM,
200 		SERVUN_REWRITE_ADDRESS,
201 		0,
202 		SERVUN_REWRITE_ADDRESS,
203 		0,
204 		SERVUN_ADDRESS,
205 	},
206 	{
207 		SOCK_ADDR_TEST_RECVMSG,
208 		"recvmsg_unix-stream",
209 		recvmsg_unix_prog_load,
210 		recvmsg_unix_prog_destroy,
211 		AF_UNIX,
212 		SOCK_STREAM,
213 		SERVUN_REWRITE_ADDRESS,
214 		0,
215 		SERVUN_REWRITE_ADDRESS,
216 		0,
217 		SERVUN_ADDRESS,
218 	},
219 	{
220 		SOCK_ADDR_TEST_GETSOCKNAME,
221 		"getsockname_unix",
222 		getsockname_unix_prog_load,
223 		getsockname_unix_prog_destroy,
224 		AF_UNIX,
225 		SOCK_STREAM,
226 		SERVUN_ADDRESS,
227 		0,
228 		SERVUN_REWRITE_ADDRESS,
229 		0,
230 		NULL,
231 	},
232 	{
233 		SOCK_ADDR_TEST_GETPEERNAME,
234 		"getpeername_unix",
235 		getpeername_unix_prog_load,
236 		getpeername_unix_prog_destroy,
237 		AF_UNIX,
238 		SOCK_STREAM,
239 		SERVUN_ADDRESS,
240 		0,
241 		SERVUN_REWRITE_ADDRESS,
242 		0,
243 		NULL,
244 	},
245 };
246 
247 typedef int (*info_fn)(int, struct sockaddr *, socklen_t *);
248 
249 static int cmp_addr(const struct sockaddr_storage *addr1, socklen_t addr1_len,
250 		    const struct sockaddr_storage *addr2, socklen_t addr2_len,
251 		    bool cmp_port)
252 {
253 	const struct sockaddr_in *four1, *four2;
254 	const struct sockaddr_in6 *six1, *six2;
255 	const struct sockaddr_un *un1, *un2;
256 
257 	if (addr1->ss_family != addr2->ss_family)
258 		return -1;
259 
260 	if (addr1_len != addr2_len)
261 		return -1;
262 
263 	if (addr1->ss_family == AF_INET) {
264 		four1 = (const struct sockaddr_in *)addr1;
265 		four2 = (const struct sockaddr_in *)addr2;
266 		return !((four1->sin_port == four2->sin_port || !cmp_port) &&
267 			 four1->sin_addr.s_addr == four2->sin_addr.s_addr);
268 	} else if (addr1->ss_family == AF_INET6) {
269 		six1 = (const struct sockaddr_in6 *)addr1;
270 		six2 = (const struct sockaddr_in6 *)addr2;
271 		return !((six1->sin6_port == six2->sin6_port || !cmp_port) &&
272 			 !memcmp(&six1->sin6_addr, &six2->sin6_addr,
273 				 sizeof(struct in6_addr)));
274 	} else if (addr1->ss_family == AF_UNIX) {
275 		un1 = (const struct sockaddr_un *)addr1;
276 		un2 = (const struct sockaddr_un *)addr2;
277 		return memcmp(un1, un2, addr1_len);
278 	}
279 
280 	return -1;
281 }
282 
283 static int cmp_sock_addr(info_fn fn, int sock1,
284 			 const struct sockaddr_storage *addr2,
285 			 socklen_t addr2_len, bool cmp_port)
286 {
287 	struct sockaddr_storage addr1;
288 	socklen_t len1 = sizeof(addr1);
289 
290 	memset(&addr1, 0, len1);
291 	if (fn(sock1, (struct sockaddr *)&addr1, (socklen_t *)&len1) != 0)
292 		return -1;
293 
294 	return cmp_addr(&addr1, len1, addr2, addr2_len, cmp_port);
295 }
296 
297 static int cmp_local_addr(int sock1, const struct sockaddr_storage *addr2,
298 			  socklen_t addr2_len, bool cmp_port)
299 {
300 	return cmp_sock_addr(getsockname, sock1, addr2, addr2_len, cmp_port);
301 }
302 
303 static int cmp_peer_addr(int sock1, const struct sockaddr_storage *addr2,
304 			 socklen_t addr2_len, bool cmp_port)
305 {
306 	return cmp_sock_addr(getpeername, sock1, addr2, addr2_len, cmp_port);
307 }
308 
309 static void test_bind(struct sock_addr_test *test)
310 {
311 	struct sockaddr_storage expected_addr;
312 	socklen_t expected_addr_len = sizeof(struct sockaddr_storage);
313 	int serv = -1, client = -1, err;
314 
315 	serv = start_server(test->socket_family, test->socket_type,
316 			    test->requested_addr, test->requested_port, 0);
317 	if (!ASSERT_GE(serv, 0, "start_server"))
318 		goto cleanup;
319 
320 	err = make_sockaddr(test->socket_family,
321 			    test->expected_addr, test->expected_port,
322 			    &expected_addr, &expected_addr_len);
323 	if (!ASSERT_EQ(err, 0, "make_sockaddr"))
324 		goto cleanup;
325 
326 	err = cmp_local_addr(serv, &expected_addr, expected_addr_len, true);
327 	if (!ASSERT_EQ(err, 0, "cmp_local_addr"))
328 		goto cleanup;
329 
330 	/* Try to connect to server just in case */
331 	client = connect_to_addr(&expected_addr, expected_addr_len, test->socket_type);
332 	if (!ASSERT_GE(client, 0, "connect_to_addr"))
333 		goto cleanup;
334 
335 cleanup:
336 	if (client != -1)
337 		close(client);
338 	if (serv != -1)
339 		close(serv);
340 }
341 
342 static void test_connect(struct sock_addr_test *test)
343 {
344 	struct sockaddr_storage addr, expected_addr, expected_src_addr;
345 	socklen_t addr_len = sizeof(struct sockaddr_storage),
346 		  expected_addr_len = sizeof(struct sockaddr_storage),
347 		  expected_src_addr_len = sizeof(struct sockaddr_storage);
348 	int serv = -1, client = -1, err;
349 
350 	serv = start_server(test->socket_family, test->socket_type,
351 			    test->expected_addr, test->expected_port, 0);
352 	if (!ASSERT_GE(serv, 0, "start_server"))
353 		goto cleanup;
354 
355 	err = make_sockaddr(test->socket_family, test->requested_addr, test->requested_port,
356 			    &addr, &addr_len);
357 	if (!ASSERT_EQ(err, 0, "make_sockaddr"))
358 		goto cleanup;
359 
360 	client = connect_to_addr(&addr, addr_len, test->socket_type);
361 	if (!ASSERT_GE(client, 0, "connect_to_addr"))
362 		goto cleanup;
363 
364 	err = make_sockaddr(test->socket_family, test->expected_addr, test->expected_port,
365 			    &expected_addr, &expected_addr_len);
366 	if (!ASSERT_EQ(err, 0, "make_sockaddr"))
367 		goto cleanup;
368 
369 	if (test->expected_src_addr) {
370 		err = make_sockaddr(test->socket_family, test->expected_src_addr, 0,
371 				    &expected_src_addr, &expected_src_addr_len);
372 		if (!ASSERT_EQ(err, 0, "make_sockaddr"))
373 			goto cleanup;
374 	}
375 
376 	err = cmp_peer_addr(client, &expected_addr, expected_addr_len, true);
377 	if (!ASSERT_EQ(err, 0, "cmp_peer_addr"))
378 		goto cleanup;
379 
380 	if (test->expected_src_addr) {
381 		err = cmp_local_addr(client, &expected_src_addr, expected_src_addr_len, false);
382 		if (!ASSERT_EQ(err, 0, "cmp_local_addr"))
383 			goto cleanup;
384 	}
385 cleanup:
386 	if (client != -1)
387 		close(client);
388 	if (serv != -1)
389 		close(serv);
390 }
391 
392 static void test_xmsg(struct sock_addr_test *test)
393 {
394 	struct sockaddr_storage addr, src_addr;
395 	socklen_t addr_len = sizeof(struct sockaddr_storage),
396 		  src_addr_len = sizeof(struct sockaddr_storage);
397 	struct msghdr hdr;
398 	struct iovec iov;
399 	char data = 'a';
400 	int serv = -1, client = -1, err;
401 
402 	/* Unlike the other tests, here we test that we can rewrite the src addr
403 	 * with a recvmsg() hook.
404 	 */
405 
406 	serv = start_server(test->socket_family, test->socket_type,
407 			    test->expected_addr, test->expected_port, 0);
408 	if (!ASSERT_GE(serv, 0, "start_server"))
409 		goto cleanup;
410 
411 	client = socket(test->socket_family, test->socket_type, 0);
412 	if (!ASSERT_GE(client, 0, "socket"))
413 		goto cleanup;
414 
415 	/* AF_UNIX sockets have to be bound to something to trigger the recvmsg bpf program. */
416 	if (test->socket_family == AF_UNIX) {
417 		err = make_sockaddr(AF_UNIX, SRCUN_ADDRESS, 0, &src_addr, &src_addr_len);
418 		if (!ASSERT_EQ(err, 0, "make_sockaddr"))
419 			goto cleanup;
420 
421 		err = bind(client, (const struct sockaddr *) &src_addr, src_addr_len);
422 		if (!ASSERT_OK(err, "bind"))
423 			goto cleanup;
424 	}
425 
426 	err = make_sockaddr(test->socket_family, test->requested_addr, test->requested_port,
427 			    &addr, &addr_len);
428 	if (!ASSERT_EQ(err, 0, "make_sockaddr"))
429 		goto cleanup;
430 
431 	if (test->socket_type == SOCK_DGRAM) {
432 		memset(&iov, 0, sizeof(iov));
433 		iov.iov_base = &data;
434 		iov.iov_len = sizeof(data);
435 
436 		memset(&hdr, 0, sizeof(hdr));
437 		hdr.msg_name = (void *)&addr;
438 		hdr.msg_namelen = addr_len;
439 		hdr.msg_iov = &iov;
440 		hdr.msg_iovlen = 1;
441 
442 		err = sendmsg(client, &hdr, 0);
443 		if (!ASSERT_EQ(err, sizeof(data), "sendmsg"))
444 			goto cleanup;
445 	} else {
446 		/* Testing with connection-oriented sockets is only valid for
447 		 * recvmsg() tests.
448 		 */
449 		if (!ASSERT_EQ(test->type, SOCK_ADDR_TEST_RECVMSG, "recvmsg"))
450 			goto cleanup;
451 
452 		err = connect(client, (const struct sockaddr *)&addr, addr_len);
453 		if (!ASSERT_OK(err, "connect"))
454 			goto cleanup;
455 
456 		err = send(client, &data, sizeof(data), 0);
457 		if (!ASSERT_EQ(err, sizeof(data), "send"))
458 			goto cleanup;
459 
460 		err = listen(serv, 0);
461 		if (!ASSERT_OK(err, "listen"))
462 			goto cleanup;
463 
464 		err = accept(serv, NULL, NULL);
465 		if (!ASSERT_GE(err, 0, "accept"))
466 			goto cleanup;
467 
468 		close(serv);
469 		serv = err;
470 	}
471 
472 	addr_len = src_addr_len = sizeof(struct sockaddr_storage);
473 
474 	err = recvfrom(serv, &data, sizeof(data), 0, (struct sockaddr *) &src_addr, &src_addr_len);
475 	if (!ASSERT_EQ(err, sizeof(data), "recvfrom"))
476 		goto cleanup;
477 
478 	ASSERT_EQ(data, 'a', "data mismatch");
479 
480 	if (test->expected_src_addr) {
481 		err = make_sockaddr(test->socket_family, test->expected_src_addr, 0,
482 				    &addr, &addr_len);
483 		if (!ASSERT_EQ(err, 0, "make_sockaddr"))
484 			goto cleanup;
485 
486 		err = cmp_addr(&src_addr, src_addr_len, &addr, addr_len, false);
487 		if (!ASSERT_EQ(err, 0, "cmp_addr"))
488 			goto cleanup;
489 	}
490 
491 cleanup:
492 	if (client != -1)
493 		close(client);
494 	if (serv != -1)
495 		close(serv);
496 }
497 
498 static void test_getsockname(struct sock_addr_test *test)
499 {
500 	struct sockaddr_storage expected_addr;
501 	socklen_t expected_addr_len = sizeof(struct sockaddr_storage);
502 	int serv = -1, err;
503 
504 	serv = start_server(test->socket_family, test->socket_type,
505 			    test->requested_addr, test->requested_port, 0);
506 	if (!ASSERT_GE(serv, 0, "start_server"))
507 		goto cleanup;
508 
509 	err = make_sockaddr(test->socket_family,
510 			    test->expected_addr, test->expected_port,
511 			    &expected_addr, &expected_addr_len);
512 	if (!ASSERT_EQ(err, 0, "make_sockaddr"))
513 		goto cleanup;
514 
515 	err = cmp_local_addr(serv, &expected_addr, expected_addr_len, true);
516 	if (!ASSERT_EQ(err, 0, "cmp_local_addr"))
517 		goto cleanup;
518 
519 cleanup:
520 	if (serv != -1)
521 		close(serv);
522 }
523 
524 static void test_getpeername(struct sock_addr_test *test)
525 {
526 	struct sockaddr_storage addr, expected_addr;
527 	socklen_t addr_len = sizeof(struct sockaddr_storage),
528 		  expected_addr_len = sizeof(struct sockaddr_storage);
529 	int serv = -1, client = -1, err;
530 
531 	serv = start_server(test->socket_family, test->socket_type,
532 			    test->requested_addr, test->requested_port, 0);
533 	if (!ASSERT_GE(serv, 0, "start_server"))
534 		goto cleanup;
535 
536 	err = make_sockaddr(test->socket_family, test->requested_addr, test->requested_port,
537 			    &addr, &addr_len);
538 	if (!ASSERT_EQ(err, 0, "make_sockaddr"))
539 		goto cleanup;
540 
541 	client = connect_to_addr(&addr, addr_len, test->socket_type);
542 	if (!ASSERT_GE(client, 0, "connect_to_addr"))
543 		goto cleanup;
544 
545 	err = make_sockaddr(test->socket_family, test->expected_addr, test->expected_port,
546 			    &expected_addr, &expected_addr_len);
547 	if (!ASSERT_EQ(err, 0, "make_sockaddr"))
548 		goto cleanup;
549 
550 	err = cmp_peer_addr(client, &expected_addr, expected_addr_len, true);
551 	if (!ASSERT_EQ(err, 0, "cmp_peer_addr"))
552 		goto cleanup;
553 
554 cleanup:
555 	if (client != -1)
556 		close(client);
557 	if (serv != -1)
558 		close(serv);
559 }
560 
561 void test_sock_addr(void)
562 {
563 	int cgroup_fd = -1;
564 	void *skel;
565 
566 	cgroup_fd = test__join_cgroup("/sock_addr");
567 	if (!ASSERT_GE(cgroup_fd, 0, "join_cgroup"))
568 		goto cleanup;
569 
570 	for (size_t i = 0; i < ARRAY_SIZE(tests); ++i) {
571 		struct sock_addr_test *test = &tests[i];
572 
573 		if (!test__start_subtest(test->name))
574 			continue;
575 
576 		skel = test->loadfn(cgroup_fd);
577 		if (!skel)
578 			continue;
579 
580 		switch (test->type) {
581 		/* Not exercised yet but we leave this code here for when the
582 		 * INET and INET6 sockaddr tests are migrated to this file in
583 		 * the future.
584 		 */
585 		case SOCK_ADDR_TEST_BIND:
586 			test_bind(test);
587 			break;
588 		case SOCK_ADDR_TEST_CONNECT:
589 			test_connect(test);
590 			break;
591 		case SOCK_ADDR_TEST_SENDMSG:
592 		case SOCK_ADDR_TEST_RECVMSG:
593 			test_xmsg(test);
594 			break;
595 		case SOCK_ADDR_TEST_GETSOCKNAME:
596 			test_getsockname(test);
597 			break;
598 		case SOCK_ADDR_TEST_GETPEERNAME:
599 			test_getpeername(test);
600 			break;
601 		default:
602 			ASSERT_TRUE(false, "Unknown sock addr test type");
603 			break;
604 		}
605 
606 		test->destroyfn(skel);
607 	}
608 
609 cleanup:
610 	if (cgroup_fd >= 0)
611 		close(cgroup_fd);
612 }
613