1 /*	$OpenBSD: handshake_table.c,v 1.15 2020/05/14 18:04:19 tb Exp $	*/
2 /*
3  * Copyright (c) 2019 Theo Buehler <tb@openbsd.org>
4  *
5  * Permission to use, copy, modify, and distribute this software for any
6  * purpose with or without fee is hereby granted, provided that the above
7  * copyright notice and this permission notice appear in all copies.
8  *
9  * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
10  * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
11  * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
12  * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
13  * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
14  * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
15  * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
16  */
17 
18 #include <err.h>
19 #include <stdint.h>
20 #include <stdio.h>
21 #include <stdlib.h>
22 #include <unistd.h>
23 
24 #include "tls13_handshake.h"
25 
26 #define MAX_FLAGS (UINT8_MAX + 1)
27 
28 /*
29  * From RFC 8446:
30  *
31  * Appendix A.  State Machine
32  *
33  *    This appendix provides a summary of the legal state transitions for
34  *    the client and server handshakes.  State names (in all capitals,
35  *    e.g., START) have no formal meaning but are provided for ease of
36  *    comprehension.  Actions which are taken only in certain circumstances
37  *    are indicated in [].  The notation "K_{send,recv} = foo" means "set
38  *    the send/recv key to the given key".
39  *
40  * A.1.  Client
41  *
42  *                               START <----+
43  *                Send ClientHello |        | Recv HelloRetryRequest
44  *           [K_send = early data] |        |
45  *                                 v        |
46  *            /                 WAIT_SH ----+
47  *            |                    | Recv ServerHello
48  *            |                    | K_recv = handshake
49  *        Can |                    V
50  *       send |                 WAIT_EE
51  *      early |                    | Recv EncryptedExtensions
52  *       data |           +--------+--------+
53  *            |     Using |                 | Using certificate
54  *            |       PSK |                 v
55  *            |           |            WAIT_CERT_CR
56  *            |           |        Recv |       | Recv CertificateRequest
57  *            |           | Certificate |       v
58  *            |           |             |    WAIT_CERT
59  *            |           |             |       | Recv Certificate
60  *            |           |             v       v
61  *            |           |              WAIT_CV
62  *            |           |                 | Recv CertificateVerify
63  *            |           +> WAIT_FINISHED <+
64  *            |                  | Recv Finished
65  *            \                  | [Send EndOfEarlyData]
66  *                               | K_send = handshake
67  *                               | [Send Certificate [+ CertificateVerify]]
68  *     Can send                  | Send Finished
69  *     app data   -->            | K_send = K_recv = application
70  *     after here                v
71  *                           CONNECTED
72  *
73  *    Note that with the transitions as shown above, clients may send
74  *    alerts that derive from post-ServerHello messages in the clear or
75  *    with the early data keys.  If clients need to send such alerts, they
76  *    SHOULD first rekey to the handshake keys if possible.
77  *
78  */
79 
80 struct child {
81 	enum tls13_message_type	mt;
82 	uint8_t			flag;
83 	uint8_t			forced;
84 	uint8_t			illegal;
85 };
86 
87 #define DEFAULT			0x00
88 
89 static struct child stateinfo[][TLS13_NUM_MESSAGE_TYPES] = {
90 	[CLIENT_HELLO] = {
91 		{SERVER_HELLO_RETRY_REQUEST, DEFAULT, 0, 0},
92 		{SERVER_HELLO, WITHOUT_HRR, 0, 0},
93 	},
94 	[SERVER_HELLO_RETRY_REQUEST] = {
95 		{CLIENT_HELLO_RETRY, DEFAULT, 0, 0},
96 	},
97 	[CLIENT_HELLO_RETRY] = {
98 		{SERVER_HELLO, DEFAULT, 0, 0},
99 	},
100 	[SERVER_HELLO] = {
101 		{SERVER_ENCRYPTED_EXTENSIONS, DEFAULT, 0, 0},
102 	},
103 	[SERVER_ENCRYPTED_EXTENSIONS] = {
104 		{SERVER_CERTIFICATE_REQUEST, DEFAULT, 0, 0},
105 		{SERVER_CERTIFICATE, WITHOUT_CR, 0, 0},
106 		{SERVER_FINISHED, WITH_PSK, 0, 0},
107 	},
108 	[SERVER_CERTIFICATE_REQUEST] = {
109 		{SERVER_CERTIFICATE, DEFAULT, 0, 0},
110 	},
111 	[SERVER_CERTIFICATE] = {
112 		{SERVER_CERTIFICATE_VERIFY, DEFAULT, 0, 0},
113 	},
114 	[SERVER_CERTIFICATE_VERIFY] = {
115 		{SERVER_FINISHED, DEFAULT, 0, 0},
116 	},
117 	[SERVER_FINISHED] = {
118 		{CLIENT_FINISHED, DEFAULT, WITHOUT_CR | WITH_PSK, 0},
119 		{CLIENT_CERTIFICATE, DEFAULT, 0, WITHOUT_CR | WITH_PSK},
120 	},
121 	[CLIENT_CERTIFICATE] = {
122 		{CLIENT_FINISHED, DEFAULT, 0, 0},
123 		{CLIENT_CERTIFICATE_VERIFY, WITH_CCV, 0, 0},
124 	},
125 	[CLIENT_CERTIFICATE_VERIFY] = {
126 		{CLIENT_FINISHED, DEFAULT, 0, 0},
127 	},
128 	[CLIENT_FINISHED] = {
129 		{APPLICATION_DATA, DEFAULT, 0, 0},
130 	},
131 	[APPLICATION_DATA] = {
132 		{0, DEFAULT, 0, 0},
133 	},
134 };
135 
136 const size_t	 stateinfo_count = sizeof(stateinfo) / sizeof(stateinfo[0]);
137 
138 void		 build_table(enum tls13_message_type
139 		     table[MAX_FLAGS][TLS13_NUM_MESSAGE_TYPES],
140 		     struct child current, struct child end,
141 		     struct child path[], uint8_t flags, unsigned int depth);
142 size_t		 count_handshakes(void);
143 void		 edge(enum tls13_message_type start,
144 		     enum tls13_message_type end, uint8_t flag);
145 const char	*flag2str(uint8_t flag);
146 void		 flag_label(uint8_t flag);
147 void		 forced_edges(enum tls13_message_type start,
148 		     enum tls13_message_type end, uint8_t forced);
149 int		 generate_graphics(void);
150 void		 fprint_entry(FILE *stream,
151 		     enum tls13_message_type path[TLS13_NUM_MESSAGE_TYPES],
152 		     uint8_t flags);
153 void		 fprint_flags(FILE *stream, uint8_t flags);
154 const char	*mt2str(enum tls13_message_type mt);
155 __dead void	 usage(void);
156 int		 verify_table(enum tls13_message_type
157 		     table[MAX_FLAGS][TLS13_NUM_MESSAGE_TYPES], int print);
158 
159 const char *
160 flag2str(uint8_t flag)
161 {
162 	const char *ret;
163 
164 	if (flag & (flag - 1))
165 		errx(1, "more than one bit is set");
166 
167 	switch (flag) {
168 	case INITIAL:
169 		ret = "INITIAL";
170 		break;
171 	case NEGOTIATED:
172 		ret = "NEGOTIATED";
173 		break;
174 	case WITHOUT_CR:
175 		ret = "WITHOUT_CR";
176 		break;
177 	case WITHOUT_HRR:
178 		ret = "WITHOUT_HRR";
179 		break;
180 	case WITH_PSK:
181 		ret = "WITH_PSK";
182 		break;
183 	case WITH_CCV:
184 		ret = "WITH_CCV";
185 		break;
186 	case WITH_0RTT:
187 		ret = "WITH_0RTT";
188 		break;
189 	default:
190 		ret = "UNKNOWN";
191 	}
192 
193 	return ret;
194 }
195 
196 const char *
197 mt2str(enum tls13_message_type mt)
198 {
199 	const char *ret;
200 
201 	switch (mt) {
202 	case INVALID:
203 		ret = "INVALID";
204 		break;
205 	case CLIENT_HELLO:
206 		ret = "CLIENT_HELLO";
207 		break;
208 	case CLIENT_HELLO_RETRY:
209 		ret = "CLIENT_HELLO_RETRY";
210 		break;
211 	case CLIENT_END_OF_EARLY_DATA:
212 		ret = "CLIENT_END_OF_EARLY_DATA";
213 		break;
214 	case CLIENT_CERTIFICATE:
215 		ret = "CLIENT_CERTIFICATE";
216 		break;
217 	case CLIENT_CERTIFICATE_VERIFY:
218 		ret = "CLIENT_CERTIFICATE_VERIFY";
219 		break;
220 	case CLIENT_FINISHED:
221 		ret = "CLIENT_FINISHED";
222 		break;
223 	case SERVER_HELLO:
224 		ret = "SERVER_HELLO";
225 		break;
226 	case SERVER_HELLO_RETRY_REQUEST:
227 		ret = "SERVER_HELLO_RETRY_REQUEST";
228 		break;
229 	case SERVER_ENCRYPTED_EXTENSIONS:
230 		ret = "SERVER_ENCRYPTED_EXTENSIONS";
231 		break;
232 	case SERVER_CERTIFICATE:
233 		ret = "SERVER_CERTIFICATE";
234 		break;
235 	case SERVER_CERTIFICATE_VERIFY:
236 		ret = "SERVER_CERTIFICATE_VERIFY";
237 		break;
238 	case SERVER_CERTIFICATE_REQUEST:
239 		ret = "SERVER_CERTIFICATE_REQUEST";
240 		break;
241 	case SERVER_FINISHED:
242 		ret = "SERVER_FINISHED";
243 		break;
244 	case APPLICATION_DATA:
245 		ret = "APPLICATION_DATA";
246 		break;
247 	case TLS13_NUM_MESSAGE_TYPES:
248 		ret = "TLS13_NUM_MESSAGE_TYPES";
249 		break;
250 	default:
251 		ret = "UNKNOWN";
252 		break;
253 	}
254 
255 	return ret;
256 }
257 
258 void
259 fprint_flags(FILE *stream, uint8_t flags)
260 {
261 	int first = 1, i;
262 
263 	if (flags == 0) {
264 		fprintf(stream, "%s", flag2str(flags));
265 		return;
266 	}
267 
268 	for (i = 0; i < 8; i++) {
269 		uint8_t set = flags & (1U << i);
270 
271 		if (set) {
272 			fprintf(stream, "%s%s", first ? "" : " | ",
273 			    flag2str(set));
274 			first = 0;
275 		}
276 	}
277 }
278 
279 void
280 fprint_entry(FILE *stream,
281     enum tls13_message_type path[TLS13_NUM_MESSAGE_TYPES], uint8_t flags)
282 {
283 	int i;
284 
285 	fprintf(stream, "\t[");
286 	fprint_flags(stream, flags);
287 	fprintf(stream, "] = {\n");
288 
289 	for (i = 0; i < TLS13_NUM_MESSAGE_TYPES; i++) {
290 		if (path[i] == 0)
291 			break;
292 		fprintf(stream, "\t\t%s,\n", mt2str(path[i]));
293 	}
294 	fprintf(stream, "\t},\n");
295 }
296 
297 void
298 edge(enum tls13_message_type start, enum tls13_message_type end,
299     uint8_t flag)
300 {
301 	printf("\t%s -> %s", mt2str(start), mt2str(end));
302 	flag_label(flag);
303 	printf(";\n");
304 }
305 
306 void
307 flag_label(uint8_t flag)
308 {
309 	if (flag)
310 		printf(" [label=\"%s\"]", flag2str(flag));
311 }
312 
313 void
314 forced_edges(enum tls13_message_type start, enum tls13_message_type end,
315     uint8_t forced)
316 {
317 	uint8_t	forced_flag, i;
318 
319 	if (forced == 0)
320 		return;
321 
322 	for (i = 0; i < 8; i++) {
323 		forced_flag = forced & (1U << i);
324 		if (forced_flag)
325 			edge(start, end, forced_flag);
326 	}
327 }
328 
329 int
330 generate_graphics(void)
331 {
332 	enum tls13_message_type	start, end;
333 	unsigned int		child;
334 	uint8_t			flag;
335 	uint8_t			forced;
336 
337 	printf("digraph G {\n");
338 	printf("\t%s [shape=box];\n", mt2str(CLIENT_HELLO));
339 	printf("\t%s [shape=box];\n", mt2str(APPLICATION_DATA));
340 
341 	for (start = CLIENT_HELLO; start < APPLICATION_DATA; start++) {
342 		for (child = 0; stateinfo[start][child].mt != 0; child++) {
343 			end = stateinfo[start][child].mt;
344 			flag = stateinfo[start][child].flag;
345 			forced = stateinfo[start][child].forced;
346 
347 			if (forced == 0)
348 				edge(start, end, flag);
349 			else
350 				forced_edges(start, end, forced);
351 		}
352 	}
353 
354 	printf("}\n");
355 	return 0;
356 }
357 
358 extern enum tls13_message_type	handshakes[][TLS13_NUM_MESSAGE_TYPES];
359 extern size_t			handshake_count;
360 
361 size_t
362 count_handshakes(void)
363 {
364 	size_t	ret = 0, i;
365 
366 	for (i = 0; i < handshake_count; i++) {
367 		if (handshakes[i][0] != INVALID)
368 			ret++;
369 	}
370 
371 	return ret;
372 }
373 
374 void
375 build_table(enum tls13_message_type table[MAX_FLAGS][TLS13_NUM_MESSAGE_TYPES],
376     struct child current, struct child end, struct child path[], uint8_t flags,
377     unsigned int depth)
378 {
379 	unsigned int i;
380 
381 	if (depth >= TLS13_NUM_MESSAGE_TYPES - 1)
382 		errx(1, "recursed too deeply");
383 
384 	/* Record current node. */
385 	path[depth++] = current;
386 	flags |= current.flag;
387 
388 	/* If we haven't reached the end, recurse over the children. */
389 	if (current.mt != end.mt) {
390 		for (i = 0; stateinfo[current.mt][i].mt != 0; i++) {
391 			struct child child = stateinfo[current.mt][i];
392 			int forced = stateinfo[current.mt][i].forced;
393 			int illegal = stateinfo[current.mt][i].illegal;
394 
395 			if ((forced == 0 || (forced & flags)) &&
396 			    (illegal == 0 || !(illegal & flags)))
397 				build_table(table, child, end, path, flags,
398 				    depth);
399 		}
400 		return;
401 	}
402 
403 	if (flags == 0)
404 		errx(1, "path does not set flags");
405 
406 	if (table[flags][0] != 0)
407 		errx(1, "path traversed twice");
408 
409 	for (i = 0; i < depth; i++)
410 		table[flags][i] = path[i].mt;
411 }
412 
413 int
414 verify_table(enum tls13_message_type table[MAX_FLAGS][TLS13_NUM_MESSAGE_TYPES],
415     int print)
416 {
417 	int	success = 1, i;
418 	size_t	num_valid, num_found = 0;
419 	uint8_t	flags = 0;
420 
421 	do {
422 		if (table[flags][0] == 0)
423 			continue;
424 
425 		num_found++;
426 
427 		for (i = 0; i < TLS13_NUM_MESSAGE_TYPES; i++) {
428 			if (table[flags][i] != handshakes[flags][i]) {
429 				fprintf(stderr,
430 				    "incorrect entry %d of handshake ", i);
431 				fprint_flags(stderr, flags);
432 				fprintf(stderr, "\n");
433 				success = 0;
434 			}
435 		}
436 
437 		if (print)
438 			fprint_entry(stdout, table[flags], flags);
439 	} while(++flags != 0);
440 
441 	num_valid = count_handshakes();
442 	if (num_valid != num_found) {
443 		fprintf(stderr,
444 		    "incorrect number of handshakes: want %zu, got %zu.\n",
445 		    num_valid, num_found);
446 		success = 0;
447 	}
448 
449 	return success;
450 }
451 
452 __dead void
453 usage(void)
454 {
455 	fprintf(stderr, "usage: handshake_table [-C | -g]\n");
456 	exit(1);
457 }
458 
459 int
460 main(int argc, char *argv[])
461 {
462 	static enum tls13_message_type
463 	    hs_table[MAX_FLAGS][TLS13_NUM_MESSAGE_TYPES] = {
464 		[INITIAL] = {
465 			CLIENT_HELLO,
466 			SERVER_HELLO_RETRY_REQUEST,
467 			CLIENT_HELLO_RETRY,
468 			SERVER_HELLO,
469 		},
470 	};
471 	struct child	start = {
472 		CLIENT_HELLO, DEFAULT, 0, 0,
473 	};
474 	struct child	end = {
475 		APPLICATION_DATA, DEFAULT, 0, 0,
476 	};
477 	struct child	path[TLS13_NUM_MESSAGE_TYPES] = {{0}};
478 	uint8_t		flags = NEGOTIATED;
479 	unsigned int	depth = 0;
480 	int		ch, graphviz = 0, print = 0;
481 
482 	while ((ch = getopt(argc, argv, "Cg")) != -1) {
483 		switch (ch) {
484 		case 'C':
485 			print = 1;
486 			break;
487 		case 'g':
488 			graphviz = 1;
489 			break;
490 		default:
491 			usage();
492 		}
493 	}
494 	argc -= optind;
495 	argv += optind;
496 
497 	if (argc != 0)
498 		usage();
499 
500 	if (graphviz && print)
501 		usage();
502 
503 	if (graphviz)
504 		return generate_graphics();
505 
506 	build_table(hs_table, start, end, path, flags, depth);
507 	if (!verify_table(hs_table, print))
508 		return 1;
509 
510 	if (!print)
511 		printf("SUCCESS\n");
512 
513 	return 0;
514 }
515