1 /* $OpenBSD: mlkem1024.c,v 1.6 2025/01/03 08:19:24 tb Exp $ */
2 /*
3 * Copyright (c) 2024, Google Inc.
4 * Copyright (c) 2024, Bob Beck <beck@obtuse.com>
5 *
6 * Permission to use, copy, modify, and/or distribute this software for any
7 * purpose with or without fee is hereby granted, provided that the above
8 * copyright notice and this permission notice appear in all copies.
9 *
10 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
11 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
12 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
13 * SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
14 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
15 * OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
16 * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
17 */
18
19 #include <assert.h>
20 #include <stdlib.h>
21 #include <string.h>
22
23 #include "bytestring.h"
24 #include "mlkem.h"
25
26 #include "sha3_internal.h"
27 #include "mlkem_internal.h"
28 #include "constant_time.h"
29 #include "crypto_internal.h"
30
31 /* Remove later */
32 #undef LCRYPTO_ALIAS
33 #define LCRYPTO_ALIAS(A)
34
35 /*
36 * See
37 * https://csrc.nist.gov/pubs/fips/203/final
38 */
39
40 static void
prf(uint8_t * out,size_t out_len,const uint8_t in[33])41 prf(uint8_t *out, size_t out_len, const uint8_t in[33])
42 {
43 sha3_ctx ctx;
44 shake256_init(&ctx);
45 shake_update(&ctx, in, 33);
46 shake_xof(&ctx);
47 shake_out(&ctx, out, out_len);
48 }
49
50 /* Section 4.1 */
51 static void
hash_h(uint8_t out[32],const uint8_t * in,size_t len)52 hash_h(uint8_t out[32], const uint8_t *in, size_t len)
53 {
54 sha3_ctx ctx;
55 sha3_init(&ctx, 32);
56 sha3_update(&ctx, in, len);
57 sha3_final(out, &ctx);
58 }
59
60 static void
hash_g(uint8_t out[64],const uint8_t * in,size_t len)61 hash_g(uint8_t out[64], const uint8_t *in, size_t len)
62 {
63 sha3_ctx ctx;
64 sha3_init(&ctx, 64);
65 sha3_update(&ctx, in, len);
66 sha3_final(out, &ctx);
67 }
68
69 /* this is called 'J' in the spec */
70 static void
kdf(uint8_t out[MLKEM_SHARED_SECRET_BYTES],const uint8_t failure_secret[32],const uint8_t * in,size_t len)71 kdf(uint8_t out[MLKEM_SHARED_SECRET_BYTES], const uint8_t failure_secret[32],
72 const uint8_t *in, size_t len)
73 {
74 sha3_ctx ctx;
75 shake256_init(&ctx);
76 shake_update(&ctx, failure_secret, 32);
77 shake_update(&ctx, in, len);
78 shake_xof(&ctx);
79 shake_out(&ctx, out, MLKEM_SHARED_SECRET_BYTES);
80 }
81
82 #define DEGREE 256
83 #define RANK1024 4
84
85 static const size_t kBarrettMultiplier = 5039;
86 static const unsigned kBarrettShift = 24;
87 static const uint16_t kPrime = 3329;
88 static const int kLog2Prime = 12;
89 static const uint16_t kHalfPrime = (/*kPrime=*/3329 - 1) / 2;
90 static const int kDU1024 = 11;
91 static const int kDV1024 = 5;
92
93 /*
94 * kInverseDegree is 128^-1 mod 3329; 128 because kPrime does not have a 512th
95 * root of unity.
96 */
97 static const uint16_t kInverseDegree = 3303;
98 static const size_t kEncodedVectorSize =
99 (/*kLog2Prime=*/12 * DEGREE / 8) * RANK1024;
100 static const size_t kCompressedVectorSize = /*kDU1024=*/ 11 * RANK1024 * DEGREE /
101 8;
102
103 typedef struct scalar {
104 /* On every function entry and exit, 0 <= c < kPrime. */
105 uint16_t c[DEGREE];
106 } scalar;
107
108 typedef struct vector {
109 scalar v[RANK1024];
110 } vector;
111
112 typedef struct matrix {
113 scalar v[RANK1024][RANK1024];
114 } matrix;
115
116 /*
117 * This bit of Python will be referenced in some of the following comments:
118 *
119 * p = 3329
120 *
121 * def bitreverse(i):
122 * ret = 0
123 * for n in range(7):
124 * bit = i & 1
125 * ret <<= 1
126 * ret |= bit
127 * i >>= 1
128 * return ret
129 */
130
131 /* kNTTRoots = [pow(17, bitreverse(i), p) for i in range(128)] */
132 static const uint16_t kNTTRoots[128] = {
133 1, 1729, 2580, 3289, 2642, 630, 1897, 848, 1062, 1919, 193, 797,
134 2786, 3260, 569, 1746, 296, 2447, 1339, 1476, 3046, 56, 2240, 1333,
135 1426, 2094, 535, 2882, 2393, 2879, 1974, 821, 289, 331, 3253, 1756,
136 1197, 2304, 2277, 2055, 650, 1977, 2513, 632, 2865, 33, 1320, 1915,
137 2319, 1435, 807, 452, 1438, 2868, 1534, 2402, 2647, 2617, 1481, 648,
138 2474, 3110, 1227, 910, 17, 2761, 583, 2649, 1637, 723, 2288, 1100,
139 1409, 2662, 3281, 233, 756, 2156, 3015, 3050, 1703, 1651, 2789, 1789,
140 1847, 952, 1461, 2687, 939, 2308, 2437, 2388, 733, 2337, 268, 641,
141 1584, 2298, 2037, 3220, 375, 2549, 2090, 1645, 1063, 319, 2773, 757,
142 2099, 561, 2466, 2594, 2804, 1092, 403, 1026, 1143, 2150, 2775, 886,
143 1722, 1212, 1874, 1029, 2110, 2935, 885, 2154,
144 };
145
146 /* kInverseNTTRoots = [pow(17, -bitreverse(i), p) for i in range(128)] */
147 static const uint16_t kInverseNTTRoots[128] = {
148 1, 1600, 40, 749, 2481, 1432, 2699, 687, 1583, 2760, 69, 543,
149 2532, 3136, 1410, 2267, 2508, 1355, 450, 936, 447, 2794, 1235, 1903,
150 1996, 1089, 3273, 283, 1853, 1990, 882, 3033, 2419, 2102, 219, 855,
151 2681, 1848, 712, 682, 927, 1795, 461, 1891, 2877, 2522, 1894, 1010,
152 1414, 2009, 3296, 464, 2697, 816, 1352, 2679, 1274, 1052, 1025, 2132,
153 1573, 76, 2998, 3040, 1175, 2444, 394, 1219, 2300, 1455, 2117, 1607,
154 2443, 554, 1179, 2186, 2303, 2926, 2237, 525, 735, 863, 2768, 1230,
155 2572, 556, 3010, 2266, 1684, 1239, 780, 2954, 109, 1292, 1031, 1745,
156 2688, 3061, 992, 2596, 941, 892, 1021, 2390, 642, 1868, 2377, 1482,
157 1540, 540, 1678, 1626, 279, 314, 1173, 2573, 3096, 48, 667, 1920,
158 2229, 1041, 2606, 1692, 680, 2746, 568, 3312,
159 };
160
161 /* kModRoots = [pow(17, 2*bitreverse(i) + 1, p) for i in range(128)] */
162 static const uint16_t kModRoots[128] = {
163 17, 3312, 2761, 568, 583, 2746, 2649, 680, 1637, 1692, 723, 2606,
164 2288, 1041, 1100, 2229, 1409, 1920, 2662, 667, 3281, 48, 233, 3096,
165 756, 2573, 2156, 1173, 3015, 314, 3050, 279, 1703, 1626, 1651, 1678,
166 2789, 540, 1789, 1540, 1847, 1482, 952, 2377, 1461, 1868, 2687, 642,
167 939, 2390, 2308, 1021, 2437, 892, 2388, 941, 733, 2596, 2337, 992,
168 268, 3061, 641, 2688, 1584, 1745, 2298, 1031, 2037, 1292, 3220, 109,
169 375, 2954, 2549, 780, 2090, 1239, 1645, 1684, 1063, 2266, 319, 3010,
170 2773, 556, 757, 2572, 2099, 1230, 561, 2768, 2466, 863, 2594, 735,
171 2804, 525, 1092, 2237, 403, 2926, 1026, 2303, 1143, 2186, 2150, 1179,
172 2775, 554, 886, 2443, 1722, 1607, 1212, 2117, 1874, 1455, 1029, 2300,
173 2110, 1219, 2935, 394, 885, 2444, 2154, 1175,
174 };
175
176 /* reduce_once reduces 0 <= x < 2*kPrime, mod kPrime. */
177 static uint16_t
reduce_once(uint16_t x)178 reduce_once(uint16_t x)
179 {
180 assert(x < 2 * kPrime);
181 const uint16_t subtracted = x - kPrime;
182 uint16_t mask = 0u - (subtracted >> 15);
183
184 /*
185 * Although this is a constant-time select, we omit a value barrier here.
186 * Value barriers impede auto-vectorization (likely because it forces the
187 * value to transit through a general-purpose register). On AArch64, this
188 * is a difference of 2x.
189 *
190 * We usually add value barriers to selects because Clang turns
191 * consecutive selects with the same condition into a branch instead of
192 * CMOV/CSEL. This condition does not occur in ML-KEM, so omitting it
193 * seems to be safe so far but see
194 * |scalar_centered_binomial_distribution_eta_2_with_prf|.
195 */
196 return (mask & x) | (~mask & subtracted);
197 }
198
199 /*
200 * constant time reduce x mod kPrime using Barrett reduction. x must be less
201 * than kPrime + 2×kPrime².
202 */
203 static uint16_t
reduce(uint32_t x)204 reduce(uint32_t x)
205 {
206 uint64_t product = (uint64_t)x * kBarrettMultiplier;
207 uint32_t quotient = (uint32_t)(product >> kBarrettShift);
208 uint32_t remainder = x - quotient * kPrime;
209
210 assert(x < kPrime + 2u * kPrime * kPrime);
211 return reduce_once(remainder);
212 }
213
214 static void
scalar_zero(scalar * out)215 scalar_zero(scalar *out)
216 {
217 memset(out, 0, sizeof(*out));
218 }
219
220 static void
vector_zero(vector * out)221 vector_zero(vector *out)
222 {
223 memset(out, 0, sizeof(*out));
224 }
225
226 /*
227 * In place number theoretic transform of a given scalar.
228 * Note that MLKEM's kPrime 3329 does not have a 512th root of unity, so this
229 * transform leaves off the last iteration of the usual FFT code, with the 128
230 * relevant roots of unity being stored in |kNTTRoots|. This means the output
231 * should be seen as 128 elements in GF(3329^2), with the coefficients of the
232 * elements being consecutive entries in |s->c|.
233 */
234 static void
scalar_ntt(scalar * s)235 scalar_ntt(scalar *s)
236 {
237 int offset = DEGREE;
238 int step;
239 /*
240 * `int` is used here because using `size_t` throughout caused a ~5% slowdown
241 * with Clang 14 on Aarch64.
242 */
243 for (step = 1; step < DEGREE / 2; step <<= 1) {
244 int i, j, k = 0;
245
246 offset >>= 1;
247 for (i = 0; i < step; i++) {
248 const uint32_t step_root = kNTTRoots[i + step];
249
250 for (j = k; j < k + offset; j++) {
251 uint16_t odd, even;
252
253 odd = reduce(step_root * s->c[j + offset]);
254 even = s->c[j];
255 s->c[j] = reduce_once(odd + even);
256 s->c[j + offset] = reduce_once(even - odd +
257 kPrime);
258 }
259 k += 2 * offset;
260 }
261 }
262 }
263
264 static void
vector_ntt(vector * a)265 vector_ntt(vector *a)
266 {
267 int i;
268
269 for (i = 0; i < RANK1024; i++) {
270 scalar_ntt(&a->v[i]);
271 }
272 }
273
274 /*
275 * In place inverse number theoretic transform of a given scalar, with pairs of
276 * entries of s->v being interpreted as elements of GF(3329^2). Just as with the
277 * number theoretic transform, this leaves off the first step of the normal iFFT
278 * to account for the fact that 3329 does not have a 512th root of unity, using
279 * the precomputed 128 roots of unity stored in |kInverseNTTRoots|.
280 */
281 static void
scalar_inverse_ntt(scalar * s)282 scalar_inverse_ntt(scalar *s)
283 {
284 int i, j, k, offset, step = DEGREE / 2;
285
286 /*
287 * `int` is used here because using `size_t` throughout caused a ~5% slowdown
288 * with Clang 14 on Aarch64.
289 */
290 for (offset = 2; offset < DEGREE; offset <<= 1) {
291 step >>= 1;
292 k = 0;
293 for (i = 0; i < step; i++) {
294 uint32_t step_root = kInverseNTTRoots[i + step];
295 for (j = k; j < k + offset; j++) {
296 uint16_t odd, even;
297 odd = s->c[j + offset];
298 even = s->c[j];
299 s->c[j] = reduce_once(odd + even);
300 s->c[j + offset] = reduce(step_root *
301 (even - odd + kPrime));
302 }
303 k += 2 * offset;
304 }
305 }
306 for (i = 0; i < DEGREE; i++) {
307 s->c[i] = reduce(s->c[i] * kInverseDegree);
308 }
309 }
310
311 static void
vector_inverse_ntt(vector * a)312 vector_inverse_ntt(vector *a)
313 {
314 int i;
315
316 for (i = 0; i < RANK1024; i++) {
317 scalar_inverse_ntt(&a->v[i]);
318 }
319 }
320
321 static void
scalar_add(scalar * lhs,const scalar * rhs)322 scalar_add(scalar *lhs, const scalar *rhs)
323 {
324 int i;
325
326 for (i = 0; i < DEGREE; i++) {
327 lhs->c[i] = reduce_once(lhs->c[i] + rhs->c[i]);
328 }
329 }
330
331 static void
scalar_sub(scalar * lhs,const scalar * rhs)332 scalar_sub(scalar *lhs, const scalar *rhs)
333 {
334 int i;
335
336 for (i = 0; i < DEGREE; i++) {
337 lhs->c[i] = reduce_once(lhs->c[i] - rhs->c[i] + kPrime);
338 }
339 }
340
341 /*
342 * Multiplying two scalars in the number theoretically transformed state.
343 * Since 3329 does not have a 512th root of unity, this means we have to
344 * interpret the 2*ith and (2*i+1)th entries of the scalar as elements of
345 * GF(3329)[X]/(X^2 - 17^(2*bitreverse(i)+1)).
346 * The value of 17^(2*bitreverse(i)+1) mod 3329 is stored in the precomputed
347 * |kModRoots| table. Our Barrett transform only allows us to multiply two
348 * reduced numbers together, so we need some intermediate reduction steps,
349 * even if an uint64_t could hold 3 multiplied numbers.
350 */
351 static void
scalar_mult(scalar * out,const scalar * lhs,const scalar * rhs)352 scalar_mult(scalar *out, const scalar *lhs, const scalar *rhs)
353 {
354 int i;
355
356 for (i = 0; i < DEGREE / 2; i++) {
357 uint32_t real_real = (uint32_t)lhs->c[2 * i] * rhs->c[2 * i];
358 uint32_t img_img = (uint32_t)lhs->c[2 * i + 1] *
359 rhs->c[2 * i + 1];
360 uint32_t real_img = (uint32_t)lhs->c[2 * i] * rhs->c[2 * i + 1];
361 uint32_t img_real = (uint32_t)lhs->c[2 * i + 1] * rhs->c[2 * i];
362
363 out->c[2 * i] =
364 reduce(real_real +
365 (uint32_t)reduce(img_img) * kModRoots[i]);
366 out->c[2 * i + 1] = reduce(img_real + real_img);
367 }
368 }
369
370 static void
vector_add(vector * lhs,const vector * rhs)371 vector_add(vector *lhs, const vector *rhs)
372 {
373 int i;
374
375 for (i = 0; i < RANK1024; i++) {
376 scalar_add(&lhs->v[i], &rhs->v[i]);
377 }
378 }
379
380 static void
matrix_mult(vector * out,const matrix * m,const vector * a)381 matrix_mult(vector *out, const matrix *m, const vector *a)
382 {
383 int i, j;
384
385 vector_zero(out);
386 for (i = 0; i < RANK1024; i++) {
387 for (j = 0; j < RANK1024; j++) {
388 scalar product;
389
390 scalar_mult(&product, &m->v[i][j], &a->v[j]);
391 scalar_add(&out->v[i], &product);
392 }
393 }
394 }
395
396 static void
matrix_mult_transpose(vector * out,const matrix * m,const vector * a)397 matrix_mult_transpose(vector *out, const matrix *m,
398 const vector *a)
399 {
400 int i, j;
401
402 vector_zero(out);
403 for (i = 0; i < RANK1024; i++) {
404 for (j = 0; j < RANK1024; j++) {
405 scalar product;
406
407 scalar_mult(&product, &m->v[j][i], &a->v[j]);
408 scalar_add(&out->v[i], &product);
409 }
410 }
411 }
412
413 static void
scalar_inner_product(scalar * out,const vector * lhs,const vector * rhs)414 scalar_inner_product(scalar *out, const vector *lhs,
415 const vector *rhs)
416 {
417 int i;
418 scalar_zero(out);
419 for (i = 0; i < RANK1024; i++) {
420 scalar product;
421
422 scalar_mult(&product, &lhs->v[i], &rhs->v[i]);
423 scalar_add(out, &product);
424 }
425 }
426
427 /*
428 * Algorithm 6 of spec. Rejection samples a Keccak stream to get uniformly
429 * distributed elements. This is used for matrix expansion and only operates on
430 * public inputs.
431 */
432 static void
scalar_from_keccak_vartime(scalar * out,sha3_ctx * keccak_ctx)433 scalar_from_keccak_vartime(scalar *out, sha3_ctx *keccak_ctx)
434 {
435 int i, done = 0;
436
437 while (done < DEGREE) {
438 uint8_t block[168];
439
440 shake_out(keccak_ctx, block, sizeof(block));
441 for (i = 0; i < sizeof(block) && done < DEGREE; i += 3) {
442 uint16_t d1 = block[i] + 256 * (block[i + 1] % 16);
443 uint16_t d2 = block[i + 1] / 16 + 16 * block[i + 2];
444
445 if (d1 < kPrime) {
446 out->c[done++] = d1;
447 }
448 if (d2 < kPrime && done < DEGREE) {
449 out->c[done++] = d2;
450 }
451 }
452 }
453 }
454
455 /*
456 * Algorithm 7 of the spec, with eta fixed to two and the PRF call
457 * included. Creates binominally distributed elements by sampling 2*|eta| bits,
458 * and setting the coefficient to the count of the first bits minus the count of
459 * the second bits, resulting in a centered binomial distribution. Since eta is
460 * two this gives -2/2 with a probability of 1/16, -1/1 with probability 1/4,
461 * and 0 with probability 3/8.
462 */
463 static void
scalar_centered_binomial_distribution_eta_2_with_prf(scalar * out,const uint8_t input[33])464 scalar_centered_binomial_distribution_eta_2_with_prf(scalar *out,
465 const uint8_t input[33])
466 {
467 uint8_t entropy[128];
468 int i;
469
470 CTASSERT(sizeof(entropy) == 2 * /*kEta=*/ 2 * DEGREE / 8);
471 prf(entropy, sizeof(entropy), input);
472
473 for (i = 0; i < DEGREE; i += 2) {
474 uint8_t byte = entropy[i / 2];
475 uint16_t mask;
476 uint16_t value = (byte & 1) + ((byte >> 1) & 1);
477
478 value -= ((byte >> 2) & 1) + ((byte >> 3) & 1);
479
480 /*
481 * Add |kPrime| if |value| underflowed. See |reduce_once| for a
482 * discussion on why the value barrier is omitted. While this
483 * could have been written reduce_once(value + kPrime), this is
484 * one extra addition and small range of |value| tempts some
485 * versions of Clang to emit a branch.
486 */
487 mask = 0u - (value >> 15);
488 out->c[i] = ((value + kPrime) & mask) | (value & ~mask);
489
490 byte >>= 4;
491 value = (byte & 1) + ((byte >> 1) & 1);
492 value -= ((byte >> 2) & 1) + ((byte >> 3) & 1);
493 /* See above. */
494 mask = 0u - (value >> 15);
495 out->c[i + 1] = ((value + kPrime) & mask) | (value & ~mask);
496 }
497 }
498
499 /*
500 * Generates a secret vector by using
501 * |scalar_centered_binomial_distribution_eta_2_with_prf|, using the given seed
502 * appending and incrementing |counter| for entry of the vector.
503 */
504 static void
vector_generate_secret_eta_2(vector * out,uint8_t * counter,const uint8_t seed[32])505 vector_generate_secret_eta_2(vector *out, uint8_t *counter,
506 const uint8_t seed[32])
507 {
508 uint8_t input[33];
509 int i;
510
511 memcpy(input, seed, 32);
512 for (i = 0; i < RANK1024; i++) {
513 input[32] = (*counter)++;
514 scalar_centered_binomial_distribution_eta_2_with_prf(&out->v[i],
515 input);
516 }
517 }
518
519 /* Expands the matrix of a seed for key generation and for encaps-CPA. */
520 static void
matrix_expand(matrix * out,const uint8_t rho[32])521 matrix_expand(matrix *out, const uint8_t rho[32])
522 {
523 uint8_t input[34];
524 int i, j;
525
526 memcpy(input, rho, 32);
527 for (i = 0; i < RANK1024; i++) {
528 for (j = 0; j < RANK1024; j++) {
529 sha3_ctx keccak_ctx;
530
531 input[32] = i;
532 input[33] = j;
533 shake128_init(&keccak_ctx);
534 shake_update(&keccak_ctx, input, sizeof(input));
535 shake_xof(&keccak_ctx);
536 scalar_from_keccak_vartime(&out->v[i][j], &keccak_ctx);
537 }
538 }
539 }
540
541 static const uint8_t kMasks[8] = {0x01, 0x03, 0x07, 0x0f,
542 0x1f, 0x3f, 0x7f, 0xff};
543
544 static void
scalar_encode(uint8_t * out,const scalar * s,int bits)545 scalar_encode(uint8_t *out, const scalar *s, int bits)
546 {
547 uint8_t out_byte = 0;
548 int i, out_byte_bits = 0;
549
550 assert(bits <= (int)sizeof(*s->c) * 8 && bits != 1);
551 for (i = 0; i < DEGREE; i++) {
552 uint16_t element = s->c[i];
553 int element_bits_done = 0;
554
555 while (element_bits_done < bits) {
556 int chunk_bits = bits - element_bits_done;
557 int out_bits_remaining = 8 - out_byte_bits;
558
559 if (chunk_bits >= out_bits_remaining) {
560 chunk_bits = out_bits_remaining;
561 out_byte |= (element &
562 kMasks[chunk_bits - 1]) << out_byte_bits;
563 *out = out_byte;
564 out++;
565 out_byte_bits = 0;
566 out_byte = 0;
567 } else {
568 out_byte |= (element &
569 kMasks[chunk_bits - 1]) << out_byte_bits;
570 out_byte_bits += chunk_bits;
571 }
572
573 element_bits_done += chunk_bits;
574 element >>= chunk_bits;
575 }
576 }
577
578 if (out_byte_bits > 0) {
579 *out = out_byte;
580 }
581 }
582
583 /* scalar_encode_1 is |scalar_encode| specialised for |bits| == 1. */
584 static void
scalar_encode_1(uint8_t out[32],const scalar * s)585 scalar_encode_1(uint8_t out[32], const scalar *s)
586 {
587 int i, j;
588
589 for (i = 0; i < DEGREE; i += 8) {
590 uint8_t out_byte = 0;
591
592 for (j = 0; j < 8; j++) {
593 out_byte |= (s->c[i + j] & 1) << j;
594 }
595 *out = out_byte;
596 out++;
597 }
598 }
599
600 /*
601 * Encodes an entire vector into 32*|RANK1024|*|bits| bytes. Note that since 256
602 * (DEGREE) is divisible by 8, the individual vector entries will always fill a
603 * whole number of bytes, so we do not need to worry about bit packing here.
604 */
605 static void
vector_encode(uint8_t * out,const vector * a,int bits)606 vector_encode(uint8_t *out, const vector *a, int bits)
607 {
608 int i;
609
610 for (i = 0; i < RANK1024; i++) {
611 scalar_encode(out + i * bits * DEGREE / 8, &a->v[i], bits);
612 }
613 }
614
615 /*
616 * scalar_decode parses |DEGREE * bits| bits from |in| into |DEGREE| values in
617 * |out|. It returns one on success and zero if any parsed value is >=
618 * |kPrime|.
619 */
620 static int
scalar_decode(scalar * out,const uint8_t * in,int bits)621 scalar_decode(scalar *out, const uint8_t *in, int bits)
622 {
623 uint8_t in_byte = 0;
624 int i, in_byte_bits_left = 0;
625
626 assert(bits <= (int)sizeof(*out->c) * 8 && bits != 1);
627
628 for (i = 0; i < DEGREE; i++) {
629 uint16_t element = 0;
630 int element_bits_done = 0;
631
632 while (element_bits_done < bits) {
633 int chunk_bits = bits - element_bits_done;
634
635 if (in_byte_bits_left == 0) {
636 in_byte = *in;
637 in++;
638 in_byte_bits_left = 8;
639 }
640
641 if (chunk_bits > in_byte_bits_left) {
642 chunk_bits = in_byte_bits_left;
643 }
644
645 element |= (in_byte & kMasks[chunk_bits - 1]) <<
646 element_bits_done;
647 in_byte_bits_left -= chunk_bits;
648 in_byte >>= chunk_bits;
649
650 element_bits_done += chunk_bits;
651 }
652
653 if (element >= kPrime) {
654 return 0;
655 }
656 out->c[i] = element;
657 }
658
659 return 1;
660 }
661
662 /* scalar_decode_1 is |scalar_decode| specialised for |bits| == 1. */
663 static void
scalar_decode_1(scalar * out,const uint8_t in[32])664 scalar_decode_1(scalar *out, const uint8_t in[32])
665 {
666 int i, j;
667
668 for (i = 0; i < DEGREE; i += 8) {
669 uint8_t in_byte = *in;
670
671 in++;
672 for (j = 0; j < 8; j++) {
673 out->c[i + j] = in_byte & 1;
674 in_byte >>= 1;
675 }
676 }
677 }
678
679 /*
680 * Decodes 32*|RANK1024|*|bits| bytes from |in| into |out|. It returns one on
681 * success or zero if any parsed value is >= |kPrime|.
682 */
683 static int
vector_decode(vector * out,const uint8_t * in,int bits)684 vector_decode(vector *out, const uint8_t *in, int bits)
685 {
686 int i;
687
688 for (i = 0; i < RANK1024; i++) {
689 if (!scalar_decode(&out->v[i], in + i * bits * DEGREE / 8,
690 bits)) {
691 return 0;
692 }
693 }
694 return 1;
695 }
696
697 /*
698 * Compresses (lossily) an input |x| mod 3329 into |bits| many bits by grouping
699 * numbers close to each other together. The formula used is
700 * round(2^|bits|/kPrime*x) mod 2^|bits|.
701 * Uses Barrett reduction to achieve constant time. Since we need both the
702 * remainder (for rounding) and the quotient (as the result), we cannot use
703 * |reduce| here, but need to do the Barrett reduction directly.
704 */
705 static uint16_t
compress(uint16_t x,int bits)706 compress(uint16_t x, int bits)
707 {
708 uint32_t shifted = (uint32_t)x << bits;
709 uint64_t product = (uint64_t)shifted * kBarrettMultiplier;
710 uint32_t quotient = (uint32_t)(product >> kBarrettShift);
711 uint32_t remainder = shifted - quotient * kPrime;
712
713 /*
714 * Adjust the quotient to round correctly:
715 * 0 <= remainder <= kHalfPrime round to 0
716 * kHalfPrime < remainder <= kPrime + kHalfPrime round to 1
717 * kPrime + kHalfPrime < remainder < 2 * kPrime round to 2
718 */
719 assert(remainder < 2u * kPrime);
720 quotient += 1 & constant_time_lt(kHalfPrime, remainder);
721 quotient += 1 & constant_time_lt(kPrime + kHalfPrime, remainder);
722 return quotient & ((1 << bits) - 1);
723 }
724
725 /*
726 * Decompresses |x| by using an equi-distant representative. The formula is
727 * round(kPrime/2^|bits|*x). Note that 2^|bits| being the divisor allows us to
728 * implement this logic using only bit operations.
729 */
730 static uint16_t
decompress(uint16_t x,int bits)731 decompress(uint16_t x, int bits)
732 {
733 uint32_t product = (uint32_t)x * kPrime;
734 uint32_t power = 1 << bits;
735 /* This is |product| % power, since |power| is a power of 2. */
736 uint32_t remainder = product & (power - 1);
737 /* This is |product| / power, since |power| is a power of 2. */
738 uint32_t lower = product >> bits;
739
740 /*
741 * The rounding logic works since the first half of numbers mod |power| have a
742 * 0 as first bit, and the second half has a 1 as first bit, since |power| is
743 * a power of 2. As a 12 bit number, |remainder| is always positive, so we
744 * will shift in 0s for a right shift.
745 */
746 return lower + (remainder >> (bits - 1));
747 }
748
749 static void
scalar_compress(scalar * s,int bits)750 scalar_compress(scalar *s, int bits)
751 {
752 int i;
753
754 for (i = 0; i < DEGREE; i++) {
755 s->c[i] = compress(s->c[i], bits);
756 }
757 }
758
759 static void
scalar_decompress(scalar * s,int bits)760 scalar_decompress(scalar *s, int bits)
761 {
762 int i;
763
764 for (i = 0; i < DEGREE; i++) {
765 s->c[i] = decompress(s->c[i], bits);
766 }
767 }
768
769 static void
vector_compress(vector * a,int bits)770 vector_compress(vector *a, int bits)
771 {
772 int i;
773
774 for (i = 0; i < RANK1024; i++) {
775 scalar_compress(&a->v[i], bits);
776 }
777 }
778
779 static void
vector_decompress(vector * a,int bits)780 vector_decompress(vector *a, int bits)
781 {
782 int i;
783
784 for (i = 0; i < RANK1024; i++) {
785 scalar_decompress(&a->v[i], bits);
786 }
787 }
788
789 struct public_key {
790 vector t;
791 uint8_t rho[32];
792 uint8_t public_key_hash[32];
793 matrix m;
794 };
795
796 static struct public_key *
public_key_1024_from_external(const struct MLKEM1024_public_key * external)797 public_key_1024_from_external(const struct MLKEM1024_public_key *external)
798 {
799 return (struct public_key *)external;
800 }
801
802 struct private_key {
803 struct public_key pub;
804 vector s;
805 uint8_t fo_failure_secret[32];
806 };
807
808 static struct private_key *
private_key_1024_from_external(const struct MLKEM1024_private_key * external)809 private_key_1024_from_external(const struct MLKEM1024_private_key *external)
810 {
811 return (struct private_key *)external;
812 }
813
814 /*
815 * Calls |MLKEM1024_generate_key_external_entropy| with random bytes from
816 * |RAND_bytes|.
817 */
818 void
MLKEM1024_generate_key(uint8_t out_encoded_public_key[MLKEM1024_PUBLIC_KEY_BYTES],uint8_t optional_out_seed[MLKEM_SEED_BYTES],struct MLKEM1024_private_key * out_private_key)819 MLKEM1024_generate_key(uint8_t out_encoded_public_key[MLKEM1024_PUBLIC_KEY_BYTES],
820 uint8_t optional_out_seed[MLKEM_SEED_BYTES],
821 struct MLKEM1024_private_key *out_private_key)
822 {
823 uint8_t entropy_buf[MLKEM_SEED_BYTES];
824 uint8_t *entropy = optional_out_seed != NULL ? optional_out_seed :
825 entropy_buf;
826
827 arc4random_buf(entropy, MLKEM_SEED_BYTES);
828 MLKEM1024_generate_key_external_entropy(out_encoded_public_key,
829 out_private_key, entropy);
830 }
831 LCRYPTO_ALIAS(MLKEM1024_generate_key);
832
833 int
MLKEM1024_private_key_from_seed(struct MLKEM1024_private_key * out_private_key,const uint8_t * seed,size_t seed_len)834 MLKEM1024_private_key_from_seed(struct MLKEM1024_private_key *out_private_key,
835 const uint8_t *seed, size_t seed_len)
836 {
837 uint8_t public_key_bytes[MLKEM1024_PUBLIC_KEY_BYTES];
838
839 if (seed_len != MLKEM_SEED_BYTES) {
840 return 0;
841 }
842 MLKEM1024_generate_key_external_entropy(public_key_bytes,
843 out_private_key, seed);
844
845 return 1;
846 }
847 LCRYPTO_ALIAS(MLKEM1024_private_key_from_seed);
848
849 static int
mlkem_marshal_public_key(CBB * out,const struct public_key * pub)850 mlkem_marshal_public_key(CBB *out, const struct public_key *pub)
851 {
852 uint8_t *vector_output;
853
854 if (!CBB_add_space(out, &vector_output, kEncodedVectorSize)) {
855 return 0;
856 }
857 vector_encode(vector_output, &pub->t, kLog2Prime);
858 if (!CBB_add_bytes(out, pub->rho, sizeof(pub->rho))) {
859 return 0;
860 }
861 return 1;
862 }
863
864 void
MLKEM1024_generate_key_external_entropy(uint8_t out_encoded_public_key[MLKEM1024_PUBLIC_KEY_BYTES],struct MLKEM1024_private_key * out_private_key,const uint8_t entropy[MLKEM_SEED_BYTES])865 MLKEM1024_generate_key_external_entropy(
866 uint8_t out_encoded_public_key[MLKEM1024_PUBLIC_KEY_BYTES],
867 struct MLKEM1024_private_key *out_private_key,
868 const uint8_t entropy[MLKEM_SEED_BYTES])
869 {
870 struct private_key *priv = private_key_1024_from_external(
871 out_private_key);
872 uint8_t augmented_seed[33];
873 uint8_t *rho, *sigma;
874 uint8_t counter = 0;
875 uint8_t hashed[64];
876 vector error;
877 CBB cbb;
878
879 memcpy(augmented_seed, entropy, 32);
880 augmented_seed[32] = RANK1024;
881 hash_g(hashed, augmented_seed, 33);
882 rho = hashed;
883 sigma = hashed + 32;
884 memcpy(priv->pub.rho, hashed, sizeof(priv->pub.rho));
885 matrix_expand(&priv->pub.m, rho);
886 vector_generate_secret_eta_2(&priv->s, &counter, sigma);
887 vector_ntt(&priv->s);
888 vector_generate_secret_eta_2(&error, &counter, sigma);
889 vector_ntt(&error);
890 matrix_mult_transpose(&priv->pub.t, &priv->pub.m, &priv->s);
891 vector_add(&priv->pub.t, &error);
892
893 /* XXX - error checking. */
894 CBB_init_fixed(&cbb, out_encoded_public_key, MLKEM1024_PUBLIC_KEY_BYTES);
895 if (!mlkem_marshal_public_key(&cbb, &priv->pub)) {
896 abort();
897 }
898 CBB_cleanup(&cbb);
899
900 hash_h(priv->pub.public_key_hash, out_encoded_public_key,
901 MLKEM1024_PUBLIC_KEY_BYTES);
902 memcpy(priv->fo_failure_secret, entropy + 32, 32);
903 }
904
905 void
MLKEM1024_public_from_private(struct MLKEM1024_public_key * out_public_key,const struct MLKEM1024_private_key * private_key)906 MLKEM1024_public_from_private(struct MLKEM1024_public_key *out_public_key,
907 const struct MLKEM1024_private_key *private_key)
908 {
909 struct public_key *const pub = public_key_1024_from_external(
910 out_public_key);
911 const struct private_key *const priv = private_key_1024_from_external(
912 private_key);
913
914 *pub = priv->pub;
915 }
916 LCRYPTO_ALIAS(MLKEM1024_public_from_private);
917
918 /*
919 * Encrypts a message with given randomness to the ciphertext in |out|. Without
920 * applying the Fujisaki-Okamoto transform this would not result in a CCA secure
921 * scheme, since lattice schemes are vulnerable to decryption failure oracles.
922 */
923 static void
encrypt_cpa(uint8_t out[MLKEM1024_CIPHERTEXT_BYTES],const struct public_key * pub,const uint8_t message[32],const uint8_t randomness[32])924 encrypt_cpa(uint8_t out[MLKEM1024_CIPHERTEXT_BYTES],
925 const struct public_key *pub, const uint8_t message[32],
926 const uint8_t randomness[32])
927 {
928 scalar expanded_message, scalar_error;
929 vector secret, error, u;
930 uint8_t counter = 0;
931 uint8_t input[33];
932 scalar v;
933
934 vector_generate_secret_eta_2(&secret, &counter, randomness);
935 vector_ntt(&secret);
936 vector_generate_secret_eta_2(&error, &counter, randomness);
937 memcpy(input, randomness, 32);
938 input[32] = counter;
939 scalar_centered_binomial_distribution_eta_2_with_prf(&scalar_error,
940 input);
941 matrix_mult(&u, &pub->m, &secret);
942 vector_inverse_ntt(&u);
943 vector_add(&u, &error);
944 scalar_inner_product(&v, &pub->t, &secret);
945 scalar_inverse_ntt(&v);
946 scalar_add(&v, &scalar_error);
947 scalar_decode_1(&expanded_message, message);
948 scalar_decompress(&expanded_message, 1);
949 scalar_add(&v, &expanded_message);
950 vector_compress(&u, kDU1024);
951 vector_encode(out, &u, kDU1024);
952 scalar_compress(&v, kDV1024);
953 scalar_encode(out + kCompressedVectorSize, &v, kDV1024);
954 }
955
956 /* Calls MLKEM1024_encap_external_entropy| with random bytes */
957 void
MLKEM1024_encap(uint8_t out_ciphertext[MLKEM1024_CIPHERTEXT_BYTES],uint8_t out_shared_secret[MLKEM_SHARED_SECRET_BYTES],const struct MLKEM1024_public_key * public_key)958 MLKEM1024_encap(uint8_t out_ciphertext[MLKEM1024_CIPHERTEXT_BYTES],
959 uint8_t out_shared_secret[MLKEM_SHARED_SECRET_BYTES],
960 const struct MLKEM1024_public_key *public_key)
961 {
962 uint8_t entropy[MLKEM_ENCAP_ENTROPY];
963
964 arc4random_buf(entropy, MLKEM_ENCAP_ENTROPY);
965 MLKEM1024_encap_external_entropy(out_ciphertext, out_shared_secret,
966 public_key, entropy);
967 }
968 LCRYPTO_ALIAS(MLKEM1024_encap);
969
970 /* See section 6.2 of the spec. */
971 void
MLKEM1024_encap_external_entropy(uint8_t out_ciphertext[MLKEM1024_CIPHERTEXT_BYTES],uint8_t out_shared_secret[MLKEM_SHARED_SECRET_BYTES],const struct MLKEM1024_public_key * public_key,const uint8_t entropy[MLKEM_ENCAP_ENTROPY])972 MLKEM1024_encap_external_entropy(
973 uint8_t out_ciphertext[MLKEM1024_CIPHERTEXT_BYTES],
974 uint8_t out_shared_secret[MLKEM_SHARED_SECRET_BYTES],
975 const struct MLKEM1024_public_key *public_key,
976 const uint8_t entropy[MLKEM_ENCAP_ENTROPY])
977 {
978 const struct public_key *pub = public_key_1024_from_external(public_key);
979 uint8_t key_and_randomness[64];
980 uint8_t input[64];
981
982 memcpy(input, entropy, MLKEM_ENCAP_ENTROPY);
983 memcpy(input + MLKEM_ENCAP_ENTROPY, pub->public_key_hash,
984 sizeof(input) - MLKEM_ENCAP_ENTROPY);
985 hash_g(key_and_randomness, input, sizeof(input));
986 encrypt_cpa(out_ciphertext, pub, entropy, key_and_randomness + 32);
987 memcpy(out_shared_secret, key_and_randomness, 32);
988 }
989
990 static void
decrypt_cpa(uint8_t out[32],const struct private_key * priv,const uint8_t ciphertext[MLKEM1024_CIPHERTEXT_BYTES])991 decrypt_cpa(uint8_t out[32], const struct private_key *priv,
992 const uint8_t ciphertext[MLKEM1024_CIPHERTEXT_BYTES])
993 {
994 scalar mask, v;
995 vector u;
996
997 vector_decode(&u, ciphertext, kDU1024);
998 vector_decompress(&u, kDU1024);
999 vector_ntt(&u);
1000 scalar_decode(&v, ciphertext + kCompressedVectorSize, kDV1024);
1001 scalar_decompress(&v, kDV1024);
1002 scalar_inner_product(&mask, &priv->s, &u);
1003 scalar_inverse_ntt(&mask);
1004 scalar_sub(&v, &mask);
1005 scalar_compress(&v, 1);
1006 scalar_encode_1(out, &v);
1007 }
1008
1009 /* See section 6.3 */
1010 int
MLKEM1024_decap(uint8_t out_shared_secret[MLKEM_SHARED_SECRET_BYTES],const uint8_t * ciphertext,size_t ciphertext_len,const struct MLKEM1024_private_key * private_key)1011 MLKEM1024_decap(uint8_t out_shared_secret[MLKEM_SHARED_SECRET_BYTES],
1012 const uint8_t *ciphertext, size_t ciphertext_len,
1013 const struct MLKEM1024_private_key *private_key)
1014 {
1015 const struct private_key *priv = private_key_1024_from_external(
1016 private_key);
1017 uint8_t expected_ciphertext[MLKEM1024_CIPHERTEXT_BYTES];
1018 uint8_t key_and_randomness[64];
1019 uint8_t failure_key[32];
1020 uint8_t decrypted[64];
1021 uint8_t mask;
1022 int i;
1023
1024 if (ciphertext_len != MLKEM1024_CIPHERTEXT_BYTES) {
1025 arc4random_buf(out_shared_secret, MLKEM_SHARED_SECRET_BYTES);
1026 return 0;
1027 }
1028
1029 decrypt_cpa(decrypted, priv, ciphertext);
1030 memcpy(decrypted + 32, priv->pub.public_key_hash,
1031 sizeof(decrypted) - 32);
1032 hash_g(key_and_randomness, decrypted, sizeof(decrypted));
1033 encrypt_cpa(expected_ciphertext, &priv->pub, decrypted,
1034 key_and_randomness + 32);
1035 kdf(failure_key, priv->fo_failure_secret, ciphertext, ciphertext_len);
1036 mask = constant_time_eq_int_8(memcmp(ciphertext, expected_ciphertext,
1037 sizeof(expected_ciphertext)), 0);
1038 for (i = 0; i < MLKEM_SHARED_SECRET_BYTES; i++) {
1039 out_shared_secret[i] = constant_time_select_8(mask,
1040 key_and_randomness[i], failure_key[i]);
1041 }
1042
1043 return 1;
1044 }
1045 LCRYPTO_ALIAS(MLKEM1024_decap);
1046
1047 int
MLKEM1024_marshal_public_key(CBB * out,const struct MLKEM1024_public_key * public_key)1048 MLKEM1024_marshal_public_key(CBB *out,
1049 const struct MLKEM1024_public_key *public_key)
1050 {
1051 return mlkem_marshal_public_key(out,
1052 public_key_1024_from_external(public_key));
1053 }
1054 LCRYPTO_ALIAS(MLKEM1024_marshal_public_key);
1055
1056 /*
1057 * mlkem_parse_public_key_no_hash parses |in| into |pub| but doesn't calculate
1058 * the value of |pub->public_key_hash|.
1059 */
1060 static int
mlkem_parse_public_key_no_hash(struct public_key * pub,CBS * in)1061 mlkem_parse_public_key_no_hash(struct public_key *pub, CBS *in)
1062 {
1063 CBS t_bytes;
1064
1065 if (!CBS_get_bytes(in, &t_bytes, kEncodedVectorSize) ||
1066 !vector_decode(&pub->t, CBS_data(&t_bytes), kLog2Prime)) {
1067 return 0;
1068 }
1069 memcpy(pub->rho, CBS_data(in), sizeof(pub->rho));
1070 if (!CBS_skip(in, sizeof(pub->rho)))
1071 return 0;
1072 matrix_expand(&pub->m, pub->rho);
1073 return 1;
1074 }
1075
1076 int
MLKEM1024_parse_public_key(struct MLKEM1024_public_key * public_key,CBS * in)1077 MLKEM1024_parse_public_key(struct MLKEM1024_public_key *public_key, CBS *in)
1078 {
1079 struct public_key *pub = public_key_1024_from_external(public_key);
1080 CBS orig_in = *in;
1081
1082 if (!mlkem_parse_public_key_no_hash(pub, in) ||
1083 CBS_len(in) != 0) {
1084 return 0;
1085 }
1086 hash_h(pub->public_key_hash, CBS_data(&orig_in), CBS_len(&orig_in));
1087 return 1;
1088 }
1089 LCRYPTO_ALIAS(MLKEM1024_parse_public_key);
1090
1091 int
MLKEM1024_marshal_private_key(CBB * out,const struct MLKEM1024_private_key * private_key)1092 MLKEM1024_marshal_private_key(CBB *out,
1093 const struct MLKEM1024_private_key *private_key)
1094 {
1095 const struct private_key *const priv = private_key_1024_from_external(
1096 private_key);
1097 uint8_t *s_output;
1098
1099 if (!CBB_add_space(out, &s_output, kEncodedVectorSize)) {
1100 return 0;
1101 }
1102 vector_encode(s_output, &priv->s, kLog2Prime);
1103 if (!mlkem_marshal_public_key(out, &priv->pub) ||
1104 !CBB_add_bytes(out, priv->pub.public_key_hash,
1105 sizeof(priv->pub.public_key_hash)) ||
1106 !CBB_add_bytes(out, priv->fo_failure_secret,
1107 sizeof(priv->fo_failure_secret))) {
1108 return 0;
1109 }
1110 return 1;
1111 }
1112
1113 int
MLKEM1024_parse_private_key(struct MLKEM1024_private_key * out_private_key,CBS * in)1114 MLKEM1024_parse_private_key(struct MLKEM1024_private_key *out_private_key,
1115 CBS *in)
1116 {
1117 struct private_key *const priv = private_key_1024_from_external(
1118 out_private_key);
1119 CBS s_bytes;
1120
1121 if (!CBS_get_bytes(in, &s_bytes, kEncodedVectorSize) ||
1122 !vector_decode(&priv->s, CBS_data(&s_bytes), kLog2Prime) ||
1123 !mlkem_parse_public_key_no_hash(&priv->pub, in)) {
1124 return 0;
1125 }
1126 memcpy(priv->pub.public_key_hash, CBS_data(in),
1127 sizeof(priv->pub.public_key_hash));
1128 if (!CBS_skip(in, sizeof(priv->pub.public_key_hash)))
1129 return 0;
1130 memcpy(priv->fo_failure_secret, CBS_data(in),
1131 sizeof(priv->fo_failure_secret));
1132 if (!CBS_skip(in, sizeof(priv->fo_failure_secret)))
1133 return 0;
1134 if (CBS_len(in) != 0)
1135 return 0;
1136
1137 return 1;
1138 }
1139 LCRYPTO_ALIAS(MLKEM1024_parse_private_key);
1140