1 /* $OpenBSD: handshake_table.c,v 1.18 2022/12/01 13:49:12 tb Exp $ */
2 /*
3 * Copyright (c) 2019 Theo Buehler <tb@openbsd.org>
4 *
5 * Permission to use, copy, modify, and distribute this software for any
6 * purpose with or without fee is hereby granted, provided that the above
7 * copyright notice and this permission notice appear in all copies.
8 *
9 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
10 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
11 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
12 * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
13 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
14 * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
15 * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
16 */
17
18 #include <err.h>
19 #include <stdint.h>
20 #include <stdio.h>
21 #include <stdlib.h>
22 #include <unistd.h>
23
24 #include "tls13_handshake.h"
25
26 #define MAX_FLAGS (UINT8_MAX + 1)
27
28 /*
29 * From RFC 8446:
30 *
31 * Appendix A. State Machine
32 *
33 * This appendix provides a summary of the legal state transitions for
34 * the client and server handshakes. State names (in all capitals,
35 * e.g., START) have no formal meaning but are provided for ease of
36 * comprehension. Actions which are taken only in certain circumstances
37 * are indicated in []. The notation "K_{send,recv} = foo" means "set
38 * the send/recv key to the given key".
39 *
40 * A.1. Client
41 *
42 * START <----+
43 * Send ClientHello | | Recv HelloRetryRequest
44 * [K_send = early data] | |
45 * v |
46 * / WAIT_SH ----+
47 * | | Recv ServerHello
48 * | | K_recv = handshake
49 * Can | V
50 * send | WAIT_EE
51 * early | | Recv EncryptedExtensions
52 * data | +--------+--------+
53 * | Using | | Using certificate
54 * | PSK | v
55 * | | WAIT_CERT_CR
56 * | | Recv | | Recv CertificateRequest
57 * | | Certificate | v
58 * | | | WAIT_CERT
59 * | | | | Recv Certificate
60 * | | v v
61 * | | WAIT_CV
62 * | | | Recv CertificateVerify
63 * | +> WAIT_FINISHED <+
64 * | | Recv Finished
65 * \ | [Send EndOfEarlyData]
66 * | K_send = handshake
67 * | [Send Certificate [+ CertificateVerify]]
68 * Can send | Send Finished
69 * app data --> | K_send = K_recv = application
70 * after here v
71 * CONNECTED
72 *
73 * Note that with the transitions as shown above, clients may send
74 * alerts that derive from post-ServerHello messages in the clear or
75 * with the early data keys. If clients need to send such alerts, they
76 * SHOULD first rekey to the handshake keys if possible.
77 *
78 */
79
80 struct child {
81 enum tls13_message_type mt;
82 uint8_t flag;
83 uint8_t forced;
84 uint8_t illegal;
85 };
86
87 static struct child stateinfo[][TLS13_NUM_MESSAGE_TYPES] = {
88 [CLIENT_HELLO] = {
89 {
90 .mt = SERVER_HELLO_RETRY_REQUEST,
91 },
92 {
93 .mt = SERVER_HELLO,
94 .flag = WITHOUT_HRR,
95 },
96 },
97 [SERVER_HELLO_RETRY_REQUEST] = {
98 {
99 .mt = CLIENT_HELLO_RETRY,
100 },
101 },
102 [CLIENT_HELLO_RETRY] = {
103 {
104 .mt = SERVER_HELLO,
105 },
106 },
107 [SERVER_HELLO] = {
108 {
109 .mt = SERVER_ENCRYPTED_EXTENSIONS,
110 },
111 },
112 [SERVER_ENCRYPTED_EXTENSIONS] = {
113 {
114 .mt = SERVER_CERTIFICATE_REQUEST,
115 },
116 { .mt = SERVER_CERTIFICATE,
117 .flag = WITHOUT_CR,
118 },
119 {
120 .mt = SERVER_FINISHED,
121 .flag = WITH_PSK,
122 },
123 },
124 [SERVER_CERTIFICATE_REQUEST] = {
125 {
126 .mt = SERVER_CERTIFICATE,
127 },
128 },
129 [SERVER_CERTIFICATE] = {
130 {
131 .mt = SERVER_CERTIFICATE_VERIFY,
132 },
133 },
134 [SERVER_CERTIFICATE_VERIFY] = {
135 {
136 .mt = SERVER_FINISHED,
137 },
138 },
139 [SERVER_FINISHED] = {
140 {
141 .mt = CLIENT_FINISHED,
142 .forced = WITHOUT_CR | WITH_PSK,
143 },
144 {
145 .mt = CLIENT_CERTIFICATE,
146 .illegal = WITHOUT_CR | WITH_PSK,
147 },
148 },
149 [CLIENT_CERTIFICATE] = {
150 {
151 .mt = CLIENT_FINISHED,
152 },
153 {
154 .mt = CLIENT_CERTIFICATE_VERIFY,
155 .flag = WITH_CCV,
156 },
157 },
158 [CLIENT_CERTIFICATE_VERIFY] = {
159 {
160 .mt = CLIENT_FINISHED,
161 },
162 },
163 [CLIENT_FINISHED] = {
164 {
165 .mt = APPLICATION_DATA,
166 },
167 },
168 [APPLICATION_DATA] = {
169 {
170 .mt = 0,
171 },
172 },
173 };
174
175 const size_t stateinfo_count = sizeof(stateinfo) / sizeof(stateinfo[0]);
176
177 void build_table(enum tls13_message_type
178 table[MAX_FLAGS][TLS13_NUM_MESSAGE_TYPES],
179 struct child current, struct child end,
180 struct child path[], uint8_t flags, unsigned int depth);
181 size_t count_handshakes(void);
182 void edge(enum tls13_message_type start,
183 enum tls13_message_type end, uint8_t flag);
184 const char *flag2str(uint8_t flag);
185 void flag_label(uint8_t flag);
186 void forced_edges(enum tls13_message_type start,
187 enum tls13_message_type end, uint8_t forced);
188 int generate_graphics(void);
189 void fprint_entry(FILE *stream,
190 enum tls13_message_type path[TLS13_NUM_MESSAGE_TYPES],
191 uint8_t flags);
192 void fprint_flags(FILE *stream, uint8_t flags);
193 const char *mt2str(enum tls13_message_type mt);
194 void usage(void);
195 int verify_table(enum tls13_message_type
196 table[MAX_FLAGS][TLS13_NUM_MESSAGE_TYPES], int print);
197
198 const char *
flag2str(uint8_t flag)199 flag2str(uint8_t flag)
200 {
201 const char *ret;
202
203 if (flag & (flag - 1))
204 errx(1, "more than one bit is set");
205
206 switch (flag) {
207 case INITIAL:
208 ret = "INITIAL";
209 break;
210 case NEGOTIATED:
211 ret = "NEGOTIATED";
212 break;
213 case WITHOUT_CR:
214 ret = "WITHOUT_CR";
215 break;
216 case WITHOUT_HRR:
217 ret = "WITHOUT_HRR";
218 break;
219 case WITH_PSK:
220 ret = "WITH_PSK";
221 break;
222 case WITH_CCV:
223 ret = "WITH_CCV";
224 break;
225 case WITH_0RTT:
226 ret = "WITH_0RTT";
227 break;
228 default:
229 ret = "UNKNOWN";
230 }
231
232 return ret;
233 }
234
235 const char *
mt2str(enum tls13_message_type mt)236 mt2str(enum tls13_message_type mt)
237 {
238 const char *ret;
239
240 switch (mt) {
241 case INVALID:
242 ret = "INVALID";
243 break;
244 case CLIENT_HELLO:
245 ret = "CLIENT_HELLO";
246 break;
247 case CLIENT_HELLO_RETRY:
248 ret = "CLIENT_HELLO_RETRY";
249 break;
250 case CLIENT_END_OF_EARLY_DATA:
251 ret = "CLIENT_END_OF_EARLY_DATA";
252 break;
253 case CLIENT_CERTIFICATE:
254 ret = "CLIENT_CERTIFICATE";
255 break;
256 case CLIENT_CERTIFICATE_VERIFY:
257 ret = "CLIENT_CERTIFICATE_VERIFY";
258 break;
259 case CLIENT_FINISHED:
260 ret = "CLIENT_FINISHED";
261 break;
262 case SERVER_HELLO:
263 ret = "SERVER_HELLO";
264 break;
265 case SERVER_HELLO_RETRY_REQUEST:
266 ret = "SERVER_HELLO_RETRY_REQUEST";
267 break;
268 case SERVER_ENCRYPTED_EXTENSIONS:
269 ret = "SERVER_ENCRYPTED_EXTENSIONS";
270 break;
271 case SERVER_CERTIFICATE:
272 ret = "SERVER_CERTIFICATE";
273 break;
274 case SERVER_CERTIFICATE_VERIFY:
275 ret = "SERVER_CERTIFICATE_VERIFY";
276 break;
277 case SERVER_CERTIFICATE_REQUEST:
278 ret = "SERVER_CERTIFICATE_REQUEST";
279 break;
280 case SERVER_FINISHED:
281 ret = "SERVER_FINISHED";
282 break;
283 case APPLICATION_DATA:
284 ret = "APPLICATION_DATA";
285 break;
286 case TLS13_NUM_MESSAGE_TYPES:
287 ret = "TLS13_NUM_MESSAGE_TYPES";
288 break;
289 default:
290 ret = "UNKNOWN";
291 break;
292 }
293
294 return ret;
295 }
296
297 void
fprint_flags(FILE * stream,uint8_t flags)298 fprint_flags(FILE *stream, uint8_t flags)
299 {
300 int first = 1, i;
301
302 if (flags == 0) {
303 fprintf(stream, "%s", flag2str(flags));
304 return;
305 }
306
307 for (i = 0; i < 8; i++) {
308 uint8_t set = flags & (1U << i);
309
310 if (set) {
311 fprintf(stream, "%s%s", first ? "" : " | ",
312 flag2str(set));
313 first = 0;
314 }
315 }
316 }
317
318 void
fprint_entry(FILE * stream,enum tls13_message_type path[TLS13_NUM_MESSAGE_TYPES],uint8_t flags)319 fprint_entry(FILE *stream,
320 enum tls13_message_type path[TLS13_NUM_MESSAGE_TYPES], uint8_t flags)
321 {
322 int i;
323
324 fprintf(stream, "\t[");
325 fprint_flags(stream, flags);
326 fprintf(stream, "] = {\n");
327
328 for (i = 0; i < TLS13_NUM_MESSAGE_TYPES; i++) {
329 if (path[i] == 0)
330 break;
331 fprintf(stream, "\t\t%s,\n", mt2str(path[i]));
332 }
333 fprintf(stream, "\t},\n");
334 }
335
336 void
edge(enum tls13_message_type start,enum tls13_message_type end,uint8_t flag)337 edge(enum tls13_message_type start, enum tls13_message_type end,
338 uint8_t flag)
339 {
340 printf("\t%s -> %s", mt2str(start), mt2str(end));
341 flag_label(flag);
342 printf(";\n");
343 }
344
345 void
flag_label(uint8_t flag)346 flag_label(uint8_t flag)
347 {
348 if (flag)
349 printf(" [label=\"%s\"]", flag2str(flag));
350 }
351
352 void
forced_edges(enum tls13_message_type start,enum tls13_message_type end,uint8_t forced)353 forced_edges(enum tls13_message_type start, enum tls13_message_type end,
354 uint8_t forced)
355 {
356 uint8_t forced_flag, i;
357
358 if (forced == 0)
359 return;
360
361 for (i = 0; i < 8; i++) {
362 forced_flag = forced & (1U << i);
363 if (forced_flag)
364 edge(start, end, forced_flag);
365 }
366 }
367
368 int
generate_graphics(void)369 generate_graphics(void)
370 {
371 enum tls13_message_type start, end;
372 unsigned int child;
373 uint8_t flag;
374 uint8_t forced;
375
376 printf("digraph G {\n");
377 printf("\t%s [shape=box];\n", mt2str(CLIENT_HELLO));
378 printf("\t%s [shape=box];\n", mt2str(APPLICATION_DATA));
379
380 for (start = CLIENT_HELLO; start < APPLICATION_DATA; start++) {
381 for (child = 0; stateinfo[start][child].mt != 0; child++) {
382 end = stateinfo[start][child].mt;
383 flag = stateinfo[start][child].flag;
384 forced = stateinfo[start][child].forced;
385
386 if (forced == 0)
387 edge(start, end, flag);
388 else
389 forced_edges(start, end, forced);
390 }
391 }
392
393 printf("}\n");
394 return 0;
395 }
396
397 extern enum tls13_message_type handshakes[][TLS13_NUM_MESSAGE_TYPES];
398 extern size_t handshake_count;
399
400 size_t
count_handshakes(void)401 count_handshakes(void)
402 {
403 size_t ret = 0, i;
404
405 for (i = 0; i < handshake_count; i++) {
406 if (handshakes[i][0] != INVALID)
407 ret++;
408 }
409
410 return ret;
411 }
412
413 void
build_table(enum tls13_message_type table[MAX_FLAGS][TLS13_NUM_MESSAGE_TYPES],struct child current,struct child end,struct child path[],uint8_t flags,unsigned int depth)414 build_table(enum tls13_message_type table[MAX_FLAGS][TLS13_NUM_MESSAGE_TYPES],
415 struct child current, struct child end, struct child path[], uint8_t flags,
416 unsigned int depth)
417 {
418 unsigned int i;
419
420 if (depth >= TLS13_NUM_MESSAGE_TYPES - 1)
421 errx(1, "recursed too deeply");
422
423 /* Record current node. */
424 path[depth++] = current;
425 flags |= current.flag;
426
427 /* If we haven't reached the end, recurse over the children. */
428 if (current.mt != end.mt) {
429 for (i = 0; stateinfo[current.mt][i].mt != 0; i++) {
430 struct child child = stateinfo[current.mt][i];
431 int forced = stateinfo[current.mt][i].forced;
432 int illegal = stateinfo[current.mt][i].illegal;
433
434 if ((forced == 0 || (forced & flags)) &&
435 (illegal == 0 || !(illegal & flags)))
436 build_table(table, child, end, path, flags,
437 depth);
438 }
439 return;
440 }
441
442 if (flags == 0)
443 errx(1, "path does not set flags");
444
445 if (table[flags][0] != 0)
446 errx(1, "path traversed twice");
447
448 for (i = 0; i < depth; i++)
449 table[flags][i] = path[i].mt;
450 }
451
452 int
verify_table(enum tls13_message_type table[MAX_FLAGS][TLS13_NUM_MESSAGE_TYPES],int print)453 verify_table(enum tls13_message_type table[MAX_FLAGS][TLS13_NUM_MESSAGE_TYPES],
454 int print)
455 {
456 int success = 1, i;
457 size_t num_valid, num_found = 0;
458 uint8_t flags = 0;
459
460 do {
461 if (table[flags][0] == 0)
462 continue;
463
464 num_found++;
465
466 for (i = 0; i < TLS13_NUM_MESSAGE_TYPES; i++) {
467 if (table[flags][i] != handshakes[flags][i]) {
468 fprintf(stderr,
469 "incorrect entry %d of handshake ", i);
470 fprint_flags(stderr, flags);
471 fprintf(stderr, "\n");
472 success = 0;
473 }
474 }
475
476 if (print)
477 fprint_entry(stdout, table[flags], flags);
478 } while(++flags != 0);
479
480 num_valid = count_handshakes();
481 if (num_valid != num_found) {
482 fprintf(stderr,
483 "incorrect number of handshakes: want %zu, got %zu.\n",
484 num_valid, num_found);
485 success = 0;
486 }
487
488 return success;
489 }
490
491 void
usage(void)492 usage(void)
493 {
494 fprintf(stderr, "usage: handshake_table [-C | -g]\n");
495 exit(1);
496 }
497
498 int
main(int argc,char * argv[])499 main(int argc, char *argv[])
500 {
501 static enum tls13_message_type
502 hs_table[MAX_FLAGS][TLS13_NUM_MESSAGE_TYPES] = {
503 [INITIAL] = {
504 CLIENT_HELLO,
505 SERVER_HELLO_RETRY_REQUEST,
506 CLIENT_HELLO_RETRY,
507 SERVER_HELLO,
508 },
509 };
510 struct child start = {
511 .mt = CLIENT_HELLO,
512 };
513 struct child end = {
514 .mt = APPLICATION_DATA,
515 };
516 struct child path[TLS13_NUM_MESSAGE_TYPES] = {{0}};
517 uint8_t flags = NEGOTIATED;
518 unsigned int depth = 0;
519 int ch, graphviz = 0, print = 0;
520
521 while ((ch = getopt(argc, argv, "Cg")) != -1) {
522 switch (ch) {
523 case 'C':
524 print = 1;
525 break;
526 case 'g':
527 graphviz = 1;
528 break;
529 default:
530 usage();
531 }
532 }
533 argc -= optind;
534 argv += optind;
535
536 if (argc != 0)
537 usage();
538
539 if (graphviz && print)
540 usage();
541
542 if (graphviz)
543 return generate_graphics();
544
545 build_table(hs_table, start, end, path, flags, depth);
546 if (!verify_table(hs_table, print))
547 return 1;
548
549 return 0;
550 }
551