1 // SPDX-License-Identifier: GPL-2.0
2 /* Author: Dmitry Safonov <dima@arista.com> */
3 /* This is over-simplified TCP_REPAIR for TCP_ESTABLISHED sockets
4  * It tests that TCP-AO enabled connection can be restored.
5  * For the proper socket repair see:
6  * https://github.com/checkpoint-restore/criu/blob/criu-dev/soccr/soccr.h
7  */
8 #include <inttypes.h>
9 #include "aolib.h"
10 
11 const size_t nr_packets = 20;
12 const size_t msg_len = 100;
13 const size_t quota = nr_packets * msg_len;
14 #define fault(type)	(inj == FAULT_ ## type)
15 
try_server_run(const char * tst_name,unsigned int port,fault_t inj,test_cnt cnt_expected)16 static void try_server_run(const char *tst_name, unsigned int port,
17 			   fault_t inj, test_cnt cnt_expected)
18 {
19 	const char *cnt_name = "TCPAOGood";
20 	struct tcp_ao_counters ao1, ao2;
21 	uint64_t before_cnt, after_cnt;
22 	int sk, lsk;
23 	time_t timeout;
24 	ssize_t bytes;
25 
26 	if (fault(TIMEOUT))
27 		cnt_name = "TCPAOBad";
28 	lsk = test_listen_socket(this_ip_addr, port, 1);
29 
30 	if (test_add_key(lsk, DEFAULT_TEST_PASSWORD, this_ip_dest, -1, 100, 100))
31 		test_error("setsockopt(TCP_AO_ADD_KEY)");
32 	synchronize_threads(); /* 1: MKT added => connect() */
33 
34 	if (test_wait_fd(lsk, TEST_TIMEOUT_SEC, 0))
35 		test_error("test_wait_fd()");
36 
37 	sk = accept(lsk, NULL, NULL);
38 	if (sk < 0)
39 		test_error("accept()");
40 
41 	synchronize_threads(); /* 2: accepted => send data */
42 	close(lsk);
43 
44 	bytes = test_server_run(sk, quota, TEST_TIMEOUT_SEC);
45 	if (bytes != quota) {
46 		test_fail("%s: server served: %zd", tst_name, bytes);
47 		goto out;
48 	}
49 
50 	before_cnt = netstat_get_one(cnt_name, NULL);
51 	if (test_get_tcp_ao_counters(sk, &ao1))
52 		test_error("test_get_tcp_ao_counters()");
53 
54 	timeout = fault(TIMEOUT) ? TEST_RETRANSMIT_SEC : TEST_TIMEOUT_SEC;
55 	bytes = test_server_run(sk, quota, timeout);
56 	if (fault(TIMEOUT)) {
57 		if (bytes > 0)
58 			test_fail("%s: server served: %zd", tst_name, bytes);
59 		else
60 			test_ok("%s: server couldn't serve", tst_name);
61 	} else {
62 		if (bytes != quota)
63 			test_fail("%s: server served: %zd", tst_name, bytes);
64 		else
65 			test_ok("%s: server alive", tst_name);
66 	}
67 	if (test_get_tcp_ao_counters(sk, &ao2))
68 		test_error("test_get_tcp_ao_counters()");
69 	after_cnt = netstat_get_one(cnt_name, NULL);
70 
71 	test_tcp_ao_counters_cmp(tst_name, &ao1, &ao2, cnt_expected);
72 
73 	if (after_cnt <= before_cnt) {
74 		test_fail("%s: %s counter did not increase: %zu <= %zu",
75 				tst_name, cnt_name, after_cnt, before_cnt);
76 	} else {
77 		test_ok("%s: counter %s increased %zu => %zu",
78 			tst_name, cnt_name, before_cnt, after_cnt);
79 	}
80 
81 	/*
82 	 * Before close() as that will send FIN and move the peer in TCP_CLOSE
83 	 * and that will prevent reading AO counters from the peer's socket.
84 	 */
85 	synchronize_threads(); /* 3: verified => closed */
86 out:
87 	close(sk);
88 }
89 
server_fn(void * arg)90 static void *server_fn(void *arg)
91 {
92 	unsigned int port = test_server_port;
93 
94 	try_server_run("TCP-AO migrate to another socket", port++,
95 		       0, TEST_CNT_GOOD);
96 	try_server_run("TCP-AO with wrong send ISN", port++,
97 		       FAULT_TIMEOUT, TEST_CNT_BAD);
98 	try_server_run("TCP-AO with wrong receive ISN", port++,
99 		       FAULT_TIMEOUT, TEST_CNT_BAD);
100 	try_server_run("TCP-AO with wrong send SEQ ext number", port++,
101 		       FAULT_TIMEOUT, TEST_CNT_BAD);
102 	try_server_run("TCP-AO with wrong receive SEQ ext number", port++,
103 		       FAULT_TIMEOUT, TEST_CNT_NS_BAD | TEST_CNT_GOOD);
104 
105 	synchronize_threads(); /* don't race to exit: client exits */
106 	return NULL;
107 }
108 
test_get_sk_checkpoint(unsigned int server_port,sockaddr_af * saddr,struct tcp_sock_state * img,struct tcp_ao_repair * ao_img)109 static void test_get_sk_checkpoint(unsigned int server_port, sockaddr_af *saddr,
110 				   struct tcp_sock_state *img,
111 				   struct tcp_ao_repair *ao_img)
112 {
113 	int sk;
114 
115 	sk = socket(test_family, SOCK_STREAM, IPPROTO_TCP);
116 	if (sk < 0)
117 		test_error("socket()");
118 
119 	if (test_add_key(sk, DEFAULT_TEST_PASSWORD, this_ip_dest, -1, 100, 100))
120 		test_error("setsockopt(TCP_AO_ADD_KEY)");
121 
122 	synchronize_threads(); /* 1: MKT added => connect() */
123 	if (test_connect_socket(sk, this_ip_dest, server_port) <= 0)
124 		test_error("failed to connect()");
125 
126 	synchronize_threads(); /* 2: accepted => send data */
127 	if (test_client_verify(sk, msg_len, nr_packets, TEST_TIMEOUT_SEC))
128 		test_fail("pre-migrate verify failed");
129 
130 	test_enable_repair(sk);
131 	test_sock_checkpoint(sk, img, saddr);
132 	test_ao_checkpoint(sk, ao_img);
133 	test_kill_sk(sk);
134 }
135 
test_sk_restore(const char * tst_name,unsigned int server_port,sockaddr_af * saddr,struct tcp_sock_state * img,struct tcp_ao_repair * ao_img,fault_t inj,test_cnt cnt_expected)136 static void test_sk_restore(const char *tst_name, unsigned int server_port,
137 			    sockaddr_af *saddr, struct tcp_sock_state *img,
138 			    struct tcp_ao_repair *ao_img,
139 			    fault_t inj, test_cnt cnt_expected)
140 {
141 	const char *cnt_name = "TCPAOGood";
142 	struct tcp_ao_counters ao1, ao2;
143 	uint64_t before_cnt, after_cnt;
144 	time_t timeout;
145 	int sk;
146 
147 	if (fault(TIMEOUT))
148 		cnt_name = "TCPAOBad";
149 
150 	before_cnt = netstat_get_one(cnt_name, NULL);
151 	sk = socket(test_family, SOCK_STREAM, IPPROTO_TCP);
152 	if (sk < 0)
153 		test_error("socket()");
154 
155 	test_enable_repair(sk);
156 	test_sock_restore(sk, img, saddr, this_ip_dest, server_port);
157 	if (test_add_repaired_key(sk, DEFAULT_TEST_PASSWORD, 0, this_ip_dest, -1, 100, 100))
158 		test_error("setsockopt(TCP_AO_ADD_KEY)");
159 	test_ao_restore(sk, ao_img);
160 
161 	if (test_get_tcp_ao_counters(sk, &ao1))
162 		test_error("test_get_tcp_ao_counters()");
163 
164 	test_disable_repair(sk);
165 	test_sock_state_free(img);
166 
167 	timeout = fault(TIMEOUT) ? TEST_RETRANSMIT_SEC : TEST_TIMEOUT_SEC;
168 	if (test_client_verify(sk, msg_len, nr_packets, timeout)) {
169 		if (fault(TIMEOUT))
170 			test_ok("%s: post-migrate connection is broken", tst_name);
171 		else
172 			test_fail("%s: post-migrate connection is working", tst_name);
173 	} else {
174 		if (fault(TIMEOUT))
175 			test_fail("%s: post-migrate connection still working", tst_name);
176 		else
177 			test_ok("%s: post-migrate connection is alive", tst_name);
178 	}
179 	if (test_get_tcp_ao_counters(sk, &ao2))
180 		test_error("test_get_tcp_ao_counters()");
181 	after_cnt = netstat_get_one(cnt_name, NULL);
182 
183 	test_tcp_ao_counters_cmp(tst_name, &ao1, &ao2, cnt_expected);
184 
185 	if (after_cnt <= before_cnt) {
186 		test_fail("%s: %s counter did not increase: %zu <= %zu",
187 				tst_name, cnt_name, after_cnt, before_cnt);
188 	} else {
189 		test_ok("%s: counter %s increased %zu => %zu",
190 			tst_name, cnt_name, before_cnt, after_cnt);
191 	}
192 	synchronize_threads(); /* 3: verified => closed */
193 	close(sk);
194 }
195 
client_fn(void * arg)196 static void *client_fn(void *arg)
197 {
198 	unsigned int port = test_server_port;
199 	struct tcp_sock_state tcp_img;
200 	struct tcp_ao_repair ao_img;
201 	sockaddr_af saddr;
202 
203 	test_get_sk_checkpoint(port, &saddr, &tcp_img, &ao_img);
204 	test_sk_restore("TCP-AO migrate to another socket", port++,
205 			&saddr, &tcp_img, &ao_img, 0, TEST_CNT_GOOD);
206 
207 	test_get_sk_checkpoint(port, &saddr, &tcp_img, &ao_img);
208 	ao_img.snt_isn += 1;
209 	test_sk_restore("TCP-AO with wrong send ISN", port++,
210 			&saddr, &tcp_img, &ao_img, FAULT_TIMEOUT, TEST_CNT_BAD);
211 
212 	test_get_sk_checkpoint(port, &saddr, &tcp_img, &ao_img);
213 	ao_img.rcv_isn += 1;
214 	test_sk_restore("TCP-AO with wrong receive ISN", port++,
215 			&saddr, &tcp_img, &ao_img, FAULT_TIMEOUT, TEST_CNT_BAD);
216 
217 	test_get_sk_checkpoint(port, &saddr, &tcp_img, &ao_img);
218 	ao_img.snd_sne += 1;
219 	test_sk_restore("TCP-AO with wrong send SEQ ext number", port++,
220 			&saddr, &tcp_img, &ao_img, FAULT_TIMEOUT,
221 			TEST_CNT_NS_BAD | TEST_CNT_GOOD);
222 
223 	test_get_sk_checkpoint(port, &saddr, &tcp_img, &ao_img);
224 	ao_img.rcv_sne += 1;
225 	test_sk_restore("TCP-AO with wrong receive SEQ ext number", port++,
226 			&saddr, &tcp_img, &ao_img, FAULT_TIMEOUT,
227 			TEST_CNT_NS_GOOD | TEST_CNT_BAD);
228 
229 	return NULL;
230 }
231 
main(int argc,char * argv[])232 int main(int argc, char *argv[])
233 {
234 	test_init(20, server_fn, client_fn);
235 	return 0;
236 }
237