xref: /openbsd/lib/libcrypto/mlkem/mlkem1024.c (revision 9a8fba7c)
1 /* $OpenBSD: mlkem1024.c,v 1.6 2025/01/03 08:19:24 tb Exp $ */
2 /*
3  * Copyright (c) 2024, Google Inc.
4  * Copyright (c) 2024, Bob Beck <beck@obtuse.com>
5  *
6  * Permission to use, copy, modify, and/or distribute this software for any
7  * purpose with or without fee is hereby granted, provided that the above
8  * copyright notice and this permission notice appear in all copies.
9  *
10  * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
11  * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
12  * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
13  * SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
14  * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
15  * OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
16  * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
17  */
18 
19 #include <assert.h>
20 #include <stdlib.h>
21 #include <string.h>
22 
23 #include "bytestring.h"
24 #include "mlkem.h"
25 
26 #include "sha3_internal.h"
27 #include "mlkem_internal.h"
28 #include "constant_time.h"
29 #include "crypto_internal.h"
30 
31 /* Remove later */
32 #undef LCRYPTO_ALIAS
33 #define LCRYPTO_ALIAS(A)
34 
35 /*
36  * See
37  * https://csrc.nist.gov/pubs/fips/203/final
38  */
39 
40 static void
prf(uint8_t * out,size_t out_len,const uint8_t in[33])41 prf(uint8_t *out, size_t out_len, const uint8_t in[33])
42 {
43 	sha3_ctx ctx;
44 	shake256_init(&ctx);
45 	shake_update(&ctx, in, 33);
46 	shake_xof(&ctx);
47 	shake_out(&ctx, out, out_len);
48 }
49 
50 /* Section 4.1 */
51 static void
hash_h(uint8_t out[32],const uint8_t * in,size_t len)52 hash_h(uint8_t out[32], const uint8_t *in, size_t len)
53 {
54 	sha3_ctx ctx;
55 	sha3_init(&ctx, 32);
56 	sha3_update(&ctx, in, len);
57 	sha3_final(out, &ctx);
58 }
59 
60 static void
hash_g(uint8_t out[64],const uint8_t * in,size_t len)61 hash_g(uint8_t out[64], const uint8_t *in, size_t len)
62 {
63 	sha3_ctx ctx;
64 	sha3_init(&ctx, 64);
65 	sha3_update(&ctx, in, len);
66 	sha3_final(out, &ctx);
67 }
68 
69 /* this is called 'J' in the spec */
70 static void
kdf(uint8_t out[MLKEM_SHARED_SECRET_BYTES],const uint8_t failure_secret[32],const uint8_t * in,size_t len)71 kdf(uint8_t out[MLKEM_SHARED_SECRET_BYTES], const uint8_t failure_secret[32],
72     const uint8_t *in, size_t len)
73 {
74 	sha3_ctx ctx;
75 	shake256_init(&ctx);
76 	shake_update(&ctx, failure_secret, 32);
77 	shake_update(&ctx, in, len);
78 	shake_xof(&ctx);
79 	shake_out(&ctx, out, MLKEM_SHARED_SECRET_BYTES);
80 }
81 
82 #define DEGREE 256
83 #define RANK1024 4
84 
85 static const size_t kBarrettMultiplier = 5039;
86 static const unsigned kBarrettShift = 24;
87 static const uint16_t kPrime = 3329;
88 static const int kLog2Prime = 12;
89 static const uint16_t kHalfPrime = (/*kPrime=*/3329 - 1) / 2;
90 static const int kDU1024 = 11;
91 static const int kDV1024 = 5;
92 
93 /*
94  * kInverseDegree is 128^-1 mod 3329; 128 because kPrime does not have a 512th
95  * root of unity.
96  */
97 static const uint16_t kInverseDegree = 3303;
98 static const size_t kEncodedVectorSize =
99     (/*kLog2Prime=*/12 * DEGREE / 8) * RANK1024;
100 static const size_t kCompressedVectorSize = /*kDU1024=*/ 11 * RANK1024 * DEGREE /
101     8;
102 
103 typedef struct scalar {
104 	/* On every function entry and exit, 0 <= c < kPrime. */
105 	uint16_t c[DEGREE];
106 } scalar;
107 
108 typedef struct vector {
109 	scalar v[RANK1024];
110 } vector;
111 
112 typedef struct matrix {
113 	scalar v[RANK1024][RANK1024];
114 } matrix;
115 
116 /*
117  * This bit of Python will be referenced in some of the following comments:
118  *
119  *  p = 3329
120  *
121  * def bitreverse(i):
122  *     ret = 0
123  *     for n in range(7):
124  *         bit = i & 1
125  *         ret <<= 1
126  *         ret |= bit
127  *         i >>= 1
128  *     return ret
129  */
130 
131 /* kNTTRoots = [pow(17, bitreverse(i), p) for i in range(128)] */
132 static const uint16_t kNTTRoots[128] = {
133 	1,    1729, 2580, 3289, 2642, 630,  1897, 848,  1062, 1919, 193,  797,
134 	2786, 3260, 569,  1746, 296,  2447, 1339, 1476, 3046, 56,   2240, 1333,
135 	1426, 2094, 535,  2882, 2393, 2879, 1974, 821,  289,  331,  3253, 1756,
136 	1197, 2304, 2277, 2055, 650,  1977, 2513, 632,  2865, 33,   1320, 1915,
137 	2319, 1435, 807,  452,  1438, 2868, 1534, 2402, 2647, 2617, 1481, 648,
138 	2474, 3110, 1227, 910,  17,   2761, 583,  2649, 1637, 723,  2288, 1100,
139 	1409, 2662, 3281, 233,  756,  2156, 3015, 3050, 1703, 1651, 2789, 1789,
140 	1847, 952,  1461, 2687, 939,  2308, 2437, 2388, 733,  2337, 268,  641,
141 	1584, 2298, 2037, 3220, 375,  2549, 2090, 1645, 1063, 319,  2773, 757,
142 	2099, 561,  2466, 2594, 2804, 1092, 403,  1026, 1143, 2150, 2775, 886,
143 	1722, 1212, 1874, 1029, 2110, 2935, 885,  2154,
144 };
145 
146 /* kInverseNTTRoots = [pow(17, -bitreverse(i), p) for i in range(128)] */
147 static const uint16_t kInverseNTTRoots[128] = {
148 	1,    1600, 40,   749,  2481, 1432, 2699, 687,  1583, 2760, 69,   543,
149 	2532, 3136, 1410, 2267, 2508, 1355, 450,  936,  447,  2794, 1235, 1903,
150 	1996, 1089, 3273, 283,  1853, 1990, 882,  3033, 2419, 2102, 219,  855,
151 	2681, 1848, 712,  682,  927,  1795, 461,  1891, 2877, 2522, 1894, 1010,
152 	1414, 2009, 3296, 464,  2697, 816,  1352, 2679, 1274, 1052, 1025, 2132,
153 	1573, 76,   2998, 3040, 1175, 2444, 394,  1219, 2300, 1455, 2117, 1607,
154 	2443, 554,  1179, 2186, 2303, 2926, 2237, 525,  735,  863,  2768, 1230,
155 	2572, 556,  3010, 2266, 1684, 1239, 780,  2954, 109,  1292, 1031, 1745,
156 	2688, 3061, 992,  2596, 941,  892,  1021, 2390, 642,  1868, 2377, 1482,
157 	1540, 540,  1678, 1626, 279,  314,  1173, 2573, 3096, 48,   667,  1920,
158 	2229, 1041, 2606, 1692, 680,  2746, 568,  3312,
159 };
160 
161 /* kModRoots = [pow(17, 2*bitreverse(i) + 1, p) for i in range(128)] */
162 static const uint16_t kModRoots[128] = {
163 	17,   3312, 2761, 568,  583,  2746, 2649, 680,  1637, 1692, 723,  2606,
164 	2288, 1041, 1100, 2229, 1409, 1920, 2662, 667,  3281, 48,   233,  3096,
165 	756,  2573, 2156, 1173, 3015, 314,  3050, 279,  1703, 1626, 1651, 1678,
166 	2789, 540,  1789, 1540, 1847, 1482, 952,  2377, 1461, 1868, 2687, 642,
167 	939,  2390, 2308, 1021, 2437, 892,  2388, 941,  733,  2596, 2337, 992,
168 	268,  3061, 641,  2688, 1584, 1745, 2298, 1031, 2037, 1292, 3220, 109,
169 	375,  2954, 2549, 780,  2090, 1239, 1645, 1684, 1063, 2266, 319,  3010,
170 	2773, 556,  757,  2572, 2099, 1230, 561,  2768, 2466, 863,  2594, 735,
171 	2804, 525,  1092, 2237, 403,  2926, 1026, 2303, 1143, 2186, 2150, 1179,
172 	2775, 554,  886,  2443, 1722, 1607, 1212, 2117, 1874, 1455, 1029, 2300,
173 	2110, 1219, 2935, 394,  885,  2444, 2154, 1175,
174 };
175 
176 /* reduce_once reduces 0 <= x < 2*kPrime, mod kPrime. */
177 static uint16_t
reduce_once(uint16_t x)178 reduce_once(uint16_t x)
179 {
180 	assert(x < 2 * kPrime);
181 	const uint16_t subtracted = x - kPrime;
182 	uint16_t mask = 0u - (subtracted >> 15);
183 
184 	/*
185 	 * Although this is a constant-time select, we omit a value barrier here.
186 	 * Value barriers impede auto-vectorization (likely because it forces the
187 	 * value to transit through a general-purpose register). On AArch64, this
188 	 * is a difference of 2x.
189 	 *
190 	 * We usually add value barriers to selects because Clang turns
191          * consecutive selects with the same condition into a branch instead of
192 	 * CMOV/CSEL. This condition does not occur in ML-KEM, so omitting it
193          * seems to be safe so far  but see
194          * |scalar_centered_binomial_distribution_eta_2_with_prf|.
195 	 */
196 	return (mask & x) | (~mask & subtracted);
197 }
198 
199 /*
200  * constant time reduce x mod kPrime using Barrett reduction. x must be less
201  * than kPrime + 2×kPrime².
202  */
203 static uint16_t
reduce(uint32_t x)204 reduce(uint32_t x)
205 {
206 	uint64_t product = (uint64_t)x * kBarrettMultiplier;
207 	uint32_t quotient = (uint32_t)(product >> kBarrettShift);
208 	uint32_t remainder = x - quotient * kPrime;
209 
210 	assert(x < kPrime + 2u * kPrime * kPrime);
211 	return reduce_once(remainder);
212 }
213 
214 static void
scalar_zero(scalar * out)215 scalar_zero(scalar *out)
216 {
217 	memset(out, 0, sizeof(*out));
218 }
219 
220 static void
vector_zero(vector * out)221 vector_zero(vector *out)
222 {
223 	memset(out, 0, sizeof(*out));
224 }
225 
226 /*
227  * In place number theoretic transform of a given scalar.
228  * Note that MLKEM's kPrime 3329 does not have a 512th root of unity, so this
229  * transform leaves off the last iteration of the usual FFT code, with the 128
230  * relevant roots of unity being stored in |kNTTRoots|. This means the output
231  * should be seen as 128 elements in GF(3329^2), with the coefficients of the
232  * elements being consecutive entries in |s->c|.
233  */
234 static void
scalar_ntt(scalar * s)235 scalar_ntt(scalar *s)
236 {
237 	int offset = DEGREE;
238 	int step;
239 	/*
240 	 * `int` is used here because using `size_t` throughout caused a ~5% slowdown
241 	 * with Clang 14 on Aarch64.
242 	 */
243 	for (step = 1; step < DEGREE / 2; step <<= 1) {
244 		int i, j, k = 0;
245 
246 		offset >>= 1;
247 		for (i = 0; i < step; i++) {
248 			const uint32_t step_root = kNTTRoots[i + step];
249 
250 			for (j = k; j < k + offset; j++) {
251 				uint16_t odd, even;
252 
253 				odd = reduce(step_root * s->c[j + offset]);
254 				even = s->c[j];
255 				s->c[j] = reduce_once(odd + even);
256 				s->c[j + offset] = reduce_once(even - odd +
257 				    kPrime);
258 			}
259 			k += 2 * offset;
260 		}
261 	}
262 }
263 
264 static void
vector_ntt(vector * a)265 vector_ntt(vector *a)
266 {
267 	int i;
268 
269 	for (i = 0; i < RANK1024; i++) {
270 		scalar_ntt(&a->v[i]);
271 	}
272 }
273 
274 /*
275  * In place inverse number theoretic transform of a given scalar, with pairs of
276  * entries of s->v being interpreted as elements of GF(3329^2). Just as with the
277  * number theoretic transform, this leaves off the first step of the normal iFFT
278  * to account for the fact that 3329 does not have a 512th root of unity, using
279  * the precomputed 128 roots of unity stored in |kInverseNTTRoots|.
280  */
281 static void
scalar_inverse_ntt(scalar * s)282 scalar_inverse_ntt(scalar *s)
283 {
284 	int i, j, k, offset, step = DEGREE / 2;
285 
286 	/*
287 	 * `int` is used here because using `size_t` throughout caused a ~5% slowdown
288 	 * with Clang 14 on Aarch64.
289 	 */
290 	for (offset = 2; offset < DEGREE; offset <<= 1) {
291 		step >>= 1;
292 		k = 0;
293 		for (i = 0; i < step; i++) {
294 			uint32_t step_root = kInverseNTTRoots[i + step];
295 			for (j = k; j < k + offset; j++) {
296 				uint16_t odd, even;
297 				odd = s->c[j + offset];
298 				even = s->c[j];
299 				s->c[j] = reduce_once(odd + even);
300 				s->c[j + offset] = reduce(step_root *
301 				    (even - odd + kPrime));
302 			}
303 			k += 2 * offset;
304 		}
305 	}
306 	for (i = 0; i < DEGREE; i++) {
307 		s->c[i] = reduce(s->c[i] * kInverseDegree);
308 	}
309 }
310 
311 static void
vector_inverse_ntt(vector * a)312 vector_inverse_ntt(vector *a)
313 {
314 	int i;
315 
316 	for (i = 0; i < RANK1024; i++) {
317 		scalar_inverse_ntt(&a->v[i]);
318 	}
319 }
320 
321 static void
scalar_add(scalar * lhs,const scalar * rhs)322 scalar_add(scalar *lhs, const scalar *rhs)
323 {
324 	int i;
325 
326 	for (i = 0; i < DEGREE; i++) {
327 		lhs->c[i] = reduce_once(lhs->c[i] + rhs->c[i]);
328 	}
329 }
330 
331 static void
scalar_sub(scalar * lhs,const scalar * rhs)332 scalar_sub(scalar *lhs, const scalar *rhs)
333 {
334 	int i;
335 
336 	for (i = 0; i < DEGREE; i++) {
337 		lhs->c[i] = reduce_once(lhs->c[i] - rhs->c[i] + kPrime);
338 	}
339 }
340 
341 /*
342  * Multiplying two scalars in the number theoretically transformed state.
343  * Since 3329 does not have a 512th root of unity, this means we have to
344  * interpret the 2*ith and (2*i+1)th entries of the scalar as elements of
345  * GF(3329)[X]/(X^2 - 17^(2*bitreverse(i)+1)).
346  * The value of 17^(2*bitreverse(i)+1) mod 3329 is stored in the precomputed
347  * |kModRoots| table. Our Barrett transform only allows us to multiply two
348  * reduced numbers together, so we need some intermediate reduction steps,
349  * even if an uint64_t could hold 3 multiplied numbers.
350  */
351 static void
scalar_mult(scalar * out,const scalar * lhs,const scalar * rhs)352 scalar_mult(scalar *out, const scalar *lhs, const scalar *rhs)
353 {
354 	int i;
355 
356 	for (i = 0; i < DEGREE / 2; i++) {
357 		uint32_t real_real = (uint32_t)lhs->c[2 * i] * rhs->c[2 * i];
358 		uint32_t img_img = (uint32_t)lhs->c[2 * i + 1] *
359 		    rhs->c[2 * i + 1];
360 		uint32_t real_img = (uint32_t)lhs->c[2 * i] * rhs->c[2 * i + 1];
361 		uint32_t img_real = (uint32_t)lhs->c[2 * i + 1] * rhs->c[2 * i];
362 
363 		out->c[2 * i] =
364 		    reduce(real_real +
365 		    (uint32_t)reduce(img_img) * kModRoots[i]);
366 		out->c[2 * i + 1] = reduce(img_real + real_img);
367 	}
368 }
369 
370 static void
vector_add(vector * lhs,const vector * rhs)371 vector_add(vector *lhs, const vector *rhs)
372 {
373 	int i;
374 
375 	for (i = 0; i < RANK1024; i++) {
376 		scalar_add(&lhs->v[i], &rhs->v[i]);
377 	}
378 }
379 
380 static void
matrix_mult(vector * out,const matrix * m,const vector * a)381 matrix_mult(vector *out, const matrix *m, const vector *a)
382 {
383 	int i, j;
384 
385 	vector_zero(out);
386 	for (i = 0; i < RANK1024; i++) {
387 		for (j = 0; j < RANK1024; j++) {
388 			scalar product;
389 
390 			scalar_mult(&product, &m->v[i][j], &a->v[j]);
391 			scalar_add(&out->v[i], &product);
392 		}
393 	}
394 }
395 
396 static void
matrix_mult_transpose(vector * out,const matrix * m,const vector * a)397 matrix_mult_transpose(vector *out, const matrix *m,
398     const vector *a)
399 {
400 	int i, j;
401 
402 	vector_zero(out);
403 	for (i = 0; i < RANK1024; i++) {
404 		for (j = 0; j < RANK1024; j++) {
405 			scalar product;
406 
407 			scalar_mult(&product, &m->v[j][i], &a->v[j]);
408 			scalar_add(&out->v[i], &product);
409 		}
410 	}
411 }
412 
413 static void
scalar_inner_product(scalar * out,const vector * lhs,const vector * rhs)414 scalar_inner_product(scalar *out, const vector *lhs,
415     const vector *rhs)
416 {
417 	int i;
418 	scalar_zero(out);
419 	for (i = 0; i < RANK1024; i++) {
420 		scalar product;
421 
422 		scalar_mult(&product, &lhs->v[i], &rhs->v[i]);
423 		scalar_add(out, &product);
424 	}
425 }
426 
427 /*
428  * Algorithm 6 of spec. Rejection samples a Keccak stream to get uniformly
429  * distributed elements. This is used for matrix expansion and only operates on
430  * public inputs.
431  */
432 static void
scalar_from_keccak_vartime(scalar * out,sha3_ctx * keccak_ctx)433 scalar_from_keccak_vartime(scalar *out, sha3_ctx *keccak_ctx)
434 {
435 	int i, done = 0;
436 
437 	while (done < DEGREE) {
438 		uint8_t block[168];
439 
440 		shake_out(keccak_ctx, block, sizeof(block));
441 		for (i = 0; i < sizeof(block) && done < DEGREE; i += 3) {
442 			uint16_t d1 = block[i] + 256 * (block[i + 1] % 16);
443 			uint16_t d2 = block[i + 1] / 16 + 16 * block[i + 2];
444 
445 			if (d1 < kPrime) {
446 				out->c[done++] = d1;
447 			}
448 			if (d2 < kPrime && done < DEGREE) {
449 				out->c[done++] = d2;
450 			}
451 		}
452 	}
453 }
454 
455 /*
456  * Algorithm 7 of the spec, with eta fixed to two and the PRF call
457  * included. Creates binominally distributed elements by sampling 2*|eta| bits,
458  * and setting the coefficient to the count of the first bits minus the count of
459  * the second bits, resulting in a centered binomial distribution. Since eta is
460  * two this gives -2/2 with a probability of 1/16, -1/1 with probability 1/4,
461  * and 0 with probability 3/8.
462  */
463 static void
scalar_centered_binomial_distribution_eta_2_with_prf(scalar * out,const uint8_t input[33])464 scalar_centered_binomial_distribution_eta_2_with_prf(scalar *out,
465     const uint8_t input[33])
466 {
467 	uint8_t entropy[128];
468 	int i;
469 
470 	CTASSERT(sizeof(entropy) == 2 * /*kEta=*/ 2 * DEGREE / 8);
471 	prf(entropy, sizeof(entropy), input);
472 
473 	for (i = 0; i < DEGREE; i += 2) {
474 		uint8_t byte = entropy[i / 2];
475 		uint16_t mask;
476 		uint16_t value = (byte & 1) + ((byte >> 1) & 1);
477 
478 		value -= ((byte >> 2) & 1) + ((byte >> 3) & 1);
479 
480 		/*
481 		 * Add |kPrime| if |value| underflowed. See |reduce_once| for a
482 		 * discussion on why the value barrier is omitted. While this
483 		 * could have been written reduce_once(value + kPrime), this is
484 		 * one extra addition and small range of |value| tempts some
485 		 * versions of Clang to emit a branch.
486 		 */
487 		mask = 0u - (value >> 15);
488 		out->c[i] = ((value + kPrime) & mask) | (value & ~mask);
489 
490 		byte >>= 4;
491 		value = (byte & 1) + ((byte >> 1) & 1);
492 		value -= ((byte >> 2) & 1) + ((byte >> 3) & 1);
493 		/*  See above. */
494 		mask = 0u - (value >> 15);
495 		out->c[i + 1] = ((value + kPrime) & mask) | (value & ~mask);
496 	}
497 }
498 
499 /*
500  * Generates a secret vector by using
501  * |scalar_centered_binomial_distribution_eta_2_with_prf|, using the given seed
502  * appending and incrementing |counter| for entry of the vector.
503  */
504 static void
vector_generate_secret_eta_2(vector * out,uint8_t * counter,const uint8_t seed[32])505 vector_generate_secret_eta_2(vector *out, uint8_t *counter,
506     const uint8_t seed[32])
507 {
508 	uint8_t input[33];
509 	int i;
510 
511 	memcpy(input, seed, 32);
512 	for (i = 0; i < RANK1024; i++) {
513 		input[32] = (*counter)++;
514 		scalar_centered_binomial_distribution_eta_2_with_prf(&out->v[i],
515 		    input);
516 	}
517 }
518 
519 /* Expands the matrix of a seed for key generation and for encaps-CPA. */
520 static void
matrix_expand(matrix * out,const uint8_t rho[32])521 matrix_expand(matrix *out, const uint8_t rho[32])
522 {
523 	uint8_t input[34];
524 	int i, j;
525 
526 	memcpy(input, rho, 32);
527 	for (i = 0; i < RANK1024; i++) {
528 		for (j = 0; j < RANK1024; j++) {
529 			sha3_ctx keccak_ctx;
530 
531 			input[32] = i;
532 			input[33] = j;
533 			shake128_init(&keccak_ctx);
534 			shake_update(&keccak_ctx, input, sizeof(input));
535 			shake_xof(&keccak_ctx);
536 			scalar_from_keccak_vartime(&out->v[i][j], &keccak_ctx);
537 		}
538 	}
539 }
540 
541 static const uint8_t kMasks[8] = {0x01, 0x03, 0x07, 0x0f,
542 	0x1f, 0x3f, 0x7f, 0xff};
543 
544 static void
scalar_encode(uint8_t * out,const scalar * s,int bits)545 scalar_encode(uint8_t *out, const scalar *s, int bits)
546 {
547 	uint8_t out_byte = 0;
548 	int i, out_byte_bits = 0;
549 
550 	assert(bits <= (int)sizeof(*s->c) * 8 && bits != 1);
551 	for (i = 0; i < DEGREE; i++) {
552 		uint16_t element = s->c[i];
553 		int element_bits_done = 0;
554 
555 		while (element_bits_done < bits) {
556 			int chunk_bits = bits - element_bits_done;
557 			int out_bits_remaining = 8 - out_byte_bits;
558 
559 			if (chunk_bits >= out_bits_remaining) {
560 				chunk_bits = out_bits_remaining;
561 				out_byte |= (element &
562 				    kMasks[chunk_bits - 1]) << out_byte_bits;
563 				*out = out_byte;
564 				out++;
565 				out_byte_bits = 0;
566 				out_byte = 0;
567 			} else {
568 				out_byte |= (element &
569 				    kMasks[chunk_bits - 1]) << out_byte_bits;
570 				out_byte_bits += chunk_bits;
571 			}
572 
573 			element_bits_done += chunk_bits;
574 			element >>= chunk_bits;
575 		}
576 	}
577 
578 	if (out_byte_bits > 0) {
579 		*out = out_byte;
580 	}
581 }
582 
583 /* scalar_encode_1 is |scalar_encode| specialised for |bits| == 1. */
584 static void
scalar_encode_1(uint8_t out[32],const scalar * s)585 scalar_encode_1(uint8_t out[32], const scalar *s)
586 {
587 	int i, j;
588 
589 	for (i = 0; i < DEGREE; i += 8) {
590 		uint8_t out_byte = 0;
591 
592 		for (j = 0; j < 8; j++) {
593 			out_byte |= (s->c[i + j] & 1) << j;
594 		}
595 		*out = out_byte;
596 		out++;
597 	}
598 }
599 
600 /*
601  * Encodes an entire vector into 32*|RANK1024|*|bits| bytes. Note that since 256
602  * (DEGREE) is divisible by 8, the individual vector entries will always fill a
603  * whole number of bytes, so we do not need to worry about bit packing here.
604  */
605 static void
vector_encode(uint8_t * out,const vector * a,int bits)606 vector_encode(uint8_t *out, const vector *a, int bits)
607 {
608 	int i;
609 
610 	for (i = 0; i < RANK1024; i++) {
611 		scalar_encode(out + i * bits * DEGREE / 8, &a->v[i], bits);
612 	}
613 }
614 
615 /*
616  * scalar_decode parses |DEGREE * bits| bits from |in| into |DEGREE| values in
617  * |out|. It returns one on success and zero if any parsed value is >=
618  * |kPrime|.
619  */
620 static int
scalar_decode(scalar * out,const uint8_t * in,int bits)621 scalar_decode(scalar *out, const uint8_t *in, int bits)
622 {
623 	uint8_t in_byte = 0;
624 	int i, in_byte_bits_left = 0;
625 
626 	assert(bits <= (int)sizeof(*out->c) * 8 && bits != 1);
627 
628 	for (i = 0; i < DEGREE; i++) {
629 		uint16_t element = 0;
630 		int element_bits_done = 0;
631 
632 		while (element_bits_done < bits) {
633 			int chunk_bits = bits - element_bits_done;
634 
635 			if (in_byte_bits_left == 0) {
636 				in_byte = *in;
637 				in++;
638 				in_byte_bits_left = 8;
639 			}
640 
641 			if (chunk_bits > in_byte_bits_left) {
642 				chunk_bits = in_byte_bits_left;
643 			}
644 
645 			element |= (in_byte & kMasks[chunk_bits - 1]) <<
646 			    element_bits_done;
647 			in_byte_bits_left -= chunk_bits;
648 			in_byte >>= chunk_bits;
649 
650 			element_bits_done += chunk_bits;
651 		}
652 
653 		if (element >= kPrime) {
654 			return 0;
655 		}
656 		out->c[i] = element;
657 	}
658 
659 	return 1;
660 }
661 
662 /* scalar_decode_1 is |scalar_decode| specialised for |bits| == 1. */
663 static void
scalar_decode_1(scalar * out,const uint8_t in[32])664 scalar_decode_1(scalar *out, const uint8_t in[32])
665 {
666 	int i, j;
667 
668 	for (i = 0; i < DEGREE; i += 8) {
669 		uint8_t in_byte = *in;
670 
671 		in++;
672 		for (j = 0; j < 8; j++) {
673 			out->c[i + j] = in_byte & 1;
674 			in_byte >>= 1;
675 		}
676 	}
677 }
678 
679 /*
680  * Decodes 32*|RANK1024|*|bits| bytes from |in| into |out|. It returns one on
681  * success or zero if any parsed value is >= |kPrime|.
682  */
683 static int
vector_decode(vector * out,const uint8_t * in,int bits)684 vector_decode(vector *out, const uint8_t *in, int bits)
685 {
686 	int i;
687 
688 	for (i = 0; i < RANK1024; i++) {
689 		if (!scalar_decode(&out->v[i], in + i * bits * DEGREE / 8,
690 		    bits)) {
691 			return 0;
692 		}
693 	}
694 	return 1;
695 }
696 
697 /*
698  * Compresses (lossily) an input |x| mod 3329 into |bits| many bits by grouping
699  * numbers close to each other together. The formula used is
700  * round(2^|bits|/kPrime*x) mod 2^|bits|.
701  * Uses Barrett reduction to achieve constant time. Since we need both the
702  * remainder (for rounding) and the quotient (as the result), we cannot use
703  * |reduce| here, but need to do the Barrett reduction directly.
704  */
705 static uint16_t
compress(uint16_t x,int bits)706 compress(uint16_t x, int bits)
707 {
708 	uint32_t shifted = (uint32_t)x << bits;
709 	uint64_t product = (uint64_t)shifted * kBarrettMultiplier;
710 	uint32_t quotient = (uint32_t)(product >> kBarrettShift);
711 	uint32_t remainder = shifted - quotient * kPrime;
712 
713 	/*
714 	 * Adjust the quotient to round correctly:
715 	 * 0 <= remainder <= kHalfPrime round to 0
716 	 * kHalfPrime < remainder <= kPrime + kHalfPrime round to 1
717 	 * kPrime + kHalfPrime < remainder < 2 * kPrime round to 2
718 	 */
719 	assert(remainder < 2u * kPrime);
720 	quotient += 1 & constant_time_lt(kHalfPrime, remainder);
721 	quotient += 1 & constant_time_lt(kPrime + kHalfPrime, remainder);
722 	return quotient & ((1 << bits) - 1);
723 }
724 
725 /*
726  * Decompresses |x| by using an equi-distant representative. The formula is
727  * round(kPrime/2^|bits|*x). Note that 2^|bits| being the divisor allows us to
728  * implement this logic using only bit operations.
729  */
730 static uint16_t
decompress(uint16_t x,int bits)731 decompress(uint16_t x, int bits)
732 {
733 	uint32_t product = (uint32_t)x * kPrime;
734 	uint32_t power = 1 << bits;
735 	/* This is |product| % power, since |power| is a power of 2. */
736 	uint32_t remainder = product & (power - 1);
737 	/* This is |product| / power, since |power| is a power of 2. */
738 	uint32_t lower = product >> bits;
739 
740 	/*
741 	 * The rounding logic works since the first half of numbers mod |power| have a
742 	 * 0 as first bit, and the second half has a 1 as first bit, since |power| is
743 	 * a power of 2. As a 12 bit number, |remainder| is always positive, so we
744 	 * will shift in 0s for a right shift.
745 	 */
746 	return lower + (remainder >> (bits - 1));
747 }
748 
749 static void
scalar_compress(scalar * s,int bits)750 scalar_compress(scalar *s, int bits)
751 {
752 	int i;
753 
754 	for (i = 0; i < DEGREE; i++) {
755 		s->c[i] = compress(s->c[i], bits);
756 	}
757 }
758 
759 static void
scalar_decompress(scalar * s,int bits)760 scalar_decompress(scalar *s, int bits)
761 {
762 	int i;
763 
764 	for (i = 0; i < DEGREE; i++) {
765 		s->c[i] = decompress(s->c[i], bits);
766 	}
767 }
768 
769 static void
vector_compress(vector * a,int bits)770 vector_compress(vector *a, int bits)
771 {
772 	int i;
773 
774 	for (i = 0; i < RANK1024; i++) {
775 		scalar_compress(&a->v[i], bits);
776 	}
777 }
778 
779 static void
vector_decompress(vector * a,int bits)780 vector_decompress(vector *a, int bits)
781 {
782 	int i;
783 
784 	for (i = 0; i < RANK1024; i++) {
785 		scalar_decompress(&a->v[i], bits);
786 	}
787 }
788 
789 struct public_key {
790 	vector t;
791 	uint8_t rho[32];
792 	uint8_t public_key_hash[32];
793 	matrix m;
794 };
795 
796 static struct public_key *
public_key_1024_from_external(const struct MLKEM1024_public_key * external)797 public_key_1024_from_external(const struct MLKEM1024_public_key *external)
798 {
799 	return (struct public_key *)external;
800 }
801 
802 struct private_key {
803 	struct public_key pub;
804 	vector s;
805 	uint8_t fo_failure_secret[32];
806 };
807 
808 static struct private_key *
private_key_1024_from_external(const struct MLKEM1024_private_key * external)809 private_key_1024_from_external(const struct MLKEM1024_private_key *external)
810 {
811 	return (struct private_key *)external;
812 }
813 
814 /*
815  * Calls |MLKEM1024_generate_key_external_entropy| with random bytes from
816  * |RAND_bytes|.
817  */
818 void
MLKEM1024_generate_key(uint8_t out_encoded_public_key[MLKEM1024_PUBLIC_KEY_BYTES],uint8_t optional_out_seed[MLKEM_SEED_BYTES],struct MLKEM1024_private_key * out_private_key)819 MLKEM1024_generate_key(uint8_t out_encoded_public_key[MLKEM1024_PUBLIC_KEY_BYTES],
820     uint8_t optional_out_seed[MLKEM_SEED_BYTES],
821     struct MLKEM1024_private_key *out_private_key)
822 {
823 	uint8_t entropy_buf[MLKEM_SEED_BYTES];
824 	uint8_t *entropy = optional_out_seed != NULL ? optional_out_seed :
825 	    entropy_buf;
826 
827 	arc4random_buf(entropy, MLKEM_SEED_BYTES);
828 	MLKEM1024_generate_key_external_entropy(out_encoded_public_key,
829 	    out_private_key, entropy);
830 }
831 LCRYPTO_ALIAS(MLKEM1024_generate_key);
832 
833 int
MLKEM1024_private_key_from_seed(struct MLKEM1024_private_key * out_private_key,const uint8_t * seed,size_t seed_len)834 MLKEM1024_private_key_from_seed(struct MLKEM1024_private_key *out_private_key,
835     const uint8_t *seed, size_t seed_len)
836 {
837 	uint8_t public_key_bytes[MLKEM1024_PUBLIC_KEY_BYTES];
838 
839 	if (seed_len != MLKEM_SEED_BYTES) {
840 		return 0;
841 	}
842 	MLKEM1024_generate_key_external_entropy(public_key_bytes,
843 	    out_private_key, seed);
844 
845 	return 1;
846 }
847 LCRYPTO_ALIAS(MLKEM1024_private_key_from_seed);
848 
849 static int
mlkem_marshal_public_key(CBB * out,const struct public_key * pub)850 mlkem_marshal_public_key(CBB *out, const struct public_key *pub)
851 {
852 	uint8_t *vector_output;
853 
854 	if (!CBB_add_space(out, &vector_output, kEncodedVectorSize)) {
855 		return 0;
856 	}
857 	vector_encode(vector_output, &pub->t, kLog2Prime);
858 	if (!CBB_add_bytes(out, pub->rho, sizeof(pub->rho))) {
859 		return 0;
860 	}
861 	return 1;
862 }
863 
864 void
MLKEM1024_generate_key_external_entropy(uint8_t out_encoded_public_key[MLKEM1024_PUBLIC_KEY_BYTES],struct MLKEM1024_private_key * out_private_key,const uint8_t entropy[MLKEM_SEED_BYTES])865 MLKEM1024_generate_key_external_entropy(
866     uint8_t out_encoded_public_key[MLKEM1024_PUBLIC_KEY_BYTES],
867     struct MLKEM1024_private_key *out_private_key,
868     const uint8_t entropy[MLKEM_SEED_BYTES])
869 {
870 	struct private_key *priv = private_key_1024_from_external(
871 	    out_private_key);
872 	uint8_t augmented_seed[33];
873 	uint8_t *rho, *sigma;
874 	uint8_t counter = 0;
875 	uint8_t hashed[64];
876 	vector error;
877 	CBB cbb;
878 
879 	memcpy(augmented_seed, entropy, 32);
880 	augmented_seed[32] = RANK1024;
881 	hash_g(hashed, augmented_seed, 33);
882 	rho = hashed;
883 	sigma = hashed + 32;
884 	memcpy(priv->pub.rho, hashed, sizeof(priv->pub.rho));
885 	matrix_expand(&priv->pub.m, rho);
886 	vector_generate_secret_eta_2(&priv->s, &counter, sigma);
887 	vector_ntt(&priv->s);
888 	vector_generate_secret_eta_2(&error, &counter, sigma);
889 	vector_ntt(&error);
890 	matrix_mult_transpose(&priv->pub.t, &priv->pub.m, &priv->s);
891 	vector_add(&priv->pub.t, &error);
892 
893 	/* XXX - error checking. */
894 	CBB_init_fixed(&cbb, out_encoded_public_key, MLKEM1024_PUBLIC_KEY_BYTES);
895 	if (!mlkem_marshal_public_key(&cbb, &priv->pub)) {
896 		abort();
897 	}
898 	CBB_cleanup(&cbb);
899 
900 	hash_h(priv->pub.public_key_hash, out_encoded_public_key,
901 	    MLKEM1024_PUBLIC_KEY_BYTES);
902 	memcpy(priv->fo_failure_secret, entropy + 32, 32);
903 }
904 
905 void
MLKEM1024_public_from_private(struct MLKEM1024_public_key * out_public_key,const struct MLKEM1024_private_key * private_key)906 MLKEM1024_public_from_private(struct MLKEM1024_public_key *out_public_key,
907     const struct MLKEM1024_private_key *private_key)
908 {
909 	struct public_key *const pub = public_key_1024_from_external(
910 	    out_public_key);
911 	const struct private_key *const priv = private_key_1024_from_external(
912 	    private_key);
913 
914 	*pub = priv->pub;
915 }
916 LCRYPTO_ALIAS(MLKEM1024_public_from_private);
917 
918 /*
919  * Encrypts a message with given randomness to the ciphertext in |out|. Without
920  * applying the Fujisaki-Okamoto transform this would not result in a CCA secure
921  * scheme, since lattice schemes are vulnerable to decryption failure oracles.
922  */
923 static void
encrypt_cpa(uint8_t out[MLKEM1024_CIPHERTEXT_BYTES],const struct public_key * pub,const uint8_t message[32],const uint8_t randomness[32])924 encrypt_cpa(uint8_t out[MLKEM1024_CIPHERTEXT_BYTES],
925     const struct public_key *pub, const uint8_t message[32],
926     const uint8_t randomness[32])
927 {
928 	scalar expanded_message, scalar_error;
929 	vector secret, error, u;
930 	uint8_t counter = 0;
931 	uint8_t input[33];
932 	scalar v;
933 
934 	vector_generate_secret_eta_2(&secret, &counter, randomness);
935 	vector_ntt(&secret);
936 	vector_generate_secret_eta_2(&error, &counter, randomness);
937 	memcpy(input, randomness, 32);
938 	input[32] = counter;
939 	scalar_centered_binomial_distribution_eta_2_with_prf(&scalar_error,
940 	    input);
941 	matrix_mult(&u, &pub->m, &secret);
942 	vector_inverse_ntt(&u);
943 	vector_add(&u, &error);
944 	scalar_inner_product(&v, &pub->t, &secret);
945 	scalar_inverse_ntt(&v);
946 	scalar_add(&v, &scalar_error);
947 	scalar_decode_1(&expanded_message, message);
948 	scalar_decompress(&expanded_message, 1);
949 	scalar_add(&v, &expanded_message);
950 	vector_compress(&u, kDU1024);
951 	vector_encode(out, &u, kDU1024);
952 	scalar_compress(&v, kDV1024);
953 	scalar_encode(out + kCompressedVectorSize, &v, kDV1024);
954 }
955 
956 /* Calls MLKEM1024_encap_external_entropy| with random bytes */
957 void
MLKEM1024_encap(uint8_t out_ciphertext[MLKEM1024_CIPHERTEXT_BYTES],uint8_t out_shared_secret[MLKEM_SHARED_SECRET_BYTES],const struct MLKEM1024_public_key * public_key)958 MLKEM1024_encap(uint8_t out_ciphertext[MLKEM1024_CIPHERTEXT_BYTES],
959     uint8_t out_shared_secret[MLKEM_SHARED_SECRET_BYTES],
960     const struct MLKEM1024_public_key *public_key)
961 {
962 	uint8_t entropy[MLKEM_ENCAP_ENTROPY];
963 
964 	arc4random_buf(entropy, MLKEM_ENCAP_ENTROPY);
965 	MLKEM1024_encap_external_entropy(out_ciphertext, out_shared_secret,
966 	    public_key, entropy);
967 }
968 LCRYPTO_ALIAS(MLKEM1024_encap);
969 
970 /* See section 6.2 of the spec. */
971 void
MLKEM1024_encap_external_entropy(uint8_t out_ciphertext[MLKEM1024_CIPHERTEXT_BYTES],uint8_t out_shared_secret[MLKEM_SHARED_SECRET_BYTES],const struct MLKEM1024_public_key * public_key,const uint8_t entropy[MLKEM_ENCAP_ENTROPY])972 MLKEM1024_encap_external_entropy(
973     uint8_t out_ciphertext[MLKEM1024_CIPHERTEXT_BYTES],
974     uint8_t out_shared_secret[MLKEM_SHARED_SECRET_BYTES],
975     const struct MLKEM1024_public_key *public_key,
976     const uint8_t entropy[MLKEM_ENCAP_ENTROPY])
977 {
978 	const struct public_key *pub = public_key_1024_from_external(public_key);
979 	uint8_t key_and_randomness[64];
980 	uint8_t input[64];
981 
982 	memcpy(input, entropy, MLKEM_ENCAP_ENTROPY);
983 	memcpy(input + MLKEM_ENCAP_ENTROPY, pub->public_key_hash,
984 	    sizeof(input) - MLKEM_ENCAP_ENTROPY);
985 	hash_g(key_and_randomness, input, sizeof(input));
986 	encrypt_cpa(out_ciphertext, pub, entropy, key_and_randomness + 32);
987 	memcpy(out_shared_secret, key_and_randomness, 32);
988 }
989 
990 static void
decrypt_cpa(uint8_t out[32],const struct private_key * priv,const uint8_t ciphertext[MLKEM1024_CIPHERTEXT_BYTES])991 decrypt_cpa(uint8_t out[32], const struct private_key *priv,
992     const uint8_t ciphertext[MLKEM1024_CIPHERTEXT_BYTES])
993 {
994 	scalar mask, v;
995 	vector u;
996 
997 	vector_decode(&u, ciphertext, kDU1024);
998 	vector_decompress(&u, kDU1024);
999 	vector_ntt(&u);
1000 	scalar_decode(&v, ciphertext + kCompressedVectorSize, kDV1024);
1001 	scalar_decompress(&v, kDV1024);
1002 	scalar_inner_product(&mask, &priv->s, &u);
1003 	scalar_inverse_ntt(&mask);
1004 	scalar_sub(&v, &mask);
1005 	scalar_compress(&v, 1);
1006 	scalar_encode_1(out, &v);
1007 }
1008 
1009 /* See section 6.3 */
1010 int
MLKEM1024_decap(uint8_t out_shared_secret[MLKEM_SHARED_SECRET_BYTES],const uint8_t * ciphertext,size_t ciphertext_len,const struct MLKEM1024_private_key * private_key)1011 MLKEM1024_decap(uint8_t out_shared_secret[MLKEM_SHARED_SECRET_BYTES],
1012     const uint8_t *ciphertext, size_t ciphertext_len,
1013     const struct MLKEM1024_private_key *private_key)
1014 {
1015 	const struct private_key *priv = private_key_1024_from_external(
1016 	    private_key);
1017 	uint8_t expected_ciphertext[MLKEM1024_CIPHERTEXT_BYTES];
1018 	uint8_t key_and_randomness[64];
1019 	uint8_t failure_key[32];
1020 	uint8_t decrypted[64];
1021 	uint8_t mask;
1022 	int i;
1023 
1024 	if (ciphertext_len != MLKEM1024_CIPHERTEXT_BYTES) {
1025 		arc4random_buf(out_shared_secret, MLKEM_SHARED_SECRET_BYTES);
1026 		return 0;
1027 	}
1028 
1029 	decrypt_cpa(decrypted, priv, ciphertext);
1030 	memcpy(decrypted + 32, priv->pub.public_key_hash,
1031 	    sizeof(decrypted) - 32);
1032 	hash_g(key_and_randomness, decrypted, sizeof(decrypted));
1033 	encrypt_cpa(expected_ciphertext, &priv->pub, decrypted,
1034 	    key_and_randomness + 32);
1035 	kdf(failure_key, priv->fo_failure_secret, ciphertext, ciphertext_len);
1036 	mask = constant_time_eq_int_8(memcmp(ciphertext, expected_ciphertext,
1037 	    sizeof(expected_ciphertext)), 0);
1038 	for (i = 0; i < MLKEM_SHARED_SECRET_BYTES; i++) {
1039 		out_shared_secret[i] = constant_time_select_8(mask,
1040 		    key_and_randomness[i], failure_key[i]);
1041 	}
1042 
1043 	return 1;
1044 }
1045 LCRYPTO_ALIAS(MLKEM1024_decap);
1046 
1047 int
MLKEM1024_marshal_public_key(CBB * out,const struct MLKEM1024_public_key * public_key)1048 MLKEM1024_marshal_public_key(CBB *out,
1049     const struct MLKEM1024_public_key *public_key)
1050 {
1051 	return mlkem_marshal_public_key(out,
1052 	    public_key_1024_from_external(public_key));
1053 }
1054 LCRYPTO_ALIAS(MLKEM1024_marshal_public_key);
1055 
1056 /*
1057  * mlkem_parse_public_key_no_hash parses |in| into |pub| but doesn't calculate
1058  * the value of |pub->public_key_hash|.
1059  */
1060 static int
mlkem_parse_public_key_no_hash(struct public_key * pub,CBS * in)1061 mlkem_parse_public_key_no_hash(struct public_key *pub, CBS *in)
1062 {
1063 	CBS t_bytes;
1064 
1065 	if (!CBS_get_bytes(in, &t_bytes, kEncodedVectorSize) ||
1066 	    !vector_decode(&pub->t, CBS_data(&t_bytes), kLog2Prime)) {
1067 		return 0;
1068 	}
1069 	memcpy(pub->rho, CBS_data(in), sizeof(pub->rho));
1070 	if (!CBS_skip(in, sizeof(pub->rho)))
1071 		return 0;
1072 	matrix_expand(&pub->m, pub->rho);
1073 	return 1;
1074 }
1075 
1076 int
MLKEM1024_parse_public_key(struct MLKEM1024_public_key * public_key,CBS * in)1077 MLKEM1024_parse_public_key(struct MLKEM1024_public_key *public_key, CBS *in)
1078 {
1079 	struct public_key *pub = public_key_1024_from_external(public_key);
1080 	CBS orig_in = *in;
1081 
1082 	if (!mlkem_parse_public_key_no_hash(pub, in) ||
1083 	    CBS_len(in) != 0) {
1084 		return 0;
1085 	}
1086 	hash_h(pub->public_key_hash, CBS_data(&orig_in), CBS_len(&orig_in));
1087 	return 1;
1088 }
1089 LCRYPTO_ALIAS(MLKEM1024_parse_public_key);
1090 
1091 int
MLKEM1024_marshal_private_key(CBB * out,const struct MLKEM1024_private_key * private_key)1092 MLKEM1024_marshal_private_key(CBB *out,
1093     const struct MLKEM1024_private_key *private_key)
1094 {
1095 	const struct private_key *const priv = private_key_1024_from_external(
1096 	    private_key);
1097 	uint8_t *s_output;
1098 
1099 	if (!CBB_add_space(out, &s_output, kEncodedVectorSize)) {
1100 		return 0;
1101 	}
1102 	vector_encode(s_output, &priv->s, kLog2Prime);
1103 	if (!mlkem_marshal_public_key(out, &priv->pub) ||
1104 	    !CBB_add_bytes(out, priv->pub.public_key_hash,
1105 	    sizeof(priv->pub.public_key_hash)) ||
1106 	    !CBB_add_bytes(out, priv->fo_failure_secret,
1107 	    sizeof(priv->fo_failure_secret))) {
1108 		return 0;
1109 	}
1110 	return 1;
1111 }
1112 
1113 int
MLKEM1024_parse_private_key(struct MLKEM1024_private_key * out_private_key,CBS * in)1114 MLKEM1024_parse_private_key(struct MLKEM1024_private_key *out_private_key,
1115     CBS *in)
1116 {
1117 	struct private_key *const priv = private_key_1024_from_external(
1118 	    out_private_key);
1119 	CBS s_bytes;
1120 
1121 	if (!CBS_get_bytes(in, &s_bytes, kEncodedVectorSize) ||
1122 	    !vector_decode(&priv->s, CBS_data(&s_bytes), kLog2Prime) ||
1123 	    !mlkem_parse_public_key_no_hash(&priv->pub, in)) {
1124 		return 0;
1125 	}
1126 	memcpy(priv->pub.public_key_hash, CBS_data(in),
1127 	    sizeof(priv->pub.public_key_hash));
1128 	if (!CBS_skip(in, sizeof(priv->pub.public_key_hash)))
1129 		return 0;
1130 	memcpy(priv->fo_failure_secret, CBS_data(in),
1131 	    sizeof(priv->fo_failure_secret));
1132 	if (!CBS_skip(in, sizeof(priv->fo_failure_secret)))
1133 		return 0;
1134 	if (CBS_len(in) != 0)
1135 		return 0;
1136 
1137 	return 1;
1138 }
1139 LCRYPTO_ALIAS(MLKEM1024_parse_private_key);
1140