1 // SPDX-License-Identifier: GPL-2.0
2 
3 #define _GNU_SOURCE
4 
5 #include <errno.h>
6 #include <limits.h>
7 #include <fcntl.h>
8 #include <string.h>
9 #include <stdbool.h>
10 #include <stdint.h>
11 #include <stdio.h>
12 #include <stdlib.h>
13 #include <strings.h>
14 #include <signal.h>
15 #include <unistd.h>
16 
17 #include <sys/poll.h>
18 #include <sys/sendfile.h>
19 #include <sys/stat.h>
20 #include <sys/socket.h>
21 #include <sys/types.h>
22 #include <sys/mman.h>
23 
24 #include <netdb.h>
25 #include <netinet/in.h>
26 
27 #include <linux/tcp.h>
28 
29 extern int optind;
30 
31 #ifndef IPPROTO_MPTCP
32 #define IPPROTO_MPTCP 262
33 #endif
34 #ifndef TCP_ULP
35 #define TCP_ULP 31
36 #endif
37 
38 static int  poll_timeout = 10 * 1000;
39 static bool listen_mode;
40 static bool quit;
41 
42 enum cfg_mode {
43 	CFG_MODE_POLL,
44 	CFG_MODE_MMAP,
45 	CFG_MODE_SENDFILE,
46 };
47 
48 enum cfg_peek {
49 	CFG_NONE_PEEK,
50 	CFG_WITH_PEEK,
51 	CFG_AFTER_PEEK,
52 };
53 
54 static enum cfg_mode cfg_mode = CFG_MODE_POLL;
55 static enum cfg_peek cfg_peek = CFG_NONE_PEEK;
56 static const char *cfg_host;
57 static const char *cfg_port	= "12000";
58 static int cfg_sock_proto	= IPPROTO_MPTCP;
59 static bool tcpulp_audit;
60 static int pf = AF_INET;
61 static int cfg_sndbuf;
62 static int cfg_rcvbuf;
63 static bool cfg_join;
64 static bool cfg_remove;
65 static unsigned int cfg_do_w;
66 static int cfg_wait;
67 static uint32_t cfg_mark;
68 
die_usage(void)69 static void die_usage(void)
70 {
71 	fprintf(stderr, "Usage: mptcp_connect [-6] [-u] [-s MPTCP|TCP] [-p port] [-m mode]"
72 		"[-l] [-w sec] connect_address\n");
73 	fprintf(stderr, "\t-6 use ipv6\n");
74 	fprintf(stderr, "\t-t num -- set poll timeout to num\n");
75 	fprintf(stderr, "\t-S num -- set SO_SNDBUF to num\n");
76 	fprintf(stderr, "\t-R num -- set SO_RCVBUF to num\n");
77 	fprintf(stderr, "\t-p num -- use port num\n");
78 	fprintf(stderr, "\t-s [MPTCP|TCP] -- use mptcp(default) or tcp sockets\n");
79 	fprintf(stderr, "\t-m [poll|mmap|sendfile] -- use poll(default)/mmap+write/sendfile\n");
80 	fprintf(stderr, "\t-M mark -- set socket packet mark\n");
81 	fprintf(stderr, "\t-u -- check mptcp ulp\n");
82 	fprintf(stderr, "\t-w num -- wait num sec before closing the socket\n");
83 	fprintf(stderr,
84 		"\t-P [saveWithPeek|saveAfterPeek] -- save data with/after MSG_PEEK form tcp socket\n");
85 	exit(1);
86 }
87 
handle_signal(int nr)88 static void handle_signal(int nr)
89 {
90 	quit = true;
91 }
92 
getxinfo_strerr(int err)93 static const char *getxinfo_strerr(int err)
94 {
95 	if (err == EAI_SYSTEM)
96 		return strerror(errno);
97 
98 	return gai_strerror(err);
99 }
100 
xgetnameinfo(const struct sockaddr * addr,socklen_t addrlen,char * host,socklen_t hostlen,char * serv,socklen_t servlen)101 static void xgetnameinfo(const struct sockaddr *addr, socklen_t addrlen,
102 			 char *host, socklen_t hostlen,
103 			 char *serv, socklen_t servlen)
104 {
105 	int flags = NI_NUMERICHOST | NI_NUMERICSERV;
106 	int err = getnameinfo(addr, addrlen, host, hostlen, serv, servlen,
107 			      flags);
108 
109 	if (err) {
110 		const char *errstr = getxinfo_strerr(err);
111 
112 		fprintf(stderr, "Fatal: getnameinfo: %s\n", errstr);
113 		exit(1);
114 	}
115 }
116 
xgetaddrinfo(const char * node,const char * service,const struct addrinfo * hints,struct addrinfo ** res)117 static void xgetaddrinfo(const char *node, const char *service,
118 			 const struct addrinfo *hints,
119 			 struct addrinfo **res)
120 {
121 	int err = getaddrinfo(node, service, hints, res);
122 
123 	if (err) {
124 		const char *errstr = getxinfo_strerr(err);
125 
126 		fprintf(stderr, "Fatal: getaddrinfo(%s:%s): %s\n",
127 			node ? node : "", service ? service : "", errstr);
128 		exit(1);
129 	}
130 }
131 
set_rcvbuf(int fd,unsigned int size)132 static void set_rcvbuf(int fd, unsigned int size)
133 {
134 	int err;
135 
136 	err = setsockopt(fd, SOL_SOCKET, SO_RCVBUF, &size, sizeof(size));
137 	if (err) {
138 		perror("set SO_RCVBUF");
139 		exit(1);
140 	}
141 }
142 
set_sndbuf(int fd,unsigned int size)143 static void set_sndbuf(int fd, unsigned int size)
144 {
145 	int err;
146 
147 	err = setsockopt(fd, SOL_SOCKET, SO_SNDBUF, &size, sizeof(size));
148 	if (err) {
149 		perror("set SO_SNDBUF");
150 		exit(1);
151 	}
152 }
153 
set_mark(int fd,uint32_t mark)154 static void set_mark(int fd, uint32_t mark)
155 {
156 	int err;
157 
158 	err = setsockopt(fd, SOL_SOCKET, SO_MARK, &mark, sizeof(mark));
159 	if (err) {
160 		perror("set SO_MARK");
161 		exit(1);
162 	}
163 }
164 
sock_listen_mptcp(const char * const listenaddr,const char * const port)165 static int sock_listen_mptcp(const char * const listenaddr,
166 			     const char * const port)
167 {
168 	int sock;
169 	struct addrinfo hints = {
170 		.ai_protocol = IPPROTO_TCP,
171 		.ai_socktype = SOCK_STREAM,
172 		.ai_flags = AI_PASSIVE | AI_NUMERICHOST
173 	};
174 
175 	hints.ai_family = pf;
176 
177 	struct addrinfo *a, *addr;
178 	int one = 1;
179 
180 	xgetaddrinfo(listenaddr, port, &hints, &addr);
181 	hints.ai_family = pf;
182 
183 	for (a = addr; a; a = a->ai_next) {
184 		sock = socket(a->ai_family, a->ai_socktype, cfg_sock_proto);
185 		if (sock < 0)
186 			continue;
187 
188 		if (-1 == setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &one,
189 				     sizeof(one)))
190 			perror("setsockopt");
191 
192 		if (bind(sock, a->ai_addr, a->ai_addrlen) == 0)
193 			break; /* success */
194 
195 		perror("bind");
196 		close(sock);
197 		sock = -1;
198 	}
199 
200 	freeaddrinfo(addr);
201 
202 	if (sock < 0) {
203 		fprintf(stderr, "Could not create listen socket\n");
204 		return sock;
205 	}
206 
207 	if (listen(sock, 20)) {
208 		perror("listen");
209 		close(sock);
210 		return -1;
211 	}
212 
213 	return sock;
214 }
215 
sock_test_tcpulp(const char * const remoteaddr,const char * const port)216 static bool sock_test_tcpulp(const char * const remoteaddr,
217 			     const char * const port)
218 {
219 	struct addrinfo hints = {
220 		.ai_protocol = IPPROTO_TCP,
221 		.ai_socktype = SOCK_STREAM,
222 	};
223 	struct addrinfo *a, *addr;
224 	int sock = -1, ret = 0;
225 	bool test_pass = false;
226 
227 	hints.ai_family = AF_INET;
228 
229 	xgetaddrinfo(remoteaddr, port, &hints, &addr);
230 	for (a = addr; a; a = a->ai_next) {
231 		sock = socket(a->ai_family, a->ai_socktype, IPPROTO_TCP);
232 		if (sock < 0) {
233 			perror("socket");
234 			continue;
235 		}
236 		ret = setsockopt(sock, IPPROTO_TCP, TCP_ULP, "mptcp",
237 				 sizeof("mptcp"));
238 		if (ret == -1 && errno == EOPNOTSUPP)
239 			test_pass = true;
240 		close(sock);
241 
242 		if (test_pass)
243 			break;
244 		if (!ret)
245 			fprintf(stderr,
246 				"setsockopt(TCP_ULP) returned 0\n");
247 		else
248 			perror("setsockopt(TCP_ULP)");
249 	}
250 	return test_pass;
251 }
252 
sock_connect_mptcp(const char * const remoteaddr,const char * const port,int proto)253 static int sock_connect_mptcp(const char * const remoteaddr,
254 			      const char * const port, int proto)
255 {
256 	struct addrinfo hints = {
257 		.ai_protocol = IPPROTO_TCP,
258 		.ai_socktype = SOCK_STREAM,
259 	};
260 	struct addrinfo *a, *addr;
261 	int sock = -1;
262 
263 	hints.ai_family = pf;
264 
265 	xgetaddrinfo(remoteaddr, port, &hints, &addr);
266 	for (a = addr; a; a = a->ai_next) {
267 		sock = socket(a->ai_family, a->ai_socktype, proto);
268 		if (sock < 0) {
269 			perror("socket");
270 			continue;
271 		}
272 
273 		if (cfg_mark)
274 			set_mark(sock, cfg_mark);
275 
276 		if (connect(sock, a->ai_addr, a->ai_addrlen) == 0)
277 			break; /* success */
278 
279 		perror("connect()");
280 		close(sock);
281 		sock = -1;
282 	}
283 
284 	freeaddrinfo(addr);
285 	return sock;
286 }
287 
do_rnd_write(const int fd,char * buf,const size_t len)288 static size_t do_rnd_write(const int fd, char *buf, const size_t len)
289 {
290 	static bool first = true;
291 	unsigned int do_w;
292 	ssize_t bw;
293 
294 	do_w = rand() & 0xffff;
295 	if (do_w == 0 || do_w > len)
296 		do_w = len;
297 
298 	if (cfg_join && first && do_w > 100)
299 		do_w = 100;
300 
301 	if (cfg_remove && do_w > cfg_do_w)
302 		do_w = cfg_do_w;
303 
304 	bw = write(fd, buf, do_w);
305 	if (bw < 0)
306 		perror("write");
307 
308 	/* let the join handshake complete, before going on */
309 	if (cfg_join && first) {
310 		usleep(200000);
311 		first = false;
312 	}
313 
314 	if (cfg_remove)
315 		usleep(200000);
316 
317 	return bw;
318 }
319 
do_write(const int fd,char * buf,const size_t len)320 static size_t do_write(const int fd, char *buf, const size_t len)
321 {
322 	size_t offset = 0;
323 
324 	while (offset < len) {
325 		size_t written;
326 		ssize_t bw;
327 
328 		bw = write(fd, buf + offset, len - offset);
329 		if (bw < 0) {
330 			perror("write");
331 			return 0;
332 		}
333 
334 		written = (size_t)bw;
335 		offset += written;
336 	}
337 
338 	return offset;
339 }
340 
do_rnd_read(const int fd,char * buf,const size_t len)341 static ssize_t do_rnd_read(const int fd, char *buf, const size_t len)
342 {
343 	int ret = 0;
344 	char tmp[16384];
345 	size_t cap = rand();
346 
347 	cap &= 0xffff;
348 
349 	if (cap == 0)
350 		cap = 1;
351 	else if (cap > len)
352 		cap = len;
353 
354 	if (cfg_peek == CFG_WITH_PEEK) {
355 		ret = recv(fd, buf, cap, MSG_PEEK);
356 		ret = (ret < 0) ? ret : read(fd, tmp, ret);
357 	} else if (cfg_peek == CFG_AFTER_PEEK) {
358 		ret = recv(fd, buf, cap, MSG_PEEK);
359 		ret = (ret < 0) ? ret : read(fd, buf, cap);
360 	} else {
361 		ret = read(fd, buf, cap);
362 	}
363 
364 	return ret;
365 }
366 
set_nonblock(int fd)367 static void set_nonblock(int fd)
368 {
369 	int flags = fcntl(fd, F_GETFL);
370 
371 	if (flags == -1)
372 		return;
373 
374 	fcntl(fd, F_SETFL, flags | O_NONBLOCK);
375 }
376 
copyfd_io_poll(int infd,int peerfd,int outfd)377 static int copyfd_io_poll(int infd, int peerfd, int outfd)
378 {
379 	struct pollfd fds = {
380 		.fd = peerfd,
381 		.events = POLLIN | POLLOUT,
382 	};
383 	unsigned int woff = 0, wlen = 0;
384 	char wbuf[8192];
385 
386 	set_nonblock(peerfd);
387 
388 	for (;;) {
389 		char rbuf[8192];
390 		ssize_t len;
391 
392 		if (fds.events == 0)
393 			break;
394 
395 		switch (poll(&fds, 1, poll_timeout)) {
396 		case -1:
397 			if (errno == EINTR)
398 				continue;
399 			perror("poll");
400 			return 1;
401 		case 0:
402 			fprintf(stderr, "%s: poll timed out (events: "
403 				"POLLIN %u, POLLOUT %u)\n", __func__,
404 				fds.events & POLLIN, fds.events & POLLOUT);
405 			return 2;
406 		}
407 
408 		if (fds.revents & POLLIN) {
409 			len = do_rnd_read(peerfd, rbuf, sizeof(rbuf));
410 			if (len == 0) {
411 				/* no more data to receive:
412 				 * peer has closed its write side
413 				 */
414 				fds.events &= ~POLLIN;
415 
416 				if ((fds.events & POLLOUT) == 0)
417 					/* and nothing more to send */
418 					break;
419 
420 			/* Else, still have data to transmit */
421 			} else if (len < 0) {
422 				perror("read");
423 				return 3;
424 			}
425 
426 			do_write(outfd, rbuf, len);
427 		}
428 
429 		if (fds.revents & POLLOUT) {
430 			if (wlen == 0) {
431 				woff = 0;
432 				wlen = read(infd, wbuf, sizeof(wbuf));
433 			}
434 
435 			if (wlen > 0) {
436 				ssize_t bw;
437 
438 				bw = do_rnd_write(peerfd, wbuf + woff, wlen);
439 				if (bw < 0)
440 					return 111;
441 
442 				woff += bw;
443 				wlen -= bw;
444 			} else if (wlen == 0) {
445 				/* We have no more data to send. */
446 				fds.events &= ~POLLOUT;
447 
448 				if ((fds.events & POLLIN) == 0)
449 					/* ... and peer also closed already */
450 					break;
451 
452 				/* ... but we still receive.
453 				 * Close our write side, ev. give some time
454 				 * for address notification and/or checking
455 				 * the current status
456 				 */
457 				if (cfg_wait)
458 					usleep(cfg_wait);
459 				shutdown(peerfd, SHUT_WR);
460 			} else {
461 				if (errno == EINTR)
462 					continue;
463 				perror("read");
464 				return 4;
465 			}
466 		}
467 
468 		if (fds.revents & (POLLERR | POLLNVAL)) {
469 			fprintf(stderr, "Unexpected revents: "
470 				"POLLERR/POLLNVAL(%x)\n", fds.revents);
471 			return 5;
472 		}
473 	}
474 
475 	/* leave some time for late join/announce */
476 	if (cfg_join || cfg_remove)
477 		usleep(cfg_wait);
478 
479 	close(peerfd);
480 	return 0;
481 }
482 
do_recvfile(int infd,int outfd)483 static int do_recvfile(int infd, int outfd)
484 {
485 	ssize_t r;
486 
487 	do {
488 		char buf[16384];
489 
490 		r = do_rnd_read(infd, buf, sizeof(buf));
491 		if (r > 0) {
492 			if (write(outfd, buf, r) != r)
493 				break;
494 		} else if (r < 0) {
495 			perror("read");
496 		}
497 	} while (r > 0);
498 
499 	return (int)r;
500 }
501 
do_mmap(int infd,int outfd,unsigned int size)502 static int do_mmap(int infd, int outfd, unsigned int size)
503 {
504 	char *inbuf = mmap(NULL, size, PROT_READ, MAP_SHARED, infd, 0);
505 	ssize_t ret = 0, off = 0;
506 	size_t rem;
507 
508 	if (inbuf == MAP_FAILED) {
509 		perror("mmap");
510 		return 1;
511 	}
512 
513 	rem = size;
514 
515 	while (rem > 0) {
516 		ret = write(outfd, inbuf + off, rem);
517 
518 		if (ret < 0) {
519 			perror("write");
520 			break;
521 		}
522 
523 		off += ret;
524 		rem -= ret;
525 	}
526 
527 	munmap(inbuf, size);
528 	return rem;
529 }
530 
get_infd_size(int fd)531 static int get_infd_size(int fd)
532 {
533 	struct stat sb;
534 	ssize_t count;
535 	int err;
536 
537 	err = fstat(fd, &sb);
538 	if (err < 0) {
539 		perror("fstat");
540 		return -1;
541 	}
542 
543 	if ((sb.st_mode & S_IFMT) != S_IFREG) {
544 		fprintf(stderr, "%s: stdin is not a regular file\n", __func__);
545 		return -2;
546 	}
547 
548 	count = sb.st_size;
549 	if (count > INT_MAX) {
550 		fprintf(stderr, "File too large: %zu\n", count);
551 		return -3;
552 	}
553 
554 	return (int)count;
555 }
556 
do_sendfile(int infd,int outfd,unsigned int count)557 static int do_sendfile(int infd, int outfd, unsigned int count)
558 {
559 	while (count > 0) {
560 		ssize_t r;
561 
562 		r = sendfile(outfd, infd, NULL, count);
563 		if (r < 0) {
564 			perror("sendfile");
565 			return 3;
566 		}
567 
568 		count -= r;
569 	}
570 
571 	return 0;
572 }
573 
copyfd_io_mmap(int infd,int peerfd,int outfd,unsigned int size)574 static int copyfd_io_mmap(int infd, int peerfd, int outfd,
575 			  unsigned int size)
576 {
577 	int err;
578 
579 	if (listen_mode) {
580 		err = do_recvfile(peerfd, outfd);
581 		if (err)
582 			return err;
583 
584 		err = do_mmap(infd, peerfd, size);
585 	} else {
586 		err = do_mmap(infd, peerfd, size);
587 		if (err)
588 			return err;
589 
590 		shutdown(peerfd, SHUT_WR);
591 
592 		err = do_recvfile(peerfd, outfd);
593 	}
594 
595 	return err;
596 }
597 
copyfd_io_sendfile(int infd,int peerfd,int outfd,unsigned int size)598 static int copyfd_io_sendfile(int infd, int peerfd, int outfd,
599 			      unsigned int size)
600 {
601 	int err;
602 
603 	if (listen_mode) {
604 		err = do_recvfile(peerfd, outfd);
605 		if (err)
606 			return err;
607 
608 		err = do_sendfile(infd, peerfd, size);
609 	} else {
610 		err = do_sendfile(infd, peerfd, size);
611 		if (err)
612 			return err;
613 		err = do_recvfile(peerfd, outfd);
614 	}
615 
616 	return err;
617 }
618 
copyfd_io(int infd,int peerfd,int outfd)619 static int copyfd_io(int infd, int peerfd, int outfd)
620 {
621 	int file_size;
622 
623 	switch (cfg_mode) {
624 	case CFG_MODE_POLL:
625 		return copyfd_io_poll(infd, peerfd, outfd);
626 	case CFG_MODE_MMAP:
627 		file_size = get_infd_size(infd);
628 		if (file_size < 0)
629 			return file_size;
630 		return copyfd_io_mmap(infd, peerfd, outfd, file_size);
631 	case CFG_MODE_SENDFILE:
632 		file_size = get_infd_size(infd);
633 		if (file_size < 0)
634 			return file_size;
635 		return copyfd_io_sendfile(infd, peerfd, outfd, file_size);
636 	}
637 
638 	fprintf(stderr, "Invalid mode %d\n", cfg_mode);
639 
640 	die_usage();
641 	return 1;
642 }
643 
check_sockaddr(int pf,struct sockaddr_storage * ss,socklen_t salen)644 static void check_sockaddr(int pf, struct sockaddr_storage *ss,
645 			   socklen_t salen)
646 {
647 	struct sockaddr_in6 *sin6;
648 	struct sockaddr_in *sin;
649 	socklen_t wanted_size = 0;
650 
651 	switch (pf) {
652 	case AF_INET:
653 		wanted_size = sizeof(*sin);
654 		sin = (void *)ss;
655 		if (!sin->sin_port)
656 			fprintf(stderr, "accept: something wrong: ip connection from port 0");
657 		break;
658 	case AF_INET6:
659 		wanted_size = sizeof(*sin6);
660 		sin6 = (void *)ss;
661 		if (!sin6->sin6_port)
662 			fprintf(stderr, "accept: something wrong: ipv6 connection from port 0");
663 		break;
664 	default:
665 		fprintf(stderr, "accept: Unknown pf %d, salen %u\n", pf, salen);
666 		return;
667 	}
668 
669 	if (salen != wanted_size)
670 		fprintf(stderr, "accept: size mismatch, got %d expected %d\n",
671 			(int)salen, wanted_size);
672 
673 	if (ss->ss_family != pf)
674 		fprintf(stderr, "accept: pf mismatch, expect %d, ss_family is %d\n",
675 			(int)ss->ss_family, pf);
676 }
677 
check_getpeername(int fd,struct sockaddr_storage * ss,socklen_t salen)678 static void check_getpeername(int fd, struct sockaddr_storage *ss, socklen_t salen)
679 {
680 	struct sockaddr_storage peerss;
681 	socklen_t peersalen = sizeof(peerss);
682 
683 	if (getpeername(fd, (struct sockaddr *)&peerss, &peersalen) < 0) {
684 		perror("getpeername");
685 		return;
686 	}
687 
688 	if (peersalen != salen) {
689 		fprintf(stderr, "%s: %d vs %d\n", __func__, peersalen, salen);
690 		return;
691 	}
692 
693 	if (memcmp(ss, &peerss, peersalen)) {
694 		char a[INET6_ADDRSTRLEN];
695 		char b[INET6_ADDRSTRLEN];
696 		char c[INET6_ADDRSTRLEN];
697 		char d[INET6_ADDRSTRLEN];
698 
699 		xgetnameinfo((struct sockaddr *)ss, salen,
700 			     a, sizeof(a), b, sizeof(b));
701 
702 		xgetnameinfo((struct sockaddr *)&peerss, peersalen,
703 			     c, sizeof(c), d, sizeof(d));
704 
705 		fprintf(stderr, "%s: memcmp failure: accept %s vs peername %s, %s vs %s salen %d vs %d\n",
706 			__func__, a, c, b, d, peersalen, salen);
707 	}
708 }
709 
check_getpeername_connect(int fd)710 static void check_getpeername_connect(int fd)
711 {
712 	struct sockaddr_storage ss;
713 	socklen_t salen = sizeof(ss);
714 	char a[INET6_ADDRSTRLEN];
715 	char b[INET6_ADDRSTRLEN];
716 
717 	if (getpeername(fd, (struct sockaddr *)&ss, &salen) < 0) {
718 		perror("getpeername");
719 		return;
720 	}
721 
722 	xgetnameinfo((struct sockaddr *)&ss, salen,
723 		     a, sizeof(a), b, sizeof(b));
724 
725 	if (strcmp(cfg_host, a) || strcmp(cfg_port, b))
726 		fprintf(stderr, "%s: %s vs %s, %s vs %s\n", __func__,
727 			cfg_host, a, cfg_port, b);
728 }
729 
maybe_close(int fd)730 static void maybe_close(int fd)
731 {
732 	unsigned int r = rand();
733 
734 	if (!(cfg_join || cfg_remove) && (r & 1))
735 		close(fd);
736 }
737 
main_loop_s(int listensock)738 int main_loop_s(int listensock)
739 {
740 	struct sockaddr_storage ss;
741 	struct pollfd polls;
742 	socklen_t salen;
743 	int remotesock;
744 
745 	polls.fd = listensock;
746 	polls.events = POLLIN;
747 
748 	switch (poll(&polls, 1, poll_timeout)) {
749 	case -1:
750 		perror("poll");
751 		return 1;
752 	case 0:
753 		fprintf(stderr, "%s: timed out\n", __func__);
754 		close(listensock);
755 		return 2;
756 	}
757 
758 	salen = sizeof(ss);
759 	remotesock = accept(listensock, (struct sockaddr *)&ss, &salen);
760 	if (remotesock >= 0) {
761 		maybe_close(listensock);
762 		check_sockaddr(pf, &ss, salen);
763 		check_getpeername(remotesock, &ss, salen);
764 
765 		return copyfd_io(0, remotesock, 1);
766 	}
767 
768 	perror("accept");
769 
770 	return 1;
771 }
772 
init_rng(void)773 static void init_rng(void)
774 {
775 	int fd = open("/dev/urandom", O_RDONLY);
776 	unsigned int foo;
777 
778 	if (fd > 0) {
779 		int ret = read(fd, &foo, sizeof(foo));
780 
781 		if (ret < 0)
782 			srand(fd + foo);
783 		close(fd);
784 	}
785 
786 	srand(foo);
787 }
788 
main_loop(void)789 int main_loop(void)
790 {
791 	int fd;
792 
793 	/* listener is ready. */
794 	fd = sock_connect_mptcp(cfg_host, cfg_port, cfg_sock_proto);
795 	if (fd < 0)
796 		return 2;
797 
798 	check_getpeername_connect(fd);
799 
800 	if (cfg_rcvbuf)
801 		set_rcvbuf(fd, cfg_rcvbuf);
802 	if (cfg_sndbuf)
803 		set_sndbuf(fd, cfg_sndbuf);
804 
805 	return copyfd_io(0, fd, 1);
806 }
807 
parse_proto(const char * proto)808 int parse_proto(const char *proto)
809 {
810 	if (!strcasecmp(proto, "MPTCP"))
811 		return IPPROTO_MPTCP;
812 	if (!strcasecmp(proto, "TCP"))
813 		return IPPROTO_TCP;
814 
815 	fprintf(stderr, "Unknown protocol: %s\n.", proto);
816 	die_usage();
817 
818 	/* silence compiler warning */
819 	return 0;
820 }
821 
parse_mode(const char * mode)822 int parse_mode(const char *mode)
823 {
824 	if (!strcasecmp(mode, "poll"))
825 		return CFG_MODE_POLL;
826 	if (!strcasecmp(mode, "mmap"))
827 		return CFG_MODE_MMAP;
828 	if (!strcasecmp(mode, "sendfile"))
829 		return CFG_MODE_SENDFILE;
830 
831 	fprintf(stderr, "Unknown test mode: %s\n", mode);
832 	fprintf(stderr, "Supported modes are:\n");
833 	fprintf(stderr, "\t\t\"poll\" - interleaved read/write using poll()\n");
834 	fprintf(stderr, "\t\t\"mmap\" - send entire input file (mmap+write), then read response (-l will read input first)\n");
835 	fprintf(stderr, "\t\t\"sendfile\" - send entire input file (sendfile), then read response (-l will read input first)\n");
836 
837 	die_usage();
838 
839 	/* silence compiler warning */
840 	return 0;
841 }
842 
parse_peek(const char * mode)843 int parse_peek(const char *mode)
844 {
845 	if (!strcasecmp(mode, "saveWithPeek"))
846 		return CFG_WITH_PEEK;
847 	if (!strcasecmp(mode, "saveAfterPeek"))
848 		return CFG_AFTER_PEEK;
849 
850 	fprintf(stderr, "Unknown: %s\n", mode);
851 	fprintf(stderr, "Supported MSG_PEEK mode are:\n");
852 	fprintf(stderr,
853 		"\t\t\"saveWithPeek\" - recv data with flags 'MSG_PEEK' and save the peek data into file\n");
854 	fprintf(stderr,
855 		"\t\t\"saveAfterPeek\" - read and save data into file after recv with flags 'MSG_PEEK'\n");
856 
857 	die_usage();
858 
859 	/* silence compiler warning */
860 	return 0;
861 }
862 
parse_int(const char * size)863 static int parse_int(const char *size)
864 {
865 	unsigned long s;
866 
867 	errno = 0;
868 
869 	s = strtoul(size, NULL, 0);
870 
871 	if (errno) {
872 		fprintf(stderr, "Invalid sndbuf size %s (%s)\n",
873 			size, strerror(errno));
874 		die_usage();
875 	}
876 
877 	if (s > INT_MAX) {
878 		fprintf(stderr, "Invalid sndbuf size %s (%s)\n",
879 			size, strerror(ERANGE));
880 		die_usage();
881 	}
882 
883 	return (int)s;
884 }
885 
parse_opts(int argc,char ** argv)886 static void parse_opts(int argc, char **argv)
887 {
888 	int c;
889 
890 	while ((c = getopt(argc, argv, "6jr:lp:s:hut:m:S:R:w:M:P:")) != -1) {
891 		switch (c) {
892 		case 'j':
893 			cfg_join = true;
894 			cfg_mode = CFG_MODE_POLL;
895 			cfg_wait = 400000;
896 			break;
897 		case 'r':
898 			cfg_remove = true;
899 			cfg_mode = CFG_MODE_POLL;
900 			cfg_wait = 400000;
901 			cfg_do_w = atoi(optarg);
902 			if (cfg_do_w <= 0)
903 				cfg_do_w = 50;
904 			break;
905 		case 'l':
906 			listen_mode = true;
907 			break;
908 		case 'p':
909 			cfg_port = optarg;
910 			break;
911 		case 's':
912 			cfg_sock_proto = parse_proto(optarg);
913 			break;
914 		case 'h':
915 			die_usage();
916 			break;
917 		case 'u':
918 			tcpulp_audit = true;
919 			break;
920 		case '6':
921 			pf = AF_INET6;
922 			break;
923 		case 't':
924 			poll_timeout = atoi(optarg) * 1000;
925 			if (poll_timeout <= 0)
926 				poll_timeout = -1;
927 			break;
928 		case 'm':
929 			cfg_mode = parse_mode(optarg);
930 			break;
931 		case 'S':
932 			cfg_sndbuf = parse_int(optarg);
933 			break;
934 		case 'R':
935 			cfg_rcvbuf = parse_int(optarg);
936 			break;
937 		case 'w':
938 			cfg_wait = atoi(optarg)*1000000;
939 			break;
940 		case 'M':
941 			cfg_mark = strtol(optarg, NULL, 0);
942 			break;
943 		case 'P':
944 			cfg_peek = parse_peek(optarg);
945 			break;
946 		}
947 	}
948 
949 	if (optind + 1 != argc)
950 		die_usage();
951 	cfg_host = argv[optind];
952 
953 	if (strchr(cfg_host, ':'))
954 		pf = AF_INET6;
955 }
956 
main(int argc,char * argv[])957 int main(int argc, char *argv[])
958 {
959 	init_rng();
960 
961 	signal(SIGUSR1, handle_signal);
962 	parse_opts(argc, argv);
963 
964 	if (tcpulp_audit)
965 		return sock_test_tcpulp(cfg_host, cfg_port) ? 0 : 1;
966 
967 	if (listen_mode) {
968 		int fd = sock_listen_mptcp(cfg_host, cfg_port);
969 
970 		if (fd < 0)
971 			return 1;
972 
973 		if (cfg_rcvbuf)
974 			set_rcvbuf(fd, cfg_rcvbuf);
975 		if (cfg_sndbuf)
976 			set_sndbuf(fd, cfg_sndbuf);
977 		if (cfg_mark)
978 			set_mark(fd, cfg_mark);
979 
980 		return main_loop_s(fd);
981 	}
982 
983 	return main_loop();
984 }
985