xref: /openbsd/lib/libcrypto/mlkem/mlkem1024.c (revision 9a8fba7c)
1*9a8fba7cStb /* $OpenBSD: mlkem1024.c,v 1.6 2025/01/03 08:19:24 tb Exp $ */
208c63c71Sbeck /*
308c63c71Sbeck  * Copyright (c) 2024, Google Inc.
408c63c71Sbeck  * Copyright (c) 2024, Bob Beck <beck@obtuse.com>
508c63c71Sbeck  *
608c63c71Sbeck  * Permission to use, copy, modify, and/or distribute this software for any
708c63c71Sbeck  * purpose with or without fee is hereby granted, provided that the above
808c63c71Sbeck  * copyright notice and this permission notice appear in all copies.
908c63c71Sbeck  *
1008c63c71Sbeck  * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
1108c63c71Sbeck  * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
1208c63c71Sbeck  * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
1308c63c71Sbeck  * SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
1408c63c71Sbeck  * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
1508c63c71Sbeck  * OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
1608c63c71Sbeck  * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
1708c63c71Sbeck  */
1808c63c71Sbeck 
1908c63c71Sbeck #include <assert.h>
2008c63c71Sbeck #include <stdlib.h>
2108c63c71Sbeck #include <string.h>
2208c63c71Sbeck 
2308c63c71Sbeck #include "bytestring.h"
24ef1019e6Stb #include "mlkem.h"
2508c63c71Sbeck 
2608c63c71Sbeck #include "sha3_internal.h"
2708c63c71Sbeck #include "mlkem_internal.h"
2808c63c71Sbeck #include "constant_time.h"
2908c63c71Sbeck #include "crypto_internal.h"
3008c63c71Sbeck 
3108c63c71Sbeck /* Remove later */
3208c63c71Sbeck #undef LCRYPTO_ALIAS
3308c63c71Sbeck #define LCRYPTO_ALIAS(A)
3408c63c71Sbeck 
3508c63c71Sbeck /*
3608c63c71Sbeck  * See
3708c63c71Sbeck  * https://csrc.nist.gov/pubs/fips/203/final
3808c63c71Sbeck  */
3908c63c71Sbeck 
4008c63c71Sbeck static void
prf(uint8_t * out,size_t out_len,const uint8_t in[33])4108c63c71Sbeck prf(uint8_t *out, size_t out_len, const uint8_t in[33])
4208c63c71Sbeck {
4308c63c71Sbeck 	sha3_ctx ctx;
4408c63c71Sbeck 	shake256_init(&ctx);
4508c63c71Sbeck 	shake_update(&ctx, in, 33);
4608c63c71Sbeck 	shake_xof(&ctx);
4708c63c71Sbeck 	shake_out(&ctx, out, out_len);
4808c63c71Sbeck }
4908c63c71Sbeck 
5008c63c71Sbeck /* Section 4.1 */
5108c63c71Sbeck static void
hash_h(uint8_t out[32],const uint8_t * in,size_t len)5208c63c71Sbeck hash_h(uint8_t out[32], const uint8_t *in, size_t len)
5308c63c71Sbeck {
5408c63c71Sbeck 	sha3_ctx ctx;
5508c63c71Sbeck 	sha3_init(&ctx, 32);
5608c63c71Sbeck 	sha3_update(&ctx, in, len);
5708c63c71Sbeck 	sha3_final(out, &ctx);
5808c63c71Sbeck }
5908c63c71Sbeck 
6008c63c71Sbeck static void
hash_g(uint8_t out[64],const uint8_t * in,size_t len)6108c63c71Sbeck hash_g(uint8_t out[64], const uint8_t *in, size_t len)
6208c63c71Sbeck {
6308c63c71Sbeck 	sha3_ctx ctx;
6408c63c71Sbeck 	sha3_init(&ctx, 64);
6508c63c71Sbeck 	sha3_update(&ctx, in, len);
6608c63c71Sbeck 	sha3_final(out, &ctx);
6708c63c71Sbeck }
6808c63c71Sbeck 
6908c63c71Sbeck /* this is called 'J' in the spec */
7008c63c71Sbeck static void
kdf(uint8_t out[MLKEM_SHARED_SECRET_BYTES],const uint8_t failure_secret[32],const uint8_t * in,size_t len)7108c63c71Sbeck kdf(uint8_t out[MLKEM_SHARED_SECRET_BYTES], const uint8_t failure_secret[32],
7208c63c71Sbeck     const uint8_t *in, size_t len)
7308c63c71Sbeck {
7408c63c71Sbeck 	sha3_ctx ctx;
7508c63c71Sbeck 	shake256_init(&ctx);
7608c63c71Sbeck 	shake_update(&ctx, failure_secret, 32);
7708c63c71Sbeck 	shake_update(&ctx, in, len);
7808c63c71Sbeck 	shake_xof(&ctx);
7908c63c71Sbeck 	shake_out(&ctx, out, MLKEM_SHARED_SECRET_BYTES);
8008c63c71Sbeck }
8108c63c71Sbeck 
8208c63c71Sbeck #define DEGREE 256
8308c63c71Sbeck #define RANK1024 4
8408c63c71Sbeck 
8508c63c71Sbeck static const size_t kBarrettMultiplier = 5039;
8608c63c71Sbeck static const unsigned kBarrettShift = 24;
8708c63c71Sbeck static const uint16_t kPrime = 3329;
8808c63c71Sbeck static const int kLog2Prime = 12;
8908c63c71Sbeck static const uint16_t kHalfPrime = (/*kPrime=*/3329 - 1) / 2;
9008c63c71Sbeck static const int kDU1024 = 11;
9108c63c71Sbeck static const int kDV1024 = 5;
9208c63c71Sbeck 
9308c63c71Sbeck /*
9408c63c71Sbeck  * kInverseDegree is 128^-1 mod 3329; 128 because kPrime does not have a 512th
9508c63c71Sbeck  * root of unity.
9608c63c71Sbeck  */
9708c63c71Sbeck static const uint16_t kInverseDegree = 3303;
9808c63c71Sbeck static const size_t kEncodedVectorSize =
9908c63c71Sbeck     (/*kLog2Prime=*/12 * DEGREE / 8) * RANK1024;
10008c63c71Sbeck static const size_t kCompressedVectorSize = /*kDU1024=*/ 11 * RANK1024 * DEGREE /
10108c63c71Sbeck     8;
10208c63c71Sbeck 
10308c63c71Sbeck typedef struct scalar {
10408c63c71Sbeck 	/* On every function entry and exit, 0 <= c < kPrime. */
10508c63c71Sbeck 	uint16_t c[DEGREE];
10608c63c71Sbeck } scalar;
10708c63c71Sbeck 
10808c63c71Sbeck typedef struct vector {
10908c63c71Sbeck 	scalar v[RANK1024];
11008c63c71Sbeck } vector;
11108c63c71Sbeck 
11208c63c71Sbeck typedef struct matrix {
11308c63c71Sbeck 	scalar v[RANK1024][RANK1024];
11408c63c71Sbeck } matrix;
11508c63c71Sbeck 
11608c63c71Sbeck /*
11708c63c71Sbeck  * This bit of Python will be referenced in some of the following comments:
11808c63c71Sbeck  *
11908c63c71Sbeck  *  p = 3329
12008c63c71Sbeck  *
12108c63c71Sbeck  * def bitreverse(i):
12208c63c71Sbeck  *     ret = 0
12308c63c71Sbeck  *     for n in range(7):
12408c63c71Sbeck  *         bit = i & 1
12508c63c71Sbeck  *         ret <<= 1
12608c63c71Sbeck  *         ret |= bit
12708c63c71Sbeck  *         i >>= 1
12808c63c71Sbeck  *     return ret
12908c63c71Sbeck  */
13008c63c71Sbeck 
13108c63c71Sbeck /* kNTTRoots = [pow(17, bitreverse(i), p) for i in range(128)] */
13208c63c71Sbeck static const uint16_t kNTTRoots[128] = {
13308c63c71Sbeck 	1,    1729, 2580, 3289, 2642, 630,  1897, 848,  1062, 1919, 193,  797,
13408c63c71Sbeck 	2786, 3260, 569,  1746, 296,  2447, 1339, 1476, 3046, 56,   2240, 1333,
13508c63c71Sbeck 	1426, 2094, 535,  2882, 2393, 2879, 1974, 821,  289,  331,  3253, 1756,
13608c63c71Sbeck 	1197, 2304, 2277, 2055, 650,  1977, 2513, 632,  2865, 33,   1320, 1915,
13708c63c71Sbeck 	2319, 1435, 807,  452,  1438, 2868, 1534, 2402, 2647, 2617, 1481, 648,
13808c63c71Sbeck 	2474, 3110, 1227, 910,  17,   2761, 583,  2649, 1637, 723,  2288, 1100,
13908c63c71Sbeck 	1409, 2662, 3281, 233,  756,  2156, 3015, 3050, 1703, 1651, 2789, 1789,
14008c63c71Sbeck 	1847, 952,  1461, 2687, 939,  2308, 2437, 2388, 733,  2337, 268,  641,
14108c63c71Sbeck 	1584, 2298, 2037, 3220, 375,  2549, 2090, 1645, 1063, 319,  2773, 757,
14208c63c71Sbeck 	2099, 561,  2466, 2594, 2804, 1092, 403,  1026, 1143, 2150, 2775, 886,
14308c63c71Sbeck 	1722, 1212, 1874, 1029, 2110, 2935, 885,  2154,
14408c63c71Sbeck };
14508c63c71Sbeck 
14608c63c71Sbeck /* kInverseNTTRoots = [pow(17, -bitreverse(i), p) for i in range(128)] */
14708c63c71Sbeck static const uint16_t kInverseNTTRoots[128] = {
14808c63c71Sbeck 	1,    1600, 40,   749,  2481, 1432, 2699, 687,  1583, 2760, 69,   543,
14908c63c71Sbeck 	2532, 3136, 1410, 2267, 2508, 1355, 450,  936,  447,  2794, 1235, 1903,
15008c63c71Sbeck 	1996, 1089, 3273, 283,  1853, 1990, 882,  3033, 2419, 2102, 219,  855,
15108c63c71Sbeck 	2681, 1848, 712,  682,  927,  1795, 461,  1891, 2877, 2522, 1894, 1010,
15208c63c71Sbeck 	1414, 2009, 3296, 464,  2697, 816,  1352, 2679, 1274, 1052, 1025, 2132,
15308c63c71Sbeck 	1573, 76,   2998, 3040, 1175, 2444, 394,  1219, 2300, 1455, 2117, 1607,
15408c63c71Sbeck 	2443, 554,  1179, 2186, 2303, 2926, 2237, 525,  735,  863,  2768, 1230,
15508c63c71Sbeck 	2572, 556,  3010, 2266, 1684, 1239, 780,  2954, 109,  1292, 1031, 1745,
15608c63c71Sbeck 	2688, 3061, 992,  2596, 941,  892,  1021, 2390, 642,  1868, 2377, 1482,
15708c63c71Sbeck 	1540, 540,  1678, 1626, 279,  314,  1173, 2573, 3096, 48,   667,  1920,
15808c63c71Sbeck 	2229, 1041, 2606, 1692, 680,  2746, 568,  3312,
15908c63c71Sbeck };
16008c63c71Sbeck 
16108c63c71Sbeck /* kModRoots = [pow(17, 2*bitreverse(i) + 1, p) for i in range(128)] */
16208c63c71Sbeck static const uint16_t kModRoots[128] = {
16308c63c71Sbeck 	17,   3312, 2761, 568,  583,  2746, 2649, 680,  1637, 1692, 723,  2606,
16408c63c71Sbeck 	2288, 1041, 1100, 2229, 1409, 1920, 2662, 667,  3281, 48,   233,  3096,
16508c63c71Sbeck 	756,  2573, 2156, 1173, 3015, 314,  3050, 279,  1703, 1626, 1651, 1678,
16608c63c71Sbeck 	2789, 540,  1789, 1540, 1847, 1482, 952,  2377, 1461, 1868, 2687, 642,
16708c63c71Sbeck 	939,  2390, 2308, 1021, 2437, 892,  2388, 941,  733,  2596, 2337, 992,
16808c63c71Sbeck 	268,  3061, 641,  2688, 1584, 1745, 2298, 1031, 2037, 1292, 3220, 109,
16908c63c71Sbeck 	375,  2954, 2549, 780,  2090, 1239, 1645, 1684, 1063, 2266, 319,  3010,
17008c63c71Sbeck 	2773, 556,  757,  2572, 2099, 1230, 561,  2768, 2466, 863,  2594, 735,
17108c63c71Sbeck 	2804, 525,  1092, 2237, 403,  2926, 1026, 2303, 1143, 2186, 2150, 1179,
17208c63c71Sbeck 	2775, 554,  886,  2443, 1722, 1607, 1212, 2117, 1874, 1455, 1029, 2300,
17308c63c71Sbeck 	2110, 1219, 2935, 394,  885,  2444, 2154, 1175,
17408c63c71Sbeck };
17508c63c71Sbeck 
17608c63c71Sbeck /* reduce_once reduces 0 <= x < 2*kPrime, mod kPrime. */
17708c63c71Sbeck static uint16_t
reduce_once(uint16_t x)17808c63c71Sbeck reduce_once(uint16_t x)
17908c63c71Sbeck {
18008c63c71Sbeck 	assert(x < 2 * kPrime);
18108c63c71Sbeck 	const uint16_t subtracted = x - kPrime;
18208c63c71Sbeck 	uint16_t mask = 0u - (subtracted >> 15);
1839ee6f1feSbeck 
18408c63c71Sbeck 	/*
1859ee6f1feSbeck 	 * Although this is a constant-time select, we omit a value barrier here.
1869ee6f1feSbeck 	 * Value barriers impede auto-vectorization (likely because it forces the
1879ee6f1feSbeck 	 * value to transit through a general-purpose register). On AArch64, this
1889ee6f1feSbeck 	 * is a difference of 2x.
1899ee6f1feSbeck 	 *
1909ee6f1feSbeck 	 * We usually add value barriers to selects because Clang turns
1919ee6f1feSbeck          * consecutive selects with the same condition into a branch instead of
1929ee6f1feSbeck 	 * CMOV/CSEL. This condition does not occur in ML-KEM, so omitting it
1939ee6f1feSbeck          * seems to be safe so far  but see
1949ee6f1feSbeck          * |scalar_centered_binomial_distribution_eta_2_with_prf|.
19508c63c71Sbeck 	 */
19608c63c71Sbeck 	return (mask & x) | (~mask & subtracted);
19708c63c71Sbeck }
19808c63c71Sbeck 
19908c63c71Sbeck /*
20008c63c71Sbeck  * constant time reduce x mod kPrime using Barrett reduction. x must be less
20108c63c71Sbeck  * than kPrime + 2×kPrime².
20208c63c71Sbeck  */
20308c63c71Sbeck static uint16_t
reduce(uint32_t x)20408c63c71Sbeck reduce(uint32_t x)
20508c63c71Sbeck {
20608c63c71Sbeck 	uint64_t product = (uint64_t)x * kBarrettMultiplier;
20708c63c71Sbeck 	uint32_t quotient = (uint32_t)(product >> kBarrettShift);
20808c63c71Sbeck 	uint32_t remainder = x - quotient * kPrime;
20908c63c71Sbeck 
21008c63c71Sbeck 	assert(x < kPrime + 2u * kPrime * kPrime);
21108c63c71Sbeck 	return reduce_once(remainder);
21208c63c71Sbeck }
21308c63c71Sbeck 
21408c63c71Sbeck static void
scalar_zero(scalar * out)21508c63c71Sbeck scalar_zero(scalar *out)
21608c63c71Sbeck {
21708c63c71Sbeck 	memset(out, 0, sizeof(*out));
21808c63c71Sbeck }
21908c63c71Sbeck 
22008c63c71Sbeck static void
vector_zero(vector * out)22108c63c71Sbeck vector_zero(vector *out)
22208c63c71Sbeck {
22308c63c71Sbeck 	memset(out, 0, sizeof(*out));
22408c63c71Sbeck }
22508c63c71Sbeck 
22608c63c71Sbeck /*
22708c63c71Sbeck  * In place number theoretic transform of a given scalar.
22808c63c71Sbeck  * Note that MLKEM's kPrime 3329 does not have a 512th root of unity, so this
22908c63c71Sbeck  * transform leaves off the last iteration of the usual FFT code, with the 128
23008c63c71Sbeck  * relevant roots of unity being stored in |kNTTRoots|. This means the output
23108c63c71Sbeck  * should be seen as 128 elements in GF(3329^2), with the coefficients of the
23208c63c71Sbeck  * elements being consecutive entries in |s->c|.
23308c63c71Sbeck  */
23408c63c71Sbeck static void
scalar_ntt(scalar * s)23508c63c71Sbeck scalar_ntt(scalar *s)
23608c63c71Sbeck {
23708c63c71Sbeck 	int offset = DEGREE;
23808c63c71Sbeck 	int step;
23908c63c71Sbeck 	/*
24008c63c71Sbeck 	 * `int` is used here because using `size_t` throughout caused a ~5% slowdown
24108c63c71Sbeck 	 * with Clang 14 on Aarch64.
24208c63c71Sbeck 	 */
24308c63c71Sbeck 	for (step = 1; step < DEGREE / 2; step <<= 1) {
24408c63c71Sbeck 		int i, j, k = 0;
24508c63c71Sbeck 
24608c63c71Sbeck 		offset >>= 1;
24708c63c71Sbeck 		for (i = 0; i < step; i++) {
24808c63c71Sbeck 			const uint32_t step_root = kNTTRoots[i + step];
24908c63c71Sbeck 
25008c63c71Sbeck 			for (j = k; j < k + offset; j++) {
25108c63c71Sbeck 				uint16_t odd, even;
25208c63c71Sbeck 
25308c63c71Sbeck 				odd = reduce(step_root * s->c[j + offset]);
25408c63c71Sbeck 				even = s->c[j];
25508c63c71Sbeck 				s->c[j] = reduce_once(odd + even);
25608c63c71Sbeck 				s->c[j + offset] = reduce_once(even - odd +
25708c63c71Sbeck 				    kPrime);
25808c63c71Sbeck 			}
25908c63c71Sbeck 			k += 2 * offset;
26008c63c71Sbeck 		}
26108c63c71Sbeck 	}
26208c63c71Sbeck }
26308c63c71Sbeck 
26408c63c71Sbeck static void
vector_ntt(vector * a)26508c63c71Sbeck vector_ntt(vector *a)
26608c63c71Sbeck {
26708c63c71Sbeck 	int i;
26808c63c71Sbeck 
26908c63c71Sbeck 	for (i = 0; i < RANK1024; i++) {
27008c63c71Sbeck 		scalar_ntt(&a->v[i]);
27108c63c71Sbeck 	}
27208c63c71Sbeck }
27308c63c71Sbeck 
27408c63c71Sbeck /*
27508c63c71Sbeck  * In place inverse number theoretic transform of a given scalar, with pairs of
27608c63c71Sbeck  * entries of s->v being interpreted as elements of GF(3329^2). Just as with the
27708c63c71Sbeck  * number theoretic transform, this leaves off the first step of the normal iFFT
27808c63c71Sbeck  * to account for the fact that 3329 does not have a 512th root of unity, using
27908c63c71Sbeck  * the precomputed 128 roots of unity stored in |kInverseNTTRoots|.
28008c63c71Sbeck  */
28108c63c71Sbeck static void
scalar_inverse_ntt(scalar * s)28208c63c71Sbeck scalar_inverse_ntt(scalar *s)
28308c63c71Sbeck {
28408c63c71Sbeck 	int i, j, k, offset, step = DEGREE / 2;
28508c63c71Sbeck 
28608c63c71Sbeck 	/*
28708c63c71Sbeck 	 * `int` is used here because using `size_t` throughout caused a ~5% slowdown
28808c63c71Sbeck 	 * with Clang 14 on Aarch64.
28908c63c71Sbeck 	 */
29008c63c71Sbeck 	for (offset = 2; offset < DEGREE; offset <<= 1) {
29108c63c71Sbeck 		step >>= 1;
29208c63c71Sbeck 		k = 0;
29308c63c71Sbeck 		for (i = 0; i < step; i++) {
29408c63c71Sbeck 			uint32_t step_root = kInverseNTTRoots[i + step];
29508c63c71Sbeck 			for (j = k; j < k + offset; j++) {
29608c63c71Sbeck 				uint16_t odd, even;
29708c63c71Sbeck 				odd = s->c[j + offset];
29808c63c71Sbeck 				even = s->c[j];
29908c63c71Sbeck 				s->c[j] = reduce_once(odd + even);
30008c63c71Sbeck 				s->c[j + offset] = reduce(step_root *
30108c63c71Sbeck 				    (even - odd + kPrime));
30208c63c71Sbeck 			}
30308c63c71Sbeck 			k += 2 * offset;
30408c63c71Sbeck 		}
30508c63c71Sbeck 	}
30608c63c71Sbeck 	for (i = 0; i < DEGREE; i++) {
30708c63c71Sbeck 		s->c[i] = reduce(s->c[i] * kInverseDegree);
30808c63c71Sbeck 	}
30908c63c71Sbeck }
31008c63c71Sbeck 
31108c63c71Sbeck static void
vector_inverse_ntt(vector * a)31208c63c71Sbeck vector_inverse_ntt(vector *a)
31308c63c71Sbeck {
31408c63c71Sbeck 	int i;
31508c63c71Sbeck 
31608c63c71Sbeck 	for (i = 0; i < RANK1024; i++) {
31708c63c71Sbeck 		scalar_inverse_ntt(&a->v[i]);
31808c63c71Sbeck 	}
31908c63c71Sbeck }
32008c63c71Sbeck 
32108c63c71Sbeck static void
scalar_add(scalar * lhs,const scalar * rhs)32208c63c71Sbeck scalar_add(scalar *lhs, const scalar *rhs)
32308c63c71Sbeck {
32408c63c71Sbeck 	int i;
32508c63c71Sbeck 
32608c63c71Sbeck 	for (i = 0; i < DEGREE; i++) {
32708c63c71Sbeck 		lhs->c[i] = reduce_once(lhs->c[i] + rhs->c[i]);
32808c63c71Sbeck 	}
32908c63c71Sbeck }
33008c63c71Sbeck 
33108c63c71Sbeck static void
scalar_sub(scalar * lhs,const scalar * rhs)33208c63c71Sbeck scalar_sub(scalar *lhs, const scalar *rhs)
33308c63c71Sbeck {
33408c63c71Sbeck 	int i;
33508c63c71Sbeck 
33608c63c71Sbeck 	for (i = 0; i < DEGREE; i++) {
33708c63c71Sbeck 		lhs->c[i] = reduce_once(lhs->c[i] - rhs->c[i] + kPrime);
33808c63c71Sbeck 	}
33908c63c71Sbeck }
34008c63c71Sbeck 
34108c63c71Sbeck /*
342*9a8fba7cStb  * Multiplying two scalars in the number theoretically transformed state.
343*9a8fba7cStb  * Since 3329 does not have a 512th root of unity, this means we have to
344*9a8fba7cStb  * interpret the 2*ith and (2*i+1)th entries of the scalar as elements of
345*9a8fba7cStb  * GF(3329)[X]/(X^2 - 17^(2*bitreverse(i)+1)).
346*9a8fba7cStb  * The value of 17^(2*bitreverse(i)+1) mod 3329 is stored in the precomputed
347*9a8fba7cStb  * |kModRoots| table. Our Barrett transform only allows us to multiply two
348*9a8fba7cStb  * reduced numbers together, so we need some intermediate reduction steps,
349*9a8fba7cStb  * even if an uint64_t could hold 3 multiplied numbers.
35008c63c71Sbeck  */
35108c63c71Sbeck static void
scalar_mult(scalar * out,const scalar * lhs,const scalar * rhs)35208c63c71Sbeck scalar_mult(scalar *out, const scalar *lhs, const scalar *rhs)
35308c63c71Sbeck {
35408c63c71Sbeck 	int i;
35508c63c71Sbeck 
35608c63c71Sbeck 	for (i = 0; i < DEGREE / 2; i++) {
35708c63c71Sbeck 		uint32_t real_real = (uint32_t)lhs->c[2 * i] * rhs->c[2 * i];
35808c63c71Sbeck 		uint32_t img_img = (uint32_t)lhs->c[2 * i + 1] *
35908c63c71Sbeck 		    rhs->c[2 * i + 1];
36008c63c71Sbeck 		uint32_t real_img = (uint32_t)lhs->c[2 * i] * rhs->c[2 * i + 1];
36108c63c71Sbeck 		uint32_t img_real = (uint32_t)lhs->c[2 * i + 1] * rhs->c[2 * i];
36208c63c71Sbeck 
36308c63c71Sbeck 		out->c[2 * i] =
36408c63c71Sbeck 		    reduce(real_real +
36508c63c71Sbeck 		    (uint32_t)reduce(img_img) * kModRoots[i]);
36608c63c71Sbeck 		out->c[2 * i + 1] = reduce(img_real + real_img);
36708c63c71Sbeck 	}
36808c63c71Sbeck }
36908c63c71Sbeck 
37008c63c71Sbeck static void
vector_add(vector * lhs,const vector * rhs)37108c63c71Sbeck vector_add(vector *lhs, const vector *rhs)
37208c63c71Sbeck {
37308c63c71Sbeck 	int i;
37408c63c71Sbeck 
37508c63c71Sbeck 	for (i = 0; i < RANK1024; i++) {
37608c63c71Sbeck 		scalar_add(&lhs->v[i], &rhs->v[i]);
37708c63c71Sbeck 	}
37808c63c71Sbeck }
37908c63c71Sbeck 
38008c63c71Sbeck static void
matrix_mult(vector * out,const matrix * m,const vector * a)38108c63c71Sbeck matrix_mult(vector *out, const matrix *m, const vector *a)
38208c63c71Sbeck {
38308c63c71Sbeck 	int i, j;
38408c63c71Sbeck 
38508c63c71Sbeck 	vector_zero(out);
38608c63c71Sbeck 	for (i = 0; i < RANK1024; i++) {
38708c63c71Sbeck 		for (j = 0; j < RANK1024; j++) {
38808c63c71Sbeck 			scalar product;
38908c63c71Sbeck 
39008c63c71Sbeck 			scalar_mult(&product, &m->v[i][j], &a->v[j]);
39108c63c71Sbeck 			scalar_add(&out->v[i], &product);
39208c63c71Sbeck 		}
39308c63c71Sbeck 	}
39408c63c71Sbeck }
39508c63c71Sbeck 
39608c63c71Sbeck static void
matrix_mult_transpose(vector * out,const matrix * m,const vector * a)39708c63c71Sbeck matrix_mult_transpose(vector *out, const matrix *m,
39808c63c71Sbeck     const vector *a)
39908c63c71Sbeck {
40008c63c71Sbeck 	int i, j;
40108c63c71Sbeck 
40208c63c71Sbeck 	vector_zero(out);
40308c63c71Sbeck 	for (i = 0; i < RANK1024; i++) {
40408c63c71Sbeck 		for (j = 0; j < RANK1024; j++) {
40508c63c71Sbeck 			scalar product;
40608c63c71Sbeck 
40708c63c71Sbeck 			scalar_mult(&product, &m->v[j][i], &a->v[j]);
40808c63c71Sbeck 			scalar_add(&out->v[i], &product);
40908c63c71Sbeck 		}
41008c63c71Sbeck 	}
41108c63c71Sbeck }
41208c63c71Sbeck 
41308c63c71Sbeck static void
scalar_inner_product(scalar * out,const vector * lhs,const vector * rhs)41408c63c71Sbeck scalar_inner_product(scalar *out, const vector *lhs,
41508c63c71Sbeck     const vector *rhs)
41608c63c71Sbeck {
41708c63c71Sbeck 	int i;
41808c63c71Sbeck 	scalar_zero(out);
41908c63c71Sbeck 	for (i = 0; i < RANK1024; i++) {
42008c63c71Sbeck 		scalar product;
42108c63c71Sbeck 
42208c63c71Sbeck 		scalar_mult(&product, &lhs->v[i], &rhs->v[i]);
42308c63c71Sbeck 		scalar_add(out, &product);
42408c63c71Sbeck 	}
42508c63c71Sbeck }
42608c63c71Sbeck 
42708c63c71Sbeck /*
42808c63c71Sbeck  * Algorithm 6 of spec. Rejection samples a Keccak stream to get uniformly
42908c63c71Sbeck  * distributed elements. This is used for matrix expansion and only operates on
43008c63c71Sbeck  * public inputs.
43108c63c71Sbeck  */
43208c63c71Sbeck static void
scalar_from_keccak_vartime(scalar * out,sha3_ctx * keccak_ctx)43308c63c71Sbeck scalar_from_keccak_vartime(scalar *out, sha3_ctx *keccak_ctx)
43408c63c71Sbeck {
43508c63c71Sbeck 	int i, done = 0;
43608c63c71Sbeck 
43708c63c71Sbeck 	while (done < DEGREE) {
43808c63c71Sbeck 		uint8_t block[168];
43908c63c71Sbeck 
44008c63c71Sbeck 		shake_out(keccak_ctx, block, sizeof(block));
44108c63c71Sbeck 		for (i = 0; i < sizeof(block) && done < DEGREE; i += 3) {
44208c63c71Sbeck 			uint16_t d1 = block[i] + 256 * (block[i + 1] % 16);
44308c63c71Sbeck 			uint16_t d2 = block[i + 1] / 16 + 16 * block[i + 2];
44408c63c71Sbeck 
44508c63c71Sbeck 			if (d1 < kPrime) {
44608c63c71Sbeck 				out->c[done++] = d1;
44708c63c71Sbeck 			}
44808c63c71Sbeck 			if (d2 < kPrime && done < DEGREE) {
44908c63c71Sbeck 				out->c[done++] = d2;
45008c63c71Sbeck 			}
45108c63c71Sbeck 		}
45208c63c71Sbeck 	}
45308c63c71Sbeck }
45408c63c71Sbeck 
45508c63c71Sbeck /*
45608c63c71Sbeck  * Algorithm 7 of the spec, with eta fixed to two and the PRF call
45708c63c71Sbeck  * included. Creates binominally distributed elements by sampling 2*|eta| bits,
45808c63c71Sbeck  * and setting the coefficient to the count of the first bits minus the count of
45908c63c71Sbeck  * the second bits, resulting in a centered binomial distribution. Since eta is
46008c63c71Sbeck  * two this gives -2/2 with a probability of 1/16, -1/1 with probability 1/4,
46108c63c71Sbeck  * and 0 with probability 3/8.
46208c63c71Sbeck  */
46308c63c71Sbeck static void
scalar_centered_binomial_distribution_eta_2_with_prf(scalar * out,const uint8_t input[33])46408c63c71Sbeck scalar_centered_binomial_distribution_eta_2_with_prf(scalar *out,
46508c63c71Sbeck     const uint8_t input[33])
46608c63c71Sbeck {
46708c63c71Sbeck 	uint8_t entropy[128];
46808c63c71Sbeck 	int i;
46908c63c71Sbeck 
47008c63c71Sbeck 	CTASSERT(sizeof(entropy) == 2 * /*kEta=*/ 2 * DEGREE / 8);
47108c63c71Sbeck 	prf(entropy, sizeof(entropy), input);
47208c63c71Sbeck 
47308c63c71Sbeck 	for (i = 0; i < DEGREE; i += 2) {
47408c63c71Sbeck 		uint8_t byte = entropy[i / 2];
4759ee6f1feSbeck 		uint16_t mask;
4769ee6f1feSbeck 		uint16_t value = (byte & 1) + ((byte >> 1) & 1);
47708c63c71Sbeck 
47808c63c71Sbeck 		value -= ((byte >> 2) & 1) + ((byte >> 3) & 1);
47967abc7a1Stb 
4809ee6f1feSbeck 		/*
4819ee6f1feSbeck 		 * Add |kPrime| if |value| underflowed. See |reduce_once| for a
4829ee6f1feSbeck 		 * discussion on why the value barrier is omitted. While this
4839ee6f1feSbeck 		 * could have been written reduce_once(value + kPrime), this is
4849ee6f1feSbeck 		 * one extra addition and small range of |value| tempts some
4859ee6f1feSbeck 		 * versions of Clang to emit a branch.
4869ee6f1feSbeck 		 */
4879ee6f1feSbeck 		mask = 0u - (value >> 15);
4889ee6f1feSbeck 		out->c[i] = ((value + kPrime) & mask) | (value & ~mask);
48908c63c71Sbeck 
49008c63c71Sbeck 		byte >>= 4;
4919ee6f1feSbeck 		value = (byte & 1) + ((byte >> 1) & 1);
49208c63c71Sbeck 		value -= ((byte >> 2) & 1) + ((byte >> 3) & 1);
4939ee6f1feSbeck 		/*  See above. */
4949ee6f1feSbeck 		mask = 0u - (value >> 15);
4959ee6f1feSbeck 		out->c[i + 1] = ((value + kPrime) & mask) | (value & ~mask);
49608c63c71Sbeck 	}
49708c63c71Sbeck }
49808c63c71Sbeck 
49908c63c71Sbeck /*
50008c63c71Sbeck  * Generates a secret vector by using
50108c63c71Sbeck  * |scalar_centered_binomial_distribution_eta_2_with_prf|, using the given seed
50208c63c71Sbeck  * appending and incrementing |counter| for entry of the vector.
50308c63c71Sbeck  */
50408c63c71Sbeck static void
vector_generate_secret_eta_2(vector * out,uint8_t * counter,const uint8_t seed[32])50508c63c71Sbeck vector_generate_secret_eta_2(vector *out, uint8_t *counter,
50608c63c71Sbeck     const uint8_t seed[32])
50708c63c71Sbeck {
50808c63c71Sbeck 	uint8_t input[33];
50908c63c71Sbeck 	int i;
51008c63c71Sbeck 
51108c63c71Sbeck 	memcpy(input, seed, 32);
51208c63c71Sbeck 	for (i = 0; i < RANK1024; i++) {
51308c63c71Sbeck 		input[32] = (*counter)++;
51408c63c71Sbeck 		scalar_centered_binomial_distribution_eta_2_with_prf(&out->v[i],
51508c63c71Sbeck 		    input);
51608c63c71Sbeck 	}
51708c63c71Sbeck }
51808c63c71Sbeck 
51908c63c71Sbeck /* Expands the matrix of a seed for key generation and for encaps-CPA. */
52008c63c71Sbeck static void
matrix_expand(matrix * out,const uint8_t rho[32])52108c63c71Sbeck matrix_expand(matrix *out, const uint8_t rho[32])
52208c63c71Sbeck {
52308c63c71Sbeck 	uint8_t input[34];
52408c63c71Sbeck 	int i, j;
52508c63c71Sbeck 
52608c63c71Sbeck 	memcpy(input, rho, 32);
52708c63c71Sbeck 	for (i = 0; i < RANK1024; i++) {
52808c63c71Sbeck 		for (j = 0; j < RANK1024; j++) {
52908c63c71Sbeck 			sha3_ctx keccak_ctx;
53008c63c71Sbeck 
53108c63c71Sbeck 			input[32] = i;
53208c63c71Sbeck 			input[33] = j;
53308c63c71Sbeck 			shake128_init(&keccak_ctx);
53408c63c71Sbeck 			shake_update(&keccak_ctx, input, sizeof(input));
53508c63c71Sbeck 			shake_xof(&keccak_ctx);
53608c63c71Sbeck 			scalar_from_keccak_vartime(&out->v[i][j], &keccak_ctx);
53708c63c71Sbeck 		}
53808c63c71Sbeck 	}
53908c63c71Sbeck }
54008c63c71Sbeck 
54108c63c71Sbeck static const uint8_t kMasks[8] = {0x01, 0x03, 0x07, 0x0f,
54208c63c71Sbeck 	0x1f, 0x3f, 0x7f, 0xff};
54308c63c71Sbeck 
54408c63c71Sbeck static void
scalar_encode(uint8_t * out,const scalar * s,int bits)54508c63c71Sbeck scalar_encode(uint8_t *out, const scalar *s, int bits)
54608c63c71Sbeck {
54708c63c71Sbeck 	uint8_t out_byte = 0;
54808c63c71Sbeck 	int i, out_byte_bits = 0;
54908c63c71Sbeck 
55008c63c71Sbeck 	assert(bits <= (int)sizeof(*s->c) * 8 && bits != 1);
55108c63c71Sbeck 	for (i = 0; i < DEGREE; i++) {
55208c63c71Sbeck 		uint16_t element = s->c[i];
55308c63c71Sbeck 		int element_bits_done = 0;
55408c63c71Sbeck 
55508c63c71Sbeck 		while (element_bits_done < bits) {
55608c63c71Sbeck 			int chunk_bits = bits - element_bits_done;
55708c63c71Sbeck 			int out_bits_remaining = 8 - out_byte_bits;
55808c63c71Sbeck 
55908c63c71Sbeck 			if (chunk_bits >= out_bits_remaining) {
56008c63c71Sbeck 				chunk_bits = out_bits_remaining;
56108c63c71Sbeck 				out_byte |= (element &
56208c63c71Sbeck 				    kMasks[chunk_bits - 1]) << out_byte_bits;
56308c63c71Sbeck 				*out = out_byte;
56408c63c71Sbeck 				out++;
56508c63c71Sbeck 				out_byte_bits = 0;
56608c63c71Sbeck 				out_byte = 0;
56708c63c71Sbeck 			} else {
56808c63c71Sbeck 				out_byte |= (element &
56908c63c71Sbeck 				    kMasks[chunk_bits - 1]) << out_byte_bits;
57008c63c71Sbeck 				out_byte_bits += chunk_bits;
57108c63c71Sbeck 			}
57208c63c71Sbeck 
57308c63c71Sbeck 			element_bits_done += chunk_bits;
57408c63c71Sbeck 			element >>= chunk_bits;
57508c63c71Sbeck 		}
57608c63c71Sbeck 	}
57708c63c71Sbeck 
57808c63c71Sbeck 	if (out_byte_bits > 0) {
57908c63c71Sbeck 		*out = out_byte;
58008c63c71Sbeck 	}
58108c63c71Sbeck }
58208c63c71Sbeck 
58308c63c71Sbeck /* scalar_encode_1 is |scalar_encode| specialised for |bits| == 1. */
58408c63c71Sbeck static void
scalar_encode_1(uint8_t out[32],const scalar * s)58508c63c71Sbeck scalar_encode_1(uint8_t out[32], const scalar *s)
58608c63c71Sbeck {
58708c63c71Sbeck 	int i, j;
58808c63c71Sbeck 
58908c63c71Sbeck 	for (i = 0; i < DEGREE; i += 8) {
59008c63c71Sbeck 		uint8_t out_byte = 0;
59108c63c71Sbeck 
59208c63c71Sbeck 		for (j = 0; j < 8; j++) {
59308c63c71Sbeck 			out_byte |= (s->c[i + j] & 1) << j;
59408c63c71Sbeck 		}
59508c63c71Sbeck 		*out = out_byte;
59608c63c71Sbeck 		out++;
59708c63c71Sbeck 	}
59808c63c71Sbeck }
59908c63c71Sbeck 
60008c63c71Sbeck /*
60108c63c71Sbeck  * Encodes an entire vector into 32*|RANK1024|*|bits| bytes. Note that since 256
60208c63c71Sbeck  * (DEGREE) is divisible by 8, the individual vector entries will always fill a
60308c63c71Sbeck  * whole number of bytes, so we do not need to worry about bit packing here.
60408c63c71Sbeck  */
60508c63c71Sbeck static void
vector_encode(uint8_t * out,const vector * a,int bits)60608c63c71Sbeck vector_encode(uint8_t *out, const vector *a, int bits)
60708c63c71Sbeck {
60808c63c71Sbeck 	int i;
60908c63c71Sbeck 
61008c63c71Sbeck 	for (i = 0; i < RANK1024; i++) {
61108c63c71Sbeck 		scalar_encode(out + i * bits * DEGREE / 8, &a->v[i], bits);
61208c63c71Sbeck 	}
61308c63c71Sbeck }
61408c63c71Sbeck 
61508c63c71Sbeck /*
61608c63c71Sbeck  * scalar_decode parses |DEGREE * bits| bits from |in| into |DEGREE| values in
61708c63c71Sbeck  * |out|. It returns one on success and zero if any parsed value is >=
61808c63c71Sbeck  * |kPrime|.
61908c63c71Sbeck  */
62008c63c71Sbeck static int
scalar_decode(scalar * out,const uint8_t * in,int bits)62108c63c71Sbeck scalar_decode(scalar *out, const uint8_t *in, int bits)
62208c63c71Sbeck {
62308c63c71Sbeck 	uint8_t in_byte = 0;
62408c63c71Sbeck 	int i, in_byte_bits_left = 0;
62508c63c71Sbeck 
62608c63c71Sbeck 	assert(bits <= (int)sizeof(*out->c) * 8 && bits != 1);
62708c63c71Sbeck 
62808c63c71Sbeck 	for (i = 0; i < DEGREE; i++) {
62908c63c71Sbeck 		uint16_t element = 0;
63008c63c71Sbeck 		int element_bits_done = 0;
63108c63c71Sbeck 
63208c63c71Sbeck 		while (element_bits_done < bits) {
63308c63c71Sbeck 			int chunk_bits = bits - element_bits_done;
63408c63c71Sbeck 
63508c63c71Sbeck 			if (in_byte_bits_left == 0) {
63608c63c71Sbeck 				in_byte = *in;
63708c63c71Sbeck 				in++;
63808c63c71Sbeck 				in_byte_bits_left = 8;
63908c63c71Sbeck 			}
64008c63c71Sbeck 
64108c63c71Sbeck 			if (chunk_bits > in_byte_bits_left) {
64208c63c71Sbeck 				chunk_bits = in_byte_bits_left;
64308c63c71Sbeck 			}
64408c63c71Sbeck 
64508c63c71Sbeck 			element |= (in_byte & kMasks[chunk_bits - 1]) <<
64608c63c71Sbeck 			    element_bits_done;
64708c63c71Sbeck 			in_byte_bits_left -= chunk_bits;
64808c63c71Sbeck 			in_byte >>= chunk_bits;
64908c63c71Sbeck 
65008c63c71Sbeck 			element_bits_done += chunk_bits;
65108c63c71Sbeck 		}
65208c63c71Sbeck 
65308c63c71Sbeck 		if (element >= kPrime) {
65408c63c71Sbeck 			return 0;
65508c63c71Sbeck 		}
65608c63c71Sbeck 		out->c[i] = element;
65708c63c71Sbeck 	}
65808c63c71Sbeck 
65908c63c71Sbeck 	return 1;
66008c63c71Sbeck }
66108c63c71Sbeck 
66208c63c71Sbeck /* scalar_decode_1 is |scalar_decode| specialised for |bits| == 1. */
66308c63c71Sbeck static void
scalar_decode_1(scalar * out,const uint8_t in[32])66408c63c71Sbeck scalar_decode_1(scalar *out, const uint8_t in[32])
66508c63c71Sbeck {
66608c63c71Sbeck 	int i, j;
66708c63c71Sbeck 
66808c63c71Sbeck 	for (i = 0; i < DEGREE; i += 8) {
66908c63c71Sbeck 		uint8_t in_byte = *in;
67008c63c71Sbeck 
67108c63c71Sbeck 		in++;
67208c63c71Sbeck 		for (j = 0; j < 8; j++) {
67308c63c71Sbeck 			out->c[i + j] = in_byte & 1;
67408c63c71Sbeck 			in_byte >>= 1;
67508c63c71Sbeck 		}
67608c63c71Sbeck 	}
67708c63c71Sbeck }
67808c63c71Sbeck 
67908c63c71Sbeck /*
68008c63c71Sbeck  * Decodes 32*|RANK1024|*|bits| bytes from |in| into |out|. It returns one on
68108c63c71Sbeck  * success or zero if any parsed value is >= |kPrime|.
68208c63c71Sbeck  */
68308c63c71Sbeck static int
vector_decode(vector * out,const uint8_t * in,int bits)68408c63c71Sbeck vector_decode(vector *out, const uint8_t *in, int bits)
68508c63c71Sbeck {
68608c63c71Sbeck 	int i;
68708c63c71Sbeck 
68808c63c71Sbeck 	for (i = 0; i < RANK1024; i++) {
68908c63c71Sbeck 		if (!scalar_decode(&out->v[i], in + i * bits * DEGREE / 8,
69008c63c71Sbeck 		    bits)) {
69108c63c71Sbeck 			return 0;
69208c63c71Sbeck 		}
69308c63c71Sbeck 	}
69408c63c71Sbeck 	return 1;
69508c63c71Sbeck }
69608c63c71Sbeck 
69708c63c71Sbeck /*
69808c63c71Sbeck  * Compresses (lossily) an input |x| mod 3329 into |bits| many bits by grouping
69908c63c71Sbeck  * numbers close to each other together. The formula used is
70008c63c71Sbeck  * round(2^|bits|/kPrime*x) mod 2^|bits|.
70108c63c71Sbeck  * Uses Barrett reduction to achieve constant time. Since we need both the
70208c63c71Sbeck  * remainder (for rounding) and the quotient (as the result), we cannot use
70308c63c71Sbeck  * |reduce| here, but need to do the Barrett reduction directly.
70408c63c71Sbeck  */
70508c63c71Sbeck static uint16_t
compress(uint16_t x,int bits)70608c63c71Sbeck compress(uint16_t x, int bits)
70708c63c71Sbeck {
70808c63c71Sbeck 	uint32_t shifted = (uint32_t)x << bits;
70908c63c71Sbeck 	uint64_t product = (uint64_t)shifted * kBarrettMultiplier;
71008c63c71Sbeck 	uint32_t quotient = (uint32_t)(product >> kBarrettShift);
71108c63c71Sbeck 	uint32_t remainder = shifted - quotient * kPrime;
71208c63c71Sbeck 
71308c63c71Sbeck 	/*
71408c63c71Sbeck 	 * Adjust the quotient to round correctly:
71508c63c71Sbeck 	 * 0 <= remainder <= kHalfPrime round to 0
71608c63c71Sbeck 	 * kHalfPrime < remainder <= kPrime + kHalfPrime round to 1
71708c63c71Sbeck 	 * kPrime + kHalfPrime < remainder < 2 * kPrime round to 2
71808c63c71Sbeck 	 */
71908c63c71Sbeck 	assert(remainder < 2u * kPrime);
72008c63c71Sbeck 	quotient += 1 & constant_time_lt(kHalfPrime, remainder);
72108c63c71Sbeck 	quotient += 1 & constant_time_lt(kPrime + kHalfPrime, remainder);
72208c63c71Sbeck 	return quotient & ((1 << bits) - 1);
72308c63c71Sbeck }
72408c63c71Sbeck 
72508c63c71Sbeck /*
72608c63c71Sbeck  * Decompresses |x| by using an equi-distant representative. The formula is
72708c63c71Sbeck  * round(kPrime/2^|bits|*x). Note that 2^|bits| being the divisor allows us to
72808c63c71Sbeck  * implement this logic using only bit operations.
72908c63c71Sbeck  */
73008c63c71Sbeck static uint16_t
decompress(uint16_t x,int bits)73108c63c71Sbeck decompress(uint16_t x, int bits)
73208c63c71Sbeck {
73308c63c71Sbeck 	uint32_t product = (uint32_t)x * kPrime;
73408c63c71Sbeck 	uint32_t power = 1 << bits;
73508c63c71Sbeck 	/* This is |product| % power, since |power| is a power of 2. */
73608c63c71Sbeck 	uint32_t remainder = product & (power - 1);
73708c63c71Sbeck 	/* This is |product| / power, since |power| is a power of 2. */
73808c63c71Sbeck 	uint32_t lower = product >> bits;
73908c63c71Sbeck 
74008c63c71Sbeck 	/*
74108c63c71Sbeck 	 * The rounding logic works since the first half of numbers mod |power| have a
74208c63c71Sbeck 	 * 0 as first bit, and the second half has a 1 as first bit, since |power| is
74308c63c71Sbeck 	 * a power of 2. As a 12 bit number, |remainder| is always positive, so we
74408c63c71Sbeck 	 * will shift in 0s for a right shift.
74508c63c71Sbeck 	 */
74608c63c71Sbeck 	return lower + (remainder >> (bits - 1));
74708c63c71Sbeck }
74808c63c71Sbeck 
74908c63c71Sbeck static void
scalar_compress(scalar * s,int bits)75008c63c71Sbeck scalar_compress(scalar *s, int bits)
75108c63c71Sbeck {
75208c63c71Sbeck 	int i;
75308c63c71Sbeck 
75408c63c71Sbeck 	for (i = 0; i < DEGREE; i++) {
75508c63c71Sbeck 		s->c[i] = compress(s->c[i], bits);
75608c63c71Sbeck 	}
75708c63c71Sbeck }
75808c63c71Sbeck 
75908c63c71Sbeck static void
scalar_decompress(scalar * s,int bits)76008c63c71Sbeck scalar_decompress(scalar *s, int bits)
76108c63c71Sbeck {
76208c63c71Sbeck 	int i;
76308c63c71Sbeck 
76408c63c71Sbeck 	for (i = 0; i < DEGREE; i++) {
76508c63c71Sbeck 		s->c[i] = decompress(s->c[i], bits);
76608c63c71Sbeck 	}
76708c63c71Sbeck }
76808c63c71Sbeck 
76908c63c71Sbeck static void
vector_compress(vector * a,int bits)77008c63c71Sbeck vector_compress(vector *a, int bits)
77108c63c71Sbeck {
77208c63c71Sbeck 	int i;
77308c63c71Sbeck 
77408c63c71Sbeck 	for (i = 0; i < RANK1024; i++) {
77508c63c71Sbeck 		scalar_compress(&a->v[i], bits);
77608c63c71Sbeck 	}
77708c63c71Sbeck }
77808c63c71Sbeck 
77908c63c71Sbeck static void
vector_decompress(vector * a,int bits)78008c63c71Sbeck vector_decompress(vector *a, int bits)
78108c63c71Sbeck {
78208c63c71Sbeck 	int i;
78308c63c71Sbeck 
78408c63c71Sbeck 	for (i = 0; i < RANK1024; i++) {
78508c63c71Sbeck 		scalar_decompress(&a->v[i], bits);
78608c63c71Sbeck 	}
78708c63c71Sbeck }
78808c63c71Sbeck 
78908c63c71Sbeck struct public_key {
79008c63c71Sbeck 	vector t;
79108c63c71Sbeck 	uint8_t rho[32];
79208c63c71Sbeck 	uint8_t public_key_hash[32];
79308c63c71Sbeck 	matrix m;
79408c63c71Sbeck };
79508c63c71Sbeck 
79608c63c71Sbeck static struct public_key *
public_key_1024_from_external(const struct MLKEM1024_public_key * external)79708c63c71Sbeck public_key_1024_from_external(const struct MLKEM1024_public_key *external)
79808c63c71Sbeck {
79908c63c71Sbeck 	return (struct public_key *)external;
80008c63c71Sbeck }
80108c63c71Sbeck 
80208c63c71Sbeck struct private_key {
80308c63c71Sbeck 	struct public_key pub;
80408c63c71Sbeck 	vector s;
80508c63c71Sbeck 	uint8_t fo_failure_secret[32];
80608c63c71Sbeck };
80708c63c71Sbeck 
80808c63c71Sbeck static struct private_key *
private_key_1024_from_external(const struct MLKEM1024_private_key * external)80908c63c71Sbeck private_key_1024_from_external(const struct MLKEM1024_private_key *external)
81008c63c71Sbeck {
81108c63c71Sbeck 	return (struct private_key *)external;
81208c63c71Sbeck }
81308c63c71Sbeck 
81408c63c71Sbeck /*
81508c63c71Sbeck  * Calls |MLKEM1024_generate_key_external_entropy| with random bytes from
81608c63c71Sbeck  * |RAND_bytes|.
81708c63c71Sbeck  */
81808c63c71Sbeck void
MLKEM1024_generate_key(uint8_t out_encoded_public_key[MLKEM1024_PUBLIC_KEY_BYTES],uint8_t optional_out_seed[MLKEM_SEED_BYTES],struct MLKEM1024_private_key * out_private_key)81908c63c71Sbeck MLKEM1024_generate_key(uint8_t out_encoded_public_key[MLKEM1024_PUBLIC_KEY_BYTES],
82008c63c71Sbeck     uint8_t optional_out_seed[MLKEM_SEED_BYTES],
82108c63c71Sbeck     struct MLKEM1024_private_key *out_private_key)
82208c63c71Sbeck {
82308c63c71Sbeck 	uint8_t entropy_buf[MLKEM_SEED_BYTES];
82408c63c71Sbeck 	uint8_t *entropy = optional_out_seed != NULL ? optional_out_seed :
82508c63c71Sbeck 	    entropy_buf;
82608c63c71Sbeck 
82708c63c71Sbeck 	arc4random_buf(entropy, MLKEM_SEED_BYTES);
82808c63c71Sbeck 	MLKEM1024_generate_key_external_entropy(out_encoded_public_key,
82908c63c71Sbeck 	    out_private_key, entropy);
83008c63c71Sbeck }
83108c63c71Sbeck LCRYPTO_ALIAS(MLKEM1024_generate_key);
83208c63c71Sbeck 
83308c63c71Sbeck int
MLKEM1024_private_key_from_seed(struct MLKEM1024_private_key * out_private_key,const uint8_t * seed,size_t seed_len)83408c63c71Sbeck MLKEM1024_private_key_from_seed(struct MLKEM1024_private_key *out_private_key,
83508c63c71Sbeck     const uint8_t *seed, size_t seed_len)
83608c63c71Sbeck {
83708c63c71Sbeck 	uint8_t public_key_bytes[MLKEM1024_PUBLIC_KEY_BYTES];
83808c63c71Sbeck 
83908c63c71Sbeck 	if (seed_len != MLKEM_SEED_BYTES) {
84008c63c71Sbeck 		return 0;
84108c63c71Sbeck 	}
84208c63c71Sbeck 	MLKEM1024_generate_key_external_entropy(public_key_bytes,
84308c63c71Sbeck 	    out_private_key, seed);
84408c63c71Sbeck 
84508c63c71Sbeck 	return 1;
84608c63c71Sbeck }
84708c63c71Sbeck LCRYPTO_ALIAS(MLKEM1024_private_key_from_seed);
84808c63c71Sbeck 
84908c63c71Sbeck static int
mlkem_marshal_public_key(CBB * out,const struct public_key * pub)85008c63c71Sbeck mlkem_marshal_public_key(CBB *out, const struct public_key *pub)
85108c63c71Sbeck {
85208c63c71Sbeck 	uint8_t *vector_output;
85308c63c71Sbeck 
85408c63c71Sbeck 	if (!CBB_add_space(out, &vector_output, kEncodedVectorSize)) {
85508c63c71Sbeck 		return 0;
85608c63c71Sbeck 	}
85708c63c71Sbeck 	vector_encode(vector_output, &pub->t, kLog2Prime);
85808c63c71Sbeck 	if (!CBB_add_bytes(out, pub->rho, sizeof(pub->rho))) {
85908c63c71Sbeck 		return 0;
86008c63c71Sbeck 	}
86108c63c71Sbeck 	return 1;
86208c63c71Sbeck }
86308c63c71Sbeck 
86408c63c71Sbeck void
MLKEM1024_generate_key_external_entropy(uint8_t out_encoded_public_key[MLKEM1024_PUBLIC_KEY_BYTES],struct MLKEM1024_private_key * out_private_key,const uint8_t entropy[MLKEM_SEED_BYTES])86508c63c71Sbeck MLKEM1024_generate_key_external_entropy(
86608c63c71Sbeck     uint8_t out_encoded_public_key[MLKEM1024_PUBLIC_KEY_BYTES],
86708c63c71Sbeck     struct MLKEM1024_private_key *out_private_key,
86808c63c71Sbeck     const uint8_t entropy[MLKEM_SEED_BYTES])
86908c63c71Sbeck {
87008c63c71Sbeck 	struct private_key *priv = private_key_1024_from_external(
87108c63c71Sbeck 	    out_private_key);
87208c63c71Sbeck 	uint8_t augmented_seed[33];
87308c63c71Sbeck 	uint8_t *rho, *sigma;
87408c63c71Sbeck 	uint8_t counter = 0;
87508c63c71Sbeck 	uint8_t hashed[64];
87608c63c71Sbeck 	vector error;
87708c63c71Sbeck 	CBB cbb;
87808c63c71Sbeck 
87908c63c71Sbeck 	memcpy(augmented_seed, entropy, 32);
88008c63c71Sbeck 	augmented_seed[32] = RANK1024;
88108c63c71Sbeck 	hash_g(hashed, augmented_seed, 33);
88208c63c71Sbeck 	rho = hashed;
88308c63c71Sbeck 	sigma = hashed + 32;
88408c63c71Sbeck 	memcpy(priv->pub.rho, hashed, sizeof(priv->pub.rho));
88508c63c71Sbeck 	matrix_expand(&priv->pub.m, rho);
88608c63c71Sbeck 	vector_generate_secret_eta_2(&priv->s, &counter, sigma);
88708c63c71Sbeck 	vector_ntt(&priv->s);
88808c63c71Sbeck 	vector_generate_secret_eta_2(&error, &counter, sigma);
88908c63c71Sbeck 	vector_ntt(&error);
89008c63c71Sbeck 	matrix_mult_transpose(&priv->pub.t, &priv->pub.m, &priv->s);
89108c63c71Sbeck 	vector_add(&priv->pub.t, &error);
89208c63c71Sbeck 
893516824a3Stb 	/* XXX - error checking. */
89408c63c71Sbeck 	CBB_init_fixed(&cbb, out_encoded_public_key, MLKEM1024_PUBLIC_KEY_BYTES);
89508c63c71Sbeck 	if (!mlkem_marshal_public_key(&cbb, &priv->pub)) {
89608c63c71Sbeck 		abort();
89708c63c71Sbeck 	}
898516824a3Stb 	CBB_cleanup(&cbb);
89908c63c71Sbeck 
90008c63c71Sbeck 	hash_h(priv->pub.public_key_hash, out_encoded_public_key,
90108c63c71Sbeck 	    MLKEM1024_PUBLIC_KEY_BYTES);
90208c63c71Sbeck 	memcpy(priv->fo_failure_secret, entropy + 32, 32);
90308c63c71Sbeck }
90408c63c71Sbeck 
90508c63c71Sbeck void
MLKEM1024_public_from_private(struct MLKEM1024_public_key * out_public_key,const struct MLKEM1024_private_key * private_key)90608c63c71Sbeck MLKEM1024_public_from_private(struct MLKEM1024_public_key *out_public_key,
90708c63c71Sbeck     const struct MLKEM1024_private_key *private_key)
90808c63c71Sbeck {
90908c63c71Sbeck 	struct public_key *const pub = public_key_1024_from_external(
91008c63c71Sbeck 	    out_public_key);
91108c63c71Sbeck 	const struct private_key *const priv = private_key_1024_from_external(
91208c63c71Sbeck 	    private_key);
91308c63c71Sbeck 
91408c63c71Sbeck 	*pub = priv->pub;
91508c63c71Sbeck }
91608c63c71Sbeck LCRYPTO_ALIAS(MLKEM1024_public_from_private);
91708c63c71Sbeck 
91808c63c71Sbeck /*
91908c63c71Sbeck  * Encrypts a message with given randomness to the ciphertext in |out|. Without
92008c63c71Sbeck  * applying the Fujisaki-Okamoto transform this would not result in a CCA secure
92108c63c71Sbeck  * scheme, since lattice schemes are vulnerable to decryption failure oracles.
92208c63c71Sbeck  */
92308c63c71Sbeck static void
encrypt_cpa(uint8_t out[MLKEM1024_CIPHERTEXT_BYTES],const struct public_key * pub,const uint8_t message[32],const uint8_t randomness[32])92408c63c71Sbeck encrypt_cpa(uint8_t out[MLKEM1024_CIPHERTEXT_BYTES],
92508c63c71Sbeck     const struct public_key *pub, const uint8_t message[32],
92608c63c71Sbeck     const uint8_t randomness[32])
92708c63c71Sbeck {
92808c63c71Sbeck 	scalar expanded_message, scalar_error;
92908c63c71Sbeck 	vector secret, error, u;
93008c63c71Sbeck 	uint8_t counter = 0;
93108c63c71Sbeck 	uint8_t input[33];
93208c63c71Sbeck 	scalar v;
93308c63c71Sbeck 
93408c63c71Sbeck 	vector_generate_secret_eta_2(&secret, &counter, randomness);
93508c63c71Sbeck 	vector_ntt(&secret);
93608c63c71Sbeck 	vector_generate_secret_eta_2(&error, &counter, randomness);
93708c63c71Sbeck 	memcpy(input, randomness, 32);
93808c63c71Sbeck 	input[32] = counter;
93908c63c71Sbeck 	scalar_centered_binomial_distribution_eta_2_with_prf(&scalar_error,
94008c63c71Sbeck 	    input);
94108c63c71Sbeck 	matrix_mult(&u, &pub->m, &secret);
94208c63c71Sbeck 	vector_inverse_ntt(&u);
94308c63c71Sbeck 	vector_add(&u, &error);
94408c63c71Sbeck 	scalar_inner_product(&v, &pub->t, &secret);
94508c63c71Sbeck 	scalar_inverse_ntt(&v);
94608c63c71Sbeck 	scalar_add(&v, &scalar_error);
94708c63c71Sbeck 	scalar_decode_1(&expanded_message, message);
94808c63c71Sbeck 	scalar_decompress(&expanded_message, 1);
94908c63c71Sbeck 	scalar_add(&v, &expanded_message);
95008c63c71Sbeck 	vector_compress(&u, kDU1024);
95108c63c71Sbeck 	vector_encode(out, &u, kDU1024);
95208c63c71Sbeck 	scalar_compress(&v, kDV1024);
95308c63c71Sbeck 	scalar_encode(out + kCompressedVectorSize, &v, kDV1024);
95408c63c71Sbeck }
95508c63c71Sbeck 
95608c63c71Sbeck /* Calls MLKEM1024_encap_external_entropy| with random bytes */
95708c63c71Sbeck void
MLKEM1024_encap(uint8_t out_ciphertext[MLKEM1024_CIPHERTEXT_BYTES],uint8_t out_shared_secret[MLKEM_SHARED_SECRET_BYTES],const struct MLKEM1024_public_key * public_key)95808c63c71Sbeck MLKEM1024_encap(uint8_t out_ciphertext[MLKEM1024_CIPHERTEXT_BYTES],
95908c63c71Sbeck     uint8_t out_shared_secret[MLKEM_SHARED_SECRET_BYTES],
96008c63c71Sbeck     const struct MLKEM1024_public_key *public_key)
96108c63c71Sbeck {
96208c63c71Sbeck 	uint8_t entropy[MLKEM_ENCAP_ENTROPY];
96308c63c71Sbeck 
96408c63c71Sbeck 	arc4random_buf(entropy, MLKEM_ENCAP_ENTROPY);
96508c63c71Sbeck 	MLKEM1024_encap_external_entropy(out_ciphertext, out_shared_secret,
96608c63c71Sbeck 	    public_key, entropy);
96708c63c71Sbeck }
96808c63c71Sbeck LCRYPTO_ALIAS(MLKEM1024_encap);
96908c63c71Sbeck 
97008c63c71Sbeck /* See section 6.2 of the spec. */
97108c63c71Sbeck void
MLKEM1024_encap_external_entropy(uint8_t out_ciphertext[MLKEM1024_CIPHERTEXT_BYTES],uint8_t out_shared_secret[MLKEM_SHARED_SECRET_BYTES],const struct MLKEM1024_public_key * public_key,const uint8_t entropy[MLKEM_ENCAP_ENTROPY])97208c63c71Sbeck MLKEM1024_encap_external_entropy(
97308c63c71Sbeck     uint8_t out_ciphertext[MLKEM1024_CIPHERTEXT_BYTES],
97408c63c71Sbeck     uint8_t out_shared_secret[MLKEM_SHARED_SECRET_BYTES],
97508c63c71Sbeck     const struct MLKEM1024_public_key *public_key,
97608c63c71Sbeck     const uint8_t entropy[MLKEM_ENCAP_ENTROPY])
97708c63c71Sbeck {
97808c63c71Sbeck 	const struct public_key *pub = public_key_1024_from_external(public_key);
97908c63c71Sbeck 	uint8_t key_and_randomness[64];
98008c63c71Sbeck 	uint8_t input[64];
98108c63c71Sbeck 
98208c63c71Sbeck 	memcpy(input, entropy, MLKEM_ENCAP_ENTROPY);
98308c63c71Sbeck 	memcpy(input + MLKEM_ENCAP_ENTROPY, pub->public_key_hash,
98408c63c71Sbeck 	    sizeof(input) - MLKEM_ENCAP_ENTROPY);
98508c63c71Sbeck 	hash_g(key_and_randomness, input, sizeof(input));
98608c63c71Sbeck 	encrypt_cpa(out_ciphertext, pub, entropy, key_and_randomness + 32);
98708c63c71Sbeck 	memcpy(out_shared_secret, key_and_randomness, 32);
98808c63c71Sbeck }
98908c63c71Sbeck 
99008c63c71Sbeck static void
decrypt_cpa(uint8_t out[32],const struct private_key * priv,const uint8_t ciphertext[MLKEM1024_CIPHERTEXT_BYTES])99108c63c71Sbeck decrypt_cpa(uint8_t out[32], const struct private_key *priv,
99208c63c71Sbeck     const uint8_t ciphertext[MLKEM1024_CIPHERTEXT_BYTES])
99308c63c71Sbeck {
99408c63c71Sbeck 	scalar mask, v;
99508c63c71Sbeck 	vector u;
99608c63c71Sbeck 
99708c63c71Sbeck 	vector_decode(&u, ciphertext, kDU1024);
99808c63c71Sbeck 	vector_decompress(&u, kDU1024);
99908c63c71Sbeck 	vector_ntt(&u);
100008c63c71Sbeck 	scalar_decode(&v, ciphertext + kCompressedVectorSize, kDV1024);
100108c63c71Sbeck 	scalar_decompress(&v, kDV1024);
100208c63c71Sbeck 	scalar_inner_product(&mask, &priv->s, &u);
100308c63c71Sbeck 	scalar_inverse_ntt(&mask);
100408c63c71Sbeck 	scalar_sub(&v, &mask);
100508c63c71Sbeck 	scalar_compress(&v, 1);
100608c63c71Sbeck 	scalar_encode_1(out, &v);
100708c63c71Sbeck }
100808c63c71Sbeck 
100908c63c71Sbeck /* See section 6.3 */
101008c63c71Sbeck int
MLKEM1024_decap(uint8_t out_shared_secret[MLKEM_SHARED_SECRET_BYTES],const uint8_t * ciphertext,size_t ciphertext_len,const struct MLKEM1024_private_key * private_key)101108c63c71Sbeck MLKEM1024_decap(uint8_t out_shared_secret[MLKEM_SHARED_SECRET_BYTES],
101208c63c71Sbeck     const uint8_t *ciphertext, size_t ciphertext_len,
101308c63c71Sbeck     const struct MLKEM1024_private_key *private_key)
101408c63c71Sbeck {
101508c63c71Sbeck 	const struct private_key *priv = private_key_1024_from_external(
101608c63c71Sbeck 	    private_key);
101708c63c71Sbeck 	uint8_t expected_ciphertext[MLKEM1024_CIPHERTEXT_BYTES];
101808c63c71Sbeck 	uint8_t key_and_randomness[64];
101908c63c71Sbeck 	uint8_t failure_key[32];
102008c63c71Sbeck 	uint8_t decrypted[64];
102108c63c71Sbeck 	uint8_t mask;
102208c63c71Sbeck 	int i;
102308c63c71Sbeck 
102408c63c71Sbeck 	if (ciphertext_len != MLKEM1024_CIPHERTEXT_BYTES) {
102508c63c71Sbeck 		arc4random_buf(out_shared_secret, MLKEM_SHARED_SECRET_BYTES);
102608c63c71Sbeck 		return 0;
102708c63c71Sbeck 	}
102808c63c71Sbeck 
102908c63c71Sbeck 	decrypt_cpa(decrypted, priv, ciphertext);
103008c63c71Sbeck 	memcpy(decrypted + 32, priv->pub.public_key_hash,
103108c63c71Sbeck 	    sizeof(decrypted) - 32);
103208c63c71Sbeck 	hash_g(key_and_randomness, decrypted, sizeof(decrypted));
103308c63c71Sbeck 	encrypt_cpa(expected_ciphertext, &priv->pub, decrypted,
103408c63c71Sbeck 	    key_and_randomness + 32);
103508c63c71Sbeck 	kdf(failure_key, priv->fo_failure_secret, ciphertext, ciphertext_len);
103608c63c71Sbeck 	mask = constant_time_eq_int_8(memcmp(ciphertext, expected_ciphertext,
103708c63c71Sbeck 	    sizeof(expected_ciphertext)), 0);
103808c63c71Sbeck 	for (i = 0; i < MLKEM_SHARED_SECRET_BYTES; i++) {
103908c63c71Sbeck 		out_shared_secret[i] = constant_time_select_8(mask,
104008c63c71Sbeck 		    key_and_randomness[i], failure_key[i]);
104108c63c71Sbeck 	}
104208c63c71Sbeck 
104308c63c71Sbeck 	return 1;
104408c63c71Sbeck }
104508c63c71Sbeck LCRYPTO_ALIAS(MLKEM1024_decap);
104608c63c71Sbeck 
104708c63c71Sbeck int
MLKEM1024_marshal_public_key(CBB * out,const struct MLKEM1024_public_key * public_key)104808c63c71Sbeck MLKEM1024_marshal_public_key(CBB *out,
104908c63c71Sbeck     const struct MLKEM1024_public_key *public_key)
105008c63c71Sbeck {
105108c63c71Sbeck 	return mlkem_marshal_public_key(out,
105208c63c71Sbeck 	    public_key_1024_from_external(public_key));
105308c63c71Sbeck }
105408c63c71Sbeck LCRYPTO_ALIAS(MLKEM1024_marshal_public_key);
105508c63c71Sbeck 
105608c63c71Sbeck /*
105708c63c71Sbeck  * mlkem_parse_public_key_no_hash parses |in| into |pub| but doesn't calculate
105808c63c71Sbeck  * the value of |pub->public_key_hash|.
105908c63c71Sbeck  */
106008c63c71Sbeck static int
mlkem_parse_public_key_no_hash(struct public_key * pub,CBS * in)106108c63c71Sbeck mlkem_parse_public_key_no_hash(struct public_key *pub, CBS *in)
106208c63c71Sbeck {
106308c63c71Sbeck 	CBS t_bytes;
106408c63c71Sbeck 
106508c63c71Sbeck 	if (!CBS_get_bytes(in, &t_bytes, kEncodedVectorSize) ||
106608c63c71Sbeck 	    !vector_decode(&pub->t, CBS_data(&t_bytes), kLog2Prime)) {
106708c63c71Sbeck 		return 0;
106808c63c71Sbeck 	}
106908c63c71Sbeck 	memcpy(pub->rho, CBS_data(in), sizeof(pub->rho));
107008c63c71Sbeck 	if (!CBS_skip(in, sizeof(pub->rho)))
107108c63c71Sbeck 		return 0;
107208c63c71Sbeck 	matrix_expand(&pub->m, pub->rho);
107308c63c71Sbeck 	return 1;
107408c63c71Sbeck }
107508c63c71Sbeck 
107608c63c71Sbeck int
MLKEM1024_parse_public_key(struct MLKEM1024_public_key * public_key,CBS * in)107708c63c71Sbeck MLKEM1024_parse_public_key(struct MLKEM1024_public_key *public_key, CBS *in)
107808c63c71Sbeck {
107908c63c71Sbeck 	struct public_key *pub = public_key_1024_from_external(public_key);
108008c63c71Sbeck 	CBS orig_in = *in;
108108c63c71Sbeck 
108208c63c71Sbeck 	if (!mlkem_parse_public_key_no_hash(pub, in) ||
108308c63c71Sbeck 	    CBS_len(in) != 0) {
108408c63c71Sbeck 		return 0;
108508c63c71Sbeck 	}
108608c63c71Sbeck 	hash_h(pub->public_key_hash, CBS_data(&orig_in), CBS_len(&orig_in));
108708c63c71Sbeck 	return 1;
108808c63c71Sbeck }
108908c63c71Sbeck LCRYPTO_ALIAS(MLKEM1024_parse_public_key);
109008c63c71Sbeck 
109108c63c71Sbeck int
MLKEM1024_marshal_private_key(CBB * out,const struct MLKEM1024_private_key * private_key)109208c63c71Sbeck MLKEM1024_marshal_private_key(CBB *out,
109308c63c71Sbeck     const struct MLKEM1024_private_key *private_key)
109408c63c71Sbeck {
109508c63c71Sbeck 	const struct private_key *const priv = private_key_1024_from_external(
109608c63c71Sbeck 	    private_key);
109708c63c71Sbeck 	uint8_t *s_output;
109808c63c71Sbeck 
109908c63c71Sbeck 	if (!CBB_add_space(out, &s_output, kEncodedVectorSize)) {
110008c63c71Sbeck 		return 0;
110108c63c71Sbeck 	}
110208c63c71Sbeck 	vector_encode(s_output, &priv->s, kLog2Prime);
110308c63c71Sbeck 	if (!mlkem_marshal_public_key(out, &priv->pub) ||
110408c63c71Sbeck 	    !CBB_add_bytes(out, priv->pub.public_key_hash,
110508c63c71Sbeck 	    sizeof(priv->pub.public_key_hash)) ||
110608c63c71Sbeck 	    !CBB_add_bytes(out, priv->fo_failure_secret,
110708c63c71Sbeck 	    sizeof(priv->fo_failure_secret))) {
110808c63c71Sbeck 		return 0;
110908c63c71Sbeck 	}
111008c63c71Sbeck 	return 1;
111108c63c71Sbeck }
111208c63c71Sbeck 
111308c63c71Sbeck int
MLKEM1024_parse_private_key(struct MLKEM1024_private_key * out_private_key,CBS * in)111408c63c71Sbeck MLKEM1024_parse_private_key(struct MLKEM1024_private_key *out_private_key,
111508c63c71Sbeck     CBS *in)
111608c63c71Sbeck {
111708c63c71Sbeck 	struct private_key *const priv = private_key_1024_from_external(
111808c63c71Sbeck 	    out_private_key);
111908c63c71Sbeck 	CBS s_bytes;
112008c63c71Sbeck 
112108c63c71Sbeck 	if (!CBS_get_bytes(in, &s_bytes, kEncodedVectorSize) ||
112208c63c71Sbeck 	    !vector_decode(&priv->s, CBS_data(&s_bytes), kLog2Prime) ||
112308c63c71Sbeck 	    !mlkem_parse_public_key_no_hash(&priv->pub, in)) {
112408c63c71Sbeck 		return 0;
112508c63c71Sbeck 	}
112608c63c71Sbeck 	memcpy(priv->pub.public_key_hash, CBS_data(in),
112708c63c71Sbeck 	    sizeof(priv->pub.public_key_hash));
112808c63c71Sbeck 	if (!CBS_skip(in, sizeof(priv->pub.public_key_hash)))
112908c63c71Sbeck 		return 0;
113008c63c71Sbeck 	memcpy(priv->fo_failure_secret, CBS_data(in),
113108c63c71Sbeck 	    sizeof(priv->fo_failure_secret));
113208c63c71Sbeck 	if (!CBS_skip(in, sizeof(priv->fo_failure_secret)))
113308c63c71Sbeck 		return 0;
113408c63c71Sbeck 	if (CBS_len(in) != 0)
113508c63c71Sbeck 		return 0;
113608c63c71Sbeck 
113708c63c71Sbeck 	return 1;
113808c63c71Sbeck }
113908c63c71Sbeck LCRYPTO_ALIAS(MLKEM1024_parse_private_key);
1140