xref: /openbsd/regress/lib/libssl/dtls/dtlstest.c (revision c9675a23)
1 /* $OpenBSD: dtlstest.c,v 1.18 2022/11/26 16:08:56 tb Exp $ */
2 /*
3  * Copyright (c) 2020, 2021 Joel Sing <jsing@openbsd.org>
4  *
5  * Permission to use, copy, modify, and distribute this software for any
6  * purpose with or without fee is hereby granted, provided that the above
7  * copyright notice and this permission notice appear in all copies.
8  *
9  * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
10  * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
11  * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
12  * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
13  * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
14  * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
15  * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
16  */
17 
18 #include <netinet/in.h>
19 #include <sys/socket.h>
20 
21 #include <err.h>
22 #include <limits.h>
23 #include <poll.h>
24 #include <unistd.h>
25 
26 #include <openssl/bio.h>
27 #include <openssl/err.h>
28 #include <openssl/ssl.h>
29 
30 #include "bio_local.h"
31 #include "ssl_local.h"
32 
33 const char *server_ca_file;
34 const char *server_cert_file;
35 const char *server_key_file;
36 
37 char dtls_cookie[32];
38 
39 int debug = 0;
40 
41 void tls12_record_layer_set_initial_epoch(struct tls12_record_layer *rl,
42     uint16_t epoch);
43 
44 static void
hexdump(const unsigned char * buf,size_t len)45 hexdump(const unsigned char *buf, size_t len)
46 {
47 	size_t i;
48 
49 	for (i = 1; i <= len; i++)
50 		fprintf(stderr, " 0x%02hhx,%s", buf[i - 1], i % 8 ? "" : "\n");
51 
52 	if (len % 8)
53 		fprintf(stderr, "\n");
54 }
55 
56 #define BIO_C_DELAY_COUNT	1000
57 #define BIO_C_DELAY_FLUSH	1001
58 #define BIO_C_DELAY_PACKET	1002
59 #define BIO_C_DROP_PACKET	1003
60 #define BIO_C_DROP_RANDOM	1004
61 
62 struct bio_packet_monkey_ctx {
63 	unsigned int delay_count;
64 	unsigned int delay_mask;
65 	unsigned int drop_rand;
66 	unsigned int drop_mask;
67 	uint8_t *delayed_msg;
68 	size_t delayed_msg_len;
69 };
70 
71 static int
bio_packet_monkey_new(BIO * bio)72 bio_packet_monkey_new(BIO *bio)
73 {
74 	struct bio_packet_monkey_ctx *ctx;
75 
76 	if ((ctx = calloc(1, sizeof(*ctx))) == NULL)
77 		return 0;
78 
79 	bio->flags = 0;
80 	bio->init = 1;
81 	bio->num = 0;
82 	bio->ptr = ctx;
83 
84 	return 1;
85 }
86 
87 static int
bio_packet_monkey_free(BIO * bio)88 bio_packet_monkey_free(BIO *bio)
89 {
90 	struct bio_packet_monkey_ctx *ctx;
91 
92 	if (bio == NULL)
93 		return 1;
94 
95 	ctx = bio->ptr;
96 	free(ctx->delayed_msg);
97 	free(ctx);
98 
99 	return 1;
100 }
101 
102 static int
bio_packet_monkey_delay_flush(BIO * bio)103 bio_packet_monkey_delay_flush(BIO *bio)
104 {
105 	struct bio_packet_monkey_ctx *ctx = bio->ptr;
106 
107 	if (ctx->delayed_msg == NULL)
108 		return 1;
109 
110 	if (debug)
111 		fprintf(stderr, "DEBUG: flushing delayed packet...\n");
112 	if (debug > 1)
113 		hexdump(ctx->delayed_msg, ctx->delayed_msg_len);
114 
115 	BIO_write(bio->next_bio, ctx->delayed_msg, ctx->delayed_msg_len);
116 
117 	free(ctx->delayed_msg);
118 	ctx->delayed_msg = NULL;
119 
120 	return BIO_ctrl(bio->next_bio, BIO_CTRL_FLUSH, 0, NULL);
121 }
122 
123 static long
bio_packet_monkey_ctrl(BIO * bio,int cmd,long num,void * ptr)124 bio_packet_monkey_ctrl(BIO *bio, int cmd, long num, void *ptr)
125 {
126 	struct bio_packet_monkey_ctx *ctx;
127 
128 	ctx = bio->ptr;
129 
130 	switch (cmd) {
131 	case BIO_C_DELAY_COUNT:
132 		if (num < 1 || num > 31)
133 			return 0;
134 		ctx->delay_count = num;
135 		return 1;
136 
137 	case BIO_C_DELAY_FLUSH:
138 		return bio_packet_monkey_delay_flush(bio);
139 
140 	case BIO_C_DELAY_PACKET:
141 		if (num < 1 || num > 31)
142 			return 0;
143 		ctx->delay_mask |= 1 << ((unsigned int)num - 1);
144 		return 1;
145 
146 	case BIO_C_DROP_PACKET:
147 		if (num < 1 || num > 31)
148 			return 0;
149 		ctx->drop_mask |= 1 << ((unsigned int)num - 1);
150 		return 1;
151 
152 	case BIO_C_DROP_RANDOM:
153 		if (num < 0 || (size_t)num > UINT_MAX)
154 			return 0;
155 		ctx->drop_rand = (unsigned int)num;
156 		return 1;
157 	}
158 
159 	if (bio->next_bio == NULL)
160 		return 0;
161 
162 	return BIO_ctrl(bio->next_bio, cmd, num, ptr);
163 }
164 
165 static int
bio_packet_monkey_read(BIO * bio,char * out,int out_len)166 bio_packet_monkey_read(BIO *bio, char *out, int out_len)
167 {
168 	struct bio_packet_monkey_ctx *ctx = bio->ptr;
169 	int ret;
170 
171 	if (ctx == NULL || bio->next_bio == NULL)
172 		return 0;
173 
174 	ret = BIO_read(bio->next_bio, out, out_len);
175 
176 	if (ret > 0) {
177 		if (debug)
178 			fprintf(stderr, "DEBUG: read packet...\n");
179 		if (debug > 1)
180 			hexdump(out, ret);
181 	}
182 
183 	BIO_clear_retry_flags(bio);
184 	if (ret <= 0 && BIO_should_retry(bio->next_bio))
185 		BIO_set_retry_read(bio);
186 
187 	return ret;
188 }
189 
190 static int
bio_packet_monkey_write(BIO * bio,const char * in,int in_len)191 bio_packet_monkey_write(BIO *bio, const char *in, int in_len)
192 {
193 	struct bio_packet_monkey_ctx *ctx = bio->ptr;
194 	const char *label = "writing";
195 	int delay = 0, drop = 0;
196 	int ret;
197 
198 	if (ctx == NULL || bio->next_bio == NULL)
199 		return 0;
200 
201 	if (ctx->delayed_msg != NULL && ctx->delay_count > 0)
202 		ctx->delay_count--;
203 
204 	if (ctx->delayed_msg != NULL && ctx->delay_count == 0) {
205 		if (debug)
206 			fprintf(stderr, "DEBUG: writing delayed packet...\n");
207 		if (debug > 1)
208 			hexdump(ctx->delayed_msg, ctx->delayed_msg_len);
209 
210 		ret = BIO_write(bio->next_bio, ctx->delayed_msg,
211 		    ctx->delayed_msg_len);
212 
213 		BIO_clear_retry_flags(bio);
214 		if (ret <= 0 && BIO_should_retry(bio->next_bio)) {
215 			BIO_set_retry_write(bio);
216 			return (ret);
217 		}
218 
219 		free(ctx->delayed_msg);
220 		ctx->delayed_msg = NULL;
221 	}
222 
223 	if (ctx->delay_mask > 0) {
224 		delay = ctx->delay_mask & 1;
225 		ctx->delay_mask >>= 1;
226 	}
227 	if (ctx->drop_rand > 0) {
228 		drop = arc4random_uniform(ctx->drop_rand) == 0;
229 	} else if (ctx->drop_mask > 0) {
230 		drop = ctx->drop_mask & 1;
231 		ctx->drop_mask >>= 1;
232 	}
233 
234 	if (delay)
235 		label = "delaying";
236 	if (drop)
237 		label = "dropping";
238 	if (debug)
239 		fprintf(stderr, "DEBUG: %s packet...\n", label);
240 	if (debug > 1)
241 		hexdump(in, in_len);
242 
243 	if (drop)
244 		return in_len;
245 
246 	if (delay) {
247 		if (ctx->delayed_msg != NULL)
248 			return 0;
249 		if ((ctx->delayed_msg = calloc(1, in_len)) == NULL)
250 			return 0;
251 		memcpy(ctx->delayed_msg, in, in_len);
252 		ctx->delayed_msg_len = in_len;
253 		return in_len;
254 	}
255 
256 	ret = BIO_write(bio->next_bio, in, in_len);
257 
258 	BIO_clear_retry_flags(bio);
259 	if (ret <= 0 && BIO_should_retry(bio->next_bio))
260 		BIO_set_retry_write(bio);
261 
262 	return ret;
263 }
264 
265 static int
bio_packet_monkey_puts(BIO * bio,const char * str)266 bio_packet_monkey_puts(BIO *bio, const char *str)
267 {
268 	return bio_packet_monkey_write(bio, str, strlen(str));
269 }
270 
271 static const BIO_METHOD bio_packet_monkey = {
272 	.type = BIO_TYPE_BUFFER,
273 	.name = "packet monkey",
274 	.bread = bio_packet_monkey_read,
275 	.bwrite = bio_packet_monkey_write,
276 	.bputs = bio_packet_monkey_puts,
277 	.ctrl = bio_packet_monkey_ctrl,
278 	.create = bio_packet_monkey_new,
279 	.destroy = bio_packet_monkey_free
280 };
281 
282 static const BIO_METHOD *
BIO_f_packet_monkey(void)283 BIO_f_packet_monkey(void)
284 {
285 	return &bio_packet_monkey;
286 }
287 
288 static BIO *
BIO_new_packet_monkey(void)289 BIO_new_packet_monkey(void)
290 {
291 	return BIO_new(BIO_f_packet_monkey());
292 }
293 
294 static int
BIO_packet_monkey_delay(BIO * bio,int num,int count)295 BIO_packet_monkey_delay(BIO *bio, int num, int count)
296 {
297 	if (!BIO_ctrl(bio, BIO_C_DELAY_COUNT, count, NULL))
298 		return 0;
299 
300 	return BIO_ctrl(bio, BIO_C_DELAY_PACKET, num, NULL);
301 }
302 
303 static int
BIO_packet_monkey_delay_flush(BIO * bio)304 BIO_packet_monkey_delay_flush(BIO *bio)
305 {
306 	return BIO_ctrl(bio, BIO_C_DELAY_FLUSH, 0, NULL);
307 }
308 
309 static int
BIO_packet_monkey_drop(BIO * bio,int num)310 BIO_packet_monkey_drop(BIO *bio, int num)
311 {
312 	return BIO_ctrl(bio, BIO_C_DROP_PACKET, num, NULL);
313 }
314 
315 #if 0
316 static int
317 BIO_packet_monkey_drop_random(BIO *bio, int num)
318 {
319 	return BIO_ctrl(bio, BIO_C_DROP_RANDOM, num, NULL);
320 }
321 #endif
322 
323 static int
datagram_pair(int * client_sock,int * server_sock,struct sockaddr_in * server_sin)324 datagram_pair(int *client_sock, int *server_sock,
325     struct sockaddr_in *server_sin)
326 {
327 	struct sockaddr_in sin;
328 	socklen_t sock_len;
329 	int cs = -1, ss = -1;
330 
331 	memset(&sin, 0, sizeof(sin));
332 	sin.sin_family = AF_INET;
333 	sin.sin_port = 0;
334 	sin.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
335 
336 	if ((ss = socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP)) == -1)
337 		err(1, "server socket");
338 	if (bind(ss, (struct sockaddr *)&sin, sizeof(sin)) == -1)
339 		err(1, "server bind");
340 	sock_len = sizeof(sin);
341 	if (getsockname(ss, (struct sockaddr *)&sin, &sock_len) == -1)
342 		err(1, "server getsockname");
343 
344 	if ((cs = socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP)) == -1)
345 		err(1, "client socket");
346 	if (connect(cs, (struct sockaddr *)&sin, sizeof(sin)) == -1)
347 		err(1, "client connect");
348 
349 	*client_sock = cs;
350 	*server_sock = ss;
351 	memcpy(server_sin, &sin, sizeof(sin));
352 
353 	return 1;
354 }
355 
356 static int
poll_timeout(SSL * client,SSL * server)357 poll_timeout(SSL *client, SSL *server)
358 {
359 	int client_timeout = 0, server_timeout = 0;
360 	struct timeval timeout;
361 
362 	if (DTLSv1_get_timeout(client, &timeout))
363 		client_timeout = timeout.tv_sec * 1000 + timeout.tv_usec / 1000;
364 
365 	if (DTLSv1_get_timeout(server, &timeout))
366 		server_timeout = timeout.tv_sec * 1000 + timeout.tv_usec / 1000;
367 
368 	if (client_timeout < 10)
369 		client_timeout = 10;
370 	if (server_timeout < 10)
371 		server_timeout = 10;
372 
373 	/* XXX */
374 	if (client_timeout <= 0)
375 		return server_timeout;
376 	if (client_timeout > 0 && server_timeout <= 0)
377 		return client_timeout;
378 	if (client_timeout < server_timeout)
379 		return client_timeout;
380 
381 	return server_timeout;
382 }
383 
384 static int
dtls_cookie_generate(SSL * ssl,unsigned char * cookie,unsigned int * cookie_len)385 dtls_cookie_generate(SSL *ssl, unsigned char *cookie,
386     unsigned int *cookie_len)
387 {
388 	arc4random_buf(dtls_cookie, sizeof(dtls_cookie));
389 	memcpy(cookie, dtls_cookie, sizeof(dtls_cookie));
390 	*cookie_len = sizeof(dtls_cookie);
391 
392 	return 1;
393 }
394 
395 static int
dtls_cookie_verify(SSL * ssl,const unsigned char * cookie,unsigned int cookie_len)396 dtls_cookie_verify(SSL *ssl, const unsigned char *cookie,
397     unsigned int cookie_len)
398 {
399 	return cookie_len == sizeof(dtls_cookie) &&
400 	    memcmp(cookie, dtls_cookie, sizeof(dtls_cookie)) == 0;
401 }
402 
403 static void
dtls_info_callback(const SSL * ssl,int type,int val)404 dtls_info_callback(const SSL *ssl, int type, int val)
405 {
406 	/*
407 	 * Squeals ahead... remove the bbio from the info callback, so we can
408 	 * drop specific messages. Ideally this would be an option for the SSL.
409 	 */
410 	if (ssl->wbio == ssl->bbio)
411 		((SSL *)ssl)->wbio = BIO_pop(ssl->wbio);
412 }
413 
414 static SSL *
dtls_client(int sock,struct sockaddr_in * server_sin,long mtu)415 dtls_client(int sock, struct sockaddr_in *server_sin, long mtu)
416 {
417 	SSL_CTX *ssl_ctx = NULL;
418 	SSL *ssl = NULL;
419 	BIO *bio = NULL;
420 
421 	if ((bio = BIO_new_dgram(sock, BIO_NOCLOSE)) == NULL)
422 		errx(1, "client bio");
423 	if (!BIO_socket_nbio(sock, 1))
424 		errx(1, "client nbio");
425 	if (!BIO_ctrl_set_connected(bio, 1, server_sin))
426 		errx(1, "client set connected");
427 
428 	if ((ssl_ctx = SSL_CTX_new(DTLS_method())) == NULL)
429 		errx(1, "client context");
430 
431 	if ((ssl = SSL_new(ssl_ctx)) == NULL)
432 		errx(1, "client ssl");
433 
434 	SSL_set_bio(ssl, bio, bio);
435 	bio = NULL;
436 
437 	if (mtu > 0) {
438 		SSL_set_options(ssl, SSL_OP_NO_QUERY_MTU);
439 		SSL_set_mtu(ssl, mtu);
440 	}
441 
442 	SSL_CTX_free(ssl_ctx);
443 	BIO_free(bio);
444 
445 	return ssl;
446 }
447 
448 static SSL *
dtls_server(int sock,long options,long mtu)449 dtls_server(int sock, long options, long mtu)
450 {
451 	SSL_CTX *ssl_ctx = NULL;
452 	SSL *ssl = NULL;
453 	BIO *bio = NULL;
454 
455 	if ((bio = BIO_new_dgram(sock, BIO_NOCLOSE)) == NULL)
456 		errx(1, "server bio");
457 	if (!BIO_socket_nbio(sock, 1))
458 		errx(1, "server nbio");
459 
460 	if ((ssl_ctx = SSL_CTX_new(DTLS_method())) == NULL)
461 		errx(1, "server context");
462 
463 	SSL_CTX_set_cookie_generate_cb(ssl_ctx, dtls_cookie_generate);
464 	SSL_CTX_set_cookie_verify_cb(ssl_ctx, dtls_cookie_verify);
465 	SSL_CTX_set_dh_auto(ssl_ctx, 2);
466 	SSL_CTX_set_options(ssl_ctx, options);
467 
468 	if (SSL_CTX_use_certificate_chain_file(ssl_ctx, server_cert_file) != 1) {
469 		fprintf(stderr, "FAIL: Failed to load server certificate");
470 		goto failure;
471 	}
472 	if (SSL_CTX_use_PrivateKey_file(ssl_ctx, server_key_file,
473 	    SSL_FILETYPE_PEM) != 1) {
474 		fprintf(stderr, "FAIL: Failed to load server private key");
475 		goto failure;
476 	}
477 
478 	if ((ssl = SSL_new(ssl_ctx)) == NULL)
479 		errx(1, "server ssl");
480 
481 	if (SSL_use_certificate_chain_file(ssl, server_cert_file) != 1) {
482 		fprintf(stderr, "FAIL: Failed to load server certificate");
483 		goto failure;
484 	}
485 	SSL_set_bio(ssl, bio, bio);
486 	bio = NULL;
487 
488 	if (mtu > 0) {
489 		SSL_set_options(ssl, SSL_OP_NO_QUERY_MTU);
490 		SSL_set_mtu(ssl, mtu);
491 	}
492 
493  failure:
494 	SSL_CTX_free(ssl_ctx);
495 	BIO_free(bio);
496 
497 	return ssl;
498 }
499 
500 static int
ssl_error(SSL * ssl,const char * name,const char * desc,int ssl_ret,short * events)501 ssl_error(SSL *ssl, const char *name, const char *desc, int ssl_ret,
502     short *events)
503 {
504 	int ssl_err;
505 
506 	ssl_err = SSL_get_error(ssl, ssl_ret);
507 
508 	if (ssl_err == SSL_ERROR_WANT_READ) {
509 		*events = POLLIN;
510 	} else if (ssl_err == SSL_ERROR_WANT_WRITE) {
511 		*events = POLLOUT;
512 	} else if (ssl_err == SSL_ERROR_SYSCALL && errno == 0) {
513 		/* Yup, this is apparently a thing... */
514 	} else {
515 		fprintf(stderr, "FAIL: %s %s failed - ssl err = %d, errno = %d\n",
516 		    name, desc, ssl_err, errno);
517 		ERR_print_errors_fp(stderr);
518 		return 0;
519 	}
520 
521 	return 1;
522 }
523 
524 static int
do_connect(SSL * ssl,const char * name,int * done,short * events)525 do_connect(SSL *ssl, const char *name, int *done, short *events)
526 {
527 	int ssl_ret;
528 
529 	if ((ssl_ret = SSL_connect(ssl)) != 1)
530 		return ssl_error(ssl, name, "connect", ssl_ret, events);
531 
532 	fprintf(stderr, "INFO: %s connect done\n", name);
533 	*done = 1;
534 
535 	return 1;
536 }
537 
538 static int
do_connect_read(SSL * ssl,const char * name,int * done,short * events)539 do_connect_read(SSL *ssl, const char *name, int *done, short *events)
540 {
541 	uint8_t buf[2048];
542 	int ssl_ret;
543 	int i;
544 
545 	if ((ssl_ret = SSL_connect(ssl)) != 1)
546 		return ssl_error(ssl, name, "connect", ssl_ret, events);
547 
548 	fprintf(stderr, "INFO: %s connect done\n", name);
549 	*done = 1;
550 
551 	for (i = 0; i < 3; i++) {
552 		fprintf(stderr, "INFO: %s reading after connect\n", name);
553 		if ((ssl_ret = SSL_read(ssl, buf, sizeof(buf))) != 3) {
554 			fprintf(stderr, "ERROR: %s read failed\n", name);
555 			return 0;
556 		}
557 	}
558 
559 	return 1;
560 }
561 
562 static int
do_connect_shutdown(SSL * ssl,const char * name,int * done,short * events)563 do_connect_shutdown(SSL *ssl, const char *name, int *done, short *events)
564 {
565 	uint8_t buf[2048];
566 	int ssl_ret;
567 
568 	if ((ssl_ret = SSL_connect(ssl)) != 1)
569 		return ssl_error(ssl, name, "connect", ssl_ret, events);
570 
571 	fprintf(stderr, "INFO: %s connect done\n", name);
572 	*done = 1;
573 
574 	ssl_ret = SSL_read(ssl, buf, sizeof(buf));
575 	if (SSL_get_error(ssl, ssl_ret) != SSL_ERROR_ZERO_RETURN) {
576 		fprintf(stderr, "FAIL: %s did not receive close-notify\n", name);
577 		return 0;
578 	}
579 
580 	fprintf(stderr, "INFO: %s received close-notify\n", name);
581 
582 	return 1;
583 }
584 
585 static int
do_accept(SSL * ssl,const char * name,int * done,short * events)586 do_accept(SSL *ssl, const char *name, int *done, short *events)
587 {
588 	int ssl_ret;
589 
590 	if ((ssl_ret = SSL_accept(ssl)) != 1)
591 		return ssl_error(ssl, name, "accept", ssl_ret, events);
592 
593 	fprintf(stderr, "INFO: %s accept done\n", name);
594 	*done = 1;
595 
596 	return 1;
597 }
598 
599 static int
do_accept_write(SSL * ssl,const char * name,int * done,short * events)600 do_accept_write(SSL *ssl, const char *name, int *done, short *events)
601 {
602 	int ssl_ret;
603 	BIO *bio;
604 	int i;
605 
606 	if ((ssl_ret = SSL_accept(ssl)) != 1)
607 		return ssl_error(ssl, name, "accept", ssl_ret, events);
608 
609 	fprintf(stderr, "INFO: %s accept done\n", name);
610 
611 	for (i = 0; i < 3; i++) {
612 		fprintf(stderr, "INFO: %s writing after accept\n", name);
613 		if ((ssl_ret = SSL_write(ssl, "abc", 3)) != 3) {
614 			fprintf(stderr, "ERROR: %s write failed\n", name);
615 			return 0;
616 		}
617 	}
618 
619 	if ((bio = SSL_get_wbio(ssl)) == NULL)
620 		errx(1, "SSL has NULL bio");
621 
622 	/* Flush any delayed packets. */
623 	BIO_packet_monkey_delay_flush(bio);
624 
625 	*done = 1;
626 	return 1;
627 }
628 
629 static int
do_accept_shutdown(SSL * ssl,const char * name,int * done,short * events)630 do_accept_shutdown(SSL *ssl, const char *name, int *done, short *events)
631 {
632 	int ssl_ret;
633 	BIO *bio;
634 
635 	if ((ssl_ret = SSL_accept(ssl)) != 1)
636 		return ssl_error(ssl, name, "accept", ssl_ret, events);
637 
638 	fprintf(stderr, "INFO: %s accept done\n", name);
639 
640 	SSL_shutdown(ssl);
641 
642 	if ((bio = SSL_get_wbio(ssl)) == NULL)
643 		errx(1, "SSL has NULL bio");
644 
645 	/* Flush any delayed packets. */
646 	BIO_packet_monkey_delay_flush(bio);
647 
648 	*done = 1;
649 	return 1;
650 }
651 
652 static int
do_read(SSL * ssl,const char * name,int * done,short * events)653 do_read(SSL *ssl, const char *name, int *done, short *events)
654 {
655 	uint8_t buf[512];
656 	int ssl_ret;
657 
658 	if ((ssl_ret = SSL_read(ssl, buf, sizeof(buf))) > 0) {
659 		fprintf(stderr, "INFO: %s read done\n", name);
660 		if (debug > 1)
661 			hexdump(buf, ssl_ret);
662 		*done = 1;
663 		return 1;
664 	}
665 
666 	return ssl_error(ssl, name, "read", ssl_ret, events);
667 }
668 
669 static int
do_write(SSL * ssl,const char * name,int * done,short * events)670 do_write(SSL *ssl, const char *name, int *done, short *events)
671 {
672 	const uint8_t buf[] = "Hello, World!\n";
673 	int ssl_ret;
674 
675 	if ((ssl_ret = SSL_write(ssl, buf, sizeof(buf))) > 0) {
676 		fprintf(stderr, "INFO: %s write done\n", name);
677 		*done = 1;
678 		return 1;
679 	}
680 
681 	return ssl_error(ssl, name, "write", ssl_ret, events);
682 }
683 
684 static int
do_shutdown(SSL * ssl,const char * name,int * done,short * events)685 do_shutdown(SSL *ssl, const char *name, int *done, short *events)
686 {
687 	int ssl_ret;
688 
689 	ssl_ret = SSL_shutdown(ssl);
690 	if (ssl_ret == 1) {
691 		fprintf(stderr, "INFO: %s shutdown done\n", name);
692 		*done = 1;
693 		return 1;
694 	}
695 	return ssl_error(ssl, name, "shutdown", ssl_ret, events);
696 }
697 
698 typedef int (ssl_func)(SSL *ssl, const char *name, int *done, short *events);
699 
700 static int
do_client_server_loop(SSL * client,ssl_func * client_func,SSL * server,ssl_func * server_func,struct pollfd pfd[2])701 do_client_server_loop(SSL *client, ssl_func *client_func, SSL *server,
702     ssl_func *server_func, struct pollfd pfd[2])
703 {
704 	int client_done = 0, server_done = 0;
705 	int i = 0;
706 
707 	pfd[0].revents = POLLIN;
708 	pfd[1].revents = POLLIN;
709 
710 	do {
711 		if (!client_done) {
712 			if (debug)
713 				fprintf(stderr, "DEBUG: client loop\n");
714 			if (DTLSv1_handle_timeout(client) > 0)
715 				fprintf(stderr, "INFO: client timeout\n");
716 			if (!client_func(client, "client", &client_done,
717 			    &pfd[0].events))
718 				return 0;
719 			if (client_done)
720 				pfd[0].events = 0;
721 		}
722 		if (!server_done) {
723 			if (debug)
724 				fprintf(stderr, "DEBUG: server loop\n");
725 			if (DTLSv1_handle_timeout(server) > 0)
726 				fprintf(stderr, "INFO: server timeout\n");
727 			if (!server_func(server, "server", &server_done,
728 			    &pfd[1].events))
729 				return 0;
730 			if (server_done)
731 				pfd[1].events = 0;
732 		}
733 		if (poll(pfd, 2, poll_timeout(client, server)) == -1)
734 			err(1, "poll");
735 
736 	} while (i++ < 100 && (!client_done || !server_done));
737 
738 	if (!client_done || !server_done)
739 		fprintf(stderr, "FAIL: gave up\n");
740 
741 	return client_done && server_done;
742 }
743 
744 #define MAX_PACKET_DELAYS 32
745 #define MAX_PACKET_DROPS 32
746 
747 struct dtls_delay {
748 	uint8_t packet;
749 	uint8_t count;
750 };
751 
752 struct dtls_test {
753 	const unsigned char *desc;
754 	long mtu;
755 	long ssl_options;
756 	int client_bbio_off;
757 	int server_bbio_off;
758 	uint16_t initial_epoch;
759 	int write_after_accept;
760 	int shutdown_after_accept;
761 	struct dtls_delay client_delays[MAX_PACKET_DELAYS];
762 	struct dtls_delay server_delays[MAX_PACKET_DELAYS];
763 	uint8_t client_drops[MAX_PACKET_DROPS];
764 	uint8_t server_drops[MAX_PACKET_DROPS];
765 };
766 
767 static const struct dtls_test dtls_tests[] = {
768 	{
769 		.desc = "DTLS without cookies",
770 		.ssl_options = 0,
771 	},
772 	{
773 		.desc = "DTLS without cookies (initial epoch 0xfffe)",
774 		.ssl_options = 0,
775 		.initial_epoch = 0xfffe,
776 	},
777 	{
778 		.desc = "DTLS without cookies (initial epoch 0xffff)",
779 		.ssl_options = 0,
780 		.initial_epoch = 0xffff,
781 	},
782 	{
783 		.desc = "DTLS with cookies",
784 		.ssl_options = SSL_OP_COOKIE_EXCHANGE,
785 	},
786 	{
787 		.desc = "DTLS with low MTU",
788 		.mtu = 256,
789 		.ssl_options = 0,
790 	},
791 	{
792 		.desc = "DTLS with low MTU and cookies",
793 		.mtu = 256,
794 		.ssl_options = SSL_OP_COOKIE_EXCHANGE,
795 	},
796 	{
797 		.desc = "DTLS with dropped server response",
798 		.ssl_options = 0,
799 		.server_drops = { 1 },
800 	},
801 	{
802 		.desc = "DTLS with two dropped server responses",
803 		.ssl_options = 0,
804 		.server_drops = { 1, 2 },
805 	},
806 	{
807 		.desc = "DTLS with dropped ServerHello",
808 		.ssl_options = SSL_OP_NO_TICKET,
809 		.server_bbio_off = 1,
810 		.server_drops = { 1 },
811 	},
812 	{
813 		.desc = "DTLS with dropped server Certificate",
814 		.ssl_options = SSL_OP_NO_TICKET,
815 		.server_bbio_off = 1,
816 		.server_drops = { 2 },
817 	},
818 	{
819 		.desc = "DTLS with dropped ServerKeyExchange",
820 		.ssl_options = SSL_OP_NO_TICKET,
821 		.server_bbio_off = 1,
822 		.server_drops = { 3 },
823 	},
824 	{
825 		.desc = "DTLS with dropped ServerHelloDone",
826 		.ssl_options = SSL_OP_NO_TICKET,
827 		.server_bbio_off = 1,
828 		.server_drops = { 4 },
829 	},
830 #if 0
831 	/*
832 	 * These two result in the server accept completing and the
833 	 * client looping on a timeout. Presumably the server should not
834 	 * complete until the client Finished is received... this due to
835 	 * a flaw in the DTLSv1.0 specification, which is addressed in
836 	 * DTLSv1.2 (see references to "last flight" in RFC 6347 section
837 	 * 4.2.4). Our DTLS server code still needs to support this.
838 	 */
839 	{
840 		.desc = "DTLS with dropped server CCS",
841 		.ssl_options = 0,
842 		.server_bbio_off = 1,
843 		.server_drops = { 5 },
844 	},
845 	{
846 		.desc = "DTLS with dropped server Finished",
847 		.ssl_options = 0,
848 		.server_bbio_off = 1,
849 		.server_drops = { 6 },
850 	},
851 #endif
852 	{
853 		.desc = "DTLS with dropped ClientKeyExchange",
854 		.ssl_options = 0,
855 		.client_bbio_off = 1,
856 		.client_drops = { 2 },
857 	},
858 	{
859 		.desc = "DTLS with dropped client CCS",
860 		.ssl_options = 0,
861 		.client_bbio_off = 1,
862 		.client_drops = { 3 },
863 	},
864 	{
865 		.desc = "DTLS with dropped client Finished",
866 		.ssl_options = 0,
867 		.client_bbio_off = 1,
868 		.client_drops = { 4 },
869 	},
870 	{
871 		/* Send CCS after client Finished. */
872 		.desc = "DTLS with delayed client CCS",
873 		.ssl_options = 0,
874 		.client_bbio_off = 1,
875 		.client_delays = { { 3, 2 } },
876 	},
877 	{
878 		/*
879 		 * Send CCS after server Finished - note app data will be
880 		 * dropped if we send the CCS after app data.
881 		 */
882 		.desc = "DTLS with delayed server CCS",
883 		.ssl_options = SSL_OP_NO_TICKET,
884 		.server_bbio_off = 1,
885 		.server_delays = { { 5, 2 } },
886 		.write_after_accept = 1,
887 	},
888 	{
889 		.desc = "DTLS with delayed server CCS (initial epoch 0xfffe)",
890 		.ssl_options = SSL_OP_NO_TICKET,
891 		.server_bbio_off = 1,
892 		.initial_epoch = 0xfffe,
893 		.server_delays = { { 5, 2 } },
894 		.write_after_accept = 1,
895 	},
896 	{
897 		.desc = "DTLS with delayed server CCS (initial epoch 0xffff)",
898 		.ssl_options = SSL_OP_NO_TICKET,
899 		.server_bbio_off = 1,
900 		.initial_epoch = 0xffff,
901 		.server_delays = { { 5, 2 } },
902 		.write_after_accept = 1,
903 	},
904 	{
905 		/* Send Finished after app data - this is currently buffered. */
906 		.desc = "DTLS with delayed server Finished",
907 		.ssl_options = SSL_OP_NO_TICKET,
908 		.server_bbio_off = 1,
909 		.server_delays = { { 6, 3 } },
910 		.write_after_accept = 1,
911 	},
912 	{
913 		/* Send CCS after server finished and close-notify. */
914 		.desc = "DTLS with delayed server CCS (close-notify)",
915 		.ssl_options = SSL_OP_NO_TICKET,
916 		.server_bbio_off = 1,
917 		.server_delays = { { 5, 3 } },
918 		.shutdown_after_accept = 1,
919 	},
920 };
921 
922 #define N_DTLS_TESTS (sizeof(dtls_tests) / sizeof(*dtls_tests))
923 
924 static void
dtlstest_packet_monkey(SSL * ssl,const struct dtls_delay delays[],const uint8_t drops[])925 dtlstest_packet_monkey(SSL *ssl, const struct dtls_delay delays[],
926     const uint8_t drops[])
927 {
928 	BIO *bio_monkey;
929 	BIO *bio;
930 	int i;
931 
932 	if ((bio_monkey = BIO_new_packet_monkey()) == NULL)
933 		errx(1, "packet monkey");
934 
935 	for (i = 0; i < MAX_PACKET_DELAYS; i++) {
936 		if (delays[i].packet == 0)
937 			break;
938 		if (!BIO_packet_monkey_delay(bio_monkey, delays[i].packet,
939 		    delays[i].count))
940 			errx(1, "delay failure");
941 	}
942 
943 	for (i = 0; i < MAX_PACKET_DROPS; i++) {
944 		if (drops[i] == 0)
945 			break;
946 		if (!BIO_packet_monkey_drop(bio_monkey, drops[i]))
947 			errx(1, "drop failure");
948 	}
949 
950 	if ((bio = SSL_get_wbio(ssl)) == NULL)
951 		errx(1, "SSL has NULL bio");
952 
953 	BIO_up_ref(bio);
954 	bio = BIO_push(bio_monkey, bio);
955 
956 	SSL_set_bio(ssl, bio, bio);
957 }
958 
959 static int
dtlstest(const struct dtls_test * dt)960 dtlstest(const struct dtls_test *dt)
961 {
962 	SSL *client = NULL, *server = NULL;
963 	ssl_func *connect_func, *accept_func;
964 	struct sockaddr_in server_sin;
965 	struct pollfd pfd[2];
966 	int client_sock = -1;
967 	int server_sock = -1;
968 	int failed = 1;
969 
970 	fprintf(stderr, "\n== Testing %s... ==\n", dt->desc);
971 
972 	if (!datagram_pair(&client_sock, &server_sock, &server_sin))
973 		goto failure;
974 
975 	if ((client = dtls_client(client_sock, &server_sin, dt->mtu)) == NULL)
976 		goto failure;
977 
978 	if ((server = dtls_server(server_sock, dt->ssl_options, dt->mtu)) == NULL)
979 		goto failure;
980 
981 	tls12_record_layer_set_initial_epoch(client->rl, dt->initial_epoch);
982 	tls12_record_layer_set_initial_epoch(server->rl, dt->initial_epoch);
983 
984 	if (dt->client_bbio_off)
985 		SSL_set_info_callback(client, dtls_info_callback);
986 	if (dt->server_bbio_off)
987 		SSL_set_info_callback(server, dtls_info_callback);
988 
989 	dtlstest_packet_monkey(client, dt->client_delays, dt->client_drops);
990 	dtlstest_packet_monkey(server, dt->server_delays, dt->server_drops);
991 
992 	pfd[0].fd = client_sock;
993 	pfd[0].events = POLLOUT;
994 	pfd[1].fd = server_sock;
995 	pfd[1].events = POLLIN;
996 
997 	accept_func = do_accept;
998 	connect_func = do_connect;
999 
1000 	if (dt->write_after_accept) {
1001 		accept_func = do_accept_write;
1002 		connect_func = do_connect_read;
1003 	} else if (dt->shutdown_after_accept) {
1004 		accept_func = do_accept_shutdown;
1005 		connect_func = do_connect_shutdown;
1006 	}
1007 
1008 	if (!do_client_server_loop(client, connect_func, server, accept_func, pfd)) {
1009 		fprintf(stderr, "FAIL: client and server handshake failed\n");
1010 		goto failure;
1011 	}
1012 
1013 	if (dt->write_after_accept || dt->shutdown_after_accept)
1014 		goto done;
1015 
1016 	pfd[0].events = POLLIN;
1017 	pfd[1].events = POLLOUT;
1018 
1019 	if (!do_client_server_loop(client, do_read, server, do_write, pfd)) {
1020 		fprintf(stderr, "FAIL: client read and server write I/O failed\n");
1021 		goto failure;
1022 	}
1023 
1024 	pfd[0].events = POLLOUT;
1025 	pfd[1].events = POLLIN;
1026 
1027 	if (!do_client_server_loop(client, do_write, server, do_read, pfd)) {
1028 		fprintf(stderr, "FAIL: client write and server read I/O failed\n");
1029 		goto failure;
1030 	}
1031 
1032 	pfd[0].events = POLLOUT;
1033 	pfd[1].events = POLLOUT;
1034 
1035 	if (!do_client_server_loop(client, do_shutdown, server, do_shutdown, pfd)) {
1036 		fprintf(stderr, "FAIL: client and server shutdown failed\n");
1037 		goto failure;
1038 	}
1039 
1040  done:
1041 	fprintf(stderr, "INFO: Done!\n");
1042 
1043 	failed = 0;
1044 
1045  failure:
1046 	if (client_sock != -1)
1047 		close(client_sock);
1048 	if (server_sock != -1)
1049 		close(server_sock);
1050 
1051 	SSL_free(client);
1052 	SSL_free(server);
1053 
1054 	return failed;
1055 }
1056 
1057 int
main(int argc,char ** argv)1058 main(int argc, char **argv)
1059 {
1060 	int failed = 0;
1061 	size_t i;
1062 
1063 	if (argc != 4) {
1064 		fprintf(stderr, "usage: %s keyfile certfile cafile\n",
1065 		    argv[0]);
1066 		exit(1);
1067 	}
1068 
1069 	server_key_file = argv[1];
1070 	server_cert_file = argv[2];
1071 	server_ca_file = argv[3];
1072 
1073 	for (i = 0; i < N_DTLS_TESTS; i++)
1074 		failed |= dtlstest(&dtls_tests[i]);
1075 
1076 	return failed;
1077 }
1078