1 /*	$OpenBSD: handshake_table.c,v 1.18 2022/12/01 13:49:12 tb Exp $	*/
2 /*
3  * Copyright (c) 2019 Theo Buehler <tb@openbsd.org>
4  *
5  * Permission to use, copy, modify, and distribute this software for any
6  * purpose with or without fee is hereby granted, provided that the above
7  * copyright notice and this permission notice appear in all copies.
8  *
9  * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
10  * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
11  * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
12  * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
13  * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
14  * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
15  * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
16  */
17 
18 #include <err.h>
19 #include <stdint.h>
20 #include <stdio.h>
21 #include <stdlib.h>
22 #include <unistd.h>
23 
24 #include "tls13_handshake.h"
25 
26 #define MAX_FLAGS (UINT8_MAX + 1)
27 
28 /*
29  * From RFC 8446:
30  *
31  * Appendix A.  State Machine
32  *
33  *    This appendix provides a summary of the legal state transitions for
34  *    the client and server handshakes.  State names (in all capitals,
35  *    e.g., START) have no formal meaning but are provided for ease of
36  *    comprehension.  Actions which are taken only in certain circumstances
37  *    are indicated in [].  The notation "K_{send,recv} = foo" means "set
38  *    the send/recv key to the given key".
39  *
40  * A.1.  Client
41  *
42  *                               START <----+
43  *                Send ClientHello |        | Recv HelloRetryRequest
44  *           [K_send = early data] |        |
45  *                                 v        |
46  *            /                 WAIT_SH ----+
47  *            |                    | Recv ServerHello
48  *            |                    | K_recv = handshake
49  *        Can |                    V
50  *       send |                 WAIT_EE
51  *      early |                    | Recv EncryptedExtensions
52  *       data |           +--------+--------+
53  *            |     Using |                 | Using certificate
54  *            |       PSK |                 v
55  *            |           |            WAIT_CERT_CR
56  *            |           |        Recv |       | Recv CertificateRequest
57  *            |           | Certificate |       v
58  *            |           |             |    WAIT_CERT
59  *            |           |             |       | Recv Certificate
60  *            |           |             v       v
61  *            |           |              WAIT_CV
62  *            |           |                 | Recv CertificateVerify
63  *            |           +> WAIT_FINISHED <+
64  *            |                  | Recv Finished
65  *            \                  | [Send EndOfEarlyData]
66  *                               | K_send = handshake
67  *                               | [Send Certificate [+ CertificateVerify]]
68  *     Can send                  | Send Finished
69  *     app data   -->            | K_send = K_recv = application
70  *     after here                v
71  *                           CONNECTED
72  *
73  *    Note that with the transitions as shown above, clients may send
74  *    alerts that derive from post-ServerHello messages in the clear or
75  *    with the early data keys.  If clients need to send such alerts, they
76  *    SHOULD first rekey to the handshake keys if possible.
77  *
78  */
79 
80 struct child {
81 	enum tls13_message_type	mt;
82 	uint8_t			flag;
83 	uint8_t			forced;
84 	uint8_t			illegal;
85 };
86 
87 static struct child stateinfo[][TLS13_NUM_MESSAGE_TYPES] = {
88 	[CLIENT_HELLO] = {
89 		{
90 			.mt = SERVER_HELLO_RETRY_REQUEST,
91 		},
92 		{
93 			.mt = SERVER_HELLO,
94 			.flag = WITHOUT_HRR,
95 		},
96 	},
97 	[SERVER_HELLO_RETRY_REQUEST] = {
98 		{
99 			.mt = CLIENT_HELLO_RETRY,
100 		},
101 	},
102 	[CLIENT_HELLO_RETRY] = {
103 		{
104 			.mt = SERVER_HELLO,
105 		},
106 	},
107 	[SERVER_HELLO] = {
108 		{
109 			.mt = SERVER_ENCRYPTED_EXTENSIONS,
110 		},
111 	},
112 	[SERVER_ENCRYPTED_EXTENSIONS] = {
113 		{
114 			.mt = SERVER_CERTIFICATE_REQUEST,
115 		},
116 		{	.mt = SERVER_CERTIFICATE,
117 			.flag = WITHOUT_CR,
118 		},
119 		{
120 			.mt = SERVER_FINISHED,
121 			.flag = WITH_PSK,
122 		},
123 	},
124 	[SERVER_CERTIFICATE_REQUEST] = {
125 		{
126 			.mt = SERVER_CERTIFICATE,
127 		},
128 	},
129 	[SERVER_CERTIFICATE] = {
130 		{
131 			.mt = SERVER_CERTIFICATE_VERIFY,
132 		},
133 	},
134 	[SERVER_CERTIFICATE_VERIFY] = {
135 		{
136 			.mt = SERVER_FINISHED,
137 		},
138 	},
139 	[SERVER_FINISHED] = {
140 		{
141 			.mt = CLIENT_FINISHED,
142 			.forced = WITHOUT_CR | WITH_PSK,
143 		},
144 		{
145 			.mt = CLIENT_CERTIFICATE,
146 			.illegal = WITHOUT_CR | WITH_PSK,
147 		},
148 	},
149 	[CLIENT_CERTIFICATE] = {
150 		{
151 			.mt = CLIENT_FINISHED,
152 		},
153 		{
154 			.mt = CLIENT_CERTIFICATE_VERIFY,
155 			.flag = WITH_CCV,
156 		},
157 	},
158 	[CLIENT_CERTIFICATE_VERIFY] = {
159 		{
160 			.mt = CLIENT_FINISHED,
161 		},
162 	},
163 	[CLIENT_FINISHED] = {
164 		{
165 			.mt = APPLICATION_DATA,
166 		},
167 	},
168 	[APPLICATION_DATA] = {
169 		{
170 			.mt = 0,
171 		},
172 	},
173 };
174 
175 const size_t	 stateinfo_count = sizeof(stateinfo) / sizeof(stateinfo[0]);
176 
177 void		 build_table(enum tls13_message_type
178 		     table[MAX_FLAGS][TLS13_NUM_MESSAGE_TYPES],
179 		     struct child current, struct child end,
180 		     struct child path[], uint8_t flags, unsigned int depth);
181 size_t		 count_handshakes(void);
182 void		 edge(enum tls13_message_type start,
183 		     enum tls13_message_type end, uint8_t flag);
184 const char	*flag2str(uint8_t flag);
185 void		 flag_label(uint8_t flag);
186 void		 forced_edges(enum tls13_message_type start,
187 		     enum tls13_message_type end, uint8_t forced);
188 int		 generate_graphics(void);
189 void		 fprint_entry(FILE *stream,
190 		     enum tls13_message_type path[TLS13_NUM_MESSAGE_TYPES],
191 		     uint8_t flags);
192 void		 fprint_flags(FILE *stream, uint8_t flags);
193 const char	*mt2str(enum tls13_message_type mt);
194 void		 usage(void);
195 int		 verify_table(enum tls13_message_type
196 		     table[MAX_FLAGS][TLS13_NUM_MESSAGE_TYPES], int print);
197 
198 const char *
199 flag2str(uint8_t flag)
200 {
201 	const char *ret;
202 
203 	if (flag & (flag - 1))
204 		errx(1, "more than one bit is set");
205 
206 	switch (flag) {
207 	case INITIAL:
208 		ret = "INITIAL";
209 		break;
210 	case NEGOTIATED:
211 		ret = "NEGOTIATED";
212 		break;
213 	case WITHOUT_CR:
214 		ret = "WITHOUT_CR";
215 		break;
216 	case WITHOUT_HRR:
217 		ret = "WITHOUT_HRR";
218 		break;
219 	case WITH_PSK:
220 		ret = "WITH_PSK";
221 		break;
222 	case WITH_CCV:
223 		ret = "WITH_CCV";
224 		break;
225 	case WITH_0RTT:
226 		ret = "WITH_0RTT";
227 		break;
228 	default:
229 		ret = "UNKNOWN";
230 	}
231 
232 	return ret;
233 }
234 
235 const char *
236 mt2str(enum tls13_message_type mt)
237 {
238 	const char *ret;
239 
240 	switch (mt) {
241 	case INVALID:
242 		ret = "INVALID";
243 		break;
244 	case CLIENT_HELLO:
245 		ret = "CLIENT_HELLO";
246 		break;
247 	case CLIENT_HELLO_RETRY:
248 		ret = "CLIENT_HELLO_RETRY";
249 		break;
250 	case CLIENT_END_OF_EARLY_DATA:
251 		ret = "CLIENT_END_OF_EARLY_DATA";
252 		break;
253 	case CLIENT_CERTIFICATE:
254 		ret = "CLIENT_CERTIFICATE";
255 		break;
256 	case CLIENT_CERTIFICATE_VERIFY:
257 		ret = "CLIENT_CERTIFICATE_VERIFY";
258 		break;
259 	case CLIENT_FINISHED:
260 		ret = "CLIENT_FINISHED";
261 		break;
262 	case SERVER_HELLO:
263 		ret = "SERVER_HELLO";
264 		break;
265 	case SERVER_HELLO_RETRY_REQUEST:
266 		ret = "SERVER_HELLO_RETRY_REQUEST";
267 		break;
268 	case SERVER_ENCRYPTED_EXTENSIONS:
269 		ret = "SERVER_ENCRYPTED_EXTENSIONS";
270 		break;
271 	case SERVER_CERTIFICATE:
272 		ret = "SERVER_CERTIFICATE";
273 		break;
274 	case SERVER_CERTIFICATE_VERIFY:
275 		ret = "SERVER_CERTIFICATE_VERIFY";
276 		break;
277 	case SERVER_CERTIFICATE_REQUEST:
278 		ret = "SERVER_CERTIFICATE_REQUEST";
279 		break;
280 	case SERVER_FINISHED:
281 		ret = "SERVER_FINISHED";
282 		break;
283 	case APPLICATION_DATA:
284 		ret = "APPLICATION_DATA";
285 		break;
286 	case TLS13_NUM_MESSAGE_TYPES:
287 		ret = "TLS13_NUM_MESSAGE_TYPES";
288 		break;
289 	default:
290 		ret = "UNKNOWN";
291 		break;
292 	}
293 
294 	return ret;
295 }
296 
297 void
298 fprint_flags(FILE *stream, uint8_t flags)
299 {
300 	int first = 1, i;
301 
302 	if (flags == 0) {
303 		fprintf(stream, "%s", flag2str(flags));
304 		return;
305 	}
306 
307 	for (i = 0; i < 8; i++) {
308 		uint8_t set = flags & (1U << i);
309 
310 		if (set) {
311 			fprintf(stream, "%s%s", first ? "" : " | ",
312 			    flag2str(set));
313 			first = 0;
314 		}
315 	}
316 }
317 
318 void
319 fprint_entry(FILE *stream,
320     enum tls13_message_type path[TLS13_NUM_MESSAGE_TYPES], uint8_t flags)
321 {
322 	int i;
323 
324 	fprintf(stream, "\t[");
325 	fprint_flags(stream, flags);
326 	fprintf(stream, "] = {\n");
327 
328 	for (i = 0; i < TLS13_NUM_MESSAGE_TYPES; i++) {
329 		if (path[i] == 0)
330 			break;
331 		fprintf(stream, "\t\t%s,\n", mt2str(path[i]));
332 	}
333 	fprintf(stream, "\t},\n");
334 }
335 
336 void
337 edge(enum tls13_message_type start, enum tls13_message_type end,
338     uint8_t flag)
339 {
340 	printf("\t%s -> %s", mt2str(start), mt2str(end));
341 	flag_label(flag);
342 	printf(";\n");
343 }
344 
345 void
346 flag_label(uint8_t flag)
347 {
348 	if (flag)
349 		printf(" [label=\"%s\"]", flag2str(flag));
350 }
351 
352 void
353 forced_edges(enum tls13_message_type start, enum tls13_message_type end,
354     uint8_t forced)
355 {
356 	uint8_t	forced_flag, i;
357 
358 	if (forced == 0)
359 		return;
360 
361 	for (i = 0; i < 8; i++) {
362 		forced_flag = forced & (1U << i);
363 		if (forced_flag)
364 			edge(start, end, forced_flag);
365 	}
366 }
367 
368 int
369 generate_graphics(void)
370 {
371 	enum tls13_message_type	start, end;
372 	unsigned int		child;
373 	uint8_t			flag;
374 	uint8_t			forced;
375 
376 	printf("digraph G {\n");
377 	printf("\t%s [shape=box];\n", mt2str(CLIENT_HELLO));
378 	printf("\t%s [shape=box];\n", mt2str(APPLICATION_DATA));
379 
380 	for (start = CLIENT_HELLO; start < APPLICATION_DATA; start++) {
381 		for (child = 0; stateinfo[start][child].mt != 0; child++) {
382 			end = stateinfo[start][child].mt;
383 			flag = stateinfo[start][child].flag;
384 			forced = stateinfo[start][child].forced;
385 
386 			if (forced == 0)
387 				edge(start, end, flag);
388 			else
389 				forced_edges(start, end, forced);
390 		}
391 	}
392 
393 	printf("}\n");
394 	return 0;
395 }
396 
397 extern enum tls13_message_type	handshakes[][TLS13_NUM_MESSAGE_TYPES];
398 extern size_t			handshake_count;
399 
400 size_t
401 count_handshakes(void)
402 {
403 	size_t	ret = 0, i;
404 
405 	for (i = 0; i < handshake_count; i++) {
406 		if (handshakes[i][0] != INVALID)
407 			ret++;
408 	}
409 
410 	return ret;
411 }
412 
413 void
414 build_table(enum tls13_message_type table[MAX_FLAGS][TLS13_NUM_MESSAGE_TYPES],
415     struct child current, struct child end, struct child path[], uint8_t flags,
416     unsigned int depth)
417 {
418 	unsigned int i;
419 
420 	if (depth >= TLS13_NUM_MESSAGE_TYPES - 1)
421 		errx(1, "recursed too deeply");
422 
423 	/* Record current node. */
424 	path[depth++] = current;
425 	flags |= current.flag;
426 
427 	/* If we haven't reached the end, recurse over the children. */
428 	if (current.mt != end.mt) {
429 		for (i = 0; stateinfo[current.mt][i].mt != 0; i++) {
430 			struct child child = stateinfo[current.mt][i];
431 			int forced = stateinfo[current.mt][i].forced;
432 			int illegal = stateinfo[current.mt][i].illegal;
433 
434 			if ((forced == 0 || (forced & flags)) &&
435 			    (illegal == 0 || !(illegal & flags)))
436 				build_table(table, child, end, path, flags,
437 				    depth);
438 		}
439 		return;
440 	}
441 
442 	if (flags == 0)
443 		errx(1, "path does not set flags");
444 
445 	if (table[flags][0] != 0)
446 		errx(1, "path traversed twice");
447 
448 	for (i = 0; i < depth; i++)
449 		table[flags][i] = path[i].mt;
450 }
451 
452 int
453 verify_table(enum tls13_message_type table[MAX_FLAGS][TLS13_NUM_MESSAGE_TYPES],
454     int print)
455 {
456 	int	success = 1, i;
457 	size_t	num_valid, num_found = 0;
458 	uint8_t	flags = 0;
459 
460 	do {
461 		if (table[flags][0] == 0)
462 			continue;
463 
464 		num_found++;
465 
466 		for (i = 0; i < TLS13_NUM_MESSAGE_TYPES; i++) {
467 			if (table[flags][i] != handshakes[flags][i]) {
468 				fprintf(stderr,
469 				    "incorrect entry %d of handshake ", i);
470 				fprint_flags(stderr, flags);
471 				fprintf(stderr, "\n");
472 				success = 0;
473 			}
474 		}
475 
476 		if (print)
477 			fprint_entry(stdout, table[flags], flags);
478 	} while(++flags != 0);
479 
480 	num_valid = count_handshakes();
481 	if (num_valid != num_found) {
482 		fprintf(stderr,
483 		    "incorrect number of handshakes: want %zu, got %zu.\n",
484 		    num_valid, num_found);
485 		success = 0;
486 	}
487 
488 	return success;
489 }
490 
491 void
492 usage(void)
493 {
494 	fprintf(stderr, "usage: handshake_table [-C | -g]\n");
495 	exit(1);
496 }
497 
498 int
499 main(int argc, char *argv[])
500 {
501 	static enum tls13_message_type
502 	    hs_table[MAX_FLAGS][TLS13_NUM_MESSAGE_TYPES] = {
503 		[INITIAL] = {
504 			CLIENT_HELLO,
505 			SERVER_HELLO_RETRY_REQUEST,
506 			CLIENT_HELLO_RETRY,
507 			SERVER_HELLO,
508 		},
509 	};
510 	struct child	start = {
511 		.mt = CLIENT_HELLO,
512 	};
513 	struct child	end = {
514 		.mt = APPLICATION_DATA,
515 	};
516 	struct child	path[TLS13_NUM_MESSAGE_TYPES] = {{0}};
517 	uint8_t		flags = NEGOTIATED;
518 	unsigned int	depth = 0;
519 	int		ch, graphviz = 0, print = 0;
520 
521 	while ((ch = getopt(argc, argv, "Cg")) != -1) {
522 		switch (ch) {
523 		case 'C':
524 			print = 1;
525 			break;
526 		case 'g':
527 			graphviz = 1;
528 			break;
529 		default:
530 			usage();
531 		}
532 	}
533 	argc -= optind;
534 	argv += optind;
535 
536 	if (argc != 0)
537 		usage();
538 
539 	if (graphviz && print)
540 		usage();
541 
542 	if (graphviz)
543 		return generate_graphics();
544 
545 	build_table(hs_table, start, end, path, flags, depth);
546 	if (!verify_table(hs_table, print))
547 		return 1;
548 
549 	return 0;
550 }
551