1*9a8fba7cStb /* $OpenBSD: mlkem1024.c,v 1.6 2025/01/03 08:19:24 tb Exp $ */
208c63c71Sbeck /*
308c63c71Sbeck * Copyright (c) 2024, Google Inc.
408c63c71Sbeck * Copyright (c) 2024, Bob Beck <beck@obtuse.com>
508c63c71Sbeck *
608c63c71Sbeck * Permission to use, copy, modify, and/or distribute this software for any
708c63c71Sbeck * purpose with or without fee is hereby granted, provided that the above
808c63c71Sbeck * copyright notice and this permission notice appear in all copies.
908c63c71Sbeck *
1008c63c71Sbeck * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
1108c63c71Sbeck * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
1208c63c71Sbeck * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
1308c63c71Sbeck * SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
1408c63c71Sbeck * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
1508c63c71Sbeck * OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
1608c63c71Sbeck * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
1708c63c71Sbeck */
1808c63c71Sbeck
1908c63c71Sbeck #include <assert.h>
2008c63c71Sbeck #include <stdlib.h>
2108c63c71Sbeck #include <string.h>
2208c63c71Sbeck
2308c63c71Sbeck #include "bytestring.h"
24ef1019e6Stb #include "mlkem.h"
2508c63c71Sbeck
2608c63c71Sbeck #include "sha3_internal.h"
2708c63c71Sbeck #include "mlkem_internal.h"
2808c63c71Sbeck #include "constant_time.h"
2908c63c71Sbeck #include "crypto_internal.h"
3008c63c71Sbeck
3108c63c71Sbeck /* Remove later */
3208c63c71Sbeck #undef LCRYPTO_ALIAS
3308c63c71Sbeck #define LCRYPTO_ALIAS(A)
3408c63c71Sbeck
3508c63c71Sbeck /*
3608c63c71Sbeck * See
3708c63c71Sbeck * https://csrc.nist.gov/pubs/fips/203/final
3808c63c71Sbeck */
3908c63c71Sbeck
4008c63c71Sbeck static void
prf(uint8_t * out,size_t out_len,const uint8_t in[33])4108c63c71Sbeck prf(uint8_t *out, size_t out_len, const uint8_t in[33])
4208c63c71Sbeck {
4308c63c71Sbeck sha3_ctx ctx;
4408c63c71Sbeck shake256_init(&ctx);
4508c63c71Sbeck shake_update(&ctx, in, 33);
4608c63c71Sbeck shake_xof(&ctx);
4708c63c71Sbeck shake_out(&ctx, out, out_len);
4808c63c71Sbeck }
4908c63c71Sbeck
5008c63c71Sbeck /* Section 4.1 */
5108c63c71Sbeck static void
hash_h(uint8_t out[32],const uint8_t * in,size_t len)5208c63c71Sbeck hash_h(uint8_t out[32], const uint8_t *in, size_t len)
5308c63c71Sbeck {
5408c63c71Sbeck sha3_ctx ctx;
5508c63c71Sbeck sha3_init(&ctx, 32);
5608c63c71Sbeck sha3_update(&ctx, in, len);
5708c63c71Sbeck sha3_final(out, &ctx);
5808c63c71Sbeck }
5908c63c71Sbeck
6008c63c71Sbeck static void
hash_g(uint8_t out[64],const uint8_t * in,size_t len)6108c63c71Sbeck hash_g(uint8_t out[64], const uint8_t *in, size_t len)
6208c63c71Sbeck {
6308c63c71Sbeck sha3_ctx ctx;
6408c63c71Sbeck sha3_init(&ctx, 64);
6508c63c71Sbeck sha3_update(&ctx, in, len);
6608c63c71Sbeck sha3_final(out, &ctx);
6708c63c71Sbeck }
6808c63c71Sbeck
6908c63c71Sbeck /* this is called 'J' in the spec */
7008c63c71Sbeck static void
kdf(uint8_t out[MLKEM_SHARED_SECRET_BYTES],const uint8_t failure_secret[32],const uint8_t * in,size_t len)7108c63c71Sbeck kdf(uint8_t out[MLKEM_SHARED_SECRET_BYTES], const uint8_t failure_secret[32],
7208c63c71Sbeck const uint8_t *in, size_t len)
7308c63c71Sbeck {
7408c63c71Sbeck sha3_ctx ctx;
7508c63c71Sbeck shake256_init(&ctx);
7608c63c71Sbeck shake_update(&ctx, failure_secret, 32);
7708c63c71Sbeck shake_update(&ctx, in, len);
7808c63c71Sbeck shake_xof(&ctx);
7908c63c71Sbeck shake_out(&ctx, out, MLKEM_SHARED_SECRET_BYTES);
8008c63c71Sbeck }
8108c63c71Sbeck
8208c63c71Sbeck #define DEGREE 256
8308c63c71Sbeck #define RANK1024 4
8408c63c71Sbeck
8508c63c71Sbeck static const size_t kBarrettMultiplier = 5039;
8608c63c71Sbeck static const unsigned kBarrettShift = 24;
8708c63c71Sbeck static const uint16_t kPrime = 3329;
8808c63c71Sbeck static const int kLog2Prime = 12;
8908c63c71Sbeck static const uint16_t kHalfPrime = (/*kPrime=*/3329 - 1) / 2;
9008c63c71Sbeck static const int kDU1024 = 11;
9108c63c71Sbeck static const int kDV1024 = 5;
9208c63c71Sbeck
9308c63c71Sbeck /*
9408c63c71Sbeck * kInverseDegree is 128^-1 mod 3329; 128 because kPrime does not have a 512th
9508c63c71Sbeck * root of unity.
9608c63c71Sbeck */
9708c63c71Sbeck static const uint16_t kInverseDegree = 3303;
9808c63c71Sbeck static const size_t kEncodedVectorSize =
9908c63c71Sbeck (/*kLog2Prime=*/12 * DEGREE / 8) * RANK1024;
10008c63c71Sbeck static const size_t kCompressedVectorSize = /*kDU1024=*/ 11 * RANK1024 * DEGREE /
10108c63c71Sbeck 8;
10208c63c71Sbeck
10308c63c71Sbeck typedef struct scalar {
10408c63c71Sbeck /* On every function entry and exit, 0 <= c < kPrime. */
10508c63c71Sbeck uint16_t c[DEGREE];
10608c63c71Sbeck } scalar;
10708c63c71Sbeck
10808c63c71Sbeck typedef struct vector {
10908c63c71Sbeck scalar v[RANK1024];
11008c63c71Sbeck } vector;
11108c63c71Sbeck
11208c63c71Sbeck typedef struct matrix {
11308c63c71Sbeck scalar v[RANK1024][RANK1024];
11408c63c71Sbeck } matrix;
11508c63c71Sbeck
11608c63c71Sbeck /*
11708c63c71Sbeck * This bit of Python will be referenced in some of the following comments:
11808c63c71Sbeck *
11908c63c71Sbeck * p = 3329
12008c63c71Sbeck *
12108c63c71Sbeck * def bitreverse(i):
12208c63c71Sbeck * ret = 0
12308c63c71Sbeck * for n in range(7):
12408c63c71Sbeck * bit = i & 1
12508c63c71Sbeck * ret <<= 1
12608c63c71Sbeck * ret |= bit
12708c63c71Sbeck * i >>= 1
12808c63c71Sbeck * return ret
12908c63c71Sbeck */
13008c63c71Sbeck
13108c63c71Sbeck /* kNTTRoots = [pow(17, bitreverse(i), p) for i in range(128)] */
13208c63c71Sbeck static const uint16_t kNTTRoots[128] = {
13308c63c71Sbeck 1, 1729, 2580, 3289, 2642, 630, 1897, 848, 1062, 1919, 193, 797,
13408c63c71Sbeck 2786, 3260, 569, 1746, 296, 2447, 1339, 1476, 3046, 56, 2240, 1333,
13508c63c71Sbeck 1426, 2094, 535, 2882, 2393, 2879, 1974, 821, 289, 331, 3253, 1756,
13608c63c71Sbeck 1197, 2304, 2277, 2055, 650, 1977, 2513, 632, 2865, 33, 1320, 1915,
13708c63c71Sbeck 2319, 1435, 807, 452, 1438, 2868, 1534, 2402, 2647, 2617, 1481, 648,
13808c63c71Sbeck 2474, 3110, 1227, 910, 17, 2761, 583, 2649, 1637, 723, 2288, 1100,
13908c63c71Sbeck 1409, 2662, 3281, 233, 756, 2156, 3015, 3050, 1703, 1651, 2789, 1789,
14008c63c71Sbeck 1847, 952, 1461, 2687, 939, 2308, 2437, 2388, 733, 2337, 268, 641,
14108c63c71Sbeck 1584, 2298, 2037, 3220, 375, 2549, 2090, 1645, 1063, 319, 2773, 757,
14208c63c71Sbeck 2099, 561, 2466, 2594, 2804, 1092, 403, 1026, 1143, 2150, 2775, 886,
14308c63c71Sbeck 1722, 1212, 1874, 1029, 2110, 2935, 885, 2154,
14408c63c71Sbeck };
14508c63c71Sbeck
14608c63c71Sbeck /* kInverseNTTRoots = [pow(17, -bitreverse(i), p) for i in range(128)] */
14708c63c71Sbeck static const uint16_t kInverseNTTRoots[128] = {
14808c63c71Sbeck 1, 1600, 40, 749, 2481, 1432, 2699, 687, 1583, 2760, 69, 543,
14908c63c71Sbeck 2532, 3136, 1410, 2267, 2508, 1355, 450, 936, 447, 2794, 1235, 1903,
15008c63c71Sbeck 1996, 1089, 3273, 283, 1853, 1990, 882, 3033, 2419, 2102, 219, 855,
15108c63c71Sbeck 2681, 1848, 712, 682, 927, 1795, 461, 1891, 2877, 2522, 1894, 1010,
15208c63c71Sbeck 1414, 2009, 3296, 464, 2697, 816, 1352, 2679, 1274, 1052, 1025, 2132,
15308c63c71Sbeck 1573, 76, 2998, 3040, 1175, 2444, 394, 1219, 2300, 1455, 2117, 1607,
15408c63c71Sbeck 2443, 554, 1179, 2186, 2303, 2926, 2237, 525, 735, 863, 2768, 1230,
15508c63c71Sbeck 2572, 556, 3010, 2266, 1684, 1239, 780, 2954, 109, 1292, 1031, 1745,
15608c63c71Sbeck 2688, 3061, 992, 2596, 941, 892, 1021, 2390, 642, 1868, 2377, 1482,
15708c63c71Sbeck 1540, 540, 1678, 1626, 279, 314, 1173, 2573, 3096, 48, 667, 1920,
15808c63c71Sbeck 2229, 1041, 2606, 1692, 680, 2746, 568, 3312,
15908c63c71Sbeck };
16008c63c71Sbeck
16108c63c71Sbeck /* kModRoots = [pow(17, 2*bitreverse(i) + 1, p) for i in range(128)] */
16208c63c71Sbeck static const uint16_t kModRoots[128] = {
16308c63c71Sbeck 17, 3312, 2761, 568, 583, 2746, 2649, 680, 1637, 1692, 723, 2606,
16408c63c71Sbeck 2288, 1041, 1100, 2229, 1409, 1920, 2662, 667, 3281, 48, 233, 3096,
16508c63c71Sbeck 756, 2573, 2156, 1173, 3015, 314, 3050, 279, 1703, 1626, 1651, 1678,
16608c63c71Sbeck 2789, 540, 1789, 1540, 1847, 1482, 952, 2377, 1461, 1868, 2687, 642,
16708c63c71Sbeck 939, 2390, 2308, 1021, 2437, 892, 2388, 941, 733, 2596, 2337, 992,
16808c63c71Sbeck 268, 3061, 641, 2688, 1584, 1745, 2298, 1031, 2037, 1292, 3220, 109,
16908c63c71Sbeck 375, 2954, 2549, 780, 2090, 1239, 1645, 1684, 1063, 2266, 319, 3010,
17008c63c71Sbeck 2773, 556, 757, 2572, 2099, 1230, 561, 2768, 2466, 863, 2594, 735,
17108c63c71Sbeck 2804, 525, 1092, 2237, 403, 2926, 1026, 2303, 1143, 2186, 2150, 1179,
17208c63c71Sbeck 2775, 554, 886, 2443, 1722, 1607, 1212, 2117, 1874, 1455, 1029, 2300,
17308c63c71Sbeck 2110, 1219, 2935, 394, 885, 2444, 2154, 1175,
17408c63c71Sbeck };
17508c63c71Sbeck
17608c63c71Sbeck /* reduce_once reduces 0 <= x < 2*kPrime, mod kPrime. */
17708c63c71Sbeck static uint16_t
reduce_once(uint16_t x)17808c63c71Sbeck reduce_once(uint16_t x)
17908c63c71Sbeck {
18008c63c71Sbeck assert(x < 2 * kPrime);
18108c63c71Sbeck const uint16_t subtracted = x - kPrime;
18208c63c71Sbeck uint16_t mask = 0u - (subtracted >> 15);
1839ee6f1feSbeck
18408c63c71Sbeck /*
1859ee6f1feSbeck * Although this is a constant-time select, we omit a value barrier here.
1869ee6f1feSbeck * Value barriers impede auto-vectorization (likely because it forces the
1879ee6f1feSbeck * value to transit through a general-purpose register). On AArch64, this
1889ee6f1feSbeck * is a difference of 2x.
1899ee6f1feSbeck *
1909ee6f1feSbeck * We usually add value barriers to selects because Clang turns
1919ee6f1feSbeck * consecutive selects with the same condition into a branch instead of
1929ee6f1feSbeck * CMOV/CSEL. This condition does not occur in ML-KEM, so omitting it
1939ee6f1feSbeck * seems to be safe so far but see
1949ee6f1feSbeck * |scalar_centered_binomial_distribution_eta_2_with_prf|.
19508c63c71Sbeck */
19608c63c71Sbeck return (mask & x) | (~mask & subtracted);
19708c63c71Sbeck }
19808c63c71Sbeck
19908c63c71Sbeck /*
20008c63c71Sbeck * constant time reduce x mod kPrime using Barrett reduction. x must be less
20108c63c71Sbeck * than kPrime + 2×kPrime².
20208c63c71Sbeck */
20308c63c71Sbeck static uint16_t
reduce(uint32_t x)20408c63c71Sbeck reduce(uint32_t x)
20508c63c71Sbeck {
20608c63c71Sbeck uint64_t product = (uint64_t)x * kBarrettMultiplier;
20708c63c71Sbeck uint32_t quotient = (uint32_t)(product >> kBarrettShift);
20808c63c71Sbeck uint32_t remainder = x - quotient * kPrime;
20908c63c71Sbeck
21008c63c71Sbeck assert(x < kPrime + 2u * kPrime * kPrime);
21108c63c71Sbeck return reduce_once(remainder);
21208c63c71Sbeck }
21308c63c71Sbeck
21408c63c71Sbeck static void
scalar_zero(scalar * out)21508c63c71Sbeck scalar_zero(scalar *out)
21608c63c71Sbeck {
21708c63c71Sbeck memset(out, 0, sizeof(*out));
21808c63c71Sbeck }
21908c63c71Sbeck
22008c63c71Sbeck static void
vector_zero(vector * out)22108c63c71Sbeck vector_zero(vector *out)
22208c63c71Sbeck {
22308c63c71Sbeck memset(out, 0, sizeof(*out));
22408c63c71Sbeck }
22508c63c71Sbeck
22608c63c71Sbeck /*
22708c63c71Sbeck * In place number theoretic transform of a given scalar.
22808c63c71Sbeck * Note that MLKEM's kPrime 3329 does not have a 512th root of unity, so this
22908c63c71Sbeck * transform leaves off the last iteration of the usual FFT code, with the 128
23008c63c71Sbeck * relevant roots of unity being stored in |kNTTRoots|. This means the output
23108c63c71Sbeck * should be seen as 128 elements in GF(3329^2), with the coefficients of the
23208c63c71Sbeck * elements being consecutive entries in |s->c|.
23308c63c71Sbeck */
23408c63c71Sbeck static void
scalar_ntt(scalar * s)23508c63c71Sbeck scalar_ntt(scalar *s)
23608c63c71Sbeck {
23708c63c71Sbeck int offset = DEGREE;
23808c63c71Sbeck int step;
23908c63c71Sbeck /*
24008c63c71Sbeck * `int` is used here because using `size_t` throughout caused a ~5% slowdown
24108c63c71Sbeck * with Clang 14 on Aarch64.
24208c63c71Sbeck */
24308c63c71Sbeck for (step = 1; step < DEGREE / 2; step <<= 1) {
24408c63c71Sbeck int i, j, k = 0;
24508c63c71Sbeck
24608c63c71Sbeck offset >>= 1;
24708c63c71Sbeck for (i = 0; i < step; i++) {
24808c63c71Sbeck const uint32_t step_root = kNTTRoots[i + step];
24908c63c71Sbeck
25008c63c71Sbeck for (j = k; j < k + offset; j++) {
25108c63c71Sbeck uint16_t odd, even;
25208c63c71Sbeck
25308c63c71Sbeck odd = reduce(step_root * s->c[j + offset]);
25408c63c71Sbeck even = s->c[j];
25508c63c71Sbeck s->c[j] = reduce_once(odd + even);
25608c63c71Sbeck s->c[j + offset] = reduce_once(even - odd +
25708c63c71Sbeck kPrime);
25808c63c71Sbeck }
25908c63c71Sbeck k += 2 * offset;
26008c63c71Sbeck }
26108c63c71Sbeck }
26208c63c71Sbeck }
26308c63c71Sbeck
26408c63c71Sbeck static void
vector_ntt(vector * a)26508c63c71Sbeck vector_ntt(vector *a)
26608c63c71Sbeck {
26708c63c71Sbeck int i;
26808c63c71Sbeck
26908c63c71Sbeck for (i = 0; i < RANK1024; i++) {
27008c63c71Sbeck scalar_ntt(&a->v[i]);
27108c63c71Sbeck }
27208c63c71Sbeck }
27308c63c71Sbeck
27408c63c71Sbeck /*
27508c63c71Sbeck * In place inverse number theoretic transform of a given scalar, with pairs of
27608c63c71Sbeck * entries of s->v being interpreted as elements of GF(3329^2). Just as with the
27708c63c71Sbeck * number theoretic transform, this leaves off the first step of the normal iFFT
27808c63c71Sbeck * to account for the fact that 3329 does not have a 512th root of unity, using
27908c63c71Sbeck * the precomputed 128 roots of unity stored in |kInverseNTTRoots|.
28008c63c71Sbeck */
28108c63c71Sbeck static void
scalar_inverse_ntt(scalar * s)28208c63c71Sbeck scalar_inverse_ntt(scalar *s)
28308c63c71Sbeck {
28408c63c71Sbeck int i, j, k, offset, step = DEGREE / 2;
28508c63c71Sbeck
28608c63c71Sbeck /*
28708c63c71Sbeck * `int` is used here because using `size_t` throughout caused a ~5% slowdown
28808c63c71Sbeck * with Clang 14 on Aarch64.
28908c63c71Sbeck */
29008c63c71Sbeck for (offset = 2; offset < DEGREE; offset <<= 1) {
29108c63c71Sbeck step >>= 1;
29208c63c71Sbeck k = 0;
29308c63c71Sbeck for (i = 0; i < step; i++) {
29408c63c71Sbeck uint32_t step_root = kInverseNTTRoots[i + step];
29508c63c71Sbeck for (j = k; j < k + offset; j++) {
29608c63c71Sbeck uint16_t odd, even;
29708c63c71Sbeck odd = s->c[j + offset];
29808c63c71Sbeck even = s->c[j];
29908c63c71Sbeck s->c[j] = reduce_once(odd + even);
30008c63c71Sbeck s->c[j + offset] = reduce(step_root *
30108c63c71Sbeck (even - odd + kPrime));
30208c63c71Sbeck }
30308c63c71Sbeck k += 2 * offset;
30408c63c71Sbeck }
30508c63c71Sbeck }
30608c63c71Sbeck for (i = 0; i < DEGREE; i++) {
30708c63c71Sbeck s->c[i] = reduce(s->c[i] * kInverseDegree);
30808c63c71Sbeck }
30908c63c71Sbeck }
31008c63c71Sbeck
31108c63c71Sbeck static void
vector_inverse_ntt(vector * a)31208c63c71Sbeck vector_inverse_ntt(vector *a)
31308c63c71Sbeck {
31408c63c71Sbeck int i;
31508c63c71Sbeck
31608c63c71Sbeck for (i = 0; i < RANK1024; i++) {
31708c63c71Sbeck scalar_inverse_ntt(&a->v[i]);
31808c63c71Sbeck }
31908c63c71Sbeck }
32008c63c71Sbeck
32108c63c71Sbeck static void
scalar_add(scalar * lhs,const scalar * rhs)32208c63c71Sbeck scalar_add(scalar *lhs, const scalar *rhs)
32308c63c71Sbeck {
32408c63c71Sbeck int i;
32508c63c71Sbeck
32608c63c71Sbeck for (i = 0; i < DEGREE; i++) {
32708c63c71Sbeck lhs->c[i] = reduce_once(lhs->c[i] + rhs->c[i]);
32808c63c71Sbeck }
32908c63c71Sbeck }
33008c63c71Sbeck
33108c63c71Sbeck static void
scalar_sub(scalar * lhs,const scalar * rhs)33208c63c71Sbeck scalar_sub(scalar *lhs, const scalar *rhs)
33308c63c71Sbeck {
33408c63c71Sbeck int i;
33508c63c71Sbeck
33608c63c71Sbeck for (i = 0; i < DEGREE; i++) {
33708c63c71Sbeck lhs->c[i] = reduce_once(lhs->c[i] - rhs->c[i] + kPrime);
33808c63c71Sbeck }
33908c63c71Sbeck }
34008c63c71Sbeck
34108c63c71Sbeck /*
342*9a8fba7cStb * Multiplying two scalars in the number theoretically transformed state.
343*9a8fba7cStb * Since 3329 does not have a 512th root of unity, this means we have to
344*9a8fba7cStb * interpret the 2*ith and (2*i+1)th entries of the scalar as elements of
345*9a8fba7cStb * GF(3329)[X]/(X^2 - 17^(2*bitreverse(i)+1)).
346*9a8fba7cStb * The value of 17^(2*bitreverse(i)+1) mod 3329 is stored in the precomputed
347*9a8fba7cStb * |kModRoots| table. Our Barrett transform only allows us to multiply two
348*9a8fba7cStb * reduced numbers together, so we need some intermediate reduction steps,
349*9a8fba7cStb * even if an uint64_t could hold 3 multiplied numbers.
35008c63c71Sbeck */
35108c63c71Sbeck static void
scalar_mult(scalar * out,const scalar * lhs,const scalar * rhs)35208c63c71Sbeck scalar_mult(scalar *out, const scalar *lhs, const scalar *rhs)
35308c63c71Sbeck {
35408c63c71Sbeck int i;
35508c63c71Sbeck
35608c63c71Sbeck for (i = 0; i < DEGREE / 2; i++) {
35708c63c71Sbeck uint32_t real_real = (uint32_t)lhs->c[2 * i] * rhs->c[2 * i];
35808c63c71Sbeck uint32_t img_img = (uint32_t)lhs->c[2 * i + 1] *
35908c63c71Sbeck rhs->c[2 * i + 1];
36008c63c71Sbeck uint32_t real_img = (uint32_t)lhs->c[2 * i] * rhs->c[2 * i + 1];
36108c63c71Sbeck uint32_t img_real = (uint32_t)lhs->c[2 * i + 1] * rhs->c[2 * i];
36208c63c71Sbeck
36308c63c71Sbeck out->c[2 * i] =
36408c63c71Sbeck reduce(real_real +
36508c63c71Sbeck (uint32_t)reduce(img_img) * kModRoots[i]);
36608c63c71Sbeck out->c[2 * i + 1] = reduce(img_real + real_img);
36708c63c71Sbeck }
36808c63c71Sbeck }
36908c63c71Sbeck
37008c63c71Sbeck static void
vector_add(vector * lhs,const vector * rhs)37108c63c71Sbeck vector_add(vector *lhs, const vector *rhs)
37208c63c71Sbeck {
37308c63c71Sbeck int i;
37408c63c71Sbeck
37508c63c71Sbeck for (i = 0; i < RANK1024; i++) {
37608c63c71Sbeck scalar_add(&lhs->v[i], &rhs->v[i]);
37708c63c71Sbeck }
37808c63c71Sbeck }
37908c63c71Sbeck
38008c63c71Sbeck static void
matrix_mult(vector * out,const matrix * m,const vector * a)38108c63c71Sbeck matrix_mult(vector *out, const matrix *m, const vector *a)
38208c63c71Sbeck {
38308c63c71Sbeck int i, j;
38408c63c71Sbeck
38508c63c71Sbeck vector_zero(out);
38608c63c71Sbeck for (i = 0; i < RANK1024; i++) {
38708c63c71Sbeck for (j = 0; j < RANK1024; j++) {
38808c63c71Sbeck scalar product;
38908c63c71Sbeck
39008c63c71Sbeck scalar_mult(&product, &m->v[i][j], &a->v[j]);
39108c63c71Sbeck scalar_add(&out->v[i], &product);
39208c63c71Sbeck }
39308c63c71Sbeck }
39408c63c71Sbeck }
39508c63c71Sbeck
39608c63c71Sbeck static void
matrix_mult_transpose(vector * out,const matrix * m,const vector * a)39708c63c71Sbeck matrix_mult_transpose(vector *out, const matrix *m,
39808c63c71Sbeck const vector *a)
39908c63c71Sbeck {
40008c63c71Sbeck int i, j;
40108c63c71Sbeck
40208c63c71Sbeck vector_zero(out);
40308c63c71Sbeck for (i = 0; i < RANK1024; i++) {
40408c63c71Sbeck for (j = 0; j < RANK1024; j++) {
40508c63c71Sbeck scalar product;
40608c63c71Sbeck
40708c63c71Sbeck scalar_mult(&product, &m->v[j][i], &a->v[j]);
40808c63c71Sbeck scalar_add(&out->v[i], &product);
40908c63c71Sbeck }
41008c63c71Sbeck }
41108c63c71Sbeck }
41208c63c71Sbeck
41308c63c71Sbeck static void
scalar_inner_product(scalar * out,const vector * lhs,const vector * rhs)41408c63c71Sbeck scalar_inner_product(scalar *out, const vector *lhs,
41508c63c71Sbeck const vector *rhs)
41608c63c71Sbeck {
41708c63c71Sbeck int i;
41808c63c71Sbeck scalar_zero(out);
41908c63c71Sbeck for (i = 0; i < RANK1024; i++) {
42008c63c71Sbeck scalar product;
42108c63c71Sbeck
42208c63c71Sbeck scalar_mult(&product, &lhs->v[i], &rhs->v[i]);
42308c63c71Sbeck scalar_add(out, &product);
42408c63c71Sbeck }
42508c63c71Sbeck }
42608c63c71Sbeck
42708c63c71Sbeck /*
42808c63c71Sbeck * Algorithm 6 of spec. Rejection samples a Keccak stream to get uniformly
42908c63c71Sbeck * distributed elements. This is used for matrix expansion and only operates on
43008c63c71Sbeck * public inputs.
43108c63c71Sbeck */
43208c63c71Sbeck static void
scalar_from_keccak_vartime(scalar * out,sha3_ctx * keccak_ctx)43308c63c71Sbeck scalar_from_keccak_vartime(scalar *out, sha3_ctx *keccak_ctx)
43408c63c71Sbeck {
43508c63c71Sbeck int i, done = 0;
43608c63c71Sbeck
43708c63c71Sbeck while (done < DEGREE) {
43808c63c71Sbeck uint8_t block[168];
43908c63c71Sbeck
44008c63c71Sbeck shake_out(keccak_ctx, block, sizeof(block));
44108c63c71Sbeck for (i = 0; i < sizeof(block) && done < DEGREE; i += 3) {
44208c63c71Sbeck uint16_t d1 = block[i] + 256 * (block[i + 1] % 16);
44308c63c71Sbeck uint16_t d2 = block[i + 1] / 16 + 16 * block[i + 2];
44408c63c71Sbeck
44508c63c71Sbeck if (d1 < kPrime) {
44608c63c71Sbeck out->c[done++] = d1;
44708c63c71Sbeck }
44808c63c71Sbeck if (d2 < kPrime && done < DEGREE) {
44908c63c71Sbeck out->c[done++] = d2;
45008c63c71Sbeck }
45108c63c71Sbeck }
45208c63c71Sbeck }
45308c63c71Sbeck }
45408c63c71Sbeck
45508c63c71Sbeck /*
45608c63c71Sbeck * Algorithm 7 of the spec, with eta fixed to two and the PRF call
45708c63c71Sbeck * included. Creates binominally distributed elements by sampling 2*|eta| bits,
45808c63c71Sbeck * and setting the coefficient to the count of the first bits minus the count of
45908c63c71Sbeck * the second bits, resulting in a centered binomial distribution. Since eta is
46008c63c71Sbeck * two this gives -2/2 with a probability of 1/16, -1/1 with probability 1/4,
46108c63c71Sbeck * and 0 with probability 3/8.
46208c63c71Sbeck */
46308c63c71Sbeck static void
scalar_centered_binomial_distribution_eta_2_with_prf(scalar * out,const uint8_t input[33])46408c63c71Sbeck scalar_centered_binomial_distribution_eta_2_with_prf(scalar *out,
46508c63c71Sbeck const uint8_t input[33])
46608c63c71Sbeck {
46708c63c71Sbeck uint8_t entropy[128];
46808c63c71Sbeck int i;
46908c63c71Sbeck
47008c63c71Sbeck CTASSERT(sizeof(entropy) == 2 * /*kEta=*/ 2 * DEGREE / 8);
47108c63c71Sbeck prf(entropy, sizeof(entropy), input);
47208c63c71Sbeck
47308c63c71Sbeck for (i = 0; i < DEGREE; i += 2) {
47408c63c71Sbeck uint8_t byte = entropy[i / 2];
4759ee6f1feSbeck uint16_t mask;
4769ee6f1feSbeck uint16_t value = (byte & 1) + ((byte >> 1) & 1);
47708c63c71Sbeck
47808c63c71Sbeck value -= ((byte >> 2) & 1) + ((byte >> 3) & 1);
47967abc7a1Stb
4809ee6f1feSbeck /*
4819ee6f1feSbeck * Add |kPrime| if |value| underflowed. See |reduce_once| for a
4829ee6f1feSbeck * discussion on why the value barrier is omitted. While this
4839ee6f1feSbeck * could have been written reduce_once(value + kPrime), this is
4849ee6f1feSbeck * one extra addition and small range of |value| tempts some
4859ee6f1feSbeck * versions of Clang to emit a branch.
4869ee6f1feSbeck */
4879ee6f1feSbeck mask = 0u - (value >> 15);
4889ee6f1feSbeck out->c[i] = ((value + kPrime) & mask) | (value & ~mask);
48908c63c71Sbeck
49008c63c71Sbeck byte >>= 4;
4919ee6f1feSbeck value = (byte & 1) + ((byte >> 1) & 1);
49208c63c71Sbeck value -= ((byte >> 2) & 1) + ((byte >> 3) & 1);
4939ee6f1feSbeck /* See above. */
4949ee6f1feSbeck mask = 0u - (value >> 15);
4959ee6f1feSbeck out->c[i + 1] = ((value + kPrime) & mask) | (value & ~mask);
49608c63c71Sbeck }
49708c63c71Sbeck }
49808c63c71Sbeck
49908c63c71Sbeck /*
50008c63c71Sbeck * Generates a secret vector by using
50108c63c71Sbeck * |scalar_centered_binomial_distribution_eta_2_with_prf|, using the given seed
50208c63c71Sbeck * appending and incrementing |counter| for entry of the vector.
50308c63c71Sbeck */
50408c63c71Sbeck static void
vector_generate_secret_eta_2(vector * out,uint8_t * counter,const uint8_t seed[32])50508c63c71Sbeck vector_generate_secret_eta_2(vector *out, uint8_t *counter,
50608c63c71Sbeck const uint8_t seed[32])
50708c63c71Sbeck {
50808c63c71Sbeck uint8_t input[33];
50908c63c71Sbeck int i;
51008c63c71Sbeck
51108c63c71Sbeck memcpy(input, seed, 32);
51208c63c71Sbeck for (i = 0; i < RANK1024; i++) {
51308c63c71Sbeck input[32] = (*counter)++;
51408c63c71Sbeck scalar_centered_binomial_distribution_eta_2_with_prf(&out->v[i],
51508c63c71Sbeck input);
51608c63c71Sbeck }
51708c63c71Sbeck }
51808c63c71Sbeck
51908c63c71Sbeck /* Expands the matrix of a seed for key generation and for encaps-CPA. */
52008c63c71Sbeck static void
matrix_expand(matrix * out,const uint8_t rho[32])52108c63c71Sbeck matrix_expand(matrix *out, const uint8_t rho[32])
52208c63c71Sbeck {
52308c63c71Sbeck uint8_t input[34];
52408c63c71Sbeck int i, j;
52508c63c71Sbeck
52608c63c71Sbeck memcpy(input, rho, 32);
52708c63c71Sbeck for (i = 0; i < RANK1024; i++) {
52808c63c71Sbeck for (j = 0; j < RANK1024; j++) {
52908c63c71Sbeck sha3_ctx keccak_ctx;
53008c63c71Sbeck
53108c63c71Sbeck input[32] = i;
53208c63c71Sbeck input[33] = j;
53308c63c71Sbeck shake128_init(&keccak_ctx);
53408c63c71Sbeck shake_update(&keccak_ctx, input, sizeof(input));
53508c63c71Sbeck shake_xof(&keccak_ctx);
53608c63c71Sbeck scalar_from_keccak_vartime(&out->v[i][j], &keccak_ctx);
53708c63c71Sbeck }
53808c63c71Sbeck }
53908c63c71Sbeck }
54008c63c71Sbeck
54108c63c71Sbeck static const uint8_t kMasks[8] = {0x01, 0x03, 0x07, 0x0f,
54208c63c71Sbeck 0x1f, 0x3f, 0x7f, 0xff};
54308c63c71Sbeck
54408c63c71Sbeck static void
scalar_encode(uint8_t * out,const scalar * s,int bits)54508c63c71Sbeck scalar_encode(uint8_t *out, const scalar *s, int bits)
54608c63c71Sbeck {
54708c63c71Sbeck uint8_t out_byte = 0;
54808c63c71Sbeck int i, out_byte_bits = 0;
54908c63c71Sbeck
55008c63c71Sbeck assert(bits <= (int)sizeof(*s->c) * 8 && bits != 1);
55108c63c71Sbeck for (i = 0; i < DEGREE; i++) {
55208c63c71Sbeck uint16_t element = s->c[i];
55308c63c71Sbeck int element_bits_done = 0;
55408c63c71Sbeck
55508c63c71Sbeck while (element_bits_done < bits) {
55608c63c71Sbeck int chunk_bits = bits - element_bits_done;
55708c63c71Sbeck int out_bits_remaining = 8 - out_byte_bits;
55808c63c71Sbeck
55908c63c71Sbeck if (chunk_bits >= out_bits_remaining) {
56008c63c71Sbeck chunk_bits = out_bits_remaining;
56108c63c71Sbeck out_byte |= (element &
56208c63c71Sbeck kMasks[chunk_bits - 1]) << out_byte_bits;
56308c63c71Sbeck *out = out_byte;
56408c63c71Sbeck out++;
56508c63c71Sbeck out_byte_bits = 0;
56608c63c71Sbeck out_byte = 0;
56708c63c71Sbeck } else {
56808c63c71Sbeck out_byte |= (element &
56908c63c71Sbeck kMasks[chunk_bits - 1]) << out_byte_bits;
57008c63c71Sbeck out_byte_bits += chunk_bits;
57108c63c71Sbeck }
57208c63c71Sbeck
57308c63c71Sbeck element_bits_done += chunk_bits;
57408c63c71Sbeck element >>= chunk_bits;
57508c63c71Sbeck }
57608c63c71Sbeck }
57708c63c71Sbeck
57808c63c71Sbeck if (out_byte_bits > 0) {
57908c63c71Sbeck *out = out_byte;
58008c63c71Sbeck }
58108c63c71Sbeck }
58208c63c71Sbeck
58308c63c71Sbeck /* scalar_encode_1 is |scalar_encode| specialised for |bits| == 1. */
58408c63c71Sbeck static void
scalar_encode_1(uint8_t out[32],const scalar * s)58508c63c71Sbeck scalar_encode_1(uint8_t out[32], const scalar *s)
58608c63c71Sbeck {
58708c63c71Sbeck int i, j;
58808c63c71Sbeck
58908c63c71Sbeck for (i = 0; i < DEGREE; i += 8) {
59008c63c71Sbeck uint8_t out_byte = 0;
59108c63c71Sbeck
59208c63c71Sbeck for (j = 0; j < 8; j++) {
59308c63c71Sbeck out_byte |= (s->c[i + j] & 1) << j;
59408c63c71Sbeck }
59508c63c71Sbeck *out = out_byte;
59608c63c71Sbeck out++;
59708c63c71Sbeck }
59808c63c71Sbeck }
59908c63c71Sbeck
60008c63c71Sbeck /*
60108c63c71Sbeck * Encodes an entire vector into 32*|RANK1024|*|bits| bytes. Note that since 256
60208c63c71Sbeck * (DEGREE) is divisible by 8, the individual vector entries will always fill a
60308c63c71Sbeck * whole number of bytes, so we do not need to worry about bit packing here.
60408c63c71Sbeck */
60508c63c71Sbeck static void
vector_encode(uint8_t * out,const vector * a,int bits)60608c63c71Sbeck vector_encode(uint8_t *out, const vector *a, int bits)
60708c63c71Sbeck {
60808c63c71Sbeck int i;
60908c63c71Sbeck
61008c63c71Sbeck for (i = 0; i < RANK1024; i++) {
61108c63c71Sbeck scalar_encode(out + i * bits * DEGREE / 8, &a->v[i], bits);
61208c63c71Sbeck }
61308c63c71Sbeck }
61408c63c71Sbeck
61508c63c71Sbeck /*
61608c63c71Sbeck * scalar_decode parses |DEGREE * bits| bits from |in| into |DEGREE| values in
61708c63c71Sbeck * |out|. It returns one on success and zero if any parsed value is >=
61808c63c71Sbeck * |kPrime|.
61908c63c71Sbeck */
62008c63c71Sbeck static int
scalar_decode(scalar * out,const uint8_t * in,int bits)62108c63c71Sbeck scalar_decode(scalar *out, const uint8_t *in, int bits)
62208c63c71Sbeck {
62308c63c71Sbeck uint8_t in_byte = 0;
62408c63c71Sbeck int i, in_byte_bits_left = 0;
62508c63c71Sbeck
62608c63c71Sbeck assert(bits <= (int)sizeof(*out->c) * 8 && bits != 1);
62708c63c71Sbeck
62808c63c71Sbeck for (i = 0; i < DEGREE; i++) {
62908c63c71Sbeck uint16_t element = 0;
63008c63c71Sbeck int element_bits_done = 0;
63108c63c71Sbeck
63208c63c71Sbeck while (element_bits_done < bits) {
63308c63c71Sbeck int chunk_bits = bits - element_bits_done;
63408c63c71Sbeck
63508c63c71Sbeck if (in_byte_bits_left == 0) {
63608c63c71Sbeck in_byte = *in;
63708c63c71Sbeck in++;
63808c63c71Sbeck in_byte_bits_left = 8;
63908c63c71Sbeck }
64008c63c71Sbeck
64108c63c71Sbeck if (chunk_bits > in_byte_bits_left) {
64208c63c71Sbeck chunk_bits = in_byte_bits_left;
64308c63c71Sbeck }
64408c63c71Sbeck
64508c63c71Sbeck element |= (in_byte & kMasks[chunk_bits - 1]) <<
64608c63c71Sbeck element_bits_done;
64708c63c71Sbeck in_byte_bits_left -= chunk_bits;
64808c63c71Sbeck in_byte >>= chunk_bits;
64908c63c71Sbeck
65008c63c71Sbeck element_bits_done += chunk_bits;
65108c63c71Sbeck }
65208c63c71Sbeck
65308c63c71Sbeck if (element >= kPrime) {
65408c63c71Sbeck return 0;
65508c63c71Sbeck }
65608c63c71Sbeck out->c[i] = element;
65708c63c71Sbeck }
65808c63c71Sbeck
65908c63c71Sbeck return 1;
66008c63c71Sbeck }
66108c63c71Sbeck
66208c63c71Sbeck /* scalar_decode_1 is |scalar_decode| specialised for |bits| == 1. */
66308c63c71Sbeck static void
scalar_decode_1(scalar * out,const uint8_t in[32])66408c63c71Sbeck scalar_decode_1(scalar *out, const uint8_t in[32])
66508c63c71Sbeck {
66608c63c71Sbeck int i, j;
66708c63c71Sbeck
66808c63c71Sbeck for (i = 0; i < DEGREE; i += 8) {
66908c63c71Sbeck uint8_t in_byte = *in;
67008c63c71Sbeck
67108c63c71Sbeck in++;
67208c63c71Sbeck for (j = 0; j < 8; j++) {
67308c63c71Sbeck out->c[i + j] = in_byte & 1;
67408c63c71Sbeck in_byte >>= 1;
67508c63c71Sbeck }
67608c63c71Sbeck }
67708c63c71Sbeck }
67808c63c71Sbeck
67908c63c71Sbeck /*
68008c63c71Sbeck * Decodes 32*|RANK1024|*|bits| bytes from |in| into |out|. It returns one on
68108c63c71Sbeck * success or zero if any parsed value is >= |kPrime|.
68208c63c71Sbeck */
68308c63c71Sbeck static int
vector_decode(vector * out,const uint8_t * in,int bits)68408c63c71Sbeck vector_decode(vector *out, const uint8_t *in, int bits)
68508c63c71Sbeck {
68608c63c71Sbeck int i;
68708c63c71Sbeck
68808c63c71Sbeck for (i = 0; i < RANK1024; i++) {
68908c63c71Sbeck if (!scalar_decode(&out->v[i], in + i * bits * DEGREE / 8,
69008c63c71Sbeck bits)) {
69108c63c71Sbeck return 0;
69208c63c71Sbeck }
69308c63c71Sbeck }
69408c63c71Sbeck return 1;
69508c63c71Sbeck }
69608c63c71Sbeck
69708c63c71Sbeck /*
69808c63c71Sbeck * Compresses (lossily) an input |x| mod 3329 into |bits| many bits by grouping
69908c63c71Sbeck * numbers close to each other together. The formula used is
70008c63c71Sbeck * round(2^|bits|/kPrime*x) mod 2^|bits|.
70108c63c71Sbeck * Uses Barrett reduction to achieve constant time. Since we need both the
70208c63c71Sbeck * remainder (for rounding) and the quotient (as the result), we cannot use
70308c63c71Sbeck * |reduce| here, but need to do the Barrett reduction directly.
70408c63c71Sbeck */
70508c63c71Sbeck static uint16_t
compress(uint16_t x,int bits)70608c63c71Sbeck compress(uint16_t x, int bits)
70708c63c71Sbeck {
70808c63c71Sbeck uint32_t shifted = (uint32_t)x << bits;
70908c63c71Sbeck uint64_t product = (uint64_t)shifted * kBarrettMultiplier;
71008c63c71Sbeck uint32_t quotient = (uint32_t)(product >> kBarrettShift);
71108c63c71Sbeck uint32_t remainder = shifted - quotient * kPrime;
71208c63c71Sbeck
71308c63c71Sbeck /*
71408c63c71Sbeck * Adjust the quotient to round correctly:
71508c63c71Sbeck * 0 <= remainder <= kHalfPrime round to 0
71608c63c71Sbeck * kHalfPrime < remainder <= kPrime + kHalfPrime round to 1
71708c63c71Sbeck * kPrime + kHalfPrime < remainder < 2 * kPrime round to 2
71808c63c71Sbeck */
71908c63c71Sbeck assert(remainder < 2u * kPrime);
72008c63c71Sbeck quotient += 1 & constant_time_lt(kHalfPrime, remainder);
72108c63c71Sbeck quotient += 1 & constant_time_lt(kPrime + kHalfPrime, remainder);
72208c63c71Sbeck return quotient & ((1 << bits) - 1);
72308c63c71Sbeck }
72408c63c71Sbeck
72508c63c71Sbeck /*
72608c63c71Sbeck * Decompresses |x| by using an equi-distant representative. The formula is
72708c63c71Sbeck * round(kPrime/2^|bits|*x). Note that 2^|bits| being the divisor allows us to
72808c63c71Sbeck * implement this logic using only bit operations.
72908c63c71Sbeck */
73008c63c71Sbeck static uint16_t
decompress(uint16_t x,int bits)73108c63c71Sbeck decompress(uint16_t x, int bits)
73208c63c71Sbeck {
73308c63c71Sbeck uint32_t product = (uint32_t)x * kPrime;
73408c63c71Sbeck uint32_t power = 1 << bits;
73508c63c71Sbeck /* This is |product| % power, since |power| is a power of 2. */
73608c63c71Sbeck uint32_t remainder = product & (power - 1);
73708c63c71Sbeck /* This is |product| / power, since |power| is a power of 2. */
73808c63c71Sbeck uint32_t lower = product >> bits;
73908c63c71Sbeck
74008c63c71Sbeck /*
74108c63c71Sbeck * The rounding logic works since the first half of numbers mod |power| have a
74208c63c71Sbeck * 0 as first bit, and the second half has a 1 as first bit, since |power| is
74308c63c71Sbeck * a power of 2. As a 12 bit number, |remainder| is always positive, so we
74408c63c71Sbeck * will shift in 0s for a right shift.
74508c63c71Sbeck */
74608c63c71Sbeck return lower + (remainder >> (bits - 1));
74708c63c71Sbeck }
74808c63c71Sbeck
74908c63c71Sbeck static void
scalar_compress(scalar * s,int bits)75008c63c71Sbeck scalar_compress(scalar *s, int bits)
75108c63c71Sbeck {
75208c63c71Sbeck int i;
75308c63c71Sbeck
75408c63c71Sbeck for (i = 0; i < DEGREE; i++) {
75508c63c71Sbeck s->c[i] = compress(s->c[i], bits);
75608c63c71Sbeck }
75708c63c71Sbeck }
75808c63c71Sbeck
75908c63c71Sbeck static void
scalar_decompress(scalar * s,int bits)76008c63c71Sbeck scalar_decompress(scalar *s, int bits)
76108c63c71Sbeck {
76208c63c71Sbeck int i;
76308c63c71Sbeck
76408c63c71Sbeck for (i = 0; i < DEGREE; i++) {
76508c63c71Sbeck s->c[i] = decompress(s->c[i], bits);
76608c63c71Sbeck }
76708c63c71Sbeck }
76808c63c71Sbeck
76908c63c71Sbeck static void
vector_compress(vector * a,int bits)77008c63c71Sbeck vector_compress(vector *a, int bits)
77108c63c71Sbeck {
77208c63c71Sbeck int i;
77308c63c71Sbeck
77408c63c71Sbeck for (i = 0; i < RANK1024; i++) {
77508c63c71Sbeck scalar_compress(&a->v[i], bits);
77608c63c71Sbeck }
77708c63c71Sbeck }
77808c63c71Sbeck
77908c63c71Sbeck static void
vector_decompress(vector * a,int bits)78008c63c71Sbeck vector_decompress(vector *a, int bits)
78108c63c71Sbeck {
78208c63c71Sbeck int i;
78308c63c71Sbeck
78408c63c71Sbeck for (i = 0; i < RANK1024; i++) {
78508c63c71Sbeck scalar_decompress(&a->v[i], bits);
78608c63c71Sbeck }
78708c63c71Sbeck }
78808c63c71Sbeck
78908c63c71Sbeck struct public_key {
79008c63c71Sbeck vector t;
79108c63c71Sbeck uint8_t rho[32];
79208c63c71Sbeck uint8_t public_key_hash[32];
79308c63c71Sbeck matrix m;
79408c63c71Sbeck };
79508c63c71Sbeck
79608c63c71Sbeck static struct public_key *
public_key_1024_from_external(const struct MLKEM1024_public_key * external)79708c63c71Sbeck public_key_1024_from_external(const struct MLKEM1024_public_key *external)
79808c63c71Sbeck {
79908c63c71Sbeck return (struct public_key *)external;
80008c63c71Sbeck }
80108c63c71Sbeck
80208c63c71Sbeck struct private_key {
80308c63c71Sbeck struct public_key pub;
80408c63c71Sbeck vector s;
80508c63c71Sbeck uint8_t fo_failure_secret[32];
80608c63c71Sbeck };
80708c63c71Sbeck
80808c63c71Sbeck static struct private_key *
private_key_1024_from_external(const struct MLKEM1024_private_key * external)80908c63c71Sbeck private_key_1024_from_external(const struct MLKEM1024_private_key *external)
81008c63c71Sbeck {
81108c63c71Sbeck return (struct private_key *)external;
81208c63c71Sbeck }
81308c63c71Sbeck
81408c63c71Sbeck /*
81508c63c71Sbeck * Calls |MLKEM1024_generate_key_external_entropy| with random bytes from
81608c63c71Sbeck * |RAND_bytes|.
81708c63c71Sbeck */
81808c63c71Sbeck void
MLKEM1024_generate_key(uint8_t out_encoded_public_key[MLKEM1024_PUBLIC_KEY_BYTES],uint8_t optional_out_seed[MLKEM_SEED_BYTES],struct MLKEM1024_private_key * out_private_key)81908c63c71Sbeck MLKEM1024_generate_key(uint8_t out_encoded_public_key[MLKEM1024_PUBLIC_KEY_BYTES],
82008c63c71Sbeck uint8_t optional_out_seed[MLKEM_SEED_BYTES],
82108c63c71Sbeck struct MLKEM1024_private_key *out_private_key)
82208c63c71Sbeck {
82308c63c71Sbeck uint8_t entropy_buf[MLKEM_SEED_BYTES];
82408c63c71Sbeck uint8_t *entropy = optional_out_seed != NULL ? optional_out_seed :
82508c63c71Sbeck entropy_buf;
82608c63c71Sbeck
82708c63c71Sbeck arc4random_buf(entropy, MLKEM_SEED_BYTES);
82808c63c71Sbeck MLKEM1024_generate_key_external_entropy(out_encoded_public_key,
82908c63c71Sbeck out_private_key, entropy);
83008c63c71Sbeck }
83108c63c71Sbeck LCRYPTO_ALIAS(MLKEM1024_generate_key);
83208c63c71Sbeck
83308c63c71Sbeck int
MLKEM1024_private_key_from_seed(struct MLKEM1024_private_key * out_private_key,const uint8_t * seed,size_t seed_len)83408c63c71Sbeck MLKEM1024_private_key_from_seed(struct MLKEM1024_private_key *out_private_key,
83508c63c71Sbeck const uint8_t *seed, size_t seed_len)
83608c63c71Sbeck {
83708c63c71Sbeck uint8_t public_key_bytes[MLKEM1024_PUBLIC_KEY_BYTES];
83808c63c71Sbeck
83908c63c71Sbeck if (seed_len != MLKEM_SEED_BYTES) {
84008c63c71Sbeck return 0;
84108c63c71Sbeck }
84208c63c71Sbeck MLKEM1024_generate_key_external_entropy(public_key_bytes,
84308c63c71Sbeck out_private_key, seed);
84408c63c71Sbeck
84508c63c71Sbeck return 1;
84608c63c71Sbeck }
84708c63c71Sbeck LCRYPTO_ALIAS(MLKEM1024_private_key_from_seed);
84808c63c71Sbeck
84908c63c71Sbeck static int
mlkem_marshal_public_key(CBB * out,const struct public_key * pub)85008c63c71Sbeck mlkem_marshal_public_key(CBB *out, const struct public_key *pub)
85108c63c71Sbeck {
85208c63c71Sbeck uint8_t *vector_output;
85308c63c71Sbeck
85408c63c71Sbeck if (!CBB_add_space(out, &vector_output, kEncodedVectorSize)) {
85508c63c71Sbeck return 0;
85608c63c71Sbeck }
85708c63c71Sbeck vector_encode(vector_output, &pub->t, kLog2Prime);
85808c63c71Sbeck if (!CBB_add_bytes(out, pub->rho, sizeof(pub->rho))) {
85908c63c71Sbeck return 0;
86008c63c71Sbeck }
86108c63c71Sbeck return 1;
86208c63c71Sbeck }
86308c63c71Sbeck
86408c63c71Sbeck void
MLKEM1024_generate_key_external_entropy(uint8_t out_encoded_public_key[MLKEM1024_PUBLIC_KEY_BYTES],struct MLKEM1024_private_key * out_private_key,const uint8_t entropy[MLKEM_SEED_BYTES])86508c63c71Sbeck MLKEM1024_generate_key_external_entropy(
86608c63c71Sbeck uint8_t out_encoded_public_key[MLKEM1024_PUBLIC_KEY_BYTES],
86708c63c71Sbeck struct MLKEM1024_private_key *out_private_key,
86808c63c71Sbeck const uint8_t entropy[MLKEM_SEED_BYTES])
86908c63c71Sbeck {
87008c63c71Sbeck struct private_key *priv = private_key_1024_from_external(
87108c63c71Sbeck out_private_key);
87208c63c71Sbeck uint8_t augmented_seed[33];
87308c63c71Sbeck uint8_t *rho, *sigma;
87408c63c71Sbeck uint8_t counter = 0;
87508c63c71Sbeck uint8_t hashed[64];
87608c63c71Sbeck vector error;
87708c63c71Sbeck CBB cbb;
87808c63c71Sbeck
87908c63c71Sbeck memcpy(augmented_seed, entropy, 32);
88008c63c71Sbeck augmented_seed[32] = RANK1024;
88108c63c71Sbeck hash_g(hashed, augmented_seed, 33);
88208c63c71Sbeck rho = hashed;
88308c63c71Sbeck sigma = hashed + 32;
88408c63c71Sbeck memcpy(priv->pub.rho, hashed, sizeof(priv->pub.rho));
88508c63c71Sbeck matrix_expand(&priv->pub.m, rho);
88608c63c71Sbeck vector_generate_secret_eta_2(&priv->s, &counter, sigma);
88708c63c71Sbeck vector_ntt(&priv->s);
88808c63c71Sbeck vector_generate_secret_eta_2(&error, &counter, sigma);
88908c63c71Sbeck vector_ntt(&error);
89008c63c71Sbeck matrix_mult_transpose(&priv->pub.t, &priv->pub.m, &priv->s);
89108c63c71Sbeck vector_add(&priv->pub.t, &error);
89208c63c71Sbeck
893516824a3Stb /* XXX - error checking. */
89408c63c71Sbeck CBB_init_fixed(&cbb, out_encoded_public_key, MLKEM1024_PUBLIC_KEY_BYTES);
89508c63c71Sbeck if (!mlkem_marshal_public_key(&cbb, &priv->pub)) {
89608c63c71Sbeck abort();
89708c63c71Sbeck }
898516824a3Stb CBB_cleanup(&cbb);
89908c63c71Sbeck
90008c63c71Sbeck hash_h(priv->pub.public_key_hash, out_encoded_public_key,
90108c63c71Sbeck MLKEM1024_PUBLIC_KEY_BYTES);
90208c63c71Sbeck memcpy(priv->fo_failure_secret, entropy + 32, 32);
90308c63c71Sbeck }
90408c63c71Sbeck
90508c63c71Sbeck void
MLKEM1024_public_from_private(struct MLKEM1024_public_key * out_public_key,const struct MLKEM1024_private_key * private_key)90608c63c71Sbeck MLKEM1024_public_from_private(struct MLKEM1024_public_key *out_public_key,
90708c63c71Sbeck const struct MLKEM1024_private_key *private_key)
90808c63c71Sbeck {
90908c63c71Sbeck struct public_key *const pub = public_key_1024_from_external(
91008c63c71Sbeck out_public_key);
91108c63c71Sbeck const struct private_key *const priv = private_key_1024_from_external(
91208c63c71Sbeck private_key);
91308c63c71Sbeck
91408c63c71Sbeck *pub = priv->pub;
91508c63c71Sbeck }
91608c63c71Sbeck LCRYPTO_ALIAS(MLKEM1024_public_from_private);
91708c63c71Sbeck
91808c63c71Sbeck /*
91908c63c71Sbeck * Encrypts a message with given randomness to the ciphertext in |out|. Without
92008c63c71Sbeck * applying the Fujisaki-Okamoto transform this would not result in a CCA secure
92108c63c71Sbeck * scheme, since lattice schemes are vulnerable to decryption failure oracles.
92208c63c71Sbeck */
92308c63c71Sbeck static void
encrypt_cpa(uint8_t out[MLKEM1024_CIPHERTEXT_BYTES],const struct public_key * pub,const uint8_t message[32],const uint8_t randomness[32])92408c63c71Sbeck encrypt_cpa(uint8_t out[MLKEM1024_CIPHERTEXT_BYTES],
92508c63c71Sbeck const struct public_key *pub, const uint8_t message[32],
92608c63c71Sbeck const uint8_t randomness[32])
92708c63c71Sbeck {
92808c63c71Sbeck scalar expanded_message, scalar_error;
92908c63c71Sbeck vector secret, error, u;
93008c63c71Sbeck uint8_t counter = 0;
93108c63c71Sbeck uint8_t input[33];
93208c63c71Sbeck scalar v;
93308c63c71Sbeck
93408c63c71Sbeck vector_generate_secret_eta_2(&secret, &counter, randomness);
93508c63c71Sbeck vector_ntt(&secret);
93608c63c71Sbeck vector_generate_secret_eta_2(&error, &counter, randomness);
93708c63c71Sbeck memcpy(input, randomness, 32);
93808c63c71Sbeck input[32] = counter;
93908c63c71Sbeck scalar_centered_binomial_distribution_eta_2_with_prf(&scalar_error,
94008c63c71Sbeck input);
94108c63c71Sbeck matrix_mult(&u, &pub->m, &secret);
94208c63c71Sbeck vector_inverse_ntt(&u);
94308c63c71Sbeck vector_add(&u, &error);
94408c63c71Sbeck scalar_inner_product(&v, &pub->t, &secret);
94508c63c71Sbeck scalar_inverse_ntt(&v);
94608c63c71Sbeck scalar_add(&v, &scalar_error);
94708c63c71Sbeck scalar_decode_1(&expanded_message, message);
94808c63c71Sbeck scalar_decompress(&expanded_message, 1);
94908c63c71Sbeck scalar_add(&v, &expanded_message);
95008c63c71Sbeck vector_compress(&u, kDU1024);
95108c63c71Sbeck vector_encode(out, &u, kDU1024);
95208c63c71Sbeck scalar_compress(&v, kDV1024);
95308c63c71Sbeck scalar_encode(out + kCompressedVectorSize, &v, kDV1024);
95408c63c71Sbeck }
95508c63c71Sbeck
95608c63c71Sbeck /* Calls MLKEM1024_encap_external_entropy| with random bytes */
95708c63c71Sbeck void
MLKEM1024_encap(uint8_t out_ciphertext[MLKEM1024_CIPHERTEXT_BYTES],uint8_t out_shared_secret[MLKEM_SHARED_SECRET_BYTES],const struct MLKEM1024_public_key * public_key)95808c63c71Sbeck MLKEM1024_encap(uint8_t out_ciphertext[MLKEM1024_CIPHERTEXT_BYTES],
95908c63c71Sbeck uint8_t out_shared_secret[MLKEM_SHARED_SECRET_BYTES],
96008c63c71Sbeck const struct MLKEM1024_public_key *public_key)
96108c63c71Sbeck {
96208c63c71Sbeck uint8_t entropy[MLKEM_ENCAP_ENTROPY];
96308c63c71Sbeck
96408c63c71Sbeck arc4random_buf(entropy, MLKEM_ENCAP_ENTROPY);
96508c63c71Sbeck MLKEM1024_encap_external_entropy(out_ciphertext, out_shared_secret,
96608c63c71Sbeck public_key, entropy);
96708c63c71Sbeck }
96808c63c71Sbeck LCRYPTO_ALIAS(MLKEM1024_encap);
96908c63c71Sbeck
97008c63c71Sbeck /* See section 6.2 of the spec. */
97108c63c71Sbeck void
MLKEM1024_encap_external_entropy(uint8_t out_ciphertext[MLKEM1024_CIPHERTEXT_BYTES],uint8_t out_shared_secret[MLKEM_SHARED_SECRET_BYTES],const struct MLKEM1024_public_key * public_key,const uint8_t entropy[MLKEM_ENCAP_ENTROPY])97208c63c71Sbeck MLKEM1024_encap_external_entropy(
97308c63c71Sbeck uint8_t out_ciphertext[MLKEM1024_CIPHERTEXT_BYTES],
97408c63c71Sbeck uint8_t out_shared_secret[MLKEM_SHARED_SECRET_BYTES],
97508c63c71Sbeck const struct MLKEM1024_public_key *public_key,
97608c63c71Sbeck const uint8_t entropy[MLKEM_ENCAP_ENTROPY])
97708c63c71Sbeck {
97808c63c71Sbeck const struct public_key *pub = public_key_1024_from_external(public_key);
97908c63c71Sbeck uint8_t key_and_randomness[64];
98008c63c71Sbeck uint8_t input[64];
98108c63c71Sbeck
98208c63c71Sbeck memcpy(input, entropy, MLKEM_ENCAP_ENTROPY);
98308c63c71Sbeck memcpy(input + MLKEM_ENCAP_ENTROPY, pub->public_key_hash,
98408c63c71Sbeck sizeof(input) - MLKEM_ENCAP_ENTROPY);
98508c63c71Sbeck hash_g(key_and_randomness, input, sizeof(input));
98608c63c71Sbeck encrypt_cpa(out_ciphertext, pub, entropy, key_and_randomness + 32);
98708c63c71Sbeck memcpy(out_shared_secret, key_and_randomness, 32);
98808c63c71Sbeck }
98908c63c71Sbeck
99008c63c71Sbeck static void
decrypt_cpa(uint8_t out[32],const struct private_key * priv,const uint8_t ciphertext[MLKEM1024_CIPHERTEXT_BYTES])99108c63c71Sbeck decrypt_cpa(uint8_t out[32], const struct private_key *priv,
99208c63c71Sbeck const uint8_t ciphertext[MLKEM1024_CIPHERTEXT_BYTES])
99308c63c71Sbeck {
99408c63c71Sbeck scalar mask, v;
99508c63c71Sbeck vector u;
99608c63c71Sbeck
99708c63c71Sbeck vector_decode(&u, ciphertext, kDU1024);
99808c63c71Sbeck vector_decompress(&u, kDU1024);
99908c63c71Sbeck vector_ntt(&u);
100008c63c71Sbeck scalar_decode(&v, ciphertext + kCompressedVectorSize, kDV1024);
100108c63c71Sbeck scalar_decompress(&v, kDV1024);
100208c63c71Sbeck scalar_inner_product(&mask, &priv->s, &u);
100308c63c71Sbeck scalar_inverse_ntt(&mask);
100408c63c71Sbeck scalar_sub(&v, &mask);
100508c63c71Sbeck scalar_compress(&v, 1);
100608c63c71Sbeck scalar_encode_1(out, &v);
100708c63c71Sbeck }
100808c63c71Sbeck
100908c63c71Sbeck /* See section 6.3 */
101008c63c71Sbeck int
MLKEM1024_decap(uint8_t out_shared_secret[MLKEM_SHARED_SECRET_BYTES],const uint8_t * ciphertext,size_t ciphertext_len,const struct MLKEM1024_private_key * private_key)101108c63c71Sbeck MLKEM1024_decap(uint8_t out_shared_secret[MLKEM_SHARED_SECRET_BYTES],
101208c63c71Sbeck const uint8_t *ciphertext, size_t ciphertext_len,
101308c63c71Sbeck const struct MLKEM1024_private_key *private_key)
101408c63c71Sbeck {
101508c63c71Sbeck const struct private_key *priv = private_key_1024_from_external(
101608c63c71Sbeck private_key);
101708c63c71Sbeck uint8_t expected_ciphertext[MLKEM1024_CIPHERTEXT_BYTES];
101808c63c71Sbeck uint8_t key_and_randomness[64];
101908c63c71Sbeck uint8_t failure_key[32];
102008c63c71Sbeck uint8_t decrypted[64];
102108c63c71Sbeck uint8_t mask;
102208c63c71Sbeck int i;
102308c63c71Sbeck
102408c63c71Sbeck if (ciphertext_len != MLKEM1024_CIPHERTEXT_BYTES) {
102508c63c71Sbeck arc4random_buf(out_shared_secret, MLKEM_SHARED_SECRET_BYTES);
102608c63c71Sbeck return 0;
102708c63c71Sbeck }
102808c63c71Sbeck
102908c63c71Sbeck decrypt_cpa(decrypted, priv, ciphertext);
103008c63c71Sbeck memcpy(decrypted + 32, priv->pub.public_key_hash,
103108c63c71Sbeck sizeof(decrypted) - 32);
103208c63c71Sbeck hash_g(key_and_randomness, decrypted, sizeof(decrypted));
103308c63c71Sbeck encrypt_cpa(expected_ciphertext, &priv->pub, decrypted,
103408c63c71Sbeck key_and_randomness + 32);
103508c63c71Sbeck kdf(failure_key, priv->fo_failure_secret, ciphertext, ciphertext_len);
103608c63c71Sbeck mask = constant_time_eq_int_8(memcmp(ciphertext, expected_ciphertext,
103708c63c71Sbeck sizeof(expected_ciphertext)), 0);
103808c63c71Sbeck for (i = 0; i < MLKEM_SHARED_SECRET_BYTES; i++) {
103908c63c71Sbeck out_shared_secret[i] = constant_time_select_8(mask,
104008c63c71Sbeck key_and_randomness[i], failure_key[i]);
104108c63c71Sbeck }
104208c63c71Sbeck
104308c63c71Sbeck return 1;
104408c63c71Sbeck }
104508c63c71Sbeck LCRYPTO_ALIAS(MLKEM1024_decap);
104608c63c71Sbeck
104708c63c71Sbeck int
MLKEM1024_marshal_public_key(CBB * out,const struct MLKEM1024_public_key * public_key)104808c63c71Sbeck MLKEM1024_marshal_public_key(CBB *out,
104908c63c71Sbeck const struct MLKEM1024_public_key *public_key)
105008c63c71Sbeck {
105108c63c71Sbeck return mlkem_marshal_public_key(out,
105208c63c71Sbeck public_key_1024_from_external(public_key));
105308c63c71Sbeck }
105408c63c71Sbeck LCRYPTO_ALIAS(MLKEM1024_marshal_public_key);
105508c63c71Sbeck
105608c63c71Sbeck /*
105708c63c71Sbeck * mlkem_parse_public_key_no_hash parses |in| into |pub| but doesn't calculate
105808c63c71Sbeck * the value of |pub->public_key_hash|.
105908c63c71Sbeck */
106008c63c71Sbeck static int
mlkem_parse_public_key_no_hash(struct public_key * pub,CBS * in)106108c63c71Sbeck mlkem_parse_public_key_no_hash(struct public_key *pub, CBS *in)
106208c63c71Sbeck {
106308c63c71Sbeck CBS t_bytes;
106408c63c71Sbeck
106508c63c71Sbeck if (!CBS_get_bytes(in, &t_bytes, kEncodedVectorSize) ||
106608c63c71Sbeck !vector_decode(&pub->t, CBS_data(&t_bytes), kLog2Prime)) {
106708c63c71Sbeck return 0;
106808c63c71Sbeck }
106908c63c71Sbeck memcpy(pub->rho, CBS_data(in), sizeof(pub->rho));
107008c63c71Sbeck if (!CBS_skip(in, sizeof(pub->rho)))
107108c63c71Sbeck return 0;
107208c63c71Sbeck matrix_expand(&pub->m, pub->rho);
107308c63c71Sbeck return 1;
107408c63c71Sbeck }
107508c63c71Sbeck
107608c63c71Sbeck int
MLKEM1024_parse_public_key(struct MLKEM1024_public_key * public_key,CBS * in)107708c63c71Sbeck MLKEM1024_parse_public_key(struct MLKEM1024_public_key *public_key, CBS *in)
107808c63c71Sbeck {
107908c63c71Sbeck struct public_key *pub = public_key_1024_from_external(public_key);
108008c63c71Sbeck CBS orig_in = *in;
108108c63c71Sbeck
108208c63c71Sbeck if (!mlkem_parse_public_key_no_hash(pub, in) ||
108308c63c71Sbeck CBS_len(in) != 0) {
108408c63c71Sbeck return 0;
108508c63c71Sbeck }
108608c63c71Sbeck hash_h(pub->public_key_hash, CBS_data(&orig_in), CBS_len(&orig_in));
108708c63c71Sbeck return 1;
108808c63c71Sbeck }
108908c63c71Sbeck LCRYPTO_ALIAS(MLKEM1024_parse_public_key);
109008c63c71Sbeck
109108c63c71Sbeck int
MLKEM1024_marshal_private_key(CBB * out,const struct MLKEM1024_private_key * private_key)109208c63c71Sbeck MLKEM1024_marshal_private_key(CBB *out,
109308c63c71Sbeck const struct MLKEM1024_private_key *private_key)
109408c63c71Sbeck {
109508c63c71Sbeck const struct private_key *const priv = private_key_1024_from_external(
109608c63c71Sbeck private_key);
109708c63c71Sbeck uint8_t *s_output;
109808c63c71Sbeck
109908c63c71Sbeck if (!CBB_add_space(out, &s_output, kEncodedVectorSize)) {
110008c63c71Sbeck return 0;
110108c63c71Sbeck }
110208c63c71Sbeck vector_encode(s_output, &priv->s, kLog2Prime);
110308c63c71Sbeck if (!mlkem_marshal_public_key(out, &priv->pub) ||
110408c63c71Sbeck !CBB_add_bytes(out, priv->pub.public_key_hash,
110508c63c71Sbeck sizeof(priv->pub.public_key_hash)) ||
110608c63c71Sbeck !CBB_add_bytes(out, priv->fo_failure_secret,
110708c63c71Sbeck sizeof(priv->fo_failure_secret))) {
110808c63c71Sbeck return 0;
110908c63c71Sbeck }
111008c63c71Sbeck return 1;
111108c63c71Sbeck }
111208c63c71Sbeck
111308c63c71Sbeck int
MLKEM1024_parse_private_key(struct MLKEM1024_private_key * out_private_key,CBS * in)111408c63c71Sbeck MLKEM1024_parse_private_key(struct MLKEM1024_private_key *out_private_key,
111508c63c71Sbeck CBS *in)
111608c63c71Sbeck {
111708c63c71Sbeck struct private_key *const priv = private_key_1024_from_external(
111808c63c71Sbeck out_private_key);
111908c63c71Sbeck CBS s_bytes;
112008c63c71Sbeck
112108c63c71Sbeck if (!CBS_get_bytes(in, &s_bytes, kEncodedVectorSize) ||
112208c63c71Sbeck !vector_decode(&priv->s, CBS_data(&s_bytes), kLog2Prime) ||
112308c63c71Sbeck !mlkem_parse_public_key_no_hash(&priv->pub, in)) {
112408c63c71Sbeck return 0;
112508c63c71Sbeck }
112608c63c71Sbeck memcpy(priv->pub.public_key_hash, CBS_data(in),
112708c63c71Sbeck sizeof(priv->pub.public_key_hash));
112808c63c71Sbeck if (!CBS_skip(in, sizeof(priv->pub.public_key_hash)))
112908c63c71Sbeck return 0;
113008c63c71Sbeck memcpy(priv->fo_failure_secret, CBS_data(in),
113108c63c71Sbeck sizeof(priv->fo_failure_secret));
113208c63c71Sbeck if (!CBS_skip(in, sizeof(priv->fo_failure_secret)))
113308c63c71Sbeck return 0;
113408c63c71Sbeck if (CBS_len(in) != 0)
113508c63c71Sbeck return 0;
113608c63c71Sbeck
113708c63c71Sbeck return 1;
113808c63c71Sbeck }
113908c63c71Sbeck LCRYPTO_ALIAS(MLKEM1024_parse_private_key);
1140