1 #include <sys/types.h>
2 #include <sys/stat.h>
3 #include <stdlib.h>
4 #include <stdio.h>
5 #include <string.h>
6 #include <unistd.h>
7 #include <signal.h>
8 #include <assert.h>
9 #include <sys/time.h>
10 #include <stdarg.h>
11 #include <errno.h>
12 #include <stdint.h>
13 #include <fcntl.h>
14 #include <time.h>
15 #include <openssl/err.h>
16 
17 #include "shared/socket_address.h"
18 #include "inc.h"
19 #include "tcpcrypt_ctl.h"
20 #include "tcpcrypt_divert.h"
21 #include "tcpcrypt.h"
22 #include "tcpcryptd.h"
23 #ifndef __WIN32__
24 #include "priv.h"
25 #endif
26 #include "profile.h"
27 #include "test.h"
28 #include "crypto.h"
29 #include "tcpcrypt_strings.h"
30 #include "config.h"
31 #include "util.h"
32 
33 #define ARRAY_SIZE(n)	(sizeof(n) / sizeof(*n))
34 #define MAX_TIMERS 1024
35 
36 struct conf _conf;
37 struct divert *_divert;
38 
39 struct backlog_ctl {
40 	struct backlog_ctl	*bc_next;
41 	struct socket_address	bc_sa;
42 	struct tcpcrypt_ctl	bc_ctl;
43 };
44 
45 struct timer {
46 	struct timeval	t_time;
47 	timer_cb	t_cb;
48 	void		*t_arg;
49 	struct timer	*t_next;
50 	struct timer	*t_prev;
51 	int		t_id;
52 };
53 
54 struct network_test {
55 	int			nt_port;
56 	int			nt_proto;
57 	int			nt_req;
58 	int			nt_s;
59 	int			nt_state;
60 	int			nt_err;
61 	int			nt_last_state;
62 	int			nt_flags;
63 	int			nt_crypt;
64 	time_t			nt_start;
65 	struct tcpcrypt_ctl	nt_ctl;
66 	struct network_test	*nt_next;
67 };
68 
69 static struct state {
70 	struct backlog_ctl	s_backlog_ctl;
71 	int			s_ctl;
72 	struct socket_address	s_ctl_addr;
73 	int			s_raw;
74 	struct timer		s_timers;
75 	struct timer		*s_timer_map[MAX_TIMERS];
76 	struct timer		s_timer_free;
77 	struct timeval		s_now;
78 	int			s_divert;
79 	int			s_time_set;
80 	packet_hook		s_post_packet_hook;
81 	packet_hook		s_pre_packet_hook;
82 	struct network_test	s_network_tests;
83 	void			*s_nt_timer;
84 	struct in_addr		s_nt_ip;
85 } _state;
86 
87 static struct fd _fds;
88 
89 typedef void (*test_cb)(void);
90 
91 struct test {
92         test_cb t_cb;
93         char    *t_desc;
94 };
95 
96 static struct test _tests[] = {
97 	{ test_sym_throughput, "Symmetric cipher throughput" },
98 	{ test_mac_throughput, "Symmetric MAC throughput" },
99 	{ test_dropper,	       "Packet dropper" },
100 };
101 
ensure_socket_address_unlinked(struct socket_address * sa)102 static void ensure_socket_address_unlinked(struct socket_address *sa)
103 {
104 	const char *path;
105 
106 	if (socket_address_is_null(sa))
107 		return;
108 
109 	if ((path = socket_address_pathname(sa)) != NULL) {
110 		if (unlink(path) != 0) {
111 			if (errno != ENOENT)
112 				warn("unlink(%s)", path);
113 		}
114 	}
115 }
116 
cleanup()117 static void cleanup()
118 {
119 	_divert->close();
120 
121 	if (_state.s_ctl > 0)
122 		close(_state.s_ctl);
123 
124 	if (_state.s_raw > 0)
125 		close(_state.s_raw);
126 
127 	profile_end();
128 }
129 
sig(int num)130 static void sig(int num)
131 {
132 	printf("\n");
133 
134 	cleanup();
135 	exit(0);
136 }
137 
dump_state(void)138 static void dump_state(void)
139 {
140 	struct fd *fd = &_fds;
141 
142 	xprintf(XP_ALWAYS, "==== DUMPING STATE ====\n");
143 
144 	while ((fd = fd->fd_next))
145 		xprintf(XP_ALWAYS, "FD %d state %d\n", fd->fd_fd, fd->fd_state);
146 
147 	xprintf(XP_ALWAYS, "=======================\n");
148 }
149 
sigusr1(int num)150 static void sigusr1(int num)
151 {
152 	dump_state();
153 }
154 
set_time(struct timeval * tv)155 void set_time(struct timeval *tv)
156 {
157 	_state.s_now	  = *tv;
158 	_state.s_time_set = 1;
159 }
160 
get_time(void)161 static struct timeval *get_time(void)
162 {
163 	if (!_state.s_time_set) {
164 		struct timeval tv;
165 
166 		gettimeofday(&tv, NULL);
167 		set_time(&tv);
168 	}
169 
170 	return &_state.s_now;
171 }
172 
alloc_timers()173 static void alloc_timers()
174 {
175 	int i;
176 	struct timer *t;
177 
178 	for (i = 0; i < MAX_TIMERS; i++) {
179 		t = xmalloc(sizeof(*t));
180 		memset(t, 0, sizeof(*t));
181 		t->t_id = i;
182 		_state.s_timer_map[i] = t;
183 
184 		t->t_next = _state.s_timer_free.t_next;
185 		_state.s_timer_free.t_next = t;
186 	}
187 }
188 
add_timer(unsigned int usec,timer_cb cb,void * arg)189 void *add_timer(unsigned int usec, timer_cb cb, void *arg)
190 {
191 	struct timer *t, *prev, *cur;
192 	int sec;
193 
194 	if (_conf.cf_disable_timers)
195 		return (void*) 0x666;
196 
197 	if (!_state.s_timer_map[0])
198 		alloc_timers();
199 
200 	t = _state.s_timer_free.t_next;
201 	assert(t);
202 	_state.s_timer_free.t_next = t->t_next;
203 	t->t_next = NULL;
204 
205 	t->t_time = *(get_time());
206 	t->t_time.tv_sec  += usec / (1000 * 1000);
207 	t->t_time.tv_usec += usec % (1000 * 1000);
208 
209 	sec = t->t_time.tv_usec / (1000 * 1000);
210 	if (sec) {
211 		t->t_time.tv_sec  += sec;
212 		t->t_time.tv_usec  = t->t_time.tv_usec % (1000 * 1000);
213 	}
214 
215 	t->t_cb   = cb;
216 	t->t_arg  = arg;
217 
218 	prev = &_state.s_timers;
219 	cur  = prev->t_next;
220 
221 	while (cur) {
222 		if (time_diff(&t->t_time, &cur->t_time) >= 0) {
223 			t->t_next   = cur;
224 			cur->t_prev = t;
225 			break;
226 		}
227 
228 		prev = cur;
229 		cur  = cur->t_next;
230 	}
231 
232 	prev->t_next = t;
233 	t->t_prev    = prev;
234 
235 	if (!t->t_next)
236 		_state.s_timers.t_prev = t;
237 
238 	return t;
239 }
240 
clear_timer(void * timer)241 void clear_timer(void *timer)
242 {
243 	struct timer *prev = &_state.s_timers;
244 	struct timer *t    = prev->t_next;
245 
246 	if (_conf.cf_disable_timers)
247 		return;
248 
249 	while (t) {
250 		if (t == timer) {
251 			prev->t_next = t->t_next;
252 
253 			t->t_next = _state.s_timer_free.t_next;
254 			_state.s_timer_free.t_next = t;
255 			return;
256 		}
257 
258 		prev = t;
259 		t    = t->t_next;
260 	}
261 
262 	assert(!"Timer not found");
263 }
264 
packet_handler(void * packet,int len,int flags)265 static int packet_handler(void *packet, int len, int flags)
266 {
267 	int rc;
268 
269 	/* XXX implement as pre packet hook */
270 	if (_conf.cf_accept)
271 		return DIVERT_ACCEPT;
272 	else if (_conf.cf_modify)
273 		return DIVERT_MODIFY;
274 
275 	if (_state.s_pre_packet_hook) {
276 		rc = _state.s_pre_packet_hook(-1, packet, len, flags);
277 
278 		if (rc != -1)
279 			return rc;
280 	}
281 
282 	rc = tcpcrypt_packet(packet, len, flags);
283 
284 	if (_state.s_post_packet_hook)
285 		return _state.s_post_packet_hook(rc, packet, len, flags);
286 
287 	return rc;
288 }
289 
set_packet_hook(int post,packet_hook p)290 void set_packet_hook(int post, packet_hook p)
291 {
292 	if (post)
293 		_state.s_post_packet_hook = p;
294 	else
295 		_state.s_pre_packet_hook  = p;
296 }
297 
backlog_ctl(struct tcpcrypt_ctl * c,struct socket_address * sa)298 static void backlog_ctl(struct tcpcrypt_ctl *c, struct socket_address *sa)
299 {
300 	struct backlog_ctl *b;
301 
302 	b = xmalloc(sizeof(*b) + c->tcc_dlen);
303 	memset(b, 0, sizeof(*b));
304 
305 	memcpy(&b->bc_sa, sa, sizeof(*sa));
306 	memcpy(&b->bc_ctl, c, sizeof(*c));
307 	memcpy(b->bc_ctl.tcc_data, c->tcc_data, c->tcc_dlen);
308 
309 	b->bc_next = _state.s_backlog_ctl.bc_next;
310 	_state.s_backlog_ctl.bc_next = b;
311 }
312 
do_handle_ctl(struct tcpcrypt_ctl * c,struct socket_address * sa)313 static int do_handle_ctl(struct tcpcrypt_ctl *c, struct socket_address *sa)
314 {
315 	int l, rc;
316 
317 	if (c->tcc_flags & TCC_SET)
318 		c->tcc_err = tcpcryptd_setsockopt(c, c->tcc_opt, c->tcc_data,
319 					 	  c->tcc_dlen);
320 	else
321 		c->tcc_err = tcpcryptd_getsockopt(c, c->tcc_opt, c->tcc_data,
322 						  &c->tcc_dlen);
323 
324 	/* we can either have client retry, or we queue things up.  The latter
325 	 * is more efficient but more painful to implement.  I'll go for the
326 	 * latter anyway, i'm sure nobody will mind (I'm the one coding after
327 	 * all).
328 	 */
329 	if (c->tcc_err == EBUSY)
330 		return 0;
331 
332 	l = sizeof(*c) + c->tcc_dlen;
333 	rc = sendto(_state.s_ctl, (void*) c, l, 0, &sa->addr.sa, sa->addr_len);
334 
335 	if (rc == -1)
336 		err(1, "sendto()");
337 
338 	if (rc != l)
339 		errx(1, "short write");
340 
341 	return 1;
342 }
343 
backlog_ctl_process(void)344 static void backlog_ctl_process(void)
345 {
346 	struct backlog_ctl *prev = &_state.s_backlog_ctl;
347 	struct backlog_ctl *b = prev->bc_next;
348 
349 	while (b) {
350 		if (do_handle_ctl(&b->bc_ctl, &b->bc_sa)) {
351 			struct backlog_ctl *next = b->bc_next;
352 
353 			prev->bc_next = next;
354 			free(b);
355 			b = next;
356 		} else {
357 			prev = b;
358 			b = b->bc_next;
359 		}
360 	}
361 }
362 
handle_ctl(int ctl)363 static void handle_ctl(int ctl)
364 {
365 	unsigned char buf[4096];
366 	struct tcpcrypt_ctl *c = (struct tcpcrypt_ctl*) buf;
367 	int rc;
368 	struct socket_address sa = SOCKET_ADDRESS_ANY;
369 
370 	rc = recvfrom(ctl, (void*) buf, sizeof(buf), 0, &sa.addr.sa, &sa.addr_len);
371 	if (rc == -1)
372 		err(1, "read(ctl)");
373 
374 	if (rc == 0)
375 		errx(1, "EOF");
376 
377 	if (rc < sizeof(*c)) {
378 		xprintf(XP_ALWAYS, "fsadlfijasldkjf\n");
379 		return;
380 	}
381 
382 	if (c->tcc_dlen + sizeof(*c) != rc) {
383 		xprintf(XP_ALWAYS, "bad len\n");
384 		return;
385 	}
386 
387 	if (!do_handle_ctl(c, &sa))
388 		backlog_ctl(c, &sa);
389 }
390 
dispatch_timers(void)391 static void dispatch_timers(void)
392 {
393 	struct timer *head = &_state.s_timers;
394 	struct timer *t;
395 	struct timer tmp;
396 
397 	while ((t = head->t_next)) {
398 		if (time_diff(&t->t_time, get_time()) < 0)
399 			break;
400 
401 		/* timers can add timers so lets fixup linked list first */
402 		tmp = *t;
403 
404 		clear_timer(t);
405 
406 		tmp.t_cb(tmp.t_arg);
407 	}
408 }
409 
add_test(int port,int proto,int req)410 static void add_test(int port, int proto, int req)
411 {
412 	struct network_test *t = xmalloc(sizeof(*t));
413 	struct network_test *cur = &_state.s_network_tests;
414 
415 	memset(t, 0, sizeof(*t));
416 
417 	t->nt_port  = port;
418 	t->nt_proto = proto;
419 	t->nt_req   = req;
420 
421 	while (cur->nt_next)
422 		cur = cur->nt_next;
423 
424 	cur->nt_next = t;
425 }
426 
test_port(int port)427 static void test_port(int port)
428 {
429 	add_test(port, TEST_TCP, 0);
430 	add_test(port, TEST_TCP, 1);
431 	add_test(port, TEST_CRYPT, 2);
432 }
433 
prepare_ctl(struct network_test * nt)434 static void prepare_ctl(struct network_test *nt)
435 {
436 	struct sockaddr_in s_in;
437 	struct tcpcrypt_ctl *ctl = &nt->nt_ctl;
438 	int s = nt->nt_s;
439 	socklen_t sl = sizeof(s_in);
440 
441 	memset(&s_in, 0, sizeof(s_in));
442 	s_in.sin_family      = AF_INET;
443 	s_in.sin_addr.s_addr = INADDR_ANY;
444 	s_in.sin_port        = htons(0);
445 
446 	if (bind(s, (struct sockaddr*) &s_in, sizeof(s_in)) == -1)
447 		err(1, "bind()");
448 
449 	if (getsockname(s, (struct sockaddr*) &s_in, &sl) == -1)
450 		err(1, "getsockname()");
451 
452 	ctl->tcc_src   = s_in.sin_addr;
453 	ctl->tcc_sport = s_in.sin_port;
454 }
455 
456 #ifdef __WIN32__
set_nonblocking(int s)457 void set_nonblocking(int s)
458 {
459 	u_long mode = 1;
460 
461 	ioctlsocket(s, FIONBIO, &mode);
462 }
463 #else
set_nonblocking(int s)464 void set_nonblocking(int s)
465 {
466 	int flags;
467 
468 	if ((flags = fcntl(s, F_GETFL, 0)) == -1)
469 		err(1, "fcntl()");
470 
471 	if (fcntl(s, F_SETFL, flags | O_NONBLOCK) == -1)
472 		err(1, "fcntl()");
473 }
474 #endif
475 
test_connect(struct network_test * t)476 static void test_connect(struct network_test *t)
477 {
478 	int s;
479 	struct sockaddr_in s_in;
480 
481 	if ((s = socket(AF_INET, SOCK_STREAM, 0)) == -1)
482 		err(1, "socket()");
483 
484 	t->nt_s = s;
485 
486 	prepare_ctl(t);
487 
488 	if (t->nt_proto == TEST_TCP) {
489 		int off = 0;
490 
491 		if (tcpcryptd_setsockopt(&t->nt_ctl, TCP_CRYPT_ENABLE, &off,
492 					 sizeof(off)) == -1)
493 			errx(1, "tcpcryptd_setsockopt()");
494 	} else {
495 		int one = 1;
496 		assert(t->nt_proto == TEST_CRYPT);
497 		if (tcpcryptd_setsockopt(&t->nt_ctl, TCP_CRYPT_NOCACHE, &one,
498 					 sizeof(one)) == -1)
499 			errx(1, "tcpcryptd_setsockopt()");
500 	}
501 
502 	set_nonblocking(s);
503 
504 	memset(&s_in, 0, sizeof(s_in));
505 
506 	s_in.sin_family      = AF_INET;
507 	s_in.sin_port        = htons(t->nt_port);
508 	s_in.sin_addr        = _state.s_nt_ip;
509 
510 	if (connect(s, (struct sockaddr*) &s_in, sizeof(s_in)) == -1) {
511 #ifdef __WIN32__
512 		if (WSAGetLastError() != WSAEWOULDBLOCK)
513 #else
514 		if (errno != EINPROGRESS)
515 #endif
516 			err(1, "connect()");
517 	}
518 
519 	t->nt_ctl.tcc_dst   = s_in.sin_addr;
520 	t->nt_ctl.tcc_dport = s_in.sin_port;
521 
522 	t->nt_state = TEST_STATE_CONNECTING;
523 	t->nt_start = time(NULL);
524 }
525 
test_finish(struct network_test * t,int rc)526 static void test_finish(struct network_test *t, int rc)
527 {
528 	t->nt_last_state = t->nt_state;
529 	t->nt_err        = rc;
530 	t->nt_state      = TEST_STATE_DONE;
531 
532 	close(t->nt_s);
533 
534 	printf("Test result: " \
535 	       "port %d crypt %d req %d state %d err %d flags %d\n",
536 	       t->nt_port,
537 	       t->nt_proto == TEST_CRYPT ? 1 : 0,
538 	       t->nt_req,
539 	       t->nt_last_state,
540 	       t->nt_err,
541 	       t->nt_flags);
542 }
543 
test_success(struct network_test * t)544 static void test_success(struct network_test *t)
545 {
546 	t->nt_state = TEST_SUCCESS;
547 	test_finish(t, 0);
548 }
549 
test_connecting(struct network_test * t)550 static void test_connecting(struct network_test *t)
551 {
552 	int s = t->nt_s;
553 	struct timeval tv;
554 	fd_set fds;
555 	int rc;
556 	socklen_t sz = sizeof(rc);
557 	char *buf = NULL;
558 	unsigned char sid[1024];
559 	unsigned int sidlen = sizeof(sid);
560 	struct sockaddr_in s_in;
561 	socklen_t sl = sizeof(s_in);
562 
563 	tv.tv_sec  = 0;
564 	tv.tv_usec = 0;
565 
566 	FD_ZERO(&fds);
567 	FD_SET(s, &fds);
568 
569 	if (select(s + 1, NULL, &fds, NULL, &tv) == -1)
570 		err(1, "select()");
571 
572 	if (!FD_ISSET(s, &fds))
573 		return;
574 
575 	if (getsockopt(s, SOL_SOCKET, SO_ERROR, &rc, &sz) == -1)
576 		err(1, "getsockopt()");
577 
578 	if (rc != 0) {
579 		test_finish(t, rc);
580 		return;
581 	}
582 
583 	if (getsockname(s, (struct sockaddr*) &s_in, &sl) == -1)
584 		err(1, "getsockname()");
585 
586 	t->nt_ctl.tcc_src = s_in.sin_addr;
587 
588 	rc = tcpcryptd_getsockopt(&t->nt_ctl, TCP_CRYPT_SESSID, sid, &sidlen);
589 
590 	if (rc == EBUSY)
591 		return;
592 
593 	t->nt_crypt = rc != -1;
594 
595 	assert(t->nt_req < (sizeof(REQS) / sizeof(*REQS)));
596 	buf = REQS[t->nt_req];
597 
598 	if (send(s, buf, strlen(buf), 0) != strlen(buf))
599 		err(1, "send()");
600 
601 	t->nt_state = TEST_STATE_REQ_SENT;
602 }
603 
test_req_sent(struct network_test * t)604 static void test_req_sent(struct network_test *t)
605 {
606 	int s = t->nt_s;
607 	fd_set fds;
608 	struct timeval tv;
609 	char buf[1024];
610 	int rc;
611 
612 	FD_ZERO(&fds);
613 	FD_SET(s, &fds);
614 
615 	tv.tv_sec  = 0;
616 	tv.tv_usec = 0;
617 
618 	if (select(s + 1, &fds, NULL, NULL, &tv) == -1)
619 		err(1, "select()");
620 
621 	if (!FD_ISSET(s, &fds))
622 		return;
623 
624 	rc = recv(s, buf, sizeof(buf) - 1, 0);
625 	if (rc == -1) {
626 		test_finish(t, errno);
627 		return;
628 	}
629 
630 	if (rc == 0) {
631 		test_finish(t, TEST_ERR_DISCONNECT);
632 		return;
633 	}
634 
635 	buf[rc] = 0;
636 
637 	if (strncmp(buf, TEST_REPLY, strlen(TEST_REPLY)) != 0) {
638 		test_finish(t, TEST_ERR_BADINPUT);
639 		return;
640 	}
641 
642 	t->nt_flags = atoi(&buf[rc - 1]);
643 
644 	if (t->nt_proto == TEST_TCP && t->nt_crypt == 1) {
645 		test_finish(t, TEST_ERR_UNEXPECTED_CRYPT);
646 		return;
647 	}
648 
649 	if (t->nt_proto == TEST_CRYPT && t->nt_crypt != 1) {
650 		test_finish(t, TEST_ERR_NO_CRYPT);
651 		return;
652 	}
653 
654 	test_success(t);
655 }
656 
run_network_test(struct network_test * t)657 static void run_network_test(struct network_test *t)
658 {
659 	if (t->nt_start && (time(NULL) - t->nt_start) > 5) {
660 		test_finish(t, TEST_ERR_TIMEOUT);
661 		return;
662 	}
663 
664 	switch (t->nt_state) {
665 	case TEST_STATE_START:
666 		test_connect(t);
667 		break;
668 
669 	case TEST_STATE_CONNECTING:
670 		test_connecting(t);
671 		break;
672 
673 	case TEST_STATE_REQ_SENT:
674 		test_req_sent(t);
675 		break;
676 	}
677 }
678 
resolve_server(void)679 static int resolve_server(void)
680 {
681 	struct hostent *he = gethostbyname(_conf.cf_test_server);
682 	struct in_addr **addr;
683 
684 	_state.s_nt_ip.s_addr = INADDR_ANY;
685 
686 	if (!he)
687 		return 0;
688 
689 	addr = (struct in_addr**) he->h_addr_list;
690 
691 	if (!addr[0])
692 		return 0;
693 
694 	_state.s_nt_ip = *addr[0];
695 
696 	return 1;
697 }
698 
test_network(void)699 static void test_network(void)
700 {
701 	resolve_server();
702 
703 	if (_state.s_nt_ip.s_addr == INADDR_ANY) {
704 		xprintf(XP_ALWAYS, "Won't test network - can't resolve %s\n",
705 			_conf.cf_test_server);
706 		return;
707 	}
708 
709 	xprintf(XP_ALWAYS, "Testing network via %s\n",
710 		inet_ntoa(_state.s_nt_ip));
711 
712 	test_port(80);
713 	test_port(7777);
714 }
715 
retest_network(void * ignored)716 static void retest_network(void* ignored)
717 {
718 	_conf.cf_disable = 0;
719 	test_network();
720 }
721 
test_results(void)722 static void test_results(void)
723 {
724 	struct network_test *t = _state.s_network_tests.nt_next;
725 	int tot = 0;
726 	int fail = 0;
727 
728 	xprintf(XP_ALWAYS, "Tests done!");
729 
730 	while (t) {
731 		tot++;
732 
733 		if (t->nt_last_state != TEST_SUCCESS) {
734 			fail++;
735 			xprintf(XP_ALWAYS, " %d", tot);
736 		}
737 
738 		t = t->nt_next;
739 	}
740 
741 	if (fail) {
742 		unsigned long mins = 30;
743 		unsigned long timeout = 1000 * 1000 * 60 * mins;
744 
745 		xprintf(XP_ALWAYS, " failed [%d/%d]!\n", fail, tot);
746 
747 		t = _state.s_network_tests.nt_next;
748 		if (t->nt_last_state == TEST_SUCCESS) {
749 			xprintf(XP_ALWAYS,
750 			        "Disabling tcpcrypt for %lu minutes\n", mins);
751 
752 			_conf.cf_disable = 1;
753 			_state.s_nt_timer = add_timer(timeout, retest_network,
754 					  	      NULL);
755 		}
756 	} else {
757 		xprintf(XP_ALWAYS, " All passed\n");
758 		/* XXX retest later? */
759 	}
760 }
761 
run_network_tests(void)762 static int run_network_tests(void)
763 {
764 	struct network_test *t = _state.s_network_tests.nt_next;
765 
766 	while (t) {
767 		if (t->nt_state != TEST_STATE_DONE) {
768 			run_network_test(t);
769 			return 1;
770 		}
771 
772 		t = t->nt_next;
773 	}
774 
775 	t = _state.s_network_tests.nt_next;
776 	if (t) {
777 		test_results();
778 
779 		while (t) {
780 			struct network_test *next = t->nt_next;
781 			free(t);
782 			t = next;
783 		}
784 
785 		_state.s_network_tests.nt_next = NULL;
786 	}
787 
788 	return 0;
789 }
790 
do_cycle(void)791 static void do_cycle(void)
792 {
793 	fd_set rd, wr;
794 	int max = 0;
795 	struct timer *t;
796 	struct timeval tv, *tvp = NULL;
797 	int testing = 0;
798 	struct fd *fd = &_fds;
799 
800 	testing = run_network_tests();
801 
802 	FD_ZERO(&rd);
803 	FD_ZERO(&wr);
804 
805         /* prepare select */
806         while (fd->fd_next) {
807                 struct fd *next = fd->fd_next;
808 
809                 /* unlink dead sockets */
810                 if (next->fd_state == FDS_DEAD) {
811 			fd->fd_next = next->fd_next;
812                         free(next);
813                         continue;
814                 }
815 
816                 fd = next;
817 
818                 switch (fd->fd_state) {
819 		case FDS_IDLE:
820 			continue;
821 
822                 case FDS_WRITE:
823                         FD_SET(fd->fd_fd, &wr);
824                         break;
825 
826                 case FDS_READ:
827                         FD_SET(fd->fd_fd, &rd);
828                         break;
829                 }
830 
831                 max = fd->fd_fd > max ? fd->fd_fd : max;
832         }
833 
834 	t = _state.s_timers.t_next;
835 
836 	if (t) {
837 		int diff = time_diff(get_time(), &t->t_time);
838 
839 		assert(diff > 0);
840 		tv.tv_sec  = diff / (1000 * 1000);
841 		tv.tv_usec = diff % (1000 * 1000);
842 		tvp = &tv;
843 	} else
844 		tvp = NULL;
845 
846 	_state.s_time_set = 0;
847 
848 	if (testing && !tvp) {
849 		tv.tv_sec = 0;
850 		tv.tv_usec = 1000;
851 		tvp = &tv;
852 	}
853 
854 	if (select(max + 1, &rd, &wr, NULL, tvp) == -1) {
855 		if (errno == EINTR)
856 			return;
857 
858 		err(1, "select()");
859 	}
860 
861 	fd = &_fds;
862 
863 	while ((fd = fd->fd_next)) {
864 		if (fd->fd_state == FDS_READ && FD_ISSET(fd->fd_fd, &rd))
865 			fd->fd_cb(fd);
866 
867 		if (fd->fd_state == FDS_WRITE && FD_ISSET(fd->fd_fd, &wr))
868 			fd->fd_cb(fd);
869 	}
870 
871 	dispatch_timers();
872 
873 	if (_divert->cycle)
874 		_divert->cycle();
875 
876 	if (_conf.cf_rdr)
877 		backlog_ctl_process();
878 }
879 
do_test(void)880 static void do_test(void)
881 {
882 	struct test *t;
883 
884 	if (_conf.cf_test < 0
885 	    || _conf.cf_test >= sizeof(_tests) / sizeof(*_tests))
886 		errx(1, "Test %d out of range", _conf.cf_test);
887 
888 	t = &_tests[_conf.cf_test];
889 
890 	printf("Running test %d: %s\n", _conf.cf_test, t->t_desc);
891 	t->t_cb();
892 	printf("Test done\n");
893 }
894 
bind_control_socket(struct socket_address * sa,const char * descr)895 static int bind_control_socket(struct socket_address *sa, const char *descr)
896 {
897 	int r, s;
898 	static const int error_len = 1000;
899 	char error[error_len];
900 	mode_t mask;
901 	const char *path;
902 
903 	r = resolve_socket_address_local(_conf.cf_ctl, sa, error, error_len);
904 	if (r != 0)
905 		errx(1, "interpreting socket address '%s': %s", descr, error);
906 	{
907 		char name[1000];
908 		socket_address_pretty(name, 1000, sa);
909 		xprintf(XP_DEFAULT, "Opening control socket at %s\n", name);
910 	}
911 
912 	if ((s = socket(sa->addr.sa.sa_family, SOCK_DGRAM, 0)) <= 0)
913 		err(1, "socket()");
914 
915 	ensure_socket_address_unlinked(sa);
916 	mask = umask(0);
917 	if (bind(s, &sa->addr.sa, sa->addr_len) != 0)
918 		err(1, "bind()");
919 	umask(mask);
920 
921 	/* in case of old systems where bind() ignores the umask: */
922 	if ((path = socket_address_pathname(sa)) != NULL) {
923 		if (chmod(path, 0777) != 0)
924 			warnx("Setting permissions on control socket");
925 	}
926 
927 	return s;
928 }
929 
_drop_privs(const char * dir,const char * name)930 void _drop_privs(const char *dir, const char *name) {
931 	xprintf(XP_DEFAULT, "Attempting to drop privileges with chroot=%s and user=%s\n",
932 		dir ? dir : "(NONE)", name ? name : "(NONE)");
933 	drop_privs(dir, name);
934 }
935 
add_fd(int f,fd_cb cb)936 struct fd *add_fd(int f, fd_cb cb)
937 {
938 	struct fd *fd = xmalloc(sizeof(*fd));
939 
940 	memset(fd, 0, sizeof(*fd));
941 
942 	fd->fd_fd    = f;
943 	fd->fd_cb    = cb;
944 	fd->fd_state = FDS_READ;
945 	fd->fd_next  = _fds.fd_next;
946 	_fds.fd_next = fd;
947 
948 	return fd;
949 }
950 
process_divert(struct fd * fd)951 static void process_divert(struct fd *fd)
952 {
953 	_divert->next_packet(fd->fd_fd);
954 	backlog_ctl_process();
955 }
956 
process_ctl(struct fd * fd)957 static void process_ctl(struct fd *fd)
958 {
959 	handle_ctl(fd->fd_fd);
960 }
961 
tcpcryptd(void)962 void tcpcryptd(void)
963 {
964 	_divert = divert_get();
965 	assert(_divert);
966 
967 	_state.s_divert = _divert->open(_conf.cf_divert, packet_handler);
968 
969 	_state.s_ctl = bind_control_socket(&_state.s_ctl_addr, _conf.cf_ctl);
970 
971 	_drop_privs(_conf.cf_jail_dir, _conf.cf_jail_user);
972 
973 	printf("Running\n");
974 
975 	if (!_conf.cf_disable && !_conf.cf_disable_network_test)
976 		test_network();
977 
978 	add_fd(_state.s_divert, process_divert);
979 	add_fd(_state.s_ctl, process_ctl);
980 
981 	while (1)
982 		do_cycle();
983 }
984 
do_set_preference(int id,int type)985 static void do_set_preference(int id, int type)
986 {
987 	if (!id)
988 		return;
989 
990 	assert(!"implement");
991 }
992 
setup_tcpcrypt(void)993 static void setup_tcpcrypt(void)
994 {
995 	struct cipher_list *c;
996 
997 	/* set cipher preference */
998 	do_set_preference(_conf.cf_cipher, TYPE_SYM);
999 
1000 	/* add ciphers */
1001 	c = crypt_cipher_list();
1002 
1003 	while (c) {
1004 		tcpcrypt_register_cipher(c);
1005 
1006 		c = c->c_next;
1007 	}
1008 
1009 	/* setup */
1010 	tcpcrypt_init();
1011 }
1012 
pwn(void)1013 static void pwn(void)
1014 {
1015 	printf("Initializing...\n");
1016 	setup_tcpcrypt();
1017 
1018 	if (_conf.cf_test != -1)
1019 		do_test();
1020 	else
1021 		tcpcryptd();
1022 }
1023 
xprintf(int level,char * fmt,...)1024 void xprintf(int level, char *fmt, ...)
1025 {
1026 	va_list ap;
1027 
1028 	if (_conf.cf_verbose < level)
1029 		return;
1030 
1031 	va_start(ap, fmt);
1032 	vprintf(fmt, ap);
1033 	va_end(ap);
1034 }
1035 
hexdump(void * x,int len)1036 void hexdump(void *x, int len)
1037 {
1038 	uint8_t *p = x;
1039 	int did = 0;
1040 	int level = XP_ALWAYS;
1041 
1042 	xprintf(level, "Dumping %d bytes\n", len);
1043 	while (len--) {
1044 		xprintf(level, "%.2X ", *p++);
1045 
1046 		if (++did == 16) {
1047 			if (len)
1048 				xprintf(level, "\n");
1049 
1050 			did = 0;
1051 		}
1052 	}
1053 
1054 	xprintf(level, "\n");
1055 }
1056 
errssl(int x,char * fmt,...)1057 void errssl(int x, char *fmt, ...)
1058 {
1059         va_list ap;
1060 
1061         va_start(ap, fmt);
1062         vprintf(fmt, ap);
1063         va_end(ap);
1064 
1065         printf(": %s\n", ERR_error_string(ERR_get_error(), NULL));
1066         exit(1);
1067 }
1068 
add_param(struct params * p,char * optarg)1069 static void add_param(struct params *p, char *optarg)
1070 {
1071 	if (p->p_paramc >= ARRAY_SIZE(p->p_params))
1072 		errx(1, "too many parameters\n");
1073 
1074 	p->p_params[p->p_paramc++] = optarg;
1075 }
1076 
get_param(struct params * p,int idx)1077 static char *get_param(struct params *p, int idx)
1078 {
1079 	if (idx >= p->p_paramc)
1080 		return NULL;
1081 
1082 	return p->p_params[idx];
1083 }
1084 
xbe64toh(uint64_t x)1085 uint64_t xbe64toh(uint64_t x)
1086 {
1087         return ntohl(x); /* XXX */
1088 }
1089 
xhtobe64(uint64_t x)1090 uint64_t xhtobe64(uint64_t x)
1091 {
1092         return htonl(x); /* XXX */
1093 }
1094 
driver_param(int idx)1095 char *driver_param(int idx)
1096 {
1097 	return get_param(&_conf.cf_divert_params, idx);
1098 }
1099 
test_param(int idx)1100 char *test_param(int idx)
1101 {
1102 	return get_param(&_conf.cf_test_params, idx);
1103 }
1104 
usage(char * prog)1105 static void usage(char *prog)
1106 {
1107 	int i;
1108 
1109 	printf("Usage: %s <opt>\n"
1110 	       "-h\thelp (or --help)\n"
1111 	       "-p\t<divert port> (default: %d)\n"
1112 	       "-v\tverbose\n"
1113 	       "-d\tdisable\n"
1114 	       "-c\tno cache\n"
1115 	       "-a\tdivert accept (NOP)\n"
1116 	       "-m\tdivert modify (NOP)\n"
1117 	       "-u\t<local control socket> (default: " TCPCRYPTD_CONTROL_SOCKET ")\n"
1118 	       "-n\tno crypto\n"
1119 	       "-P\tprofile\n"
1120 	       "-S\tprofile time source (0 TSC, 1 gettimeofday)\n"
1121 	       "-t\t<test>\n"
1122 	       "-T\t<test param>\n"
1123 	       "-D\tdebug\n"
1124 	       "-x\t<divert driver param>\n"
1125 	       "-N\trun as nat / middlebox\n"
1126 	       "-C\t<preferred cipher>\n"
1127 	       "-M\t<preferred MAC>\n"
1128 	       "-r\t<random device>\n"
1129 	       "-R\tRSA client hack\n"
1130 	       "-i\tdisable timers\n"
1131 	       "-f\tdisable network test\n"
1132 	       "-s\t<network test server> (default: " TCPCRYPTD_TEST_SERVER ")\n"
1133 	       "-V\tshow version (or --version)\n"
1134 	       "-U\t<jail username> (default: " TCPCRYPTD_JAIL_USER ")\n"
1135 	       "-J\t<jail directory> (default: " TCPCRYPTD_JAIL_DIR ")\n"
1136 	       "-e\tredirect\n"
1137 	       , prog, TCPCRYPTD_DIVERT_PORT);
1138 
1139 	printf("\nTests:\n");
1140 	for (i = 0; i < sizeof(_tests) / sizeof(*_tests); i++)
1141 		printf("%d) %s\n", i, _tests[i].t_desc);
1142 }
1143 
main(int argc,char * argv[])1144 int main(int argc, char *argv[])
1145 {
1146 	int ch;
1147 
1148 #ifdef __WIN32__
1149 	WSADATA wsadata;
1150 	if (WSAStartup(MAKEWORD(1,1), &wsadata) == SOCKET_ERROR)
1151 		errx(1, "WSAStartup()");
1152 #endif
1153 
1154 	_conf.cf_divert	     	      = TCPCRYPTD_DIVERT_PORT;
1155 	_conf.cf_ctl  	     	      = TCPCRYPTD_CONTROL_SOCKET;
1156 	_conf.cf_test 	     	      = -1;
1157 	_conf.cf_test_server 	      = TCPCRYPTD_TEST_SERVER;
1158 	_conf.cf_jail_dir    	      = TCPCRYPTD_JAIL_DIR;
1159 	_conf.cf_jail_user   	      = TCPCRYPTD_JAIL_USER;
1160 	_conf.cf_disable_network_test = 1;
1161 
1162 	if (argc == 2 && argv[1][0] == '-' && argv[1][1] == '-') {
1163 		if (strcmp(argv[1], "--help") == 0) {
1164 			usage(argv[0]);
1165 			exit(0);
1166 		} else if (strcmp(argv[1], "--version") == 0) {
1167 			printf("tcpcrypt version %s\n", TCPCRYPT_VERSION);
1168 			exit(0);
1169 		} else {
1170 			usage(argv[0]);
1171 			exit(1);
1172 		}
1173 	}
1174 
1175 	while ((ch = getopt(argc, argv, "hp:vdu:camnPt:T:S:Dx:NC:M:r:Rifs:VU:J:e"))
1176 	       != -1) {
1177 		switch (ch) {
1178 		case 'e':
1179 			_conf.cf_rdr = 1;
1180 			break;
1181 
1182 		case 'i':
1183 			_conf.cf_disable_timers = 1;
1184 			break;
1185 
1186 		case 'r':
1187 			_conf.cf_random_path = optarg;
1188 			break;
1189 
1190 		case 'R':
1191 			_conf.cf_rsa_client_hack = 1;
1192 			break;
1193 
1194 		case 'M':
1195 			_conf.cf_mac = atoi(optarg);
1196 			break;
1197 
1198 		case 'C':
1199 			_conf.cf_cipher = atoi(optarg);
1200 			break;
1201 
1202 		case 'N':
1203 			_conf.cf_nat = 1;
1204 			break;
1205 
1206 		case 'D':
1207 			_conf.cf_debug = 1;
1208 			break;
1209 
1210 		case 'S':
1211 			profile_setopt(PROFILE_TIME_SOURCE, atoi(optarg));
1212 			break;
1213 
1214 		case 'x':
1215 			add_param(&_conf.cf_divert_params, optarg);
1216 			break;
1217 
1218 		case 'T':
1219 			add_param(&_conf.cf_test_params, optarg);
1220 			break;
1221 
1222 		case 't':
1223 			_conf.cf_test = atoi(optarg);
1224 			break;
1225 
1226 		case 'P':
1227 			_conf.cf_profile++;
1228 			break;
1229 
1230 		case 'n':
1231 			_conf.cf_dummy = 1;
1232 			break;
1233 
1234 		case 'a':
1235 			_conf.cf_accept = 1;
1236 			break;
1237 
1238 		case 'm':
1239 			_conf.cf_modify = 1;
1240 			break;
1241 
1242 		case 'c':
1243 			_conf.cf_nocache = 1;
1244 			break;
1245 
1246 		case 'u':
1247 			_conf.cf_ctl = optarg;
1248 			break;
1249 
1250 		case 'd':
1251 			_conf.cf_disable = 1;
1252 			break;
1253 
1254 		case 'p':
1255 			_conf.cf_divert = atoi(optarg);
1256 			break;
1257 
1258 		case 'v':
1259 			_conf.cf_verbose++;
1260 			break;
1261 
1262 		case 'V':
1263 			printf("tcpcrypt version %s\n", TCPCRYPT_VERSION);
1264 			exit(0);
1265 
1266 		case 'f':
1267 			_conf.cf_disable_network_test = 1;
1268 			break;
1269 
1270 		case 's':
1271 			_conf.cf_test_server = optarg;
1272 			break;
1273 
1274 		case 'U':
1275 			_conf.cf_jail_user = optarg;
1276 			break;
1277 
1278 		case 'J':
1279 			_conf.cf_jail_dir = optarg;
1280 			break;
1281 
1282 		case 'h':
1283 			usage(argv[0]);
1284 			exit(0);
1285 			break;
1286 
1287 		default:
1288 			usage(argv[0]);
1289 			exit(1);
1290 			break;
1291 		}
1292 	}
1293 
1294 	resolve_server();
1295 
1296 	if (signal(SIGINT, sig) == SIG_ERR)
1297 		err(1, "signal(SIGINT)");
1298 
1299 	if (signal(SIGTERM, sig) == SIG_ERR)
1300 		err(1, "signal(SIGTERM)");
1301 
1302 #ifndef __WIN32__
1303 	if (signal(SIGUSR1, sigusr1) == SIG_ERR)
1304 		err(1, "signal(SIGUSR1)");
1305 
1306 	if (signal(SIGPIPE, SIG_IGN) == SIG_ERR)
1307 		err(1, "signal(SIGPIPE)");
1308 #endif
1309 
1310 	profile_setopt(PROFILE_DISCARD, 3);
1311 	profile_setopt(PROFILE_ENABLE, _conf.cf_profile);
1312 
1313 	if (atexit(dump_state))
1314 		err(1, "atexit()");
1315 
1316 	pwn();
1317 	cleanup();
1318 
1319 	exit(0);
1320 }
1321