1 /* $OpenBSD: tlstest.c,v 1.13 2021/04/04 16:19:47 tb Exp $ */
2 /*
3 * Copyright (c) 2017 Joel Sing <jsing@openbsd.org>
4 *
5 * Permission to use, copy, modify, and distribute this software for any
6 * purpose with or without fee is hereby granted, provided that the above
7 * copyright notice and this permission notice appear in all copies.
8 *
9 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
10 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
11 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
12 * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
13 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
14 * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
15 * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
16 */
17
18 #include <sys/socket.h>
19
20 #include <err.h>
21 #include <fcntl.h>
22 #include <stdio.h>
23 #include <string.h>
24 #include <unistd.h>
25
26 #include <tls.h>
27
28 #define CIRCULAR_BUFFER_SIZE 512
29
30 unsigned char client_buffer[CIRCULAR_BUFFER_SIZE];
31 unsigned char *client_readptr, *client_writeptr;
32
33 unsigned char server_buffer[CIRCULAR_BUFFER_SIZE];
34 unsigned char *server_readptr, *server_writeptr;
35
36 char *cafile, *certfile, *keyfile;
37
38 int debug = 0;
39
40 static void
circular_init(void)41 circular_init(void)
42 {
43 client_readptr = client_writeptr = client_buffer;
44 server_readptr = server_writeptr = server_buffer;
45 }
46
47 static ssize_t
circular_read(char * name,unsigned char * buf,size_t bufsize,unsigned char ** readptr,unsigned char * writeptr,unsigned char * outbuf,size_t outlen)48 circular_read(char *name, unsigned char *buf, size_t bufsize,
49 unsigned char **readptr, unsigned char *writeptr,
50 unsigned char *outbuf, size_t outlen)
51 {
52 unsigned char *nextptr = *readptr;
53 size_t n = 0;
54
55 while (n < outlen) {
56 if (nextptr == writeptr)
57 break;
58 *outbuf++ = *nextptr++;
59 if ((size_t)(nextptr - buf) >= bufsize)
60 nextptr = buf;
61 *readptr = nextptr;
62 n++;
63 }
64
65 if (debug && n > 0)
66 fprintf(stderr, "%s buffer: read %zi bytes\n", name, n);
67
68 return (n > 0 ? (ssize_t)n : TLS_WANT_POLLIN);
69 }
70
71 static ssize_t
circular_write(char * name,unsigned char * buf,size_t bufsize,unsigned char * readptr,unsigned char ** writeptr,const unsigned char * inbuf,size_t inlen)72 circular_write(char *name, unsigned char *buf, size_t bufsize,
73 unsigned char *readptr, unsigned char **writeptr,
74 const unsigned char *inbuf, size_t inlen)
75 {
76 unsigned char *nextptr = *writeptr;
77 unsigned char *prevptr;
78 size_t n = 0;
79
80 while (n < inlen) {
81 prevptr = nextptr++;
82 if ((size_t)(nextptr - buf) >= bufsize)
83 nextptr = buf;
84 if (nextptr == readptr)
85 break;
86 *prevptr = *inbuf++;
87 *writeptr = nextptr;
88 n++;
89 }
90
91 if (debug && n > 0)
92 fprintf(stderr, "%s buffer: wrote %zi bytes\n", name, n);
93
94 return (n > 0 ? (ssize_t)n : TLS_WANT_POLLOUT);
95 }
96
97 static ssize_t
client_read(struct tls * ctx,void * buf,size_t buflen,void * cb_arg)98 client_read(struct tls *ctx, void *buf, size_t buflen, void *cb_arg)
99 {
100 return circular_read("client", client_buffer, sizeof(client_buffer),
101 &client_readptr, client_writeptr, buf, buflen);
102 }
103
104 static ssize_t
client_write(struct tls * ctx,const void * buf,size_t buflen,void * cb_arg)105 client_write(struct tls *ctx, const void *buf, size_t buflen, void *cb_arg)
106 {
107 return circular_write("server", server_buffer, sizeof(server_buffer),
108 server_readptr, &server_writeptr, buf, buflen);
109 }
110
111 static ssize_t
server_read(struct tls * ctx,void * buf,size_t buflen,void * cb_arg)112 server_read(struct tls *ctx, void *buf, size_t buflen, void *cb_arg)
113 {
114 return circular_read("server", server_buffer, sizeof(server_buffer),
115 &server_readptr, server_writeptr, buf, buflen);
116 }
117
118 static ssize_t
server_write(struct tls * ctx,const void * buf,size_t buflen,void * cb_arg)119 server_write(struct tls *ctx, const void *buf, size_t buflen, void *cb_arg)
120 {
121 return circular_write("client", client_buffer, sizeof(client_buffer),
122 client_readptr, &client_writeptr, buf, buflen);
123 }
124
125 static int
do_tls_handshake(char * name,struct tls * ctx)126 do_tls_handshake(char *name, struct tls *ctx)
127 {
128 int rv;
129
130 rv = tls_handshake(ctx);
131 if (rv == 0)
132 return (1);
133 if (rv == TLS_WANT_POLLIN || rv == TLS_WANT_POLLOUT)
134 return (0);
135
136 errx(1, "%s handshake failed: %s", name, tls_error(ctx));
137 }
138
139 static int
do_tls_close(char * name,struct tls * ctx)140 do_tls_close(char *name, struct tls *ctx)
141 {
142 int rv;
143
144 rv = tls_close(ctx);
145 if (rv == 0)
146 return (1);
147 if (rv == TLS_WANT_POLLIN || rv == TLS_WANT_POLLOUT)
148 return (0);
149
150 errx(1, "%s close failed: %s", name, tls_error(ctx));
151 }
152
153 static int
do_client_server_handshake(char * desc,struct tls * client,struct tls * server_cctx)154 do_client_server_handshake(char *desc, struct tls *client,
155 struct tls *server_cctx)
156 {
157 int i, client_done, server_done;
158
159 i = client_done = server_done = 0;
160 do {
161 if (client_done == 0)
162 client_done = do_tls_handshake("client", client);
163 if (server_done == 0)
164 server_done = do_tls_handshake("server", server_cctx);
165 } while (i++ < 100 && (client_done == 0 || server_done == 0));
166
167 if (client_done == 0 || server_done == 0) {
168 printf("FAIL: %s TLS handshake did not complete\n", desc);
169 return (1);
170 }
171
172 return (0);
173 }
174
175 static int
do_client_server_close(char * desc,struct tls * client,struct tls * server_cctx)176 do_client_server_close(char *desc, struct tls *client, struct tls *server_cctx)
177 {
178 int i, client_done, server_done;
179
180 i = client_done = server_done = 0;
181 do {
182 if (client_done == 0)
183 client_done = do_tls_close("client", client);
184 if (server_done == 0)
185 server_done = do_tls_close("server", server_cctx);
186 } while (i++ < 100 && (client_done == 0 || server_done == 0));
187
188 if (client_done == 0 || server_done == 0) {
189 printf("FAIL: %s TLS close did not complete\n", desc);
190 return (1);
191 }
192
193 return (0);
194 }
195
196 static int
do_client_server_test(char * desc,struct tls * client,struct tls * server_cctx)197 do_client_server_test(char *desc, struct tls *client, struct tls *server_cctx)
198 {
199 if (do_client_server_handshake(desc, client, server_cctx) != 0)
200 return (1);
201
202 printf("INFO: %s TLS handshake completed successfully\n", desc);
203
204 /* XXX - Do some reads and writes... */
205
206 if (do_client_server_close(desc, client, server_cctx) != 0)
207 return (1);
208
209 printf("INFO: %s TLS close completed successfully\n", desc);
210
211 return (0);
212 }
213
214 static int
test_tls_cbs(struct tls * client,struct tls * server)215 test_tls_cbs(struct tls *client, struct tls *server)
216 {
217 struct tls *server_cctx;
218 int failure;
219
220 circular_init();
221
222 if (tls_accept_cbs(server, &server_cctx, server_read, server_write,
223 NULL) == -1)
224 errx(1, "failed to accept: %s", tls_error(server));
225
226 if (tls_connect_cbs(client, client_read, client_write, NULL,
227 "test") == -1)
228 errx(1, "failed to connect: %s", tls_error(client));
229
230 failure = do_client_server_test("callback", client, server_cctx);
231
232 tls_free(server_cctx);
233
234 return (failure);
235 }
236
237 static int
test_tls_fds(struct tls * client,struct tls * server)238 test_tls_fds(struct tls *client, struct tls *server)
239 {
240 struct tls *server_cctx;
241 int cfds[2], sfds[2];
242 int failure;
243
244 if (pipe2(cfds, O_NONBLOCK) == -1)
245 err(1, "failed to create pipe");
246 if (pipe2(sfds, O_NONBLOCK) == -1)
247 err(1, "failed to create pipe");
248
249 if (tls_accept_fds(server, &server_cctx, sfds[0], cfds[1]) == -1)
250 errx(1, "failed to accept: %s", tls_error(server));
251
252 if (tls_connect_fds(client, cfds[0], sfds[1], "test") == -1)
253 errx(1, "failed to connect: %s", tls_error(client));
254
255 failure = do_client_server_test("file descriptor", client, server_cctx);
256
257 tls_free(server_cctx);
258
259 close(cfds[0]);
260 close(cfds[1]);
261 close(sfds[0]);
262 close(sfds[1]);
263
264 return (failure);
265 }
266
267 static int
test_tls_socket(struct tls * client,struct tls * server)268 test_tls_socket(struct tls *client, struct tls *server)
269 {
270 struct tls *server_cctx;
271 int failure;
272 int sv[2];
273
274 if (socketpair(AF_UNIX, SOCK_STREAM | SOCK_NONBLOCK, PF_UNSPEC,
275 sv) == -1)
276 err(1, "failed to create socketpair");
277
278 if (tls_accept_socket(server, &server_cctx, sv[0]) == -1)
279 errx(1, "failed to accept: %s", tls_error(server));
280
281 if (tls_connect_socket(client, sv[1], "test") == -1)
282 errx(1, "failed to connect: %s", tls_error(client));
283
284 failure = do_client_server_test("socket", client, server_cctx);
285
286 tls_free(server_cctx);
287
288 close(sv[0]);
289 close(sv[1]);
290
291 return (failure);
292 }
293
294 static int
test_tls(char * client_protocols,char * server_protocols,char * ciphers)295 test_tls(char *client_protocols, char *server_protocols, char *ciphers)
296 {
297 struct tls_config *client_cfg, *server_cfg;
298 struct tls *client, *server;
299 uint32_t protocols;
300 int failure = 0;
301
302 if ((client = tls_client()) == NULL)
303 errx(1, "failed to create tls client");
304 if ((client_cfg = tls_config_new()) == NULL)
305 errx(1, "failed to create tls client config");
306 tls_config_insecure_noverifyname(client_cfg);
307 if (tls_config_parse_protocols(&protocols, client_protocols) == -1)
308 errx(1, "failed to parse protocols: %s", tls_config_error(client_cfg));
309 if (tls_config_set_protocols(client_cfg, protocols) == -1)
310 errx(1, "failed to set protocols: %s", tls_config_error(client_cfg));
311 if (tls_config_set_ciphers(client_cfg, ciphers) == -1)
312 errx(1, "failed to set ciphers: %s", tls_config_error(client_cfg));
313 if (tls_config_set_ca_file(client_cfg, cafile) == -1)
314 errx(1, "failed to set ca: %s", tls_config_error(client_cfg));
315
316 if ((server = tls_server()) == NULL)
317 errx(1, "failed to create tls server");
318 if ((server_cfg = tls_config_new()) == NULL)
319 errx(1, "failed to create tls server config");
320 if (tls_config_parse_protocols(&protocols, server_protocols) == -1)
321 errx(1, "failed to parse protocols: %s", tls_config_error(server_cfg));
322 if (tls_config_set_protocols(server_cfg, protocols) == -1)
323 errx(1, "failed to set protocols: %s", tls_config_error(server_cfg));
324 if (tls_config_set_ciphers(server_cfg, ciphers) == -1)
325 errx(1, "failed to set ciphers: %s", tls_config_error(server_cfg));
326 if (tls_config_set_keypair_file(server_cfg, certfile, keyfile) == -1)
327 errx(1, "failed to set keypair: %s",
328 tls_config_error(server_cfg));
329
330 if (tls_configure(client, client_cfg) == -1)
331 errx(1, "failed to configure client: %s", tls_error(client));
332 tls_reset(server);
333 if (tls_configure(server, server_cfg) == -1)
334 errx(1, "failed to configure server: %s", tls_error(server));
335
336 tls_config_free(client_cfg);
337 tls_config_free(server_cfg);
338
339 failure |= test_tls_cbs(client, server);
340
341 tls_free(client);
342 tls_free(server);
343
344 return (failure);
345 }
346
347 static int
do_tls_tests(void)348 do_tls_tests(void)
349 {
350 struct tls_config *client_cfg, *server_cfg;
351 struct tls *client, *server;
352 int failure = 0;
353
354 printf("== TLS tests ==\n");
355
356 if ((client = tls_client()) == NULL)
357 errx(1, "failed to create tls client");
358 if ((client_cfg = tls_config_new()) == NULL)
359 errx(1, "failed to create tls client config");
360 tls_config_insecure_noverifyname(client_cfg);
361 if (tls_config_set_ca_file(client_cfg, cafile) == -1)
362 errx(1, "failed to set ca: %s", tls_config_error(client_cfg));
363
364 if ((server = tls_server()) == NULL)
365 errx(1, "failed to create tls server");
366 if ((server_cfg = tls_config_new()) == NULL)
367 errx(1, "failed to create tls server config");
368 if (tls_config_set_keypair_file(server_cfg, certfile, keyfile) == -1)
369 errx(1, "failed to set keypair: %s",
370 tls_config_error(server_cfg));
371
372 tls_reset(client);
373 if (tls_configure(client, client_cfg) == -1)
374 errx(1, "failed to configure client: %s", tls_error(client));
375 tls_reset(server);
376 if (tls_configure(server, server_cfg) == -1)
377 errx(1, "failed to configure server: %s", tls_error(server));
378
379 failure |= test_tls_cbs(client, server);
380
381 tls_reset(client);
382 if (tls_configure(client, client_cfg) == -1)
383 errx(1, "failed to configure client: %s", tls_error(client));
384 tls_reset(server);
385 if (tls_configure(server, server_cfg) == -1)
386 errx(1, "failed to configure server: %s", tls_error(server));
387
388 failure |= test_tls_fds(client, server);
389
390 tls_reset(client);
391 if (tls_configure(client, client_cfg) == -1)
392 errx(1, "failed to configure client: %s", tls_error(client));
393 tls_reset(server);
394 if (tls_configure(server, server_cfg) == -1)
395 errx(1, "failed to configure server: %s", tls_error(server));
396
397 tls_config_free(client_cfg);
398 tls_config_free(server_cfg);
399
400 failure |= test_tls_socket(client, server);
401
402 tls_free(client);
403 tls_free(server);
404
405 printf("\n");
406
407 return (failure);
408 }
409
410 static int
do_tls_ordering_tests(void)411 do_tls_ordering_tests(void)
412 {
413 struct tls *client = NULL, *server = NULL, *server_cctx = NULL;
414 struct tls_config *client_cfg, *server_cfg;
415 int failure = 0;
416
417 printf("== TLS ordering tests ==\n");
418
419 if ((client = tls_client()) == NULL)
420 errx(1, "failed to create tls client");
421 if ((client_cfg = tls_config_new()) == NULL)
422 errx(1, "failed to create tls client config");
423 tls_config_insecure_noverifyname(client_cfg);
424 if (tls_config_set_ca_file(client_cfg, cafile) == -1)
425 errx(1, "failed to set ca: %s", tls_config_error(client_cfg));
426
427 if ((server = tls_server()) == NULL)
428 errx(1, "failed to create tls server");
429 if ((server_cfg = tls_config_new()) == NULL)
430 errx(1, "failed to create tls server config");
431 if (tls_config_set_keypair_file(server_cfg, certfile, keyfile) == -1)
432 errx(1, "failed to set keypair: %s",
433 tls_config_error(server_cfg));
434
435 if (tls_configure(client, client_cfg) == -1)
436 errx(1, "failed to configure client: %s", tls_error(client));
437 if (tls_configure(server, server_cfg) == -1)
438 errx(1, "failed to configure server: %s", tls_error(server));
439
440 tls_config_free(client_cfg);
441 tls_config_free(server_cfg);
442
443 if (tls_handshake(client) != -1) {
444 printf("FAIL: TLS handshake succeeded on unconnnected "
445 "client context\n");
446 failure = 1;
447 goto done;
448 }
449
450 circular_init();
451
452 if (tls_accept_cbs(server, &server_cctx, server_read, server_write,
453 NULL) == -1)
454 errx(1, "failed to accept: %s", tls_error(server));
455
456 if (tls_connect_cbs(client, client_read, client_write, NULL,
457 "test") == -1)
458 errx(1, "failed to connect: %s", tls_error(client));
459
460 if (do_client_server_handshake("ordering", client, server_cctx) != 0) {
461 failure = 1;
462 goto done;
463 }
464
465 if (tls_handshake(client) != -1) {
466 printf("FAIL: TLS handshake succeeded twice\n");
467 failure = 1;
468 goto done;
469 }
470
471 if (tls_handshake(server_cctx) != -1) {
472 printf("FAIL: TLS handshake succeeded twice\n");
473 failure = 1;
474 goto done;
475 }
476
477 if (do_client_server_close("ordering", client, server_cctx) != 0) {
478 failure = 1;
479 goto done;
480 }
481
482 done:
483 tls_free(client);
484 tls_free(server);
485 tls_free(server_cctx);
486
487 printf("\n");
488
489 return (failure);
490 }
491
492 struct test_versions {
493 char *client;
494 char *server;
495 };
496
497 static struct test_versions tls_test_versions[] = {
498 {"tlsv1.3", "all"},
499 {"tlsv1.2", "all"},
500 {"tlsv1.1", "all"},
501 {"tlsv1.0", "all"},
502 {"all", "tlsv1.3"},
503 {"all", "tlsv1.2"},
504 {"all", "tlsv1.1"},
505 {"all", "tlsv1.0"},
506 {"tlsv1.3", "tlsv1.3"},
507 {"tlsv1.2", "tlsv1.2"},
508 {"tlsv1.1", "tlsv1.1"},
509 {"tlsv1.0", "tlsv1.0"},
510 };
511
512 #define N_TLS_VERSION_TESTS \
513 (sizeof(tls_test_versions) / sizeof(*tls_test_versions))
514
515 static int
do_tls_version_tests(void)516 do_tls_version_tests(void)
517 {
518 struct test_versions *tv;
519 int failure = 0;
520 size_t i;
521
522 printf("== TLS version tests ==\n");
523
524 for (i = 0; i < N_TLS_VERSION_TESTS; i++) {
525 tv = &tls_test_versions[i];
526 printf("INFO: version test %zu - client versions '%s' "
527 "and server versions '%s'\n", i, tv->client, tv->server);
528 failure |= test_tls(tv->client, tv->server, "legacy");
529 printf("\n");
530 }
531
532 return failure;
533 }
534
535 int
main(int argc,char ** argv)536 main(int argc, char **argv)
537 {
538 int failure = 0;
539
540 if (argc != 4) {
541 fprintf(stderr, "usage: %s cafile certfile keyfile\n",
542 argv[0]);
543 return (1);
544 }
545
546 cafile = argv[1];
547 certfile = argv[2];
548 keyfile = argv[3];
549
550 failure |= do_tls_tests();
551 failure |= do_tls_ordering_tests();
552 failure |= do_tls_version_tests();
553
554 return (failure);
555 }
556