1 /* $OpenBSD: handshake_table.c,v 1.18 2022/12/01 13:49:12 tb Exp $ */ 2 /* 3 * Copyright (c) 2019 Theo Buehler <tb@openbsd.org> 4 * 5 * Permission to use, copy, modify, and distribute this software for any 6 * purpose with or without fee is hereby granted, provided that the above 7 * copyright notice and this permission notice appear in all copies. 8 * 9 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 10 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 11 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR 12 * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 13 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN 14 * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF 15 * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 16 */ 17 18 #include <err.h> 19 #include <stdint.h> 20 #include <stdio.h> 21 #include <stdlib.h> 22 #include <unistd.h> 23 24 #include "tls13_handshake.h" 25 26 #define MAX_FLAGS (UINT8_MAX + 1) 27 28 /* 29 * From RFC 8446: 30 * 31 * Appendix A. State Machine 32 * 33 * This appendix provides a summary of the legal state transitions for 34 * the client and server handshakes. State names (in all capitals, 35 * e.g., START) have no formal meaning but are provided for ease of 36 * comprehension. Actions which are taken only in certain circumstances 37 * are indicated in []. The notation "K_{send,recv} = foo" means "set 38 * the send/recv key to the given key". 39 * 40 * A.1. Client 41 * 42 * START <----+ 43 * Send ClientHello | | Recv HelloRetryRequest 44 * [K_send = early data] | | 45 * v | 46 * / WAIT_SH ----+ 47 * | | Recv ServerHello 48 * | | K_recv = handshake 49 * Can | V 50 * send | WAIT_EE 51 * early | | Recv EncryptedExtensions 52 * data | +--------+--------+ 53 * | Using | | Using certificate 54 * | PSK | v 55 * | | WAIT_CERT_CR 56 * | | Recv | | Recv CertificateRequest 57 * | | Certificate | v 58 * | | | WAIT_CERT 59 * | | | | Recv Certificate 60 * | | v v 61 * | | WAIT_CV 62 * | | | Recv CertificateVerify 63 * | +> WAIT_FINISHED <+ 64 * | | Recv Finished 65 * \ | [Send EndOfEarlyData] 66 * | K_send = handshake 67 * | [Send Certificate [+ CertificateVerify]] 68 * Can send | Send Finished 69 * app data --> | K_send = K_recv = application 70 * after here v 71 * CONNECTED 72 * 73 * Note that with the transitions as shown above, clients may send 74 * alerts that derive from post-ServerHello messages in the clear or 75 * with the early data keys. If clients need to send such alerts, they 76 * SHOULD first rekey to the handshake keys if possible. 77 * 78 */ 79 80 struct child { 81 enum tls13_message_type mt; 82 uint8_t flag; 83 uint8_t forced; 84 uint8_t illegal; 85 }; 86 87 static struct child stateinfo[][TLS13_NUM_MESSAGE_TYPES] = { 88 [CLIENT_HELLO] = { 89 { 90 .mt = SERVER_HELLO_RETRY_REQUEST, 91 }, 92 { 93 .mt = SERVER_HELLO, 94 .flag = WITHOUT_HRR, 95 }, 96 }, 97 [SERVER_HELLO_RETRY_REQUEST] = { 98 { 99 .mt = CLIENT_HELLO_RETRY, 100 }, 101 }, 102 [CLIENT_HELLO_RETRY] = { 103 { 104 .mt = SERVER_HELLO, 105 }, 106 }, 107 [SERVER_HELLO] = { 108 { 109 .mt = SERVER_ENCRYPTED_EXTENSIONS, 110 }, 111 }, 112 [SERVER_ENCRYPTED_EXTENSIONS] = { 113 { 114 .mt = SERVER_CERTIFICATE_REQUEST, 115 }, 116 { .mt = SERVER_CERTIFICATE, 117 .flag = WITHOUT_CR, 118 }, 119 { 120 .mt = SERVER_FINISHED, 121 .flag = WITH_PSK, 122 }, 123 }, 124 [SERVER_CERTIFICATE_REQUEST] = { 125 { 126 .mt = SERVER_CERTIFICATE, 127 }, 128 }, 129 [SERVER_CERTIFICATE] = { 130 { 131 .mt = SERVER_CERTIFICATE_VERIFY, 132 }, 133 }, 134 [SERVER_CERTIFICATE_VERIFY] = { 135 { 136 .mt = SERVER_FINISHED, 137 }, 138 }, 139 [SERVER_FINISHED] = { 140 { 141 .mt = CLIENT_FINISHED, 142 .forced = WITHOUT_CR | WITH_PSK, 143 }, 144 { 145 .mt = CLIENT_CERTIFICATE, 146 .illegal = WITHOUT_CR | WITH_PSK, 147 }, 148 }, 149 [CLIENT_CERTIFICATE] = { 150 { 151 .mt = CLIENT_FINISHED, 152 }, 153 { 154 .mt = CLIENT_CERTIFICATE_VERIFY, 155 .flag = WITH_CCV, 156 }, 157 }, 158 [CLIENT_CERTIFICATE_VERIFY] = { 159 { 160 .mt = CLIENT_FINISHED, 161 }, 162 }, 163 [CLIENT_FINISHED] = { 164 { 165 .mt = APPLICATION_DATA, 166 }, 167 }, 168 [APPLICATION_DATA] = { 169 { 170 .mt = 0, 171 }, 172 }, 173 }; 174 175 const size_t stateinfo_count = sizeof(stateinfo) / sizeof(stateinfo[0]); 176 177 void build_table(enum tls13_message_type 178 table[MAX_FLAGS][TLS13_NUM_MESSAGE_TYPES], 179 struct child current, struct child end, 180 struct child path[], uint8_t flags, unsigned int depth); 181 size_t count_handshakes(void); 182 void edge(enum tls13_message_type start, 183 enum tls13_message_type end, uint8_t flag); 184 const char *flag2str(uint8_t flag); 185 void flag_label(uint8_t flag); 186 void forced_edges(enum tls13_message_type start, 187 enum tls13_message_type end, uint8_t forced); 188 int generate_graphics(void); 189 void fprint_entry(FILE *stream, 190 enum tls13_message_type path[TLS13_NUM_MESSAGE_TYPES], 191 uint8_t flags); 192 void fprint_flags(FILE *stream, uint8_t flags); 193 const char *mt2str(enum tls13_message_type mt); 194 void usage(void); 195 int verify_table(enum tls13_message_type 196 table[MAX_FLAGS][TLS13_NUM_MESSAGE_TYPES], int print); 197 198 const char * 199 flag2str(uint8_t flag) 200 { 201 const char *ret; 202 203 if (flag & (flag - 1)) 204 errx(1, "more than one bit is set"); 205 206 switch (flag) { 207 case INITIAL: 208 ret = "INITIAL"; 209 break; 210 case NEGOTIATED: 211 ret = "NEGOTIATED"; 212 break; 213 case WITHOUT_CR: 214 ret = "WITHOUT_CR"; 215 break; 216 case WITHOUT_HRR: 217 ret = "WITHOUT_HRR"; 218 break; 219 case WITH_PSK: 220 ret = "WITH_PSK"; 221 break; 222 case WITH_CCV: 223 ret = "WITH_CCV"; 224 break; 225 case WITH_0RTT: 226 ret = "WITH_0RTT"; 227 break; 228 default: 229 ret = "UNKNOWN"; 230 } 231 232 return ret; 233 } 234 235 const char * 236 mt2str(enum tls13_message_type mt) 237 { 238 const char *ret; 239 240 switch (mt) { 241 case INVALID: 242 ret = "INVALID"; 243 break; 244 case CLIENT_HELLO: 245 ret = "CLIENT_HELLO"; 246 break; 247 case CLIENT_HELLO_RETRY: 248 ret = "CLIENT_HELLO_RETRY"; 249 break; 250 case CLIENT_END_OF_EARLY_DATA: 251 ret = "CLIENT_END_OF_EARLY_DATA"; 252 break; 253 case CLIENT_CERTIFICATE: 254 ret = "CLIENT_CERTIFICATE"; 255 break; 256 case CLIENT_CERTIFICATE_VERIFY: 257 ret = "CLIENT_CERTIFICATE_VERIFY"; 258 break; 259 case CLIENT_FINISHED: 260 ret = "CLIENT_FINISHED"; 261 break; 262 case SERVER_HELLO: 263 ret = "SERVER_HELLO"; 264 break; 265 case SERVER_HELLO_RETRY_REQUEST: 266 ret = "SERVER_HELLO_RETRY_REQUEST"; 267 break; 268 case SERVER_ENCRYPTED_EXTENSIONS: 269 ret = "SERVER_ENCRYPTED_EXTENSIONS"; 270 break; 271 case SERVER_CERTIFICATE: 272 ret = "SERVER_CERTIFICATE"; 273 break; 274 case SERVER_CERTIFICATE_VERIFY: 275 ret = "SERVER_CERTIFICATE_VERIFY"; 276 break; 277 case SERVER_CERTIFICATE_REQUEST: 278 ret = "SERVER_CERTIFICATE_REQUEST"; 279 break; 280 case SERVER_FINISHED: 281 ret = "SERVER_FINISHED"; 282 break; 283 case APPLICATION_DATA: 284 ret = "APPLICATION_DATA"; 285 break; 286 case TLS13_NUM_MESSAGE_TYPES: 287 ret = "TLS13_NUM_MESSAGE_TYPES"; 288 break; 289 default: 290 ret = "UNKNOWN"; 291 break; 292 } 293 294 return ret; 295 } 296 297 void 298 fprint_flags(FILE *stream, uint8_t flags) 299 { 300 int first = 1, i; 301 302 if (flags == 0) { 303 fprintf(stream, "%s", flag2str(flags)); 304 return; 305 } 306 307 for (i = 0; i < 8; i++) { 308 uint8_t set = flags & (1U << i); 309 310 if (set) { 311 fprintf(stream, "%s%s", first ? "" : " | ", 312 flag2str(set)); 313 first = 0; 314 } 315 } 316 } 317 318 void 319 fprint_entry(FILE *stream, 320 enum tls13_message_type path[TLS13_NUM_MESSAGE_TYPES], uint8_t flags) 321 { 322 int i; 323 324 fprintf(stream, "\t["); 325 fprint_flags(stream, flags); 326 fprintf(stream, "] = {\n"); 327 328 for (i = 0; i < TLS13_NUM_MESSAGE_TYPES; i++) { 329 if (path[i] == 0) 330 break; 331 fprintf(stream, "\t\t%s,\n", mt2str(path[i])); 332 } 333 fprintf(stream, "\t},\n"); 334 } 335 336 void 337 edge(enum tls13_message_type start, enum tls13_message_type end, 338 uint8_t flag) 339 { 340 printf("\t%s -> %s", mt2str(start), mt2str(end)); 341 flag_label(flag); 342 printf(";\n"); 343 } 344 345 void 346 flag_label(uint8_t flag) 347 { 348 if (flag) 349 printf(" [label=\"%s\"]", flag2str(flag)); 350 } 351 352 void 353 forced_edges(enum tls13_message_type start, enum tls13_message_type end, 354 uint8_t forced) 355 { 356 uint8_t forced_flag, i; 357 358 if (forced == 0) 359 return; 360 361 for (i = 0; i < 8; i++) { 362 forced_flag = forced & (1U << i); 363 if (forced_flag) 364 edge(start, end, forced_flag); 365 } 366 } 367 368 int 369 generate_graphics(void) 370 { 371 enum tls13_message_type start, end; 372 unsigned int child; 373 uint8_t flag; 374 uint8_t forced; 375 376 printf("digraph G {\n"); 377 printf("\t%s [shape=box];\n", mt2str(CLIENT_HELLO)); 378 printf("\t%s [shape=box];\n", mt2str(APPLICATION_DATA)); 379 380 for (start = CLIENT_HELLO; start < APPLICATION_DATA; start++) { 381 for (child = 0; stateinfo[start][child].mt != 0; child++) { 382 end = stateinfo[start][child].mt; 383 flag = stateinfo[start][child].flag; 384 forced = stateinfo[start][child].forced; 385 386 if (forced == 0) 387 edge(start, end, flag); 388 else 389 forced_edges(start, end, forced); 390 } 391 } 392 393 printf("}\n"); 394 return 0; 395 } 396 397 extern enum tls13_message_type handshakes[][TLS13_NUM_MESSAGE_TYPES]; 398 extern size_t handshake_count; 399 400 size_t 401 count_handshakes(void) 402 { 403 size_t ret = 0, i; 404 405 for (i = 0; i < handshake_count; i++) { 406 if (handshakes[i][0] != INVALID) 407 ret++; 408 } 409 410 return ret; 411 } 412 413 void 414 build_table(enum tls13_message_type table[MAX_FLAGS][TLS13_NUM_MESSAGE_TYPES], 415 struct child current, struct child end, struct child path[], uint8_t flags, 416 unsigned int depth) 417 { 418 unsigned int i; 419 420 if (depth >= TLS13_NUM_MESSAGE_TYPES - 1) 421 errx(1, "recursed too deeply"); 422 423 /* Record current node. */ 424 path[depth++] = current; 425 flags |= current.flag; 426 427 /* If we haven't reached the end, recurse over the children. */ 428 if (current.mt != end.mt) { 429 for (i = 0; stateinfo[current.mt][i].mt != 0; i++) { 430 struct child child = stateinfo[current.mt][i]; 431 int forced = stateinfo[current.mt][i].forced; 432 int illegal = stateinfo[current.mt][i].illegal; 433 434 if ((forced == 0 || (forced & flags)) && 435 (illegal == 0 || !(illegal & flags))) 436 build_table(table, child, end, path, flags, 437 depth); 438 } 439 return; 440 } 441 442 if (flags == 0) 443 errx(1, "path does not set flags"); 444 445 if (table[flags][0] != 0) 446 errx(1, "path traversed twice"); 447 448 for (i = 0; i < depth; i++) 449 table[flags][i] = path[i].mt; 450 } 451 452 int 453 verify_table(enum tls13_message_type table[MAX_FLAGS][TLS13_NUM_MESSAGE_TYPES], 454 int print) 455 { 456 int success = 1, i; 457 size_t num_valid, num_found = 0; 458 uint8_t flags = 0; 459 460 do { 461 if (table[flags][0] == 0) 462 continue; 463 464 num_found++; 465 466 for (i = 0; i < TLS13_NUM_MESSAGE_TYPES; i++) { 467 if (table[flags][i] != handshakes[flags][i]) { 468 fprintf(stderr, 469 "incorrect entry %d of handshake ", i); 470 fprint_flags(stderr, flags); 471 fprintf(stderr, "\n"); 472 success = 0; 473 } 474 } 475 476 if (print) 477 fprint_entry(stdout, table[flags], flags); 478 } while(++flags != 0); 479 480 num_valid = count_handshakes(); 481 if (num_valid != num_found) { 482 fprintf(stderr, 483 "incorrect number of handshakes: want %zu, got %zu.\n", 484 num_valid, num_found); 485 success = 0; 486 } 487 488 return success; 489 } 490 491 void 492 usage(void) 493 { 494 fprintf(stderr, "usage: handshake_table [-C | -g]\n"); 495 exit(1); 496 } 497 498 int 499 main(int argc, char *argv[]) 500 { 501 static enum tls13_message_type 502 hs_table[MAX_FLAGS][TLS13_NUM_MESSAGE_TYPES] = { 503 [INITIAL] = { 504 CLIENT_HELLO, 505 SERVER_HELLO_RETRY_REQUEST, 506 CLIENT_HELLO_RETRY, 507 SERVER_HELLO, 508 }, 509 }; 510 struct child start = { 511 .mt = CLIENT_HELLO, 512 }; 513 struct child end = { 514 .mt = APPLICATION_DATA, 515 }; 516 struct child path[TLS13_NUM_MESSAGE_TYPES] = {{0}}; 517 uint8_t flags = NEGOTIATED; 518 unsigned int depth = 0; 519 int ch, graphviz = 0, print = 0; 520 521 while ((ch = getopt(argc, argv, "Cg")) != -1) { 522 switch (ch) { 523 case 'C': 524 print = 1; 525 break; 526 case 'g': 527 graphviz = 1; 528 break; 529 default: 530 usage(); 531 } 532 } 533 argc -= optind; 534 argv += optind; 535 536 if (argc != 0) 537 usage(); 538 539 if (graphviz && print) 540 usage(); 541 542 if (graphviz) 543 return generate_graphics(); 544 545 build_table(hs_table, start, end, path, flags, depth); 546 if (!verify_table(hs_table, print)) 547 return 1; 548 549 return 0; 550 } 551