1 /*
2  * Copyright 2020-2021 The OpenSSL Project Authors. All Rights Reserved.
3  * Copyright (c) 2020, Intel Corporation. All Rights Reserved.
4  *
5  * Licensed under the Apache License 2.0 (the "License").  You may not use
6  * this file except in compliance with the License.  You can obtain a copy
7  * in the file LICENSE in the source distribution or at
8  * https://www.openssl.org/source/license.html
9  *
10  *
11  * Originally written by Ilya Albrekht, Sergey Kirillov and Andrey Matyukov
12  * Intel Corporation
13  *
14  */
15 
16 #include <openssl/opensslconf.h>
17 #include "rsaz_exp.h"
18 
19 #ifndef RSAZ_ENABLED
20 NON_EMPTY_TRANSLATION_UNIT
21 #else
22 # include <assert.h>
23 # include <string.h>
24 
25 # if defined(__GNUC__)
26 #  define ALIGN64 __attribute__((aligned(64)))
27 # elif defined(_MSC_VER)
28 #  define ALIGN64 __declspec(align(64))
29 # else
30 #  define ALIGN64
31 # endif
32 
33 # define ALIGN_OF(ptr, boundary) \
34     ((unsigned char *)(ptr) + (boundary - (((size_t)(ptr)) & (boundary - 1))))
35 
36 /* Internal radix */
37 # define DIGIT_SIZE (52)
38 /* 52-bit mask */
39 # define DIGIT_MASK ((uint64_t)0xFFFFFFFFFFFFF)
40 
41 # define BITS2WORD8_SIZE(x)  (((x) + 7) >> 3)
42 # define BITS2WORD64_SIZE(x) (((x) + 63) >> 6)
43 
44 static ossl_inline uint64_t get_digit52(const uint8_t *in, int in_len);
45 static ossl_inline void put_digit52(uint8_t *out, int out_len, uint64_t digit);
46 static void to_words52(BN_ULONG *out, int out_len, const BN_ULONG *in,
47                        int in_bitsize);
48 static void from_words52(BN_ULONG *bn_out, int out_bitsize, const BN_ULONG *in);
49 static ossl_inline void set_bit(BN_ULONG *a, int idx);
50 
51 /* Number of |digit_size|-bit digits in |bitsize|-bit value */
52 static ossl_inline int number_of_digits(int bitsize, int digit_size)
53 {
54     return (bitsize + digit_size - 1) / digit_size;
55 }
56 
57 typedef void (*AMM52)(BN_ULONG *res, const BN_ULONG *base,
58                       const BN_ULONG *exp, const BN_ULONG *m, BN_ULONG k0);
59 typedef void (*EXP52_x2)(BN_ULONG *res, const BN_ULONG *base,
60                          const BN_ULONG *exp[2], const BN_ULONG *m,
61                          const BN_ULONG *rr, const BN_ULONG k0[2]);
62 
63 /*
64  * For details of the methods declared below please refer to
65  *    crypto/bn/asm/rsaz-avx512.pl
66  *
67  * Naming notes:
68  *  amm = Almost Montgomery Multiplication
69  *  ams = Almost Montgomery Squaring
70  *  52x20 - data represented as array of 20 digits in 52-bit radix
71  *  _x1_/_x2_ - 1 or 2 independent inputs/outputs
72  *  _256 suffix - uses 256-bit (AVX512VL) registers
73  */
74 
75 /*AMM = Almost Montgomery Multiplication. */
76 void ossl_rsaz_amm52x20_x1_256(BN_ULONG *res, const BN_ULONG *base,
77                                const BN_ULONG *exp, const BN_ULONG *m,
78                                BN_ULONG k0);
79 static void RSAZ_exp52x20_x2_256(BN_ULONG *res, const BN_ULONG *base,
80                                  const BN_ULONG *exp[2], const BN_ULONG *m,
81                                  const BN_ULONG *rr, const BN_ULONG k0[2]);
82 void ossl_rsaz_amm52x20_x2_256(BN_ULONG *out, const BN_ULONG *a,
83                                const BN_ULONG *b, const BN_ULONG *m,
84                                const BN_ULONG k0[2]);
85 void ossl_extract_multiplier_2x20_win5(BN_ULONG *red_Y,
86                                        const BN_ULONG *red_table,
87                                        int red_table_idx, int tbl_idx);
88 
89 /*
90  * Dual Montgomery modular exponentiation using prime moduli of the
91  * same bit size, optimized with AVX512 ISA.
92  *
93  * Input and output parameters for each exponentiation are independent and
94  * denoted here by index |i|, i = 1..2.
95  *
96  * Input and output are all in regular 2^64 radix.
97  *
98  * Each moduli shall be |factor_size| bit size.
99  *
100  * NOTE: currently only 2x1024 case is supported.
101  *
102  *  [out] res|i|      - result of modular exponentiation: array of qword values
103  *                      in regular (2^64) radix. Size of array shall be enough
104  *                      to hold |factor_size| bits.
105  *  [in]  base|i|     - base
106  *  [in]  exp|i|      - exponent
107  *  [in]  m|i|        - moduli
108  *  [in]  rr|i|       - Montgomery parameter RR = R^2 mod m|i|
109  *  [in]  k0_|i|      - Montgomery parameter k0 = -1/m|i| mod 2^64
110  *  [in]  factor_size - moduli bit size
111  *
112  * \return 0 in case of failure,
113  *         1 in case of success.
114  */
115 int ossl_rsaz_mod_exp_avx512_x2(BN_ULONG *res1,
116                                 const BN_ULONG *base1,
117                                 const BN_ULONG *exp1,
118                                 const BN_ULONG *m1,
119                                 const BN_ULONG *rr1,
120                                 BN_ULONG k0_1,
121                                 BN_ULONG *res2,
122                                 const BN_ULONG *base2,
123                                 const BN_ULONG *exp2,
124                                 const BN_ULONG *m2,
125                                 const BN_ULONG *rr2,
126                                 BN_ULONG k0_2,
127                                 int factor_size)
128 {
129     int ret = 0;
130 
131     /*
132      * Number of word-size (BN_ULONG) digits to store exponent in redundant
133      * representation.
134      */
135     int exp_digits = number_of_digits(factor_size + 2, DIGIT_SIZE);
136     int coeff_pow = 4 * (DIGIT_SIZE * exp_digits - factor_size);
137     BN_ULONG *base1_red, *m1_red, *rr1_red;
138     BN_ULONG *base2_red, *m2_red, *rr2_red;
139     BN_ULONG *coeff_red;
140     BN_ULONG *storage = NULL;
141     BN_ULONG *storage_aligned = NULL;
142     BN_ULONG storage_len_bytes = 7 * exp_digits * sizeof(BN_ULONG);
143 
144     /* AMM = Almost Montgomery Multiplication */
145     AMM52 amm = NULL;
146     /* Dual (2-exps in parallel) exponentiation */
147     EXP52_x2 exp_x2 = NULL;
148 
149     const BN_ULONG *exp[2] = {0};
150     BN_ULONG k0[2] = {0};
151 
152     /* Only 1024-bit factor size is supported now */
153     switch (factor_size) {
154     case 1024:
155         amm = ossl_rsaz_amm52x20_x1_256;
156         exp_x2 = RSAZ_exp52x20_x2_256;
157         break;
158     default:
159         goto err;
160     }
161 
162     storage = (BN_ULONG *)OPENSSL_malloc(storage_len_bytes + 64);
163     if (storage == NULL)
164         goto err;
165     storage_aligned = (BN_ULONG *)ALIGN_OF(storage, 64);
166 
167     /* Memory layout for red(undant) representations */
168     base1_red = storage_aligned;
169     base2_red = storage_aligned + 1 * exp_digits;
170     m1_red    = storage_aligned + 2 * exp_digits;
171     m2_red    = storage_aligned + 3 * exp_digits;
172     rr1_red   = storage_aligned + 4 * exp_digits;
173     rr2_red   = storage_aligned + 5 * exp_digits;
174     coeff_red = storage_aligned + 6 * exp_digits;
175 
176     /* Convert base_i, m_i, rr_i, from regular to 52-bit radix */
177     to_words52(base1_red, exp_digits, base1, factor_size);
178     to_words52(base2_red, exp_digits, base2, factor_size);
179     to_words52(m1_red, exp_digits, m1, factor_size);
180     to_words52(m2_red, exp_digits, m2, factor_size);
181     to_words52(rr1_red, exp_digits, rr1, factor_size);
182     to_words52(rr2_red, exp_digits, rr2, factor_size);
183 
184     /*
185      * Compute target domain Montgomery converters RR' for each modulus
186      * based on precomputed original domain's RR.
187      *
188      * RR -> RR' transformation steps:
189      *  (1) coeff = 2^k
190      *  (2) t = AMM(RR,RR) = RR^2 / R' mod m
191      *  (3) RR' = AMM(t, coeff) = RR^2 * 2^k / R'^2 mod m
192      * where
193      *  k = 4 * (52 * digits52 - modlen)
194      *  R  = 2^(64 * ceil(modlen/64)) mod m
195      *  RR = R^2 mod M
196      *  R' = 2^(52 * ceil(modlen/52)) mod m
197      *
198      *  modlen = 1024: k = 64, RR = 2^2048 mod m, RR' = 2^2080 mod m
199      */
200     memset(coeff_red, 0, exp_digits * sizeof(BN_ULONG));
201     /* (1) in reduced domain representation */
202     set_bit(coeff_red, 64 * (int)(coeff_pow / 52) + coeff_pow % 52);
203 
204     amm(rr1_red, rr1_red, rr1_red, m1_red, k0_1);     /* (2) for m1 */
205     amm(rr1_red, rr1_red, coeff_red, m1_red, k0_1);   /* (3) for m1 */
206 
207     amm(rr2_red, rr2_red, rr2_red, m2_red, k0_2);     /* (2) for m2 */
208     amm(rr2_red, rr2_red, coeff_red, m2_red, k0_2);   /* (3) for m2 */
209 
210     exp[0] = exp1;
211     exp[1] = exp2;
212 
213     k0[0] = k0_1;
214     k0[1] = k0_2;
215 
216     exp_x2(rr1_red, base1_red, exp, m1_red, rr1_red, k0);
217 
218     /* Convert rr_i back to regular radix */
219     from_words52(res1, factor_size, rr1_red);
220     from_words52(res2, factor_size, rr2_red);
221 
222     ret = 1;
223 err:
224     if (storage != NULL) {
225         OPENSSL_cleanse(storage, storage_len_bytes);
226         OPENSSL_free(storage);
227     }
228     return ret;
229 }
230 
231 /*
232  * Dual 1024-bit w-ary modular exponentiation using prime moduli of the same
233  * bit size using Almost Montgomery Multiplication, optimized with AVX512_IFMA
234  * ISA.
235  *
236  * The parameter w (window size) = 5.
237  *
238  *  [out] res      - result of modular exponentiation: 2x20 qword
239  *                   values in 2^52 radix.
240  *  [in]  base     - base (2x20 qword values in 2^52 radix)
241  *  [in]  exp      - array of 2 pointers to 16 qword values in 2^64 radix.
242  *                   Exponent is not converted to redundant representation.
243  *  [in]  m        - moduli (2x20 qword values in 2^52 radix)
244  *  [in]  rr       - Montgomery parameter for 2 moduli: RR = 2^2080 mod m.
245  *                   (2x20 qword values in 2^52 radix)
246  *  [in]  k0       - Montgomery parameter for 2 moduli: k0 = -1/m mod 2^64
247  *
248  * \return (void).
249  */
250 static void RSAZ_exp52x20_x2_256(BN_ULONG *out,          /* [2][20] */
251                                  const BN_ULONG *base,   /* [2][20] */
252                                  const BN_ULONG *exp[2], /* 2x16    */
253                                  const BN_ULONG *m,      /* [2][20] */
254                                  const BN_ULONG *rr,     /* [2][20] */
255                                  const BN_ULONG k0[2])
256 {
257 # define BITSIZE_MODULUS (1024)
258 # define EXP_WIN_SIZE (5)
259 # define EXP_WIN_MASK ((1U << EXP_WIN_SIZE) - 1)
260 /*
261  * Number of digits (64-bit words) in redundant representation to handle
262  * modulus bits
263  */
264 # define RED_DIGITS (20)
265 # define EXP_DIGITS (16)
266 # define DAMM ossl_rsaz_amm52x20_x2_256
267 /*
268  * Squaring is done using multiplication now. That can be a subject of
269  * optimization in future.
270  */
271 # define DAMS(r,a,m,k0) \
272               ossl_rsaz_amm52x20_x2_256((r),(a),(a),(m),(k0))
273 
274     /* Allocate stack for red(undant) result Y and multiplier X */
275     ALIGN64 BN_ULONG red_Y[2][RED_DIGITS];
276     ALIGN64 BN_ULONG red_X[2][RED_DIGITS];
277 
278     /* Allocate expanded exponent */
279     ALIGN64 BN_ULONG expz[2][EXP_DIGITS + 1];
280 
281     /* Pre-computed table of base powers */
282     ALIGN64 BN_ULONG red_table[1U << EXP_WIN_SIZE][2][RED_DIGITS];
283 
284     int idx;
285 
286     memset(red_Y, 0, sizeof(red_Y));
287     memset(red_table, 0, sizeof(red_table));
288     memset(red_X, 0, sizeof(red_X));
289 
290     /*
291      * Compute table of powers base^i, i = 0, ..., (2^EXP_WIN_SIZE) - 1
292      *   table[0] = mont(x^0) = mont(1)
293      *   table[1] = mont(x^1) = mont(x)
294      */
295     red_X[0][0] = 1;
296     red_X[1][0] = 1;
297     DAMM(red_table[0][0], (const BN_ULONG*)red_X, rr, m, k0);
298     DAMM(red_table[1][0], base,  rr, m, k0);
299 
300     for (idx = 1; idx < (int)((1U << EXP_WIN_SIZE) / 2); idx++) {
301         DAMS(red_table[2 * idx + 0][0], red_table[1 * idx][0], m, k0);
302         DAMM(red_table[2 * idx + 1][0], red_table[2 * idx][0], red_table[1][0], m, k0);
303     }
304 
305     /* Copy and expand exponents */
306     memcpy(expz[0], exp[0], EXP_DIGITS * sizeof(BN_ULONG));
307     expz[0][EXP_DIGITS] = 0;
308     memcpy(expz[1], exp[1], EXP_DIGITS * sizeof(BN_ULONG));
309     expz[1][EXP_DIGITS] = 0;
310 
311     /* Exponentiation */
312     {
313         int rem = BITSIZE_MODULUS % EXP_WIN_SIZE;
314         int delta = rem ? rem : EXP_WIN_SIZE;
315         BN_ULONG table_idx_mask = EXP_WIN_MASK;
316 
317         int exp_bit_no = BITSIZE_MODULUS - delta;
318         int exp_chunk_no = exp_bit_no / 64;
319         int exp_chunk_shift = exp_bit_no % 64;
320 
321         /* Process 1-st exp window - just init result */
322         BN_ULONG red_table_idx_0 = expz[0][exp_chunk_no];
323         BN_ULONG red_table_idx_1 = expz[1][exp_chunk_no];
324         /*
325          * The function operates with fixed moduli sizes divisible by 64,
326          * thus table index here is always in supported range [0, EXP_WIN_SIZE).
327          */
328         red_table_idx_0 >>= exp_chunk_shift;
329         red_table_idx_1 >>= exp_chunk_shift;
330 
331         ossl_extract_multiplier_2x20_win5(red_Y[0], (const BN_ULONG*)red_table,
332                                           (int)red_table_idx_0, 0);
333         ossl_extract_multiplier_2x20_win5(red_Y[1], (const BN_ULONG*)red_table,
334                                           (int)red_table_idx_1, 1);
335 
336         /* Process other exp windows */
337         for (exp_bit_no -= EXP_WIN_SIZE; exp_bit_no >= 0; exp_bit_no -= EXP_WIN_SIZE) {
338             /* Extract pre-computed multiplier from the table */
339             {
340                 BN_ULONG T;
341 
342                 exp_chunk_no = exp_bit_no / 64;
343                 exp_chunk_shift = exp_bit_no % 64;
344                 {
345                     red_table_idx_0 = expz[0][exp_chunk_no];
346                     T = expz[0][exp_chunk_no + 1];
347 
348                     red_table_idx_0 >>= exp_chunk_shift;
349                     /*
350                      * Get additional bits from then next quadword
351                      * when 64-bit boundaries are crossed.
352                      */
353                     if (exp_chunk_shift > 64 - EXP_WIN_SIZE) {
354                         T <<= (64 - exp_chunk_shift);
355                         red_table_idx_0 ^= T;
356                     }
357                     red_table_idx_0 &= table_idx_mask;
358 
359                     ossl_extract_multiplier_2x20_win5(red_X[0],
360                                                       (const BN_ULONG*)red_table,
361                                                       (int)red_table_idx_0, 0);
362                 }
363                 {
364                     red_table_idx_1 = expz[1][exp_chunk_no];
365                     T = expz[1][exp_chunk_no + 1];
366 
367                     red_table_idx_1 >>= exp_chunk_shift;
368                     /*
369                      * Get additional bits from then next quadword
370                      * when 64-bit boundaries are crossed.
371                      */
372                     if (exp_chunk_shift > 64 - EXP_WIN_SIZE) {
373                         T <<= (64 - exp_chunk_shift);
374                         red_table_idx_1 ^= T;
375                     }
376                     red_table_idx_1 &= table_idx_mask;
377 
378                     ossl_extract_multiplier_2x20_win5(red_X[1],
379                                                       (const BN_ULONG*)red_table,
380                                                       (int)red_table_idx_1, 1);
381                 }
382             }
383 
384             /* Series of squaring */
385             DAMS((BN_ULONG*)red_Y, (const BN_ULONG*)red_Y, m, k0);
386             DAMS((BN_ULONG*)red_Y, (const BN_ULONG*)red_Y, m, k0);
387             DAMS((BN_ULONG*)red_Y, (const BN_ULONG*)red_Y, m, k0);
388             DAMS((BN_ULONG*)red_Y, (const BN_ULONG*)red_Y, m, k0);
389             DAMS((BN_ULONG*)red_Y, (const BN_ULONG*)red_Y, m, k0);
390 
391             DAMM((BN_ULONG*)red_Y, (const BN_ULONG*)red_Y, (const BN_ULONG*)red_X, m, k0);
392         }
393     }
394 
395     /*
396      *
397      * NB: After the last AMM of exponentiation in Montgomery domain, the result
398      * may be 1025-bit, but the conversion out of Montgomery domain performs an
399      * AMM(x,1) which guarantees that the final result is less than |m|, so no
400      * conditional subtraction is needed here. See "Efficient Software
401      * Implementations of Modular Exponentiation" (by Shay Gueron) paper for details.
402      */
403 
404     /* Convert result back in regular 2^52 domain */
405     memset(red_X, 0, sizeof(red_X));
406     red_X[0][0] = 1;
407     red_X[1][0] = 1;
408     DAMM(out, (const BN_ULONG*)red_Y, (const BN_ULONG*)red_X, m, k0);
409 
410     /* Clear exponents */
411     OPENSSL_cleanse(expz, sizeof(expz));
412     OPENSSL_cleanse(red_Y, sizeof(red_Y));
413 
414 # undef DAMS
415 # undef DAMM
416 # undef EXP_DIGITS
417 # undef RED_DIGITS
418 # undef EXP_WIN_MASK
419 # undef EXP_WIN_SIZE
420 # undef BITSIZE_MODULUS
421 }
422 
423 static ossl_inline uint64_t get_digit52(const uint8_t *in, int in_len)
424 {
425     uint64_t digit = 0;
426 
427     assert(in != NULL);
428 
429     for (; in_len > 0; in_len--) {
430         digit <<= 8;
431         digit += (uint64_t)(in[in_len - 1]);
432     }
433     return digit;
434 }
435 
436 /*
437  * Convert array of words in regular (base=2^64) representation to array of
438  * words in redundant (base=2^52) one.
439  */
440 static void to_words52(BN_ULONG *out, int out_len,
441                        const BN_ULONG *in, int in_bitsize)
442 {
443     uint8_t *in_str = NULL;
444 
445     assert(out != NULL);
446     assert(in != NULL);
447     /* Check destination buffer capacity */
448     assert(out_len >= number_of_digits(in_bitsize, DIGIT_SIZE));
449 
450     in_str = (uint8_t *)in;
451 
452     for (; in_bitsize >= (2 * DIGIT_SIZE); in_bitsize -= (2 * DIGIT_SIZE), out += 2) {
453         out[0] = (*(uint64_t *)in_str) & DIGIT_MASK;
454         in_str += 6;
455         out[1] = ((*(uint64_t *)in_str) >> 4) & DIGIT_MASK;
456         in_str += 7;
457         out_len -= 2;
458     }
459 
460     if (in_bitsize > DIGIT_SIZE) {
461         uint64_t digit = get_digit52(in_str, 7);
462 
463         out[0] = digit & DIGIT_MASK;
464         in_str += 6;
465         in_bitsize -= DIGIT_SIZE;
466         digit = get_digit52(in_str, BITS2WORD8_SIZE(in_bitsize));
467         out[1] = digit >> 4;
468         out += 2;
469         out_len -= 2;
470     } else if (in_bitsize > 0) {
471         out[0] = get_digit52(in_str, BITS2WORD8_SIZE(in_bitsize));
472         out++;
473         out_len--;
474     }
475 
476     while (out_len > 0) {
477         *out = 0;
478         out_len--;
479         out++;
480     }
481 }
482 
483 static ossl_inline void put_digit52(uint8_t *pStr, int strLen, uint64_t digit)
484 {
485     assert(pStr != NULL);
486 
487     for (; strLen > 0; strLen--) {
488         *pStr++ = (uint8_t)(digit & 0xFF);
489         digit >>= 8;
490     }
491 }
492 
493 /*
494  * Convert array of words in redundant (base=2^52) representation to array of
495  * words in regular (base=2^64) one.
496  */
497 static void from_words52(BN_ULONG *out, int out_bitsize, const BN_ULONG *in)
498 {
499     int i;
500     int out_len = BITS2WORD64_SIZE(out_bitsize);
501 
502     assert(out != NULL);
503     assert(in != NULL);
504 
505     for (i = 0; i < out_len; i++)
506         out[i] = 0;
507 
508     {
509         uint8_t *out_str = (uint8_t *)out;
510 
511         for (; out_bitsize >= (2 * DIGIT_SIZE); out_bitsize -= (2 * DIGIT_SIZE), in += 2) {
512             (*(uint64_t *)out_str) = in[0];
513             out_str += 6;
514             (*(uint64_t *)out_str) ^= in[1] << 4;
515             out_str += 7;
516         }
517 
518         if (out_bitsize > DIGIT_SIZE) {
519             put_digit52(out_str, 7, in[0]);
520             out_str += 6;
521             out_bitsize -= DIGIT_SIZE;
522             put_digit52(out_str, BITS2WORD8_SIZE(out_bitsize),
523                         (in[1] << 4 | in[0] >> 48));
524         } else if (out_bitsize) {
525             put_digit52(out_str, BITS2WORD8_SIZE(out_bitsize), in[0]);
526         }
527     }
528 }
529 
530 /*
531  * Set bit at index |idx| in the words array |a|.
532  * It does not do any boundaries checks, make sure the index is valid before
533  * calling the function.
534  */
535 static ossl_inline void set_bit(BN_ULONG *a, int idx)
536 {
537     assert(a != NULL);
538 
539     {
540         int i, j;
541 
542         i = idx / BN_BITS2;
543         j = idx % BN_BITS2;
544         a[i] |= (((BN_ULONG)1) << j);
545     }
546 }
547 
548 #endif
549