1*1d58e56aStb /*	$OpenBSD: mlkem_tests.c,v 1.2 2024/12/26 00:10:19 tb Exp $ */
20c814320Stb /*
30c814320Stb  * Copyright (c) 2024 Google Inc.
40c814320Stb  * Copyright (c) 2024 Theo Buehler <tb@openbsd.org>
50c814320Stb  * Copyright (c) 2024 Bob Beck <beck@obtuse.com>
60c814320Stb  *
70c814320Stb  * Permission to use, copy, modify, and/or distribute this software for any
80c814320Stb  * purpose with or without fee is hereby granted, provided that the above
90c814320Stb  * copyright notice and this permission notice appear in all copies.
100c814320Stb  *
110c814320Stb  * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
120c814320Stb  * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
130c814320Stb  * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
140c814320Stb  * SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
150c814320Stb  * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
160c814320Stb  * OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
170c814320Stb  * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
180c814320Stb  */
190c814320Stb 
200c814320Stb #include <err.h>
210c814320Stb #include <stdint.h>
220c814320Stb #include <stdio.h>
230c814320Stb #include <stdlib.h>
240c814320Stb #include <string.h>
250c814320Stb 
260c814320Stb #include "bytestring.h"
270c814320Stb #include "mlkem.h"
280c814320Stb 
290c814320Stb #include "mlkem_internal.h"
300c814320Stb 
310c814320Stb #include "mlkem_tests_util.h"
320c814320Stb #include "parse_test_file.h"
330c814320Stb 
340c814320Stb enum test_type {
350c814320Stb 	TEST_TYPE_NORMAL,
360c814320Stb 	TEST_TYPE_NIST,
370c814320Stb };
380c814320Stb 
390c814320Stb struct decap_ctx {
400c814320Stb 	struct parse *parse_ctx;
410c814320Stb 
420c814320Stb 	void *private_key;
430c814320Stb 	size_t private_key_len;
440c814320Stb 
450c814320Stb 	mlkem_parse_private_key_fn parse_private_key;
460c814320Stb 	mlkem_decap_fn decap;
470c814320Stb };
480c814320Stb 
490c814320Stb enum decap_states {
500c814320Stb 	DECAP_PRIVATE_KEY,
510c814320Stb 	DECAP_CIPHERTEXT,
520c814320Stb 	DECAP_RESULT,
530c814320Stb 	DECAP_SHARED_SECRET,
540c814320Stb 	N_DECAP_STATES,
550c814320Stb };
560c814320Stb 
570c814320Stb static const struct line_spec decap_state_machine[] = {
580c814320Stb 	[DECAP_PRIVATE_KEY] = {
590c814320Stb 		.state = DECAP_PRIVATE_KEY,
600c814320Stb 		.type = LINE_HEX,
610c814320Stb 		.name = "private key",
620c814320Stb 		.label = "private_key",
630c814320Stb 	},
640c814320Stb 	[DECAP_CIPHERTEXT] = {
650c814320Stb 		.state = DECAP_CIPHERTEXT,
660c814320Stb 		.type = LINE_HEX,
670c814320Stb 		.name = "cipher text",
680c814320Stb 		.label = "ciphertext",
690c814320Stb 	},
700c814320Stb 	[DECAP_RESULT] = {
710c814320Stb 		.state = DECAP_RESULT,
720c814320Stb 		.type = LINE_STRING_MATCH,
730c814320Stb 		.name = "result",
740c814320Stb 		.label = "result",
750c814320Stb 		.match = "fail",
760c814320Stb 	},
770c814320Stb 	[DECAP_SHARED_SECRET] = {
780c814320Stb 		.state = DECAP_SHARED_SECRET,
790c814320Stb 		.type = LINE_HEX,
800c814320Stb 		.name = "shared secret",
810c814320Stb 		.label = "shared_secret",
820c814320Stb 	},
830c814320Stb };
840c814320Stb 
850c814320Stb static int
decap_init(void * ctx,void * parse_ctx)860c814320Stb decap_init(void *ctx, void *parse_ctx)
870c814320Stb {
880c814320Stb 	struct decap_ctx *decap = ctx;
890c814320Stb 
900c814320Stb 	decap->parse_ctx = parse_ctx;
910c814320Stb 
920c814320Stb 	return 1;
930c814320Stb }
940c814320Stb 
950c814320Stb static void
decap_finish(void * ctx)960c814320Stb decap_finish(void *ctx)
970c814320Stb {
980c814320Stb 	(void)ctx;
990c814320Stb }
1000c814320Stb 
1010c814320Stb static int
MlkemDecapFileTest(struct decap_ctx * decap)1020c814320Stb MlkemDecapFileTest(struct decap_ctx *decap)
1030c814320Stb {
1040c814320Stb 	struct parse *p = decap->parse_ctx;
1050c814320Stb 	uint8_t shared_secret_buf[MLKEM_SHARED_SECRET_BYTES];
1060c814320Stb 	CBS ciphertext, shared_secret, private_key;
1070c814320Stb 	int should_fail;
1080c814320Stb 	int failed = 1;
1090c814320Stb 
1100c814320Stb 	parse_get_cbs(p, DECAP_CIPHERTEXT, &ciphertext);
1110c814320Stb 	parse_get_cbs(p, DECAP_SHARED_SECRET, &shared_secret);
1120c814320Stb 	parse_get_cbs(p, DECAP_PRIVATE_KEY, &private_key);
1130c814320Stb 	parse_get_int(p, DECAP_RESULT, &should_fail);
1140c814320Stb 
1150c814320Stb 	if (!decap->parse_private_key(decap->private_key, &private_key)) {
1160c814320Stb 		if ((failed = !should_fail))
1170c814320Stb 			parse_info(p, "parse private key");
1180c814320Stb 		goto err;
1190c814320Stb 	}
1200c814320Stb 	if (!decap->decap(shared_secret_buf,
1210c814320Stb 	    CBS_data(&ciphertext), CBS_len(&ciphertext), decap->private_key)) {
1220c814320Stb 		if ((failed = !should_fail))
123*1d58e56aStb 			parse_info(p, "decap");
1240c814320Stb 		goto err;
1250c814320Stb 	}
1260c814320Stb 
1270c814320Stb 	failed = !parse_data_equal(p, "shared_secret", &shared_secret,
1280c814320Stb 	    shared_secret_buf, sizeof(shared_secret_buf));
1290c814320Stb 
1300c814320Stb 	if (should_fail != failed) {
1310c814320Stb 		parse_info(p, "FAIL: should_fail %d, failed %d",
1320c814320Stb 		    should_fail, failed);
1330c814320Stb 		failed = 1;
1340c814320Stb 	}
1350c814320Stb 
1360c814320Stb  err:
1370c814320Stb 	return failed;
1380c814320Stb }
1390c814320Stb 
1400c814320Stb static int
decap_run_test_case(void * ctx)1410c814320Stb decap_run_test_case(void *ctx)
1420c814320Stb {
1430c814320Stb 	return MlkemDecapFileTest(ctx);
1440c814320Stb }
1450c814320Stb 
1460c814320Stb static const struct test_parse decap_parse = {
1470c814320Stb 	.states = decap_state_machine,
1480c814320Stb 	.num_states = N_DECAP_STATES,
1490c814320Stb 
1500c814320Stb 	.init = decap_init,
1510c814320Stb 	.finish = decap_finish,
1520c814320Stb 
1530c814320Stb 	.run_test_case = decap_run_test_case,
1540c814320Stb };
1550c814320Stb 
1560c814320Stb enum nist_decap_instructions {
1570c814320Stb 	NIST_DECAP_DK,
1580c814320Stb 	N_NIST_DECAP_INSTRUCTIONS,
1590c814320Stb };
1600c814320Stb 
1610c814320Stb static const struct line_spec nist_decap_instruction_state_machine[] = {
1620c814320Stb 	[NIST_DECAP_DK] = {
1630c814320Stb 		.state = NIST_DECAP_DK,
1640c814320Stb 		.type = LINE_HEX,
1650c814320Stb 		.name = "private key (instruction [dk])",
1660c814320Stb 		.label = "dk",
1670c814320Stb 	},
1680c814320Stb };
1690c814320Stb 
1700c814320Stb enum nist_decap_states {
1710c814320Stb 	NIST_DECAP_C,
1720c814320Stb 	NIST_DECAP_K,
1730c814320Stb 	N_NIST_DECAP_STATES,
1740c814320Stb };
1750c814320Stb 
1760c814320Stb static const struct line_spec nist_decap_state_machine[] = {
1770c814320Stb 	[NIST_DECAP_C] = {
1780c814320Stb 		.state = NIST_DECAP_C,
1790c814320Stb 		.type = LINE_HEX,
1800c814320Stb 		.name = "ciphertext (c)",
1810c814320Stb 		.label = "c",
1820c814320Stb 	},
1830c814320Stb 	[NIST_DECAP_K] = {
1840c814320Stb 		.state = NIST_DECAP_K,
1850c814320Stb 		.type = LINE_HEX,
1860c814320Stb 		.name = "shared secret (k)",
1870c814320Stb 		.label = "k",
1880c814320Stb 	},
1890c814320Stb };
1900c814320Stb 
1910c814320Stb static int
MlkemNistDecapFileTest(struct decap_ctx * decap)1920c814320Stb MlkemNistDecapFileTest(struct decap_ctx *decap)
1930c814320Stb {
1940c814320Stb 	struct parse *p = decap->parse_ctx;
1950c814320Stb 	uint8_t shared_secret[MLKEM_SHARED_SECRET_BYTES];
1960c814320Stb 	CBS dk, c, k;
1970c814320Stb 	int failed = 1;
1980c814320Stb 
1990c814320Stb 	parse_instruction_get_cbs(p, NIST_DECAP_DK, &dk);
2000c814320Stb 	parse_get_cbs(p, NIST_DECAP_C, &c);
2010c814320Stb 	parse_get_cbs(p, NIST_DECAP_K, &k);
2020c814320Stb 
2030c814320Stb 	if (!parse_length_equal(p, "private key",
2040c814320Stb 	    decap->private_key_len, CBS_len(&dk)))
2050c814320Stb 		goto err;
2060c814320Stb 	if (!parse_length_equal(p, "shared secret",
2070c814320Stb 	    MLKEM_SHARED_SECRET_BYTES, CBS_len(&k)))
2080c814320Stb 		goto err;
2090c814320Stb 
2100c814320Stb 	if (!decap->parse_private_key(decap->private_key, &dk)) {
2110c814320Stb 		parse_info(p, "parse private key");
2120c814320Stb 		goto err;
2130c814320Stb 	}
2140c814320Stb 	if (!decap->decap(shared_secret, CBS_data(&c), CBS_len(&c),
2150c814320Stb 	    decap->private_key)) {
2160c814320Stb 		parse_info(p, "decap");
2170c814320Stb 		goto err;
2180c814320Stb 	}
2190c814320Stb 
2200c814320Stb 	failed = !parse_data_equal(p, "shared secret", &k,
2210c814320Stb 	    shared_secret, MLKEM_SHARED_SECRET_BYTES);
2220c814320Stb 
2230c814320Stb  err:
2240c814320Stb 	return failed;
2250c814320Stb }
2260c814320Stb 
2270c814320Stb static int
nist_decap_run_test_case(void * ctx)2280c814320Stb nist_decap_run_test_case(void *ctx)
2290c814320Stb {
2300c814320Stb 	return MlkemNistDecapFileTest(ctx);
2310c814320Stb }
2320c814320Stb 
2330c814320Stb static const struct test_parse nist_decap_parse = {
2340c814320Stb 	.instructions = nist_decap_instruction_state_machine,
2350c814320Stb 	.num_instructions = N_NIST_DECAP_INSTRUCTIONS,
2360c814320Stb 
2370c814320Stb 	.states = nist_decap_state_machine,
2380c814320Stb 	.num_states = N_NIST_DECAP_STATES,
2390c814320Stb 
2400c814320Stb 	.init = decap_init,
2410c814320Stb 	.finish = decap_finish,
2420c814320Stb 
2430c814320Stb 	.run_test_case = nist_decap_run_test_case,
2440c814320Stb };
2450c814320Stb 
2460c814320Stb static int
mlkem_decap_tests(const char * fn,size_t size,enum test_type test_type)2470c814320Stb mlkem_decap_tests(const char *fn, size_t size, enum test_type test_type)
2480c814320Stb {
2490c814320Stb 	struct MLKEM768_private_key private_key768;
2500c814320Stb 	struct decap_ctx decap768 = {
2510c814320Stb 		.private_key = &private_key768,
2520c814320Stb 		.private_key_len = MLKEM768_PRIVATE_KEY_BYTES,
2530c814320Stb 
2540c814320Stb 		.parse_private_key = mlkem768_parse_private_key,
2550c814320Stb 		.decap = mlkem768_decap,
2560c814320Stb 	};
2570c814320Stb 	struct MLKEM1024_private_key private_key1024;
2580c814320Stb 	struct decap_ctx decap1024 = {
2590c814320Stb 		.private_key = &private_key1024,
2600c814320Stb 		.private_key_len = MLKEM1024_PRIVATE_KEY_BYTES,
2610c814320Stb 
2620c814320Stb 		.parse_private_key = mlkem1024_parse_private_key,
2630c814320Stb 		.decap = mlkem1024_decap,
2640c814320Stb 	};
2650c814320Stb 
2660c814320Stb 	if (size == 768 && test_type == TEST_TYPE_NORMAL)
2670c814320Stb 		return parse_test_file(fn, &decap_parse, &decap768);
2680c814320Stb 	if (size == 768 && test_type == TEST_TYPE_NIST)
2690c814320Stb 		return parse_test_file(fn, &nist_decap_parse, &decap768);
2700c814320Stb 	if (size == 1024 && test_type == TEST_TYPE_NORMAL)
2710c814320Stb 		return parse_test_file(fn, &decap_parse, &decap1024);
2720c814320Stb 	if (size == 1024 && test_type == TEST_TYPE_NIST)
2730c814320Stb 		return parse_test_file(fn, &nist_decap_parse, &decap1024);
2740c814320Stb 
2750c814320Stb 	errx(1, "unknown decap test: size %zu, type %d", size, test_type);
2760c814320Stb }
2770c814320Stb 
2780c814320Stb struct encap_ctx {
2790c814320Stb 	struct parse *parse_ctx;
2800c814320Stb 
2810c814320Stb 	void *public_key;
2820c814320Stb 	uint8_t *ciphertext;
2830c814320Stb 	size_t ciphertext_len;
2840c814320Stb 
2850c814320Stb 	mlkem_parse_public_key_fn parse_public_key;
2860c814320Stb 	mlkem_encap_external_entropy_fn encap_external_entropy;
2870c814320Stb };
2880c814320Stb 
2890c814320Stb enum encap_states {
2900c814320Stb 	ENCAP_ENTROPY,
2910c814320Stb 	ENCAP_PUBLIC_KEY,
2920c814320Stb 	ENCAP_RESULT,
2930c814320Stb 	ENCAP_CIPHERTEXT,
2940c814320Stb 	ENCAP_SHARED_SECRET,
2950c814320Stb 	N_ENCAP_STATES,
2960c814320Stb };
2970c814320Stb 
2980c814320Stb static const struct line_spec encap_state_machine[] = {
2990c814320Stb 	[ENCAP_ENTROPY] = {
3000c814320Stb 		.state = ENCAP_ENTROPY,
3010c814320Stb 		.type = LINE_HEX,
3020c814320Stb 		.name = "entropy",
3030c814320Stb 		.label = "entropy",
3040c814320Stb 	},
3050c814320Stb 	[ENCAP_PUBLIC_KEY] = {
3060c814320Stb 		.state = ENCAP_PUBLIC_KEY,
3070c814320Stb 		.type = LINE_HEX,
3080c814320Stb 		.name = "public key",
3090c814320Stb 		.label = "public_key",
3100c814320Stb 	},
3110c814320Stb 	[ENCAP_RESULT] = {
3120c814320Stb 		.state = ENCAP_RESULT,
3130c814320Stb 		.type = LINE_STRING_MATCH,
3140c814320Stb 		.name = "result",
3150c814320Stb 		.label = "result",
3160c814320Stb 		.match = "fail",
3170c814320Stb 	},
3180c814320Stb 	[ENCAP_CIPHERTEXT] = {
3190c814320Stb 		.state = ENCAP_CIPHERTEXT,
3200c814320Stb 		.type = LINE_HEX,
3210c814320Stb 		.name = "ciphertext",
3220c814320Stb 		.label = "ciphertext",
3230c814320Stb 	},
3240c814320Stb 	[ENCAP_SHARED_SECRET] = {
3250c814320Stb 		.state = ENCAP_SHARED_SECRET,
3260c814320Stb 		.type = LINE_HEX,
3270c814320Stb 		.name = "shared secret",
3280c814320Stb 		.label = "shared_secret",
3290c814320Stb 	},
3300c814320Stb };
3310c814320Stb 
3320c814320Stb static int
encap_init(void * ctx,void * parse_ctx)3330c814320Stb encap_init(void *ctx, void *parse_ctx)
3340c814320Stb {
3350c814320Stb 	struct encap_ctx *encap = ctx;
3360c814320Stb 
3370c814320Stb 	encap->parse_ctx = parse_ctx;
3380c814320Stb 
3390c814320Stb 	return 1;
3400c814320Stb }
3410c814320Stb 
3420c814320Stb static void
encap_finish(void * ctx)3430c814320Stb encap_finish(void *ctx)
3440c814320Stb {
3450c814320Stb 	(void)ctx;
3460c814320Stb }
3470c814320Stb 
3480c814320Stb static int
MlkemEncapFileTest(struct encap_ctx * encap)3490c814320Stb MlkemEncapFileTest(struct encap_ctx *encap)
3500c814320Stb {
3510c814320Stb 	struct parse *p = encap->parse_ctx;
3520c814320Stb 	uint8_t shared_secret_buf[MLKEM_SHARED_SECRET_BYTES];
3530c814320Stb 	CBS entropy, public_key, ciphertext, shared_secret;
3540c814320Stb 	int should_fail;
3550c814320Stb 	int failed = 1;
3560c814320Stb 
3570c814320Stb 	parse_get_cbs(p, ENCAP_ENTROPY, &entropy);
3580c814320Stb 	parse_get_cbs(p, ENCAP_PUBLIC_KEY, &public_key);
3590c814320Stb 	parse_get_cbs(p, ENCAP_CIPHERTEXT, &ciphertext);
3600c814320Stb 	parse_get_cbs(p, ENCAP_SHARED_SECRET, &shared_secret);
3610c814320Stb 	parse_get_int(p, ENCAP_RESULT, &should_fail);
3620c814320Stb 
3630c814320Stb 	if (!encap->parse_public_key(encap->public_key, &public_key)) {
3640c814320Stb 		if ((failed = !should_fail))
3650c814320Stb 			parse_info(p, "parse public key");
3660c814320Stb 		goto err;
3670c814320Stb 	}
3680c814320Stb 	encap->encap_external_entropy(encap->ciphertext, shared_secret_buf,
3690c814320Stb 	    encap->public_key, CBS_data(&entropy));
3700c814320Stb 
3710c814320Stb 	failed = !parse_data_equal(p, "shared_secret", &shared_secret,
3720c814320Stb 	    shared_secret_buf, sizeof(shared_secret_buf));
3730c814320Stb 	failed |= !parse_data_equal(p, "ciphertext", &ciphertext,
3740c814320Stb 	    encap->ciphertext, encap->ciphertext_len);
3750c814320Stb 
3760c814320Stb 	if (should_fail != failed) {
3770c814320Stb 		parse_info(p, "FAIL: should_fail %d, failed %d",
3780c814320Stb 		    should_fail, failed);
3790c814320Stb 		failed = 1;
3800c814320Stb 	}
3810c814320Stb 
3820c814320Stb  err:
3830c814320Stb 	return failed;
3840c814320Stb }
3850c814320Stb 
3860c814320Stb static int
encap_run_test_case(void * ctx)3870c814320Stb encap_run_test_case(void *ctx)
3880c814320Stb {
3890c814320Stb 	return MlkemEncapFileTest(ctx);
3900c814320Stb }
3910c814320Stb 
3920c814320Stb static const struct test_parse encap_parse = {
3930c814320Stb 	.states = encap_state_machine,
3940c814320Stb 	.num_states = N_ENCAP_STATES,
3950c814320Stb 
3960c814320Stb 	.init = encap_init,
3970c814320Stb 	.finish = encap_finish,
3980c814320Stb 
3990c814320Stb 	.run_test_case = encap_run_test_case,
4000c814320Stb };
4010c814320Stb 
4020c814320Stb static int
mlkem_encap_tests(const char * fn,size_t size)4030c814320Stb mlkem_encap_tests(const char *fn, size_t size)
4040c814320Stb {
4050c814320Stb 	struct MLKEM768_public_key public_key768;
4060c814320Stb 	uint8_t ciphertext768[MLKEM768_CIPHERTEXT_BYTES];
4070c814320Stb 	struct encap_ctx encap768 = {
4080c814320Stb 		.public_key = &public_key768,
4090c814320Stb 		.ciphertext = ciphertext768,
4100c814320Stb 		.ciphertext_len = sizeof(ciphertext768),
4110c814320Stb 
4120c814320Stb 		.parse_public_key = mlkem768_parse_public_key,
4130c814320Stb 		.encap_external_entropy = mlkem768_encap_external_entropy,
4140c814320Stb 	};
4150c814320Stb 	struct MLKEM1024_public_key public_key1024;
4160c814320Stb 	uint8_t ciphertext1024[MLKEM1024_CIPHERTEXT_BYTES];
4170c814320Stb 	struct encap_ctx encap1024 = {
4180c814320Stb 		.public_key = &public_key1024,
4190c814320Stb 		.ciphertext = ciphertext1024,
4200c814320Stb 		.ciphertext_len = sizeof(ciphertext1024),
4210c814320Stb 
4220c814320Stb 		.parse_public_key = mlkem1024_parse_public_key,
4230c814320Stb 		.encap_external_entropy = mlkem1024_encap_external_entropy,
4240c814320Stb 	};
4250c814320Stb 
4260c814320Stb 	if (size == 768)
4270c814320Stb 		return parse_test_file(fn, &encap_parse, &encap768);
4280c814320Stb 	if (size == 1024)
4290c814320Stb 		return parse_test_file(fn, &encap_parse, &encap1024);
4300c814320Stb 
4310c814320Stb 	errx(1, "unknown encap test: size %zu", size);
4320c814320Stb }
4330c814320Stb 
4340c814320Stb struct keygen_ctx {
4350c814320Stb 	struct parse *parse_ctx;
4360c814320Stb 
4370c814320Stb 	void *private_key;
4380c814320Stb 	void *encoded_public_key;
4390c814320Stb 	size_t encoded_public_key_len;
4400c814320Stb 	size_t private_key_len;
4410c814320Stb 	size_t public_key_len;
4420c814320Stb 
4430c814320Stb 	mlkem_generate_key_external_entropy_fn generate_key_external_entropy;
4440c814320Stb 	mlkem_encode_private_key_fn encode_private_key;
4450c814320Stb };
4460c814320Stb 
4470c814320Stb enum keygen_states {
4480c814320Stb 	KEYGEN_SEED,
4490c814320Stb 	KEYGEN_PUBLIC_KEY,
4500c814320Stb 	KEYGEN_PRIVATE_KEY,
4510c814320Stb 	N_KEYGEN_STATES,
4520c814320Stb };
4530c814320Stb 
4540c814320Stb static const struct line_spec keygen_state_machine[] = {
4550c814320Stb 	[KEYGEN_SEED] = {
4560c814320Stb 		.state = KEYGEN_SEED,
4570c814320Stb 		.type = LINE_HEX,
4580c814320Stb 		.name = "seed",
4590c814320Stb 		.label = "seed",
4600c814320Stb 	},
4610c814320Stb 	[KEYGEN_PUBLIC_KEY] = {
4620c814320Stb 		.state = KEYGEN_PUBLIC_KEY,
4630c814320Stb 		.type = LINE_HEX,
4640c814320Stb 		.name = "public key",
4650c814320Stb 		.label = "public_key",
4660c814320Stb 	},
4670c814320Stb 	[KEYGEN_PRIVATE_KEY] = {
4680c814320Stb 		.state = KEYGEN_PRIVATE_KEY,
4690c814320Stb 		.type = LINE_HEX,
4700c814320Stb 		.name = "private key",
4710c814320Stb 		.label = "private_key",
4720c814320Stb 	},
4730c814320Stb };
4740c814320Stb 
4750c814320Stb static int
keygen_init(void * ctx,void * parse_ctx)4760c814320Stb keygen_init(void *ctx, void *parse_ctx)
4770c814320Stb {
4780c814320Stb 	struct keygen_ctx *keygen = ctx;
4790c814320Stb 
4800c814320Stb 	keygen->parse_ctx = parse_ctx;
4810c814320Stb 
4820c814320Stb 	return 1;
4830c814320Stb }
4840c814320Stb 
4850c814320Stb static void
keygen_finish(void * ctx)4860c814320Stb keygen_finish(void *ctx)
4870c814320Stb {
4880c814320Stb 	(void)ctx;
4890c814320Stb }
4900c814320Stb 
4910c814320Stb static int
MlkemKeygenFileTest(struct keygen_ctx * keygen)4920c814320Stb MlkemKeygenFileTest(struct keygen_ctx *keygen)
4930c814320Stb {
4940c814320Stb 	struct parse *p = keygen->parse_ctx;
4950c814320Stb 	CBS seed, public_key, private_key;
4960c814320Stb 	uint8_t *encoded_private_key = NULL;
4970c814320Stb 	size_t encoded_private_key_len = 0;
4980c814320Stb 	int failed = 1;
4990c814320Stb 
5000c814320Stb 	parse_get_cbs(p, KEYGEN_SEED, &seed);
5010c814320Stb 	parse_get_cbs(p, KEYGEN_PUBLIC_KEY, &public_key);
5020c814320Stb 	parse_get_cbs(p, KEYGEN_PRIVATE_KEY, &private_key);
5030c814320Stb 
5040c814320Stb 	if (!parse_length_equal(p, "seed", MLKEM_SEED_BYTES, CBS_len(&seed)))
5050c814320Stb 		goto err;
5060c814320Stb 	if (!parse_length_equal(p, "public key",
5070c814320Stb 	    keygen->public_key_len, CBS_len(&public_key)))
5080c814320Stb 		goto err;
5090c814320Stb 	if (!parse_length_equal(p, "private key",
5100c814320Stb 	    keygen->private_key_len, CBS_len(&private_key)))
5110c814320Stb 		goto err;
5120c814320Stb 
5130c814320Stb 	keygen->generate_key_external_entropy(keygen->encoded_public_key,
5140c814320Stb 	    keygen->private_key, CBS_data(&seed));
5150c814320Stb 	if (!keygen->encode_private_key(keygen->private_key,
5160c814320Stb 	    &encoded_private_key, &encoded_private_key_len)) {
5170c814320Stb 		parse_info(p, "encode private key");
5180c814320Stb 		goto err;
5190c814320Stb 	}
5200c814320Stb 
5210c814320Stb 	failed = !parse_data_equal(p, "private key", &private_key,
5220c814320Stb 	    encoded_private_key, encoded_private_key_len);
5230c814320Stb 	failed |= !parse_data_equal(p, "public key", &public_key,
5240c814320Stb 	    keygen->encoded_public_key, keygen->encoded_public_key_len);
5250c814320Stb 
5260c814320Stb  err:
5270c814320Stb 	freezero(encoded_private_key, encoded_private_key_len);
5280c814320Stb 
5290c814320Stb 	return failed;
5300c814320Stb }
5310c814320Stb 
5320c814320Stb static int
keygen_run_test_case(void * ctx)5330c814320Stb keygen_run_test_case(void *ctx)
5340c814320Stb {
5350c814320Stb 	return MlkemKeygenFileTest(ctx);
5360c814320Stb }
5370c814320Stb 
5380c814320Stb static const struct test_parse keygen_parse = {
5390c814320Stb 	.states = keygen_state_machine,
5400c814320Stb 	.num_states = N_KEYGEN_STATES,
5410c814320Stb 
5420c814320Stb 	.init = keygen_init,
5430c814320Stb 	.finish = keygen_finish,
5440c814320Stb 
5450c814320Stb 	.run_test_case = keygen_run_test_case,
5460c814320Stb };
5470c814320Stb 
5480c814320Stb enum nist_keygen_states {
5490c814320Stb 	NIST_KEYGEN_Z,
5500c814320Stb 	NIST_KEYGEN_D,
5510c814320Stb 	NIST_KEYGEN_EK,
5520c814320Stb 	NIST_KEYGEN_DK,
5530c814320Stb 	N_NIST_KEYGEN_STATES,
5540c814320Stb };
5550c814320Stb 
5560c814320Stb static const struct line_spec nist_keygen_state_machine[] = {
5570c814320Stb 	[NIST_KEYGEN_Z] = {
5580c814320Stb 		.state = NIST_KEYGEN_Z,
5590c814320Stb 		.type = LINE_HEX,
5600c814320Stb 		.name = "seed (z)",
5610c814320Stb 		.label = "z",
5620c814320Stb 	},
5630c814320Stb 	[NIST_KEYGEN_D] = {
5640c814320Stb 		.state = NIST_KEYGEN_D,
5650c814320Stb 		.type = LINE_HEX,
5660c814320Stb 		.name = "seed (d)",
5670c814320Stb 		.label = "d",
5680c814320Stb 	},
5690c814320Stb 	[NIST_KEYGEN_EK] = {
5700c814320Stb 		.state = NIST_KEYGEN_EK,
5710c814320Stb 		.type = LINE_HEX,
5720c814320Stb 		.name = "public key (ek)",
5730c814320Stb 		.label = "ek",
5740c814320Stb 	},
5750c814320Stb 	[NIST_KEYGEN_DK] = {
5760c814320Stb 		.state = NIST_KEYGEN_DK,
5770c814320Stb 		.type = LINE_HEX,
5780c814320Stb 		.name = "private key (dk)",
5790c814320Stb 		.label = "dk",
5800c814320Stb 	},
5810c814320Stb };
5820c814320Stb 
5830c814320Stb static int
MlkemNistKeygenFileTest(struct keygen_ctx * keygen)5840c814320Stb MlkemNistKeygenFileTest(struct keygen_ctx *keygen)
5850c814320Stb {
5860c814320Stb 	struct parse *p = keygen->parse_ctx;
5870c814320Stb 	CBB seed_cbb;
5880c814320Stb 	CBS z, d, ek, dk;
5890c814320Stb 	uint8_t seed[MLKEM_SEED_BYTES];
5900c814320Stb 	size_t seed_len;
5910c814320Stb 	uint8_t *encoded_private_key = NULL;
5920c814320Stb 	size_t encoded_private_key_len = 0;
5930c814320Stb 	int failed = 1;
5940c814320Stb 
5950c814320Stb 	parse_get_cbs(p, NIST_KEYGEN_Z, &z);
5960c814320Stb 	parse_get_cbs(p, NIST_KEYGEN_D, &d);
5970c814320Stb 	parse_get_cbs(p, NIST_KEYGEN_EK, &ek);
5980c814320Stb 	parse_get_cbs(p, NIST_KEYGEN_DK, &dk);
5990c814320Stb 
6000c814320Stb 	if (!CBB_init_fixed(&seed_cbb, seed, sizeof(seed)))
6010c814320Stb 		parse_errx(p, "CBB_init_fixed");
6020c814320Stb 	if (!CBB_add_bytes(&seed_cbb, CBS_data(&d), CBS_len(&d)))
6030c814320Stb 		parse_errx(p, "CBB_add_bytes");
6040c814320Stb 	if (!CBB_add_bytes(&seed_cbb, CBS_data(&z), CBS_len(&z)))
6050c814320Stb 		parse_errx(p, "CBB_add_bytes");
6060c814320Stb 	if (!CBB_finish(&seed_cbb, NULL, &seed_len))
6070c814320Stb 		parse_errx(p, "CBB_finish");
6080c814320Stb 
6090c814320Stb 	if (!parse_length_equal(p, "bogus z or d", MLKEM_SEED_BYTES, seed_len))
6100c814320Stb 		goto err;
6110c814320Stb 
6120c814320Stb 	keygen->generate_key_external_entropy(keygen->encoded_public_key,
6130c814320Stb 	    keygen->private_key, seed);
6140c814320Stb 	if (!keygen->encode_private_key(keygen->private_key,
6150c814320Stb 	    &encoded_private_key, &encoded_private_key_len)) {
6160c814320Stb 		parse_info(p, "encode private key");
6170c814320Stb 		goto err;
6180c814320Stb 	}
6190c814320Stb 
6200c814320Stb 	failed = !parse_data_equal(p, "public key", &ek,
6210c814320Stb 	    keygen->encoded_public_key, keygen->encoded_public_key_len);
6220c814320Stb 	failed |= !parse_data_equal(p, "private key", &dk,
6230c814320Stb 	    encoded_private_key, encoded_private_key_len);
6240c814320Stb 
6250c814320Stb  err:
6260c814320Stb 	freezero(encoded_private_key, encoded_private_key_len);
6270c814320Stb 
6280c814320Stb 	return failed;
6290c814320Stb }
6300c814320Stb 
6310c814320Stb static int
nist_keygen_run_test_case(void * ctx)6320c814320Stb nist_keygen_run_test_case(void *ctx)
6330c814320Stb {
6340c814320Stb 	return MlkemNistKeygenFileTest(ctx);
6350c814320Stb }
6360c814320Stb 
6370c814320Stb static const struct test_parse nist_keygen_parse = {
6380c814320Stb 	.states = nist_keygen_state_machine,
6390c814320Stb 	.num_states = N_NIST_KEYGEN_STATES,
6400c814320Stb 
6410c814320Stb 	.init = keygen_init,
6420c814320Stb 	.finish = keygen_finish,
6430c814320Stb 
6440c814320Stb 	.run_test_case = nist_keygen_run_test_case,
6450c814320Stb };
6460c814320Stb 
6470c814320Stb static int
mlkem_keygen_tests(const char * fn,size_t size,enum test_type test_type)6480c814320Stb mlkem_keygen_tests(const char *fn, size_t size, enum test_type test_type)
6490c814320Stb {
6500c814320Stb 	struct MLKEM768_private_key private_key768;
6510c814320Stb 	uint8_t encoded_public_key768[MLKEM768_PUBLIC_KEY_BYTES];
6520c814320Stb 	struct keygen_ctx keygen768 = {
6530c814320Stb 		.private_key = &private_key768,
6540c814320Stb 		.encoded_public_key = encoded_public_key768,
6550c814320Stb 		.encoded_public_key_len = sizeof(encoded_public_key768),
6560c814320Stb 		.private_key_len = MLKEM768_PRIVATE_KEY_BYTES,
6570c814320Stb 		.public_key_len = MLKEM768_PUBLIC_KEY_BYTES,
6580c814320Stb 		.generate_key_external_entropy =
6590c814320Stb 		    mlkem768_generate_key_external_entropy,
6600c814320Stb 		.encode_private_key =
6610c814320Stb 		    mlkem768_encode_private_key,
6620c814320Stb 	};
6630c814320Stb 	struct MLKEM1024_private_key private_key1024;
6640c814320Stb 	uint8_t encoded_public_key1024[MLKEM1024_PUBLIC_KEY_BYTES];
6650c814320Stb 	struct keygen_ctx keygen1024 = {
6660c814320Stb 		.private_key = &private_key1024,
6670c814320Stb 		.encoded_public_key = encoded_public_key1024,
6680c814320Stb 		.encoded_public_key_len = sizeof(encoded_public_key1024),
6690c814320Stb 		.private_key_len = MLKEM1024_PRIVATE_KEY_BYTES,
6700c814320Stb 		.public_key_len = MLKEM1024_PUBLIC_KEY_BYTES,
6710c814320Stb 
6720c814320Stb 		.generate_key_external_entropy =
6730c814320Stb 		    mlkem1024_generate_key_external_entropy,
6740c814320Stb 		.encode_private_key =
6750c814320Stb 		    mlkem1024_encode_private_key,
6760c814320Stb 	};
6770c814320Stb 
6780c814320Stb 	if (size == 768 && test_type == TEST_TYPE_NORMAL)
6790c814320Stb 		return parse_test_file(fn, &keygen_parse, &keygen768);
6800c814320Stb 	if (size == 768 && test_type == TEST_TYPE_NIST)
6810c814320Stb 		return parse_test_file(fn, &nist_keygen_parse, &keygen768);
6820c814320Stb 	if (size == 1024 && test_type == TEST_TYPE_NORMAL)
6830c814320Stb 		return parse_test_file(fn, &keygen_parse, &keygen1024);
6840c814320Stb 	if (size == 1024 && test_type == TEST_TYPE_NIST)
6850c814320Stb 		return parse_test_file(fn, &nist_keygen_parse, &keygen1024);
6860c814320Stb 
6870c814320Stb 	errx(1, "unknown keygen test: size %zu, type %d", size, test_type);
6880c814320Stb }
6890c814320Stb 
6900c814320Stb static int
run_mlkem_test(const char * test,const char * fn)6910c814320Stb run_mlkem_test(const char *test, const char *fn)
6920c814320Stb {
6930c814320Stb 	if (strcmp(test, "mlkem768_decap_tests") == 0)
6940c814320Stb 		return mlkem_decap_tests(fn, 768, TEST_TYPE_NORMAL);
6950c814320Stb 	if (strcmp(test, "mlkem768_nist_decap_tests") == 0)
6960c814320Stb 		return mlkem_decap_tests(fn, 768, TEST_TYPE_NIST);
6970c814320Stb 	if (strcmp(test, "mlkem1024_decap_tests") == 0)
6980c814320Stb 		return mlkem_decap_tests(fn, 1024, TEST_TYPE_NORMAL);
6990c814320Stb 	if (strcmp(test, "mlkem1024_nist_decap_tests") == 0)
7000c814320Stb 		return mlkem_decap_tests(fn, 1024, TEST_TYPE_NIST);
7010c814320Stb 
7020c814320Stb 	if (strcmp(test, "mlkem768_encap_tests") == 0)
7030c814320Stb 		return mlkem_encap_tests(fn, 768);
7040c814320Stb 	if (strcmp(test, "mlkem1024_encap_tests") == 0)
7050c814320Stb 		return mlkem_encap_tests(fn, 1024);
7060c814320Stb 
7070c814320Stb 	if (strcmp(test, "mlkem768_keygen_tests") == 0)
7080c814320Stb 		return mlkem_keygen_tests(fn, 768, TEST_TYPE_NORMAL);
7090c814320Stb 	if (strcmp(test, "mlkem768_nist_keygen_tests") == 0)
7100c814320Stb 		return mlkem_keygen_tests(fn, 768, TEST_TYPE_NIST);
7110c814320Stb 	if (strcmp(test, "mlkem1024_keygen_tests") == 0)
7120c814320Stb 		return mlkem_keygen_tests(fn, 1024, TEST_TYPE_NORMAL);
7130c814320Stb 	if (strcmp(test, "mlkem1024_nist_keygen_tests") == 0)
7140c814320Stb 		return mlkem_keygen_tests(fn, 1024, TEST_TYPE_NIST);
7150c814320Stb 
7160c814320Stb 	errx(1, "unknown test %s (test file %s)", test, fn);
7170c814320Stb }
7180c814320Stb 
7190c814320Stb int
main(int argc,const char * argv[])7200c814320Stb main(int argc, const char *argv[])
7210c814320Stb {
7220c814320Stb 	if (argc != 3) {
7230c814320Stb 		fprintf(stderr, "usage: mlkem_test test testfile.txt\n");
7240c814320Stb 		exit(1);
7250c814320Stb 	}
7260c814320Stb 
7270c814320Stb 	return run_mlkem_test(argv[1], argv[2]);
7280c814320Stb }
729