1 #include <stdio.h>
2 #include <stdint.h>
3 #include <stdlib.h>
4 #include <string.h>
5 #include <assert.h>
6 #include <errno.h>
7 #include <time.h>
8 #include <fcntl.h>
9 #include <unistd.h>
10 
11 #include "inc.h"
12 #include "util.h"
13 #include "tcpcrypt.h"
14 #include "tcpcrypt_divert.h"
15 #include "tcpcryptd.h"
16 #include "crypto.h"
17 #include "profile.h"
18 #include "checksum.h"
19 #include "test.h"
20 
21 struct conn {
22 	struct sockaddr_in	c_addr[2];
23 	struct tc		*c_tc;
24 	struct conn		*c_next;
25 };
26 
27 /* XXX someone that knows what they're doing code a proper hash table */
28 static struct conn *_connection_map[65536];
29 
30 struct freelist {
31 	void		*f_obj;
32 	struct freelist	*f_next;
33 };
34 
35 struct retransmit {
36 	void	*r_timer;
37 	int	r_num;
38 	uint8_t	r_packet[0];
39 };
40 
41 struct ciphers {
42 	struct cipher_list	*c_cipher;
43 	unsigned char		c_spec[4];
44 	int			c_speclen;
45 	struct ciphers	 	*c_next;
46 };
47 
48 static struct tc		*_sockopts[65536];
49 static struct tc_sess		_sessions;
50 static struct ciphers		_ciphers_pkey;
51 static struct ciphers		_ciphers_sym;
52 static struct freelist		_free_free;
53 static struct freelist		_free_tc;
54 static struct freelist		_free_conn;
55 static struct tc_cipher_spec	_pkey[MAX_CIPHERS];
56 static int			_pkey_len;
57 static struct tc_scipher	_sym[MAX_CIPHERS];
58 static int			_sym_len;
59 
60 typedef int (*opt_cb)(struct tc *tc, int tcpop, int len, void *data);
61 typedef int (*sm_cb)(struct tc_seq *s, uint32_t seq);
62 
get_free(struct freelist * f,unsigned int sz)63 static void *get_free(struct freelist *f, unsigned int sz)
64 {
65 	struct freelist *x = f->f_next;
66 	void *o;
67 
68 	if (x) {
69 		o = x->f_obj;
70 		f->f_next = x->f_next;
71 
72 		if (f != &_free_free) {
73 			x->f_next         = _free_free.f_next;
74 			_free_free.f_next = x;
75 			x->f_obj	  = x;
76 		}
77 	} else {
78 		xprintf(XP_DEBUG, "Gotta malloc %u\n", sz);
79 		o = xmalloc(sz);
80 	}
81 
82 	return o;
83 }
84 
put_free(struct freelist * f,void * obj)85 static void put_free(struct freelist *f, void *obj)
86 {
87 	struct freelist *x = get_free(&_free_free, sizeof(*f));
88 
89 	x->f_obj  = obj;
90 	x->f_next = f->f_next;
91 	f->f_next = x;
92 }
93 
get_tc(void)94 static struct tc *get_tc(void)
95 {
96 	return get_free(&_free_tc, sizeof(struct tc));
97 }
98 
put_tc(struct tc * tc)99 static void put_tc(struct tc *tc)
100 {
101 	put_free(&_free_tc, tc);
102 }
103 
get_connection(void)104 static struct conn *get_connection(void)
105 {
106 	return get_free(&_free_conn, sizeof(struct conn));
107 }
108 
put_connection(struct conn * c)109 static void put_connection(struct conn *c)
110 {
111 	put_free(&_free_conn, c);
112 }
113 
do_add_ciphers(struct ciphers * c,void * spec,int * speclen,int sz,void * specend)114 static void do_add_ciphers(struct ciphers *c, void *spec, int *speclen, int sz,
115 			   void *specend)
116 {
117 	uint8_t *p = (uint8_t*) spec + *speclen;
118 
119 	c = c->c_next;
120 
121 	while (c) {
122 		unsigned char *sp = c->c_spec;
123 
124 		assert(p + sz <= (uint8_t*) specend);
125 
126 		memcpy(p, sp, sz);
127 		p        += sz;
128 		*speclen += sz;
129 
130 		c = c->c_next;
131 	}
132 }
133 
bad_packet(char * msg)134 static int bad_packet(char *msg)
135 {
136 	xprintf(XP_ALWAYS, "%s\n", msg);
137 
138 	return 0;
139 }
140 
tc_init(struct tc * tc)141 static void tc_init(struct tc *tc)
142 {
143 	memset(tc, 0, sizeof(*tc));
144 
145 	tc->tc_state        = _conf.cf_disable ? STATE_DISABLED : STATE_CLOSED;
146 	tc->tc_mtu	    = TC_MTU;
147 	tc->tc_mss_clamp    = 40; /* XXX */
148 	tc->tc_sack_disable = 1;
149 	tc->tc_rto	    = 100 * 1000; /* XXX */
150 	tc->tc_nocache	    = _conf.cf_nocache;
151 
152 	tc->tc_ciphers_pkey     = _pkey;
153 	tc->tc_ciphers_pkey_len = _pkey_len;
154 	tc->tc_ciphers_sym      = _sym;
155 	tc->tc_ciphers_sym_len  = _sym_len;
156 }
157 
158 /* XXX */
tc_reset(struct tc * tc)159 static void tc_reset(struct tc *tc)
160 {
161 	struct conn *c = tc->tc_conn;
162 
163 	assert(c);
164 	tc_init(tc);
165 	tc->tc_conn = c;
166 }
167 
kill_retransmit(struct tc * tc)168 static void kill_retransmit(struct tc *tc)
169 {
170 	if (!tc->tc_retransmit)
171 		return;
172 
173 	clear_timer(tc->tc_retransmit->r_timer);
174 	free(tc->tc_retransmit);
175 	tc->tc_retransmit = NULL;
176 }
177 
crypto_free_keyset(struct tc * tc,struct tc_keyset * ks)178 static void crypto_free_keyset(struct tc *tc, struct tc_keyset *ks)
179 {
180 	if (ks->tc_alg_tx)
181 		crypt_sym_destroy(ks->tc_alg_tx);
182 
183 	if (ks->tc_alg_rx)
184 		crypt_sym_destroy(ks->tc_alg_rx);
185 }
186 
do_kill_rdr(struct tc * tc)187 static void do_kill_rdr(struct tc *tc)
188 {
189 	struct fd *fd = tc->tc_rdr_fd;
190 
191 	tc->tc_state = STATE_DISABLED;
192 
193 	if (fd) {
194 		fd->fd_state = FDS_DEAD;
195 #ifdef __WIN32__
196 		closesocket(fd->fd_fd);
197 #else
198 		close(fd->fd_fd);
199 #endif
200 		fd->fd_fd = -1;
201 		tc->tc_rdr_fd = NULL;
202 	}
203 }
204 
kill_rdr(struct tc * tc)205 static void kill_rdr(struct tc *tc)
206 {
207 	struct tc *peer = tc->tc_rdr_peer;
208 
209 	do_kill_rdr(tc);
210 
211 	if (peer) {
212 		assert(peer->tc_rdr_peer == tc);
213 
214 		/* XXX will still leak conn and tc (if we don't receive other
215 		 * packets) */
216 		do_kill_rdr(peer);
217 	}
218 }
219 
tc_finish(struct tc * tc)220 static void tc_finish(struct tc *tc)
221 {
222 	if (tc->tc_crypt_pub)
223 		crypt_pub_destroy(tc->tc_crypt_pub);
224 
225 	if (tc->tc_crypt_sym)
226 		crypt_sym_destroy(tc->tc_crypt_sym);
227 
228 	crypto_free_keyset(tc, &tc->tc_key_current);
229 	crypto_free_keyset(tc, &tc->tc_key_next);
230 
231 	kill_retransmit(tc);
232 
233 	if (tc->tc_last_ack_timer)
234 		clear_timer(tc->tc_last_ack_timer);
235 
236 	if (tc->tc_sess)
237 		tc->tc_sess->ts_used = 0;
238 
239 	kill_rdr(tc);
240 }
241 
tc_dup(struct tc * tc)242 static struct tc *tc_dup(struct tc *tc)
243 {
244 	struct tc *x = get_tc();
245 
246 	assert(x);
247 
248 	*x = *tc;
249 
250 	assert(!x->tc_crypt);
251 	assert(!x->tc_crypt_ops);
252 
253 	return x;
254 }
255 
do_expand(struct tc * tc,uint8_t tag,struct stuff * out)256 static void do_expand(struct tc *tc, uint8_t tag, struct stuff *out)
257 {
258 	int len = tc->tc_crypt_pub->cp_k_len;
259 
260 	assert(len <= sizeof(out->s_data));
261 
262 	crypt_expand(tc->tc_crypt_pub->cp_hkdf, &tag, sizeof(tag), out->s_data,
263 		     len);
264 
265 	out->s_len = len;
266 }
267 
compute_nextk(struct tc * tc,struct stuff * out)268 static void compute_nextk(struct tc *tc, struct stuff *out)
269 {
270 	do_expand(tc, CONST_NEXTK, out);
271 }
272 
compute_mk(struct tc * tc,struct stuff * out)273 static void compute_mk(struct tc *tc, struct stuff *out)
274 {
275 	int len = tc->tc_crypt_pub->cp_k_len;
276 	unsigned char tag = CONST_REKEY;
277 
278 	assert(len <= sizeof(out->s_data));
279 
280 	crypt_expand(tc->tc_crypt_pub->cp_hkdf, &tag, sizeof(tag), out->s_data,
281 		     len);
282 
283 	out->s_len = len;
284 }
285 
compute_sid(struct tc * tc,struct stuff * out,int v)286 static void compute_sid(struct tc *tc, struct stuff *out, int v)
287 {
288 	do_expand(tc, CONST_SESSID, out);
289 
290 	assert(out->s_len + 1 <= sizeof(out->s_data));
291 	memmove(out->s_data + 1, out->s_data, out->s_len);
292 
293 	assert(tc->tc_cipher_pkey.tcs_algo);
294 
295 	out->s_data[0] = tc->tc_cipher_pkey.tcs_algo | v;
296 	out->s_len++;
297 }
298 
set_expand_key(struct tc * tc,struct stuff * s)299 static void set_expand_key(struct tc *tc, struct stuff *s)
300 {
301 	crypt_set_key(tc->tc_crypt_pub->cp_hkdf, s->s_data, s->s_len);
302 }
303 
session_cache(struct tc * tc)304 static void session_cache(struct tc *tc)
305 {
306 	struct tc_sess *s = tc->tc_sess;
307 
308 	if (tc->tc_nocache)
309 		return;
310 
311 	if (!s) {
312 		s = xmalloc(sizeof(*s));
313 		if (!s)
314 			err(1, "malloc()");
315 
316 		memset(s, 0, sizeof(*s));
317 		s->ts_next	  = _sessions.ts_next;
318 		_sessions.ts_next = s;
319 		tc->tc_sess	  = s;
320 
321 		s->ts_dir	 = tc->tc_dir;
322 		s->ts_role 	 = tc->tc_role;
323 		s->ts_ip   	 = tc->tc_dst_ip;
324 		s->ts_port 	 = tc->tc_dst_port;
325 		s->ts_pub_spec   = tc->tc_cipher_pkey.tcs_algo;
326 		s->ts_pub	 = crypt_new(tc->tc_crypt_pub->cp_ctr);
327 		s->ts_sym	 = crypt_new(tc->tc_crypt_sym->cs_ctr);
328 	}
329 
330 	set_expand_key(tc, &tc->tc_nk);
331 	profile_add(1, "session_cache crypto_mac_set_key");
332 
333 	compute_sid(tc, &s->ts_sid, TC_OPT_VLEN);
334 	profile_add(1, "session_cache compute_sid");
335 
336 	compute_mk(tc, &s->ts_mk);
337 	profile_add(1, "session_cache compute_mk");
338 
339 	compute_nextk(tc, &s->ts_nk);
340 	profile_add(1, "session_cache compute_nk");
341 }
342 
init_algo(struct tc * tc,struct crypt_sym * cs,struct crypt_sym ** algo,struct tc_keys * keys)343 static void init_algo(struct tc *tc, struct crypt_sym *cs,
344 		      struct crypt_sym **algo, struct tc_keys *keys)
345 {
346 	*algo = crypt_new(cs->cs_ctr);
347 
348 	cs = *algo;
349 
350 	assert(keys->tk_prk.s_len >= cs->cs_key_len);
351 
352 	crypt_set_key(cs->cs_cipher, keys->tk_prk.s_data, cs->cs_key_len);
353 }
354 
compute_keys(struct tc * tc,struct tc_keyset * out)355 static void compute_keys(struct tc *tc, struct tc_keyset *out)
356 {
357 	struct crypt_sym **tx, **rx;
358 
359 	set_expand_key(tc, &tc->tc_mk);
360 
361 	profile_add(1, "compute keys mac set key");
362 
363 	do_expand(tc, CONST_KEY_C, &out->tc_client.tk_prk);
364 	do_expand(tc, CONST_KEY_S, &out->tc_server.tk_prk);
365 
366 	profile_add(1, "compute keys calculated keys");
367 
368 	switch (tc->tc_role) {
369 	case ROLE_CLIENT:
370 		tx = &out->tc_alg_tx;
371 		rx = &out->tc_alg_rx;
372 		break;
373 
374 	case ROLE_SERVER:
375 		tx = &out->tc_alg_rx;
376 		rx = &out->tc_alg_tx;
377 		break;
378 
379 	default:
380 		assert(!"Unknown role");
381 		abort();
382 		break;
383 	}
384 
385 	init_algo(tc, tc->tc_crypt_sym, tx, &out->tc_client);
386 	init_algo(tc, tc->tc_crypt_sym, rx, &out->tc_server);
387 	profile_add(1, "initialized algos");
388 }
389 
get_algo_info(struct tc * tc)390 static void get_algo_info(struct tc *tc)
391 {
392 	tc->tc_mac_size = tc->tc_crypt_sym->cs_mac_len;
393 	tc->tc_sym_ivmode = IVMODE_SEQ; /* XXX */
394 }
395 
scrub_sensitive(struct tc * tc)396 static void scrub_sensitive(struct tc *tc)
397 {
398 }
399 
copy_stuff(struct stuff * dst,struct stuff * src)400 static void copy_stuff(struct stuff *dst, struct stuff *src)
401 {
402 	memcpy(dst, src, sizeof(*dst));
403 }
404 
session_resume(struct tc * tc)405 static int session_resume(struct tc *tc)
406 {
407 	struct tc_sess *s = tc->tc_sess;
408 
409 	if (!s)
410 		return 0;
411 
412 	copy_stuff(&tc->tc_sid, &s->ts_sid);
413 	copy_stuff(&tc->tc_mk, &s->ts_mk);
414 	copy_stuff(&tc->tc_nk, &s->ts_nk);
415 
416 	tc->tc_role	 	    = s->ts_role;
417 	tc->tc_crypt_sym 	    = crypt_new(s->ts_sym->cs_ctr);
418 	tc->tc_crypt_pub 	    = crypt_new(s->ts_pub->cp_ctr);
419 	tc->tc_cipher_pkey.tcs_algo = s->ts_pub_spec;
420 
421 	return 1;
422 }
423 
enable_encryption(struct tc * tc)424 static void enable_encryption(struct tc *tc)
425 {
426 	profile_add(1, "enable_encryption in");
427 
428 	tc->tc_state   = STATE_ENCRYPTING;
429 	tc->tc_rdr_len = 0;
430 
431 	if (!session_resume(tc)) {
432 		set_expand_key(tc, &tc->tc_ss);
433 
434 		profile_add(1, "enable_encryption mac set key");
435 
436 		compute_sid(tc, &tc->tc_sid, 0);
437 		profile_add(1, "enable_encryption compute SID");
438 
439 		compute_mk(tc, &tc->tc_mk);
440 		profile_add(1, "enable_encryption compute mk");
441 
442 		compute_nextk(tc, &tc->tc_nk);
443 		profile_add(1, "enable_encryption did compute_nextk");
444 	}
445 
446 	compute_keys(tc, &tc->tc_key_current);
447 	profile_add(1, "enable_encryption compute keys");
448 
449 	get_algo_info(tc);
450 
451 	session_cache(tc);
452 	profile_add(1, "enable_encryption session cache");
453 
454 	scrub_sensitive(tc);
455 }
456 
conn_hash(uint16_t src,uint16_t dst)457 static int conn_hash(uint16_t src, uint16_t dst)
458 {
459 	return (src + dst) %
460 		(sizeof(_connection_map) / sizeof(*_connection_map));
461 }
462 
get_head(uint16_t src,uint16_t dst)463 static struct conn *get_head(uint16_t src, uint16_t dst)
464 {
465 	return _connection_map[conn_hash(src, dst)];
466 }
467 
do_lookup_connection_prev(struct sockaddr_in * src,struct sockaddr_in * dst,struct conn ** prev)468 static struct tc *do_lookup_connection_prev(struct sockaddr_in *src,
469 					    struct sockaddr_in *dst,
470 					    struct conn **prev)
471 {
472 	struct conn *head;
473 	struct conn *c;
474 
475 	head = get_head(src->sin_port, dst->sin_port);
476 	if (!head)
477 		return NULL;
478 
479 	c     = head->c_next;
480 	*prev = head;
481 
482 	while (c) {
483 		if (   src->sin_addr.s_addr == c->c_addr[0].sin_addr.s_addr
484 		    && dst->sin_addr.s_addr == c->c_addr[1].sin_addr.s_addr
485 		    && src->sin_port == c->c_addr[0].sin_port
486 		    && dst->sin_port == c->c_addr[1].sin_port)
487 			return c->c_tc;
488 
489 		*prev = c;
490 		c = c->c_next;
491 	}
492 
493 	return NULL;
494 }
495 
lookup_connection_prev(struct ip * ip,struct tcphdr * tcp,int flags,struct conn ** prev)496 static struct tc *lookup_connection_prev(struct ip *ip, struct tcphdr *tcp,
497 				    	 int flags, struct conn **prev)
498 {
499 	struct sockaddr_in addr[2];
500 	int idx = flags & DF_IN ? 1 : 0;
501 
502 	addr[idx].sin_addr.s_addr  = ip->ip_src.s_addr;
503 	addr[idx].sin_port         = tcp->th_sport;
504 	addr[!idx].sin_addr.s_addr = ip->ip_dst.s_addr;
505 	addr[!idx].sin_port        = tcp->th_dport;
506 
507 	return do_lookup_connection_prev(&addr[0], &addr[1], prev);
508 }
509 
lookup_connection(struct ip * ip,struct tcphdr * tcp,int flags)510 static struct tc *lookup_connection(struct ip *ip, struct tcphdr *tcp,
511 				    int flags)
512 {
513 	struct conn *prev;
514 
515 	return lookup_connection_prev(ip, tcp, flags, &prev);
516 }
517 
sockopt_find_port(int port)518 static struct tc *sockopt_find_port(int port)
519 {
520 	return _sockopts[port];
521 }
522 
sockopt_find(struct tcpcrypt_ctl * ctl)523 static struct tc *sockopt_find(struct tcpcrypt_ctl *ctl)
524 {
525 	struct ip ip;
526 	struct tcphdr tcp;
527 
528 	if (!ctl->tcc_dport)
529 		return sockopt_find_port(ctl->tcc_sport);
530 
531 	/* XXX */
532 	ip.ip_src = ctl->tcc_src;
533 	ip.ip_dst = ctl->tcc_dst;
534 
535 	tcp.th_sport = ctl->tcc_sport;
536 	tcp.th_dport = ctl->tcc_dport;
537 
538 	return lookup_connection(&ip, &tcp, 0);
539 }
540 
sockopt_clear(unsigned short port)541 static void sockopt_clear(unsigned short port)
542 {
543 	_sockopts[port] = NULL;
544 }
545 
get_tcp(struct ip * ip)546 struct tcphdr *get_tcp(struct ip *ip)
547 {
548         return (struct tcphdr*) ((unsigned long) ip + ip->ip_hl * 4);
549 }
550 
do_inject_ip(struct ip * ip)551 static void do_inject_ip(struct ip *ip)
552 {
553 	xprintf(XP_NOISY, "Injecting ");
554 	print_packet(ip, get_tcp(ip), 0, NULL);
555 
556 	_divert->inject(ip, ntohs(ip->ip_len));
557 }
558 
inject_ip(struct ip * ip)559 static void inject_ip(struct ip *ip)
560 {
561 	if (_conf.cf_rdr)
562 		return;
563 
564 	do_inject_ip(ip);
565 }
566 
retransmit(void * a)567 static void retransmit(void *a)
568 {
569 	struct tc *tc = a;
570 	struct ip *ip;
571 
572 	xprintf(XP_DEBUG, "Retransmitting %p\n", tc);
573 
574 	assert(tc->tc_retransmit);
575 
576 	if (tc->tc_retransmit->r_num++ >= 10) {
577 		xprintf(XP_DEFAULT, "Retransmit timeout\n");
578 		tc->tc_tcp_state = TCPSTATE_DEAD; /* XXX remove connection */
579 	}
580 
581 	ip = (struct ip*) tc->tc_retransmit->r_packet;
582 
583 	inject_ip(ip);
584 
585 	/* XXX decay */
586 	tc->tc_retransmit->r_timer = add_timer(tc->tc_rto, retransmit, tc);
587 }
588 
add_connection(struct conn * c)589 static void add_connection(struct conn *c)
590 {
591 	int idx = c->c_addr[0].sin_port;
592 	struct conn *head;
593 
594 	idx = conn_hash(c->c_addr[0].sin_port, c->c_addr[1].sin_port);
595 	if (!_connection_map[idx]) {
596 		_connection_map[idx] = xmalloc(sizeof(*c));
597 		memset(_connection_map[idx], 0, sizeof(*c));
598 	}
599 
600 	head = _connection_map[idx];
601 
602 	c->c_next    = head->c_next;
603 	head->c_next = c;
604 }
605 
do_new_connection(uint32_t saddr,uint16_t sport,uint32_t daddr,uint16_t dport,int in)606 static struct tc *do_new_connection(uint32_t saddr, uint16_t sport,
607 				    uint32_t daddr, uint16_t dport,
608 				    int in)
609 {
610 	struct tc *tc;
611 	struct conn *c;
612 	int idx = in ? 1 : 0;
613 
614 	c = get_connection();
615 	assert(c);
616 	profile_add(2, "alloc connection");
617 
618 	memset(c, 0, sizeof(*c));
619 	c->c_addr[idx].sin_addr.s_addr  = saddr;
620 	c->c_addr[idx].sin_port         = sport;
621 	c->c_addr[!idx].sin_addr.s_addr = daddr;
622 	c->c_addr[!idx].sin_port        = dport;
623 	profile_add(2, "setup connection");
624 
625 	tc = sockopt_find_port(c->c_addr[0].sin_port);
626 	if (!tc) {
627 		tc = get_tc();
628 		assert(tc);
629 
630 		profile_add(2, "TC malloc");
631 
632 		tc_init(tc);
633 
634 		profile_add(2, "TC init");
635 	} else {
636 		/* For servers, we gotta duplicate options on child sockets.
637 		 * For clients, we just steal it.
638 		 */
639 		if (in)
640 			tc = tc_dup(tc);
641 		else
642 			sockopt_clear(c->c_addr[0].sin_port);
643 	}
644 
645 	tc->tc_dst_ip.s_addr = c->c_addr[1].sin_addr.s_addr;
646 	tc->tc_dst_port	     = c->c_addr[1].sin_port;
647 	tc->tc_conn	     = c;
648 
649 	c->c_tc	= tc;
650 
651 	add_connection(c);
652 
653 	return tc;
654 }
655 
new_connection(struct ip * ip,struct tcphdr * tcp,int flags)656 static struct tc *new_connection(struct ip *ip, struct tcphdr *tcp, int flags)
657 {
658 	return do_new_connection(ip->ip_src.s_addr, tcp->th_sport,
659 				 ip->ip_dst.s_addr, tcp->th_dport,
660 				 flags & DF_IN);
661 }
662 
do_remove_connection(struct tc * tc,struct conn * prev)663 static void do_remove_connection(struct tc *tc, struct conn *prev)
664 {
665 	struct conn *item;
666 
667 	assert(tc);
668 	assert(prev);
669 
670 	item = prev->c_next;
671 	assert(item);
672 
673 	tc_finish(tc);
674 	put_tc(tc);
675 
676 	prev->c_next = item->c_next;
677 	put_connection(item);
678 }
679 
remove_connection(struct ip * ip,struct tcphdr * tcp,int flags)680 static void remove_connection(struct ip *ip, struct tcphdr *tcp, int flags)
681 {
682 	struct conn *prev = NULL;
683 	struct tc *tc;
684 
685 	tc = lookup_connection_prev(ip, tcp, flags, &prev);
686 
687 	do_remove_connection(tc, prev);
688 }
689 
kill_connection(struct tc * tc)690 static void kill_connection(struct tc *tc)
691 {
692 	struct conn *c = tc->tc_conn;
693 	struct conn *prev;
694 	struct tc *found;
695 
696 	assert(c);
697 	found = do_lookup_connection_prev(&c->c_addr[0], &c->c_addr[1], &prev);
698 	assert(found);
699 	assert(found == tc);
700 
701 	do_remove_connection(tc, prev);
702 }
703 
last_ack(void * a)704 static void last_ack(void *a)
705 {
706 	struct tc *tc = a;
707 
708 	tc->tc_last_ack_timer = NULL;
709 	xprintf(XP_NOISY, "Last ack for %p\n");
710 	kill_connection(tc);
711 }
712 
tcp_data(struct tcphdr * tcp)713 static void *tcp_data(struct tcphdr *tcp)
714 {
715 	return (char*) tcp + (tcp->th_off << 2);
716 }
717 
tcp_data_len(struct ip * ip,struct tcphdr * tcp)718 static int tcp_data_len(struct ip *ip, struct tcphdr *tcp)
719 {
720 	int hl = (ip->ip_hl << 2) + (tcp->th_off << 2);
721 
722 	return ntohs(ip->ip_len) - hl;
723 }
724 
find_opt(struct tcphdr * tcp,unsigned char opt)725 static void *find_opt(struct tcphdr *tcp, unsigned char opt)
726 {
727 	unsigned char *p = (unsigned char*) (tcp + 1);
728 	int len = (tcp->th_off << 2) - sizeof(*tcp);
729 	int o, l;
730 
731 	assert(len >= 0);
732 
733 	while (len > 0) {
734 		if (*p == opt) {
735 			if (*(p + 1) > len) {
736 				xprintf(XP_ALWAYS, "fek\n");
737 				return NULL;
738 			}
739 
740 			return p;
741 		}
742 
743 		o = *p++;
744 		len--;
745 
746 		switch (o) {
747 		case TCPOPT_EOL:
748 		case TCPOPT_NOP:
749 			continue;
750 		}
751 
752 		if (!len) {
753 			xprintf(XP_ALWAYS, "fuck\n");
754 			return NULL;
755 		}
756 
757 		l = *p++;
758 		len--;
759 		if (l > (len + 2) || l < 2) {
760 			xprintf(XP_ALWAYS, "fuck2 %d %d\n", l, len);
761 			return NULL;
762 		}
763 
764 		p += l - 2;
765 		len -= l - 2;
766 	}
767 	assert(len == 0);
768 
769 	return NULL;
770 }
771 
checksum_packet(struct tc * tc,struct ip * ip,struct tcphdr * tcp)772 void checksum_packet(struct tc *tc, struct ip *ip, struct tcphdr *tcp)
773 {
774 	checksum_ip(ip);
775 	checksum_tcp(tc, ip, tcp);
776 }
777 
set_ip_len(struct ip * ip,unsigned short len)778 static void set_ip_len(struct ip *ip, unsigned short len)
779 {
780 	unsigned short old = ntohs(ip->ip_len);
781 	int diff;
782 	int sum;
783 
784 	ip->ip_len = htons(len);
785 
786 	diff	   = len - old;
787 	sum  	   = ntohs(~ip->ip_sum);
788 	sum 	  += diff;
789 	sum	   = (sum >> 16) + (sum & 0xffff);
790 	sum	  += (sum >> 16);
791 	ip->ip_sum = htons(~sum);
792 }
793 
foreach_opt(struct tc * tc,struct tcphdr * tcp,opt_cb cb)794 static void foreach_opt(struct tc *tc, struct tcphdr *tcp, opt_cb cb)
795 {
796 	unsigned char *p = (unsigned char*) (tcp + 1);
797 	int len = (tcp->th_off << 2) - sizeof(*tcp);
798 	int o, l;
799 
800 	assert(len >= 0);
801 
802 	while (len > 0) {
803 		o = *p++;
804 		len--;
805 
806 		switch (o) {
807 		case TCPOPT_EOL:
808 		case TCPOPT_NOP:
809 			continue; /* XXX optimize */
810 			l = 0;
811 			break;
812 
813 		default:
814 			if (!len) {
815 				xprintf(XP_ALWAYS, "fuck\n");
816 				return;
817 			}
818 			l = *p++;
819 			len--;
820 			if (l < 2 || l > (len + 2)) {
821 				xprintf(XP_ALWAYS, "fuck2 %d %d\n", l, len);
822 				return;
823 			}
824 			l -= 2;
825 			break;
826 		}
827 
828 		if (cb(tc, o, l, p))
829 			return;
830 
831 		p   += l;
832 		len -= l;
833 	}
834 	assert(len == 0);
835 }
836 
do_ops_len(struct tc * tc,int tcpop,int len,void * data)837 static int do_ops_len(struct tc *tc, int tcpop, int len, void *data)
838 {
839 	tc->tc_optlen += len + 2;
840 
841 	return 0;
842 }
843 
tcp_ops_len(struct tc * tc,struct tcphdr * tcp)844 static int tcp_ops_len(struct tc *tc, struct tcphdr *tcp)
845 {
846 	int nops   = 40;
847 	uint8_t *p = (uint8_t*) (tcp + 1);
848 
849 	tc->tc_optlen = 0;
850 
851 	foreach_opt(tc, tcp, do_ops_len);
852 
853 	nops -= tc->tc_optlen;
854 	p    += tc->tc_optlen;
855 
856 	assert(nops >= 0);
857 
858 	while (nops--) {
859 		if (*p != TCPOPT_NOP && *p != TCPOPT_EOL)
860 			return (tcp->th_off << 2) - 20;
861 
862 		p++;
863 	}
864 
865 	return tc->tc_optlen;
866 }
867 
tcp_opts_alloc(struct tc * tc,struct ip * ip,struct tcphdr * tcp,int len)868 static void *tcp_opts_alloc(struct tc *tc, struct ip *ip, struct tcphdr *tcp,
869 			    int len)
870 {
871 	int opslen = (tcp->th_off << 2) + len;
872 	int pad = opslen % 4;
873 	char *p;
874 	int dlen = ntohs(ip->ip_len) - (ip->ip_hl << 2) - (tcp->th_off << 2);
875 	int ol = (tcp->th_off << 2) - sizeof(*tcp);
876 
877 	assert(len);
878 
879 	/* find space in tail if full of nops */
880 	if (ol == 40) {
881 		ol = tcp_ops_len(tc, tcp);
882 		assert(ol <= 40);
883 
884 		if (40 - ol >= len)
885 			return (uint8_t*) (tcp + 1) + ol;
886 	}
887 
888 	if (pad)
889 		len += 4 - pad;
890 
891 	if (ntohs(ip->ip_len) + len > tc->tc_mtu)
892 		return NULL;
893 
894 	p = (char*) tcp + (tcp->th_off << 2);
895 	memmove(p + len, p, dlen);
896 	memset(p, 0, len);
897 
898 	assert(((tcp->th_off << 2) + len) <= 60);
899 
900 	set_ip_len(ip, ntohs(ip->ip_len) + len);
901 	tcp->th_off += len >> 2;
902 
903 	return p;
904 }
905 
session_find_host(struct tc * tc,struct in_addr * in,int port)906 static struct tc_sess *session_find_host(struct tc *tc, struct in_addr *in,
907 					 int port)
908 {
909 	struct tc_sess *s = _sessions.ts_next;
910 
911 	while (s) {
912 		/* we're liberal - lets only check host */
913 		if (!s->ts_used
914 		    && (s->ts_dir == tc->tc_dir)
915 		    && (s->ts_ip.s_addr == in->s_addr))
916 			return s;
917 
918 		s = s->ts_next;
919 	}
920 
921 	return NULL;
922 }
923 
is_eno(int tcpop,void * data,int len)924 static int is_eno(int tcpop, void *data, int len)
925 {
926 	uint16_t *exid = data;
927 
928 	if (tcpop != TCPOPT_EXP)
929 		return 0;
930 
931 	if (len < sizeof(*exid))
932 		return 0;
933 
934 	if (ntohs(*exid) != EXID_ENO)
935 		return 0;
936 
937 	return 1;
938 }
939 
get_eno(int tcpop,void ** data,int * len)940 static int get_eno(int tcpop, void **data, int *len)
941 {
942 	if (!is_eno(tcpop, *data, *len))
943 		return 0;
944 
945 	assert(*len >= 2);
946 
947 	*len -= 2;
948 	*data = ((unsigned char*) *data) + 2;
949 
950 	return 1;
951 }
952 
do_set_eno_transcript(struct tc * tc,int tcpop,int len,void * data)953 static int do_set_eno_transcript(struct tc *tc, int tcpop, int len, void *data)
954 {
955 	uint8_t *p = &tc->tc_eno[tc->tc_eno_len];
956 
957 	if (!is_eno(tcpop, data, len))
958 		return 0;
959 
960 	assert(len + 2 + tc->tc_eno_len < sizeof(tc->tc_eno));
961 
962 	*p++ = TCPOPT_EXP;
963 	*p++ = len + 2;
964 
965 	memcpy(p, data, len);
966 
967 	tc->tc_eno_len += 2 + len;
968 
969 	return 0;
970 }
971 
set_eno(struct tcpopt_eno * eno,int len)972 static void set_eno(struct tcpopt_eno *eno, int len)
973 {
974 	eno->toe_kind = TCPOPT_EXP;
975 	eno->toe_len  = len;
976 	eno->toe_exid = htons(0x454E);
977 }
978 
set_eno_transcript(struct tc * tc,struct tcphdr * tcp)979 static void set_eno_transcript(struct tc *tc, struct tcphdr *tcp)
980 {
981 	struct tcpopt_eno *eno;
982 
983 	foreach_opt(tc, tcp, do_set_eno_transcript);
984 
985 	assert(tc->tc_eno_len + sizeof(*eno) < sizeof(tc->tc_eno));
986 
987 	eno = (struct tcpopt_eno*) &tc->tc_eno[tc->tc_eno_len];
988 	set_eno(eno, sizeof(*eno));
989 
990 	tc->tc_eno_len += sizeof(*eno);
991 }
992 
send_rst(struct tc * tc)993 static void send_rst(struct tc *tc)
994 {
995         struct ip *ip = (struct ip*) tc->tc_rdr_buf;
996         struct tcphdr *tcp = (struct tcphdr*) get_tcp(ip);
997         struct in_addr addr;
998         int port;
999 
1000         addr.s_addr = ip->ip_src.s_addr;
1001         ip->ip_src.s_addr = ip->ip_dst.s_addr;
1002         ip->ip_dst.s_addr = addr.s_addr;
1003 
1004         port = tcp->th_sport;
1005         tcp->th_sport = tcp->th_dport;
1006         tcp->th_dport = port;
1007 
1008         tcp->th_flags = TH_RST | TH_ACK;
1009         tcp->th_ack   = htonl(ntohl(tcp->th_seq) + 1);
1010         tcp->th_seq   = htonl(0);
1011 
1012 	checksum_packet(tc, ip, tcp);
1013 
1014 	xprintf(XP_ALWAYS, "Sending RST\n");
1015 
1016         do_inject_ip(ip);
1017 }
1018 
rdr_check_connect(struct tc * tc)1019 static void rdr_check_connect(struct tc *tc)
1020 {
1021         int e;
1022         socklen_t len = sizeof(e);
1023 	struct fd *fd = tc->tc_rdr_fd;
1024         struct ip *ip = (struct ip*) tc->tc_rdr_buf;
1025 
1026         if (getsockopt(fd->fd_fd, SOL_SOCKET, SO_ERROR, &e, &len) == -1) {
1027                 perror("getsockopt()");
1028 		kill_rdr(tc);
1029                 return;
1030         }
1031 
1032         if (e != 0) {
1033 #ifdef __WIN32__
1034 		if (e == WSAECONNREFUSED)
1035 #else
1036                 if (e == ECONNREFUSED)
1037 #endif
1038                         send_rst(tc);
1039 
1040 		kill_rdr(tc);
1041                 return;
1042         }
1043 
1044 	xprintf(XP_NOISY, "Connected %p %s\n",
1045 		tc, tc->tc_rdr_inbound ?  "inbound" : "");
1046 
1047 	tc->tc_rdr_connected = 1;
1048 	fd->fd_state = FDS_IDLE;
1049 
1050 	if (tc->tc_rdr_inbound) {
1051                 /* we need to manually redirect... */
1052                 struct tcphdr *tcp = get_tcp(ip);
1053 
1054                 ip->ip_dst.s_addr = inet_addr("127.0.0.1");
1055                 tcp->th_dport = htons(REDIRECT_PORT);
1056                 checksum_packet(tc, ip, tcp);
1057 	}
1058 
1059 	/* inject the local SYN so that user connects to proxy */
1060 	if (!tc->tc_rdr_peer->tc_rdr_drop_sa)
1061 		do_inject_ip(ip);
1062 }
1063 
do_output_closed(struct tc * tc,struct ip * ip,struct tcphdr * tcp)1064 static int do_output_closed(struct tc *tc, struct ip *ip, struct tcphdr *tcp)
1065 {
1066 	struct tc_sess *ts = tc->tc_sess;
1067 	struct tcpopt_eno *eno;
1068 	struct tc_sid *sopt;
1069 	int len;
1070 	uint8_t *p;
1071 
1072 	tc->tc_dir = DIR_OUT;
1073 
1074 	if (tcp->th_flags != TH_SYN)
1075 		return DIVERT_ACCEPT;
1076 
1077 	if (!ts && !tc->tc_nocache)
1078 		ts = session_find_host(tc, &ip->ip_dst, tcp->th_dport);
1079 
1080 	len = sizeof(*eno) + tc->tc_ciphers_pkey_len;
1081 
1082 	if (tc->tc_app_support)
1083 		len += 1;
1084 
1085 	if (ts)
1086 		len += sizeof(*sopt);
1087 
1088 	eno = tcp_opts_alloc(tc, ip, tcp, len);
1089 	if (!eno) {
1090 		xprintf(XP_ALWAYS, "No space for hello\n");
1091 		tc->tc_state = STATE_DISABLED;
1092 
1093 		/* XXX try without session resumption */
1094 
1095 		return DIVERT_ACCEPT;
1096 	}
1097 
1098 	set_eno(eno, len);
1099 
1100 	memcpy(eno->toe_opts, tc->tc_ciphers_pkey, tc->tc_ciphers_pkey_len);
1101 
1102 	p = eno->toe_opts + tc->tc_ciphers_pkey_len;
1103 
1104 	if (tc->tc_app_support)
1105 		*p++ = tc->tc_app_support << 1;
1106 
1107 	tc->tc_state = STATE_HELLO_SENT;
1108 
1109 	if (!ts) {
1110 		if (!_conf.cf_nocache)
1111 			xprintf(XP_DEBUG, "Can't find session for host\n");
1112 	} else {
1113 		/* session caching */
1114 		sopt = (struct tc_sid*) p;
1115 
1116 		assert(ts->ts_sid.s_len >= sizeof(*sopt));
1117 		memcpy(sopt, &ts->ts_sid.s_data, sizeof(*sopt));
1118 
1119 		tc->tc_state = STATE_NEXTK1_SENT;
1120 		assert(!ts->ts_used || ts == tc->tc_sess);
1121 		tc->tc_sess  = ts;
1122 		ts->ts_used  = 1;
1123 	}
1124 
1125 	tc->tc_eno_len = 0;
1126 	set_eno_transcript(tc, tcp);
1127 
1128 	return DIVERT_MODIFY;
1129 }
1130 
do_output_hello_rcvd(struct tc * tc,struct ip * ip,struct tcphdr * tcp)1131 static int do_output_hello_rcvd(struct tc *tc, struct ip *ip,
1132 				struct tcphdr *tcp)
1133 {
1134 	struct tcpopt_eno *eno;
1135 	int len;
1136 	int app_support = tc->tc_app_support & 1;
1137 
1138 	len = sizeof(*eno) + sizeof(tc->tc_cipher_pkey);
1139 
1140 	if (app_support)
1141 		len++;
1142 
1143 	eno = tcp_opts_alloc(tc, ip, tcp, len);
1144 	if (!eno) {
1145 		xprintf(XP_ALWAYS, "No space for ENO\n");
1146 		tc->tc_state = STATE_DISABLED;
1147 
1148 		return DIVERT_ACCEPT;
1149 	}
1150 
1151 	set_eno(eno, len);
1152 
1153 	memcpy(eno->toe_opts, &tc->tc_cipher_pkey, sizeof(tc->tc_cipher_pkey));
1154 
1155 	if (app_support)
1156 		eno->toe_opts[sizeof(tc->tc_cipher_pkey)] = app_support << 1;
1157 
1158 	/* don't set on retransmit.  XXX check if same */
1159 	if (tc->tc_state != STATE_PKCONF_SENT)
1160 		set_eno_transcript(tc, tcp);
1161 
1162 	tc->tc_state = STATE_PKCONF_SENT;
1163 
1164 	return DIVERT_MODIFY;
1165 }
1166 
seqmap_find_start(struct tc_seq * s,uint32_t seq)1167 static int seqmap_find_start(struct tc_seq *s, uint32_t seq)
1168 {
1169 	return s->sm_start == seq;
1170 }
1171 
seqmap_find_end(struct tc_seq * s,uint32_t seq)1172 static int seqmap_find_end(struct tc_seq *s, uint32_t seq)
1173 {
1174 	return s->sm_end == seq;
1175 }
1176 
1177 /* kernel -> internet */
seqmap_find_ack_out(struct tc_seq * s,uint32_t ack)1178 static int seqmap_find_ack_out(struct tc_seq *s, uint32_t ack)
1179 {
1180 	return (s->sm_end - s->sm_off) == ack;
1181 }
1182 
1183 /* internet -> kernel */
seqmap_find_ack_in(struct tc_seq * s,uint32_t ack)1184 static int seqmap_find_ack_in(struct tc_seq *s, uint32_t ack)
1185 {
1186 	return (s->sm_end + s->sm_off) == ack;
1187 }
1188 
seqmap_find(struct tc_seqmap * sm,uint32_t seq,sm_cb cb)1189 static struct tc_seq *seqmap_find(struct tc_seqmap *sm, uint32_t seq, sm_cb cb)
1190 {
1191 	int i = sm->sm_idx;
1192 
1193 	do {
1194 		struct tc_seq *s = &sm->sm_seq[i];
1195 
1196 		if (s->sm_start == 0 && s->sm_end == 0 && s->sm_off == 0)
1197 			return NULL;
1198 
1199 		if (cb(s, seq))
1200 			return s;
1201 
1202 		if (i == 0)
1203 			i = MAX_SEQMAP - 1;
1204 		else
1205 			i--;
1206 	} while (i != sm->sm_idx);
1207 
1208 	return NULL;
1209 }
1210 
get_seq_off(struct tc * tc,uint32_t seq,struct tc_seqmap * seqmap,sm_cb cb)1211 static uint32_t get_seq_off(struct tc *tc, uint32_t seq,
1212 			    struct tc_seqmap *seqmap, sm_cb cb)
1213 {
1214 	struct tc_seq *s = seqmap_find(seqmap, seq, cb);
1215 
1216 	if (!s)
1217 		return 0; /* XXX */
1218 
1219 	return s->sm_off;
1220 }
1221 
add_seq(struct tc * tc,struct ip * ip,struct tcphdr * tcp,int len,struct tc_seqmap * seqmap)1222 static void add_seq(struct tc *tc, struct ip *ip, struct tcphdr *tcp, int len,
1223 		    struct tc_seqmap *seqmap)
1224 {
1225 	uint32_t dlen = tcp_data_len(ip, tcp);
1226 	uint32_t seq  = ntohl(tcp->th_seq);
1227 	uint32_t off  = len;
1228 	struct tc_seq *s, *rtr;
1229 
1230 	/* find cumulative offset until now, based on last packet */
1231 	s = seqmap_find(seqmap, seq, seqmap_find_end);
1232 	if (!s) {
1233 		/* can't find last packet... but it's ok if we just started */
1234 		s = &seqmap->sm_seq[seqmap->sm_idx];
1235 
1236 		if (seqmap->sm_idx != 0
1237 		    || s->sm_start != 0 || s->sm_end != 0 || s->sm_off != 0) {
1238 			xprintf(XP_ALWAYS, "Damn - can't find seq %u\n", seq);
1239 			return;
1240 		}
1241 	}
1242 
1243 	/* Check if it's a retransmit.
1244 	 * XXX be more efficient
1245 	 */
1246 	rtr = seqmap_find(seqmap, seq, seqmap_find_start);
1247 	if (rtr) {
1248 		if (rtr->sm_end != (seq + dlen)) {
1249 			xprintf(XP_ALWAYS, "Damn - retransmitted diff size\n");
1250 			return;
1251 		}
1252 
1253 		/* retransmit */
1254 		return;
1255 	}
1256 
1257 	off += s->sm_off;
1258 
1259 	/* add an entry for this packet */
1260 	seqmap->sm_idx = (seqmap->sm_idx + 1) % MAX_SEQMAP;
1261 	s = &seqmap->sm_seq[seqmap->sm_idx];
1262 
1263 	s->sm_start = seq;
1264 	s->sm_end   = seq + dlen;
1265 	s->sm_off   = off;
1266 }
1267 
1268 /*
1269  * 1.  Record an entry for how much padding we're adding for this packet.
1270  * 2.  Fix up the sequence number for this packet.
1271  */
fixup_seq_add(struct tc * tc,struct ip * ip,struct tcphdr * tcp,int len,int in)1272 static void fixup_seq_add(struct tc *tc, struct ip *ip, struct tcphdr *tcp,
1273 			  int len, int in)
1274 {
1275 	uint32_t ack, seq;
1276 
1277 	if (_conf.cf_rdr)
1278 		return;
1279 
1280 	if (in) {
1281 		if (len)
1282 			add_seq(tc, ip, tcp, len, &tc->tc_rseqm);
1283 
1284 		ack  = ntohl(tcp->th_ack) - tc->tc_seq_off;
1285 		ack -= get_seq_off(tc, ack, &tc->tc_seqm, seqmap_find_ack_in);
1286 
1287 		tcp->th_ack = htonl(ack);
1288 
1289 		seq  = ntohl(tcp->th_seq);
1290 		seq -= get_seq_off(tc, seq, &tc->tc_rseqm, seqmap_find_end);
1291 		seq -= tc->tc_rseq_off;
1292 
1293 		tcp->th_seq = htonl(seq);
1294 	} else {
1295 		if (len)
1296 			add_seq(tc, ip, tcp, len, &tc->tc_seqm);
1297 
1298 		seq  = ntohl(tcp->th_seq);
1299 		seq += get_seq_off(tc, seq, &tc->tc_seqm, seqmap_find_end);
1300 		seq += tc->tc_seq_off;
1301 
1302 		tcp->th_seq = htonl(seq);
1303 
1304 		ack  = ntohl(tcp->th_ack) + tc->tc_rseq_off;
1305 		ack += get_seq_off(tc, ack, &tc->tc_rseqm, seqmap_find_ack_out);
1306 
1307 		tcp->th_ack = htonl(ack);
1308 	}
1309 
1310 	return;
1311 }
1312 
data_alloc(struct tc * tc,struct ip * ip,struct tcphdr * tcp,int len,int retx)1313 static void *data_alloc(struct tc *tc, struct ip *ip, struct tcphdr *tcp,
1314 			int len, int retx)
1315 {
1316 	int totlen = ntohs(ip->ip_len);
1317 	int hl     = (ip->ip_hl << 2) + (tcp->th_off << 2);
1318 	void *p;
1319 
1320 	if (_conf.cf_rdr) {
1321 		assert(len < sizeof(tc->tc_rdr_buf));
1322 		tc->tc_rdr_len = len;
1323 
1324 		return tc->tc_rdr_buf;
1325 	}
1326 
1327 	assert(totlen == hl);
1328 	p = (char*) tcp + (tcp->th_off << 2);
1329 
1330 	totlen += len;
1331 	assert(totlen <= 1500);
1332 	set_ip_len(ip, totlen);
1333 
1334 	if (!retx)
1335 		tc->tc_seq_off = len;
1336 
1337 	return p;
1338 }
1339 
do_random(void * p,int len)1340 static void do_random(void *p, int len)
1341 {
1342 	uint8_t *x = p;
1343 
1344 	while (len--)
1345 		*x++ = rand() & 0xff;
1346 }
1347 
generate_nonce(struct tc * tc,int len)1348 static void generate_nonce(struct tc *tc, int len)
1349 {
1350 	profile_add(1, "generated nonce in");
1351 
1352 	assert(tc->tc_nonce_len == 0);
1353 
1354 	tc->tc_nonce_len = len;
1355 
1356 	do_random(tc->tc_nonce, tc->tc_nonce_len);
1357 
1358 	profile_add(1, "generated nonce out");
1359 }
1360 
add_eno(struct tc * tc,struct ip * ip,struct tcphdr * tcp)1361 static int add_eno(struct tc *tc, struct ip *ip, struct tcphdr *tcp)
1362 {
1363 	struct tcpopt_eno *eno;
1364 	int len = sizeof(*eno);
1365 
1366 	eno = tcp_opts_alloc(tc, ip, tcp, len);
1367 	if (!eno) {
1368 		xprintf(XP_ALWAYS, "No space for ENO\n");
1369 		tc->tc_state = STATE_DISABLED;
1370 		return -1;
1371 	}
1372 
1373 	set_eno(eno, len);
1374 
1375 	return 0;
1376 }
1377 
do_output_pkconf_rcvd(struct tc * tc,struct ip * ip,struct tcphdr * tcp,int retx)1378 static int do_output_pkconf_rcvd(struct tc *tc, struct ip *ip,
1379 				 struct tcphdr *tcp, int retx)
1380 {
1381 	int len;
1382 	uint16_t klen;
1383 	struct tc_init1 *init1;
1384 	void *key;
1385 	uint8_t *p;
1386 
1387 	/* Add the minimal ENO option to indicate support */
1388 	if (add_eno(tc, ip, tcp) == -1)
1389 		return DIVERT_ACCEPT;
1390 
1391 	if (!retx)
1392 		generate_nonce(tc, tc->tc_crypt_pub->cp_n_c);
1393 
1394 	klen = crypt_get_key(tc->tc_crypt_pub->cp_pub, &key);
1395 	len  = sizeof(*init1)
1396 	       + tc->tc_ciphers_sym_len
1397 	       + tc->tc_nonce_len
1398 	       + klen;
1399 
1400 	init1 = data_alloc(tc, ip, tcp, len, retx);
1401 
1402 	init1->i1_magic    = htonl(TC_INIT1);
1403 	init1->i1_len      = htonl(len);
1404 	init1->i1_nciphers = tc->tc_ciphers_sym_len;
1405 
1406 	p = init1->i1_data;
1407 
1408 	memcpy(p, tc->tc_ciphers_sym, tc->tc_ciphers_sym_len);
1409 	p += tc->tc_ciphers_sym_len;
1410 
1411 	memcpy(p, tc->tc_nonce, tc->tc_nonce_len);
1412 	p += tc->tc_nonce_len;
1413 
1414 	memcpy(p, key, klen);
1415 	p += klen;
1416 
1417 	tc->tc_state = STATE_INIT1_SENT;
1418 	tc->tc_role  = ROLE_CLIENT;
1419 
1420 	assert(len <= sizeof(tc->tc_init1));
1421 
1422 	memcpy(tc->tc_init1, init1, len);
1423 	tc->tc_init1_len = len;
1424 
1425 	tc->tc_isn = ntohl(tcp->th_seq) + len;
1426 
1427 	return DIVERT_MODIFY;
1428 }
1429 
do_output_init1_rcvd(struct tc * tc,struct ip * ip,struct tcphdr * tcp)1430 static int do_output_init1_rcvd(struct tc *tc, struct ip *ip,
1431 				struct tcphdr *tcp)
1432 {
1433 	return DIVERT_ACCEPT;
1434 }
1435 
is_init(struct ip * ip,struct tcphdr * tcp,int init)1436 static int is_init(struct ip *ip, struct tcphdr *tcp, int init)
1437 {
1438 	struct tc_init1 *i1 = tcp_data(tcp);
1439 	int dlen = tcp_data_len(ip, tcp);
1440 
1441 	if (dlen < sizeof(*i1))
1442 		return 0;
1443 
1444 	if (ntohl(i1->i1_magic) != init)
1445 		return 0;
1446 
1447 	return 1;
1448 }
1449 
do_output_init2_sent(struct tc * tc,struct ip * ip,struct tcphdr * tcp)1450 static int do_output_init2_sent(struct tc *tc, struct ip *ip,
1451 				struct tcphdr *tcp)
1452 {
1453 	/* we generated this packet */
1454 	int is_init2 = is_init(ip, tcp, TC_INIT2);
1455 
1456 	/* kernel is getting pissed off and is resending SYN ack (because we're
1457 	 * delaying his connect setup)
1458 	 */
1459 	if (!is_init2) {
1460 		/* we could piggy back / retx init2 */
1461 
1462 		assert(tcp_data_len(ip, tcp) == 0);
1463 		assert(tcp->th_flags == (TH_SYN | TH_ACK));
1464 		assert(tc->tc_retransmit);
1465 
1466 		/* XXX */
1467 		ip  = (struct ip*) tc->tc_retransmit->r_packet;
1468 		tcp = (struct tcphdr*) (ip + 1);
1469 		assert(is_init(ip, tcp, TC_INIT2));
1470 
1471 		return DIVERT_DROP;
1472 	} else {
1473 		/* Let the ACK of INIT2 enable encryption.  Less efficient when
1474 		 * servers send first because we wait for that ACK to open up
1475 		 * window and let kernel send packets.
1476 		 *
1477 		 * Otherwise, be careful not to encrypt retransmits.
1478 		 */
1479 #if 0
1480 		enable_encryption(tc);
1481 #endif
1482 	}
1483 
1484 	return DIVERT_ACCEPT;
1485 }
1486 
get_iv(struct tc * tc,struct ip * ip,struct tcphdr * tcp,int enc)1487 static void *get_iv(struct tc *tc, struct ip *ip, struct tcphdr *tcp, int enc)
1488 {
1489 	static uint64_t seq;
1490 	uint64_t isn = enc ? tc->tc_isn : tc->tc_isn_peer;
1491 	void *iv = NULL;
1492 
1493 	/* XXX byte order */
1494 
1495 	if (_conf.cf_rdr) {
1496 		seq = enc ? tc->tc_rdr_tx : tc->tc_rdr_rx;
1497 
1498 		return &seq;
1499 	}
1500 
1501 	switch (tc->tc_sym_ivmode) {
1502 	case IVMODE_CRYPT:
1503 		assert(!"codeme");
1504 		break;
1505 
1506 	case IVMODE_SEQ:
1507 		/* XXX WRAP */
1508 		seq = htonl(tcp->th_seq) - isn;
1509 		iv = &seq;
1510 		break;
1511 
1512 	case IVMODE_NONE:
1513 		break;
1514 
1515 	default:
1516 		assert(!"sdfsfd");
1517 		break;
1518 	}
1519 
1520 	return iv;
1521 }
1522 
add_data(struct tc * tc,struct ip * ip,struct tcphdr * tcp,int head,int tail)1523 static int add_data(struct tc *tc, struct ip *ip, struct tcphdr *tcp,
1524 		    int head, int tail)
1525 {
1526 	int thlen   = tcp->th_off * 4;
1527 	int datalen = tcp_data_len(ip, tcp);
1528 	int totlen = (ip->ip_hl * 4) + thlen + head + datalen + tail;
1529 	uint8_t *data = tcp_data(tcp);
1530 
1531 	/* extend packet
1532          * We assume we clamped the MSS
1533          */
1534 	if (totlen >= 1500) {
1535 		xprintf(XP_DEBUG, "Damn... sending large packet %d\n", totlen);
1536 		return -1;
1537 	}
1538 
1539 	set_ip_len(ip, totlen);
1540 
1541 	/* move data forward */
1542 	memmove(data + head, data, datalen);
1543 
1544 	return 0;
1545 }
1546 
encrypt_and_mac(struct tc * tc,struct ip * ip,struct tcphdr * tcp)1547 static int encrypt_and_mac(struct tc *tc, struct ip *ip, struct tcphdr *tcp)
1548 {
1549 	uint8_t *data = tcp_data(tcp);
1550 	int dlen = tcp_data_len(ip, tcp);
1551 	void *iv = NULL;
1552 	struct crypt *c = tc->tc_key_active->tc_alg_tx->cs_cipher;
1553 	int head;
1554 	struct tc_record *record;
1555 	int maclen = tc->tc_mac_size + tc->tc_mac_ivlen;
1556 	struct tc_flags *flags;
1557 	uint8_t *mac;
1558 
1559 	if (!dlen) {
1560 		fixup_seq_add(tc, ip, tcp, 0, 0);
1561 		return 0;
1562 	}
1563 
1564 	/* TLV + flags */
1565 	head = sizeof(*record) + 1;
1566 
1567 	if (tcp->th_flags & TH_URG)
1568 		head += 2;
1569 
1570 	/* XXX should check if add_data fails first */
1571 	fixup_seq_add(tc, ip, tcp, head + maclen, 0);
1572 
1573 	if (add_data(tc, ip, tcp, head, maclen))
1574 		return -1;
1575 
1576 	iv = get_iv(tc, ip, tcp, 1);
1577 
1578 	/* Prepare TLV */
1579 	record = tcp_data(tcp);
1580 	record->tr_control = 0;
1581 	record->tr_len     = htons(tcp_data_len(ip, tcp) - sizeof(*record));
1582 
1583 	/* Prepare flags */
1584 	flags = (struct tc_flags *) record->tr_data;
1585 	flags->tf_flags = 0;
1586 	flags->tf_flags |= tcp->th_flags & TH_FIN ? TCF_FIN : 0;
1587 	flags->tf_flags |= tcp->th_flags & TH_URG ? TCF_URG : 0;
1588 
1589 	if (flags->tf_flags & TCF_URG)
1590 		flags->tf_urp[0] = tcp->th_urp;
1591 
1592 	mac = data + tcp_data_len(ip, tcp) - maclen;
1593 
1594 	c->c_aead_encrypt(c, iv, record, sizeof(*record),
1595 			  data + sizeof(*record), dlen + head - sizeof(*record),
1596 			  mac);
1597 
1598 	profile_add(1, "do_output post sym encrypt and mac");
1599 
1600 	return 0;
1601 }
1602 
connected(struct tc * tc)1603 static int connected(struct tc *tc)
1604 {
1605 	return tc->tc_state == STATE_ENCRYPTING
1606 	       || tc->tc_state == STATE_REKEY_SENT
1607 	       || tc->tc_state == STATE_REKEY_RCVD;
1608 }
1609 
do_rekey(struct tc * tc)1610 static void do_rekey(struct tc *tc)
1611 {
1612 	assert(!tc->tc_key_next.tc_alg_rx);
1613 
1614 	tc->tc_keygen++;
1615 
1616 	assert(!"implement");
1617 //	crypto_mac_set_key(tc, tc->tc_mk.s_data, tc->tc_mk.s_len);
1618 
1619 	compute_mk(tc, &tc->tc_mk);
1620 	compute_keys(tc, &tc->tc_key_next);
1621 
1622 	xprintf(XP_DEFAULT, "Rekeying, keygen %d [%p]\n", tc->tc_keygen, tc);
1623 }
1624 
rekey_complete(struct tc * tc)1625 static int rekey_complete(struct tc *tc)
1626 {
1627 	if (tc->tc_keygenrx != tc->tc_keygen) {
1628 		assert((uint8_t)(tc->tc_keygenrx + 1) == tc->tc_keygen);
1629 
1630 		return 0;
1631 	}
1632 
1633 	if (tc->tc_keygentx != tc->tc_keygen) {
1634 		assert((uint8_t)(tc->tc_keygentx + 1) == tc->tc_keygen);
1635 
1636 		return 0;
1637 	}
1638 
1639 	assert(tc->tc_key_current.tc_alg_tx);
1640 	assert(tc->tc_key_next.tc_alg_tx);
1641 
1642 	crypto_free_keyset(tc, &tc->tc_key_current);
1643 	memcpy(&tc->tc_key_current, &tc->tc_key_next,
1644 	       sizeof(tc->tc_key_current));
1645 	memset(&tc->tc_key_next, 0, sizeof(tc->tc_key_next));
1646 
1647 	tc->tc_state = STATE_ENCRYPTING;
1648 
1649 	xprintf(XP_DEBUG, "Rekey complete %d [%p]\n", tc->tc_keygen, tc);
1650 
1651 	return 1;
1652 }
1653 
do_output_encrypting(struct tc * tc,struct ip * ip,struct tcphdr * tcp)1654 static int do_output_encrypting(struct tc *tc, struct ip *ip,
1655 				struct tcphdr *tcp)
1656 {
1657 	if (tcp->th_flags == (TH_SYN | TH_ACK)) {
1658 		/* XXX I assume we just sent ACK to dude but he didn't get it
1659 		 * yet
1660 		 */
1661 		return DIVERT_DROP;
1662 	}
1663 
1664 	/* We're retransmitting INIT2 */
1665 	if (tc->tc_retransmit) {
1666 		/* XXX */
1667 		ip  = (struct ip*) tc->tc_retransmit->r_packet;
1668 		tcp = (struct tcphdr*) (ip + 1);
1669 		assert(is_init(ip, tcp, TC_INIT2));
1670 
1671 		return DIVERT_ACCEPT;
1672 	}
1673 
1674 	assert(!(tcp->th_flags & TH_SYN));
1675 
1676 	tc->tc_key_active = &tc->tc_key_current;
1677 
1678 	profile_add(1, "do_output pre sym encrypt");
1679 	if (encrypt_and_mac(tc, ip, tcp)) {
1680 		/* hopefully pmtu disc works */
1681 		xprintf(XP_ALWAYS, "No space for MAC - dropping\n");
1682 
1683 		return DIVERT_DROP;
1684 	}
1685 
1686 	return DIVERT_MODIFY;
1687 }
1688 
sack_disable(struct tc * tc,struct tcphdr * tcp)1689 static int sack_disable(struct tc *tc, struct tcphdr *tcp)
1690 {
1691 	struct {
1692 		uint8_t	kind;
1693 		uint8_t len;
1694 	} *sack;
1695 
1696 	sack = find_opt(tcp, TCPOPT_SACK_PERMITTED);
1697 	if (!sack)
1698 		return DIVERT_ACCEPT;
1699 
1700 	memset(sack, TCPOPT_NOP, sizeof(*sack));
1701 
1702 	return DIVERT_MODIFY;
1703 }
1704 
do_tcp_output(struct tc * tc,struct ip * ip,struct tcphdr * tcp)1705 static int do_tcp_output(struct tc *tc, struct ip *ip, struct tcphdr *tcp)
1706 {
1707 	int rc = DIVERT_ACCEPT;
1708 
1709 	if (tcp->th_flags & TH_SYN)
1710 		tc->tc_isn = ntohl(tcp->th_seq) + 1;
1711 
1712 	if (tcp->th_flags == TH_SYN) {
1713 		if (tc->tc_tcp_state == TCPSTATE_LASTACK) {
1714 			tc_finish(tc);
1715 			tc_reset(tc);
1716 		}
1717 
1718 		rc = sack_disable(tc, tcp);
1719 	}
1720 
1721 	if (tcp->th_flags & TH_FIN) {
1722 		switch (tc->tc_tcp_state) {
1723 		case TCPSTATE_FIN1_RCVD:
1724 			tc->tc_tcp_state = TCPSTATE_FIN2_SENT;
1725 			break;
1726 
1727 		case TCPSTATE_FIN2_SENT:
1728 			break;
1729 
1730 		default:
1731 			tc->tc_tcp_state = TCPSTATE_FIN1_SENT;
1732 		}
1733 
1734 		return rc;
1735 	}
1736 
1737 	if (tcp->th_flags & TH_RST) {
1738 		tc->tc_tcp_state = TCPSTATE_DEAD;
1739 		return rc;
1740 	}
1741 
1742 	if (!(tcp->th_flags & TH_ACK))
1743 		return rc;
1744 
1745 	switch (tc->tc_tcp_state) {
1746 	case TCPSTATE_FIN2_RCVD:
1747 		tc->tc_tcp_state = TCPSTATE_LASTACK;
1748 		if (!tc->tc_last_ack_timer)
1749 			tc->tc_last_ack_timer = add_timer(10 * 1000 * 1000,
1750 							  last_ack, tc);
1751 		else
1752 			xprintf(XP_DEFAULT, "uarning\n");
1753 		break;
1754 	}
1755 
1756 	return rc;
1757 }
1758 
do_output_nextk1_rcvd(struct tc * tc,struct ip * ip,struct tcphdr * tcp)1759 static int do_output_nextk1_rcvd(struct tc *tc, struct ip *ip,
1760 				 struct tcphdr *tcp)
1761 {
1762 	struct tcpopt_eno *eno;
1763 	int len;
1764 	int i = 0;
1765 
1766 	if (!tc->tc_sess)
1767 		return do_output_hello_rcvd(tc, ip, tcp);
1768 
1769 	len = sizeof(*eno) + 1;
1770 
1771 	if (tc->tc_app_support)
1772 		len += 1;
1773 
1774 	eno = tcp_opts_alloc(tc, ip, tcp, len);
1775 	if (!eno) {
1776 		xprintf(XP_ALWAYS, "No space for NEXTK2\n");
1777 		tc->tc_state = STATE_DISABLED;
1778 		return DIVERT_ACCEPT;
1779 	}
1780 
1781 	set_eno(eno, len);
1782 
1783 	if (tc->tc_app_support)
1784 		eno->toe_opts[i++] = tc->tc_app_support << 1;
1785 
1786 	eno->toe_opts[i++] = tc->tc_sess->ts_pub_spec | TC_OPT_VLEN;
1787 
1788 	tc->tc_state = STATE_NEXTK2_SENT;
1789 
1790 	return DIVERT_MODIFY;
1791 }
1792 
do_output(struct tc * tc,struct ip * ip,struct tcphdr * tcp)1793 static int do_output(struct tc *tc, struct ip *ip, struct tcphdr *tcp)
1794 {
1795 	int rc = DIVERT_ACCEPT;
1796 	int tcp_rc;
1797 
1798 	tcp_rc = do_tcp_output(tc, ip, tcp);
1799 
1800 	/* an RST half way through the handshake */
1801 	if (tc->tc_tcp_state == TCPSTATE_DEAD
1802 	    && !connected(tc))
1803 		return tcp_rc;
1804 
1805 	switch (tc->tc_state) {
1806 	case STATE_HELLO_SENT:
1807 	case STATE_NEXTK1_SENT:
1808 		/* syn re-TX.  fallthrough */
1809 		assert(tcp->th_flags & TH_SYN);
1810 	case STATE_CLOSED:
1811 		rc = do_output_closed(tc, ip, tcp);
1812 		break;
1813 
1814 	case STATE_PKCONF_SENT:
1815 		/* reTX of syn ack, or ACK (role switch) */
1816 	case STATE_HELLO_RCVD:
1817 		rc = do_output_hello_rcvd(tc, ip, tcp);
1818 		break;
1819 
1820 	case STATE_NEXTK2_SENT:
1821 		/* syn ack rtx */
1822 		assert(tc->tc_sess);
1823 		assert(tcp->th_flags == (TH_SYN | TH_ACK));
1824 	case STATE_NEXTK1_RCVD:
1825 		rc = do_output_nextk1_rcvd(tc, ip, tcp);
1826 		break;
1827 
1828 	case STATE_PKCONF_RCVD:
1829 		rc = do_output_pkconf_rcvd(tc, ip, tcp, 0);
1830 		break;
1831 
1832 	case STATE_INIT1_RCVD:
1833 		rc = do_output_init1_rcvd(tc, ip, tcp);
1834 		break;
1835 
1836 	case STATE_INIT1_SENT:
1837 		if (!is_init(ip, tcp, TC_INIT1))
1838 			rc = do_output_pkconf_rcvd(tc, ip, tcp, 1);
1839 		break;
1840 
1841 	case STATE_INIT2_SENT:
1842 		rc = do_output_init2_sent(tc, ip, tcp);
1843 		break;
1844 
1845 	case STATE_ENCRYPTING:
1846 	case STATE_REKEY_SENT:
1847 	case STATE_REKEY_RCVD:
1848 		rc = do_output_encrypting(tc, ip, tcp);
1849 		break;
1850 
1851 	case STATE_DISABLED:
1852 		rc = DIVERT_ACCEPT;
1853 		break;
1854 
1855 	default:
1856 		xprintf(XP_ALWAYS, "Unknown state %d\n", tc->tc_state);
1857 		abort();
1858 	}
1859 
1860 	if (rc == DIVERT_ACCEPT)
1861 		return tcp_rc;
1862 
1863 	return rc;
1864 }
1865 
session_find(struct tc * tc,struct tc_sid * sid)1866 static struct tc_sess *session_find(struct tc *tc, struct tc_sid *sid)
1867 {
1868 	struct tc_sess *s = _sessions.ts_next;
1869 
1870 	while (s) {
1871 		if (tc->tc_dir == s->ts_dir
1872 		    && memcmp(sid, s->ts_sid.s_data, sizeof(*sid)) == 0)
1873 			return s;
1874 
1875 		s = s->ts_next;
1876 	}
1877 
1878 	return NULL;
1879 }
1880 
do_clamp_mss(struct tc * tc,uint16_t * mss)1881 static int do_clamp_mss(struct tc *tc, uint16_t *mss)
1882 {
1883 	int len;
1884 
1885 	len = ntohs(*mss) - tc->tc_mss_clamp;
1886 	assert(len > 0);
1887 
1888 	*mss = htons(len);
1889 
1890 	xprintf(XP_NOISY, "Clamping MSS to %d\n", len);
1891 
1892 	return DIVERT_MODIFY;
1893 }
1894 
negotiate_cipher(struct tc * tc,struct tc_cipher_spec * a,int an)1895 static int negotiate_cipher(struct tc *tc, struct tc_cipher_spec *a, int an)
1896 {
1897 	struct tc_cipher_spec *b = tc->tc_ciphers_pkey;
1898 	int bn = tc->tc_ciphers_pkey_len / sizeof(*tc->tc_ciphers_pkey);
1899 	struct tc_cipher_spec *out = &tc->tc_cipher_pkey;
1900 
1901 	tc->tc_pub_cipher_list_len = an * sizeof(*a);
1902 	memcpy(tc->tc_pub_cipher_list, a, tc->tc_pub_cipher_list_len);
1903 
1904 	while (an--) {
1905 		while (bn--) {
1906 			if (a->tcs_algo == b->tcs_algo) {
1907 				out->tcs_algo = a->tcs_algo;
1908 				return 1;
1909 			}
1910 
1911 			b++;
1912 		}
1913 
1914 		a++;
1915 	}
1916 
1917 	return 0;
1918 }
1919 
init_pkey(struct tc * tc)1920 static void init_pkey(struct tc *tc)
1921 {
1922 	struct ciphers *c = _ciphers_pkey.c_next;
1923 	struct tc_cipher_spec *s;
1924 
1925 	assert(tc->tc_cipher_pkey.tcs_algo);
1926 
1927 	while (c) {
1928 		s = (struct tc_cipher_spec*) c->c_spec;
1929 
1930 		if (s->tcs_algo == tc->tc_cipher_pkey.tcs_algo) {
1931 			tc->tc_crypt_pub = crypt_new(c->c_cipher->c_ctr);
1932 			return;
1933 		}
1934 
1935 		c = c->c_next;
1936 	}
1937 
1938 	assert(!"Can't init cipher");
1939 }
1940 
check_app_support(struct tc * tc,uint8_t * data,int len)1941 static void check_app_support(struct tc *tc, uint8_t *data, int len)
1942 {
1943 	while (len--) {
1944 		/* general option */
1945 		if ((*data >> 4) == 0) {
1946 			/* application aware bit */
1947 			if (*data & 2)
1948 				tc->tc_app_support |= 2;
1949 		}
1950 
1951 		data++;
1952 	}
1953 }
1954 
can_session_resume(struct tc * tc,uint8_t * data,int len)1955 static int can_session_resume(struct tc *tc, uint8_t *data, int len)
1956 {
1957 	int i;
1958 	struct tc_sid *sid = NULL;
1959 
1960 	for (i = 0; i <= (len - (int) sizeof(*sid)); i++) {
1961 		/* XXX should check spec / other opts of var length */
1962 		if (data[i] & TC_OPT_VLEN) {
1963 			sid = (struct tc_sid*) &data[i];
1964 
1965 			if ((tc->tc_sess = session_find(tc, sid)))
1966 				break;
1967 		}
1968 	}
1969 
1970 	profile_add(2, "found session");
1971 
1972 	if (!tc->tc_sess)
1973 		return 0;
1974 
1975 	tc->tc_state = STATE_NEXTK1_RCVD;
1976 
1977 	return 1;
1978 }
1979 
input_closed_eno(struct tc * tc,uint8_t * data,int len)1980 static void input_closed_eno(struct tc *tc, uint8_t *data, int len)
1981 {
1982 	struct tc_cipher_spec *cipher = (struct tc_cipher_spec*) data;
1983 
1984 	check_app_support(tc, data, len);
1985 
1986 	if (can_session_resume(tc, data, len))
1987 		return;
1988 
1989 	if (!negotiate_cipher(tc, cipher, len)) {
1990 		xprintf(XP_ALWAYS, "No cipher\n");
1991 		tc->tc_state = STATE_DISABLED;
1992 		return;
1993 	}
1994 
1995 	init_pkey(tc);
1996 
1997 	tc->tc_state = STATE_HELLO_RCVD;
1998 }
1999 
opt_input_closed(struct tc * tc,int tcpop,int len,void * data)2000 static int opt_input_closed(struct tc *tc, int tcpop, int len, void *data)
2001 {
2002 	uint8_t *p;
2003 
2004 	profile_add(2, "opt_input_closed in");
2005 
2006 	if (get_eno(tcpop, &data, &len))
2007 		input_closed_eno(tc, data, len);
2008 
2009 	switch (tcpop) {
2010 	case TCPOPT_SACK_PERMITTED:
2011 		p     = data;
2012 		p[-2] = TCPOPT_NOP;
2013 		p[-1] = TCPOPT_NOP;
2014 		tc->tc_verdict = DIVERT_MODIFY;
2015 		break;
2016 
2017 	case TCPOPT_MAXSEG:
2018 		if (do_clamp_mss(tc, data) == DIVERT_MODIFY)
2019 			tc->tc_verdict = DIVERT_MODIFY;
2020 
2021 		tc->tc_mss_clamp = -1;
2022 		break;
2023 	}
2024 
2025 	profile_add(2, "opt_input_closed out");
2026 
2027 	return 0;
2028 }
2029 
do_input_closed(struct tc * tc,struct ip * ip,struct tcphdr * tcp)2030 static int do_input_closed(struct tc *tc, struct ip *ip, struct tcphdr *tcp)
2031 {
2032 	tc->tc_dir = DIR_IN;
2033 
2034 	if (tcp->th_flags != TH_SYN)
2035 		return DIVERT_ACCEPT;
2036 
2037 	tc->tc_verdict = DIVERT_ACCEPT;
2038 	tc->tc_state   = STATE_DISABLED;
2039 
2040 	profile_add(1, "do_input_closed pre option parse");
2041 	foreach_opt(tc, tcp, opt_input_closed);
2042 	profile_add(1, "do_input_closed options parsed");
2043 
2044 	tc->tc_eno_len = 0;
2045 	set_eno_transcript(tc, tcp);
2046 
2047 	return tc->tc_verdict;
2048 }
2049 
make_reply(void * buf,struct ip * ip,struct tcphdr * tcp)2050 static void make_reply(void *buf, struct ip *ip, struct tcphdr *tcp)
2051 {
2052 	struct ip *ip2 = buf;
2053 	struct tcphdr *tcp2;
2054 	int dlen = ntohs(ip->ip_len) - (ip->ip_hl << 2) - (tcp->th_off << 2);
2055 
2056 	ip2->ip_v   = 4;
2057 	ip2->ip_hl  = sizeof(*ip2) >> 2;
2058 	ip2->ip_tos = 0;
2059 	ip2->ip_len = htons(sizeof(*ip2) + sizeof(*tcp2));
2060 	ip2->ip_id  = 0;
2061 	ip2->ip_off = 0;
2062 	ip2->ip_ttl = 128;
2063 	ip2->ip_p   = IPPROTO_TCP;
2064 	ip2->ip_sum = 0;
2065 	ip2->ip_src = ip->ip_dst;
2066 	ip2->ip_dst = ip->ip_src;
2067 
2068 	tcp2 = (struct tcphdr*) (ip2 + 1);
2069 	tcp2->th_sport = tcp->th_dport;
2070 	tcp2->th_dport = tcp->th_sport;
2071 	tcp2->th_seq   = tcp->th_ack;
2072 	tcp2->th_ack   = htonl(ntohl(tcp->th_seq) + dlen);
2073 	tcp2->th_x2    = 0;
2074 	tcp2->th_off   = sizeof(*tcp2) >> 2;
2075 	tcp2->th_flags = TH_ACK;
2076 	tcp2->th_win   = tcp->th_win;
2077 	tcp2->th_sum   = 0;
2078 	tcp2->th_urp   = 0;
2079 }
2080 
alloc_retransmit(struct tc * tc)2081 static void *alloc_retransmit(struct tc *tc)
2082 {
2083 	struct retransmit *r;
2084 	int len;
2085 
2086 	if (_conf.cf_rdr)
2087 		return &tc->tc_rdr_buf[512]; /* XXX */
2088 
2089 	assert(!tc->tc_retransmit);
2090 
2091 	len = sizeof(*r) + tc->tc_mtu;
2092 	r = xmalloc(len);
2093 	memset(r, 0, len);
2094 
2095 	r->r_timer = add_timer(tc->tc_rto, retransmit, tc);
2096 
2097 	tc->tc_retransmit = r;
2098 
2099 	return r->r_packet;
2100 }
2101 
find_eno(struct tcphdr * tcp)2102 static struct tcpopt_eno *find_eno(struct tcphdr *tcp)
2103 {
2104 	struct tcpopt_eno *eno = find_opt(tcp, TCPOPT_EXP);
2105 
2106 	if (!eno)
2107 		return NULL;
2108 
2109 	assert(eno->toe_len >= 2);
2110 
2111 	if (is_eno(eno->toe_kind, (unsigned char*) eno + 2, eno->toe_len - 2))
2112 		return eno;
2113 
2114 	return NULL;
2115 }
2116 
do_input_hello_sent(struct tc * tc,struct ip * ip,struct tcphdr * tcp)2117 static int do_input_hello_sent(struct tc *tc, struct ip *ip, struct tcphdr *tcp)
2118 {
2119 	struct tc_cipher_spec *cipher;
2120 	struct tcpopt_eno *eno;
2121 	int len;
2122 
2123 	if (!(eno = find_eno(tcp))) {
2124 		tc->tc_state = STATE_DISABLED;
2125 
2126 		return DIVERT_ACCEPT;
2127 	}
2128 
2129 	len = eno->toe_len - sizeof(*eno);
2130 	assert(len >= 0);
2131 
2132 	check_app_support(tc, eno->toe_opts, len);
2133 
2134 	cipher = (struct tc_cipher_spec*) eno->toe_opts;
2135 
2136 	/* XXX truncate len as it could go to the variable options (like SID) */
2137 
2138 	if (!negotiate_cipher(tc, cipher, len)) {
2139 		xprintf(XP_ALWAYS, "No cipher\n");
2140 		tc->tc_state = STATE_DISABLED;
2141 
2142 		return DIVERT_ACCEPT;
2143 	}
2144 
2145 	set_eno_transcript(tc, tcp);
2146 
2147 	init_pkey(tc);
2148 
2149 	tc->tc_state = STATE_PKCONF_RCVD;
2150 
2151 	return DIVERT_ACCEPT;
2152 }
2153 
do_neg_sym(struct tc * tc,struct ciphers * c,struct tc_scipher * a)2154 static void do_neg_sym(struct tc *tc, struct ciphers *c, struct tc_scipher *a)
2155 {
2156 	struct tc_scipher *b;
2157 
2158 	c = c->c_next;
2159 
2160 	while (c) {
2161 		b = (struct tc_scipher*) c->c_spec;
2162 
2163 		if (b->sc_algo == a->sc_algo) {
2164 			tc->tc_crypt_sym = crypt_new(c->c_cipher->c_ctr);
2165 			tc->tc_cipher_sym.sc_algo = a->sc_algo;
2166 			break;
2167 		}
2168 
2169 		c = c->c_next;
2170 	}
2171 }
2172 
negotiate_sym_cipher(struct tc * tc,struct tc_scipher * a,int alen)2173 static int negotiate_sym_cipher(struct tc *tc, struct tc_scipher *a, int alen)
2174 {
2175 	int rc = 0;
2176 
2177 	tc->tc_sym_cipher_list_len = alen * sizeof(*a);
2178 	memcpy(tc->tc_sym_cipher_list, a, tc->tc_sym_cipher_list_len);
2179 
2180 	while (alen--) {
2181 		do_neg_sym(tc, &_ciphers_sym, a);
2182 
2183 		if (tc->tc_crypt_sym) {
2184 			rc = 1;
2185 			break;
2186 		}
2187 
2188 		a++;
2189 	}
2190 
2191 	return rc;
2192 }
2193 
select_pkey(struct tc * tc,struct tc_cipher_spec * pkey)2194 static int select_pkey(struct tc *tc, struct tc_cipher_spec *pkey)
2195 {
2196 	struct tc_cipher_spec *spec;
2197 	struct ciphers *c = _ciphers_pkey.c_next;
2198 	int i;
2199 
2200 	/* check whether we know about the cipher */
2201 	while (c) {
2202 		spec = (struct tc_cipher_spec*) c->c_spec;
2203 
2204 		if (spec->tcs_algo == pkey->tcs_algo) {
2205 			tc->tc_crypt_pub = crypt_new(c->c_cipher->c_ctr);
2206 			break;
2207 		}
2208 
2209 		c = c->c_next;
2210 	}
2211 	if (!c)
2212 		return 0;
2213 
2214 	/* check whether we were willing to accept this cipher */
2215 	for (i = 0; i < tc->tc_ciphers_pkey_len / sizeof(*tc->tc_ciphers_pkey);
2216 	     i++) {
2217 		spec = &tc->tc_ciphers_pkey[i];
2218 
2219 		if (spec->tcs_algo == pkey->tcs_algo) {
2220 			tc->tc_cipher_pkey = *pkey;
2221 			return 1;
2222 		}
2223 	}
2224 
2225 	/* XXX cleanup */
2226 
2227 	return 0;
2228 }
2229 
compute_ss(struct tc * tc)2230 static void compute_ss(struct tc *tc)
2231 {
2232 	struct iovec iov[4];
2233 
2234 	profile_add(1, "compute ss in");
2235 
2236 	iov[0].iov_base = tc->tc_eno;
2237 	iov[0].iov_len  = tc->tc_eno_len;
2238 
2239 	iov[1].iov_base = tc->tc_init1;
2240 	iov[1].iov_len  = tc->tc_init1_len;
2241 
2242 	iov[2].iov_base = tc->tc_init2;
2243 	iov[2].iov_len  = tc->tc_init2_len;
2244 
2245 	iov[3].iov_base = tc->tc_pms;
2246 	iov[3].iov_len  = tc->tc_pms_len;
2247 
2248 	crypt_set_key(tc->tc_crypt_pub->cp_hkdf,
2249 		      tc->tc_nonce, tc->tc_nonce_len);
2250 
2251 	profile_add(1, "compute ss mac set key");
2252 
2253 	tc->tc_ss.s_len = sizeof(tc->tc_ss.s_data);
2254 
2255 	crypt_extract(tc->tc_crypt_pub->cp_hkdf, iov,
2256 		      sizeof(iov) / sizeof(*iov), tc->tc_ss.s_data,
2257 	              &tc->tc_ss.s_len);
2258 
2259 	assert(tc->tc_ss.s_len <= sizeof(tc->tc_ss.s_data));
2260 
2261 	profile_add(1, "compute ss did MAC");
2262 }
2263 
process_init1(struct tc * tc,struct ip * ip,struct tcphdr * tcp,uint8_t * kxs,int kxs_len)2264 static int process_init1(struct tc *tc, struct ip *ip, struct tcphdr *tcp,
2265 			 uint8_t *kxs, int kxs_len)
2266 {
2267 	struct tc_init1 *i1;
2268 	int dlen;
2269 	uint8_t *nonce;
2270 	int nonce_len;
2271 	void *key;
2272 	int klen;
2273 	int cl;
2274 	void *pms;
2275 	int pmsl;
2276 	int len;
2277 	uint8_t *p;
2278 
2279 	if (!is_init(ip, tcp, TC_INIT1))
2280 		return bad_packet("can't find init1");
2281 
2282 	dlen = tcp_data_len(ip, tcp);
2283 	i1   = tcp_data(tcp);
2284 
2285 	if (!select_pkey(tc, &tc->tc_cipher_pkey))
2286 		return bad_packet("init1: bad public key");
2287 
2288 	klen 	  = crypt_get_key(tc->tc_crypt_pub->cp_pub, &key);
2289 	nonce_len = tc->tc_crypt_pub->cp_n_c;
2290 	len 	  = sizeof(*i1) + i1->i1_nciphers + nonce_len + klen;
2291 
2292 	/* strict len for now */
2293 	if (len != dlen || len != ntohl(i1->i1_len))
2294 	    	return bad_packet("bad init1 len");
2295 
2296 	p = i1->i1_data;
2297 	if (!negotiate_sym_cipher(tc, (struct tc_scipher *) p, i1->i1_nciphers))
2298 		return bad_packet("init1: can't negotiate");
2299 
2300 	nonce = p + i1->i1_nciphers;
2301 	key   = nonce + nonce_len;
2302 
2303 	profile_add(1, "pre pkey set key");
2304 
2305 	/* figure out key len */
2306 	if (crypt_set_key(tc->tc_crypt_pub->cp_pub, key, klen) == -1)
2307 		return bad_packet("init1: bad pubkey");
2308 
2309 	profile_add(1, "pkey set key");
2310 
2311 	generate_nonce(tc, tc->tc_crypt_pub->cp_n_s);
2312 
2313 	/* XXX fix crypto api to have from to args */
2314 	memcpy(kxs, tc->tc_nonce, tc->tc_nonce_len);
2315 	cl = crypt_encrypt(tc->tc_crypt_pub->cp_pub,
2316 			   NULL, kxs, tc->tc_nonce_len);
2317 
2318 	assert(cl <= kxs_len); /* XXX too late to check */
2319 
2320 	pms  = tc->tc_nonce;
2321 	pmsl = tc->tc_nonce_len;
2322 
2323 	if (tc->tc_crypt_pub->cp_key_agreement) {
2324 		pms = alloca(1024);
2325 		pmsl = crypt_compute_key(tc->tc_crypt_pub->cp_pub, pms);
2326 
2327 		assert(pmsl < 1024); /* XXX */
2328 	}
2329 
2330 	assert(dlen <= sizeof(tc->tc_init1));
2331 
2332 	memcpy(tc->tc_init1, i1, dlen);
2333 	tc->tc_init1_len = dlen;
2334 
2335 	assert(pmsl <= sizeof(tc->tc_pms));
2336 	memcpy(tc->tc_pms, pms, pmsl);
2337 	tc->tc_pms_len = pmsl;
2338 
2339 	assert(nonce_len <= sizeof(tc->tc_nonce));
2340 	memcpy(tc->tc_nonce, nonce, nonce_len);
2341 	tc->tc_nonce_len = nonce_len;
2342 
2343 	tc->tc_state = STATE_INIT1_RCVD;
2344 
2345 	tc->tc_isn_peer = ntohl(tcp->th_seq) + dlen;
2346 
2347 	return 1;
2348 }
2349 
swallow_data(struct ip * ip,struct tcphdr * tcp)2350 static int swallow_data(struct ip *ip, struct tcphdr *tcp)
2351 {
2352 	int len, dlen;
2353 
2354 	len  = (ip->ip_hl << 2) + (tcp->th_off << 2);
2355 	dlen = ntohs(ip->ip_len) - len;
2356 	set_ip_len(ip, len);
2357 
2358 	return dlen;
2359 }
2360 
do_input_pkconf_sent(struct tc * tc,struct ip * ip,struct tcphdr * tcp)2361 static int do_input_pkconf_sent(struct tc *tc, struct ip *ip,
2362 				struct tcphdr *tcp)
2363 {
2364 	int len, dlen;
2365 	void *buf;
2366 	struct ip *ip2;
2367 	struct tcphdr *tcp2;
2368 	struct tc_init2 *i2;
2369 	uint8_t kxs[1024];
2370 	int cipherlen;
2371 	struct tcpopt_eno *eno;
2372 	int rdr = _conf.cf_rdr;
2373 
2374 	/* Check to see if the other side added ENO per
2375 	   Section 3.2 of draft-ietf-tcpinc-tcpeno-00. */
2376 	if (!rdr && !(eno = find_eno(tcp))) {
2377 		xprintf(XP_DEBUG, "No ENO option found in expected INIT1\n");
2378 		tc->tc_state = STATE_DISABLED;
2379 
2380 		return DIVERT_ACCEPT;
2381 	}
2382 
2383 	/* syn retransmission */
2384 	if (tcp->th_flags == TH_SYN)
2385 		return do_input_closed(tc, ip, tcp);
2386 
2387 	if (!process_init1(tc, ip, tcp, kxs, sizeof(kxs))) {
2388 		/* XXX. Per Section 3.2 of draft-ietf-tcpinc-tcpeno-00,
2389 		   you are supposed to tear down the connection.
2390 		   This is a bug.
2391 		*/
2392 		tc->tc_state = STATE_DISABLED;
2393 
2394 		return DIVERT_ACCEPT;
2395 	}
2396 
2397 	cipherlen = tc->tc_crypt_pub->cp_cipher_len;
2398 
2399 	/* send init2 */
2400 	buf = alloc_retransmit(tc);
2401 	make_reply(buf, ip, tcp);
2402 	ip2 = (struct ip*) buf;
2403 	tcp2 = (struct tcphdr*) (ip2 + 1);
2404 
2405 	len = sizeof(*i2) + cipherlen;
2406 	i2  = data_alloc(tc, ip2, tcp2, len, 0);
2407 
2408 	i2->i2_magic  = htonl(TC_INIT2);
2409 	i2->i2_len    = htonl(len);
2410 	i2->i2_cipher = tc->tc_cipher_sym.sc_algo;
2411 
2412 	memcpy(i2->i2_data, kxs, cipherlen);
2413 
2414 	if (_conf.cf_rsa_client_hack)
2415 		memcpy(i2->i2_data, tc->tc_nonce, tc->tc_nonce_len);
2416 
2417 	assert(len <= sizeof(tc->tc_init2));
2418 
2419 	memcpy(tc->tc_init2, i2, len);
2420 	tc->tc_init2_len = len;
2421 
2422 	tc->tc_isn = ntohl(tcp2->th_seq) + len;
2423 
2424 	checksum_packet(tc, ip2, tcp2);
2425 
2426 	inject_ip(ip2);
2427 
2428 	tc->tc_state = STATE_INIT2_SENT;
2429 
2430 	/* swallow data - ewwww */
2431 	dlen = swallow_data(ip, tcp);
2432 
2433 	tc->tc_rseq_off = dlen;
2434 	tc->tc_role     = ROLE_SERVER;
2435 
2436 	compute_ss(tc);
2437 
2438 #if 1
2439 	return DIVERT_MODIFY;
2440 #else
2441 	/* we let the ACK of INIT2 through to complete the handshake */
2442 	return DIVERT_DROP;
2443 #endif
2444 }
2445 
select_sym(struct tc * tc,struct tc_scipher * s)2446 static int select_sym(struct tc *tc, struct tc_scipher *s)
2447 {
2448 	struct tc_scipher *me = tc->tc_ciphers_sym;
2449 	int len = tc->tc_ciphers_sym_len;
2450 	int sym = 0;
2451 	struct ciphers *c;
2452 
2453 	/* check if we approve it */
2454 	while (len) {
2455 		if (memcmp(me, s, sizeof(*s)) == 0) {
2456 			sym = 1;
2457 			break;
2458 		}
2459 
2460 		me++;
2461 		len -= sizeof(*me);
2462 		assert(len >= 0);
2463 	}
2464 
2465 	if (!sym)
2466 		return 0;
2467 
2468 	/* select ciphers */
2469 	c = _ciphers_sym.c_next;
2470 	while (c) {
2471 		me = (struct tc_scipher*) c->c_spec;
2472 
2473 		if (me->sc_algo == s->sc_algo) {
2474 			tc->tc_crypt_sym = crypt_new(c->c_cipher->c_ctr);
2475 			break;
2476 		}
2477 
2478 		c = c->c_next;
2479 	}
2480 
2481 	assert(tc->tc_crypt_sym);
2482 
2483 	memcpy(&tc->tc_cipher_sym, s, sizeof(*s));
2484 
2485 	return 1;
2486 }
2487 
process_init2(struct tc * tc,struct ip * ip,struct tcphdr * tcp)2488 static int process_init2(struct tc *tc, struct ip *ip, struct tcphdr *tcp)
2489 {
2490 	struct tc_init2 *i2;
2491 	int len;
2492 	int nlen;
2493 	void *nonce;
2494 
2495 	if (!is_init(ip, tcp, TC_INIT2))
2496 		return bad_packet("init2: can't find opt");
2497 
2498 	i2  = tcp_data(tcp);
2499 	len = tcp_data_len(ip, tcp);
2500 
2501 	nlen = tc->tc_crypt_pub->cp_cipher_len;
2502 
2503 	if (sizeof(*i2) + nlen > len || ntohl(i2->i2_len) > len)
2504 		return bad_packet("init2: bad len");
2505 
2506 	if (!select_sym(tc, (struct tc_scipher*) (&i2->i2_cipher)))
2507 		return bad_packet("init2: select_sym()");
2508 
2509 	if (len > sizeof(tc->tc_init2))
2510 		return bad_packet("init2: too long");
2511 
2512 	memcpy(tc->tc_init2, i2, len);
2513 	tc->tc_init2_len = len;
2514 
2515 	tc->tc_isn_peer = ntohl(tcp->th_seq) + len;
2516 
2517 	nonce = i2->i2_data;
2518 	nlen  = crypt_decrypt(tc->tc_crypt_pub->cp_pub, NULL, nonce, nlen);
2519 
2520 	assert(nlen <= sizeof(tc->tc_pms));
2521 	memcpy(tc->tc_pms, nonce, nlen);
2522 	tc->tc_pms_len = nlen;
2523 
2524 	compute_ss(tc);
2525 
2526 	return 1;
2527 }
2528 
ack(struct tc * tc,struct ip * ip,struct tcphdr * tcp)2529 static void ack(struct tc *tc, struct ip *ip, struct tcphdr *tcp)
2530 {
2531 	char buf[2048];
2532 	struct ip *ip2;
2533 	struct tcphdr *tcp2;
2534 
2535 	if (_conf.cf_rdr)
2536 		return;
2537 
2538 	ip2  = (struct ip*) buf;
2539 	tcp2 = (struct tcphdr*) (ip2 + 1);
2540 
2541 	make_reply(buf, ip, tcp);
2542 
2543 	/* XXX */
2544 	tcp2->th_seq = htonl(ntohl(tcp2->th_seq) - tc->tc_seq_off);
2545 	tcp2->th_ack = htonl(ntohl(tcp2->th_ack) - tc->tc_rseq_off);
2546 
2547 	checksum_packet(tc, ip2, tcp2);
2548 	do_inject_ip(ip2);
2549 }
2550 
do_input_init1_sent(struct tc * tc,struct ip * ip,struct tcphdr * tcp)2551 static int do_input_init1_sent(struct tc *tc, struct ip *ip, struct tcphdr *tcp)
2552 {
2553 	int dlen = tcp_data_len(ip, tcp);
2554 
2555 	/* XXX syn ack re-TX - check pkconf */
2556 	if (tcp->th_flags == (TH_SYN | TH_ACK))
2557 		return DIVERT_ACCEPT;
2558 
2559 	/* pure ack after connect */
2560 	if (dlen == 0)
2561 		return DIVERT_ACCEPT;
2562 
2563 	if (!process_init2(tc, ip, tcp)) {
2564 		tc->tc_state = STATE_DISABLED;
2565 		return DIVERT_ACCEPT;
2566 	}
2567 
2568 	dlen = ntohs(ip->ip_len) - (ip->ip_hl << 2) - (tcp->th_off << 2);
2569 	tc->tc_rseq_off = dlen;
2570 
2571 	ack(tc, ip, tcp);
2572 
2573 	enable_encryption(tc);
2574 
2575 	/* we let this packet through to reopen window */
2576 	swallow_data(ip, tcp);
2577 	tcp->th_ack = htonl(ntohl(tcp->th_ack) - tc->tc_seq_off);
2578 
2579 	return DIVERT_MODIFY;
2580 }
2581 
rekey_input(struct tc * tc,struct ip * ip,struct tcphdr * tcp)2582 static struct tco_rekeystream *rekey_input(struct tc *tc, struct ip *ip,
2583 					   struct tcphdr *tcp)
2584 {
2585 	struct tco_rekeystream *tr;
2586 
2587 	/* half way through rekey - figure out current key */
2588 	if (tc->tc_keygentx != tc->tc_keygenrx
2589 	    && tc->tc_keygenrx == tc->tc_keygen)
2590 		tc->tc_key_active = &tc->tc_key_next;
2591 
2592 	/* XXX TODO */
2593 	return NULL;
2594 
2595 	if (tr->tr_key == (uint8_t) ((tc->tc_keygen + 1))) {
2596 		do_rekey(tc);
2597 		tc->tc_state     = STATE_REKEY_RCVD;
2598 		tc->tc_rekey_seq = ntohl(tr->tr_seq);
2599 
2600 		if (tc->tc_rekey_seq != ntohl(tcp->th_seq)) {
2601 			assert(!"implement");
2602 //			unsigned char dummy[] = "a";
2603 //			void *iv = &tr->tr_seq;
2604 
2605 			/* XXX assuming stream, and seq as IV */
2606 //			crypto_decrypt(tc, iv, dummy, sizeof(dummy));
2607 		}
2608 
2609 		/* XXX assert that MAC checks out, else revert */
2610 	}
2611 
2612 	assert(tr->tr_key == tc->tc_keygen);
2613 
2614 	if (tr->tr_key == tc->tc_keygen) {
2615 		/* old news - we've finished rekeying */
2616 		if (tc->tc_state == STATE_ENCRYPTING) {
2617 			assert(tc->tc_keygen == tc->tc_keygenrx
2618 			       && tc->tc_keygen == tc->tc_keygentx);
2619 			return NULL;
2620 		}
2621 
2622 		tc->tc_key_active = &tc->tc_key_next;
2623 	}
2624 
2625 	return tr;
2626 }
2627 
rekey_input_post(struct tc * tc,struct ip * ip,struct tcphdr * tcp,struct tco_rekeystream * tr)2628 static void rekey_input_post(struct tc *tc, struct ip *ip, struct tcphdr *tcp,
2629 			     struct tco_rekeystream *tr)
2630 {
2631 	/* XXX seqno wrap */
2632 	if (tc->tc_state == STATE_REKEY_SENT
2633 	    && ntohl(tcp->th_ack) >= tc->tc_rekey_seq) {
2634 	    	xprintf(XP_DEBUG, "TX rekey done %d %p\n", tc->tc_keygen, tc);
2635 		tc->tc_keygentx++;
2636 		assert(tc->tc_keygentx == tc->tc_keygen);
2637 		if (rekey_complete(tc))
2638 			return;
2639 
2640 		tc->tc_state = STATE_ENCRYPTING;
2641 	}
2642 
2643 	if (tr && (tc->tc_state = STATE_ENCRYPTING)) {
2644 		tc->tc_state     = STATE_REKEY_RCVD;
2645 		tc->tc_rekey_seq = ntohl(tr->tr_seq);
2646 	}
2647 }
2648 
check_mac_and_decrypt(struct tc * tc,struct ip * ip,struct tcphdr * tcp)2649 static int check_mac_and_decrypt(struct tc *tc, struct ip *ip,
2650 				 struct tcphdr *tcp)
2651 {
2652 	int rc;
2653 	struct tc_flags *flags;
2654 	struct tc_record *record = tcp_data(tcp);
2655 	int len = tcp_data_len(ip, tcp);
2656 	int maclen = tc->tc_mac_size + tc->tc_mac_ivlen;
2657 	uint8_t *clear;
2658 	struct crypt *c = tc->tc_key_active->tc_alg_rx->cs_cipher;
2659 	uint8_t *data = (uint8_t*) (record + 1);
2660 	uint8_t *mac = ((uint8_t*) record) + len - maclen;
2661 	void *iv = get_iv(tc, ip, tcp, 0);
2662 	int dlen;
2663 
2664 	if (len == 0) {
2665 		fixup_seq_add(tc, ip, tcp, 0, 1);
2666 		return 0;
2667 	}
2668 
2669 	/* basic length check */
2670 	if (len < (sizeof(*record) + maclen))
2671 		return -1;
2672 
2673 	/* check MAC and decrypt */
2674 	profile_add(1, "do_input pre check_mac and decrypt");
2675 
2676 	rc = c->c_aead_decrypt(c, iv, record, sizeof(*record),
2677 			      data, len - sizeof(*record) - maclen,
2678 			      mac);
2679 
2680 	profile_add(1, "do_input post check_mac and decrypt");
2681 
2682 	if (rc == -1) {
2683 		xprintf(XP_ALWAYS, "MAC check failed\n");
2684 
2685 		if (_conf.cf_debug)
2686 			abort();
2687 
2688 		return -1;
2689 	}
2690 
2691 	/* MAC passed */
2692 
2693 	if (tc->tc_sess) {
2694 		/* When we receive the first MACed packet, we know the other
2695 		 * side is setup so we can cache this session.
2696 		 */
2697 		tc->tc_sess->ts_used = 0;
2698 		tc->tc_sess	     = NULL;
2699 	}
2700 
2701 	/* check record */
2702 	dlen = len - sizeof(*record);
2703 
2704 	if (dlen != ntohs(record->tr_len))
2705 		return -1;
2706 
2707 	if (record->tr_control != 0)
2708 		return -1;
2709 
2710 	if (dlen < maclen)
2711 		return -1;
2712 
2713 	dlen -= maclen;
2714 
2715 	assert(dlen > 0);
2716 
2717 	/* check flags */
2718 	dlen -= sizeof(*flags);
2719 
2720 	if (dlen < 0) {
2721 		xprintf(XP_ALWAYS, "Short packet\n");
2722 		return -1;
2723 	}
2724 
2725 	flags = (struct tc_flags*) (record + 1);
2726 	clear = (uint8_t*) (flags + 1);
2727 
2728 	if (flags->tf_flags & TCF_URG) {
2729 		dlen  -= 2;
2730 		clear += 2;
2731 
2732 		if (dlen < 0) {
2733 			xprintf(XP_ALWAYS, "Short packett\n");
2734 			return -1;
2735 		}
2736 	}
2737 
2738 	fixup_seq_add(tc, ip, tcp, len - dlen, 1);
2739 
2740 	/* remove record, flags, MAC */
2741 	memmove(record, clear, dlen);
2742 	set_ip_len(ip, (ip->ip_hl * 4) + (tcp->th_off * 4) + dlen);
2743 
2744 	return 0;
2745 }
2746 
do_input_encrypting(struct tc * tc,struct ip * ip,struct tcphdr * tcp)2747 static int do_input_encrypting(struct tc *tc, struct ip *ip, struct tcphdr *tcp)
2748 {
2749 	struct tco_rekeystream *tr;
2750 
2751 	tc->tc_key_active = &tc->tc_key_current;
2752 	tr = rekey_input(tc, ip, tcp);
2753 
2754 	if (check_mac_and_decrypt(tc, ip, tcp))
2755 		return DIVERT_DROP;
2756 
2757 	rekey_input_post(tc, ip, tcp, tr);
2758 
2759 	return DIVERT_MODIFY;
2760 }
2761 
do_input_init2_sent(struct tc * tc,struct ip * ip,struct tcphdr * tcp)2762 static int do_input_init2_sent(struct tc *tc, struct ip *ip, struct tcphdr *tcp)
2763 {
2764 	int rc;
2765 
2766 	if (tc->tc_retransmit) {
2767 		assert(is_init(ip, tcp, TC_INIT1));
2768 		return DIVERT_DROP;
2769 	}
2770 
2771 	/* XXX check ACK */
2772 
2773 	enable_encryption(tc);
2774 
2775 	rc = do_input_encrypting(tc, ip, tcp);
2776 	assert(rc != DIVERT_DROP);
2777 
2778 	return rc;
2779 }
2780 
clamp_mss(struct tc * tc,struct ip * ip,struct tcphdr * tcp)2781 static int clamp_mss(struct tc *tc, struct ip *ip, struct tcphdr *tcp)
2782 {
2783 	struct {
2784 		uint8_t	 kind;
2785 		uint8_t	 len;
2786 		uint16_t mss;
2787 	} *mss;
2788 
2789 	if (tc->tc_mss_clamp == -1)
2790 		return DIVERT_ACCEPT;
2791 
2792 	if (!(tcp->th_flags & TH_SYN))
2793 		return DIVERT_ACCEPT;
2794 
2795 	if (tc->tc_state == STATE_DISABLED)
2796 		return DIVERT_ACCEPT;
2797 
2798 	mss = find_opt(tcp, TCPOPT_MAXSEG);
2799 	if (!mss) {
2800 		mss = tcp_opts_alloc(tc, ip, tcp, sizeof(*mss));
2801 		if (!mss) {
2802 			tc->tc_state = STATE_DISABLED;
2803 
2804 			xprintf(XP_ALWAYS, "Can't clamp MSS\n");
2805 
2806 			return DIVERT_ACCEPT;
2807 		}
2808 
2809 		mss->kind = TCPOPT_MAXSEG;
2810 		mss->len  = sizeof(*mss);
2811 		mss->mss  = htons(tc->tc_mtu - sizeof(*ip) - sizeof(*tcp));
2812 	}
2813 
2814 	return do_clamp_mss(tc, &mss->mss);
2815 }
2816 
check_retransmit(struct tc * tc,struct ip * ip,struct tcphdr * tcp)2817 static void check_retransmit(struct tc *tc, struct ip *ip, struct tcphdr *tcp)
2818 {
2819 	struct ip *ip2;
2820 	struct tcphdr *tcp2;
2821 	int seq;
2822 
2823 	if (!tc->tc_retransmit)
2824 		return;
2825 
2826 	if (!(tcp->th_flags & TH_ACK))
2827 		return;
2828 
2829 	ip2  = (struct ip*) tc->tc_retransmit->r_packet;
2830 	tcp2 = (struct tcphdr*) ((unsigned long) ip2 + (ip2->ip_hl << 2));
2831 	seq  = ntohl(tcp2->th_seq) + tcp_data_len(ip2, tcp2);
2832 
2833 	if (ntohl(tcp->th_ack) < seq)
2834 		return;
2835 
2836 	kill_retransmit(tc);
2837 }
2838 
tcp_input_pre(struct tc * tc,struct ip * ip,struct tcphdr * tcp)2839 static int tcp_input_pre(struct tc *tc, struct ip *ip, struct tcphdr *tcp)
2840 {
2841 	int rc = DIVERT_ACCEPT;
2842 
2843 	if (tcp->th_flags & TH_SYN)
2844 		tc->tc_isn_peer = ntohl(tcp->th_seq) + 1;
2845 
2846 	if (tcp->th_flags == TH_SYN && tc->tc_tcp_state == TCPSTATE_LASTACK) {
2847 		tc_finish(tc);
2848 		tc_reset(tc);
2849 	}
2850 
2851 	/* XXX check seq numbers, etc. */
2852 
2853 	check_retransmit(tc, ip, tcp);
2854 
2855 	if (tcp->th_flags & TH_RST) {
2856 		tc->tc_tcp_state = TCPSTATE_DEAD;
2857 		return rc;
2858 	}
2859 
2860 	return rc;
2861 }
2862 
tcp_input_post(struct tc * tc,struct ip * ip,struct tcphdr * tcp)2863 static int tcp_input_post(struct tc *tc, struct ip *ip, struct tcphdr *tcp)
2864 {
2865 	int rc = DIVERT_ACCEPT;
2866 
2867 	if (clamp_mss(tc, ip, tcp) == DIVERT_MODIFY)
2868 		rc = DIVERT_MODIFY;
2869 
2870 	profile_add(2, "did clamp MSS");
2871 
2872 	/* Make sure kernel doesn't send shit until we connect */
2873 	switch (tc->tc_state) {
2874 	case STATE_ENCRYPTING:
2875 	case STATE_REKEY_SENT:
2876 	case STATE_REKEY_RCVD:
2877 	case STATE_DISABLED:
2878 		break;
2879 
2880 	default:
2881 		tcp->th_win = htons(0);
2882 		rc = DIVERT_MODIFY;
2883 		break;
2884 	}
2885 
2886 	if (tcp->th_flags & TH_FIN) {
2887 		switch (tc->tc_tcp_state) {
2888 		case TCPSTATE_FIN1_SENT:
2889 			tc->tc_tcp_state = TCPSTATE_FIN2_RCVD;
2890 			break;
2891 
2892 		case TCPSTATE_LASTACK:
2893 		case TCPSTATE_FIN2_RCVD:
2894 			break;
2895 
2896 		default:
2897 			tc->tc_tcp_state = TCPSTATE_FIN1_RCVD;
2898 			break;
2899 		}
2900 
2901 		return rc;
2902 	}
2903 
2904 	if (tcp->th_flags & TH_RST) {
2905 		tc->tc_tcp_state = TCPSTATE_DEAD;
2906 		return rc;
2907 	}
2908 
2909 	switch (tc->tc_tcp_state) {
2910 	case TCPSTATE_FIN2_SENT:
2911 		if (tcp->th_flags & TH_ACK)
2912 			tc->tc_tcp_state = TCPSTATE_DEAD;
2913 		break;
2914 	}
2915 
2916 	return rc;
2917 }
2918 
do_input_nextk1_sent(struct tc * tc,struct ip * ip,struct tcphdr * tcp)2919 static int do_input_nextk1_sent(struct tc *tc, struct ip *ip,
2920 				struct tcphdr *tcp)
2921 {
2922 	struct tcpopt_eno *eno = find_eno(tcp);
2923 	int len;
2924 
2925 	if (!eno) {
2926 		tc->tc_state = STATE_DISABLED;
2927 
2928 		return DIVERT_ACCEPT;
2929 	}
2930 
2931 	len = eno->toe_len - sizeof(*eno);
2932 
2933 	assert(len >= 0);
2934 	check_app_support(tc, eno->toe_opts, len);
2935 
2936 	/* see if we can resume the session */
2937 	if (len > 0 && eno->toe_opts[len - 1]
2938 	               == (tc->tc_sess->ts_pub_spec | TC_OPT_VLEN)) {
2939 		enable_encryption(tc);
2940 		return DIVERT_ACCEPT;
2941 	}
2942 
2943 	/* nope */
2944 	assert(tc->tc_sess->ts_used);
2945 	tc->tc_sess->ts_used = 0;
2946 	tc->tc_sess = NULL;
2947 
2948 	if (!_conf.cf_nocache)
2949 		xprintf(XP_DEFAULT, "Session caching failed\n");
2950 
2951 	return do_input_hello_sent(tc, ip, tcp);
2952 }
2953 
do_input_nextk2_sent(struct tc * tc,struct ip * ip,struct tcphdr * tcp)2954 static int do_input_nextk2_sent(struct tc *tc, struct ip *ip,
2955 				struct tcphdr *tcp)
2956 {
2957 	int rc;
2958 
2959 	if (tcp->th_flags & TH_SYN)
2960 		return DIVERT_ACCEPT;
2961 
2962 	assert(tcp->th_flags & TH_ACK);
2963 
2964 	enable_encryption(tc);
2965 
2966 	rc = do_input_encrypting(tc, ip, tcp);
2967 	assert(rc != DIVERT_DROP);
2968 
2969 	return rc;
2970 }
2971 
do_input(struct tc * tc,struct ip * ip,struct tcphdr * tcp)2972 static int do_input(struct tc *tc, struct ip *ip, struct tcphdr *tcp)
2973 {
2974 	int rc = DIVERT_DROP;
2975 	int tcp_rc, tcp_rc2;
2976 
2977 	tcp_rc = tcp_input_pre(tc, ip, tcp);
2978 
2979 	/* an RST half way through the handshake */
2980 	if (tc->tc_tcp_state == TCPSTATE_DEAD
2981 	    && !connected(tc))
2982 		return tcp_rc;
2983 
2984 	if (tcp_rc == DIVERT_DROP)
2985 		return DIVERT_ACCEPT; /* kernel will deal with it */
2986 
2987 	switch (tc->tc_state) {
2988 	case STATE_NEXTK1_RCVD:
2989 		/* XXX check same SID */
2990 	case STATE_HELLO_RCVD:
2991 		tc_reset(tc); /* XXX */
2992 	case STATE_CLOSED:
2993 		rc = do_input_closed(tc, ip, tcp);
2994 		break;
2995 
2996 	case STATE_HELLO_SENT:
2997 		rc = do_input_hello_sent(tc, ip, tcp);
2998 		break;
2999 
3000 	case STATE_PKCONF_RCVD:
3001 		/* XXX syn ack re-TX check that we're getting the same shit */
3002 		assert(tcp->th_flags == (TH_SYN | TH_ACK));
3003 		rc = DIVERT_ACCEPT;
3004 		break;
3005 
3006 	case STATE_NEXTK1_SENT:
3007 		rc = do_input_nextk1_sent(tc, ip, tcp);
3008 		break;
3009 
3010 	case STATE_NEXTK2_SENT:
3011 		rc = do_input_nextk2_sent(tc, ip, tcp);
3012 		break;
3013 
3014 	case STATE_PKCONF_SENT:
3015 		rc = do_input_pkconf_sent(tc, ip, tcp);
3016 		break;
3017 
3018 	case STATE_INIT1_SENT:
3019 		rc = do_input_init1_sent(tc, ip, tcp);
3020 		break;
3021 
3022 	case STATE_INIT2_SENT:
3023 		rc = do_input_init2_sent(tc, ip, tcp);
3024 		break;
3025 
3026 	case STATE_ENCRYPTING:
3027 	case STATE_REKEY_SENT:
3028 	case STATE_REKEY_RCVD:
3029 		rc = do_input_encrypting(tc, ip, tcp);
3030 		break;
3031 
3032 	case STATE_DISABLED:
3033 		rc = DIVERT_ACCEPT;
3034 		break;
3035 
3036 	default:
3037 		xprintf(XP_ALWAYS, "Unknown state %d\n", tc->tc_state);
3038 		abort();
3039 	}
3040 
3041 	tcp_rc2 = tcp_input_post(tc, ip, tcp);
3042 
3043 	if (tcp_rc == DIVERT_ACCEPT)
3044 		tcp_rc = tcp_rc2;
3045 
3046 	if (rc == DIVERT_ACCEPT)
3047 		return tcp_rc;
3048 
3049 	return rc;
3050 }
3051 
fake_ip_tcp(struct ip * ip,struct tcphdr * tcp,int len)3052 static void fake_ip_tcp(struct ip *ip, struct tcphdr *tcp, int len)
3053 {
3054 	int hl = sizeof(*ip) + sizeof(*tcp);
3055 
3056 	memset(ip, 0, hl);
3057 
3058 	ip->ip_hl     = sizeof(*ip) / 4;
3059 	ip->ip_len    = htons(len + hl);
3060 
3061 	tcp->th_flags = 0;
3062 	tcp->th_off   = sizeof(*tcp) / 4;
3063 }
3064 
proxy_connection(struct tc * tc)3065 static void proxy_connection(struct tc *tc)
3066 {
3067 	struct ip *ip = (struct ip *) tc->tc_rdr_buf;
3068 	struct tcphdr *tcp = (struct tcphdr*) (ip + 1);
3069 	unsigned char *p = (unsigned char*) (tcp + 1);
3070 	unsigned char *rp = p;
3071         int rc;
3072 	struct tc *peer = tc->tc_rdr_peer;
3073 	struct tc *enc = NULL;
3074 	int out = tc->tc_rdr_state == STATE_RDR_LOCAL;
3075 	int rdlen = 1500 - 256;
3076 	struct tc_record *rec = (struct tc_record*) p;
3077 
3078 	if (tc->tc_state == STATE_ENCRYPTING)
3079 		enc = tc;
3080 	else if (peer->tc_state == STATE_ENCRYPTING)
3081 		enc = peer;
3082 
3083 	/* XXX fix variables / state */
3084 	if (peer->tc_rdr_inbound || tc->tc_rdr_inbound)
3085 		out = !out;
3086 
3087 	/* For incoming traffic, first read the header (record), then read the
3088 	 * rest
3089 	 */
3090 	if (enc && !out) {
3091 		/* we're reading new data - read header */
3092 		if (tc->tc_rdr_len == 0)
3093 			rdlen = sizeof(*rec);
3094 		else {
3095 			/* we already read the header - read the rest */
3096 			rdlen = ntohs(rec->tr_len)
3097 				- (tc->tc_rdr_len - sizeof(*rec));
3098 
3099 			assert(rdlen > 0);
3100 			rp += tc->tc_rdr_len;
3101 		}
3102 	}
3103 
3104         if ((rc = recv(tc->tc_rdr_fd->fd_fd, rp, rdlen, 0)) <= 0) {
3105                 kill_rdr(tc);
3106                 return;
3107         }
3108 
3109 	/* incoming traffic, read the rest */
3110 	if (enc && !out) {
3111 		/* we just started */
3112 		if (tc->tc_rdr_len == 0) {
3113 			if (rc != rdlen) {
3114 				kill_rdr(tc);
3115 				return;
3116 			}
3117 
3118 			rdlen = ntohs(rec->tr_len);
3119 
3120 			/* XXX */
3121 			if (rdlen > sizeof(tc->tc_rdr_buf) - 256) {
3122 				xprintf(XP_ALWAYS, "Record too big %d\n", rdlen);
3123 				kill_rdr(tc);
3124 				return;
3125 			}
3126 		}
3127 
3128 		tc->tc_rdr_len += rc;
3129 		assert(tc->tc_rdr_len >= sizeof(*rec));
3130 
3131 		/* need to read more */
3132 		if ((tc->tc_rdr_len - sizeof(*rec)) < ntohs(rec->tr_len))
3133 			return;
3134 
3135 		/* good to go! */
3136 		rc = tc->tc_rdr_len;
3137 		tc->tc_rdr_len = 0;
3138 	}
3139 
3140 	/* XXX */
3141 	fake_ip_tcp(ip, tcp, rc);
3142 
3143 	if (enc) {
3144 		if (out) {
3145 			rc = do_output_encrypting(enc, ip, tcp);
3146 			rc = tcp_data_len(ip, tcp);
3147 			enc->tc_rdr_tx += rc;
3148 		} else {
3149 			if (do_input_encrypting(enc, ip, tcp) == DIVERT_DROP)
3150 				return;
3151 
3152 			enc->tc_rdr_rx += rc;
3153 			rc = tcp_data_len(ip, tcp);
3154 		}
3155 	}
3156 
3157         /* XXX assuming non-blocking write */
3158         if (send(peer->tc_rdr_fd->fd_fd, p, rc, 0) != rc) {
3159                 kill_rdr(tc);
3160                 return;
3161         }
3162 }
3163 
rdr_handshake_complete(struct tc * tc)3164 static void rdr_handshake_complete(struct tc *tc)
3165 {
3166 	int tos = 0;
3167 
3168 	if (!tc->tc_rdr_fd)
3169 		return;
3170 
3171 #ifndef __WIN32__
3172 	/* stop intercepting handshake - all ENO opts have been set */
3173 	if (setsockopt(tc->tc_rdr_fd->fd_fd, IPPROTO_IP, IP_TOS, &tos,
3174 		       sizeof(tos)) == -1) {
3175 	    perror("setsockopt(IP_TOS)");
3176 	    kill_rdr(tc);
3177 	    return;
3178 	}
3179 #else
3180 	win_handshake_complete(tc->tc_rdr_fd->fd_fd);
3181 #endif
3182 }
3183 
rdr_process_init(struct tc * tc)3184 static void rdr_process_init(struct tc *tc)
3185 {
3186 	int headroom = 40;
3187 	unsigned char buf[2048];
3188 	int len;
3189 	struct ip *ip = (struct ip *) buf;
3190 	struct tcphdr *tcp = (struct tcphdr*) (ip + 1);
3191 	struct fd *fd = tc->tc_rdr_fd;
3192 	struct tc_init1 *i1 = (struct tc_init1*) &buf[headroom];
3193 	int rem = sizeof(buf) - headroom;
3194 	fd_set fds;
3195 	struct timeval tv;
3196 
3197 	/* make sure we read only init1 and not past it.
3198 	 * First, figure out how big init is.  Then read that.
3199 	 */
3200 	if ((len = recv(fd->fd_fd, i1, sizeof(*i1), 0)) != sizeof(*i1))
3201 		goto __kill_rdr;
3202 
3203 	rem -= sizeof(*i1);
3204 
3205 	/* Read init */
3206 	len = ntohl(i1->i1_len);
3207 
3208 	if (len > rem || len < sizeof(*i1) || len < 0)
3209 		goto __kill_rdr;
3210 
3211 	rem = len - sizeof(*i1);
3212 
3213 	FD_ZERO(&fds);
3214 	FD_SET(fd->fd_fd, &fds);
3215 
3216 	tv.tv_sec = tv.tv_usec = 0;
3217 
3218 	if (select(fd->fd_fd + 1, &fds, NULL, NULL, &tv) == -1)
3219 		err(1, "select(2)");
3220 
3221 	if (!FD_ISSET(fd->fd_fd, &fds))
3222 		goto __kill_rdr;
3223 
3224 	if (recv(fd->fd_fd, i1 + 1, rem, 0) != rem)
3225 		goto __kill_rdr;
3226 
3227 	/* XXX */
3228 	fake_ip_tcp(ip, tcp, len);
3229 
3230 	switch (tc->tc_state) {
3231 	/* outbound connections */
3232 	case STATE_INIT1_SENT:
3233 		do_input_init1_sent(tc, ip, tcp);
3234 		rdr_handshake_complete(tc);
3235 		break;
3236 
3237 	/* inbound connections */
3238 	case STATE_PKCONF_SENT:
3239 		/* XXX sniff ENO */
3240 		if (is_init(ip, tcp, TC_INIT1)) {
3241 			add_eno(tc, ip, tcp);
3242 		} else {
3243 			tc->tc_state = STATE_DISABLED;
3244 			return;
3245 		}
3246 
3247 		do_input_pkconf_sent(tc, ip, tcp);
3248 		if (tc->tc_state != STATE_INIT2_SENT)
3249 			goto __kill_rdr;
3250 
3251 		if (send(fd->fd_fd, tc->tc_rdr_buf, tc->tc_rdr_len, 0)
3252 			  != tc->tc_rdr_len)
3253 			goto __kill_rdr;
3254 
3255 		enable_encryption(tc);
3256 		break;
3257 	}
3258 
3259 	return;
3260 __kill_rdr:
3261 	xprintf(XP_ALWAYS, "Error reading INIT %p\n", tc);
3262 	kill_rdr(tc);
3263 	return;
3264 }
3265 
rdr_local_handler(struct fd * fd)3266 static void rdr_local_handler(struct fd *fd)
3267 {
3268 	struct tc *tc = fd->fd_priv;
3269 	struct tc *peer = tc->tc_rdr_peer;
3270 
3271 	if (tc->tc_state == STATE_NEXTK2_SENT)
3272 		enable_encryption(tc);
3273 
3274 	if (peer->tc_state == STATE_NEXTK2_SENT)
3275 		enable_encryption(peer);
3276 
3277 	switch (tc->tc_state) {
3278 	case STATE_INIT1_SENT:
3279 	case STATE_PKCONF_SENT:
3280 		rdr_process_init(tc);
3281 		return;
3282 	}
3283 
3284 	if (tc->tc_state == STATE_ENCRYPTING
3285 	    || peer->tc_state == STATE_ENCRYPTING
3286 	    || tc->tc_state == STATE_RDR_PLAIN
3287 	    || peer->tc_state == STATE_RDR_PLAIN) {
3288 		proxy_connection(tc);
3289 		return;
3290 	}
3291 
3292 	/* XXX we should really fix this - shouldn't get here randomly.
3293 	 * We should: 1. check if socket is dead / alive
3294 	 * 2. not put this thing in select until we're ready.
3295 	 * 3. def not spin the CPU
3296 	 */
3297 #if 0
3298 	xprintf(XP_ALWAYS, "unhandled RDR %d:%d\n",
3299 		tc->tc_state, peer->tc_state);
3300 	kill_rdr(tc);
3301 #endif
3302 }
3303 
rdr_remote_handler(struct fd * fd)3304 static void rdr_remote_handler(struct fd *fd)
3305 {
3306 	struct tc *tc = fd->fd_priv;
3307 
3308 	if (!tc->tc_rdr_connected) {
3309 		rdr_check_connect(tc);
3310 		return;
3311 	}
3312 
3313 	rdr_local_handler(fd);
3314 }
3315 
rdr_new_connection(struct tc * tc,struct ip * ip,struct tcphdr * tcp,int flags)3316 static void rdr_new_connection(struct tc *tc, struct ip *ip, struct tcphdr *tcp,
3317 			       int flags)
3318 {
3319         struct sockaddr_in from, to;
3320         int s, rc;
3321         struct fd *sock;
3322         socklen_t len;
3323         int tos = IPTOS_RELIABILITY;
3324 	struct tc *peer;
3325 	int in = flags & DF_IN;
3326 
3327         /* figure out where connection is going to */
3328         memset(&to, 0, sizeof(to));
3329         memset(&from, 0, sizeof(from));
3330 
3331         from.sin_family = to.sin_family = PF_INET;
3332 
3333         from.sin_port        = tcp->th_sport;
3334         from.sin_addr.s_addr = ip->ip_src.s_addr;
3335 
3336         to.sin_port          = tcp->th_dport;
3337         to.sin_addr.s_addr   = ip->ip_dst.s_addr;
3338 
3339 	if (_divert->orig_dest && _divert->orig_dest(&to, ip, &flags) == -1) {
3340 		/* XXX this is retarded - we rely on the SYN retransmit to kick
3341 		 * things off again
3342 		 */
3343 		tc->tc_rdr_drop_sa = 1;
3344 		xprintf(XP_ALWAYS, "Can't find RDR\n");
3345 		return;
3346 	}
3347 
3348 	in = flags & DF_IN;
3349 
3350 	xprintf(XP_NOISY, "RDR orig dest %s:%d\n",
3351 		inet_ntoa(to.sin_addr), ntohs(to.sin_port));
3352 
3353         /* connect to destination */
3354         if ((s = socket(PF_INET, SOCK_STREAM, IPPROTO_TCP)) == -1)
3355                 err(1, "socket()");
3356 
3357 	set_nonblocking(s);
3358 
3359 #ifndef __WIN32__
3360 	/* signal handshake to firewall */
3361         if (setsockopt(s, IPPROTO_IP, IP_TOS, &tos, sizeof(tos)) == -1)
3362             err(1, "setsockopt(IP_TOS)");
3363 #endif
3364 
3365 	/* XXX bypass firewall */
3366         if (in) {
3367 		memcpy(&tc->tc_rdr_addr, &to, sizeof(tc->tc_rdr_addr));
3368                 to.sin_addr.s_addr = inet_addr("127.0.0.1");
3369 	}
3370 
3371         if ((rc = connect(s, (struct sockaddr*) &to, sizeof(to))) == -1) {
3372 #ifdef __WIN32__
3373 		if (WSAGetLastError() != WSAEWOULDBLOCK) {
3374 #else
3375 		if (errno != EINPROGRESS) {
3376 #endif
3377 			close(s);
3378 			tc->tc_state = STATE_DISABLED;
3379 			return;
3380 		}
3381 	}
3382 
3383 #ifdef __WIN32__
3384 	win_dont_rdr(s);
3385 #endif
3386 
3387 	/* XXX */
3388 	if (in && !tc->tc_rdr_drop_sa) {
3389 		to.sin_port = htons(REDIRECT_PORT);
3390 	} else {
3391 		len = sizeof(from);
3392 
3393 		if (getsockname(s, (struct sockaddr*) &from, &len) == -1)
3394 			err(1, "getsockname()");
3395 
3396 #ifdef __WIN32__
3397 		from.sin_addr.s_addr = win_local_ip();
3398 #endif
3399 	}
3400 
3401         /* create peer */
3402 	peer = do_new_connection(from.sin_addr.s_addr, from.sin_port,
3403 				 to.sin_addr.s_addr, to.sin_port, in);
3404 
3405         xprintf(XP_NOISY, "Adding a connection %s:%d",
3406 	        inet_ntoa(from.sin_addr),
3407 		ntohs(from.sin_port));
3408 
3409         xprintf(XP_NOISY, "->%s:%d [%p]%s\n",
3410                 inet_ntoa(to.sin_addr),
3411 		ntohs(to.sin_port), peer,
3412 		in ? " inbound" : "");
3413 
3414         sock = add_fd(s, rdr_remote_handler);
3415 	sock->fd_priv  = peer;
3416 	sock->fd_state = FDS_WRITE;
3417 
3418 	peer->tc_rdr_fd      = sock;
3419 	peer->tc_rdr_state   = STATE_RDR_REMOTE;
3420 	peer->tc_rdr_peer    = tc;
3421 	peer->tc_rdr_inbound = in;
3422 
3423 	memcpy(&peer->tc_rdr_addr, &to, sizeof(peer->tc_rdr_addr));
3424 
3425         /* save SYN to replay once connection is successful */
3426         len = ntohs(ip->ip_len);
3427         assert(len < sizeof(peer->tc_rdr_buf));
3428 
3429         memcpy(peer->tc_rdr_buf, ip, len);
3430         peer->tc_rdr_len = len;
3431 
3432 	if (!in) {
3433 		ip  = (struct ip *) peer->tc_rdr_buf;
3434 		tcp = get_tcp(ip);
3435 
3436 		ip->ip_dst.s_addr = to.sin_addr.s_addr;
3437 		tcp->th_dport     = to.sin_port;
3438 		checksum_packet(tc, ip, tcp);
3439 	}
3440 
3441 	tc->tc_rdr_peer  = peer;
3442 	tc->tc_rdr_state = STATE_RDR_LOCAL;
3443 
3444 	return;
3445 }
3446 
3447 static int handle_syn_ack(struct tc *tc, struct ip *ip, struct tcphdr *tcp)
3448 {
3449 	switch (tc->tc_state) {
3450 	case STATE_HELLO_RCVD:
3451 		return do_output_hello_rcvd(tc, ip, tcp);
3452 
3453 	case STATE_NEXTK2_SENT:
3454 		/* syn ack rtx */
3455 	case STATE_NEXTK1_RCVD:
3456 		return do_output_nextk1_rcvd(tc, ip, tcp);
3457 
3458 	case STATE_CLOSED:
3459 	case STATE_RDR_PLAIN:
3460 		break;
3461 
3462 	default:
3463 		return DIVERT_DROP;
3464 	}
3465 
3466 	return DIVERT_ACCEPT;
3467 }
3468 
3469 static int rdr_syn_ack(struct tc *tc, struct ip *ip, struct tcphdr *tcp)
3470 {
3471 	struct tc *peer = tc->tc_rdr_peer;
3472 
3473 	/* Linux: we let the SYN through but not the SYN ACK.  We need to let
3474 	 * the SYN through so we can get orig dest.
3475 	 */
3476 	if (tc->tc_rdr_state == STATE_RDR_NONE) {
3477 		tc->tc_rdr_drop_sa = 1;
3478 
3479 		return DIVERT_DROP;
3480 	}
3481 
3482 	if (tc->tc_rdr_drop_sa)
3483 		return handle_syn_ack(tc, ip, tcp);
3484 
3485 	if (tc->tc_rdr_inbound) {
3486 		int rc;
3487 
3488 		assert(peer);
3489 
3490 		rc = handle_syn_ack(peer, ip, tcp);
3491 
3492 		if (rc == DIVERT_DROP)
3493 			return DIVERT_DROP;
3494 
3495 		/* we're still redirecting manually */
3496 		ip->ip_src.s_addr = peer->tc_rdr_addr.sin_addr.s_addr;
3497 		tcp->th_sport     = peer->tc_rdr_addr.sin_port;
3498 		checksum_packet(tc, ip, tcp);
3499 
3500 		return DIVERT_MODIFY;
3501 	}
3502 
3503 	switch (tc->tc_state) {
3504 	case STATE_HELLO_SENT:
3505 		do_input_hello_sent(tc, ip, tcp);
3506 		break;
3507 
3508 	case STATE_NEXTK1_SENT:
3509 		do_input_nextk1_sent(tc, ip, tcp);
3510 
3511 		/* XXX wait to send an ACK */
3512 		if (tc->tc_state == STATE_ENCRYPTING)
3513 			rdr_handshake_complete(tc);
3514 		break;
3515 	}
3516 
3517 	if (tc->tc_state == STATE_DISABLED) {
3518 		tc->tc_state   = STATE_RDR_PLAIN;
3519 		tc->tc_rdr_len = 0;
3520 		rdr_handshake_complete(tc);
3521 	}
3522 
3523 	return DIVERT_ACCEPT;
3524 }
3525 
3526 static int rdr_ack(struct tc *tc, struct ip *ip, struct tcphdr *tcp)
3527 {
3528 	/* send init1 */
3529 	if (tc->tc_state == STATE_PKCONF_RCVD) {
3530 		do_output_pkconf_rcvd(tc, ip, tcp, 0);
3531 
3532 		if (send(tc->tc_rdr_fd->fd_fd, tc->tc_rdr_buf, tc->tc_rdr_len,
3533 			 0) != tc->tc_rdr_len) {
3534 			kill_rdr(tc);
3535 			return DIVERT_DROP;
3536 		}
3537 
3538 		/* drop packet - let's add ENO to it */
3539 		return DIVERT_DROP;
3540 	}
3541 
3542 	/* add eno to init1 */
3543 	if (tc->tc_state == STATE_INIT1_SENT) {
3544 		if (is_init(ip, tcp, TC_INIT1))
3545 			return do_output_pkconf_rcvd(tc, ip, tcp, 1);
3546 	}
3547 
3548 	return DIVERT_DROP;
3549 }
3550 
3551 static int rdr_syn(struct tc *tc, struct ip *ip, struct tcphdr *tcp, int flags)
3552 {
3553 	int in = flags & DIR_IN;
3554 
3555 	/* new connection */
3556 	if (tc->tc_rdr_state == STATE_RDR_NONE)
3557 		rdr_new_connection(tc, ip, tcp, flags);
3558 
3559 	if (tc->tc_rdr_state == STATE_RDR_NONE)
3560 		return DIVERT_ACCEPT;
3561 
3562 	/* incoming */
3563 	if (in) {
3564 		/* drop the locally generated SYN */
3565 		if (tc->tc_rdr_state == STATE_RDR_LOCAL
3566 		    && !tc->tc_rdr_drop_sa
3567 		    && !tc->tc_rdr_peer->tc_rdr_inbound) {
3568 			return DIVERT_DROP;
3569 		}
3570 
3571 		switch (tc->tc_state) {
3572 		case STATE_NEXTK1_RCVD:
3573 			/* XXX check same SID */
3574 		case STATE_HELLO_RCVD:
3575 		case STATE_CLOSED:
3576 			do_input_closed(tc, ip, tcp);
3577 
3578 			if (tc->tc_state == STATE_DISABLED) {
3579 				tc->tc_state   = STATE_RDR_PLAIN;
3580 				tc->tc_rdr_len = 0;
3581 			}
3582 
3583 			/* XXX clamp MSS */
3584 			return DIVERT_ACCEPT;
3585 		}
3586 
3587 		return DIVERT_DROP;
3588 	}
3589 
3590 	/* outbound */
3591 
3592 	/* Add ENO to SYN */
3593 	if (tc->tc_rdr_state == STATE_RDR_REMOTE) {
3594 		switch (tc->tc_state) {
3595 		case STATE_HELLO_SENT:
3596 		case STATE_NEXTK1_SENT:
3597 		case STATE_CLOSED:
3598 			return do_output_closed(tc, ip, tcp);
3599 		}
3600 	}
3601 
3602 	/* drop original non-ENO syn */
3603 
3604 	return DIVERT_DROP;
3605 }
3606 
3607 static int rdr_packet(struct tc *tc, struct ip *ip, struct tcphdr *tcp,
3608 		      int flags)
3609 {
3610         /* our own connections */
3611         if (ip->ip_dst.s_addr == inet_addr("127.0.0.1")
3612             && ip->ip_dst.s_addr == ip->ip_src.s_addr)
3613                 return DIVERT_ACCEPT;
3614 
3615 	if (tcp->th_flags == TH_SYN)
3616 		return rdr_syn(tc, ip, tcp, flags);
3617 
3618 	if (tcp->th_flags == (TH_SYN | TH_ACK))
3619 		return rdr_syn_ack(tc, ip, tcp);
3620 
3621 	if (tcp->th_flags & TH_ACK)
3622 		return rdr_ack(tc, ip, tcp);
3623 
3624 	return DIVERT_DROP;
3625 }
3626 
3627 int tcpcrypt_packet(void *packet, int len, int flags)
3628 {
3629 	struct ip *ip = packet;
3630 	struct tc *tc;
3631 	struct tcphdr *tcp;
3632 	int rc;
3633 
3634 	profile_add(1, "tcpcrypt_packet in");
3635 
3636 	if (ntohs(ip->ip_len) > len)
3637 		goto __bad_packet;
3638 
3639 	/* len can be larger - Ethernet padding (e.g., RSTs) */
3640 	len = ntohs(ip->ip_len);
3641 
3642 	if (ip->ip_p != IPPROTO_TCP)
3643 		return DIVERT_ACCEPT;
3644 
3645 	tcp = (struct tcphdr*) ((unsigned long) ip + (ip->ip_hl << 2));
3646 	if ((unsigned long) tcp - (unsigned long) ip + (tcp->th_off << 2) > len)
3647 		goto __bad_packet;
3648 
3649 	tc = lookup_connection(ip, tcp, flags);
3650 
3651 	/* new connection */
3652 	if (!tc) {
3653 		profile_add(1, "tcpcrypt_packet found no connection");
3654 
3655 		if (_conf.cf_disable)
3656 			return DIVERT_ACCEPT;
3657 
3658 		if (tcp->th_flags != TH_SYN) {
3659 			xprintf(XP_NOISY, "Ignoring established connection: ");
3660 			print_packet(ip, tcp, flags, tc);
3661 
3662 			return DIVERT_ACCEPT;
3663 		}
3664 
3665 		tc = new_connection(ip, tcp, flags);
3666 		profile_add(1, "tcpcrypt_packet new connection");
3667 	} else
3668 		profile_add(1, "tcpcrypt_packet found connection");
3669 
3670 	print_packet(ip, tcp, flags, tc);
3671 
3672 	tc->tc_dir_packet = (flags & DF_IN) ? DIR_IN : DIR_OUT;
3673 	tc->tc_csum       = 0;
3674 
3675 	if (_conf.cf_rdr) {
3676 		rc = rdr_packet(tc, ip, tcp, flags);
3677 	} else {
3678 		if (flags & DF_IN)
3679 			rc = do_input(tc, ip, tcp);
3680 		else
3681 			rc = do_output(tc, ip, tcp);
3682 	}
3683 
3684 	/* XXX for performance measuring - ensure sane results */
3685 	assert(!_conf.cf_debug || (tc->tc_state != STATE_DISABLED));
3686 
3687 	profile_add(1, "tcpcrypt_packet did processing");
3688 
3689 	if (rc == DIVERT_MODIFY) {
3690 		checksum_tcp(tc, ip, tcp);
3691 		profile_add(1, "tcpcrypt_packet did checksum");
3692 	}
3693 
3694 	if (tc->tc_tcp_state == TCPSTATE_DEAD
3695 	    || tc->tc_state  == STATE_DISABLED)
3696 		remove_connection(ip, tcp, flags);
3697 
3698 	profile_print();
3699 
3700 	return rc;
3701 
3702 __bad_packet:
3703 	xprintf(XP_ALWAYS, "Bad packet 2\n");
3704 	return DIVERT_ACCEPT; /* kernel will drop / deal with it */
3705 }
3706 
3707 static struct tc *sockopt_get(struct tcpcrypt_ctl *ctl)
3708 {
3709 	struct tc *tc = sockopt_find(ctl);
3710 
3711 	if (tc) {
3712 		/* XXX it depends */
3713 		if (tc->tc_rdr_peer)
3714 			return tc->tc_rdr_peer;
3715 
3716 		return tc;
3717 	}
3718 
3719 	if (ctl->tcc_sport == 0)
3720 		return NULL;
3721 
3722 	tc = get_tc();
3723 	assert(tc);
3724 
3725 	_sockopts[ctl->tcc_sport] = tc;
3726 	tc_init(tc);
3727 
3728 	return tc;
3729 }
3730 
3731 static int do_opt(int set, void *p, int len, void *val, unsigned int *vallen)
3732 {
3733 	if (set) {
3734 		if (*vallen > len)
3735 			return -1;
3736 
3737 		memcpy(p, val, *vallen);
3738 		return 0;
3739 	}
3740 
3741 	/* get */
3742 	if (len > *vallen)
3743 		len = *vallen;
3744 
3745 	memcpy(val, p, len);
3746 	*vallen = len;
3747 
3748 	return 0;
3749 }
3750 
3751 static int do_sockopt(int set, struct tc *tc, int opt, void *val,
3752 		      unsigned int *len)
3753 {
3754 	int v;
3755 	int rc;
3756 
3757 	/* do not allow options during connection */
3758 	switch (tc->tc_state) {
3759 	case STATE_CLOSED:
3760 	case STATE_ENCRYPTING:
3761 	case STATE_DISABLED:
3762 	case STATE_REKEY_SENT:
3763 	case STATE_REKEY_RCVD:
3764 	case STATE_RDR_PLAIN:
3765 		break;
3766 
3767 	default:
3768 		return EBUSY;
3769 	}
3770 
3771 	switch (opt) {
3772 	case TCP_CRYPT_ENABLE:
3773 		if (tc->tc_state == STATE_DISABLED)
3774 			v = 0;
3775 		else
3776 			v = 1;
3777 
3778 		rc = do_opt(set, &v, sizeof(v), val, len);
3779 		if (rc)
3780 			return rc;
3781 
3782 		/* XXX can't re-enable */
3783 		if (tc->tc_state == STATE_CLOSED && !v)
3784 			tc->tc_state = STATE_DISABLED;
3785 
3786 		break;
3787 
3788 	case TCP_CRYPT_APP_SUPPORT:
3789 		if (set) {
3790 			if (tc->tc_state != STATE_CLOSED)
3791 				return -1;
3792 
3793 			return do_opt(set, &tc->tc_app_support,
3794 				      sizeof(tc->tc_app_support), val, len);
3795 		} else {
3796 			unsigned char *p = val;
3797 
3798 			if (!connected(tc))
3799 				return -1;
3800 
3801 			if (*len < (tc->tc_sid.s_len + 1))
3802 				return -1;
3803 
3804 			*p++ = (char) tc->tc_app_support;
3805 			memcpy(p, tc->tc_sid.s_data, tc->tc_sid.s_len);
3806 
3807 			*len = tc->tc_sid.s_len + 1;
3808 
3809 			return 0;
3810 		}
3811 
3812 	case TCP_CRYPT_NOCACHE:
3813 		if (tc->tc_state != STATE_CLOSED)
3814 			return -1;
3815 
3816 		return do_opt(set, &tc->tc_nocache, sizeof(tc->tc_nocache),
3817 			      val, len);
3818 
3819 	case TCP_CRYPT_CMODE:
3820 		if (tc->tc_state != STATE_CLOSED)
3821 			return -1;
3822 
3823 		switch (tc->tc_cmode) {
3824 		case CMODE_ALWAYS:
3825 		case CMODE_ALWAYS_NK:
3826 			v = 1;
3827 			break;
3828 
3829 		default:
3830 			v = 0;
3831 			break;
3832 		}
3833 
3834 		rc = do_opt(set, &v, sizeof(v), val, len);
3835 		if (rc)
3836 			return rc;
3837 
3838 		if (!set)
3839 			break;
3840 
3841 		if (v)
3842 			tc->tc_cmode = CMODE_ALWAYS;
3843 		else
3844 			tc->tc_cmode = CMODE_DEFAULT;
3845 
3846 		break;
3847 
3848 	case TCP_CRYPT_SESSID:
3849 		if (set)
3850 			return -1;
3851 
3852 		if (!connected(tc))
3853 			return -1;
3854 
3855 		return do_opt(set, tc->tc_sid.s_data, tc->tc_sid.s_len,
3856 			      val, len);
3857 
3858 	default:
3859 		return -1;
3860 	}
3861 
3862 	return 0;
3863 }
3864 
3865 int tcpcryptd_setsockopt(struct tcpcrypt_ctl *s, int opt, void *val,
3866 			 unsigned int len)
3867 {
3868 	struct tc *tc;
3869 
3870 	switch (opt) {
3871 	case TCP_CRYPT_RESET:
3872 		tc = sockopt_find(s);
3873 		if (!tc)
3874 			return -1;
3875 
3876 		tc_finish(tc);
3877 		put_tc(tc);
3878 		sockopt_clear(s->tcc_sport);
3879 
3880 		return 0;
3881 	}
3882 
3883 	tc = sockopt_get(s);
3884 	if (!tc)
3885 		return -1;
3886 
3887 	return do_sockopt(1, tc, opt, val, &len);
3888 }
3889 
3890 static int do_tcpcrypt_netstat(struct conn *c, void *val, unsigned int *len)
3891 {
3892 	struct tc_netstat *n = val;
3893 	int l = *len;
3894 	int copied = 0;
3895 	struct tc *tc;
3896 	int tl;
3897 
3898 	while (c) {
3899 		tc = c->c_tc;
3900 
3901 		if (!connected(tc))
3902 			goto __next;
3903 
3904 		if (tc->tc_tcp_state == TCPSTATE_LASTACK)
3905 			goto __next;
3906 
3907 		tl = sizeof(*n) + tc->tc_sid.s_len;
3908 		if (l < tl)
3909 			break;
3910 
3911 		n->tn_sip.s_addr = c->c_addr[0].sin_addr.s_addr;
3912 		n->tn_dip.s_addr = c->c_addr[1].sin_addr.s_addr;
3913 		n->tn_sport	 = c->c_addr[0].sin_port;
3914 		n->tn_dport	 = c->c_addr[1].sin_port;
3915 		n->tn_len	 = htons(tc->tc_sid.s_len);
3916 
3917 		if (_conf.cf_rdr) {
3918 			struct tc *peer = tc->tc_rdr_peer;
3919 
3920 			switch (peer->tc_rdr_state) {
3921 			case STATE_RDR_LOCAL:
3922 				n->tn_sip.s_addr = peer->tc_rdr_addr.sin_addr
3923 								.s_addr;
3924 				n->tn_sport = peer->tc_rdr_addr.sin_port;
3925 				break;
3926 
3927 			case STATE_RDR_REMOTE:
3928 				if (ntohs(n->tn_sport) == REDIRECT_PORT)
3929 					n->tn_sport = peer->tc_rdr_addr
3930 								.sin_port;
3931 				break;
3932 			}
3933 		}
3934 
3935 		memcpy(n->tn_sid, tc->tc_sid.s_data, tc->tc_sid.s_len);
3936 		n = (struct tc_netstat*) ((unsigned long) n + tl);
3937 		copied += tl;
3938 		l -= tl;
3939 __next:
3940 		c = c->c_next;
3941 	}
3942 
3943 	*len -= copied;
3944 
3945 	return copied;
3946 }
3947 
3948 /* XXX slow */
3949 static int tcpcrypt_netstat(void *val, unsigned int *len)
3950 {
3951 	int i;
3952 	int num = sizeof(_connection_map) / sizeof(*_connection_map);
3953 	struct conn *c;
3954 	int copied = 0;
3955 	unsigned char *v = val;
3956 
3957 	for (i = 0; i < num; i++) {
3958 		c = _connection_map[i];
3959 
3960 		if (!c)
3961 			continue;
3962 
3963 		copied += do_tcpcrypt_netstat(c->c_next, &v[copied], len);
3964 	}
3965 
3966 	*len = copied;
3967 
3968 	return 0;
3969 }
3970 
3971 int tcpcryptd_getsockopt(struct tcpcrypt_ctl *s, int opt, void *val,
3972 			 unsigned int *len)
3973 {
3974 	struct tc *tc;
3975 
3976 	switch (opt) {
3977 	case TCP_CRYPT_NETSTAT:
3978 		return tcpcrypt_netstat(val, len);
3979 	}
3980 
3981 	tc = sockopt_get(s);
3982 	if (!tc)
3983 		return -1;
3984 
3985 	return do_sockopt(0, tc, opt, val, len);
3986 }
3987 
3988 static int get_pref(struct crypt_ops *ops)
3989 {
3990 	int pref = 0;
3991 
3992 	/* XXX implement */
3993 
3994 	return pref;
3995 }
3996 
3997 static void do_register_cipher(struct ciphers *c, struct cipher_list *cl)
3998 {
3999 	struct ciphers *x;
4000 	int pref = 0;
4001 
4002 	x = xmalloc(sizeof(*x));
4003 	memset(x, 0, sizeof(*x));
4004 	x->c_cipher = cl;
4005 
4006 	while (c->c_next) {
4007 		if (pref >= get_pref(NULL))
4008 			break;
4009 
4010 		c = c->c_next;
4011 	}
4012 
4013 	x->c_next  = c->c_next;
4014 	c->c_next  = x;
4015 }
4016 
4017 void tcpcrypt_register_cipher(struct cipher_list *c)
4018 {
4019 	int type = c->c_type;
4020 
4021 	switch (type) {
4022 	case TYPE_PKEY:
4023 		do_register_cipher(&_ciphers_pkey, c);
4024 		break;
4025 
4026 	case TYPE_SYM:
4027 		do_register_cipher(&_ciphers_sym, c);
4028 		break;
4029 
4030 	default:
4031 		assert(!"Unknown type");
4032 		break;
4033 	}
4034 }
4035 
4036 static void init_cipher(struct ciphers *c)
4037 {
4038 	struct crypt_pub *cp;
4039 	struct crypt_sym *cs;
4040 	uint8_t spec = c->c_cipher->c_id;
4041 
4042 	switch (c->c_cipher->c_type) {
4043 	case TYPE_PKEY:
4044 		c->c_speclen = 1;
4045 
4046 		cp = c->c_cipher->c_ctr();
4047 		crypt_pub_destroy(cp);
4048 		break;
4049 
4050 	case TYPE_SYM:
4051 		c->c_speclen = 1;
4052 
4053 		cs = crypt_new(c->c_cipher->c_ctr);
4054 		crypt_sym_destroy(cs);
4055 		break;
4056 
4057 	default:
4058 		assert(!"unknown type");
4059 		abort();
4060 	}
4061 
4062 	memcpy(c->c_spec,
4063 	       ((unsigned char*) &spec) + sizeof(spec) - c->c_speclen,
4064 	       c->c_speclen);
4065 }
4066 
4067 static void do_init_ciphers(struct ciphers *c)
4068 {
4069 	struct tc *tc = get_tc();
4070 	struct ciphers *prev = c;
4071 	struct ciphers *head = c;
4072 
4073 	c = c->c_next;
4074 
4075 	while (c) {
4076 		/* XXX */
4077 		if (TC_DUMMY != TC_DUMMY) {
4078 			if (!_conf.cf_dummy) {
4079 				/* kill dummy */
4080 				prev->c_next = c->c_next;
4081 				free(c);
4082 				c = prev->c_next;
4083 				continue;
4084 			} else {
4085 				/* leave all but dummy */
4086 				head->c_next = c;
4087 				c->c_next = NULL;
4088 				return;
4089 			}
4090 		} else if (!_conf.cf_dummy) {
4091 			/* standard path */
4092 			init_cipher(c);
4093 		}
4094 
4095 		prev = c;
4096 		c = c->c_next;
4097 	}
4098 
4099 	put_tc(tc);
4100 }
4101 
4102 static void init_ciphers(void)
4103 {
4104 	do_init_ciphers(&_ciphers_pkey);
4105 	do_init_ciphers(&_ciphers_sym);
4106 
4107 	do_add_ciphers(&_ciphers_pkey, &_pkey, &_pkey_len, sizeof(*_pkey),
4108 		       (uint8_t*) _pkey + sizeof(_pkey));
4109 	do_add_ciphers(&_ciphers_sym, &_sym, &_sym_len, sizeof(*_sym),
4110                        (uint8_t*) _sym + sizeof(_sym));
4111 }
4112 
4113 static void init_random(void)
4114 {
4115 	unsigned int seed = 0;
4116 	const char *path;
4117 	FILE *f;
4118 	size_t nread;
4119 
4120 #ifdef __WIN32__
4121 	seed = time(NULL);
4122 #else
4123 	path = _conf.cf_random_path;
4124 	if (path) {
4125 		if (!(f = fopen(path, "r"))) {
4126 			err(1, "Could not open random device %s", path);
4127 		}
4128 	}
4129 	else {
4130 		path = "/dev/urandom";
4131 		if (!(f = fopen(path, "r"))) {
4132 			path = "/dev/random";
4133 			if (!(f = fopen(path, "r"))) {
4134 				errx(1, "Could not find a random device");
4135 			}
4136 		}
4137 	}
4138 	if (f) {
4139 		xprintf(XP_ALWAYS, "Reading random seed from %s ", path);
4140 		nread = fread((void*) &seed, sizeof(seed), 1, f);
4141 		if (nread != 1) {
4142 			errx(1, "Could not read random seed from %s", path);
4143 		}
4144 		xprintf(XP_ALWAYS, "\n");
4145 	}
4146 #endif
4147 	if (seed) {
4148 		srand(seed);
4149 		xprintf(XP_DEBUG, "Random seed set to %u\n", seed);
4150 	} else {
4151 		errx(1, "Could not provide random seed");
4152 	}
4153 }
4154 
4155 static struct tc *lookup_connection_rdr(struct sockaddr_in *s_in)
4156 {
4157 	int i, j;
4158 	struct conn *c;
4159 
4160 	/* XXX data strcuture fail */
4161 	for (i = 0; i < sizeof(_connection_map) / sizeof(*_connection_map); i++)
4162 	{
4163 		c = _connection_map[i];
4164 		if (!c)
4165 			continue;
4166 
4167 		while ((c = c->c_next)) {
4168 			for (j = 0; j < 2; j++) {
4169 				struct sockaddr_in *s = &c->c_addr[j];
4170 
4171 				if (s->sin_addr.s_addr == s_in->sin_addr.s_addr
4172 				    && s->sin_port == s_in->sin_port) {
4173 					return c->c_tc;
4174 				}
4175 			}
4176 		}
4177 	}
4178 
4179 	return NULL;
4180 }
4181 
4182 static void redirect_listen_handler(struct fd *fd)
4183 {
4184         struct sockaddr_in s_in;
4185         socklen_t len = sizeof(s_in);
4186 	int dude;
4187 	struct tc *tc, *peer;
4188 
4189         /* Accept redirected connection */
4190         if ((dude = accept(fd->fd_fd, (struct sockaddr*) &s_in, &len)) == -1) {
4191                 xprintf(XP_ALWAYS, "accept() failed\n");
4192                 return;
4193         }
4194 
4195         /* try to find him */
4196 	tc = lookup_connection_rdr(&s_in);
4197 	if (!tc) {
4198                 xprintf(XP_ALWAYS, "Couldn't find dude %s:%d\n",
4199 			inet_ntoa(s_in.sin_addr), ntohs(s_in.sin_port));
4200                 close(dude);
4201                 return;
4202         }
4203 
4204 	peer = tc->tc_rdr_peer;
4205 	if (!peer) {
4206 		xprintf(XP_ALWAYS, "Redirected connection from %s:%d: tc %p has no peer; "
4207 				   "closing connection\n",
4208 			inet_ntoa(s_in.sin_addr), ntohs(s_in.sin_port), tc);
4209 		close(dude);
4210 		kill_rdr(tc);
4211 		return;
4212 	}
4213 
4214 	if (tc->tc_rdr_inbound) {
4215 		struct tc *tmp = peer;
4216 
4217 		peer = tc;
4218 		tc   = tmp;
4219 	}
4220 
4221 	/* XXX */
4222 	if (!peer->tc_rdr_fd) {
4223 		close(dude);
4224 		kill_rdr(peer);
4225 		return;
4226 	}
4227 
4228 	assert(peer);
4229 	assert(peer->tc_rdr_peer == tc);
4230 	assert(peer->tc_rdr_fd);
4231 	assert(!tc->tc_rdr_fd);
4232 
4233 	fd = add_fd(dude, rdr_local_handler);
4234 	fd->fd_priv   = tc;
4235 	tc->tc_rdr_fd = fd;
4236 
4237 	memcpy(&tc->tc_rdr_addr, &s_in, sizeof(tc->tc_rdr_addr));
4238 
4239         xprintf(XP_NOISY, "Redirect proxy accepted %s:%d",
4240                 inet_ntoa(tc->tc_rdr_addr.sin_addr),
4241 		ntohs(tc->tc_rdr_addr.sin_port));
4242 
4243         xprintf(XP_NOISY, "->%s:%d\n",
4244                 inet_ntoa(peer->tc_rdr_addr.sin_addr),
4245 		ntohs(peer->tc_rdr_addr.sin_port));
4246 
4247 	/* wake up peer */
4248 	if (peer->tc_rdr_fd->fd_state == FDS_IDLE)
4249 		peer->tc_rdr_fd->fd_state = FDS_READ;
4250 }
4251 
4252 static void init_rdr(void)
4253 {
4254         int s, one = 1;
4255         struct sockaddr_in s_in;
4256 
4257 	if (!_conf.cf_rdr)
4258 		return;
4259 
4260         if ((s = socket(PF_INET, SOCK_STREAM, IPPROTO_TCP)) == -1)
4261                 err(1, "socket()");
4262 
4263         if (setsockopt(s, SOL_SOCKET, SO_REUSEADDR, &one, sizeof(one)) == -1)
4264                 err(1, "setsockopt()");
4265 
4266         memset(&s_in, 0, sizeof(s_in));
4267 
4268         s_in.sin_family      = PF_INET;
4269         s_in.sin_port        = htons(REDIRECT_PORT);
4270         s_in.sin_addr.s_addr = INADDR_ANY;
4271 
4272         if (bind(s, (struct sockaddr*) &s_in, sizeof(s_in)) == -1)
4273                 err(1, "bind()");
4274 
4275         if (listen(s, 5) == -1)
4276                 err(1, "listen()");
4277 
4278         add_fd(s, redirect_listen_handler);
4279 }
4280 
4281 void tcpcrypt_init(void)
4282 {
4283 	init_random();
4284 	init_ciphers();
4285 	init_rdr();
4286 }
4287