1 #include <stddef.h>
2 #include <stdint.h>
3 #include <immintrin.h>
4 #include <string.h>
5 #include "align.h"
6 #include "params.h"
7 #include "indcpa.h"
8 #include "polyvec.h"
9 #include "poly.h"
10 #include "ntt.h"
11 #include "cbd.h"
12 #include "rejsample.h"
13 #include "symmetric.h"
14 #include "randombytes.h"
15 
16 /*************************************************
17 * Name:        pack_pk
18 *
19 * Description: Serialize the public key as concatenation of the
20 *              serialized vector of polynomials pk and the
21 *              public seed used to generate the matrix A.
22 *              The polynomial coefficients in pk are assumed to
23 *              lie in the invertal [0,q], i.e. pk must be reduced
24 *              by polyvec_reduce().
25 *
26 * Arguments:   uint8_t *r: pointer to the output serialized public key
27 *              polyvec *pk: pointer to the input public-key polyvec
28 *              const uint8_t *seed: pointer to the input public seed
29 **************************************************/
pack_pk(uint8_t r[KYBER_INDCPA_PUBLICKEYBYTES],polyvec * pk,const uint8_t seed[KYBER_SYMBYTES])30 static void pack_pk(uint8_t r[KYBER_INDCPA_PUBLICKEYBYTES],
31                     polyvec *pk,
32                     const uint8_t seed[KYBER_SYMBYTES])
33 {
34   polyvec_tobytes(r, pk);
35   memcpy(r+KYBER_POLYVECBYTES, seed, KYBER_SYMBYTES);
36 }
37 
38 /*************************************************
39 * Name:        unpack_pk
40 *
41 * Description: De-serialize public key from a byte array;
42 *              approximate inverse of pack_pk
43 *
44 * Arguments:   - polyvec *pk: pointer to output public-key polynomial vector
45 *              - uint8_t *seed: pointer to output seed to generate matrix A
46 *              - const uint8_t *packedpk: pointer to input serialized public key
47 **************************************************/
unpack_pk(polyvec * pk,uint8_t seed[KYBER_SYMBYTES],const uint8_t packedpk[KYBER_INDCPA_PUBLICKEYBYTES])48 static void unpack_pk(polyvec *pk,
49                       uint8_t seed[KYBER_SYMBYTES],
50                       const uint8_t packedpk[KYBER_INDCPA_PUBLICKEYBYTES])
51 {
52   polyvec_frombytes(pk, packedpk);
53   memcpy(seed, packedpk+KYBER_POLYVECBYTES, KYBER_SYMBYTES);
54 }
55 
56 /*************************************************
57 * Name:        pack_sk
58 *
59 * Description: Serialize the secret key.
60 *              The polynomial coefficients in sk are assumed to
61 *              lie in the invertal [0,q], i.e. sk must be reduced
62 *              by polyvec_reduce().
63 *
64 * Arguments:   - uint8_t *r: pointer to output serialized secret key
65 *              - polyvec *sk: pointer to input vector of polynomials (secret key)
66 **************************************************/
pack_sk(uint8_t r[KYBER_INDCPA_SECRETKEYBYTES],polyvec * sk)67 static void pack_sk(uint8_t r[KYBER_INDCPA_SECRETKEYBYTES], polyvec *sk)
68 {
69   polyvec_tobytes(r, sk);
70 }
71 
72 /*************************************************
73 * Name:        unpack_sk
74 *
75 * Description: De-serialize the secret key; inverse of pack_sk
76 *
77 * Arguments:   - polyvec *sk: pointer to output vector of polynomials (secret key)
78 *              - const uint8_t *packedsk: pointer to input serialized secret key
79 **************************************************/
unpack_sk(polyvec * sk,const uint8_t packedsk[KYBER_INDCPA_SECRETKEYBYTES])80 static void unpack_sk(polyvec *sk, const uint8_t packedsk[KYBER_INDCPA_SECRETKEYBYTES])
81 {
82   polyvec_frombytes(sk, packedsk);
83 }
84 
85 /*************************************************
86 * Name:        pack_ciphertext
87 *
88 * Description: Serialize the ciphertext as concatenation of the
89 *              compressed and serialized vector of polynomials b
90 *              and the compressed and serialized polynomial v.
91 *              The polynomial coefficients in b and v are assumed to
92 *              lie in the invertal [0,q], i.e. b and v must be reduced
93 *              by polyvec_reduce() and poly_reduce(), respectively.
94 *
95 * Arguments:   uint8_t *r: pointer to the output serialized ciphertext
96 *              poly *pk: pointer to the input vector of polynomials b
97 *              poly *v: pointer to the input polynomial v
98 **************************************************/
pack_ciphertext(uint8_t r[KYBER_INDCPA_BYTES],polyvec * b,poly * v)99 static void pack_ciphertext(uint8_t r[KYBER_INDCPA_BYTES], polyvec *b, poly *v)
100 {
101   polyvec_compress(r, b);
102   poly_compress(r+KYBER_POLYVECCOMPRESSEDBYTES, v);
103 }
104 
105 /*************************************************
106 * Name:        unpack_ciphertext
107 *
108 * Description: De-serialize and decompress ciphertext from a byte array;
109 *              approximate inverse of pack_ciphertext
110 *
111 * Arguments:   - polyvec *b: pointer to the output vector of polynomials b
112 *              - poly *v: pointer to the output polynomial v
113 *              - const uint8_t *c: pointer to the input serialized ciphertext
114 **************************************************/
unpack_ciphertext(polyvec * b,poly * v,const uint8_t c[KYBER_INDCPA_BYTES])115 static void unpack_ciphertext(polyvec *b, poly *v, const uint8_t c[KYBER_INDCPA_BYTES])
116 {
117   polyvec_decompress(b, c);
118   poly_decompress(v, c+KYBER_POLYVECCOMPRESSEDBYTES);
119 }
120 
121 /*************************************************
122 * Name:        rej_uniform
123 *
124 * Description: Run rejection sampling on uniform random bytes to generate
125 *              uniform random integers mod q
126 *
127 * Arguments:   - int16_t *r: pointer to output array
128 *              - unsigned int len: requested number of 16-bit integers (uniform mod q)
129 *              - const uint8_t *buf: pointer to input buffer (assumed to be uniformly random bytes)
130 *              - unsigned int buflen: length of input buffer in bytes
131 *
132 * Returns number of sampled 16-bit integers (at most len)
133 **************************************************/
rej_uniform(int16_t * r,unsigned int len,const uint8_t * buf,unsigned int buflen)134 static unsigned int rej_uniform(int16_t *r,
135                                 unsigned int len,
136                                 const uint8_t *buf,
137                                 unsigned int buflen)
138 {
139   unsigned int ctr, pos;
140   uint16_t val0, val1;
141 
142   ctr = pos = 0;
143   while(ctr < len && pos <= buflen - 3) {  // buflen is always at least 3
144     val0 = ((buf[pos+0] >> 0) | ((uint16_t)buf[pos+1] << 8)) & 0xFFF;
145     val1 = ((buf[pos+1] >> 4) | ((uint16_t)buf[pos+2] << 4)) & 0xFFF;
146     pos += 3;
147 
148     if(val0 < KYBER_Q)
149       r[ctr++] = val0;
150     if(ctr < len && val1 < KYBER_Q)
151       r[ctr++] = val1;
152   }
153 
154   return ctr;
155 }
156 
157 #define gen_a(A,B)  gen_matrix(A,B,0)
158 #define gen_at(A,B) gen_matrix(A,B,1)
159 
160 /*************************************************
161 * Name:        gen_matrix
162 *
163 * Description: Deterministically generate matrix A (or the transpose of A)
164 *              from a seed. Entries of the matrix are polynomials that look
165 *              uniformly random. Performs rejection sampling on output of
166 *              a XOF
167 *
168 * Arguments:   - polyvec *a: pointer to ouptput matrix A
169 *              - const uint8_t *seed: pointer to input seed
170 *              - int transposed: boolean deciding whether A or A^T is generated
171 **************************************************/
172 #ifdef KYBER_90S
gen_matrix(polyvec * a,const uint8_t seed[32],int transposed)173 void gen_matrix(polyvec *a, const uint8_t seed[32], int transposed)
174 {
175   unsigned int ctr, i, j, k;
176   unsigned int buflen, off;
177   uint64_t nonce = 0;
178   ALIGNED_UINT8(REJ_UNIFORM_AVX_NBLOCKS*AES256CTR_BLOCKBYTES) buf;
179   aes256ctr_ctx state;
180 
181   aes256ctr_init(&state, seed, 0);
182 
183   for(i=0;i<KYBER_K;i++) {
184     for(j=0;j<KYBER_K;j++) {
185       if(transposed)
186         nonce = (j << 8) | i;
187       else
188         nonce = (i << 8) | j;
189 
190       state.n = _mm_loadl_epi64((__m128i *)&nonce);
191       aes256ctr_squeezeblocks(buf.coeffs, REJ_UNIFORM_AVX_NBLOCKS, &state);
192       buflen = REJ_UNIFORM_AVX_NBLOCKS*AES256CTR_BLOCKBYTES;
193       ctr = rej_uniform_avx(a[i].vec[j].coeffs, buf.coeffs);
194 
195       while(ctr < KYBER_N) {
196         off = buflen % 3;
197         for(k = 0; k < off; k++)
198           buf.coeffs[k] = buf.coeffs[buflen - off + k];
199         aes256ctr_squeezeblocks(buf.coeffs + off, 1, &state);
200         buflen = off + AES256CTR_BLOCKBYTES;
201         ctr += rej_uniform(a[i].vec[j].coeffs + ctr, KYBER_N - ctr, buf.coeffs, buflen);
202       }
203 
204       poly_nttunpack(&a[i].vec[j]);
205     }
206   }
207 }
208 #else
209 #if KYBER_K == 2
gen_matrix(polyvec * a,const uint8_t seed[32],int transposed)210 void gen_matrix(polyvec *a, const uint8_t seed[32], int transposed)
211 {
212   unsigned int ctr0, ctr1, ctr2, ctr3;
213   ALIGNED_UINT8(REJ_UNIFORM_AVX_NBLOCKS*SHAKE128_RATE) buf[4];
214   __m256i f;
215   shake128x4incctx state;
216 
217   f = _mm256_loadu_si256((__m256i *)seed);
218   _mm256_store_si256(buf[0].vec, f);
219   _mm256_store_si256(buf[1].vec, f);
220   _mm256_store_si256(buf[2].vec, f);
221   _mm256_store_si256(buf[3].vec, f);
222 
223   if(transposed) {
224     buf[0].coeffs[32] = 0;
225     buf[0].coeffs[33] = 0;
226     buf[1].coeffs[32] = 0;
227     buf[1].coeffs[33] = 1;
228     buf[2].coeffs[32] = 1;
229     buf[2].coeffs[33] = 0;
230     buf[3].coeffs[32] = 1;
231     buf[3].coeffs[33] = 1;
232   }
233   else {
234     buf[0].coeffs[32] = 0;
235     buf[0].coeffs[33] = 0;
236     buf[1].coeffs[32] = 1;
237     buf[1].coeffs[33] = 0;
238     buf[2].coeffs[32] = 0;
239     buf[2].coeffs[33] = 1;
240     buf[3].coeffs[32] = 1;
241     buf[3].coeffs[33] = 1;
242   }
243 
244   shake128x4_inc_init(&state);
245   shake128x4_absorb_once(&state, buf[0].coeffs, buf[1].coeffs, buf[2].coeffs, buf[3].coeffs, 34);
246   shake128x4_squeezeblocks(buf[0].coeffs, buf[1].coeffs, buf[2].coeffs, buf[3].coeffs, REJ_UNIFORM_AVX_NBLOCKS, &state);
247 
248   ctr0 = rej_uniform_avx(a[0].vec[0].coeffs, buf[0].coeffs);
249   ctr1 = rej_uniform_avx(a[0].vec[1].coeffs, buf[1].coeffs);
250   ctr2 = rej_uniform_avx(a[1].vec[0].coeffs, buf[2].coeffs);
251   ctr3 = rej_uniform_avx(a[1].vec[1].coeffs, buf[3].coeffs);
252 
253   while(ctr0 < KYBER_N || ctr1 < KYBER_N || ctr2 < KYBER_N || ctr3 < KYBER_N) {
254     shake128x4_squeezeblocks(buf[0].coeffs, buf[1].coeffs, buf[2].coeffs, buf[3].coeffs, 1, &state);
255 
256     ctr0 += rej_uniform(a[0].vec[0].coeffs + ctr0, KYBER_N - ctr0, buf[0].coeffs, SHAKE128_RATE);
257     ctr1 += rej_uniform(a[0].vec[1].coeffs + ctr1, KYBER_N - ctr1, buf[1].coeffs, SHAKE128_RATE);
258     ctr2 += rej_uniform(a[1].vec[0].coeffs + ctr2, KYBER_N - ctr2, buf[2].coeffs, SHAKE128_RATE);
259     ctr3 += rej_uniform(a[1].vec[1].coeffs + ctr3, KYBER_N - ctr3, buf[3].coeffs, SHAKE128_RATE);
260   }
261 
262   poly_nttunpack(&a[0].vec[0]);
263   poly_nttunpack(&a[0].vec[1]);
264   poly_nttunpack(&a[1].vec[0]);
265   poly_nttunpack(&a[1].vec[1]);
266   shake128x4_inc_ctx_release(&state);
267 }
268 #elif KYBER_K == 3
gen_matrix(polyvec * a,const uint8_t seed[32],int transposed)269 void gen_matrix(polyvec *a, const uint8_t seed[32], int transposed)
270 {
271   unsigned int ctr0, ctr1, ctr2, ctr3;
272   ALIGNED_UINT8(REJ_UNIFORM_AVX_NBLOCKS*SHAKE128_RATE) buf[4];
273   __m256i f;
274   shake128x4incctx state;
275   shake128incctx state1x;
276 
277   f = _mm256_loadu_si256((__m256i *)seed);
278   _mm256_store_si256(buf[0].vec, f);
279   _mm256_store_si256(buf[1].vec, f);
280   _mm256_store_si256(buf[2].vec, f);
281   _mm256_store_si256(buf[3].vec, f);
282 
283   if(transposed) {
284     buf[0].coeffs[32] = 0;
285     buf[0].coeffs[33] = 0;
286     buf[1].coeffs[32] = 0;
287     buf[1].coeffs[33] = 1;
288     buf[2].coeffs[32] = 0;
289     buf[2].coeffs[33] = 2;
290     buf[3].coeffs[32] = 1;
291     buf[3].coeffs[33] = 0;
292   }
293   else {
294     buf[0].coeffs[32] = 0;
295     buf[0].coeffs[33] = 0;
296     buf[1].coeffs[32] = 1;
297     buf[1].coeffs[33] = 0;
298     buf[2].coeffs[32] = 2;
299     buf[2].coeffs[33] = 0;
300     buf[3].coeffs[32] = 0;
301     buf[3].coeffs[33] = 1;
302   }
303 
304   shake128x4_inc_init(&state);
305   shake128x4_absorb_once(&state, buf[0].coeffs, buf[1].coeffs, buf[2].coeffs, buf[3].coeffs, 34);
306   shake128x4_squeezeblocks(buf[0].coeffs, buf[1].coeffs, buf[2].coeffs, buf[3].coeffs, REJ_UNIFORM_AVX_NBLOCKS, &state);
307 
308   ctr0 = rej_uniform_avx(a[0].vec[0].coeffs, buf[0].coeffs);
309   ctr1 = rej_uniform_avx(a[0].vec[1].coeffs, buf[1].coeffs);
310   ctr2 = rej_uniform_avx(a[0].vec[2].coeffs, buf[2].coeffs);
311   ctr3 = rej_uniform_avx(a[1].vec[0].coeffs, buf[3].coeffs);
312 
313   while(ctr0 < KYBER_N || ctr1 < KYBER_N || ctr2 < KYBER_N || ctr3 < KYBER_N) {
314     shake128x4_squeezeblocks(buf[0].coeffs, buf[1].coeffs, buf[2].coeffs, buf[3].coeffs, 1, &state);
315 
316     ctr0 += rej_uniform(a[0].vec[0].coeffs + ctr0, KYBER_N - ctr0, buf[0].coeffs, SHAKE128_RATE);
317     ctr1 += rej_uniform(a[0].vec[1].coeffs + ctr1, KYBER_N - ctr1, buf[1].coeffs, SHAKE128_RATE);
318     ctr2 += rej_uniform(a[0].vec[2].coeffs + ctr2, KYBER_N - ctr2, buf[2].coeffs, SHAKE128_RATE);
319     ctr3 += rej_uniform(a[1].vec[0].coeffs + ctr3, KYBER_N - ctr3, buf[3].coeffs, SHAKE128_RATE);
320   }
321 
322   poly_nttunpack(&a[0].vec[0]);
323   poly_nttunpack(&a[0].vec[1]);
324   poly_nttunpack(&a[0].vec[2]);
325   poly_nttunpack(&a[1].vec[0]);
326 
327   f = _mm256_loadu_si256((__m256i *)seed);
328   _mm256_store_si256(buf[0].vec, f);
329   _mm256_store_si256(buf[1].vec, f);
330   _mm256_store_si256(buf[2].vec, f);
331   _mm256_store_si256(buf[3].vec, f);
332 
333   if(transposed) {
334     buf[0].coeffs[32] = 1;
335     buf[0].coeffs[33] = 1;
336     buf[1].coeffs[32] = 1;
337     buf[1].coeffs[33] = 2;
338     buf[2].coeffs[32] = 2;
339     buf[2].coeffs[33] = 0;
340     buf[3].coeffs[32] = 2;
341     buf[3].coeffs[33] = 1;
342   }
343   else {
344     buf[0].coeffs[32] = 1;
345     buf[0].coeffs[33] = 1;
346     buf[1].coeffs[32] = 2;
347     buf[1].coeffs[33] = 1;
348     buf[2].coeffs[32] = 0;
349     buf[2].coeffs[33] = 2;
350     buf[3].coeffs[32] = 1;
351     buf[3].coeffs[33] = 2;
352   }
353 
354   shake128x4_absorb_once(&state, buf[0].coeffs, buf[1].coeffs, buf[2].coeffs, buf[3].coeffs, 34);
355   shake128x4_squeezeblocks(buf[0].coeffs, buf[1].coeffs, buf[2].coeffs, buf[3].coeffs, REJ_UNIFORM_AVX_NBLOCKS, &state);
356 
357   ctr0 = rej_uniform_avx(a[1].vec[1].coeffs, buf[0].coeffs);
358   ctr1 = rej_uniform_avx(a[1].vec[2].coeffs, buf[1].coeffs);
359   ctr2 = rej_uniform_avx(a[2].vec[0].coeffs, buf[2].coeffs);
360   ctr3 = rej_uniform_avx(a[2].vec[1].coeffs, buf[3].coeffs);
361 
362   while(ctr0 < KYBER_N || ctr1 < KYBER_N || ctr2 < KYBER_N || ctr3 < KYBER_N) {
363     shake128x4_squeezeblocks(buf[0].coeffs, buf[1].coeffs, buf[2].coeffs, buf[3].coeffs, 1, &state);
364 
365     ctr0 += rej_uniform(a[1].vec[1].coeffs + ctr0, KYBER_N - ctr0, buf[0].coeffs, SHAKE128_RATE);
366     ctr1 += rej_uniform(a[1].vec[2].coeffs + ctr1, KYBER_N - ctr1, buf[1].coeffs, SHAKE128_RATE);
367     ctr2 += rej_uniform(a[2].vec[0].coeffs + ctr2, KYBER_N - ctr2, buf[2].coeffs, SHAKE128_RATE);
368     ctr3 += rej_uniform(a[2].vec[1].coeffs + ctr3, KYBER_N - ctr3, buf[3].coeffs, SHAKE128_RATE);
369   }
370   shake128x4_inc_ctx_release(&state);
371 
372   poly_nttunpack(&a[1].vec[1]);
373   poly_nttunpack(&a[1].vec[2]);
374   poly_nttunpack(&a[2].vec[0]);
375   poly_nttunpack(&a[2].vec[1]);
376 
377   f = _mm256_loadu_si256((__m256i *)seed);
378   _mm256_store_si256(buf[0].vec, f);
379   buf[0].coeffs[32] = 2;
380   buf[0].coeffs[33] = 2;
381 
382   shake128_inc_init(&state1x);
383   shake128_absorb_once(&state1x, buf[0].coeffs, 34);
384   shake128_squeezeblocks(buf[0].coeffs, REJ_UNIFORM_AVX_NBLOCKS, &state1x);
385   ctr0 = rej_uniform_avx(a[2].vec[2].coeffs, buf[0].coeffs);
386   while(ctr0 < KYBER_N) {
387     shake128_squeezeblocks(buf[0].coeffs, 1, &state1x);
388     ctr0 += rej_uniform(a[2].vec[2].coeffs + ctr0, KYBER_N - ctr0, buf[0].coeffs, SHAKE128_RATE);
389   }
390   shake128_inc_ctx_release(&state1x);
391 
392   poly_nttunpack(&a[2].vec[2]);
393 }
394 #elif KYBER_K == 4
gen_matrix(polyvec * a,const uint8_t seed[32],int transposed)395 void gen_matrix(polyvec *a, const uint8_t seed[32], int transposed)
396 {
397   unsigned int i, ctr0, ctr1, ctr2, ctr3;
398   ALIGNED_UINT8(REJ_UNIFORM_AVX_NBLOCKS*SHAKE128_RATE) buf[4];
399   __m256i f;
400   shake128x4incctx state;
401   shake128x4_inc_init(&state);
402 
403   for(i=0;i<4;i++) {
404     f = _mm256_loadu_si256((__m256i *)seed);
405     _mm256_store_si256(buf[0].vec, f);
406     _mm256_store_si256(buf[1].vec, f);
407     _mm256_store_si256(buf[2].vec, f);
408     _mm256_store_si256(buf[3].vec, f);
409 
410     if(transposed) {
411       buf[0].coeffs[32] = i;
412       buf[0].coeffs[33] = 0;
413       buf[1].coeffs[32] = i;
414       buf[1].coeffs[33] = 1;
415       buf[2].coeffs[32] = i;
416       buf[2].coeffs[33] = 2;
417       buf[3].coeffs[32] = i;
418       buf[3].coeffs[33] = 3;
419     }
420     else {
421       buf[0].coeffs[32] = 0;
422       buf[0].coeffs[33] = i;
423       buf[1].coeffs[32] = 1;
424       buf[1].coeffs[33] = i;
425       buf[2].coeffs[32] = 2;
426       buf[2].coeffs[33] = i;
427       buf[3].coeffs[32] = 3;
428       buf[3].coeffs[33] = i;
429     }
430 
431     shake128x4_absorb_once(&state, buf[0].coeffs, buf[1].coeffs, buf[2].coeffs, buf[3].coeffs, 34);
432     shake128x4_squeezeblocks(buf[0].coeffs, buf[1].coeffs, buf[2].coeffs, buf[3].coeffs, REJ_UNIFORM_AVX_NBLOCKS, &state);
433 
434     ctr0 = rej_uniform_avx(a[i].vec[0].coeffs, buf[0].coeffs);
435     ctr1 = rej_uniform_avx(a[i].vec[1].coeffs, buf[1].coeffs);
436     ctr2 = rej_uniform_avx(a[i].vec[2].coeffs, buf[2].coeffs);
437     ctr3 = rej_uniform_avx(a[i].vec[3].coeffs, buf[3].coeffs);
438 
439     while(ctr0 < KYBER_N || ctr1 < KYBER_N || ctr2 < KYBER_N || ctr3 < KYBER_N) {
440       shake128x4_squeezeblocks(buf[0].coeffs, buf[1].coeffs, buf[2].coeffs, buf[3].coeffs, 1, &state);
441 
442       ctr0 += rej_uniform(a[i].vec[0].coeffs + ctr0, KYBER_N - ctr0, buf[0].coeffs, SHAKE128_RATE);
443       ctr1 += rej_uniform(a[i].vec[1].coeffs + ctr1, KYBER_N - ctr1, buf[1].coeffs, SHAKE128_RATE);
444       ctr2 += rej_uniform(a[i].vec[2].coeffs + ctr2, KYBER_N - ctr2, buf[2].coeffs, SHAKE128_RATE);
445       ctr3 += rej_uniform(a[i].vec[3].coeffs + ctr3, KYBER_N - ctr3, buf[3].coeffs, SHAKE128_RATE);
446     }
447 
448     poly_nttunpack(&a[i].vec[0]);
449     poly_nttunpack(&a[i].vec[1]);
450     poly_nttunpack(&a[i].vec[2]);
451     poly_nttunpack(&a[i].vec[3]);
452   }
453   shake128x4_inc_ctx_release(&state);
454 }
455 #endif
456 #endif
457 
458 /*************************************************
459 * Name:        indcpa_keypair
460 *
461 * Description: Generates public and private key for the CPA-secure
462 *              public-key encryption scheme underlying Kyber
463 *
464 * Arguments:   - uint8_t *pk: pointer to output public key
465 *                             (of length KYBER_INDCPA_PUBLICKEYBYTES bytes)
466 *              - uint8_t *sk: pointer to output private key
467                               (of length KYBER_INDCPA_SECRETKEYBYTES bytes)
468 **************************************************/
indcpa_keypair(uint8_t pk[KYBER_INDCPA_PUBLICKEYBYTES],uint8_t sk[KYBER_INDCPA_SECRETKEYBYTES])469 void indcpa_keypair(uint8_t pk[KYBER_INDCPA_PUBLICKEYBYTES],
470                     uint8_t sk[KYBER_INDCPA_SECRETKEYBYTES])
471 {
472   unsigned int i;
473   uint8_t buf[2*KYBER_SYMBYTES];
474   const uint8_t *publicseed = buf;
475   const uint8_t *noiseseed = buf + KYBER_SYMBYTES;
476   polyvec a[KYBER_K], e, pkpv, skpv;
477 
478   randombytes(buf, KYBER_SYMBYTES);
479   hash_g(buf, buf, KYBER_SYMBYTES);
480 
481   gen_a(a, publicseed);
482 
483 #ifdef KYBER_90S
484 #define NOISE_NBLOCKS ((KYBER_ETA1*KYBER_N/4)/AES256CTR_BLOCKBYTES) /* Assumes divisibility */
485   uint64_t nonce = 0;
486   ALIGNED_UINT8(NOISE_NBLOCKS*AES256CTR_BLOCKBYTES+32) coins; // +32 bytes as required by poly_cbd_eta1
487   aes256ctr_ctx state;
488   aes256ctr_init(&state, noiseseed, nonce++);
489   for(i=0;i<KYBER_K;i++) {
490     aes256ctr_squeezeblocks(coins.coeffs, NOISE_NBLOCKS, &state);
491     state.n = _mm_loadl_epi64((__m128i *)&nonce);
492     nonce += 1;
493     poly_cbd_eta1(&skpv.vec[i], coins.vec);
494   }
495   for(i=0;i<KYBER_K;i++) {
496     aes256ctr_squeezeblocks(coins.coeffs, NOISE_NBLOCKS, &state);
497     state.n = _mm_loadl_epi64((__m128i *)&nonce);
498     nonce += 1;
499     poly_cbd_eta1(&e.vec[i], coins.vec);
500   }
501 #else
502 #if KYBER_K == 2
503   poly_getnoise_eta1_4x(skpv.vec+0, skpv.vec+1, e.vec+0, e.vec+1, noiseseed, 0, 1, 2, 3);
504 #elif KYBER_K == 3
505   poly_getnoise_eta1_4x(skpv.vec+0, skpv.vec+1, skpv.vec+2, e.vec+0, noiseseed, 0, 1, 2, 3);
506   poly_getnoise_eta1_4x(e.vec+1, e.vec+2, pkpv.vec+0, pkpv.vec+1, noiseseed, 4, 5, 6, 7);
507 #elif KYBER_K == 4
508   poly_getnoise_eta1_4x(skpv.vec+0, skpv.vec+1, skpv.vec+2, skpv.vec+3, noiseseed,  0, 1, 2, 3);
509   poly_getnoise_eta1_4x(e.vec+0, e.vec+1, e.vec+2, e.vec+3, noiseseed, 4, 5, 6, 7);
510 #endif
511 #endif
512 
513   polyvec_ntt(&skpv);
514   polyvec_reduce(&skpv);
515   polyvec_ntt(&e);
516 
517   // matrix-vector multiplication
518   for(i=0;i<KYBER_K;i++) {
519     polyvec_basemul_acc_montgomery(&pkpv.vec[i], &a[i], &skpv);
520     poly_tomont(&pkpv.vec[i]);
521   }
522 
523   polyvec_add(&pkpv, &pkpv, &e);
524   polyvec_reduce(&pkpv);
525 
526   pack_sk(sk, &skpv);
527   pack_pk(pk, &pkpv, publicseed);
528 }
529 
530 /*************************************************
531 * Name:        indcpa_enc
532 *
533 * Description: Encryption function of the CPA-secure
534 *              public-key encryption scheme underlying Kyber.
535 *
536 * Arguments:   - uint8_t *c: pointer to output ciphertext
537 *                            (of length KYBER_INDCPA_BYTES bytes)
538 *              - const uint8_t *m: pointer to input message
539 *                                  (of length KYBER_INDCPA_MSGBYTES bytes)
540 *              - const uint8_t *pk: pointer to input public key
541 *                                   (of length KYBER_INDCPA_PUBLICKEYBYTES)
542 *              - const uint8_t *coins: pointer to input random coins used as seed
543 *                                      (of length KYBER_SYMBYTES) to deterministically
544 *                                      generate all randomness
545 **************************************************/
indcpa_enc(uint8_t c[KYBER_INDCPA_BYTES],const uint8_t m[KYBER_INDCPA_MSGBYTES],const uint8_t pk[KYBER_INDCPA_PUBLICKEYBYTES],const uint8_t coins[KYBER_SYMBYTES])546 void indcpa_enc(uint8_t c[KYBER_INDCPA_BYTES],
547                 const uint8_t m[KYBER_INDCPA_MSGBYTES],
548                 const uint8_t pk[KYBER_INDCPA_PUBLICKEYBYTES],
549                 const uint8_t coins[KYBER_SYMBYTES])
550 {
551   unsigned int i;
552   uint8_t seed[KYBER_SYMBYTES];
553   polyvec sp, pkpv, ep, at[KYBER_K], b;
554   poly v, k, epp;
555 
556   unpack_pk(&pkpv, seed, pk);
557   poly_frommsg(&k, m);
558   gen_at(at, seed);
559 
560 #ifdef KYBER_90S
561 #define NOISE_NBLOCKS ((KYBER_ETA1*KYBER_N/4)/AES256CTR_BLOCKBYTES) /* Assumes divisibility */
562 #define CIPHERTEXTNOISE_NBLOCKS ((KYBER_ETA2*KYBER_N/4)/AES256CTR_BLOCKBYTES) /* Assumes divisibility */
563   uint64_t nonce = 0;
564   ALIGNED_UINT8(NOISE_NBLOCKS*AES256CTR_BLOCKBYTES+32) buf; /* +32 bytes as required by poly_cbd_eta1 */
565   aes256ctr_ctx state;
566   aes256ctr_init(&state, coins, nonce++);
567   for(i=0;i<KYBER_K;i++) {
568     aes256ctr_squeezeblocks(buf.coeffs, NOISE_NBLOCKS, &state);
569     state.n = _mm_loadl_epi64((__m128i *)&nonce);
570     nonce += 1;
571     poly_cbd_eta1(&sp.vec[i], buf.vec);
572   }
573   for(i=0;i<KYBER_K;i++) {
574     aes256ctr_squeezeblocks(buf.coeffs, CIPHERTEXTNOISE_NBLOCKS, &state);
575     state.n = _mm_loadl_epi64((__m128i *)&nonce);
576     nonce += 1;
577     poly_cbd_eta2(&ep.vec[i], buf.vec);
578   }
579   aes256ctr_squeezeblocks(buf.coeffs, CIPHERTEXTNOISE_NBLOCKS, &state);
580   state.n = _mm_loadl_epi64((__m128i *)&nonce);
581   nonce += 1;
582   poly_cbd_eta2(&epp, buf.vec);
583 #else
584 #if KYBER_K == 2
585   poly_getnoise_eta1122_4x(sp.vec+0, sp.vec+1, ep.vec+0, ep.vec+1, coins, 0, 1, 2, 3);
586   poly_getnoise_eta2(&epp, coins, 4);
587 #elif KYBER_K == 3
588   poly_getnoise_eta1_4x(sp.vec+0, sp.vec+1, sp.vec+2, ep.vec+0, coins, 0, 1, 2 ,3);
589   poly_getnoise_eta1_4x(ep.vec+1, ep.vec+2, &epp, b.vec+0, coins,  4, 5, 6, 7);
590 #elif KYBER_K == 4
591   poly_getnoise_eta1_4x(sp.vec+0, sp.vec+1, sp.vec+2, sp.vec+3, coins, 0, 1, 2, 3);
592   poly_getnoise_eta1_4x(ep.vec+0, ep.vec+1, ep.vec+2, ep.vec+3, coins, 4, 5, 6, 7);
593   poly_getnoise_eta2(&epp, coins, 8);
594 #endif
595 #endif
596 
597   polyvec_ntt(&sp);
598 
599   // matrix-vector multiplication
600   for(i=0;i<KYBER_K;i++)
601     polyvec_basemul_acc_montgomery(&b.vec[i], &at[i], &sp);
602   polyvec_basemul_acc_montgomery(&v, &pkpv, &sp);
603 
604   polyvec_invntt_tomont(&b);
605   poly_invntt_tomont(&v);
606 
607   polyvec_add(&b, &b, &ep);
608   poly_add(&v, &v, &epp);
609   poly_add(&v, &v, &k);
610   polyvec_reduce(&b);
611   poly_reduce(&v);
612 
613   pack_ciphertext(c, &b, &v);
614 }
615 
616 /*************************************************
617 * Name:        indcpa_dec
618 *
619 * Description: Decryption function of the CPA-secure
620 *              public-key encryption scheme underlying Kyber.
621 *
622 * Arguments:   - uint8_t *m: pointer to output decrypted message
623 *                            (of length KYBER_INDCPA_MSGBYTES)
624 *              - const uint8_t *c: pointer to input ciphertext
625 *                                  (of length KYBER_INDCPA_BYTES)
626 *              - const uint8_t *sk: pointer to input secret key
627 *                                   (of length KYBER_INDCPA_SECRETKEYBYTES)
628 **************************************************/
indcpa_dec(uint8_t m[KYBER_INDCPA_MSGBYTES],const uint8_t c[KYBER_INDCPA_BYTES],const uint8_t sk[KYBER_INDCPA_SECRETKEYBYTES])629 void indcpa_dec(uint8_t m[KYBER_INDCPA_MSGBYTES],
630                 const uint8_t c[KYBER_INDCPA_BYTES],
631                 const uint8_t sk[KYBER_INDCPA_SECRETKEYBYTES])
632 {
633   polyvec b, skpv;
634   poly v, mp;
635 
636   unpack_ciphertext(&b, &v, c);
637   unpack_sk(&skpv, sk);
638 
639   polyvec_ntt(&b);
640   polyvec_basemul_acc_montgomery(&mp, &skpv, &b);
641   poly_invntt_tomont(&mp);
642 
643   poly_sub(&mp, &v, &mp);
644   poly_reduce(&mp);
645 
646   poly_tomsg(m, &mp);
647 }
648