1*67abc7a1Stb /* $OpenBSD: mlkem1024.c,v 1.4 2024/12/18 10:55:21 tb Exp $ */ 208c63c71Sbeck /* 308c63c71Sbeck * Copyright (c) 2024, Google Inc. 408c63c71Sbeck * Copyright (c) 2024, Bob Beck <beck@obtuse.com> 508c63c71Sbeck * 608c63c71Sbeck * Permission to use, copy, modify, and/or distribute this software for any 708c63c71Sbeck * purpose with or without fee is hereby granted, provided that the above 808c63c71Sbeck * copyright notice and this permission notice appear in all copies. 908c63c71Sbeck * 1008c63c71Sbeck * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES 1108c63c71Sbeck * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF 1208c63c71Sbeck * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY 1308c63c71Sbeck * SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES 1408c63c71Sbeck * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION 1508c63c71Sbeck * OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN 1608c63c71Sbeck * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 1708c63c71Sbeck */ 1808c63c71Sbeck 1908c63c71Sbeck #include <openssl/mlkem.h> 2008c63c71Sbeck 2108c63c71Sbeck #include <assert.h> 2208c63c71Sbeck #include <stdlib.h> 2308c63c71Sbeck #include <string.h> 2408c63c71Sbeck 2508c63c71Sbeck #include "bytestring.h" 2608c63c71Sbeck 2708c63c71Sbeck #include "sha3_internal.h" 2808c63c71Sbeck #include "mlkem_internal.h" 2908c63c71Sbeck #include "constant_time.h" 3008c63c71Sbeck #include "crypto_internal.h" 3108c63c71Sbeck 3208c63c71Sbeck /* Remove later */ 3308c63c71Sbeck #undef LCRYPTO_ALIAS 3408c63c71Sbeck #define LCRYPTO_ALIAS(A) 3508c63c71Sbeck 3608c63c71Sbeck /* 3708c63c71Sbeck * See 3808c63c71Sbeck * https://csrc.nist.gov/pubs/fips/203/final 3908c63c71Sbeck */ 4008c63c71Sbeck 4108c63c71Sbeck static void 4208c63c71Sbeck prf(uint8_t *out, size_t out_len, const uint8_t in[33]) 4308c63c71Sbeck { 4408c63c71Sbeck sha3_ctx ctx; 4508c63c71Sbeck shake256_init(&ctx); 4608c63c71Sbeck shake_update(&ctx, in, 33); 4708c63c71Sbeck shake_xof(&ctx); 4808c63c71Sbeck shake_out(&ctx, out, out_len); 4908c63c71Sbeck } 5008c63c71Sbeck 5108c63c71Sbeck /* Section 4.1 */ 5208c63c71Sbeck static void 5308c63c71Sbeck hash_h(uint8_t out[32], const uint8_t *in, size_t len) 5408c63c71Sbeck { 5508c63c71Sbeck sha3_ctx ctx; 5608c63c71Sbeck sha3_init(&ctx, 32); 5708c63c71Sbeck sha3_update(&ctx, in, len); 5808c63c71Sbeck sha3_final(out, &ctx); 5908c63c71Sbeck } 6008c63c71Sbeck 6108c63c71Sbeck static void 6208c63c71Sbeck hash_g(uint8_t out[64], const uint8_t *in, size_t len) 6308c63c71Sbeck { 6408c63c71Sbeck sha3_ctx ctx; 6508c63c71Sbeck sha3_init(&ctx, 64); 6608c63c71Sbeck sha3_update(&ctx, in, len); 6708c63c71Sbeck sha3_final(out, &ctx); 6808c63c71Sbeck } 6908c63c71Sbeck 7008c63c71Sbeck /* this is called 'J' in the spec */ 7108c63c71Sbeck static void 7208c63c71Sbeck kdf(uint8_t out[MLKEM_SHARED_SECRET_BYTES], const uint8_t failure_secret[32], 7308c63c71Sbeck const uint8_t *in, size_t len) 7408c63c71Sbeck { 7508c63c71Sbeck sha3_ctx ctx; 7608c63c71Sbeck shake256_init(&ctx); 7708c63c71Sbeck shake_update(&ctx, failure_secret, 32); 7808c63c71Sbeck shake_update(&ctx, in, len); 7908c63c71Sbeck shake_xof(&ctx); 8008c63c71Sbeck shake_out(&ctx, out, MLKEM_SHARED_SECRET_BYTES); 8108c63c71Sbeck } 8208c63c71Sbeck 8308c63c71Sbeck #define DEGREE 256 8408c63c71Sbeck #define RANK1024 4 8508c63c71Sbeck 8608c63c71Sbeck static const size_t kBarrettMultiplier = 5039; 8708c63c71Sbeck static const unsigned kBarrettShift = 24; 8808c63c71Sbeck static const uint16_t kPrime = 3329; 8908c63c71Sbeck static const int kLog2Prime = 12; 9008c63c71Sbeck static const uint16_t kHalfPrime = (/*kPrime=*/3329 - 1) / 2; 9108c63c71Sbeck static const int kDU1024 = 11; 9208c63c71Sbeck static const int kDV1024 = 5; 9308c63c71Sbeck 9408c63c71Sbeck /* 9508c63c71Sbeck * kInverseDegree is 128^-1 mod 3329; 128 because kPrime does not have a 512th 9608c63c71Sbeck * root of unity. 9708c63c71Sbeck */ 9808c63c71Sbeck static const uint16_t kInverseDegree = 3303; 9908c63c71Sbeck static const size_t kEncodedVectorSize = 10008c63c71Sbeck (/*kLog2Prime=*/12 * DEGREE / 8) * RANK1024; 10108c63c71Sbeck static const size_t kCompressedVectorSize = /*kDU1024=*/ 11 * RANK1024 * DEGREE / 10208c63c71Sbeck 8; 10308c63c71Sbeck 10408c63c71Sbeck typedef struct scalar { 10508c63c71Sbeck /* On every function entry and exit, 0 <= c < kPrime. */ 10608c63c71Sbeck uint16_t c[DEGREE]; 10708c63c71Sbeck } scalar; 10808c63c71Sbeck 10908c63c71Sbeck typedef struct vector { 11008c63c71Sbeck scalar v[RANK1024]; 11108c63c71Sbeck } vector; 11208c63c71Sbeck 11308c63c71Sbeck typedef struct matrix { 11408c63c71Sbeck scalar v[RANK1024][RANK1024]; 11508c63c71Sbeck } matrix; 11608c63c71Sbeck 11708c63c71Sbeck /* 11808c63c71Sbeck * This bit of Python will be referenced in some of the following comments: 11908c63c71Sbeck * 12008c63c71Sbeck * p = 3329 12108c63c71Sbeck * 12208c63c71Sbeck * def bitreverse(i): 12308c63c71Sbeck * ret = 0 12408c63c71Sbeck * for n in range(7): 12508c63c71Sbeck * bit = i & 1 12608c63c71Sbeck * ret <<= 1 12708c63c71Sbeck * ret |= bit 12808c63c71Sbeck * i >>= 1 12908c63c71Sbeck * return ret 13008c63c71Sbeck */ 13108c63c71Sbeck 13208c63c71Sbeck /* kNTTRoots = [pow(17, bitreverse(i), p) for i in range(128)] */ 13308c63c71Sbeck static const uint16_t kNTTRoots[128] = { 13408c63c71Sbeck 1, 1729, 2580, 3289, 2642, 630, 1897, 848, 1062, 1919, 193, 797, 13508c63c71Sbeck 2786, 3260, 569, 1746, 296, 2447, 1339, 1476, 3046, 56, 2240, 1333, 13608c63c71Sbeck 1426, 2094, 535, 2882, 2393, 2879, 1974, 821, 289, 331, 3253, 1756, 13708c63c71Sbeck 1197, 2304, 2277, 2055, 650, 1977, 2513, 632, 2865, 33, 1320, 1915, 13808c63c71Sbeck 2319, 1435, 807, 452, 1438, 2868, 1534, 2402, 2647, 2617, 1481, 648, 13908c63c71Sbeck 2474, 3110, 1227, 910, 17, 2761, 583, 2649, 1637, 723, 2288, 1100, 14008c63c71Sbeck 1409, 2662, 3281, 233, 756, 2156, 3015, 3050, 1703, 1651, 2789, 1789, 14108c63c71Sbeck 1847, 952, 1461, 2687, 939, 2308, 2437, 2388, 733, 2337, 268, 641, 14208c63c71Sbeck 1584, 2298, 2037, 3220, 375, 2549, 2090, 1645, 1063, 319, 2773, 757, 14308c63c71Sbeck 2099, 561, 2466, 2594, 2804, 1092, 403, 1026, 1143, 2150, 2775, 886, 14408c63c71Sbeck 1722, 1212, 1874, 1029, 2110, 2935, 885, 2154, 14508c63c71Sbeck }; 14608c63c71Sbeck 14708c63c71Sbeck /* kInverseNTTRoots = [pow(17, -bitreverse(i), p) for i in range(128)] */ 14808c63c71Sbeck static const uint16_t kInverseNTTRoots[128] = { 14908c63c71Sbeck 1, 1600, 40, 749, 2481, 1432, 2699, 687, 1583, 2760, 69, 543, 15008c63c71Sbeck 2532, 3136, 1410, 2267, 2508, 1355, 450, 936, 447, 2794, 1235, 1903, 15108c63c71Sbeck 1996, 1089, 3273, 283, 1853, 1990, 882, 3033, 2419, 2102, 219, 855, 15208c63c71Sbeck 2681, 1848, 712, 682, 927, 1795, 461, 1891, 2877, 2522, 1894, 1010, 15308c63c71Sbeck 1414, 2009, 3296, 464, 2697, 816, 1352, 2679, 1274, 1052, 1025, 2132, 15408c63c71Sbeck 1573, 76, 2998, 3040, 1175, 2444, 394, 1219, 2300, 1455, 2117, 1607, 15508c63c71Sbeck 2443, 554, 1179, 2186, 2303, 2926, 2237, 525, 735, 863, 2768, 1230, 15608c63c71Sbeck 2572, 556, 3010, 2266, 1684, 1239, 780, 2954, 109, 1292, 1031, 1745, 15708c63c71Sbeck 2688, 3061, 992, 2596, 941, 892, 1021, 2390, 642, 1868, 2377, 1482, 15808c63c71Sbeck 1540, 540, 1678, 1626, 279, 314, 1173, 2573, 3096, 48, 667, 1920, 15908c63c71Sbeck 2229, 1041, 2606, 1692, 680, 2746, 568, 3312, 16008c63c71Sbeck }; 16108c63c71Sbeck 16208c63c71Sbeck /* kModRoots = [pow(17, 2*bitreverse(i) + 1, p) for i in range(128)] */ 16308c63c71Sbeck static const uint16_t kModRoots[128] = { 16408c63c71Sbeck 17, 3312, 2761, 568, 583, 2746, 2649, 680, 1637, 1692, 723, 2606, 16508c63c71Sbeck 2288, 1041, 1100, 2229, 1409, 1920, 2662, 667, 3281, 48, 233, 3096, 16608c63c71Sbeck 756, 2573, 2156, 1173, 3015, 314, 3050, 279, 1703, 1626, 1651, 1678, 16708c63c71Sbeck 2789, 540, 1789, 1540, 1847, 1482, 952, 2377, 1461, 1868, 2687, 642, 16808c63c71Sbeck 939, 2390, 2308, 1021, 2437, 892, 2388, 941, 733, 2596, 2337, 992, 16908c63c71Sbeck 268, 3061, 641, 2688, 1584, 1745, 2298, 1031, 2037, 1292, 3220, 109, 17008c63c71Sbeck 375, 2954, 2549, 780, 2090, 1239, 1645, 1684, 1063, 2266, 319, 3010, 17108c63c71Sbeck 2773, 556, 757, 2572, 2099, 1230, 561, 2768, 2466, 863, 2594, 735, 17208c63c71Sbeck 2804, 525, 1092, 2237, 403, 2926, 1026, 2303, 1143, 2186, 2150, 1179, 17308c63c71Sbeck 2775, 554, 886, 2443, 1722, 1607, 1212, 2117, 1874, 1455, 1029, 2300, 17408c63c71Sbeck 2110, 1219, 2935, 394, 885, 2444, 2154, 1175, 17508c63c71Sbeck }; 17608c63c71Sbeck 17708c63c71Sbeck /* reduce_once reduces 0 <= x < 2*kPrime, mod kPrime. */ 17808c63c71Sbeck static uint16_t 17908c63c71Sbeck reduce_once(uint16_t x) 18008c63c71Sbeck { 18108c63c71Sbeck assert(x < 2 * kPrime); 18208c63c71Sbeck const uint16_t subtracted = x - kPrime; 18308c63c71Sbeck uint16_t mask = 0u - (subtracted >> 15); 1849ee6f1feSbeck 18508c63c71Sbeck /* 1869ee6f1feSbeck * Although this is a constant-time select, we omit a value barrier here. 1879ee6f1feSbeck * Value barriers impede auto-vectorization (likely because it forces the 1889ee6f1feSbeck * value to transit through a general-purpose register). On AArch64, this 1899ee6f1feSbeck * is a difference of 2x. 1909ee6f1feSbeck * 1919ee6f1feSbeck * We usually add value barriers to selects because Clang turns 1929ee6f1feSbeck * consecutive selects with the same condition into a branch instead of 1939ee6f1feSbeck * CMOV/CSEL. This condition does not occur in ML-KEM, so omitting it 1949ee6f1feSbeck * seems to be safe so far but see 1959ee6f1feSbeck * |scalar_centered_binomial_distribution_eta_2_with_prf|. 19608c63c71Sbeck */ 19708c63c71Sbeck return (mask & x) | (~mask & subtracted); 19808c63c71Sbeck } 19908c63c71Sbeck 20008c63c71Sbeck /* 20108c63c71Sbeck * constant time reduce x mod kPrime using Barrett reduction. x must be less 20208c63c71Sbeck * than kPrime + 2×kPrime². 20308c63c71Sbeck */ 20408c63c71Sbeck static uint16_t 20508c63c71Sbeck reduce(uint32_t x) 20608c63c71Sbeck { 20708c63c71Sbeck uint64_t product = (uint64_t)x * kBarrettMultiplier; 20808c63c71Sbeck uint32_t quotient = (uint32_t)(product >> kBarrettShift); 20908c63c71Sbeck uint32_t remainder = x - quotient * kPrime; 21008c63c71Sbeck 21108c63c71Sbeck assert(x < kPrime + 2u * kPrime * kPrime); 21208c63c71Sbeck return reduce_once(remainder); 21308c63c71Sbeck } 21408c63c71Sbeck 21508c63c71Sbeck static void 21608c63c71Sbeck scalar_zero(scalar *out) 21708c63c71Sbeck { 21808c63c71Sbeck memset(out, 0, sizeof(*out)); 21908c63c71Sbeck } 22008c63c71Sbeck 22108c63c71Sbeck static void 22208c63c71Sbeck vector_zero(vector *out) 22308c63c71Sbeck { 22408c63c71Sbeck memset(out, 0, sizeof(*out)); 22508c63c71Sbeck } 22608c63c71Sbeck 22708c63c71Sbeck /* 22808c63c71Sbeck * In place number theoretic transform of a given scalar. 22908c63c71Sbeck * Note that MLKEM's kPrime 3329 does not have a 512th root of unity, so this 23008c63c71Sbeck * transform leaves off the last iteration of the usual FFT code, with the 128 23108c63c71Sbeck * relevant roots of unity being stored in |kNTTRoots|. This means the output 23208c63c71Sbeck * should be seen as 128 elements in GF(3329^2), with the coefficients of the 23308c63c71Sbeck * elements being consecutive entries in |s->c|. 23408c63c71Sbeck */ 23508c63c71Sbeck static void 23608c63c71Sbeck scalar_ntt(scalar *s) 23708c63c71Sbeck { 23808c63c71Sbeck int offset = DEGREE; 23908c63c71Sbeck int step; 24008c63c71Sbeck /* 24108c63c71Sbeck * `int` is used here because using `size_t` throughout caused a ~5% slowdown 24208c63c71Sbeck * with Clang 14 on Aarch64. 24308c63c71Sbeck */ 24408c63c71Sbeck for (step = 1; step < DEGREE / 2; step <<= 1) { 24508c63c71Sbeck int i, j, k = 0; 24608c63c71Sbeck 24708c63c71Sbeck offset >>= 1; 24808c63c71Sbeck for (i = 0; i < step; i++) { 24908c63c71Sbeck const uint32_t step_root = kNTTRoots[i + step]; 25008c63c71Sbeck 25108c63c71Sbeck for (j = k; j < k + offset; j++) { 25208c63c71Sbeck uint16_t odd, even; 25308c63c71Sbeck 25408c63c71Sbeck odd = reduce(step_root * s->c[j + offset]); 25508c63c71Sbeck even = s->c[j]; 25608c63c71Sbeck s->c[j] = reduce_once(odd + even); 25708c63c71Sbeck s->c[j + offset] = reduce_once(even - odd + 25808c63c71Sbeck kPrime); 25908c63c71Sbeck } 26008c63c71Sbeck k += 2 * offset; 26108c63c71Sbeck } 26208c63c71Sbeck } 26308c63c71Sbeck } 26408c63c71Sbeck 26508c63c71Sbeck static void 26608c63c71Sbeck vector_ntt(vector *a) 26708c63c71Sbeck { 26808c63c71Sbeck int i; 26908c63c71Sbeck 27008c63c71Sbeck for (i = 0; i < RANK1024; i++) { 27108c63c71Sbeck scalar_ntt(&a->v[i]); 27208c63c71Sbeck } 27308c63c71Sbeck } 27408c63c71Sbeck 27508c63c71Sbeck /* 27608c63c71Sbeck * In place inverse number theoretic transform of a given scalar, with pairs of 27708c63c71Sbeck * entries of s->v being interpreted as elements of GF(3329^2). Just as with the 27808c63c71Sbeck * number theoretic transform, this leaves off the first step of the normal iFFT 27908c63c71Sbeck * to account for the fact that 3329 does not have a 512th root of unity, using 28008c63c71Sbeck * the precomputed 128 roots of unity stored in |kInverseNTTRoots|. 28108c63c71Sbeck */ 28208c63c71Sbeck static void 28308c63c71Sbeck scalar_inverse_ntt(scalar *s) 28408c63c71Sbeck { 28508c63c71Sbeck int i, j, k, offset, step = DEGREE / 2; 28608c63c71Sbeck 28708c63c71Sbeck /* 28808c63c71Sbeck * `int` is used here because using `size_t` throughout caused a ~5% slowdown 28908c63c71Sbeck * with Clang 14 on Aarch64. 29008c63c71Sbeck */ 29108c63c71Sbeck for (offset = 2; offset < DEGREE; offset <<= 1) { 29208c63c71Sbeck step >>= 1; 29308c63c71Sbeck k = 0; 29408c63c71Sbeck for (i = 0; i < step; i++) { 29508c63c71Sbeck uint32_t step_root = kInverseNTTRoots[i + step]; 29608c63c71Sbeck for (j = k; j < k + offset; j++) { 29708c63c71Sbeck uint16_t odd, even; 29808c63c71Sbeck odd = s->c[j + offset]; 29908c63c71Sbeck even = s->c[j]; 30008c63c71Sbeck s->c[j] = reduce_once(odd + even); 30108c63c71Sbeck s->c[j + offset] = reduce(step_root * 30208c63c71Sbeck (even - odd + kPrime)); 30308c63c71Sbeck } 30408c63c71Sbeck k += 2 * offset; 30508c63c71Sbeck } 30608c63c71Sbeck } 30708c63c71Sbeck for (i = 0; i < DEGREE; i++) { 30808c63c71Sbeck s->c[i] = reduce(s->c[i] * kInverseDegree); 30908c63c71Sbeck } 31008c63c71Sbeck } 31108c63c71Sbeck 31208c63c71Sbeck static void 31308c63c71Sbeck vector_inverse_ntt(vector *a) 31408c63c71Sbeck { 31508c63c71Sbeck int i; 31608c63c71Sbeck 31708c63c71Sbeck for (i = 0; i < RANK1024; i++) { 31808c63c71Sbeck scalar_inverse_ntt(&a->v[i]); 31908c63c71Sbeck } 32008c63c71Sbeck } 32108c63c71Sbeck 32208c63c71Sbeck static void 32308c63c71Sbeck scalar_add(scalar *lhs, const scalar *rhs) 32408c63c71Sbeck { 32508c63c71Sbeck int i; 32608c63c71Sbeck 32708c63c71Sbeck for (i = 0; i < DEGREE; i++) { 32808c63c71Sbeck lhs->c[i] = reduce_once(lhs->c[i] + rhs->c[i]); 32908c63c71Sbeck } 33008c63c71Sbeck } 33108c63c71Sbeck 33208c63c71Sbeck static void 33308c63c71Sbeck scalar_sub(scalar *lhs, const scalar *rhs) 33408c63c71Sbeck { 33508c63c71Sbeck int i; 33608c63c71Sbeck 33708c63c71Sbeck for (i = 0; i < DEGREE; i++) { 33808c63c71Sbeck lhs->c[i] = reduce_once(lhs->c[i] - rhs->c[i] + kPrime); 33908c63c71Sbeck } 34008c63c71Sbeck } 34108c63c71Sbeck 34208c63c71Sbeck /* 34308c63c71Sbeck * Multiplying two scalars in the number theoretically transformed state. Since 34408c63c71Sbeck * 3329 does not have a 512th root of unity, this means we have to interpret 34508c63c71Sbeck * the 2*ith and (2*i+1)th entries of the scalar as elements of GF(3329)[X]/(X^2 34608c63c71Sbeck * - 17^(2*bitreverse(i)+1)) The value of 17^(2*bitreverse(i)+1) mod 3329 is 34708c63c71Sbeck * stored in the precomputed |kModRoots| table. Note that our Barrett transform 34808c63c71Sbeck * only allows us to multipy two reduced numbers together, so we need some 34908c63c71Sbeck * intermediate reduction steps, even if an uint64_t could hold 3 multiplied 35008c63c71Sbeck * numbers. 35108c63c71Sbeck */ 35208c63c71Sbeck static void 35308c63c71Sbeck scalar_mult(scalar *out, const scalar *lhs, const scalar *rhs) 35408c63c71Sbeck { 35508c63c71Sbeck int i; 35608c63c71Sbeck 35708c63c71Sbeck for (i = 0; i < DEGREE / 2; i++) { 35808c63c71Sbeck uint32_t real_real = (uint32_t)lhs->c[2 * i] * rhs->c[2 * i]; 35908c63c71Sbeck uint32_t img_img = (uint32_t)lhs->c[2 * i + 1] * 36008c63c71Sbeck rhs->c[2 * i + 1]; 36108c63c71Sbeck uint32_t real_img = (uint32_t)lhs->c[2 * i] * rhs->c[2 * i + 1]; 36208c63c71Sbeck uint32_t img_real = (uint32_t)lhs->c[2 * i + 1] * rhs->c[2 * i]; 36308c63c71Sbeck 36408c63c71Sbeck out->c[2 * i] = 36508c63c71Sbeck reduce(real_real + 36608c63c71Sbeck (uint32_t)reduce(img_img) * kModRoots[i]); 36708c63c71Sbeck out->c[2 * i + 1] = reduce(img_real + real_img); 36808c63c71Sbeck } 36908c63c71Sbeck } 37008c63c71Sbeck 37108c63c71Sbeck static void 37208c63c71Sbeck vector_add(vector *lhs, const vector *rhs) 37308c63c71Sbeck { 37408c63c71Sbeck int i; 37508c63c71Sbeck 37608c63c71Sbeck for (i = 0; i < RANK1024; i++) { 37708c63c71Sbeck scalar_add(&lhs->v[i], &rhs->v[i]); 37808c63c71Sbeck } 37908c63c71Sbeck } 38008c63c71Sbeck 38108c63c71Sbeck static void 38208c63c71Sbeck matrix_mult(vector *out, const matrix *m, const vector *a) 38308c63c71Sbeck { 38408c63c71Sbeck int i, j; 38508c63c71Sbeck 38608c63c71Sbeck vector_zero(out); 38708c63c71Sbeck for (i = 0; i < RANK1024; i++) { 38808c63c71Sbeck for (j = 0; j < RANK1024; j++) { 38908c63c71Sbeck scalar product; 39008c63c71Sbeck 39108c63c71Sbeck scalar_mult(&product, &m->v[i][j], &a->v[j]); 39208c63c71Sbeck scalar_add(&out->v[i], &product); 39308c63c71Sbeck } 39408c63c71Sbeck } 39508c63c71Sbeck } 39608c63c71Sbeck 39708c63c71Sbeck static void 39808c63c71Sbeck matrix_mult_transpose(vector *out, const matrix *m, 39908c63c71Sbeck const vector *a) 40008c63c71Sbeck { 40108c63c71Sbeck int i, j; 40208c63c71Sbeck 40308c63c71Sbeck vector_zero(out); 40408c63c71Sbeck for (i = 0; i < RANK1024; i++) { 40508c63c71Sbeck for (j = 0; j < RANK1024; j++) { 40608c63c71Sbeck scalar product; 40708c63c71Sbeck 40808c63c71Sbeck scalar_mult(&product, &m->v[j][i], &a->v[j]); 40908c63c71Sbeck scalar_add(&out->v[i], &product); 41008c63c71Sbeck } 41108c63c71Sbeck } 41208c63c71Sbeck } 41308c63c71Sbeck 41408c63c71Sbeck static void 41508c63c71Sbeck scalar_inner_product(scalar *out, const vector *lhs, 41608c63c71Sbeck const vector *rhs) 41708c63c71Sbeck { 41808c63c71Sbeck int i; 41908c63c71Sbeck scalar_zero(out); 42008c63c71Sbeck for (i = 0; i < RANK1024; i++) { 42108c63c71Sbeck scalar product; 42208c63c71Sbeck 42308c63c71Sbeck scalar_mult(&product, &lhs->v[i], &rhs->v[i]); 42408c63c71Sbeck scalar_add(out, &product); 42508c63c71Sbeck } 42608c63c71Sbeck } 42708c63c71Sbeck 42808c63c71Sbeck /* 42908c63c71Sbeck * Algorithm 6 of spec. Rejection samples a Keccak stream to get uniformly 43008c63c71Sbeck * distributed elements. This is used for matrix expansion and only operates on 43108c63c71Sbeck * public inputs. 43208c63c71Sbeck */ 43308c63c71Sbeck static void 43408c63c71Sbeck scalar_from_keccak_vartime(scalar *out, sha3_ctx *keccak_ctx) 43508c63c71Sbeck { 43608c63c71Sbeck int i, done = 0; 43708c63c71Sbeck 43808c63c71Sbeck while (done < DEGREE) { 43908c63c71Sbeck uint8_t block[168]; 44008c63c71Sbeck 44108c63c71Sbeck shake_out(keccak_ctx, block, sizeof(block)); 44208c63c71Sbeck for (i = 0; i < sizeof(block) && done < DEGREE; i += 3) { 44308c63c71Sbeck uint16_t d1 = block[i] + 256 * (block[i + 1] % 16); 44408c63c71Sbeck uint16_t d2 = block[i + 1] / 16 + 16 * block[i + 2]; 44508c63c71Sbeck 44608c63c71Sbeck if (d1 < kPrime) { 44708c63c71Sbeck out->c[done++] = d1; 44808c63c71Sbeck } 44908c63c71Sbeck if (d2 < kPrime && done < DEGREE) { 45008c63c71Sbeck out->c[done++] = d2; 45108c63c71Sbeck } 45208c63c71Sbeck } 45308c63c71Sbeck } 45408c63c71Sbeck } 45508c63c71Sbeck 45608c63c71Sbeck /* 45708c63c71Sbeck * Algorithm 7 of the spec, with eta fixed to two and the PRF call 45808c63c71Sbeck * included. Creates binominally distributed elements by sampling 2*|eta| bits, 45908c63c71Sbeck * and setting the coefficient to the count of the first bits minus the count of 46008c63c71Sbeck * the second bits, resulting in a centered binomial distribution. Since eta is 46108c63c71Sbeck * two this gives -2/2 with a probability of 1/16, -1/1 with probability 1/4, 46208c63c71Sbeck * and 0 with probability 3/8. 46308c63c71Sbeck */ 46408c63c71Sbeck static void 46508c63c71Sbeck scalar_centered_binomial_distribution_eta_2_with_prf(scalar *out, 46608c63c71Sbeck const uint8_t input[33]) 46708c63c71Sbeck { 46808c63c71Sbeck uint8_t entropy[128]; 46908c63c71Sbeck int i; 47008c63c71Sbeck 47108c63c71Sbeck CTASSERT(sizeof(entropy) == 2 * /*kEta=*/ 2 * DEGREE / 8); 47208c63c71Sbeck prf(entropy, sizeof(entropy), input); 47308c63c71Sbeck 47408c63c71Sbeck for (i = 0; i < DEGREE; i += 2) { 47508c63c71Sbeck uint8_t byte = entropy[i / 2]; 4769ee6f1feSbeck uint16_t mask; 4779ee6f1feSbeck uint16_t value = (byte & 1) + ((byte >> 1) & 1); 47808c63c71Sbeck 47908c63c71Sbeck value -= ((byte >> 2) & 1) + ((byte >> 3) & 1); 480*67abc7a1Stb 4819ee6f1feSbeck /* 4829ee6f1feSbeck * Add |kPrime| if |value| underflowed. See |reduce_once| for a 4839ee6f1feSbeck * discussion on why the value barrier is omitted. While this 4849ee6f1feSbeck * could have been written reduce_once(value + kPrime), this is 4859ee6f1feSbeck * one extra addition and small range of |value| tempts some 4869ee6f1feSbeck * versions of Clang to emit a branch. 4879ee6f1feSbeck */ 4889ee6f1feSbeck mask = 0u - (value >> 15); 4899ee6f1feSbeck out->c[i] = ((value + kPrime) & mask) | (value & ~mask); 49008c63c71Sbeck 49108c63c71Sbeck byte >>= 4; 4929ee6f1feSbeck value = (byte & 1) + ((byte >> 1) & 1); 49308c63c71Sbeck value -= ((byte >> 2) & 1) + ((byte >> 3) & 1); 4949ee6f1feSbeck /* See above. */ 4959ee6f1feSbeck mask = 0u - (value >> 15); 4969ee6f1feSbeck out->c[i + 1] = ((value + kPrime) & mask) | (value & ~mask); 49708c63c71Sbeck } 49808c63c71Sbeck } 49908c63c71Sbeck 50008c63c71Sbeck /* 50108c63c71Sbeck * Generates a secret vector by using 50208c63c71Sbeck * |scalar_centered_binomial_distribution_eta_2_with_prf|, using the given seed 50308c63c71Sbeck * appending and incrementing |counter| for entry of the vector. 50408c63c71Sbeck */ 50508c63c71Sbeck static void 50608c63c71Sbeck vector_generate_secret_eta_2(vector *out, uint8_t *counter, 50708c63c71Sbeck const uint8_t seed[32]) 50808c63c71Sbeck { 50908c63c71Sbeck uint8_t input[33]; 51008c63c71Sbeck int i; 51108c63c71Sbeck 51208c63c71Sbeck memcpy(input, seed, 32); 51308c63c71Sbeck for (i = 0; i < RANK1024; i++) { 51408c63c71Sbeck input[32] = (*counter)++; 51508c63c71Sbeck scalar_centered_binomial_distribution_eta_2_with_prf(&out->v[i], 51608c63c71Sbeck input); 51708c63c71Sbeck } 51808c63c71Sbeck } 51908c63c71Sbeck 52008c63c71Sbeck /* Expands the matrix of a seed for key generation and for encaps-CPA. */ 52108c63c71Sbeck static void 52208c63c71Sbeck matrix_expand(matrix *out, const uint8_t rho[32]) 52308c63c71Sbeck { 52408c63c71Sbeck uint8_t input[34]; 52508c63c71Sbeck int i, j; 52608c63c71Sbeck 52708c63c71Sbeck memcpy(input, rho, 32); 52808c63c71Sbeck for (i = 0; i < RANK1024; i++) { 52908c63c71Sbeck for (j = 0; j < RANK1024; j++) { 53008c63c71Sbeck sha3_ctx keccak_ctx; 53108c63c71Sbeck 53208c63c71Sbeck input[32] = i; 53308c63c71Sbeck input[33] = j; 53408c63c71Sbeck shake128_init(&keccak_ctx); 53508c63c71Sbeck shake_update(&keccak_ctx, input, sizeof(input)); 53608c63c71Sbeck shake_xof(&keccak_ctx); 53708c63c71Sbeck scalar_from_keccak_vartime(&out->v[i][j], &keccak_ctx); 53808c63c71Sbeck } 53908c63c71Sbeck } 54008c63c71Sbeck } 54108c63c71Sbeck 54208c63c71Sbeck static const uint8_t kMasks[8] = {0x01, 0x03, 0x07, 0x0f, 54308c63c71Sbeck 0x1f, 0x3f, 0x7f, 0xff}; 54408c63c71Sbeck 54508c63c71Sbeck static void 54608c63c71Sbeck scalar_encode(uint8_t *out, const scalar *s, int bits) 54708c63c71Sbeck { 54808c63c71Sbeck uint8_t out_byte = 0; 54908c63c71Sbeck int i, out_byte_bits = 0; 55008c63c71Sbeck 55108c63c71Sbeck assert(bits <= (int)sizeof(*s->c) * 8 && bits != 1); 55208c63c71Sbeck for (i = 0; i < DEGREE; i++) { 55308c63c71Sbeck uint16_t element = s->c[i]; 55408c63c71Sbeck int element_bits_done = 0; 55508c63c71Sbeck 55608c63c71Sbeck while (element_bits_done < bits) { 55708c63c71Sbeck int chunk_bits = bits - element_bits_done; 55808c63c71Sbeck int out_bits_remaining = 8 - out_byte_bits; 55908c63c71Sbeck 56008c63c71Sbeck if (chunk_bits >= out_bits_remaining) { 56108c63c71Sbeck chunk_bits = out_bits_remaining; 56208c63c71Sbeck out_byte |= (element & 56308c63c71Sbeck kMasks[chunk_bits - 1]) << out_byte_bits; 56408c63c71Sbeck *out = out_byte; 56508c63c71Sbeck out++; 56608c63c71Sbeck out_byte_bits = 0; 56708c63c71Sbeck out_byte = 0; 56808c63c71Sbeck } else { 56908c63c71Sbeck out_byte |= (element & 57008c63c71Sbeck kMasks[chunk_bits - 1]) << out_byte_bits; 57108c63c71Sbeck out_byte_bits += chunk_bits; 57208c63c71Sbeck } 57308c63c71Sbeck 57408c63c71Sbeck element_bits_done += chunk_bits; 57508c63c71Sbeck element >>= chunk_bits; 57608c63c71Sbeck } 57708c63c71Sbeck } 57808c63c71Sbeck 57908c63c71Sbeck if (out_byte_bits > 0) { 58008c63c71Sbeck *out = out_byte; 58108c63c71Sbeck } 58208c63c71Sbeck } 58308c63c71Sbeck 58408c63c71Sbeck /* scalar_encode_1 is |scalar_encode| specialised for |bits| == 1. */ 58508c63c71Sbeck static void 58608c63c71Sbeck scalar_encode_1(uint8_t out[32], const scalar *s) 58708c63c71Sbeck { 58808c63c71Sbeck int i, j; 58908c63c71Sbeck 59008c63c71Sbeck for (i = 0; i < DEGREE; i += 8) { 59108c63c71Sbeck uint8_t out_byte = 0; 59208c63c71Sbeck 59308c63c71Sbeck for (j = 0; j < 8; j++) { 59408c63c71Sbeck out_byte |= (s->c[i + j] & 1) << j; 59508c63c71Sbeck } 59608c63c71Sbeck *out = out_byte; 59708c63c71Sbeck out++; 59808c63c71Sbeck } 59908c63c71Sbeck } 60008c63c71Sbeck 60108c63c71Sbeck /* 60208c63c71Sbeck * Encodes an entire vector into 32*|RANK1024|*|bits| bytes. Note that since 256 60308c63c71Sbeck * (DEGREE) is divisible by 8, the individual vector entries will always fill a 60408c63c71Sbeck * whole number of bytes, so we do not need to worry about bit packing here. 60508c63c71Sbeck */ 60608c63c71Sbeck static void 60708c63c71Sbeck vector_encode(uint8_t *out, const vector *a, int bits) 60808c63c71Sbeck { 60908c63c71Sbeck int i; 61008c63c71Sbeck 61108c63c71Sbeck for (i = 0; i < RANK1024; i++) { 61208c63c71Sbeck scalar_encode(out + i * bits * DEGREE / 8, &a->v[i], bits); 61308c63c71Sbeck } 61408c63c71Sbeck } 61508c63c71Sbeck 61608c63c71Sbeck /* 61708c63c71Sbeck * scalar_decode parses |DEGREE * bits| bits from |in| into |DEGREE| values in 61808c63c71Sbeck * |out|. It returns one on success and zero if any parsed value is >= 61908c63c71Sbeck * |kPrime|. 62008c63c71Sbeck */ 62108c63c71Sbeck static int 62208c63c71Sbeck scalar_decode(scalar *out, const uint8_t *in, int bits) 62308c63c71Sbeck { 62408c63c71Sbeck uint8_t in_byte = 0; 62508c63c71Sbeck int i, in_byte_bits_left = 0; 62608c63c71Sbeck 62708c63c71Sbeck assert(bits <= (int)sizeof(*out->c) * 8 && bits != 1); 62808c63c71Sbeck 62908c63c71Sbeck for (i = 0; i < DEGREE; i++) { 63008c63c71Sbeck uint16_t element = 0; 63108c63c71Sbeck int element_bits_done = 0; 63208c63c71Sbeck 63308c63c71Sbeck while (element_bits_done < bits) { 63408c63c71Sbeck int chunk_bits = bits - element_bits_done; 63508c63c71Sbeck 63608c63c71Sbeck if (in_byte_bits_left == 0) { 63708c63c71Sbeck in_byte = *in; 63808c63c71Sbeck in++; 63908c63c71Sbeck in_byte_bits_left = 8; 64008c63c71Sbeck } 64108c63c71Sbeck 64208c63c71Sbeck if (chunk_bits > in_byte_bits_left) { 64308c63c71Sbeck chunk_bits = in_byte_bits_left; 64408c63c71Sbeck } 64508c63c71Sbeck 64608c63c71Sbeck element |= (in_byte & kMasks[chunk_bits - 1]) << 64708c63c71Sbeck element_bits_done; 64808c63c71Sbeck in_byte_bits_left -= chunk_bits; 64908c63c71Sbeck in_byte >>= chunk_bits; 65008c63c71Sbeck 65108c63c71Sbeck element_bits_done += chunk_bits; 65208c63c71Sbeck } 65308c63c71Sbeck 65408c63c71Sbeck if (element >= kPrime) { 65508c63c71Sbeck return 0; 65608c63c71Sbeck } 65708c63c71Sbeck out->c[i] = element; 65808c63c71Sbeck } 65908c63c71Sbeck 66008c63c71Sbeck return 1; 66108c63c71Sbeck } 66208c63c71Sbeck 66308c63c71Sbeck /* scalar_decode_1 is |scalar_decode| specialised for |bits| == 1. */ 66408c63c71Sbeck static void 66508c63c71Sbeck scalar_decode_1(scalar *out, const uint8_t in[32]) 66608c63c71Sbeck { 66708c63c71Sbeck int i, j; 66808c63c71Sbeck 66908c63c71Sbeck for (i = 0; i < DEGREE; i += 8) { 67008c63c71Sbeck uint8_t in_byte = *in; 67108c63c71Sbeck 67208c63c71Sbeck in++; 67308c63c71Sbeck for (j = 0; j < 8; j++) { 67408c63c71Sbeck out->c[i + j] = in_byte & 1; 67508c63c71Sbeck in_byte >>= 1; 67608c63c71Sbeck } 67708c63c71Sbeck } 67808c63c71Sbeck } 67908c63c71Sbeck 68008c63c71Sbeck /* 68108c63c71Sbeck * Decodes 32*|RANK1024|*|bits| bytes from |in| into |out|. It returns one on 68208c63c71Sbeck * success or zero if any parsed value is >= |kPrime|. 68308c63c71Sbeck */ 68408c63c71Sbeck static int 68508c63c71Sbeck vector_decode(vector *out, const uint8_t *in, int bits) 68608c63c71Sbeck { 68708c63c71Sbeck int i; 68808c63c71Sbeck 68908c63c71Sbeck for (i = 0; i < RANK1024; i++) { 69008c63c71Sbeck if (!scalar_decode(&out->v[i], in + i * bits * DEGREE / 8, 69108c63c71Sbeck bits)) { 69208c63c71Sbeck return 0; 69308c63c71Sbeck } 69408c63c71Sbeck } 69508c63c71Sbeck return 1; 69608c63c71Sbeck } 69708c63c71Sbeck 69808c63c71Sbeck /* 69908c63c71Sbeck * Compresses (lossily) an input |x| mod 3329 into |bits| many bits by grouping 70008c63c71Sbeck * numbers close to each other together. The formula used is 70108c63c71Sbeck * round(2^|bits|/kPrime*x) mod 2^|bits|. 70208c63c71Sbeck * Uses Barrett reduction to achieve constant time. Since we need both the 70308c63c71Sbeck * remainder (for rounding) and the quotient (as the result), we cannot use 70408c63c71Sbeck * |reduce| here, but need to do the Barrett reduction directly. 70508c63c71Sbeck */ 70608c63c71Sbeck static uint16_t 70708c63c71Sbeck compress(uint16_t x, int bits) 70808c63c71Sbeck { 70908c63c71Sbeck uint32_t shifted = (uint32_t)x << bits; 71008c63c71Sbeck uint64_t product = (uint64_t)shifted * kBarrettMultiplier; 71108c63c71Sbeck uint32_t quotient = (uint32_t)(product >> kBarrettShift); 71208c63c71Sbeck uint32_t remainder = shifted - quotient * kPrime; 71308c63c71Sbeck 71408c63c71Sbeck /* 71508c63c71Sbeck * Adjust the quotient to round correctly: 71608c63c71Sbeck * 0 <= remainder <= kHalfPrime round to 0 71708c63c71Sbeck * kHalfPrime < remainder <= kPrime + kHalfPrime round to 1 71808c63c71Sbeck * kPrime + kHalfPrime < remainder < 2 * kPrime round to 2 71908c63c71Sbeck */ 72008c63c71Sbeck assert(remainder < 2u * kPrime); 72108c63c71Sbeck quotient += 1 & constant_time_lt(kHalfPrime, remainder); 72208c63c71Sbeck quotient += 1 & constant_time_lt(kPrime + kHalfPrime, remainder); 72308c63c71Sbeck return quotient & ((1 << bits) - 1); 72408c63c71Sbeck } 72508c63c71Sbeck 72608c63c71Sbeck /* 72708c63c71Sbeck * Decompresses |x| by using an equi-distant representative. The formula is 72808c63c71Sbeck * round(kPrime/2^|bits|*x). Note that 2^|bits| being the divisor allows us to 72908c63c71Sbeck * implement this logic using only bit operations. 73008c63c71Sbeck */ 73108c63c71Sbeck static uint16_t 73208c63c71Sbeck decompress(uint16_t x, int bits) 73308c63c71Sbeck { 73408c63c71Sbeck uint32_t product = (uint32_t)x * kPrime; 73508c63c71Sbeck uint32_t power = 1 << bits; 73608c63c71Sbeck /* This is |product| % power, since |power| is a power of 2. */ 73708c63c71Sbeck uint32_t remainder = product & (power - 1); 73808c63c71Sbeck /* This is |product| / power, since |power| is a power of 2. */ 73908c63c71Sbeck uint32_t lower = product >> bits; 74008c63c71Sbeck 74108c63c71Sbeck /* 74208c63c71Sbeck * The rounding logic works since the first half of numbers mod |power| have a 74308c63c71Sbeck * 0 as first bit, and the second half has a 1 as first bit, since |power| is 74408c63c71Sbeck * a power of 2. As a 12 bit number, |remainder| is always positive, so we 74508c63c71Sbeck * will shift in 0s for a right shift. 74608c63c71Sbeck */ 74708c63c71Sbeck return lower + (remainder >> (bits - 1)); 74808c63c71Sbeck } 74908c63c71Sbeck 75008c63c71Sbeck static void 75108c63c71Sbeck scalar_compress(scalar *s, int bits) 75208c63c71Sbeck { 75308c63c71Sbeck int i; 75408c63c71Sbeck 75508c63c71Sbeck for (i = 0; i < DEGREE; i++) { 75608c63c71Sbeck s->c[i] = compress(s->c[i], bits); 75708c63c71Sbeck } 75808c63c71Sbeck } 75908c63c71Sbeck 76008c63c71Sbeck static void 76108c63c71Sbeck scalar_decompress(scalar *s, int bits) 76208c63c71Sbeck { 76308c63c71Sbeck int i; 76408c63c71Sbeck 76508c63c71Sbeck for (i = 0; i < DEGREE; i++) { 76608c63c71Sbeck s->c[i] = decompress(s->c[i], bits); 76708c63c71Sbeck } 76808c63c71Sbeck } 76908c63c71Sbeck 77008c63c71Sbeck static void 77108c63c71Sbeck vector_compress(vector *a, int bits) 77208c63c71Sbeck { 77308c63c71Sbeck int i; 77408c63c71Sbeck 77508c63c71Sbeck for (i = 0; i < RANK1024; i++) { 77608c63c71Sbeck scalar_compress(&a->v[i], bits); 77708c63c71Sbeck } 77808c63c71Sbeck } 77908c63c71Sbeck 78008c63c71Sbeck static void 78108c63c71Sbeck vector_decompress(vector *a, int bits) 78208c63c71Sbeck { 78308c63c71Sbeck int i; 78408c63c71Sbeck 78508c63c71Sbeck for (i = 0; i < RANK1024; i++) { 78608c63c71Sbeck scalar_decompress(&a->v[i], bits); 78708c63c71Sbeck } 78808c63c71Sbeck } 78908c63c71Sbeck 79008c63c71Sbeck struct public_key { 79108c63c71Sbeck vector t; 79208c63c71Sbeck uint8_t rho[32]; 79308c63c71Sbeck uint8_t public_key_hash[32]; 79408c63c71Sbeck matrix m; 79508c63c71Sbeck }; 79608c63c71Sbeck 79708c63c71Sbeck static struct public_key * 79808c63c71Sbeck public_key_1024_from_external(const struct MLKEM1024_public_key *external) 79908c63c71Sbeck { 80008c63c71Sbeck return (struct public_key *)external; 80108c63c71Sbeck } 80208c63c71Sbeck 80308c63c71Sbeck struct private_key { 80408c63c71Sbeck struct public_key pub; 80508c63c71Sbeck vector s; 80608c63c71Sbeck uint8_t fo_failure_secret[32]; 80708c63c71Sbeck }; 80808c63c71Sbeck 80908c63c71Sbeck static struct private_key * 81008c63c71Sbeck private_key_1024_from_external(const struct MLKEM1024_private_key *external) 81108c63c71Sbeck { 81208c63c71Sbeck return (struct private_key *)external; 81308c63c71Sbeck } 81408c63c71Sbeck 81508c63c71Sbeck /* 81608c63c71Sbeck * Calls |MLKEM1024_generate_key_external_entropy| with random bytes from 81708c63c71Sbeck * |RAND_bytes|. 81808c63c71Sbeck */ 81908c63c71Sbeck void 82008c63c71Sbeck MLKEM1024_generate_key(uint8_t out_encoded_public_key[MLKEM1024_PUBLIC_KEY_BYTES], 82108c63c71Sbeck uint8_t optional_out_seed[MLKEM_SEED_BYTES], 82208c63c71Sbeck struct MLKEM1024_private_key *out_private_key) 82308c63c71Sbeck { 82408c63c71Sbeck uint8_t entropy_buf[MLKEM_SEED_BYTES]; 82508c63c71Sbeck uint8_t *entropy = optional_out_seed != NULL ? optional_out_seed : 82608c63c71Sbeck entropy_buf; 82708c63c71Sbeck 82808c63c71Sbeck arc4random_buf(entropy, MLKEM_SEED_BYTES); 82908c63c71Sbeck MLKEM1024_generate_key_external_entropy(out_encoded_public_key, 83008c63c71Sbeck out_private_key, entropy); 83108c63c71Sbeck } 83208c63c71Sbeck LCRYPTO_ALIAS(MLKEM1024_generate_key); 83308c63c71Sbeck 83408c63c71Sbeck int 83508c63c71Sbeck MLKEM1024_private_key_from_seed(struct MLKEM1024_private_key *out_private_key, 83608c63c71Sbeck const uint8_t *seed, size_t seed_len) 83708c63c71Sbeck { 83808c63c71Sbeck uint8_t public_key_bytes[MLKEM1024_PUBLIC_KEY_BYTES]; 83908c63c71Sbeck 84008c63c71Sbeck if (seed_len != MLKEM_SEED_BYTES) { 84108c63c71Sbeck return 0; 84208c63c71Sbeck } 84308c63c71Sbeck MLKEM1024_generate_key_external_entropy(public_key_bytes, 84408c63c71Sbeck out_private_key, seed); 84508c63c71Sbeck 84608c63c71Sbeck return 1; 84708c63c71Sbeck } 84808c63c71Sbeck LCRYPTO_ALIAS(MLKEM1024_private_key_from_seed); 84908c63c71Sbeck 85008c63c71Sbeck static int 85108c63c71Sbeck mlkem_marshal_public_key(CBB *out, const struct public_key *pub) 85208c63c71Sbeck { 85308c63c71Sbeck uint8_t *vector_output; 85408c63c71Sbeck 85508c63c71Sbeck if (!CBB_add_space(out, &vector_output, kEncodedVectorSize)) { 85608c63c71Sbeck return 0; 85708c63c71Sbeck } 85808c63c71Sbeck vector_encode(vector_output, &pub->t, kLog2Prime); 85908c63c71Sbeck if (!CBB_add_bytes(out, pub->rho, sizeof(pub->rho))) { 86008c63c71Sbeck return 0; 86108c63c71Sbeck } 86208c63c71Sbeck return 1; 86308c63c71Sbeck } 86408c63c71Sbeck 86508c63c71Sbeck void 86608c63c71Sbeck MLKEM1024_generate_key_external_entropy( 86708c63c71Sbeck uint8_t out_encoded_public_key[MLKEM1024_PUBLIC_KEY_BYTES], 86808c63c71Sbeck struct MLKEM1024_private_key *out_private_key, 86908c63c71Sbeck const uint8_t entropy[MLKEM_SEED_BYTES]) 87008c63c71Sbeck { 87108c63c71Sbeck struct private_key *priv = private_key_1024_from_external( 87208c63c71Sbeck out_private_key); 87308c63c71Sbeck uint8_t augmented_seed[33]; 87408c63c71Sbeck uint8_t *rho, *sigma; 87508c63c71Sbeck uint8_t counter = 0; 87608c63c71Sbeck uint8_t hashed[64]; 87708c63c71Sbeck vector error; 87808c63c71Sbeck CBB cbb; 87908c63c71Sbeck 88008c63c71Sbeck memcpy(augmented_seed, entropy, 32); 88108c63c71Sbeck augmented_seed[32] = RANK1024; 88208c63c71Sbeck hash_g(hashed, augmented_seed, 33); 88308c63c71Sbeck rho = hashed; 88408c63c71Sbeck sigma = hashed + 32; 88508c63c71Sbeck memcpy(priv->pub.rho, hashed, sizeof(priv->pub.rho)); 88608c63c71Sbeck matrix_expand(&priv->pub.m, rho); 88708c63c71Sbeck vector_generate_secret_eta_2(&priv->s, &counter, sigma); 88808c63c71Sbeck vector_ntt(&priv->s); 88908c63c71Sbeck vector_generate_secret_eta_2(&error, &counter, sigma); 89008c63c71Sbeck vector_ntt(&error); 89108c63c71Sbeck matrix_mult_transpose(&priv->pub.t, &priv->pub.m, &priv->s); 89208c63c71Sbeck vector_add(&priv->pub.t, &error); 89308c63c71Sbeck 894516824a3Stb /* XXX - error checking. */ 89508c63c71Sbeck CBB_init_fixed(&cbb, out_encoded_public_key, MLKEM1024_PUBLIC_KEY_BYTES); 89608c63c71Sbeck if (!mlkem_marshal_public_key(&cbb, &priv->pub)) { 89708c63c71Sbeck abort(); 89808c63c71Sbeck } 899516824a3Stb CBB_cleanup(&cbb); 90008c63c71Sbeck 90108c63c71Sbeck hash_h(priv->pub.public_key_hash, out_encoded_public_key, 90208c63c71Sbeck MLKEM1024_PUBLIC_KEY_BYTES); 90308c63c71Sbeck memcpy(priv->fo_failure_secret, entropy + 32, 32); 90408c63c71Sbeck } 90508c63c71Sbeck 90608c63c71Sbeck void 90708c63c71Sbeck MLKEM1024_public_from_private(struct MLKEM1024_public_key *out_public_key, 90808c63c71Sbeck const struct MLKEM1024_private_key *private_key) 90908c63c71Sbeck { 91008c63c71Sbeck struct public_key *const pub = public_key_1024_from_external( 91108c63c71Sbeck out_public_key); 91208c63c71Sbeck const struct private_key *const priv = private_key_1024_from_external( 91308c63c71Sbeck private_key); 91408c63c71Sbeck 91508c63c71Sbeck *pub = priv->pub; 91608c63c71Sbeck } 91708c63c71Sbeck LCRYPTO_ALIAS(MLKEM1024_public_from_private); 91808c63c71Sbeck 91908c63c71Sbeck /* 92008c63c71Sbeck * Encrypts a message with given randomness to the ciphertext in |out|. Without 92108c63c71Sbeck * applying the Fujisaki-Okamoto transform this would not result in a CCA secure 92208c63c71Sbeck * scheme, since lattice schemes are vulnerable to decryption failure oracles. 92308c63c71Sbeck */ 92408c63c71Sbeck static void 92508c63c71Sbeck encrypt_cpa(uint8_t out[MLKEM1024_CIPHERTEXT_BYTES], 92608c63c71Sbeck const struct public_key *pub, const uint8_t message[32], 92708c63c71Sbeck const uint8_t randomness[32]) 92808c63c71Sbeck { 92908c63c71Sbeck scalar expanded_message, scalar_error; 93008c63c71Sbeck vector secret, error, u; 93108c63c71Sbeck uint8_t counter = 0; 93208c63c71Sbeck uint8_t input[33]; 93308c63c71Sbeck scalar v; 93408c63c71Sbeck 93508c63c71Sbeck vector_generate_secret_eta_2(&secret, &counter, randomness); 93608c63c71Sbeck vector_ntt(&secret); 93708c63c71Sbeck vector_generate_secret_eta_2(&error, &counter, randomness); 93808c63c71Sbeck memcpy(input, randomness, 32); 93908c63c71Sbeck input[32] = counter; 94008c63c71Sbeck scalar_centered_binomial_distribution_eta_2_with_prf(&scalar_error, 94108c63c71Sbeck input); 94208c63c71Sbeck matrix_mult(&u, &pub->m, &secret); 94308c63c71Sbeck vector_inverse_ntt(&u); 94408c63c71Sbeck vector_add(&u, &error); 94508c63c71Sbeck scalar_inner_product(&v, &pub->t, &secret); 94608c63c71Sbeck scalar_inverse_ntt(&v); 94708c63c71Sbeck scalar_add(&v, &scalar_error); 94808c63c71Sbeck scalar_decode_1(&expanded_message, message); 94908c63c71Sbeck scalar_decompress(&expanded_message, 1); 95008c63c71Sbeck scalar_add(&v, &expanded_message); 95108c63c71Sbeck vector_compress(&u, kDU1024); 95208c63c71Sbeck vector_encode(out, &u, kDU1024); 95308c63c71Sbeck scalar_compress(&v, kDV1024); 95408c63c71Sbeck scalar_encode(out + kCompressedVectorSize, &v, kDV1024); 95508c63c71Sbeck } 95608c63c71Sbeck 95708c63c71Sbeck /* Calls MLKEM1024_encap_external_entropy| with random bytes */ 95808c63c71Sbeck void 95908c63c71Sbeck MLKEM1024_encap(uint8_t out_ciphertext[MLKEM1024_CIPHERTEXT_BYTES], 96008c63c71Sbeck uint8_t out_shared_secret[MLKEM_SHARED_SECRET_BYTES], 96108c63c71Sbeck const struct MLKEM1024_public_key *public_key) 96208c63c71Sbeck { 96308c63c71Sbeck uint8_t entropy[MLKEM_ENCAP_ENTROPY]; 96408c63c71Sbeck 96508c63c71Sbeck arc4random_buf(entropy, MLKEM_ENCAP_ENTROPY); 96608c63c71Sbeck MLKEM1024_encap_external_entropy(out_ciphertext, out_shared_secret, 96708c63c71Sbeck public_key, entropy); 96808c63c71Sbeck } 96908c63c71Sbeck LCRYPTO_ALIAS(MLKEM1024_encap); 97008c63c71Sbeck 97108c63c71Sbeck /* See section 6.2 of the spec. */ 97208c63c71Sbeck void 97308c63c71Sbeck MLKEM1024_encap_external_entropy( 97408c63c71Sbeck uint8_t out_ciphertext[MLKEM1024_CIPHERTEXT_BYTES], 97508c63c71Sbeck uint8_t out_shared_secret[MLKEM_SHARED_SECRET_BYTES], 97608c63c71Sbeck const struct MLKEM1024_public_key *public_key, 97708c63c71Sbeck const uint8_t entropy[MLKEM_ENCAP_ENTROPY]) 97808c63c71Sbeck { 97908c63c71Sbeck const struct public_key *pub = public_key_1024_from_external(public_key); 98008c63c71Sbeck uint8_t key_and_randomness[64]; 98108c63c71Sbeck uint8_t input[64]; 98208c63c71Sbeck 98308c63c71Sbeck memcpy(input, entropy, MLKEM_ENCAP_ENTROPY); 98408c63c71Sbeck memcpy(input + MLKEM_ENCAP_ENTROPY, pub->public_key_hash, 98508c63c71Sbeck sizeof(input) - MLKEM_ENCAP_ENTROPY); 98608c63c71Sbeck hash_g(key_and_randomness, input, sizeof(input)); 98708c63c71Sbeck encrypt_cpa(out_ciphertext, pub, entropy, key_and_randomness + 32); 98808c63c71Sbeck memcpy(out_shared_secret, key_and_randomness, 32); 98908c63c71Sbeck } 99008c63c71Sbeck 99108c63c71Sbeck static void 99208c63c71Sbeck decrypt_cpa(uint8_t out[32], const struct private_key *priv, 99308c63c71Sbeck const uint8_t ciphertext[MLKEM1024_CIPHERTEXT_BYTES]) 99408c63c71Sbeck { 99508c63c71Sbeck scalar mask, v; 99608c63c71Sbeck vector u; 99708c63c71Sbeck 99808c63c71Sbeck vector_decode(&u, ciphertext, kDU1024); 99908c63c71Sbeck vector_decompress(&u, kDU1024); 100008c63c71Sbeck vector_ntt(&u); 100108c63c71Sbeck scalar_decode(&v, ciphertext + kCompressedVectorSize, kDV1024); 100208c63c71Sbeck scalar_decompress(&v, kDV1024); 100308c63c71Sbeck scalar_inner_product(&mask, &priv->s, &u); 100408c63c71Sbeck scalar_inverse_ntt(&mask); 100508c63c71Sbeck scalar_sub(&v, &mask); 100608c63c71Sbeck scalar_compress(&v, 1); 100708c63c71Sbeck scalar_encode_1(out, &v); 100808c63c71Sbeck } 100908c63c71Sbeck 101008c63c71Sbeck /* See section 6.3 */ 101108c63c71Sbeck int 101208c63c71Sbeck MLKEM1024_decap(uint8_t out_shared_secret[MLKEM_SHARED_SECRET_BYTES], 101308c63c71Sbeck const uint8_t *ciphertext, size_t ciphertext_len, 101408c63c71Sbeck const struct MLKEM1024_private_key *private_key) 101508c63c71Sbeck { 101608c63c71Sbeck const struct private_key *priv = private_key_1024_from_external( 101708c63c71Sbeck private_key); 101808c63c71Sbeck uint8_t expected_ciphertext[MLKEM1024_CIPHERTEXT_BYTES]; 101908c63c71Sbeck uint8_t key_and_randomness[64]; 102008c63c71Sbeck uint8_t failure_key[32]; 102108c63c71Sbeck uint8_t decrypted[64]; 102208c63c71Sbeck uint8_t mask; 102308c63c71Sbeck int i; 102408c63c71Sbeck 102508c63c71Sbeck if (ciphertext_len != MLKEM1024_CIPHERTEXT_BYTES) { 102608c63c71Sbeck arc4random_buf(out_shared_secret, MLKEM_SHARED_SECRET_BYTES); 102708c63c71Sbeck return 0; 102808c63c71Sbeck } 102908c63c71Sbeck 103008c63c71Sbeck decrypt_cpa(decrypted, priv, ciphertext); 103108c63c71Sbeck memcpy(decrypted + 32, priv->pub.public_key_hash, 103208c63c71Sbeck sizeof(decrypted) - 32); 103308c63c71Sbeck hash_g(key_and_randomness, decrypted, sizeof(decrypted)); 103408c63c71Sbeck encrypt_cpa(expected_ciphertext, &priv->pub, decrypted, 103508c63c71Sbeck key_and_randomness + 32); 103608c63c71Sbeck kdf(failure_key, priv->fo_failure_secret, ciphertext, ciphertext_len); 103708c63c71Sbeck mask = constant_time_eq_int_8(memcmp(ciphertext, expected_ciphertext, 103808c63c71Sbeck sizeof(expected_ciphertext)), 0); 103908c63c71Sbeck for (i = 0; i < MLKEM_SHARED_SECRET_BYTES; i++) { 104008c63c71Sbeck out_shared_secret[i] = constant_time_select_8(mask, 104108c63c71Sbeck key_and_randomness[i], failure_key[i]); 104208c63c71Sbeck } 104308c63c71Sbeck 104408c63c71Sbeck return 1; 104508c63c71Sbeck } 104608c63c71Sbeck LCRYPTO_ALIAS(MLKEM1024_decap); 104708c63c71Sbeck 104808c63c71Sbeck int 104908c63c71Sbeck MLKEM1024_marshal_public_key(CBB *out, 105008c63c71Sbeck const struct MLKEM1024_public_key *public_key) 105108c63c71Sbeck { 105208c63c71Sbeck return mlkem_marshal_public_key(out, 105308c63c71Sbeck public_key_1024_from_external(public_key)); 105408c63c71Sbeck } 105508c63c71Sbeck LCRYPTO_ALIAS(MLKEM1024_marshal_public_key); 105608c63c71Sbeck 105708c63c71Sbeck /* 105808c63c71Sbeck * mlkem_parse_public_key_no_hash parses |in| into |pub| but doesn't calculate 105908c63c71Sbeck * the value of |pub->public_key_hash|. 106008c63c71Sbeck */ 106108c63c71Sbeck static int 106208c63c71Sbeck mlkem_parse_public_key_no_hash(struct public_key *pub, CBS *in) 106308c63c71Sbeck { 106408c63c71Sbeck CBS t_bytes; 106508c63c71Sbeck 106608c63c71Sbeck if (!CBS_get_bytes(in, &t_bytes, kEncodedVectorSize) || 106708c63c71Sbeck !vector_decode(&pub->t, CBS_data(&t_bytes), kLog2Prime)) { 106808c63c71Sbeck return 0; 106908c63c71Sbeck } 107008c63c71Sbeck memcpy(pub->rho, CBS_data(in), sizeof(pub->rho)); 107108c63c71Sbeck if (!CBS_skip(in, sizeof(pub->rho))) 107208c63c71Sbeck return 0; 107308c63c71Sbeck matrix_expand(&pub->m, pub->rho); 107408c63c71Sbeck return 1; 107508c63c71Sbeck } 107608c63c71Sbeck 107708c63c71Sbeck int 107808c63c71Sbeck MLKEM1024_parse_public_key(struct MLKEM1024_public_key *public_key, CBS *in) 107908c63c71Sbeck { 108008c63c71Sbeck struct public_key *pub = public_key_1024_from_external(public_key); 108108c63c71Sbeck CBS orig_in = *in; 108208c63c71Sbeck 108308c63c71Sbeck if (!mlkem_parse_public_key_no_hash(pub, in) || 108408c63c71Sbeck CBS_len(in) != 0) { 108508c63c71Sbeck return 0; 108608c63c71Sbeck } 108708c63c71Sbeck hash_h(pub->public_key_hash, CBS_data(&orig_in), CBS_len(&orig_in)); 108808c63c71Sbeck return 1; 108908c63c71Sbeck } 109008c63c71Sbeck LCRYPTO_ALIAS(MLKEM1024_parse_public_key); 109108c63c71Sbeck 109208c63c71Sbeck int 109308c63c71Sbeck MLKEM1024_marshal_private_key(CBB *out, 109408c63c71Sbeck const struct MLKEM1024_private_key *private_key) 109508c63c71Sbeck { 109608c63c71Sbeck const struct private_key *const priv = private_key_1024_from_external( 109708c63c71Sbeck private_key); 109808c63c71Sbeck uint8_t *s_output; 109908c63c71Sbeck 110008c63c71Sbeck if (!CBB_add_space(out, &s_output, kEncodedVectorSize)) { 110108c63c71Sbeck return 0; 110208c63c71Sbeck } 110308c63c71Sbeck vector_encode(s_output, &priv->s, kLog2Prime); 110408c63c71Sbeck if (!mlkem_marshal_public_key(out, &priv->pub) || 110508c63c71Sbeck !CBB_add_bytes(out, priv->pub.public_key_hash, 110608c63c71Sbeck sizeof(priv->pub.public_key_hash)) || 110708c63c71Sbeck !CBB_add_bytes(out, priv->fo_failure_secret, 110808c63c71Sbeck sizeof(priv->fo_failure_secret))) { 110908c63c71Sbeck return 0; 111008c63c71Sbeck } 111108c63c71Sbeck return 1; 111208c63c71Sbeck } 111308c63c71Sbeck 111408c63c71Sbeck int 111508c63c71Sbeck MLKEM1024_parse_private_key(struct MLKEM1024_private_key *out_private_key, 111608c63c71Sbeck CBS *in) 111708c63c71Sbeck { 111808c63c71Sbeck struct private_key *const priv = private_key_1024_from_external( 111908c63c71Sbeck out_private_key); 112008c63c71Sbeck CBS s_bytes; 112108c63c71Sbeck 112208c63c71Sbeck if (!CBS_get_bytes(in, &s_bytes, kEncodedVectorSize) || 112308c63c71Sbeck !vector_decode(&priv->s, CBS_data(&s_bytes), kLog2Prime) || 112408c63c71Sbeck !mlkem_parse_public_key_no_hash(&priv->pub, in)) { 112508c63c71Sbeck return 0; 112608c63c71Sbeck } 112708c63c71Sbeck memcpy(priv->pub.public_key_hash, CBS_data(in), 112808c63c71Sbeck sizeof(priv->pub.public_key_hash)); 112908c63c71Sbeck if (!CBS_skip(in, sizeof(priv->pub.public_key_hash))) 113008c63c71Sbeck return 0; 113108c63c71Sbeck memcpy(priv->fo_failure_secret, CBS_data(in), 113208c63c71Sbeck sizeof(priv->fo_failure_secret)); 113308c63c71Sbeck if (!CBS_skip(in, sizeof(priv->fo_failure_secret))) 113408c63c71Sbeck return 0; 113508c63c71Sbeck if (CBS_len(in) != 0) 113608c63c71Sbeck return 0; 113708c63c71Sbeck 113808c63c71Sbeck return 1; 113908c63c71Sbeck } 114008c63c71Sbeck LCRYPTO_ALIAS(MLKEM1024_parse_private_key); 1141