1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright (c) 2020, Tessares SA. */
3 /* Copyright (c) 2022, SUSE. */
4 
5 #include <test_progs.h>
6 #include "cgroup_helpers.h"
7 #include "network_helpers.h"
8 #include "mptcp_sock.skel.h"
9 
10 #ifndef TCP_CA_NAME_MAX
11 #define TCP_CA_NAME_MAX	16
12 #endif
13 
14 struct mptcp_storage {
15 	__u32 invoked;
16 	__u32 is_mptcp;
17 	struct sock *sk;
18 	__u32 token;
19 	struct sock *first;
20 	char ca_name[TCP_CA_NAME_MAX];
21 };
22 
23 static int verify_tsk(int map_fd, int client_fd)
24 {
25 	int err, cfd = client_fd;
26 	struct mptcp_storage val;
27 
28 	err = bpf_map_lookup_elem(map_fd, &cfd, &val);
29 	if (!ASSERT_OK(err, "bpf_map_lookup_elem"))
30 		return err;
31 
32 	if (!ASSERT_EQ(val.invoked, 1, "unexpected invoked count"))
33 		err++;
34 
35 	if (!ASSERT_EQ(val.is_mptcp, 0, "unexpected is_mptcp"))
36 		err++;
37 
38 	return err;
39 }
40 
41 static void get_msk_ca_name(char ca_name[])
42 {
43 	size_t len;
44 	int fd;
45 
46 	fd = open("/proc/sys/net/ipv4/tcp_congestion_control", O_RDONLY);
47 	if (!ASSERT_GE(fd, 0, "failed to open tcp_congestion_control"))
48 		return;
49 
50 	len = read(fd, ca_name, TCP_CA_NAME_MAX);
51 	if (!ASSERT_GT(len, 0, "failed to read ca_name"))
52 		goto err;
53 
54 	if (len > 0 && ca_name[len - 1] == '\n')
55 		ca_name[len - 1] = '\0';
56 
57 err:
58 	close(fd);
59 }
60 
61 static int verify_msk(int map_fd, int client_fd, __u32 token)
62 {
63 	char ca_name[TCP_CA_NAME_MAX];
64 	int err, cfd = client_fd;
65 	struct mptcp_storage val;
66 
67 	if (!ASSERT_GT(token, 0, "invalid token"))
68 		return -1;
69 
70 	get_msk_ca_name(ca_name);
71 
72 	err = bpf_map_lookup_elem(map_fd, &cfd, &val);
73 	if (!ASSERT_OK(err, "bpf_map_lookup_elem"))
74 		return err;
75 
76 	if (!ASSERT_EQ(val.invoked, 1, "unexpected invoked count"))
77 		err++;
78 
79 	if (!ASSERT_EQ(val.is_mptcp, 1, "unexpected is_mptcp"))
80 		err++;
81 
82 	if (!ASSERT_EQ(val.token, token, "unexpected token"))
83 		err++;
84 
85 	if (!ASSERT_EQ(val.first, val.sk, "unexpected first"))
86 		err++;
87 
88 	if (!ASSERT_STRNEQ(val.ca_name, ca_name, TCP_CA_NAME_MAX, "unexpected ca_name"))
89 		err++;
90 
91 	return err;
92 }
93 
94 static int run_test(int cgroup_fd, int server_fd, bool is_mptcp)
95 {
96 	int client_fd, prog_fd, map_fd, err;
97 	struct mptcp_sock *sock_skel;
98 
99 	sock_skel = mptcp_sock__open_and_load();
100 	if (!ASSERT_OK_PTR(sock_skel, "skel_open_load"))
101 		return -EIO;
102 
103 	err = mptcp_sock__attach(sock_skel);
104 	if (!ASSERT_OK(err, "skel_attach"))
105 		goto out;
106 
107 	prog_fd = bpf_program__fd(sock_skel->progs._sockops);
108 	if (!ASSERT_GE(prog_fd, 0, "bpf_program__fd")) {
109 		err = -EIO;
110 		goto out;
111 	}
112 
113 	map_fd = bpf_map__fd(sock_skel->maps.socket_storage_map);
114 	if (!ASSERT_GE(map_fd, 0, "bpf_map__fd")) {
115 		err = -EIO;
116 		goto out;
117 	}
118 
119 	err = bpf_prog_attach(prog_fd, cgroup_fd, BPF_CGROUP_SOCK_OPS, 0);
120 	if (!ASSERT_OK(err, "bpf_prog_attach"))
121 		goto out;
122 
123 	client_fd = connect_to_fd(server_fd, 0);
124 	if (!ASSERT_GE(client_fd, 0, "connect to fd")) {
125 		err = -EIO;
126 		goto out;
127 	}
128 
129 	err += is_mptcp ? verify_msk(map_fd, client_fd, sock_skel->bss->token) :
130 			  verify_tsk(map_fd, client_fd);
131 
132 	close(client_fd);
133 
134 out:
135 	mptcp_sock__destroy(sock_skel);
136 	return err;
137 }
138 
139 static void test_base(void)
140 {
141 	int server_fd, cgroup_fd;
142 
143 	cgroup_fd = test__join_cgroup("/mptcp");
144 	if (!ASSERT_GE(cgroup_fd, 0, "test__join_cgroup"))
145 		return;
146 
147 	/* without MPTCP */
148 	server_fd = start_server(AF_INET, SOCK_STREAM, NULL, 0, 0);
149 	if (!ASSERT_GE(server_fd, 0, "start_server"))
150 		goto with_mptcp;
151 
152 	ASSERT_OK(run_test(cgroup_fd, server_fd, false), "run_test tcp");
153 
154 	close(server_fd);
155 
156 with_mptcp:
157 	/* with MPTCP */
158 	server_fd = start_mptcp_server(AF_INET, NULL, 0, 0);
159 	if (!ASSERT_GE(server_fd, 0, "start_mptcp_server"))
160 		goto close_cgroup_fd;
161 
162 	ASSERT_OK(run_test(cgroup_fd, server_fd, true), "run_test mptcp");
163 
164 	close(server_fd);
165 
166 close_cgroup_fd:
167 	close(cgroup_fd);
168 }
169 
170 void test_mptcp(void)
171 {
172 	if (test__start_subtest("base"))
173 		test_base();
174 }
175