1 /* $OpenBSD: mlkem768.c,v 1.7 2025/01/03 08:19:24 tb Exp $ */
2 /*
3 * Copyright (c) 2024, Google Inc.
4 * Copyright (c) 2024, Bob Beck <beck@obtuse.com>
5 *
6 * Permission to use, copy, modify, and/or distribute this software for any
7 * purpose with or without fee is hereby granted, provided that the above
8 * copyright notice and this permission notice appear in all copies.
9 *
10 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
11 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
12 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
13 * SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
14 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
15 * OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
16 * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
17 */
18
19 #include <assert.h>
20 #include <stdlib.h>
21 #include <string.h>
22
23 #include "bytestring.h"
24 #include "mlkem.h"
25
26 #include "sha3_internal.h"
27 #include "mlkem_internal.h"
28 #include "constant_time.h"
29 #include "crypto_internal.h"
30
31 /* Remove later */
32 #undef LCRYPTO_ALIAS
33 #define LCRYPTO_ALIAS(A)
34
35 /*
36 * See
37 * https://csrc.nist.gov/pubs/fips/203/final
38 */
39
40 static void
prf(uint8_t * out,size_t out_len,const uint8_t in[33])41 prf(uint8_t *out, size_t out_len, const uint8_t in[33])
42 {
43 sha3_ctx ctx;
44 shake256_init(&ctx);
45 shake_update(&ctx, in, 33);
46 shake_xof(&ctx);
47 shake_out(&ctx, out, out_len);
48 }
49
50 /* Section 4.1 */
51 static void
hash_h(uint8_t out[32],const uint8_t * in,size_t len)52 hash_h(uint8_t out[32], const uint8_t *in, size_t len)
53 {
54 sha3_ctx ctx;
55 sha3_init(&ctx, 32);
56 sha3_update(&ctx, in, len);
57 sha3_final(out, &ctx);
58 }
59
60 static void
hash_g(uint8_t out[64],const uint8_t * in,size_t len)61 hash_g(uint8_t out[64], const uint8_t *in, size_t len)
62 {
63 sha3_ctx ctx;
64 sha3_init(&ctx, 64);
65 sha3_update(&ctx, in, len);
66 sha3_final(out, &ctx);
67 }
68
69 /* this is called 'J' in the spec */
70 static void
kdf(uint8_t out[MLKEM_SHARED_SECRET_BYTES],const uint8_t failure_secret[32],const uint8_t * in,size_t len)71 kdf(uint8_t out[MLKEM_SHARED_SECRET_BYTES], const uint8_t failure_secret[32],
72 const uint8_t *in, size_t len)
73 {
74 sha3_ctx ctx;
75 shake256_init(&ctx);
76 shake_update(&ctx, failure_secret, 32);
77 shake_update(&ctx, in, len);
78 shake_xof(&ctx);
79 shake_out(&ctx, out, MLKEM_SHARED_SECRET_BYTES);
80 }
81
82 #define DEGREE 256
83 #define RANK768 3
84
85 static const size_t kBarrettMultiplier = 5039;
86 static const unsigned kBarrettShift = 24;
87 static const uint16_t kPrime = 3329;
88 static const int kLog2Prime = 12;
89 static const uint16_t kHalfPrime = (/*kPrime=*/3329 - 1) / 2;
90 static const int kDU768 = 10;
91 static const int kDV768 = 4;
92 /*
93 * kInverseDegree is 128^-1 mod 3329; 128 because kPrime does not have a 512th
94 * root of unity.
95 */
96 static const uint16_t kInverseDegree = 3303;
97 static const size_t kEncodedVectorSize =
98 (/*kLog2Prime=*/12 * DEGREE / 8) * RANK768;
99 static const size_t kCompressedVectorSize = /*kDU768=*/ 10 * RANK768 * DEGREE /
100 8;
101
102 typedef struct scalar {
103 /* On every function entry and exit, 0 <= c < kPrime. */
104 uint16_t c[DEGREE];
105 } scalar;
106
107 typedef struct vector {
108 scalar v[RANK768];
109 } vector;
110
111 typedef struct matrix {
112 scalar v[RANK768][RANK768];
113 } matrix;
114
115 /*
116 * This bit of Python will be referenced in some of the following comments:
117 *
118 * p = 3329
119 *
120 * def bitreverse(i):
121 * ret = 0
122 * for n in range(7):
123 * bit = i & 1
124 * ret <<= 1
125 * ret |= bit
126 * i >>= 1
127 * return ret
128 */
129
130 /* kNTTRoots = [pow(17, bitreverse(i), p) for i in range(128)] */
131 static const uint16_t kNTTRoots[128] = {
132 1, 1729, 2580, 3289, 2642, 630, 1897, 848, 1062, 1919, 193, 797,
133 2786, 3260, 569, 1746, 296, 2447, 1339, 1476, 3046, 56, 2240, 1333,
134 1426, 2094, 535, 2882, 2393, 2879, 1974, 821, 289, 331, 3253, 1756,
135 1197, 2304, 2277, 2055, 650, 1977, 2513, 632, 2865, 33, 1320, 1915,
136 2319, 1435, 807, 452, 1438, 2868, 1534, 2402, 2647, 2617, 1481, 648,
137 2474, 3110, 1227, 910, 17, 2761, 583, 2649, 1637, 723, 2288, 1100,
138 1409, 2662, 3281, 233, 756, 2156, 3015, 3050, 1703, 1651, 2789, 1789,
139 1847, 952, 1461, 2687, 939, 2308, 2437, 2388, 733, 2337, 268, 641,
140 1584, 2298, 2037, 3220, 375, 2549, 2090, 1645, 1063, 319, 2773, 757,
141 2099, 561, 2466, 2594, 2804, 1092, 403, 1026, 1143, 2150, 2775, 886,
142 1722, 1212, 1874, 1029, 2110, 2935, 885, 2154,
143 };
144
145 /* kInverseNTTRoots = [pow(17, -bitreverse(i), p) for i in range(128)] */
146 static const uint16_t kInverseNTTRoots[128] = {
147 1, 1600, 40, 749, 2481, 1432, 2699, 687, 1583, 2760, 69, 543,
148 2532, 3136, 1410, 2267, 2508, 1355, 450, 936, 447, 2794, 1235, 1903,
149 1996, 1089, 3273, 283, 1853, 1990, 882, 3033, 2419, 2102, 219, 855,
150 2681, 1848, 712, 682, 927, 1795, 461, 1891, 2877, 2522, 1894, 1010,
151 1414, 2009, 3296, 464, 2697, 816, 1352, 2679, 1274, 1052, 1025, 2132,
152 1573, 76, 2998, 3040, 1175, 2444, 394, 1219, 2300, 1455, 2117, 1607,
153 2443, 554, 1179, 2186, 2303, 2926, 2237, 525, 735, 863, 2768, 1230,
154 2572, 556, 3010, 2266, 1684, 1239, 780, 2954, 109, 1292, 1031, 1745,
155 2688, 3061, 992, 2596, 941, 892, 1021, 2390, 642, 1868, 2377, 1482,
156 1540, 540, 1678, 1626, 279, 314, 1173, 2573, 3096, 48, 667, 1920,
157 2229, 1041, 2606, 1692, 680, 2746, 568, 3312,
158 };
159
160 /* kModRoots = [pow(17, 2*bitreverse(i) + 1, p) for i in range(128)] */
161 static const uint16_t kModRoots[128] = {
162 17, 3312, 2761, 568, 583, 2746, 2649, 680, 1637, 1692, 723, 2606,
163 2288, 1041, 1100, 2229, 1409, 1920, 2662, 667, 3281, 48, 233, 3096,
164 756, 2573, 2156, 1173, 3015, 314, 3050, 279, 1703, 1626, 1651, 1678,
165 2789, 540, 1789, 1540, 1847, 1482, 952, 2377, 1461, 1868, 2687, 642,
166 939, 2390, 2308, 1021, 2437, 892, 2388, 941, 733, 2596, 2337, 992,
167 268, 3061, 641, 2688, 1584, 1745, 2298, 1031, 2037, 1292, 3220, 109,
168 375, 2954, 2549, 780, 2090, 1239, 1645, 1684, 1063, 2266, 319, 3010,
169 2773, 556, 757, 2572, 2099, 1230, 561, 2768, 2466, 863, 2594, 735,
170 2804, 525, 1092, 2237, 403, 2926, 1026, 2303, 1143, 2186, 2150, 1179,
171 2775, 554, 886, 2443, 1722, 1607, 1212, 2117, 1874, 1455, 1029, 2300,
172 2110, 1219, 2935, 394, 885, 2444, 2154, 1175,
173 };
174
175 /* reduce_once reduces 0 <= x < 2*kPrime, mod kPrime. */
176 static uint16_t
reduce_once(uint16_t x)177 reduce_once(uint16_t x)
178 {
179 assert(x < 2 * kPrime);
180 const uint16_t subtracted = x - kPrime;
181 uint16_t mask = 0u - (subtracted >> 15);
182
183 /*
184 * Although this is a constant-time select, we omit a value barrier here.
185 * Value barriers impede auto-vectorization (likely because it forces the
186 * value to transit through a general-purpose register). On AArch64, this
187 * is a difference of 2x.
188 *
189 * We usually add value barriers to selects because Clang turns
190 * consecutive selects with the same condition into a branch instead of
191 * CMOV/CSEL. This condition does not occur in ML-KEM, so omitting it
192 * seems to be safe so far but see
193 * |scalar_centered_binomial_distribution_eta_2_with_prf|.
194 */
195 return (mask & x) | (~mask & subtracted);
196 }
197
198 /*
199 * constant time reduce x mod kPrime using Barrett reduction. x must be less
200 * than kPrime + 2×kPrime².
201 */
202 static uint16_t
reduce(uint32_t x)203 reduce(uint32_t x)
204 {
205 uint64_t product = (uint64_t)x * kBarrettMultiplier;
206 uint32_t quotient = (uint32_t)(product >> kBarrettShift);
207 uint32_t remainder = x - quotient * kPrime;
208
209 assert(x < kPrime + 2u * kPrime * kPrime);
210 return reduce_once(remainder);
211 }
212
213 static void
scalar_zero(scalar * out)214 scalar_zero(scalar *out)
215 {
216 memset(out, 0, sizeof(*out));
217 }
218
219 static void
vector_zero(vector * out)220 vector_zero(vector *out)
221 {
222 memset(out, 0, sizeof(*out));
223 }
224
225 /*
226 * In place number theoretic transform of a given scalar.
227 * Note that MLKEM's kPrime 3329 does not have a 512th root of unity, so this
228 * transform leaves off the last iteration of the usual FFT code, with the 128
229 * relevant roots of unity being stored in |kNTTRoots|. This means the output
230 * should be seen as 128 elements in GF(3329^2), with the coefficients of the
231 * elements being consecutive entries in |s->c|.
232 */
233 static void
scalar_ntt(scalar * s)234 scalar_ntt(scalar *s)
235 {
236 int offset = DEGREE;
237 int step;
238 /*
239 * `int` is used here because using `size_t` throughout caused a ~5% slowdown
240 * with Clang 14 on Aarch64.
241 */
242 for (step = 1; step < DEGREE / 2; step <<= 1) {
243 int i, j, k = 0;
244
245 offset >>= 1;
246 for (i = 0; i < step; i++) {
247 const uint32_t step_root = kNTTRoots[i + step];
248
249 for (j = k; j < k + offset; j++) {
250 uint16_t odd, even;
251
252 odd = reduce(step_root * s->c[j + offset]);
253 even = s->c[j];
254 s->c[j] = reduce_once(odd + even);
255 s->c[j + offset] = reduce_once(even - odd +
256 kPrime);
257 }
258 k += 2 * offset;
259 }
260 }
261 }
262
263 static void
vector_ntt(vector * a)264 vector_ntt(vector *a)
265 {
266 int i;
267
268 for (i = 0; i < RANK768; i++) {
269 scalar_ntt(&a->v[i]);
270 }
271 }
272
273 /*
274 * In place inverse number theoretic transform of a given scalar, with pairs of
275 * entries of s->v being interpreted as elements of GF(3329^2). Just as with the
276 * number theoretic transform, this leaves off the first step of the normal iFFT
277 * to account for the fact that 3329 does not have a 512th root of unity, using
278 * the precomputed 128 roots of unity stored in |kInverseNTTRoots|.
279 */
280 static void
scalar_inverse_ntt(scalar * s)281 scalar_inverse_ntt(scalar *s)
282 {
283 int i, j, k, offset, step = DEGREE / 2;
284
285 /*
286 * `int` is used here because using `size_t` throughout caused a ~5% slowdown
287 * with Clang 14 on Aarch64.
288 */
289 for (offset = 2; offset < DEGREE; offset <<= 1) {
290 step >>= 1;
291 k = 0;
292 for (i = 0; i < step; i++) {
293 uint32_t step_root = kInverseNTTRoots[i + step];
294 for (j = k; j < k + offset; j++) {
295 uint16_t odd, even;
296 odd = s->c[j + offset];
297 even = s->c[j];
298 s->c[j] = reduce_once(odd + even);
299 s->c[j + offset] = reduce(step_root *
300 (even - odd + kPrime));
301 }
302 k += 2 * offset;
303 }
304 }
305 for (i = 0; i < DEGREE; i++) {
306 s->c[i] = reduce(s->c[i] * kInverseDegree);
307 }
308 }
309
310 static void
vector_inverse_ntt(vector * a)311 vector_inverse_ntt(vector *a)
312 {
313 int i;
314
315 for (i = 0; i < RANK768; i++) {
316 scalar_inverse_ntt(&a->v[i]);
317 }
318 }
319
320 static void
scalar_add(scalar * lhs,const scalar * rhs)321 scalar_add(scalar *lhs, const scalar *rhs)
322 {
323 int i;
324
325 for (i = 0; i < DEGREE; i++) {
326 lhs->c[i] = reduce_once(lhs->c[i] + rhs->c[i]);
327 }
328 }
329
330 static void
scalar_sub(scalar * lhs,const scalar * rhs)331 scalar_sub(scalar *lhs, const scalar *rhs)
332 {
333 int i;
334
335 for (i = 0; i < DEGREE; i++) {
336 lhs->c[i] = reduce_once(lhs->c[i] - rhs->c[i] + kPrime);
337 }
338 }
339
340 /*
341 * Multiplying two scalars in the number theoretically transformed state.
342 * Since 3329 does not have a 512th root of unity, this means we have to
343 * interpret the 2*ith and (2*i+1)th entries of the scalar as elements of
344 * GF(3329)[X]/(X^2 - 17^(2*bitreverse(i)+1)).
345 * The value of 17^(2*bitreverse(i)+1) mod 3329 is stored in the precomputed
346 * |kModRoots| table. Our Barrett transform only allows us to multiply two
347 * reduced numbers together, so we need some intermediate reduction steps,
348 * even if an uint64_t could hold 3 multiplied numbers.
349 */
350 static void
scalar_mult(scalar * out,const scalar * lhs,const scalar * rhs)351 scalar_mult(scalar *out, const scalar *lhs, const scalar *rhs)
352 {
353 int i;
354
355 for (i = 0; i < DEGREE / 2; i++) {
356 uint32_t real_real = (uint32_t)lhs->c[2 * i] * rhs->c[2 * i];
357 uint32_t img_img = (uint32_t)lhs->c[2 * i + 1] *
358 rhs->c[2 * i + 1];
359 uint32_t real_img = (uint32_t)lhs->c[2 * i] * rhs->c[2 * i + 1];
360 uint32_t img_real = (uint32_t)lhs->c[2 * i + 1] * rhs->c[2 * i];
361
362 out->c[2 * i] =
363 reduce(real_real +
364 (uint32_t)reduce(img_img) * kModRoots[i]);
365 out->c[2 * i + 1] = reduce(img_real + real_img);
366 }
367 }
368
369 static void
vector_add(vector * lhs,const vector * rhs)370 vector_add(vector *lhs, const vector *rhs)
371 {
372 int i;
373
374 for (i = 0; i < RANK768; i++) {
375 scalar_add(&lhs->v[i], &rhs->v[i]);
376 }
377 }
378
379 static void
matrix_mult(vector * out,const matrix * m,const vector * a)380 matrix_mult(vector *out, const matrix *m, const vector *a)
381 {
382 int i, j;
383
384 vector_zero(out);
385 for (i = 0; i < RANK768; i++) {
386 for (j = 0; j < RANK768; j++) {
387 scalar product;
388
389 scalar_mult(&product, &m->v[i][j], &a->v[j]);
390 scalar_add(&out->v[i], &product);
391 }
392 }
393 }
394
395 static void
matrix_mult_transpose(vector * out,const matrix * m,const vector * a)396 matrix_mult_transpose(vector *out, const matrix *m,
397 const vector *a)
398 {
399 int i, j;
400
401 vector_zero(out);
402 for (i = 0; i < RANK768; i++) {
403 for (j = 0; j < RANK768; j++) {
404 scalar product;
405
406 scalar_mult(&product, &m->v[j][i], &a->v[j]);
407 scalar_add(&out->v[i], &product);
408 }
409 }
410 }
411
412 static void
scalar_inner_product(scalar * out,const vector * lhs,const vector * rhs)413 scalar_inner_product(scalar *out, const vector *lhs,
414 const vector *rhs)
415 {
416 int i;
417 scalar_zero(out);
418 for (i = 0; i < RANK768; i++) {
419 scalar product;
420
421 scalar_mult(&product, &lhs->v[i], &rhs->v[i]);
422 scalar_add(out, &product);
423 }
424 }
425
426 /*
427 * Algorithm 6 of spec. Rejection samples a Keccak stream to get uniformly
428 * distributed elements. This is used for matrix expansion and only operates on
429 * public inputs.
430 */
431 static void
scalar_from_keccak_vartime(scalar * out,sha3_ctx * keccak_ctx)432 scalar_from_keccak_vartime(scalar *out, sha3_ctx *keccak_ctx)
433 {
434 int i, done = 0;
435
436 while (done < DEGREE) {
437 uint8_t block[168];
438
439 shake_out(keccak_ctx, block, sizeof(block));
440 for (i = 0; i < sizeof(block) && done < DEGREE; i += 3) {
441 uint16_t d1 = block[i] + 256 * (block[i + 1] % 16);
442 uint16_t d2 = block[i + 1] / 16 + 16 * block[i + 2];
443
444 if (d1 < kPrime) {
445 out->c[done++] = d1;
446 }
447 if (d2 < kPrime && done < DEGREE) {
448 out->c[done++] = d2;
449 }
450 }
451 }
452 }
453
454 /*
455 * Algorithm 7 of the spec, with eta fixed to two and the PRF call
456 * included. Creates binominally distributed elements by sampling 2*|eta| bits,
457 * and setting the coefficient to the count of the first bits minus the count of
458 * the second bits, resulting in a centered binomial distribution. Since eta is
459 * two this gives -2/2 with a probability of 1/16, -1/1 with probability 1/4,
460 * and 0 with probability 3/8.
461 */
462 static void
scalar_centered_binomial_distribution_eta_2_with_prf(scalar * out,const uint8_t input[33])463 scalar_centered_binomial_distribution_eta_2_with_prf(scalar *out,
464 const uint8_t input[33])
465 {
466 uint8_t entropy[128];
467 int i;
468
469 CTASSERT(sizeof(entropy) == 2 * /*kEta=*/ 2 * DEGREE / 8);
470 prf(entropy, sizeof(entropy), input);
471
472 for (i = 0; i < DEGREE; i += 2) {
473 uint8_t byte = entropy[i / 2];
474 uint16_t mask;
475 uint16_t value = (byte & 1) + ((byte >> 1) & 1);
476
477 value -= ((byte >> 2) & 1) + ((byte >> 3) & 1);
478
479 /*
480 * Add |kPrime| if |value| underflowed. See |reduce_once| for a
481 * discussion on why the value barrier is omitted. While this
482 * could have been written reduce_once(value + kPrime), this is
483 * one extra addition and small range of |value| tempts some
484 * versions of Clang to emit a branch.
485 */
486 mask = 0u - (value >> 15);
487 out->c[i] = ((value + kPrime) & mask) | (value & ~mask);
488
489 byte >>= 4;
490 value = (byte & 1) + ((byte >> 1) & 1);
491 value -= ((byte >> 2) & 1) + ((byte >> 3) & 1);
492 /* See above. */
493 mask = 0u - (value >> 15);
494 out->c[i + 1] = ((value + kPrime) & mask) | (value & ~mask);
495 }
496 }
497
498 /*
499 * Generates a secret vector by using
500 * |scalar_centered_binomial_distribution_eta_2_with_prf|, using the given seed
501 * appending and incrementing |counter| for entry of the vector.
502 */
503 static void
vector_generate_secret_eta_2(vector * out,uint8_t * counter,const uint8_t seed[32])504 vector_generate_secret_eta_2(vector *out, uint8_t *counter,
505 const uint8_t seed[32])
506 {
507 uint8_t input[33];
508 int i;
509
510 memcpy(input, seed, 32);
511 for (i = 0; i < RANK768; i++) {
512 input[32] = (*counter)++;
513 scalar_centered_binomial_distribution_eta_2_with_prf(&out->v[i],
514 input);
515 }
516 }
517
518 /* Expands the matrix of a seed for key generation and for encaps-CPA. */
519 static void
matrix_expand(matrix * out,const uint8_t rho[32])520 matrix_expand(matrix *out, const uint8_t rho[32])
521 {
522 uint8_t input[34];
523 int i, j;
524
525 memcpy(input, rho, 32);
526 for (i = 0; i < RANK768; i++) {
527 for (j = 0; j < RANK768; j++) {
528 sha3_ctx keccak_ctx;
529
530 input[32] = i;
531 input[33] = j;
532 shake128_init(&keccak_ctx);
533 shake_update(&keccak_ctx, input, sizeof(input));
534 shake_xof(&keccak_ctx);
535 scalar_from_keccak_vartime(&out->v[i][j], &keccak_ctx);
536 }
537 }
538 }
539
540 static const uint8_t kMasks[8] = {0x01, 0x03, 0x07, 0x0f,
541 0x1f, 0x3f, 0x7f, 0xff};
542
543 static void
scalar_encode(uint8_t * out,const scalar * s,int bits)544 scalar_encode(uint8_t *out, const scalar *s, int bits)
545 {
546 uint8_t out_byte = 0;
547 int i, out_byte_bits = 0;
548
549 assert(bits <= (int)sizeof(*s->c) * 8 && bits != 1);
550 for (i = 0; i < DEGREE; i++) {
551 uint16_t element = s->c[i];
552 int element_bits_done = 0;
553
554 while (element_bits_done < bits) {
555 int chunk_bits = bits - element_bits_done;
556 int out_bits_remaining = 8 - out_byte_bits;
557
558 if (chunk_bits >= out_bits_remaining) {
559 chunk_bits = out_bits_remaining;
560 out_byte |= (element &
561 kMasks[chunk_bits - 1]) << out_byte_bits;
562 *out = out_byte;
563 out++;
564 out_byte_bits = 0;
565 out_byte = 0;
566 } else {
567 out_byte |= (element &
568 kMasks[chunk_bits - 1]) << out_byte_bits;
569 out_byte_bits += chunk_bits;
570 }
571
572 element_bits_done += chunk_bits;
573 element >>= chunk_bits;
574 }
575 }
576
577 if (out_byte_bits > 0) {
578 *out = out_byte;
579 }
580 }
581
582 /* scalar_encode_1 is |scalar_encode| specialised for |bits| == 1. */
583 static void
scalar_encode_1(uint8_t out[32],const scalar * s)584 scalar_encode_1(uint8_t out[32], const scalar *s)
585 {
586 int i, j;
587
588 for (i = 0; i < DEGREE; i += 8) {
589 uint8_t out_byte = 0;
590
591 for (j = 0; j < 8; j++) {
592 out_byte |= (s->c[i + j] & 1) << j;
593 }
594 *out = out_byte;
595 out++;
596 }
597 }
598
599 /*
600 * Encodes an entire vector into 32*|RANK768|*|bits| bytes. Note that since 256
601 * (DEGREE) is divisible by 8, the individual vector entries will always fill a
602 * whole number of bytes, so we do not need to worry about bit packing here.
603 */
604 static void
vector_encode(uint8_t * out,const vector * a,int bits)605 vector_encode(uint8_t *out, const vector *a, int bits)
606 {
607 int i;
608
609 for (i = 0; i < RANK768; i++) {
610 scalar_encode(out + i * bits * DEGREE / 8, &a->v[i], bits);
611 }
612 }
613
614 /*
615 * scalar_decode parses |DEGREE * bits| bits from |in| into |DEGREE| values in
616 * |out|. It returns one on success and zero if any parsed value is >=
617 * |kPrime|.
618 */
619 static int
scalar_decode(scalar * out,const uint8_t * in,int bits)620 scalar_decode(scalar *out, const uint8_t *in, int bits)
621 {
622 uint8_t in_byte = 0;
623 int i, in_byte_bits_left = 0;
624
625 assert(bits <= (int)sizeof(*out->c) * 8 && bits != 1);
626
627 for (i = 0; i < DEGREE; i++) {
628 uint16_t element = 0;
629 int element_bits_done = 0;
630
631 while (element_bits_done < bits) {
632 int chunk_bits = bits - element_bits_done;
633
634 if (in_byte_bits_left == 0) {
635 in_byte = *in;
636 in++;
637 in_byte_bits_left = 8;
638 }
639
640 if (chunk_bits > in_byte_bits_left) {
641 chunk_bits = in_byte_bits_left;
642 }
643
644 element |= (in_byte & kMasks[chunk_bits - 1]) <<
645 element_bits_done;
646 in_byte_bits_left -= chunk_bits;
647 in_byte >>= chunk_bits;
648
649 element_bits_done += chunk_bits;
650 }
651
652 if (element >= kPrime) {
653 return 0;
654 }
655 out->c[i] = element;
656 }
657
658 return 1;
659 }
660
661 /* scalar_decode_1 is |scalar_decode| specialised for |bits| == 1. */
662 static void
scalar_decode_1(scalar * out,const uint8_t in[32])663 scalar_decode_1(scalar *out, const uint8_t in[32])
664 {
665 int i, j;
666
667 for (i = 0; i < DEGREE; i += 8) {
668 uint8_t in_byte = *in;
669
670 in++;
671 for (j = 0; j < 8; j++) {
672 out->c[i + j] = in_byte & 1;
673 in_byte >>= 1;
674 }
675 }
676 }
677
678 /*
679 * Decodes 32*|RANK768|*|bits| bytes from |in| into |out|. It returns one on
680 * success or zero if any parsed value is >= |kPrime|.
681 */
682 static int
vector_decode(vector * out,const uint8_t * in,int bits)683 vector_decode(vector *out, const uint8_t *in, int bits)
684 {
685 int i;
686
687 for (i = 0; i < RANK768; i++) {
688 if (!scalar_decode(&out->v[i], in + i * bits * DEGREE / 8,
689 bits)) {
690 return 0;
691 }
692 }
693 return 1;
694 }
695
696 /*
697 * Compresses (lossily) an input |x| mod 3329 into |bits| many bits by grouping
698 * numbers close to each other together. The formula used is
699 * round(2^|bits|/kPrime*x) mod 2^|bits|.
700 * Uses Barrett reduction to achieve constant time. Since we need both the
701 * remainder (for rounding) and the quotient (as the result), we cannot use
702 * |reduce| here, but need to do the Barrett reduction directly.
703 */
704 static uint16_t
compress(uint16_t x,int bits)705 compress(uint16_t x, int bits)
706 {
707 uint32_t shifted = (uint32_t)x << bits;
708 uint64_t product = (uint64_t)shifted * kBarrettMultiplier;
709 uint32_t quotient = (uint32_t)(product >> kBarrettShift);
710 uint32_t remainder = shifted - quotient * kPrime;
711
712 /*
713 * Adjust the quotient to round correctly:
714 * 0 <= remainder <= kHalfPrime round to 0
715 * kHalfPrime < remainder <= kPrime + kHalfPrime round to 1
716 * kPrime + kHalfPrime < remainder < 2 * kPrime round to 2
717 */
718 assert(remainder < 2u * kPrime);
719 quotient += 1 & constant_time_lt(kHalfPrime, remainder);
720 quotient += 1 & constant_time_lt(kPrime + kHalfPrime, remainder);
721 return quotient & ((1 << bits) - 1);
722 }
723
724 /*
725 * Decompresses |x| by using an equi-distant representative. The formula is
726 * round(kPrime/2^|bits|*x). Note that 2^|bits| being the divisor allows us to
727 * implement this logic using only bit operations.
728 */
729 static uint16_t
decompress(uint16_t x,int bits)730 decompress(uint16_t x, int bits)
731 {
732 uint32_t product = (uint32_t)x * kPrime;
733 uint32_t power = 1 << bits;
734 /* This is |product| % power, since |power| is a power of 2. */
735 uint32_t remainder = product & (power - 1);
736 /* This is |product| / power, since |power| is a power of 2. */
737 uint32_t lower = product >> bits;
738
739 /*
740 * The rounding logic works since the first half of numbers mod |power| have a
741 * 0 as first bit, and the second half has a 1 as first bit, since |power| is
742 * a power of 2. As a 12 bit number, |remainder| is always positive, so we
743 * will shift in 0s for a right shift.
744 */
745 return lower + (remainder >> (bits - 1));
746 }
747
748 static void
scalar_compress(scalar * s,int bits)749 scalar_compress(scalar *s, int bits)
750 {
751 int i;
752
753 for (i = 0; i < DEGREE; i++) {
754 s->c[i] = compress(s->c[i], bits);
755 }
756 }
757
758 static void
scalar_decompress(scalar * s,int bits)759 scalar_decompress(scalar *s, int bits)
760 {
761 int i;
762
763 for (i = 0; i < DEGREE; i++) {
764 s->c[i] = decompress(s->c[i], bits);
765 }
766 }
767
768 static void
vector_compress(vector * a,int bits)769 vector_compress(vector *a, int bits)
770 {
771 int i;
772
773 for (i = 0; i < RANK768; i++) {
774 scalar_compress(&a->v[i], bits);
775 }
776 }
777
778 static void
vector_decompress(vector * a,int bits)779 vector_decompress(vector *a, int bits)
780 {
781 int i;
782
783 for (i = 0; i < RANK768; i++) {
784 scalar_decompress(&a->v[i], bits);
785 }
786 }
787
788 struct public_key {
789 vector t;
790 uint8_t rho[32];
791 uint8_t public_key_hash[32];
792 matrix m;
793 };
794
795 static struct public_key *
public_key_768_from_external(const struct MLKEM768_public_key * external)796 public_key_768_from_external(const struct MLKEM768_public_key *external)
797 {
798 return (struct public_key *)external;
799 }
800
801 struct private_key {
802 struct public_key pub;
803 vector s;
804 uint8_t fo_failure_secret[32];
805 };
806
807 static struct private_key *
private_key_768_from_external(const struct MLKEM768_private_key * external)808 private_key_768_from_external(const struct MLKEM768_private_key *external)
809 {
810 return (struct private_key *)external;
811 }
812
813 /*
814 * Calls |MLKEM768_generate_key_external_entropy| with random bytes from
815 * |RAND_bytes|.
816 */
817 void
MLKEM768_generate_key(uint8_t out_encoded_public_key[MLKEM768_PUBLIC_KEY_BYTES],uint8_t optional_out_seed[MLKEM_SEED_BYTES],struct MLKEM768_private_key * out_private_key)818 MLKEM768_generate_key(uint8_t out_encoded_public_key[MLKEM768_PUBLIC_KEY_BYTES],
819 uint8_t optional_out_seed[MLKEM_SEED_BYTES],
820 struct MLKEM768_private_key *out_private_key)
821 {
822 uint8_t entropy_buf[MLKEM_SEED_BYTES];
823 uint8_t *entropy = optional_out_seed != NULL ? optional_out_seed :
824 entropy_buf;
825
826 arc4random_buf(entropy, MLKEM_SEED_BYTES);
827 MLKEM768_generate_key_external_entropy(out_encoded_public_key,
828 out_private_key, entropy);
829 }
830 LCRYPTO_ALIAS(MLKEM768_generate_key);
831
832 int
MLKEM768_private_key_from_seed(struct MLKEM768_private_key * out_private_key,const uint8_t * seed,size_t seed_len)833 MLKEM768_private_key_from_seed(struct MLKEM768_private_key *out_private_key,
834 const uint8_t *seed, size_t seed_len)
835 {
836 uint8_t public_key_bytes[MLKEM768_PUBLIC_KEY_BYTES];
837
838 if (seed_len != MLKEM_SEED_BYTES) {
839 return 0;
840 }
841 MLKEM768_generate_key_external_entropy(public_key_bytes,
842 out_private_key, seed);
843
844 return 1;
845 }
846 LCRYPTO_ALIAS(MLKEM768_private_key_from_seed);
847
848 static int
mlkem_marshal_public_key(CBB * out,const struct public_key * pub)849 mlkem_marshal_public_key(CBB *out, const struct public_key *pub)
850 {
851 uint8_t *vector_output;
852
853 if (!CBB_add_space(out, &vector_output, kEncodedVectorSize)) {
854 return 0;
855 }
856 vector_encode(vector_output, &pub->t, kLog2Prime);
857 if (!CBB_add_bytes(out, pub->rho, sizeof(pub->rho))) {
858 return 0;
859 }
860 return 1;
861 }
862
863 void
MLKEM768_generate_key_external_entropy(uint8_t out_encoded_public_key[MLKEM768_PUBLIC_KEY_BYTES],struct MLKEM768_private_key * out_private_key,const uint8_t entropy[MLKEM_SEED_BYTES])864 MLKEM768_generate_key_external_entropy(
865 uint8_t out_encoded_public_key[MLKEM768_PUBLIC_KEY_BYTES],
866 struct MLKEM768_private_key *out_private_key,
867 const uint8_t entropy[MLKEM_SEED_BYTES])
868 {
869 struct private_key *priv = private_key_768_from_external(
870 out_private_key);
871 uint8_t augmented_seed[33];
872 uint8_t *rho, *sigma;
873 uint8_t counter = 0;
874 uint8_t hashed[64];
875 vector error;
876 CBB cbb;
877
878 memcpy(augmented_seed, entropy, 32);
879 augmented_seed[32] = RANK768;
880 hash_g(hashed, augmented_seed, 33);
881 rho = hashed;
882 sigma = hashed + 32;
883 memcpy(priv->pub.rho, hashed, sizeof(priv->pub.rho));
884 matrix_expand(&priv->pub.m, rho);
885 vector_generate_secret_eta_2(&priv->s, &counter, sigma);
886 vector_ntt(&priv->s);
887 vector_generate_secret_eta_2(&error, &counter, sigma);
888 vector_ntt(&error);
889 matrix_mult_transpose(&priv->pub.t, &priv->pub.m, &priv->s);
890 vector_add(&priv->pub.t, &error);
891
892 /* XXX - error checking */
893 CBB_init_fixed(&cbb, out_encoded_public_key, MLKEM768_PUBLIC_KEY_BYTES);
894 if (!mlkem_marshal_public_key(&cbb, &priv->pub)) {
895 abort();
896 }
897 CBB_cleanup(&cbb);
898
899 hash_h(priv->pub.public_key_hash, out_encoded_public_key,
900 MLKEM768_PUBLIC_KEY_BYTES);
901 memcpy(priv->fo_failure_secret, entropy + 32, 32);
902 }
903
904 void
MLKEM768_public_from_private(struct MLKEM768_public_key * out_public_key,const struct MLKEM768_private_key * private_key)905 MLKEM768_public_from_private(struct MLKEM768_public_key *out_public_key,
906 const struct MLKEM768_private_key *private_key)
907 {
908 struct public_key *const pub = public_key_768_from_external(
909 out_public_key);
910 const struct private_key *const priv = private_key_768_from_external(
911 private_key);
912
913 *pub = priv->pub;
914 }
915 LCRYPTO_ALIAS(MLKEM768_public_from_private);
916
917 /*
918 * Encrypts a message with given randomness to the ciphertext in |out|. Without
919 * applying the Fujisaki-Okamoto transform this would not result in a CCA secure
920 * scheme, since lattice schemes are vulnerable to decryption failure oracles.
921 */
922 static void
encrypt_cpa(uint8_t out[MLKEM768_CIPHERTEXT_BYTES],const struct public_key * pub,const uint8_t message[32],const uint8_t randomness[32])923 encrypt_cpa(uint8_t out[MLKEM768_CIPHERTEXT_BYTES],
924 const struct public_key *pub, const uint8_t message[32],
925 const uint8_t randomness[32])
926 {
927 scalar expanded_message, scalar_error;
928 vector secret, error, u;
929 uint8_t counter = 0;
930 uint8_t input[33];
931 scalar v;
932
933 vector_generate_secret_eta_2(&secret, &counter, randomness);
934 vector_ntt(&secret);
935 vector_generate_secret_eta_2(&error, &counter, randomness);
936 memcpy(input, randomness, 32);
937 input[32] = counter;
938 scalar_centered_binomial_distribution_eta_2_with_prf(&scalar_error,
939 input);
940 matrix_mult(&u, &pub->m, &secret);
941 vector_inverse_ntt(&u);
942 vector_add(&u, &error);
943 scalar_inner_product(&v, &pub->t, &secret);
944 scalar_inverse_ntt(&v);
945 scalar_add(&v, &scalar_error);
946 scalar_decode_1(&expanded_message, message);
947 scalar_decompress(&expanded_message, 1);
948 scalar_add(&v, &expanded_message);
949 vector_compress(&u, kDU768);
950 vector_encode(out, &u, kDU768);
951 scalar_compress(&v, kDV768);
952 scalar_encode(out + kCompressedVectorSize, &v, kDV768);
953 }
954
955 /* Calls MLKEM768_encap_external_entropy| with random bytes */
956 void
MLKEM768_encap(uint8_t out_ciphertext[MLKEM768_CIPHERTEXT_BYTES],uint8_t out_shared_secret[MLKEM_SHARED_SECRET_BYTES],const struct MLKEM768_public_key * public_key)957 MLKEM768_encap(uint8_t out_ciphertext[MLKEM768_CIPHERTEXT_BYTES],
958 uint8_t out_shared_secret[MLKEM_SHARED_SECRET_BYTES],
959 const struct MLKEM768_public_key *public_key)
960 {
961 uint8_t entropy[MLKEM_ENCAP_ENTROPY];
962
963 arc4random_buf(entropy, MLKEM_ENCAP_ENTROPY);
964 MLKEM768_encap_external_entropy(out_ciphertext, out_shared_secret,
965 public_key, entropy);
966 }
967 LCRYPTO_ALIAS(MLKEM768_encap);
968
969 /* See section 6.2 of the spec. */
970 void
MLKEM768_encap_external_entropy(uint8_t out_ciphertext[MLKEM768_CIPHERTEXT_BYTES],uint8_t out_shared_secret[MLKEM_SHARED_SECRET_BYTES],const struct MLKEM768_public_key * public_key,const uint8_t entropy[MLKEM_ENCAP_ENTROPY])971 MLKEM768_encap_external_entropy(
972 uint8_t out_ciphertext[MLKEM768_CIPHERTEXT_BYTES],
973 uint8_t out_shared_secret[MLKEM_SHARED_SECRET_BYTES],
974 const struct MLKEM768_public_key *public_key,
975 const uint8_t entropy[MLKEM_ENCAP_ENTROPY])
976 {
977 const struct public_key *pub = public_key_768_from_external(public_key);
978 uint8_t key_and_randomness[64];
979 uint8_t input[64];
980
981 memcpy(input, entropy, MLKEM_ENCAP_ENTROPY);
982 memcpy(input + MLKEM_ENCAP_ENTROPY, pub->public_key_hash,
983 sizeof(input) - MLKEM_ENCAP_ENTROPY);
984 hash_g(key_and_randomness, input, sizeof(input));
985 encrypt_cpa(out_ciphertext, pub, entropy, key_and_randomness + 32);
986 memcpy(out_shared_secret, key_and_randomness, 32);
987 }
988
989 static void
decrypt_cpa(uint8_t out[32],const struct private_key * priv,const uint8_t ciphertext[MLKEM768_CIPHERTEXT_BYTES])990 decrypt_cpa(uint8_t out[32], const struct private_key *priv,
991 const uint8_t ciphertext[MLKEM768_CIPHERTEXT_BYTES])
992 {
993 scalar mask, v;
994 vector u;
995
996 vector_decode(&u, ciphertext, kDU768);
997 vector_decompress(&u, kDU768);
998 vector_ntt(&u);
999 scalar_decode(&v, ciphertext + kCompressedVectorSize, kDV768);
1000 scalar_decompress(&v, kDV768);
1001 scalar_inner_product(&mask, &priv->s, &u);
1002 scalar_inverse_ntt(&mask);
1003 scalar_sub(&v, &mask);
1004 scalar_compress(&v, 1);
1005 scalar_encode_1(out, &v);
1006 }
1007
1008 /* See section 6.3 */
1009 int
MLKEM768_decap(uint8_t out_shared_secret[MLKEM_SHARED_SECRET_BYTES],const uint8_t * ciphertext,size_t ciphertext_len,const struct MLKEM768_private_key * private_key)1010 MLKEM768_decap(uint8_t out_shared_secret[MLKEM_SHARED_SECRET_BYTES],
1011 const uint8_t *ciphertext, size_t ciphertext_len,
1012 const struct MLKEM768_private_key *private_key)
1013 {
1014 const struct private_key *priv = private_key_768_from_external(
1015 private_key);
1016 uint8_t expected_ciphertext[MLKEM768_CIPHERTEXT_BYTES];
1017 uint8_t key_and_randomness[64];
1018 uint8_t failure_key[32];
1019 uint8_t decrypted[64];
1020 uint8_t mask;
1021 int i;
1022
1023 if (ciphertext_len != MLKEM768_CIPHERTEXT_BYTES) {
1024 arc4random_buf(out_shared_secret, MLKEM_SHARED_SECRET_BYTES);
1025 return 0;
1026 }
1027
1028 decrypt_cpa(decrypted, priv, ciphertext);
1029 memcpy(decrypted + 32, priv->pub.public_key_hash,
1030 sizeof(decrypted) - 32);
1031 hash_g(key_and_randomness, decrypted, sizeof(decrypted));
1032 encrypt_cpa(expected_ciphertext, &priv->pub, decrypted,
1033 key_and_randomness + 32);
1034 kdf(failure_key, priv->fo_failure_secret, ciphertext, ciphertext_len);
1035 mask = constant_time_eq_int_8(memcmp(ciphertext, expected_ciphertext,
1036 sizeof(expected_ciphertext)), 0);
1037 for (i = 0; i < MLKEM_SHARED_SECRET_BYTES; i++) {
1038 out_shared_secret[i] = constant_time_select_8(mask,
1039 key_and_randomness[i], failure_key[i]);
1040 }
1041
1042 return 1;
1043 }
1044 LCRYPTO_ALIAS(MLKEM768_decap);
1045
1046 int
MLKEM768_marshal_public_key(CBB * out,const struct MLKEM768_public_key * public_key)1047 MLKEM768_marshal_public_key(CBB *out,
1048 const struct MLKEM768_public_key *public_key)
1049 {
1050 return mlkem_marshal_public_key(out,
1051 public_key_768_from_external(public_key));
1052 }
1053 LCRYPTO_ALIAS(MLKEM768_marshal_public_key);
1054
1055 /*
1056 * mlkem_parse_public_key_no_hash parses |in| into |pub| but doesn't calculate
1057 * the value of |pub->public_key_hash|.
1058 */
1059 static int
mlkem_parse_public_key_no_hash(struct public_key * pub,CBS * in)1060 mlkem_parse_public_key_no_hash(struct public_key *pub, CBS *in)
1061 {
1062 CBS t_bytes;
1063
1064 if (!CBS_get_bytes(in, &t_bytes, kEncodedVectorSize) ||
1065 !vector_decode(&pub->t, CBS_data(&t_bytes), kLog2Prime)) {
1066 return 0;
1067 }
1068 memcpy(pub->rho, CBS_data(in), sizeof(pub->rho));
1069 if (!CBS_skip(in, sizeof(pub->rho)))
1070 return 0;
1071 matrix_expand(&pub->m, pub->rho);
1072 return 1;
1073 }
1074
1075 int
MLKEM768_parse_public_key(struct MLKEM768_public_key * public_key,CBS * in)1076 MLKEM768_parse_public_key(struct MLKEM768_public_key *public_key, CBS *in)
1077 {
1078 struct public_key *pub = public_key_768_from_external(public_key);
1079 CBS orig_in = *in;
1080
1081 if (!mlkem_parse_public_key_no_hash(pub, in) ||
1082 CBS_len(in) != 0) {
1083 return 0;
1084 }
1085 hash_h(pub->public_key_hash, CBS_data(&orig_in), CBS_len(&orig_in));
1086 return 1;
1087 }
1088 LCRYPTO_ALIAS(MLKEM768_parse_public_key);
1089
1090 int
MLKEM768_marshal_private_key(CBB * out,const struct MLKEM768_private_key * private_key)1091 MLKEM768_marshal_private_key(CBB *out,
1092 const struct MLKEM768_private_key *private_key)
1093 {
1094 const struct private_key *const priv = private_key_768_from_external(
1095 private_key);
1096 uint8_t *s_output;
1097
1098 if (!CBB_add_space(out, &s_output, kEncodedVectorSize)) {
1099 return 0;
1100 }
1101 vector_encode(s_output, &priv->s, kLog2Prime);
1102 if (!mlkem_marshal_public_key(out, &priv->pub) ||
1103 !CBB_add_bytes(out, priv->pub.public_key_hash,
1104 sizeof(priv->pub.public_key_hash)) ||
1105 !CBB_add_bytes(out, priv->fo_failure_secret,
1106 sizeof(priv->fo_failure_secret))) {
1107 return 0;
1108 }
1109 return 1;
1110 }
1111
1112 int
MLKEM768_parse_private_key(struct MLKEM768_private_key * out_private_key,CBS * in)1113 MLKEM768_parse_private_key(struct MLKEM768_private_key *out_private_key,
1114 CBS *in)
1115 {
1116 struct private_key *const priv = private_key_768_from_external(
1117 out_private_key);
1118 CBS s_bytes;
1119
1120 if (!CBS_get_bytes(in, &s_bytes, kEncodedVectorSize) ||
1121 !vector_decode(&priv->s, CBS_data(&s_bytes), kLog2Prime) ||
1122 !mlkem_parse_public_key_no_hash(&priv->pub, in)) {
1123 return 0;
1124 }
1125 memcpy(priv->pub.public_key_hash, CBS_data(in),
1126 sizeof(priv->pub.public_key_hash));
1127 if (!CBS_skip(in, sizeof(priv->pub.public_key_hash)))
1128 return 0;
1129 memcpy(priv->fo_failure_secret, CBS_data(in),
1130 sizeof(priv->fo_failure_secret));
1131 if (!CBS_skip(in, sizeof(priv->fo_failure_secret)))
1132 return 0;
1133 if (CBS_len(in) != 0)
1134 return 0;
1135
1136 return 1;
1137 }
1138 LCRYPTO_ALIAS(MLKEM768_parse_private_key);
1139