1 /* $OpenBSD: tlstest.c,v 1.2 2023/07/02 17:21:33 beck Exp $ */ 2 /* 3 * Copyright (c) 2020, 2021 Joel Sing <jsing@openbsd.org> 4 * 5 * Permission to use, copy, modify, and distribute this software for any 6 * purpose with or without fee is hereby granted, provided that the above 7 * copyright notice and this permission notice appear in all copies. 8 * 9 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 10 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 11 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 12 * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 13 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 14 * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 15 * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 16 */ 17 18 #include <err.h> 19 20 #include <openssl/bio.h> 21 #include <openssl/err.h> 22 #include <openssl/ssl.h> 23 24 const char *server_ca_file; 25 const char *server_cert_file; 26 const char *server_key_file; 27 28 int debug = 0; 29 30 static void 31 hexdump(const unsigned char *buf, size_t len) 32 { 33 size_t i; 34 35 for (i = 1; i <= len; i++) 36 fprintf(stderr, " 0x%02hhx,%s", buf[i - 1], i % 8 ? "" : "\n"); 37 38 if (len % 8) 39 fprintf(stderr, "\n"); 40 } 41 42 static SSL * 43 tls_client(BIO *rbio, BIO *wbio) 44 { 45 SSL_CTX *ssl_ctx = NULL; 46 SSL *ssl = NULL; 47 48 if ((ssl_ctx = SSL_CTX_new(TLS_method())) == NULL) 49 errx(1, "client context"); 50 51 if ((ssl = SSL_new(ssl_ctx)) == NULL) 52 errx(1, "client ssl"); 53 54 BIO_up_ref(rbio); 55 BIO_up_ref(wbio); 56 57 SSL_set_bio(ssl, rbio, wbio); 58 59 SSL_CTX_free(ssl_ctx); 60 61 return ssl; 62 } 63 64 static SSL * 65 tls_server(BIO *rbio, BIO *wbio) 66 { 67 SSL_CTX *ssl_ctx = NULL; 68 SSL *ssl = NULL; 69 70 if ((ssl_ctx = SSL_CTX_new(TLS_method())) == NULL) 71 errx(1, "server context"); 72 73 SSL_CTX_set_dh_auto(ssl_ctx, 2); 74 75 if (SSL_CTX_use_certificate_file(ssl_ctx, server_cert_file, 76 SSL_FILETYPE_PEM) != 1) { 77 fprintf(stderr, "FAIL: Failed to load server certificate"); 78 goto failure; 79 } 80 if (SSL_CTX_use_PrivateKey_file(ssl_ctx, server_key_file, 81 SSL_FILETYPE_PEM) != 1) { 82 fprintf(stderr, "FAIL: Failed to load server private key"); 83 goto failure; 84 } 85 86 if ((ssl = SSL_new(ssl_ctx)) == NULL) 87 errx(1, "server ssl"); 88 89 BIO_up_ref(rbio); 90 BIO_up_ref(wbio); 91 92 SSL_set_bio(ssl, rbio, wbio); 93 94 failure: 95 SSL_CTX_free(ssl_ctx); 96 97 return ssl; 98 } 99 100 static int 101 ssl_error(SSL *ssl, const char *name, const char *desc, int ssl_ret) 102 { 103 int ssl_err; 104 105 ssl_err = SSL_get_error(ssl, ssl_ret); 106 107 if (ssl_err == SSL_ERROR_WANT_READ) { 108 return 1; 109 } else if (ssl_err == SSL_ERROR_WANT_WRITE) { 110 return 1; 111 } else if (ssl_err == SSL_ERROR_SYSCALL && errno == 0) { 112 /* Yup, this is apparently a thing... */ 113 } else { 114 fprintf(stderr, "FAIL: %s %s failed - ssl err = %d, errno = %d\n", 115 name, desc, ssl_err, errno); 116 ERR_print_errors_fp(stderr); 117 return 0; 118 } 119 120 return 1; 121 } 122 123 static int 124 do_connect(SSL *ssl, const char *name, int *done) 125 { 126 int ssl_ret; 127 128 if ((ssl_ret = SSL_connect(ssl)) == 1) { 129 fprintf(stderr, "INFO: %s connect done\n", name); 130 *done = 1; 131 return 1; 132 } 133 134 return ssl_error(ssl, name, "connect", ssl_ret); 135 } 136 137 static int 138 do_accept(SSL *ssl, const char *name, int *done) 139 { 140 int ssl_ret; 141 142 if ((ssl_ret = SSL_accept(ssl)) == 1) { 143 fprintf(stderr, "INFO: %s accept done\n", name); 144 *done = 1; 145 return 1; 146 } 147 148 return ssl_error(ssl, name, "accept", ssl_ret); 149 } 150 151 static int 152 do_read(SSL *ssl, const char *name, int *done) 153 { 154 uint8_t buf[512]; 155 int ssl_ret; 156 157 if ((ssl_ret = SSL_read(ssl, buf, sizeof(buf))) > 0) { 158 fprintf(stderr, "INFO: %s read done\n", name); 159 if (debug > 1) 160 hexdump(buf, ssl_ret); 161 *done = 1; 162 return 1; 163 } 164 165 return ssl_error(ssl, name, "read", ssl_ret); 166 } 167 168 static int 169 do_write(SSL *ssl, const char *name, int *done) 170 { 171 const uint8_t buf[] = "Hello, World!\n"; 172 int ssl_ret; 173 174 if ((ssl_ret = SSL_write(ssl, buf, sizeof(buf))) > 0) { 175 fprintf(stderr, "INFO: %s write done\n", name); 176 *done = 1; 177 return 1; 178 } 179 180 return ssl_error(ssl, name, "write", ssl_ret); 181 } 182 183 static int 184 do_shutdown(SSL *ssl, const char *name, int *done) 185 { 186 int ssl_ret; 187 188 ssl_ret = SSL_shutdown(ssl); 189 if (ssl_ret == 1) { 190 fprintf(stderr, "INFO: %s shutdown done\n", name); 191 *done = 1; 192 return 1; 193 } 194 return ssl_error(ssl, name, "shutdown", ssl_ret); 195 } 196 197 typedef int (*ssl_func)(SSL *ssl, const char *name, int *done); 198 199 static int 200 do_client_server_loop(SSL *client, ssl_func client_func, SSL *server, 201 ssl_func server_func) 202 { 203 int client_done = 0, server_done = 0; 204 int i = 0; 205 206 do { 207 if (!client_done) { 208 if (debug) 209 fprintf(stderr, "DEBUG: client loop\n"); 210 if (!client_func(client, "client", &client_done)) 211 return 0; 212 } 213 if (!server_done) { 214 if (debug) 215 fprintf(stderr, "DEBUG: server loop\n"); 216 if (!server_func(server, "server", &server_done)) 217 return 0; 218 } 219 } while (i++ < 100 && (!client_done || !server_done)); 220 221 if (!client_done || !server_done) 222 fprintf(stderr, "FAIL: gave up\n"); 223 224 return client_done && server_done; 225 } 226 227 struct tls_test { 228 const unsigned char *desc; 229 const SSL_METHOD *(*client_method)(void); 230 uint16_t client_min_version; 231 uint16_t client_max_version; 232 const char *client_ciphers; 233 const SSL_METHOD *(*server_method)(void); 234 uint16_t server_min_version; 235 uint16_t server_max_version; 236 const char *server_ciphers; 237 }; 238 239 static const struct tls_test tls_tests[] = { 240 { 241 .desc = "Default client and server", 242 }, 243 { 244 .desc = "Default client and TLSv1.2 server", 245 .server_max_version = TLS1_2_VERSION, 246 }, 247 { 248 .desc = "Default client and default server with ECDHE KEX", 249 .server_ciphers = "ECDHE-RSA-AES128-SHA", 250 }, 251 { 252 .desc = "Default client and TLSv1.2 server with ECDHE KEX", 253 .server_max_version = TLS1_2_VERSION, 254 .server_ciphers = "ECDHE-RSA-AES128-SHA", 255 }, 256 { 257 .desc = "Default client and default server with DHE KEX", 258 .server_ciphers = "DHE-RSA-AES128-SHA", 259 }, 260 { 261 .desc = "Default client and TLSv1.2 server with DHE KEX", 262 .server_max_version = TLS1_2_VERSION, 263 .server_ciphers = "DHE-RSA-AES128-SHA", 264 }, 265 { 266 .desc = "Default client and default server with RSA KEX", 267 .server_ciphers = "AES128-SHA", 268 }, 269 { 270 .desc = "Default client and TLSv1.2 server with RSA KEX", 271 .server_max_version = TLS1_2_VERSION, 272 .server_ciphers = "AES128-SHA", 273 }, 274 { 275 .desc = "TLSv1.2 client and default server", 276 .client_max_version = TLS1_2_VERSION, 277 }, 278 { 279 .desc = "TLSv1.2 client and default server with ECDHE KEX", 280 .client_max_version = TLS1_2_VERSION, 281 .client_ciphers = "ECDHE-RSA-AES128-SHA", 282 }, 283 { 284 .desc = "TLSv1.2 client and default server with DHE KEX", 285 .server_max_version = TLS1_2_VERSION, 286 .client_ciphers = "DHE-RSA-AES128-SHA", 287 }, 288 { 289 .desc = "TLSv1.2 client and default server with RSA KEX", 290 .client_max_version = TLS1_2_VERSION, 291 .client_ciphers = "AES128-SHA", 292 }, 293 }; 294 295 #define N_TLS_TESTS (sizeof(tls_tests) / sizeof(*tls_tests)) 296 297 static int 298 tlstest(const struct tls_test *tt) 299 { 300 BIO *client_wbio = NULL, *server_wbio = NULL; 301 SSL *client = NULL, *server = NULL; 302 int failed = 1; 303 304 fprintf(stderr, "\n== Testing %s... ==\n", tt->desc); 305 306 if ((client_wbio = BIO_new(BIO_s_mem())) == NULL) 307 goto failure; 308 if (BIO_set_mem_eof_return(client_wbio, -1) <= 0) 309 goto failure; 310 311 if ((server_wbio = BIO_new(BIO_s_mem())) == NULL) 312 goto failure; 313 if (BIO_set_mem_eof_return(server_wbio, -1) <= 0) 314 goto failure; 315 316 if ((client = tls_client(server_wbio, client_wbio)) == NULL) 317 goto failure; 318 if (tt->client_min_version != 0) { 319 if (!SSL_set_min_proto_version(client, tt->client_min_version)) 320 goto failure; 321 } 322 if (tt->client_max_version != 0) { 323 if (!SSL_set_max_proto_version(client, tt->client_max_version)) 324 goto failure; 325 } 326 if (tt->client_ciphers != NULL) { 327 if (!SSL_set_cipher_list(client, tt->client_ciphers)) 328 goto failure; 329 } 330 331 if ((server = tls_server(client_wbio, server_wbio)) == NULL) 332 goto failure; 333 if (tt->server_min_version != 0) { 334 if (!SSL_set_min_proto_version(server, tt->server_min_version)) 335 goto failure; 336 } 337 if (tt->server_max_version != 0) { 338 if (!SSL_set_max_proto_version(server, tt->server_max_version)) 339 goto failure; 340 } 341 if (tt->server_ciphers != NULL) { 342 if (!SSL_set_cipher_list(server, tt->server_ciphers)) 343 goto failure; 344 } 345 346 if (!do_client_server_loop(client, do_connect, server, do_accept)) { 347 fprintf(stderr, "FAIL: client and server handshake failed\n"); 348 goto failure; 349 } 350 351 if (!do_client_server_loop(client, do_write, server, do_read)) { 352 fprintf(stderr, "FAIL: client write and server read I/O failed\n"); 353 goto failure; 354 } 355 356 if (!do_client_server_loop(client, do_read, server, do_write)) { 357 fprintf(stderr, "FAIL: client read and server write I/O failed\n"); 358 goto failure; 359 } 360 361 if (!do_client_server_loop(client, do_shutdown, server, do_shutdown)) { 362 fprintf(stderr, "FAIL: client and server shutdown failed\n"); 363 goto failure; 364 } 365 366 fprintf(stderr, "INFO: Done!\n"); 367 368 failed = 0; 369 370 failure: 371 BIO_free(client_wbio); 372 BIO_free(server_wbio); 373 374 SSL_free(client); 375 SSL_free(server); 376 377 return failed; 378 } 379 380 int 381 main(int argc, char **argv) 382 { 383 int failed = 0; 384 size_t i; 385 386 if (argc != 4) { 387 fprintf(stderr, "usage: %s keyfile certfile cafile\n", 388 argv[0]); 389 exit(1); 390 } 391 392 server_key_file = argv[1]; 393 server_cert_file = argv[2]; 394 server_ca_file = argv[3]; 395 396 for (i = 0; i < N_TLS_TESTS; i++) 397 failed |= tlstest(&tls_tests[i]); 398 399 return failed; 400 } 401