xref: /openbsd/lib/libcrypto/mlkem/mlkem1024.c (revision 67abc7a1)
1*67abc7a1Stb /* $OpenBSD: mlkem1024.c,v 1.4 2024/12/18 10:55:21 tb Exp $ */
208c63c71Sbeck /*
308c63c71Sbeck  * Copyright (c) 2024, Google Inc.
408c63c71Sbeck  * Copyright (c) 2024, Bob Beck <beck@obtuse.com>
508c63c71Sbeck  *
608c63c71Sbeck  * Permission to use, copy, modify, and/or distribute this software for any
708c63c71Sbeck  * purpose with or without fee is hereby granted, provided that the above
808c63c71Sbeck  * copyright notice and this permission notice appear in all copies.
908c63c71Sbeck  *
1008c63c71Sbeck  * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
1108c63c71Sbeck  * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
1208c63c71Sbeck  * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
1308c63c71Sbeck  * SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
1408c63c71Sbeck  * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
1508c63c71Sbeck  * OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
1608c63c71Sbeck  * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
1708c63c71Sbeck  */
1808c63c71Sbeck 
1908c63c71Sbeck #include <openssl/mlkem.h>
2008c63c71Sbeck 
2108c63c71Sbeck #include <assert.h>
2208c63c71Sbeck #include <stdlib.h>
2308c63c71Sbeck #include <string.h>
2408c63c71Sbeck 
2508c63c71Sbeck #include "bytestring.h"
2608c63c71Sbeck 
2708c63c71Sbeck #include "sha3_internal.h"
2808c63c71Sbeck #include "mlkem_internal.h"
2908c63c71Sbeck #include "constant_time.h"
3008c63c71Sbeck #include "crypto_internal.h"
3108c63c71Sbeck 
3208c63c71Sbeck /* Remove later */
3308c63c71Sbeck #undef LCRYPTO_ALIAS
3408c63c71Sbeck #define LCRYPTO_ALIAS(A)
3508c63c71Sbeck 
3608c63c71Sbeck /*
3708c63c71Sbeck  * See
3808c63c71Sbeck  * https://csrc.nist.gov/pubs/fips/203/final
3908c63c71Sbeck  */
4008c63c71Sbeck 
4108c63c71Sbeck static void
4208c63c71Sbeck prf(uint8_t *out, size_t out_len, const uint8_t in[33])
4308c63c71Sbeck {
4408c63c71Sbeck 	sha3_ctx ctx;
4508c63c71Sbeck 	shake256_init(&ctx);
4608c63c71Sbeck 	shake_update(&ctx, in, 33);
4708c63c71Sbeck 	shake_xof(&ctx);
4808c63c71Sbeck 	shake_out(&ctx, out, out_len);
4908c63c71Sbeck }
5008c63c71Sbeck 
5108c63c71Sbeck /* Section 4.1 */
5208c63c71Sbeck static void
5308c63c71Sbeck hash_h(uint8_t out[32], const uint8_t *in, size_t len)
5408c63c71Sbeck {
5508c63c71Sbeck 	sha3_ctx ctx;
5608c63c71Sbeck 	sha3_init(&ctx, 32);
5708c63c71Sbeck 	sha3_update(&ctx, in, len);
5808c63c71Sbeck 	sha3_final(out, &ctx);
5908c63c71Sbeck }
6008c63c71Sbeck 
6108c63c71Sbeck static void
6208c63c71Sbeck hash_g(uint8_t out[64], const uint8_t *in, size_t len)
6308c63c71Sbeck {
6408c63c71Sbeck 	sha3_ctx ctx;
6508c63c71Sbeck 	sha3_init(&ctx, 64);
6608c63c71Sbeck 	sha3_update(&ctx, in, len);
6708c63c71Sbeck 	sha3_final(out, &ctx);
6808c63c71Sbeck }
6908c63c71Sbeck 
7008c63c71Sbeck /* this is called 'J' in the spec */
7108c63c71Sbeck static void
7208c63c71Sbeck kdf(uint8_t out[MLKEM_SHARED_SECRET_BYTES], const uint8_t failure_secret[32],
7308c63c71Sbeck     const uint8_t *in, size_t len)
7408c63c71Sbeck {
7508c63c71Sbeck 	sha3_ctx ctx;
7608c63c71Sbeck 	shake256_init(&ctx);
7708c63c71Sbeck 	shake_update(&ctx, failure_secret, 32);
7808c63c71Sbeck 	shake_update(&ctx, in, len);
7908c63c71Sbeck 	shake_xof(&ctx);
8008c63c71Sbeck 	shake_out(&ctx, out, MLKEM_SHARED_SECRET_BYTES);
8108c63c71Sbeck }
8208c63c71Sbeck 
8308c63c71Sbeck #define DEGREE 256
8408c63c71Sbeck #define RANK1024 4
8508c63c71Sbeck 
8608c63c71Sbeck static const size_t kBarrettMultiplier = 5039;
8708c63c71Sbeck static const unsigned kBarrettShift = 24;
8808c63c71Sbeck static const uint16_t kPrime = 3329;
8908c63c71Sbeck static const int kLog2Prime = 12;
9008c63c71Sbeck static const uint16_t kHalfPrime = (/*kPrime=*/3329 - 1) / 2;
9108c63c71Sbeck static const int kDU1024 = 11;
9208c63c71Sbeck static const int kDV1024 = 5;
9308c63c71Sbeck 
9408c63c71Sbeck /*
9508c63c71Sbeck  * kInverseDegree is 128^-1 mod 3329; 128 because kPrime does not have a 512th
9608c63c71Sbeck  * root of unity.
9708c63c71Sbeck  */
9808c63c71Sbeck static const uint16_t kInverseDegree = 3303;
9908c63c71Sbeck static const size_t kEncodedVectorSize =
10008c63c71Sbeck     (/*kLog2Prime=*/12 * DEGREE / 8) * RANK1024;
10108c63c71Sbeck static const size_t kCompressedVectorSize = /*kDU1024=*/ 11 * RANK1024 * DEGREE /
10208c63c71Sbeck     8;
10308c63c71Sbeck 
10408c63c71Sbeck typedef struct scalar {
10508c63c71Sbeck 	/* On every function entry and exit, 0 <= c < kPrime. */
10608c63c71Sbeck 	uint16_t c[DEGREE];
10708c63c71Sbeck } scalar;
10808c63c71Sbeck 
10908c63c71Sbeck typedef struct vector {
11008c63c71Sbeck 	scalar v[RANK1024];
11108c63c71Sbeck } vector;
11208c63c71Sbeck 
11308c63c71Sbeck typedef struct matrix {
11408c63c71Sbeck 	scalar v[RANK1024][RANK1024];
11508c63c71Sbeck } matrix;
11608c63c71Sbeck 
11708c63c71Sbeck /*
11808c63c71Sbeck  * This bit of Python will be referenced in some of the following comments:
11908c63c71Sbeck  *
12008c63c71Sbeck  *  p = 3329
12108c63c71Sbeck  *
12208c63c71Sbeck  * def bitreverse(i):
12308c63c71Sbeck  *     ret = 0
12408c63c71Sbeck  *     for n in range(7):
12508c63c71Sbeck  *         bit = i & 1
12608c63c71Sbeck  *         ret <<= 1
12708c63c71Sbeck  *         ret |= bit
12808c63c71Sbeck  *         i >>= 1
12908c63c71Sbeck  *     return ret
13008c63c71Sbeck  */
13108c63c71Sbeck 
13208c63c71Sbeck /* kNTTRoots = [pow(17, bitreverse(i), p) for i in range(128)] */
13308c63c71Sbeck static const uint16_t kNTTRoots[128] = {
13408c63c71Sbeck 	1,    1729, 2580, 3289, 2642, 630,  1897, 848,  1062, 1919, 193,  797,
13508c63c71Sbeck 	2786, 3260, 569,  1746, 296,  2447, 1339, 1476, 3046, 56,   2240, 1333,
13608c63c71Sbeck 	1426, 2094, 535,  2882, 2393, 2879, 1974, 821,  289,  331,  3253, 1756,
13708c63c71Sbeck 	1197, 2304, 2277, 2055, 650,  1977, 2513, 632,  2865, 33,   1320, 1915,
13808c63c71Sbeck 	2319, 1435, 807,  452,  1438, 2868, 1534, 2402, 2647, 2617, 1481, 648,
13908c63c71Sbeck 	2474, 3110, 1227, 910,  17,   2761, 583,  2649, 1637, 723,  2288, 1100,
14008c63c71Sbeck 	1409, 2662, 3281, 233,  756,  2156, 3015, 3050, 1703, 1651, 2789, 1789,
14108c63c71Sbeck 	1847, 952,  1461, 2687, 939,  2308, 2437, 2388, 733,  2337, 268,  641,
14208c63c71Sbeck 	1584, 2298, 2037, 3220, 375,  2549, 2090, 1645, 1063, 319,  2773, 757,
14308c63c71Sbeck 	2099, 561,  2466, 2594, 2804, 1092, 403,  1026, 1143, 2150, 2775, 886,
14408c63c71Sbeck 	1722, 1212, 1874, 1029, 2110, 2935, 885,  2154,
14508c63c71Sbeck };
14608c63c71Sbeck 
14708c63c71Sbeck /* kInverseNTTRoots = [pow(17, -bitreverse(i), p) for i in range(128)] */
14808c63c71Sbeck static const uint16_t kInverseNTTRoots[128] = {
14908c63c71Sbeck 	1,    1600, 40,   749,  2481, 1432, 2699, 687,  1583, 2760, 69,   543,
15008c63c71Sbeck 	2532, 3136, 1410, 2267, 2508, 1355, 450,  936,  447,  2794, 1235, 1903,
15108c63c71Sbeck 	1996, 1089, 3273, 283,  1853, 1990, 882,  3033, 2419, 2102, 219,  855,
15208c63c71Sbeck 	2681, 1848, 712,  682,  927,  1795, 461,  1891, 2877, 2522, 1894, 1010,
15308c63c71Sbeck 	1414, 2009, 3296, 464,  2697, 816,  1352, 2679, 1274, 1052, 1025, 2132,
15408c63c71Sbeck 	1573, 76,   2998, 3040, 1175, 2444, 394,  1219, 2300, 1455, 2117, 1607,
15508c63c71Sbeck 	2443, 554,  1179, 2186, 2303, 2926, 2237, 525,  735,  863,  2768, 1230,
15608c63c71Sbeck 	2572, 556,  3010, 2266, 1684, 1239, 780,  2954, 109,  1292, 1031, 1745,
15708c63c71Sbeck 	2688, 3061, 992,  2596, 941,  892,  1021, 2390, 642,  1868, 2377, 1482,
15808c63c71Sbeck 	1540, 540,  1678, 1626, 279,  314,  1173, 2573, 3096, 48,   667,  1920,
15908c63c71Sbeck 	2229, 1041, 2606, 1692, 680,  2746, 568,  3312,
16008c63c71Sbeck };
16108c63c71Sbeck 
16208c63c71Sbeck /* kModRoots = [pow(17, 2*bitreverse(i) + 1, p) for i in range(128)] */
16308c63c71Sbeck static const uint16_t kModRoots[128] = {
16408c63c71Sbeck 	17,   3312, 2761, 568,  583,  2746, 2649, 680,  1637, 1692, 723,  2606,
16508c63c71Sbeck 	2288, 1041, 1100, 2229, 1409, 1920, 2662, 667,  3281, 48,   233,  3096,
16608c63c71Sbeck 	756,  2573, 2156, 1173, 3015, 314,  3050, 279,  1703, 1626, 1651, 1678,
16708c63c71Sbeck 	2789, 540,  1789, 1540, 1847, 1482, 952,  2377, 1461, 1868, 2687, 642,
16808c63c71Sbeck 	939,  2390, 2308, 1021, 2437, 892,  2388, 941,  733,  2596, 2337, 992,
16908c63c71Sbeck 	268,  3061, 641,  2688, 1584, 1745, 2298, 1031, 2037, 1292, 3220, 109,
17008c63c71Sbeck 	375,  2954, 2549, 780,  2090, 1239, 1645, 1684, 1063, 2266, 319,  3010,
17108c63c71Sbeck 	2773, 556,  757,  2572, 2099, 1230, 561,  2768, 2466, 863,  2594, 735,
17208c63c71Sbeck 	2804, 525,  1092, 2237, 403,  2926, 1026, 2303, 1143, 2186, 2150, 1179,
17308c63c71Sbeck 	2775, 554,  886,  2443, 1722, 1607, 1212, 2117, 1874, 1455, 1029, 2300,
17408c63c71Sbeck 	2110, 1219, 2935, 394,  885,  2444, 2154, 1175,
17508c63c71Sbeck };
17608c63c71Sbeck 
17708c63c71Sbeck /* reduce_once reduces 0 <= x < 2*kPrime, mod kPrime. */
17808c63c71Sbeck static uint16_t
17908c63c71Sbeck reduce_once(uint16_t x)
18008c63c71Sbeck {
18108c63c71Sbeck 	assert(x < 2 * kPrime);
18208c63c71Sbeck 	const uint16_t subtracted = x - kPrime;
18308c63c71Sbeck 	uint16_t mask = 0u - (subtracted >> 15);
1849ee6f1feSbeck 
18508c63c71Sbeck 	/*
1869ee6f1feSbeck 	 * Although this is a constant-time select, we omit a value barrier here.
1879ee6f1feSbeck 	 * Value barriers impede auto-vectorization (likely because it forces the
1889ee6f1feSbeck 	 * value to transit through a general-purpose register). On AArch64, this
1899ee6f1feSbeck 	 * is a difference of 2x.
1909ee6f1feSbeck 	 *
1919ee6f1feSbeck 	 * We usually add value barriers to selects because Clang turns
1929ee6f1feSbeck          * consecutive selects with the same condition into a branch instead of
1939ee6f1feSbeck 	 * CMOV/CSEL. This condition does not occur in ML-KEM, so omitting it
1949ee6f1feSbeck          * seems to be safe so far  but see
1959ee6f1feSbeck          * |scalar_centered_binomial_distribution_eta_2_with_prf|.
19608c63c71Sbeck 	 */
19708c63c71Sbeck 	return (mask & x) | (~mask & subtracted);
19808c63c71Sbeck }
19908c63c71Sbeck 
20008c63c71Sbeck /*
20108c63c71Sbeck  * constant time reduce x mod kPrime using Barrett reduction. x must be less
20208c63c71Sbeck  * than kPrime + 2×kPrime².
20308c63c71Sbeck  */
20408c63c71Sbeck static uint16_t
20508c63c71Sbeck reduce(uint32_t x)
20608c63c71Sbeck {
20708c63c71Sbeck 	uint64_t product = (uint64_t)x * kBarrettMultiplier;
20808c63c71Sbeck 	uint32_t quotient = (uint32_t)(product >> kBarrettShift);
20908c63c71Sbeck 	uint32_t remainder = x - quotient * kPrime;
21008c63c71Sbeck 
21108c63c71Sbeck 	assert(x < kPrime + 2u * kPrime * kPrime);
21208c63c71Sbeck 	return reduce_once(remainder);
21308c63c71Sbeck }
21408c63c71Sbeck 
21508c63c71Sbeck static void
21608c63c71Sbeck scalar_zero(scalar *out)
21708c63c71Sbeck {
21808c63c71Sbeck 	memset(out, 0, sizeof(*out));
21908c63c71Sbeck }
22008c63c71Sbeck 
22108c63c71Sbeck static void
22208c63c71Sbeck vector_zero(vector *out)
22308c63c71Sbeck {
22408c63c71Sbeck 	memset(out, 0, sizeof(*out));
22508c63c71Sbeck }
22608c63c71Sbeck 
22708c63c71Sbeck /*
22808c63c71Sbeck  * In place number theoretic transform of a given scalar.
22908c63c71Sbeck  * Note that MLKEM's kPrime 3329 does not have a 512th root of unity, so this
23008c63c71Sbeck  * transform leaves off the last iteration of the usual FFT code, with the 128
23108c63c71Sbeck  * relevant roots of unity being stored in |kNTTRoots|. This means the output
23208c63c71Sbeck  * should be seen as 128 elements in GF(3329^2), with the coefficients of the
23308c63c71Sbeck  * elements being consecutive entries in |s->c|.
23408c63c71Sbeck  */
23508c63c71Sbeck static void
23608c63c71Sbeck scalar_ntt(scalar *s)
23708c63c71Sbeck {
23808c63c71Sbeck 	int offset = DEGREE;
23908c63c71Sbeck 	int step;
24008c63c71Sbeck 	/*
24108c63c71Sbeck 	 * `int` is used here because using `size_t` throughout caused a ~5% slowdown
24208c63c71Sbeck 	 * with Clang 14 on Aarch64.
24308c63c71Sbeck 	 */
24408c63c71Sbeck 	for (step = 1; step < DEGREE / 2; step <<= 1) {
24508c63c71Sbeck 		int i, j, k = 0;
24608c63c71Sbeck 
24708c63c71Sbeck 		offset >>= 1;
24808c63c71Sbeck 		for (i = 0; i < step; i++) {
24908c63c71Sbeck 			const uint32_t step_root = kNTTRoots[i + step];
25008c63c71Sbeck 
25108c63c71Sbeck 			for (j = k; j < k + offset; j++) {
25208c63c71Sbeck 				uint16_t odd, even;
25308c63c71Sbeck 
25408c63c71Sbeck 				odd = reduce(step_root * s->c[j + offset]);
25508c63c71Sbeck 				even = s->c[j];
25608c63c71Sbeck 				s->c[j] = reduce_once(odd + even);
25708c63c71Sbeck 				s->c[j + offset] = reduce_once(even - odd +
25808c63c71Sbeck 				    kPrime);
25908c63c71Sbeck 			}
26008c63c71Sbeck 			k += 2 * offset;
26108c63c71Sbeck 		}
26208c63c71Sbeck 	}
26308c63c71Sbeck }
26408c63c71Sbeck 
26508c63c71Sbeck static void
26608c63c71Sbeck vector_ntt(vector *a)
26708c63c71Sbeck {
26808c63c71Sbeck 	int i;
26908c63c71Sbeck 
27008c63c71Sbeck 	for (i = 0; i < RANK1024; i++) {
27108c63c71Sbeck 		scalar_ntt(&a->v[i]);
27208c63c71Sbeck 	}
27308c63c71Sbeck }
27408c63c71Sbeck 
27508c63c71Sbeck /*
27608c63c71Sbeck  * In place inverse number theoretic transform of a given scalar, with pairs of
27708c63c71Sbeck  * entries of s->v being interpreted as elements of GF(3329^2). Just as with the
27808c63c71Sbeck  * number theoretic transform, this leaves off the first step of the normal iFFT
27908c63c71Sbeck  * to account for the fact that 3329 does not have a 512th root of unity, using
28008c63c71Sbeck  * the precomputed 128 roots of unity stored in |kInverseNTTRoots|.
28108c63c71Sbeck  */
28208c63c71Sbeck static void
28308c63c71Sbeck scalar_inverse_ntt(scalar *s)
28408c63c71Sbeck {
28508c63c71Sbeck 	int i, j, k, offset, step = DEGREE / 2;
28608c63c71Sbeck 
28708c63c71Sbeck 	/*
28808c63c71Sbeck 	 * `int` is used here because using `size_t` throughout caused a ~5% slowdown
28908c63c71Sbeck 	 * with Clang 14 on Aarch64.
29008c63c71Sbeck 	 */
29108c63c71Sbeck 	for (offset = 2; offset < DEGREE; offset <<= 1) {
29208c63c71Sbeck 		step >>= 1;
29308c63c71Sbeck 		k = 0;
29408c63c71Sbeck 		for (i = 0; i < step; i++) {
29508c63c71Sbeck 			uint32_t step_root = kInverseNTTRoots[i + step];
29608c63c71Sbeck 			for (j = k; j < k + offset; j++) {
29708c63c71Sbeck 				uint16_t odd, even;
29808c63c71Sbeck 				odd = s->c[j + offset];
29908c63c71Sbeck 				even = s->c[j];
30008c63c71Sbeck 				s->c[j] = reduce_once(odd + even);
30108c63c71Sbeck 				s->c[j + offset] = reduce(step_root *
30208c63c71Sbeck 				    (even - odd + kPrime));
30308c63c71Sbeck 			}
30408c63c71Sbeck 			k += 2 * offset;
30508c63c71Sbeck 		}
30608c63c71Sbeck 	}
30708c63c71Sbeck 	for (i = 0; i < DEGREE; i++) {
30808c63c71Sbeck 		s->c[i] = reduce(s->c[i] * kInverseDegree);
30908c63c71Sbeck 	}
31008c63c71Sbeck }
31108c63c71Sbeck 
31208c63c71Sbeck static void
31308c63c71Sbeck vector_inverse_ntt(vector *a)
31408c63c71Sbeck {
31508c63c71Sbeck 	int i;
31608c63c71Sbeck 
31708c63c71Sbeck 	for (i = 0; i < RANK1024; i++) {
31808c63c71Sbeck 		scalar_inverse_ntt(&a->v[i]);
31908c63c71Sbeck 	}
32008c63c71Sbeck }
32108c63c71Sbeck 
32208c63c71Sbeck static void
32308c63c71Sbeck scalar_add(scalar *lhs, const scalar *rhs)
32408c63c71Sbeck {
32508c63c71Sbeck 	int i;
32608c63c71Sbeck 
32708c63c71Sbeck 	for (i = 0; i < DEGREE; i++) {
32808c63c71Sbeck 		lhs->c[i] = reduce_once(lhs->c[i] + rhs->c[i]);
32908c63c71Sbeck 	}
33008c63c71Sbeck }
33108c63c71Sbeck 
33208c63c71Sbeck static void
33308c63c71Sbeck scalar_sub(scalar *lhs, const scalar *rhs)
33408c63c71Sbeck {
33508c63c71Sbeck 	int i;
33608c63c71Sbeck 
33708c63c71Sbeck 	for (i = 0; i < DEGREE; i++) {
33808c63c71Sbeck 		lhs->c[i] = reduce_once(lhs->c[i] - rhs->c[i] + kPrime);
33908c63c71Sbeck 	}
34008c63c71Sbeck }
34108c63c71Sbeck 
34208c63c71Sbeck /*
34308c63c71Sbeck  * Multiplying two scalars in the number theoretically transformed state. Since
34408c63c71Sbeck  * 3329 does not have a 512th root of unity, this means we have to interpret
34508c63c71Sbeck  * the 2*ith and (2*i+1)th entries of the scalar as elements of GF(3329)[X]/(X^2
34608c63c71Sbeck  * - 17^(2*bitreverse(i)+1)) The value of 17^(2*bitreverse(i)+1) mod 3329 is
34708c63c71Sbeck  * stored in the precomputed |kModRoots| table. Note that our Barrett transform
34808c63c71Sbeck  * only allows us to multipy two reduced numbers together, so we need some
34908c63c71Sbeck  * intermediate reduction steps, even if an uint64_t could hold 3 multiplied
35008c63c71Sbeck  * numbers.
35108c63c71Sbeck  */
35208c63c71Sbeck static void
35308c63c71Sbeck scalar_mult(scalar *out, const scalar *lhs, const scalar *rhs)
35408c63c71Sbeck {
35508c63c71Sbeck 	int i;
35608c63c71Sbeck 
35708c63c71Sbeck 	for (i = 0; i < DEGREE / 2; i++) {
35808c63c71Sbeck 		uint32_t real_real = (uint32_t)lhs->c[2 * i] * rhs->c[2 * i];
35908c63c71Sbeck 		uint32_t img_img = (uint32_t)lhs->c[2 * i + 1] *
36008c63c71Sbeck 		    rhs->c[2 * i + 1];
36108c63c71Sbeck 		uint32_t real_img = (uint32_t)lhs->c[2 * i] * rhs->c[2 * i + 1];
36208c63c71Sbeck 		uint32_t img_real = (uint32_t)lhs->c[2 * i + 1] * rhs->c[2 * i];
36308c63c71Sbeck 
36408c63c71Sbeck 		out->c[2 * i] =
36508c63c71Sbeck 		    reduce(real_real +
36608c63c71Sbeck 		    (uint32_t)reduce(img_img) * kModRoots[i]);
36708c63c71Sbeck 		out->c[2 * i + 1] = reduce(img_real + real_img);
36808c63c71Sbeck 	}
36908c63c71Sbeck }
37008c63c71Sbeck 
37108c63c71Sbeck static void
37208c63c71Sbeck vector_add(vector *lhs, const vector *rhs)
37308c63c71Sbeck {
37408c63c71Sbeck 	int i;
37508c63c71Sbeck 
37608c63c71Sbeck 	for (i = 0; i < RANK1024; i++) {
37708c63c71Sbeck 		scalar_add(&lhs->v[i], &rhs->v[i]);
37808c63c71Sbeck 	}
37908c63c71Sbeck }
38008c63c71Sbeck 
38108c63c71Sbeck static void
38208c63c71Sbeck matrix_mult(vector *out, const matrix *m, const vector *a)
38308c63c71Sbeck {
38408c63c71Sbeck 	int i, j;
38508c63c71Sbeck 
38608c63c71Sbeck 	vector_zero(out);
38708c63c71Sbeck 	for (i = 0; i < RANK1024; i++) {
38808c63c71Sbeck 		for (j = 0; j < RANK1024; j++) {
38908c63c71Sbeck 			scalar product;
39008c63c71Sbeck 
39108c63c71Sbeck 			scalar_mult(&product, &m->v[i][j], &a->v[j]);
39208c63c71Sbeck 			scalar_add(&out->v[i], &product);
39308c63c71Sbeck 		}
39408c63c71Sbeck 	}
39508c63c71Sbeck }
39608c63c71Sbeck 
39708c63c71Sbeck static void
39808c63c71Sbeck matrix_mult_transpose(vector *out, const matrix *m,
39908c63c71Sbeck     const vector *a)
40008c63c71Sbeck {
40108c63c71Sbeck 	int i, j;
40208c63c71Sbeck 
40308c63c71Sbeck 	vector_zero(out);
40408c63c71Sbeck 	for (i = 0; i < RANK1024; i++) {
40508c63c71Sbeck 		for (j = 0; j < RANK1024; j++) {
40608c63c71Sbeck 			scalar product;
40708c63c71Sbeck 
40808c63c71Sbeck 			scalar_mult(&product, &m->v[j][i], &a->v[j]);
40908c63c71Sbeck 			scalar_add(&out->v[i], &product);
41008c63c71Sbeck 		}
41108c63c71Sbeck 	}
41208c63c71Sbeck }
41308c63c71Sbeck 
41408c63c71Sbeck static void
41508c63c71Sbeck scalar_inner_product(scalar *out, const vector *lhs,
41608c63c71Sbeck     const vector *rhs)
41708c63c71Sbeck {
41808c63c71Sbeck 	int i;
41908c63c71Sbeck 	scalar_zero(out);
42008c63c71Sbeck 	for (i = 0; i < RANK1024; i++) {
42108c63c71Sbeck 		scalar product;
42208c63c71Sbeck 
42308c63c71Sbeck 		scalar_mult(&product, &lhs->v[i], &rhs->v[i]);
42408c63c71Sbeck 		scalar_add(out, &product);
42508c63c71Sbeck 	}
42608c63c71Sbeck }
42708c63c71Sbeck 
42808c63c71Sbeck /*
42908c63c71Sbeck  * Algorithm 6 of spec. Rejection samples a Keccak stream to get uniformly
43008c63c71Sbeck  * distributed elements. This is used for matrix expansion and only operates on
43108c63c71Sbeck  * public inputs.
43208c63c71Sbeck  */
43308c63c71Sbeck static void
43408c63c71Sbeck scalar_from_keccak_vartime(scalar *out, sha3_ctx *keccak_ctx)
43508c63c71Sbeck {
43608c63c71Sbeck 	int i, done = 0;
43708c63c71Sbeck 
43808c63c71Sbeck 	while (done < DEGREE) {
43908c63c71Sbeck 		uint8_t block[168];
44008c63c71Sbeck 
44108c63c71Sbeck 		shake_out(keccak_ctx, block, sizeof(block));
44208c63c71Sbeck 		for (i = 0; i < sizeof(block) && done < DEGREE; i += 3) {
44308c63c71Sbeck 			uint16_t d1 = block[i] + 256 * (block[i + 1] % 16);
44408c63c71Sbeck 			uint16_t d2 = block[i + 1] / 16 + 16 * block[i + 2];
44508c63c71Sbeck 
44608c63c71Sbeck 			if (d1 < kPrime) {
44708c63c71Sbeck 				out->c[done++] = d1;
44808c63c71Sbeck 			}
44908c63c71Sbeck 			if (d2 < kPrime && done < DEGREE) {
45008c63c71Sbeck 				out->c[done++] = d2;
45108c63c71Sbeck 			}
45208c63c71Sbeck 		}
45308c63c71Sbeck 	}
45408c63c71Sbeck }
45508c63c71Sbeck 
45608c63c71Sbeck /*
45708c63c71Sbeck  * Algorithm 7 of the spec, with eta fixed to two and the PRF call
45808c63c71Sbeck  * included. Creates binominally distributed elements by sampling 2*|eta| bits,
45908c63c71Sbeck  * and setting the coefficient to the count of the first bits minus the count of
46008c63c71Sbeck  * the second bits, resulting in a centered binomial distribution. Since eta is
46108c63c71Sbeck  * two this gives -2/2 with a probability of 1/16, -1/1 with probability 1/4,
46208c63c71Sbeck  * and 0 with probability 3/8.
46308c63c71Sbeck  */
46408c63c71Sbeck static void
46508c63c71Sbeck scalar_centered_binomial_distribution_eta_2_with_prf(scalar *out,
46608c63c71Sbeck     const uint8_t input[33])
46708c63c71Sbeck {
46808c63c71Sbeck 	uint8_t entropy[128];
46908c63c71Sbeck 	int i;
47008c63c71Sbeck 
47108c63c71Sbeck 	CTASSERT(sizeof(entropy) == 2 * /*kEta=*/ 2 * DEGREE / 8);
47208c63c71Sbeck 	prf(entropy, sizeof(entropy), input);
47308c63c71Sbeck 
47408c63c71Sbeck 	for (i = 0; i < DEGREE; i += 2) {
47508c63c71Sbeck 		uint8_t byte = entropy[i / 2];
4769ee6f1feSbeck 		uint16_t mask;
4779ee6f1feSbeck 		uint16_t value = (byte & 1) + ((byte >> 1) & 1);
47808c63c71Sbeck 
47908c63c71Sbeck 		value -= ((byte >> 2) & 1) + ((byte >> 3) & 1);
480*67abc7a1Stb 
4819ee6f1feSbeck 		/*
4829ee6f1feSbeck 		 * Add |kPrime| if |value| underflowed. See |reduce_once| for a
4839ee6f1feSbeck 		 * discussion on why the value barrier is omitted. While this
4849ee6f1feSbeck 		 * could have been written reduce_once(value + kPrime), this is
4859ee6f1feSbeck 		 * one extra addition and small range of |value| tempts some
4869ee6f1feSbeck 		 * versions of Clang to emit a branch.
4879ee6f1feSbeck 		 */
4889ee6f1feSbeck 		mask = 0u - (value >> 15);
4899ee6f1feSbeck 		out->c[i] = ((value + kPrime) & mask) | (value & ~mask);
49008c63c71Sbeck 
49108c63c71Sbeck 		byte >>= 4;
4929ee6f1feSbeck 		value = (byte & 1) + ((byte >> 1) & 1);
49308c63c71Sbeck 		value -= ((byte >> 2) & 1) + ((byte >> 3) & 1);
4949ee6f1feSbeck 		/*  See above. */
4959ee6f1feSbeck 		mask = 0u - (value >> 15);
4969ee6f1feSbeck 		out->c[i + 1] = ((value + kPrime) & mask) | (value & ~mask);
49708c63c71Sbeck 	}
49808c63c71Sbeck }
49908c63c71Sbeck 
50008c63c71Sbeck /*
50108c63c71Sbeck  * Generates a secret vector by using
50208c63c71Sbeck  * |scalar_centered_binomial_distribution_eta_2_with_prf|, using the given seed
50308c63c71Sbeck  * appending and incrementing |counter| for entry of the vector.
50408c63c71Sbeck  */
50508c63c71Sbeck static void
50608c63c71Sbeck vector_generate_secret_eta_2(vector *out, uint8_t *counter,
50708c63c71Sbeck     const uint8_t seed[32])
50808c63c71Sbeck {
50908c63c71Sbeck 	uint8_t input[33];
51008c63c71Sbeck 	int i;
51108c63c71Sbeck 
51208c63c71Sbeck 	memcpy(input, seed, 32);
51308c63c71Sbeck 	for (i = 0; i < RANK1024; i++) {
51408c63c71Sbeck 		input[32] = (*counter)++;
51508c63c71Sbeck 		scalar_centered_binomial_distribution_eta_2_with_prf(&out->v[i],
51608c63c71Sbeck 		    input);
51708c63c71Sbeck 	}
51808c63c71Sbeck }
51908c63c71Sbeck 
52008c63c71Sbeck /* Expands the matrix of a seed for key generation and for encaps-CPA. */
52108c63c71Sbeck static void
52208c63c71Sbeck matrix_expand(matrix *out, const uint8_t rho[32])
52308c63c71Sbeck {
52408c63c71Sbeck 	uint8_t input[34];
52508c63c71Sbeck 	int i, j;
52608c63c71Sbeck 
52708c63c71Sbeck 	memcpy(input, rho, 32);
52808c63c71Sbeck 	for (i = 0; i < RANK1024; i++) {
52908c63c71Sbeck 		for (j = 0; j < RANK1024; j++) {
53008c63c71Sbeck 			sha3_ctx keccak_ctx;
53108c63c71Sbeck 
53208c63c71Sbeck 			input[32] = i;
53308c63c71Sbeck 			input[33] = j;
53408c63c71Sbeck 			shake128_init(&keccak_ctx);
53508c63c71Sbeck 			shake_update(&keccak_ctx, input, sizeof(input));
53608c63c71Sbeck 			shake_xof(&keccak_ctx);
53708c63c71Sbeck 			scalar_from_keccak_vartime(&out->v[i][j], &keccak_ctx);
53808c63c71Sbeck 		}
53908c63c71Sbeck 	}
54008c63c71Sbeck }
54108c63c71Sbeck 
54208c63c71Sbeck static const uint8_t kMasks[8] = {0x01, 0x03, 0x07, 0x0f,
54308c63c71Sbeck 	0x1f, 0x3f, 0x7f, 0xff};
54408c63c71Sbeck 
54508c63c71Sbeck static void
54608c63c71Sbeck scalar_encode(uint8_t *out, const scalar *s, int bits)
54708c63c71Sbeck {
54808c63c71Sbeck 	uint8_t out_byte = 0;
54908c63c71Sbeck 	int i, out_byte_bits = 0;
55008c63c71Sbeck 
55108c63c71Sbeck 	assert(bits <= (int)sizeof(*s->c) * 8 && bits != 1);
55208c63c71Sbeck 	for (i = 0; i < DEGREE; i++) {
55308c63c71Sbeck 		uint16_t element = s->c[i];
55408c63c71Sbeck 		int element_bits_done = 0;
55508c63c71Sbeck 
55608c63c71Sbeck 		while (element_bits_done < bits) {
55708c63c71Sbeck 			int chunk_bits = bits - element_bits_done;
55808c63c71Sbeck 			int out_bits_remaining = 8 - out_byte_bits;
55908c63c71Sbeck 
56008c63c71Sbeck 			if (chunk_bits >= out_bits_remaining) {
56108c63c71Sbeck 				chunk_bits = out_bits_remaining;
56208c63c71Sbeck 				out_byte |= (element &
56308c63c71Sbeck 				    kMasks[chunk_bits - 1]) << out_byte_bits;
56408c63c71Sbeck 				*out = out_byte;
56508c63c71Sbeck 				out++;
56608c63c71Sbeck 				out_byte_bits = 0;
56708c63c71Sbeck 				out_byte = 0;
56808c63c71Sbeck 			} else {
56908c63c71Sbeck 				out_byte |= (element &
57008c63c71Sbeck 				    kMasks[chunk_bits - 1]) << out_byte_bits;
57108c63c71Sbeck 				out_byte_bits += chunk_bits;
57208c63c71Sbeck 			}
57308c63c71Sbeck 
57408c63c71Sbeck 			element_bits_done += chunk_bits;
57508c63c71Sbeck 			element >>= chunk_bits;
57608c63c71Sbeck 		}
57708c63c71Sbeck 	}
57808c63c71Sbeck 
57908c63c71Sbeck 	if (out_byte_bits > 0) {
58008c63c71Sbeck 		*out = out_byte;
58108c63c71Sbeck 	}
58208c63c71Sbeck }
58308c63c71Sbeck 
58408c63c71Sbeck /* scalar_encode_1 is |scalar_encode| specialised for |bits| == 1. */
58508c63c71Sbeck static void
58608c63c71Sbeck scalar_encode_1(uint8_t out[32], const scalar *s)
58708c63c71Sbeck {
58808c63c71Sbeck 	int i, j;
58908c63c71Sbeck 
59008c63c71Sbeck 	for (i = 0; i < DEGREE; i += 8) {
59108c63c71Sbeck 		uint8_t out_byte = 0;
59208c63c71Sbeck 
59308c63c71Sbeck 		for (j = 0; j < 8; j++) {
59408c63c71Sbeck 			out_byte |= (s->c[i + j] & 1) << j;
59508c63c71Sbeck 		}
59608c63c71Sbeck 		*out = out_byte;
59708c63c71Sbeck 		out++;
59808c63c71Sbeck 	}
59908c63c71Sbeck }
60008c63c71Sbeck 
60108c63c71Sbeck /*
60208c63c71Sbeck  * Encodes an entire vector into 32*|RANK1024|*|bits| bytes. Note that since 256
60308c63c71Sbeck  * (DEGREE) is divisible by 8, the individual vector entries will always fill a
60408c63c71Sbeck  * whole number of bytes, so we do not need to worry about bit packing here.
60508c63c71Sbeck  */
60608c63c71Sbeck static void
60708c63c71Sbeck vector_encode(uint8_t *out, const vector *a, int bits)
60808c63c71Sbeck {
60908c63c71Sbeck 	int i;
61008c63c71Sbeck 
61108c63c71Sbeck 	for (i = 0; i < RANK1024; i++) {
61208c63c71Sbeck 		scalar_encode(out + i * bits * DEGREE / 8, &a->v[i], bits);
61308c63c71Sbeck 	}
61408c63c71Sbeck }
61508c63c71Sbeck 
61608c63c71Sbeck /*
61708c63c71Sbeck  * scalar_decode parses |DEGREE * bits| bits from |in| into |DEGREE| values in
61808c63c71Sbeck  * |out|. It returns one on success and zero if any parsed value is >=
61908c63c71Sbeck  * |kPrime|.
62008c63c71Sbeck  */
62108c63c71Sbeck static int
62208c63c71Sbeck scalar_decode(scalar *out, const uint8_t *in, int bits)
62308c63c71Sbeck {
62408c63c71Sbeck 	uint8_t in_byte = 0;
62508c63c71Sbeck 	int i, in_byte_bits_left = 0;
62608c63c71Sbeck 
62708c63c71Sbeck 	assert(bits <= (int)sizeof(*out->c) * 8 && bits != 1);
62808c63c71Sbeck 
62908c63c71Sbeck 	for (i = 0; i < DEGREE; i++) {
63008c63c71Sbeck 		uint16_t element = 0;
63108c63c71Sbeck 		int element_bits_done = 0;
63208c63c71Sbeck 
63308c63c71Sbeck 		while (element_bits_done < bits) {
63408c63c71Sbeck 			int chunk_bits = bits - element_bits_done;
63508c63c71Sbeck 
63608c63c71Sbeck 			if (in_byte_bits_left == 0) {
63708c63c71Sbeck 				in_byte = *in;
63808c63c71Sbeck 				in++;
63908c63c71Sbeck 				in_byte_bits_left = 8;
64008c63c71Sbeck 			}
64108c63c71Sbeck 
64208c63c71Sbeck 			if (chunk_bits > in_byte_bits_left) {
64308c63c71Sbeck 				chunk_bits = in_byte_bits_left;
64408c63c71Sbeck 			}
64508c63c71Sbeck 
64608c63c71Sbeck 			element |= (in_byte & kMasks[chunk_bits - 1]) <<
64708c63c71Sbeck 			    element_bits_done;
64808c63c71Sbeck 			in_byte_bits_left -= chunk_bits;
64908c63c71Sbeck 			in_byte >>= chunk_bits;
65008c63c71Sbeck 
65108c63c71Sbeck 			element_bits_done += chunk_bits;
65208c63c71Sbeck 		}
65308c63c71Sbeck 
65408c63c71Sbeck 		if (element >= kPrime) {
65508c63c71Sbeck 			return 0;
65608c63c71Sbeck 		}
65708c63c71Sbeck 		out->c[i] = element;
65808c63c71Sbeck 	}
65908c63c71Sbeck 
66008c63c71Sbeck 	return 1;
66108c63c71Sbeck }
66208c63c71Sbeck 
66308c63c71Sbeck /* scalar_decode_1 is |scalar_decode| specialised for |bits| == 1. */
66408c63c71Sbeck static void
66508c63c71Sbeck scalar_decode_1(scalar *out, const uint8_t in[32])
66608c63c71Sbeck {
66708c63c71Sbeck 	int i, j;
66808c63c71Sbeck 
66908c63c71Sbeck 	for (i = 0; i < DEGREE; i += 8) {
67008c63c71Sbeck 		uint8_t in_byte = *in;
67108c63c71Sbeck 
67208c63c71Sbeck 		in++;
67308c63c71Sbeck 		for (j = 0; j < 8; j++) {
67408c63c71Sbeck 			out->c[i + j] = in_byte & 1;
67508c63c71Sbeck 			in_byte >>= 1;
67608c63c71Sbeck 		}
67708c63c71Sbeck 	}
67808c63c71Sbeck }
67908c63c71Sbeck 
68008c63c71Sbeck /*
68108c63c71Sbeck  * Decodes 32*|RANK1024|*|bits| bytes from |in| into |out|. It returns one on
68208c63c71Sbeck  * success or zero if any parsed value is >= |kPrime|.
68308c63c71Sbeck  */
68408c63c71Sbeck static int
68508c63c71Sbeck vector_decode(vector *out, const uint8_t *in, int bits)
68608c63c71Sbeck {
68708c63c71Sbeck 	int i;
68808c63c71Sbeck 
68908c63c71Sbeck 	for (i = 0; i < RANK1024; i++) {
69008c63c71Sbeck 		if (!scalar_decode(&out->v[i], in + i * bits * DEGREE / 8,
69108c63c71Sbeck 		    bits)) {
69208c63c71Sbeck 			return 0;
69308c63c71Sbeck 		}
69408c63c71Sbeck 	}
69508c63c71Sbeck 	return 1;
69608c63c71Sbeck }
69708c63c71Sbeck 
69808c63c71Sbeck /*
69908c63c71Sbeck  * Compresses (lossily) an input |x| mod 3329 into |bits| many bits by grouping
70008c63c71Sbeck  * numbers close to each other together. The formula used is
70108c63c71Sbeck  * round(2^|bits|/kPrime*x) mod 2^|bits|.
70208c63c71Sbeck  * Uses Barrett reduction to achieve constant time. Since we need both the
70308c63c71Sbeck  * remainder (for rounding) and the quotient (as the result), we cannot use
70408c63c71Sbeck  * |reduce| here, but need to do the Barrett reduction directly.
70508c63c71Sbeck  */
70608c63c71Sbeck static uint16_t
70708c63c71Sbeck compress(uint16_t x, int bits)
70808c63c71Sbeck {
70908c63c71Sbeck 	uint32_t shifted = (uint32_t)x << bits;
71008c63c71Sbeck 	uint64_t product = (uint64_t)shifted * kBarrettMultiplier;
71108c63c71Sbeck 	uint32_t quotient = (uint32_t)(product >> kBarrettShift);
71208c63c71Sbeck 	uint32_t remainder = shifted - quotient * kPrime;
71308c63c71Sbeck 
71408c63c71Sbeck 	/*
71508c63c71Sbeck 	 * Adjust the quotient to round correctly:
71608c63c71Sbeck 	 * 0 <= remainder <= kHalfPrime round to 0
71708c63c71Sbeck 	 * kHalfPrime < remainder <= kPrime + kHalfPrime round to 1
71808c63c71Sbeck 	 * kPrime + kHalfPrime < remainder < 2 * kPrime round to 2
71908c63c71Sbeck 	 */
72008c63c71Sbeck 	assert(remainder < 2u * kPrime);
72108c63c71Sbeck 	quotient += 1 & constant_time_lt(kHalfPrime, remainder);
72208c63c71Sbeck 	quotient += 1 & constant_time_lt(kPrime + kHalfPrime, remainder);
72308c63c71Sbeck 	return quotient & ((1 << bits) - 1);
72408c63c71Sbeck }
72508c63c71Sbeck 
72608c63c71Sbeck /*
72708c63c71Sbeck  * Decompresses |x| by using an equi-distant representative. The formula is
72808c63c71Sbeck  * round(kPrime/2^|bits|*x). Note that 2^|bits| being the divisor allows us to
72908c63c71Sbeck  * implement this logic using only bit operations.
73008c63c71Sbeck  */
73108c63c71Sbeck static uint16_t
73208c63c71Sbeck decompress(uint16_t x, int bits)
73308c63c71Sbeck {
73408c63c71Sbeck 	uint32_t product = (uint32_t)x * kPrime;
73508c63c71Sbeck 	uint32_t power = 1 << bits;
73608c63c71Sbeck 	/* This is |product| % power, since |power| is a power of 2. */
73708c63c71Sbeck 	uint32_t remainder = product & (power - 1);
73808c63c71Sbeck 	/* This is |product| / power, since |power| is a power of 2. */
73908c63c71Sbeck 	uint32_t lower = product >> bits;
74008c63c71Sbeck 
74108c63c71Sbeck 	/*
74208c63c71Sbeck 	 * The rounding logic works since the first half of numbers mod |power| have a
74308c63c71Sbeck 	 * 0 as first bit, and the second half has a 1 as first bit, since |power| is
74408c63c71Sbeck 	 * a power of 2. As a 12 bit number, |remainder| is always positive, so we
74508c63c71Sbeck 	 * will shift in 0s for a right shift.
74608c63c71Sbeck 	 */
74708c63c71Sbeck 	return lower + (remainder >> (bits - 1));
74808c63c71Sbeck }
74908c63c71Sbeck 
75008c63c71Sbeck static void
75108c63c71Sbeck scalar_compress(scalar *s, int bits)
75208c63c71Sbeck {
75308c63c71Sbeck 	int i;
75408c63c71Sbeck 
75508c63c71Sbeck 	for (i = 0; i < DEGREE; i++) {
75608c63c71Sbeck 		s->c[i] = compress(s->c[i], bits);
75708c63c71Sbeck 	}
75808c63c71Sbeck }
75908c63c71Sbeck 
76008c63c71Sbeck static void
76108c63c71Sbeck scalar_decompress(scalar *s, int bits)
76208c63c71Sbeck {
76308c63c71Sbeck 	int i;
76408c63c71Sbeck 
76508c63c71Sbeck 	for (i = 0; i < DEGREE; i++) {
76608c63c71Sbeck 		s->c[i] = decompress(s->c[i], bits);
76708c63c71Sbeck 	}
76808c63c71Sbeck }
76908c63c71Sbeck 
77008c63c71Sbeck static void
77108c63c71Sbeck vector_compress(vector *a, int bits)
77208c63c71Sbeck {
77308c63c71Sbeck 	int i;
77408c63c71Sbeck 
77508c63c71Sbeck 	for (i = 0; i < RANK1024; i++) {
77608c63c71Sbeck 		scalar_compress(&a->v[i], bits);
77708c63c71Sbeck 	}
77808c63c71Sbeck }
77908c63c71Sbeck 
78008c63c71Sbeck static void
78108c63c71Sbeck vector_decompress(vector *a, int bits)
78208c63c71Sbeck {
78308c63c71Sbeck 	int i;
78408c63c71Sbeck 
78508c63c71Sbeck 	for (i = 0; i < RANK1024; i++) {
78608c63c71Sbeck 		scalar_decompress(&a->v[i], bits);
78708c63c71Sbeck 	}
78808c63c71Sbeck }
78908c63c71Sbeck 
79008c63c71Sbeck struct public_key {
79108c63c71Sbeck 	vector t;
79208c63c71Sbeck 	uint8_t rho[32];
79308c63c71Sbeck 	uint8_t public_key_hash[32];
79408c63c71Sbeck 	matrix m;
79508c63c71Sbeck };
79608c63c71Sbeck 
79708c63c71Sbeck static struct public_key *
79808c63c71Sbeck public_key_1024_from_external(const struct MLKEM1024_public_key *external)
79908c63c71Sbeck {
80008c63c71Sbeck 	return (struct public_key *)external;
80108c63c71Sbeck }
80208c63c71Sbeck 
80308c63c71Sbeck struct private_key {
80408c63c71Sbeck 	struct public_key pub;
80508c63c71Sbeck 	vector s;
80608c63c71Sbeck 	uint8_t fo_failure_secret[32];
80708c63c71Sbeck };
80808c63c71Sbeck 
80908c63c71Sbeck static struct private_key *
81008c63c71Sbeck private_key_1024_from_external(const struct MLKEM1024_private_key *external)
81108c63c71Sbeck {
81208c63c71Sbeck 	return (struct private_key *)external;
81308c63c71Sbeck }
81408c63c71Sbeck 
81508c63c71Sbeck /*
81608c63c71Sbeck  * Calls |MLKEM1024_generate_key_external_entropy| with random bytes from
81708c63c71Sbeck  * |RAND_bytes|.
81808c63c71Sbeck  */
81908c63c71Sbeck void
82008c63c71Sbeck MLKEM1024_generate_key(uint8_t out_encoded_public_key[MLKEM1024_PUBLIC_KEY_BYTES],
82108c63c71Sbeck     uint8_t optional_out_seed[MLKEM_SEED_BYTES],
82208c63c71Sbeck     struct MLKEM1024_private_key *out_private_key)
82308c63c71Sbeck {
82408c63c71Sbeck 	uint8_t entropy_buf[MLKEM_SEED_BYTES];
82508c63c71Sbeck 	uint8_t *entropy = optional_out_seed != NULL ? optional_out_seed :
82608c63c71Sbeck 	    entropy_buf;
82708c63c71Sbeck 
82808c63c71Sbeck 	arc4random_buf(entropy, MLKEM_SEED_BYTES);
82908c63c71Sbeck 	MLKEM1024_generate_key_external_entropy(out_encoded_public_key,
83008c63c71Sbeck 	    out_private_key, entropy);
83108c63c71Sbeck }
83208c63c71Sbeck LCRYPTO_ALIAS(MLKEM1024_generate_key);
83308c63c71Sbeck 
83408c63c71Sbeck int
83508c63c71Sbeck MLKEM1024_private_key_from_seed(struct MLKEM1024_private_key *out_private_key,
83608c63c71Sbeck     const uint8_t *seed, size_t seed_len)
83708c63c71Sbeck {
83808c63c71Sbeck 	uint8_t public_key_bytes[MLKEM1024_PUBLIC_KEY_BYTES];
83908c63c71Sbeck 
84008c63c71Sbeck 	if (seed_len != MLKEM_SEED_BYTES) {
84108c63c71Sbeck 		return 0;
84208c63c71Sbeck 	}
84308c63c71Sbeck 	MLKEM1024_generate_key_external_entropy(public_key_bytes,
84408c63c71Sbeck 	    out_private_key, seed);
84508c63c71Sbeck 
84608c63c71Sbeck 	return 1;
84708c63c71Sbeck }
84808c63c71Sbeck LCRYPTO_ALIAS(MLKEM1024_private_key_from_seed);
84908c63c71Sbeck 
85008c63c71Sbeck static int
85108c63c71Sbeck mlkem_marshal_public_key(CBB *out, const struct public_key *pub)
85208c63c71Sbeck {
85308c63c71Sbeck 	uint8_t *vector_output;
85408c63c71Sbeck 
85508c63c71Sbeck 	if (!CBB_add_space(out, &vector_output, kEncodedVectorSize)) {
85608c63c71Sbeck 		return 0;
85708c63c71Sbeck 	}
85808c63c71Sbeck 	vector_encode(vector_output, &pub->t, kLog2Prime);
85908c63c71Sbeck 	if (!CBB_add_bytes(out, pub->rho, sizeof(pub->rho))) {
86008c63c71Sbeck 		return 0;
86108c63c71Sbeck 	}
86208c63c71Sbeck 	return 1;
86308c63c71Sbeck }
86408c63c71Sbeck 
86508c63c71Sbeck void
86608c63c71Sbeck MLKEM1024_generate_key_external_entropy(
86708c63c71Sbeck     uint8_t out_encoded_public_key[MLKEM1024_PUBLIC_KEY_BYTES],
86808c63c71Sbeck     struct MLKEM1024_private_key *out_private_key,
86908c63c71Sbeck     const uint8_t entropy[MLKEM_SEED_BYTES])
87008c63c71Sbeck {
87108c63c71Sbeck 	struct private_key *priv = private_key_1024_from_external(
87208c63c71Sbeck 	    out_private_key);
87308c63c71Sbeck 	uint8_t augmented_seed[33];
87408c63c71Sbeck 	uint8_t *rho, *sigma;
87508c63c71Sbeck 	uint8_t counter = 0;
87608c63c71Sbeck 	uint8_t hashed[64];
87708c63c71Sbeck 	vector error;
87808c63c71Sbeck 	CBB cbb;
87908c63c71Sbeck 
88008c63c71Sbeck 	memcpy(augmented_seed, entropy, 32);
88108c63c71Sbeck 	augmented_seed[32] = RANK1024;
88208c63c71Sbeck 	hash_g(hashed, augmented_seed, 33);
88308c63c71Sbeck 	rho = hashed;
88408c63c71Sbeck 	sigma = hashed + 32;
88508c63c71Sbeck 	memcpy(priv->pub.rho, hashed, sizeof(priv->pub.rho));
88608c63c71Sbeck 	matrix_expand(&priv->pub.m, rho);
88708c63c71Sbeck 	vector_generate_secret_eta_2(&priv->s, &counter, sigma);
88808c63c71Sbeck 	vector_ntt(&priv->s);
88908c63c71Sbeck 	vector_generate_secret_eta_2(&error, &counter, sigma);
89008c63c71Sbeck 	vector_ntt(&error);
89108c63c71Sbeck 	matrix_mult_transpose(&priv->pub.t, &priv->pub.m, &priv->s);
89208c63c71Sbeck 	vector_add(&priv->pub.t, &error);
89308c63c71Sbeck 
894516824a3Stb 	/* XXX - error checking. */
89508c63c71Sbeck 	CBB_init_fixed(&cbb, out_encoded_public_key, MLKEM1024_PUBLIC_KEY_BYTES);
89608c63c71Sbeck 	if (!mlkem_marshal_public_key(&cbb, &priv->pub)) {
89708c63c71Sbeck 		abort();
89808c63c71Sbeck 	}
899516824a3Stb 	CBB_cleanup(&cbb);
90008c63c71Sbeck 
90108c63c71Sbeck 	hash_h(priv->pub.public_key_hash, out_encoded_public_key,
90208c63c71Sbeck 	    MLKEM1024_PUBLIC_KEY_BYTES);
90308c63c71Sbeck 	memcpy(priv->fo_failure_secret, entropy + 32, 32);
90408c63c71Sbeck }
90508c63c71Sbeck 
90608c63c71Sbeck void
90708c63c71Sbeck MLKEM1024_public_from_private(struct MLKEM1024_public_key *out_public_key,
90808c63c71Sbeck     const struct MLKEM1024_private_key *private_key)
90908c63c71Sbeck {
91008c63c71Sbeck 	struct public_key *const pub = public_key_1024_from_external(
91108c63c71Sbeck 	    out_public_key);
91208c63c71Sbeck 	const struct private_key *const priv = private_key_1024_from_external(
91308c63c71Sbeck 	    private_key);
91408c63c71Sbeck 
91508c63c71Sbeck 	*pub = priv->pub;
91608c63c71Sbeck }
91708c63c71Sbeck LCRYPTO_ALIAS(MLKEM1024_public_from_private);
91808c63c71Sbeck 
91908c63c71Sbeck /*
92008c63c71Sbeck  * Encrypts a message with given randomness to the ciphertext in |out|. Without
92108c63c71Sbeck  * applying the Fujisaki-Okamoto transform this would not result in a CCA secure
92208c63c71Sbeck  * scheme, since lattice schemes are vulnerable to decryption failure oracles.
92308c63c71Sbeck  */
92408c63c71Sbeck static void
92508c63c71Sbeck encrypt_cpa(uint8_t out[MLKEM1024_CIPHERTEXT_BYTES],
92608c63c71Sbeck     const struct public_key *pub, const uint8_t message[32],
92708c63c71Sbeck     const uint8_t randomness[32])
92808c63c71Sbeck {
92908c63c71Sbeck 	scalar expanded_message, scalar_error;
93008c63c71Sbeck 	vector secret, error, u;
93108c63c71Sbeck 	uint8_t counter = 0;
93208c63c71Sbeck 	uint8_t input[33];
93308c63c71Sbeck 	scalar v;
93408c63c71Sbeck 
93508c63c71Sbeck 	vector_generate_secret_eta_2(&secret, &counter, randomness);
93608c63c71Sbeck 	vector_ntt(&secret);
93708c63c71Sbeck 	vector_generate_secret_eta_2(&error, &counter, randomness);
93808c63c71Sbeck 	memcpy(input, randomness, 32);
93908c63c71Sbeck 	input[32] = counter;
94008c63c71Sbeck 	scalar_centered_binomial_distribution_eta_2_with_prf(&scalar_error,
94108c63c71Sbeck 	    input);
94208c63c71Sbeck 	matrix_mult(&u, &pub->m, &secret);
94308c63c71Sbeck 	vector_inverse_ntt(&u);
94408c63c71Sbeck 	vector_add(&u, &error);
94508c63c71Sbeck 	scalar_inner_product(&v, &pub->t, &secret);
94608c63c71Sbeck 	scalar_inverse_ntt(&v);
94708c63c71Sbeck 	scalar_add(&v, &scalar_error);
94808c63c71Sbeck 	scalar_decode_1(&expanded_message, message);
94908c63c71Sbeck 	scalar_decompress(&expanded_message, 1);
95008c63c71Sbeck 	scalar_add(&v, &expanded_message);
95108c63c71Sbeck 	vector_compress(&u, kDU1024);
95208c63c71Sbeck 	vector_encode(out, &u, kDU1024);
95308c63c71Sbeck 	scalar_compress(&v, kDV1024);
95408c63c71Sbeck 	scalar_encode(out + kCompressedVectorSize, &v, kDV1024);
95508c63c71Sbeck }
95608c63c71Sbeck 
95708c63c71Sbeck /* Calls MLKEM1024_encap_external_entropy| with random bytes */
95808c63c71Sbeck void
95908c63c71Sbeck MLKEM1024_encap(uint8_t out_ciphertext[MLKEM1024_CIPHERTEXT_BYTES],
96008c63c71Sbeck     uint8_t out_shared_secret[MLKEM_SHARED_SECRET_BYTES],
96108c63c71Sbeck     const struct MLKEM1024_public_key *public_key)
96208c63c71Sbeck {
96308c63c71Sbeck 	uint8_t entropy[MLKEM_ENCAP_ENTROPY];
96408c63c71Sbeck 
96508c63c71Sbeck 	arc4random_buf(entropy, MLKEM_ENCAP_ENTROPY);
96608c63c71Sbeck 	MLKEM1024_encap_external_entropy(out_ciphertext, out_shared_secret,
96708c63c71Sbeck 	    public_key, entropy);
96808c63c71Sbeck }
96908c63c71Sbeck LCRYPTO_ALIAS(MLKEM1024_encap);
97008c63c71Sbeck 
97108c63c71Sbeck /* See section 6.2 of the spec. */
97208c63c71Sbeck void
97308c63c71Sbeck MLKEM1024_encap_external_entropy(
97408c63c71Sbeck     uint8_t out_ciphertext[MLKEM1024_CIPHERTEXT_BYTES],
97508c63c71Sbeck     uint8_t out_shared_secret[MLKEM_SHARED_SECRET_BYTES],
97608c63c71Sbeck     const struct MLKEM1024_public_key *public_key,
97708c63c71Sbeck     const uint8_t entropy[MLKEM_ENCAP_ENTROPY])
97808c63c71Sbeck {
97908c63c71Sbeck 	const struct public_key *pub = public_key_1024_from_external(public_key);
98008c63c71Sbeck 	uint8_t key_and_randomness[64];
98108c63c71Sbeck 	uint8_t input[64];
98208c63c71Sbeck 
98308c63c71Sbeck 	memcpy(input, entropy, MLKEM_ENCAP_ENTROPY);
98408c63c71Sbeck 	memcpy(input + MLKEM_ENCAP_ENTROPY, pub->public_key_hash,
98508c63c71Sbeck 	    sizeof(input) - MLKEM_ENCAP_ENTROPY);
98608c63c71Sbeck 	hash_g(key_and_randomness, input, sizeof(input));
98708c63c71Sbeck 	encrypt_cpa(out_ciphertext, pub, entropy, key_and_randomness + 32);
98808c63c71Sbeck 	memcpy(out_shared_secret, key_and_randomness, 32);
98908c63c71Sbeck }
99008c63c71Sbeck 
99108c63c71Sbeck static void
99208c63c71Sbeck decrypt_cpa(uint8_t out[32], const struct private_key *priv,
99308c63c71Sbeck     const uint8_t ciphertext[MLKEM1024_CIPHERTEXT_BYTES])
99408c63c71Sbeck {
99508c63c71Sbeck 	scalar mask, v;
99608c63c71Sbeck 	vector u;
99708c63c71Sbeck 
99808c63c71Sbeck 	vector_decode(&u, ciphertext, kDU1024);
99908c63c71Sbeck 	vector_decompress(&u, kDU1024);
100008c63c71Sbeck 	vector_ntt(&u);
100108c63c71Sbeck 	scalar_decode(&v, ciphertext + kCompressedVectorSize, kDV1024);
100208c63c71Sbeck 	scalar_decompress(&v, kDV1024);
100308c63c71Sbeck 	scalar_inner_product(&mask, &priv->s, &u);
100408c63c71Sbeck 	scalar_inverse_ntt(&mask);
100508c63c71Sbeck 	scalar_sub(&v, &mask);
100608c63c71Sbeck 	scalar_compress(&v, 1);
100708c63c71Sbeck 	scalar_encode_1(out, &v);
100808c63c71Sbeck }
100908c63c71Sbeck 
101008c63c71Sbeck /* See section 6.3 */
101108c63c71Sbeck int
101208c63c71Sbeck MLKEM1024_decap(uint8_t out_shared_secret[MLKEM_SHARED_SECRET_BYTES],
101308c63c71Sbeck     const uint8_t *ciphertext, size_t ciphertext_len,
101408c63c71Sbeck     const struct MLKEM1024_private_key *private_key)
101508c63c71Sbeck {
101608c63c71Sbeck 	const struct private_key *priv = private_key_1024_from_external(
101708c63c71Sbeck 	    private_key);
101808c63c71Sbeck 	uint8_t expected_ciphertext[MLKEM1024_CIPHERTEXT_BYTES];
101908c63c71Sbeck 	uint8_t key_and_randomness[64];
102008c63c71Sbeck 	uint8_t failure_key[32];
102108c63c71Sbeck 	uint8_t decrypted[64];
102208c63c71Sbeck 	uint8_t mask;
102308c63c71Sbeck 	int i;
102408c63c71Sbeck 
102508c63c71Sbeck 	if (ciphertext_len != MLKEM1024_CIPHERTEXT_BYTES) {
102608c63c71Sbeck 		arc4random_buf(out_shared_secret, MLKEM_SHARED_SECRET_BYTES);
102708c63c71Sbeck 		return 0;
102808c63c71Sbeck 	}
102908c63c71Sbeck 
103008c63c71Sbeck 	decrypt_cpa(decrypted, priv, ciphertext);
103108c63c71Sbeck 	memcpy(decrypted + 32, priv->pub.public_key_hash,
103208c63c71Sbeck 	    sizeof(decrypted) - 32);
103308c63c71Sbeck 	hash_g(key_and_randomness, decrypted, sizeof(decrypted));
103408c63c71Sbeck 	encrypt_cpa(expected_ciphertext, &priv->pub, decrypted,
103508c63c71Sbeck 	    key_and_randomness + 32);
103608c63c71Sbeck 	kdf(failure_key, priv->fo_failure_secret, ciphertext, ciphertext_len);
103708c63c71Sbeck 	mask = constant_time_eq_int_8(memcmp(ciphertext, expected_ciphertext,
103808c63c71Sbeck 	    sizeof(expected_ciphertext)), 0);
103908c63c71Sbeck 	for (i = 0; i < MLKEM_SHARED_SECRET_BYTES; i++) {
104008c63c71Sbeck 		out_shared_secret[i] = constant_time_select_8(mask,
104108c63c71Sbeck 		    key_and_randomness[i], failure_key[i]);
104208c63c71Sbeck 	}
104308c63c71Sbeck 
104408c63c71Sbeck 	return 1;
104508c63c71Sbeck }
104608c63c71Sbeck LCRYPTO_ALIAS(MLKEM1024_decap);
104708c63c71Sbeck 
104808c63c71Sbeck int
104908c63c71Sbeck MLKEM1024_marshal_public_key(CBB *out,
105008c63c71Sbeck     const struct MLKEM1024_public_key *public_key)
105108c63c71Sbeck {
105208c63c71Sbeck 	return mlkem_marshal_public_key(out,
105308c63c71Sbeck 	    public_key_1024_from_external(public_key));
105408c63c71Sbeck }
105508c63c71Sbeck LCRYPTO_ALIAS(MLKEM1024_marshal_public_key);
105608c63c71Sbeck 
105708c63c71Sbeck /*
105808c63c71Sbeck  * mlkem_parse_public_key_no_hash parses |in| into |pub| but doesn't calculate
105908c63c71Sbeck  * the value of |pub->public_key_hash|.
106008c63c71Sbeck  */
106108c63c71Sbeck static int
106208c63c71Sbeck mlkem_parse_public_key_no_hash(struct public_key *pub, CBS *in)
106308c63c71Sbeck {
106408c63c71Sbeck 	CBS t_bytes;
106508c63c71Sbeck 
106608c63c71Sbeck 	if (!CBS_get_bytes(in, &t_bytes, kEncodedVectorSize) ||
106708c63c71Sbeck 	    !vector_decode(&pub->t, CBS_data(&t_bytes), kLog2Prime)) {
106808c63c71Sbeck 		return 0;
106908c63c71Sbeck 	}
107008c63c71Sbeck 	memcpy(pub->rho, CBS_data(in), sizeof(pub->rho));
107108c63c71Sbeck 	if (!CBS_skip(in, sizeof(pub->rho)))
107208c63c71Sbeck 		return 0;
107308c63c71Sbeck 	matrix_expand(&pub->m, pub->rho);
107408c63c71Sbeck 	return 1;
107508c63c71Sbeck }
107608c63c71Sbeck 
107708c63c71Sbeck int
107808c63c71Sbeck MLKEM1024_parse_public_key(struct MLKEM1024_public_key *public_key, CBS *in)
107908c63c71Sbeck {
108008c63c71Sbeck 	struct public_key *pub = public_key_1024_from_external(public_key);
108108c63c71Sbeck 	CBS orig_in = *in;
108208c63c71Sbeck 
108308c63c71Sbeck 	if (!mlkem_parse_public_key_no_hash(pub, in) ||
108408c63c71Sbeck 	    CBS_len(in) != 0) {
108508c63c71Sbeck 		return 0;
108608c63c71Sbeck 	}
108708c63c71Sbeck 	hash_h(pub->public_key_hash, CBS_data(&orig_in), CBS_len(&orig_in));
108808c63c71Sbeck 	return 1;
108908c63c71Sbeck }
109008c63c71Sbeck LCRYPTO_ALIAS(MLKEM1024_parse_public_key);
109108c63c71Sbeck 
109208c63c71Sbeck int
109308c63c71Sbeck MLKEM1024_marshal_private_key(CBB *out,
109408c63c71Sbeck     const struct MLKEM1024_private_key *private_key)
109508c63c71Sbeck {
109608c63c71Sbeck 	const struct private_key *const priv = private_key_1024_from_external(
109708c63c71Sbeck 	    private_key);
109808c63c71Sbeck 	uint8_t *s_output;
109908c63c71Sbeck 
110008c63c71Sbeck 	if (!CBB_add_space(out, &s_output, kEncodedVectorSize)) {
110108c63c71Sbeck 		return 0;
110208c63c71Sbeck 	}
110308c63c71Sbeck 	vector_encode(s_output, &priv->s, kLog2Prime);
110408c63c71Sbeck 	if (!mlkem_marshal_public_key(out, &priv->pub) ||
110508c63c71Sbeck 	    !CBB_add_bytes(out, priv->pub.public_key_hash,
110608c63c71Sbeck 	    sizeof(priv->pub.public_key_hash)) ||
110708c63c71Sbeck 	    !CBB_add_bytes(out, priv->fo_failure_secret,
110808c63c71Sbeck 	    sizeof(priv->fo_failure_secret))) {
110908c63c71Sbeck 		return 0;
111008c63c71Sbeck 	}
111108c63c71Sbeck 	return 1;
111208c63c71Sbeck }
111308c63c71Sbeck 
111408c63c71Sbeck int
111508c63c71Sbeck MLKEM1024_parse_private_key(struct MLKEM1024_private_key *out_private_key,
111608c63c71Sbeck     CBS *in)
111708c63c71Sbeck {
111808c63c71Sbeck 	struct private_key *const priv = private_key_1024_from_external(
111908c63c71Sbeck 	    out_private_key);
112008c63c71Sbeck 	CBS s_bytes;
112108c63c71Sbeck 
112208c63c71Sbeck 	if (!CBS_get_bytes(in, &s_bytes, kEncodedVectorSize) ||
112308c63c71Sbeck 	    !vector_decode(&priv->s, CBS_data(&s_bytes), kLog2Prime) ||
112408c63c71Sbeck 	    !mlkem_parse_public_key_no_hash(&priv->pub, in)) {
112508c63c71Sbeck 		return 0;
112608c63c71Sbeck 	}
112708c63c71Sbeck 	memcpy(priv->pub.public_key_hash, CBS_data(in),
112808c63c71Sbeck 	    sizeof(priv->pub.public_key_hash));
112908c63c71Sbeck 	if (!CBS_skip(in, sizeof(priv->pub.public_key_hash)))
113008c63c71Sbeck 		return 0;
113108c63c71Sbeck 	memcpy(priv->fo_failure_secret, CBS_data(in),
113208c63c71Sbeck 	    sizeof(priv->fo_failure_secret));
113308c63c71Sbeck 	if (!CBS_skip(in, sizeof(priv->fo_failure_secret)))
113408c63c71Sbeck 		return 0;
113508c63c71Sbeck 	if (CBS_len(in) != 0)
113608c63c71Sbeck 		return 0;
113708c63c71Sbeck 
113808c63c71Sbeck 	return 1;
113908c63c71Sbeck }
114008c63c71Sbeck LCRYPTO_ALIAS(MLKEM1024_parse_private_key);
1141