xref: /freebsd/crypto/openssl/crypto/rsa/rsa_ossl.c (revision b077aed3)
1 /*
2  * Copyright 1995-2023 The OpenSSL Project Authors. All Rights Reserved.
3  *
4  * Licensed under the Apache License 2.0 (the "License").  You may not use
5  * this file except in compliance with the License.  You can obtain a copy
6  * in the file LICENSE in the source distribution or at
7  * https://www.openssl.org/source/license.html
8  */
9 
10 /*
11  * RSA low level APIs are deprecated for public use, but still ok for
12  * internal use.
13  */
14 #include "internal/deprecated.h"
15 
16 #include "internal/cryptlib.h"
17 #include "crypto/bn.h"
18 #include "rsa_local.h"
19 #include "internal/constant_time.h"
20 
21 static int rsa_ossl_public_encrypt(int flen, const unsigned char *from,
22                                   unsigned char *to, RSA *rsa, int padding);
23 static int rsa_ossl_private_encrypt(int flen, const unsigned char *from,
24                                    unsigned char *to, RSA *rsa, int padding);
25 static int rsa_ossl_public_decrypt(int flen, const unsigned char *from,
26                                   unsigned char *to, RSA *rsa, int padding);
27 static int rsa_ossl_private_decrypt(int flen, const unsigned char *from,
28                                    unsigned char *to, RSA *rsa, int padding);
29 static int rsa_ossl_mod_exp(BIGNUM *r0, const BIGNUM *i, RSA *rsa,
30                            BN_CTX *ctx);
31 static int rsa_ossl_init(RSA *rsa);
32 static int rsa_ossl_finish(RSA *rsa);
33 static RSA_METHOD rsa_pkcs1_ossl_meth = {
34     "OpenSSL PKCS#1 RSA",
35     rsa_ossl_public_encrypt,
36     rsa_ossl_public_decrypt,     /* signature verification */
37     rsa_ossl_private_encrypt,    /* signing */
38     rsa_ossl_private_decrypt,
39     rsa_ossl_mod_exp,
40     BN_mod_exp_mont,            /* XXX probably we should not use Montgomery
41                                  * if e == 3 */
42     rsa_ossl_init,
43     rsa_ossl_finish,
44     RSA_FLAG_FIPS_METHOD,       /* flags */
45     NULL,
46     0,                          /* rsa_sign */
47     0,                          /* rsa_verify */
48     NULL,                       /* rsa_keygen */
49     NULL                        /* rsa_multi_prime_keygen */
50 };
51 
52 static const RSA_METHOD *default_RSA_meth = &rsa_pkcs1_ossl_meth;
53 
RSA_set_default_method(const RSA_METHOD * meth)54 void RSA_set_default_method(const RSA_METHOD *meth)
55 {
56     default_RSA_meth = meth;
57 }
58 
RSA_get_default_method(void)59 const RSA_METHOD *RSA_get_default_method(void)
60 {
61     return default_RSA_meth;
62 }
63 
RSA_PKCS1_OpenSSL(void)64 const RSA_METHOD *RSA_PKCS1_OpenSSL(void)
65 {
66     return &rsa_pkcs1_ossl_meth;
67 }
68 
RSA_null_method(void)69 const RSA_METHOD *RSA_null_method(void)
70 {
71     return NULL;
72 }
73 
rsa_ossl_public_encrypt(int flen,const unsigned char * from,unsigned char * to,RSA * rsa,int padding)74 static int rsa_ossl_public_encrypt(int flen, const unsigned char *from,
75                                   unsigned char *to, RSA *rsa, int padding)
76 {
77     BIGNUM *f, *ret;
78     int i, num = 0, r = -1;
79     unsigned char *buf = NULL;
80     BN_CTX *ctx = NULL;
81 
82     if (BN_num_bits(rsa->n) > OPENSSL_RSA_MAX_MODULUS_BITS) {
83         ERR_raise(ERR_LIB_RSA, RSA_R_MODULUS_TOO_LARGE);
84         return -1;
85     }
86 
87     if (BN_ucmp(rsa->n, rsa->e) <= 0) {
88         ERR_raise(ERR_LIB_RSA, RSA_R_BAD_E_VALUE);
89         return -1;
90     }
91 
92     /* for large moduli, enforce exponent limit */
93     if (BN_num_bits(rsa->n) > OPENSSL_RSA_SMALL_MODULUS_BITS) {
94         if (BN_num_bits(rsa->e) > OPENSSL_RSA_MAX_PUBEXP_BITS) {
95             ERR_raise(ERR_LIB_RSA, RSA_R_BAD_E_VALUE);
96             return -1;
97         }
98     }
99 
100     if ((ctx = BN_CTX_new_ex(rsa->libctx)) == NULL)
101         goto err;
102     BN_CTX_start(ctx);
103     f = BN_CTX_get(ctx);
104     ret = BN_CTX_get(ctx);
105     num = BN_num_bytes(rsa->n);
106     buf = OPENSSL_malloc(num);
107     if (ret == NULL || buf == NULL) {
108         ERR_raise(ERR_LIB_RSA, ERR_R_MALLOC_FAILURE);
109         goto err;
110     }
111 
112     switch (padding) {
113     case RSA_PKCS1_PADDING:
114         i = ossl_rsa_padding_add_PKCS1_type_2_ex(rsa->libctx, buf, num,
115                                                  from, flen);
116         break;
117     case RSA_PKCS1_OAEP_PADDING:
118         i = ossl_rsa_padding_add_PKCS1_OAEP_mgf1_ex(rsa->libctx, buf, num,
119                                                     from, flen, NULL, 0,
120                                                     NULL, NULL);
121         break;
122     case RSA_NO_PADDING:
123         i = RSA_padding_add_none(buf, num, from, flen);
124         break;
125     default:
126         ERR_raise(ERR_LIB_RSA, RSA_R_UNKNOWN_PADDING_TYPE);
127         goto err;
128     }
129     if (i <= 0)
130         goto err;
131 
132     if (BN_bin2bn(buf, num, f) == NULL)
133         goto err;
134 
135     if (BN_ucmp(f, rsa->n) >= 0) {
136         /* usually the padding functions would catch this */
137         ERR_raise(ERR_LIB_RSA, RSA_R_DATA_TOO_LARGE_FOR_MODULUS);
138         goto err;
139     }
140 
141     if (rsa->flags & RSA_FLAG_CACHE_PUBLIC)
142         if (!BN_MONT_CTX_set_locked(&rsa->_method_mod_n, rsa->lock,
143                                     rsa->n, ctx))
144             goto err;
145 
146     if (!rsa->meth->bn_mod_exp(ret, f, rsa->e, rsa->n, ctx,
147                                rsa->_method_mod_n))
148         goto err;
149 
150     /*
151      * BN_bn2binpad puts in leading 0 bytes if the number is less than
152      * the length of the modulus.
153      */
154     r = BN_bn2binpad(ret, to, num);
155  err:
156     BN_CTX_end(ctx);
157     BN_CTX_free(ctx);
158     OPENSSL_clear_free(buf, num);
159     return r;
160 }
161 
rsa_get_blinding(RSA * rsa,int * local,BN_CTX * ctx)162 static BN_BLINDING *rsa_get_blinding(RSA *rsa, int *local, BN_CTX *ctx)
163 {
164     BN_BLINDING *ret;
165 
166     if (!CRYPTO_THREAD_write_lock(rsa->lock))
167         return NULL;
168 
169     if (rsa->blinding == NULL) {
170         rsa->blinding = RSA_setup_blinding(rsa, ctx);
171     }
172 
173     ret = rsa->blinding;
174     if (ret == NULL)
175         goto err;
176 
177     if (BN_BLINDING_is_current_thread(ret)) {
178         /* rsa->blinding is ours! */
179 
180         *local = 1;
181     } else {
182         /* resort to rsa->mt_blinding instead */
183 
184         /*
185          * instructs rsa_blinding_convert(), rsa_blinding_invert() that the
186          * BN_BLINDING is shared, meaning that accesses require locks, and
187          * that the blinding factor must be stored outside the BN_BLINDING
188          */
189         *local = 0;
190 
191         if (rsa->mt_blinding == NULL) {
192             rsa->mt_blinding = RSA_setup_blinding(rsa, ctx);
193         }
194         ret = rsa->mt_blinding;
195     }
196 
197  err:
198     CRYPTO_THREAD_unlock(rsa->lock);
199     return ret;
200 }
201 
rsa_blinding_convert(BN_BLINDING * b,BIGNUM * f,BIGNUM * unblind,BN_CTX * ctx)202 static int rsa_blinding_convert(BN_BLINDING *b, BIGNUM *f, BIGNUM *unblind,
203                                 BN_CTX *ctx)
204 {
205     if (unblind == NULL) {
206         /*
207          * Local blinding: store the unblinding factor in BN_BLINDING.
208          */
209         return BN_BLINDING_convert_ex(f, NULL, b, ctx);
210     } else {
211         /*
212          * Shared blinding: store the unblinding factor outside BN_BLINDING.
213          */
214         int ret;
215 
216         if (!BN_BLINDING_lock(b))
217             return 0;
218 
219         ret = BN_BLINDING_convert_ex(f, unblind, b, ctx);
220         BN_BLINDING_unlock(b);
221 
222         return ret;
223     }
224 }
225 
rsa_blinding_invert(BN_BLINDING * b,BIGNUM * f,BIGNUM * unblind,BN_CTX * ctx)226 static int rsa_blinding_invert(BN_BLINDING *b, BIGNUM *f, BIGNUM *unblind,
227                                BN_CTX *ctx)
228 {
229     /*
230      * For local blinding, unblind is set to NULL, and BN_BLINDING_invert_ex
231      * will use the unblinding factor stored in BN_BLINDING. If BN_BLINDING
232      * is shared between threads, unblind must be non-null:
233      * BN_BLINDING_invert_ex will then use the local unblinding factor, and
234      * will only read the modulus from BN_BLINDING. In both cases it's safe
235      * to access the blinding without a lock.
236      */
237     BN_set_flags(f, BN_FLG_CONSTTIME);
238     return BN_BLINDING_invert_ex(f, unblind, b, ctx);
239 }
240 
241 /* signing */
rsa_ossl_private_encrypt(int flen,const unsigned char * from,unsigned char * to,RSA * rsa,int padding)242 static int rsa_ossl_private_encrypt(int flen, const unsigned char *from,
243                                    unsigned char *to, RSA *rsa, int padding)
244 {
245     BIGNUM *f, *ret, *res;
246     int i, num = 0, r = -1;
247     unsigned char *buf = NULL;
248     BN_CTX *ctx = NULL;
249     int local_blinding = 0;
250     /*
251      * Used only if the blinding structure is shared. A non-NULL unblind
252      * instructs rsa_blinding_convert() and rsa_blinding_invert() to store
253      * the unblinding factor outside the blinding structure.
254      */
255     BIGNUM *unblind = NULL;
256     BN_BLINDING *blinding = NULL;
257 
258     if ((ctx = BN_CTX_new_ex(rsa->libctx)) == NULL)
259         goto err;
260     BN_CTX_start(ctx);
261     f = BN_CTX_get(ctx);
262     ret = BN_CTX_get(ctx);
263     num = BN_num_bytes(rsa->n);
264     buf = OPENSSL_malloc(num);
265     if (ret == NULL || buf == NULL) {
266         ERR_raise(ERR_LIB_RSA, ERR_R_MALLOC_FAILURE);
267         goto err;
268     }
269 
270     switch (padding) {
271     case RSA_PKCS1_PADDING:
272         i = RSA_padding_add_PKCS1_type_1(buf, num, from, flen);
273         break;
274     case RSA_X931_PADDING:
275         i = RSA_padding_add_X931(buf, num, from, flen);
276         break;
277     case RSA_NO_PADDING:
278         i = RSA_padding_add_none(buf, num, from, flen);
279         break;
280     default:
281         ERR_raise(ERR_LIB_RSA, RSA_R_UNKNOWN_PADDING_TYPE);
282         goto err;
283     }
284     if (i <= 0)
285         goto err;
286 
287     if (BN_bin2bn(buf, num, f) == NULL)
288         goto err;
289 
290     if (BN_ucmp(f, rsa->n) >= 0) {
291         /* usually the padding functions would catch this */
292         ERR_raise(ERR_LIB_RSA, RSA_R_DATA_TOO_LARGE_FOR_MODULUS);
293         goto err;
294     }
295 
296     if (rsa->flags & RSA_FLAG_CACHE_PUBLIC)
297         if (!BN_MONT_CTX_set_locked(&rsa->_method_mod_n, rsa->lock,
298                                     rsa->n, ctx))
299             goto err;
300 
301     if (!(rsa->flags & RSA_FLAG_NO_BLINDING)) {
302         blinding = rsa_get_blinding(rsa, &local_blinding, ctx);
303         if (blinding == NULL) {
304             ERR_raise(ERR_LIB_RSA, ERR_R_INTERNAL_ERROR);
305             goto err;
306         }
307     }
308 
309     if (blinding != NULL) {
310         if (!local_blinding && ((unblind = BN_CTX_get(ctx)) == NULL)) {
311             ERR_raise(ERR_LIB_RSA, ERR_R_MALLOC_FAILURE);
312             goto err;
313         }
314         if (!rsa_blinding_convert(blinding, f, unblind, ctx))
315             goto err;
316     }
317 
318     if ((rsa->flags & RSA_FLAG_EXT_PKEY) ||
319         (rsa->version == RSA_ASN1_VERSION_MULTI) ||
320         ((rsa->p != NULL) &&
321          (rsa->q != NULL) &&
322          (rsa->dmp1 != NULL) && (rsa->dmq1 != NULL) && (rsa->iqmp != NULL))) {
323         if (!rsa->meth->rsa_mod_exp(ret, f, rsa, ctx))
324             goto err;
325     } else {
326         BIGNUM *d = BN_new();
327         if (d == NULL) {
328             ERR_raise(ERR_LIB_RSA, ERR_R_MALLOC_FAILURE);
329             goto err;
330         }
331         if (rsa->d == NULL) {
332             ERR_raise(ERR_LIB_RSA, RSA_R_MISSING_PRIVATE_KEY);
333             BN_free(d);
334             goto err;
335         }
336         BN_with_flags(d, rsa->d, BN_FLG_CONSTTIME);
337 
338         if (!rsa->meth->bn_mod_exp(ret, f, d, rsa->n, ctx,
339                                    rsa->_method_mod_n)) {
340             BN_free(d);
341             goto err;
342         }
343         /* We MUST free d before any further use of rsa->d */
344         BN_free(d);
345     }
346 
347     if (blinding)
348         if (!rsa_blinding_invert(blinding, ret, unblind, ctx))
349             goto err;
350 
351     if (padding == RSA_X931_PADDING) {
352         if (!BN_sub(f, rsa->n, ret))
353             goto err;
354         if (BN_cmp(ret, f) > 0)
355             res = f;
356         else
357             res = ret;
358     } else {
359         res = ret;
360     }
361 
362     /*
363      * BN_bn2binpad puts in leading 0 bytes if the number is less than
364      * the length of the modulus.
365      */
366     r = BN_bn2binpad(res, to, num);
367  err:
368     BN_CTX_end(ctx);
369     BN_CTX_free(ctx);
370     OPENSSL_clear_free(buf, num);
371     return r;
372 }
373 
rsa_ossl_private_decrypt(int flen,const unsigned char * from,unsigned char * to,RSA * rsa,int padding)374 static int rsa_ossl_private_decrypt(int flen, const unsigned char *from,
375                                    unsigned char *to, RSA *rsa, int padding)
376 {
377     BIGNUM *f, *ret;
378     int j, num = 0, r = -1;
379     unsigned char *buf = NULL;
380     BN_CTX *ctx = NULL;
381     int local_blinding = 0;
382     /*
383      * Used only if the blinding structure is shared. A non-NULL unblind
384      * instructs rsa_blinding_convert() and rsa_blinding_invert() to store
385      * the unblinding factor outside the blinding structure.
386      */
387     BIGNUM *unblind = NULL;
388     BN_BLINDING *blinding = NULL;
389 
390     if ((ctx = BN_CTX_new_ex(rsa->libctx)) == NULL)
391         goto err;
392     BN_CTX_start(ctx);
393     f = BN_CTX_get(ctx);
394     ret = BN_CTX_get(ctx);
395     num = BN_num_bytes(rsa->n);
396     buf = OPENSSL_malloc(num);
397     if (ret == NULL || buf == NULL) {
398         ERR_raise(ERR_LIB_RSA, ERR_R_MALLOC_FAILURE);
399         goto err;
400     }
401 
402     /*
403      * This check was for equality but PGP does evil things and chops off the
404      * top '0' bytes
405      */
406     if (flen > num) {
407         ERR_raise(ERR_LIB_RSA, RSA_R_DATA_GREATER_THAN_MOD_LEN);
408         goto err;
409     }
410 
411     /* make data into a big number */
412     if (BN_bin2bn(from, (int)flen, f) == NULL)
413         goto err;
414 
415     if (BN_ucmp(f, rsa->n) >= 0) {
416         ERR_raise(ERR_LIB_RSA, RSA_R_DATA_TOO_LARGE_FOR_MODULUS);
417         goto err;
418     }
419 
420     if (rsa->flags & RSA_FLAG_CACHE_PUBLIC)
421         if (!BN_MONT_CTX_set_locked(&rsa->_method_mod_n, rsa->lock,
422                                     rsa->n, ctx))
423             goto err;
424 
425     if (!(rsa->flags & RSA_FLAG_NO_BLINDING)) {
426         blinding = rsa_get_blinding(rsa, &local_blinding, ctx);
427         if (blinding == NULL) {
428             ERR_raise(ERR_LIB_RSA, ERR_R_INTERNAL_ERROR);
429             goto err;
430         }
431     }
432 
433     if (blinding != NULL) {
434         if (!local_blinding && ((unblind = BN_CTX_get(ctx)) == NULL)) {
435             ERR_raise(ERR_LIB_RSA, ERR_R_MALLOC_FAILURE);
436             goto err;
437         }
438         if (!rsa_blinding_convert(blinding, f, unblind, ctx))
439             goto err;
440     }
441 
442     /* do the decrypt */
443     if ((rsa->flags & RSA_FLAG_EXT_PKEY) ||
444         (rsa->version == RSA_ASN1_VERSION_MULTI) ||
445         ((rsa->p != NULL) &&
446          (rsa->q != NULL) &&
447          (rsa->dmp1 != NULL) && (rsa->dmq1 != NULL) && (rsa->iqmp != NULL))) {
448         if (!rsa->meth->rsa_mod_exp(ret, f, rsa, ctx))
449             goto err;
450     } else {
451         BIGNUM *d = BN_new();
452         if (d == NULL) {
453             ERR_raise(ERR_LIB_RSA, ERR_R_MALLOC_FAILURE);
454             goto err;
455         }
456         if (rsa->d == NULL) {
457             ERR_raise(ERR_LIB_RSA, RSA_R_MISSING_PRIVATE_KEY);
458             BN_free(d);
459             goto err;
460         }
461         BN_with_flags(d, rsa->d, BN_FLG_CONSTTIME);
462         if (!rsa->meth->bn_mod_exp(ret, f, d, rsa->n, ctx,
463                                    rsa->_method_mod_n)) {
464             BN_free(d);
465             goto err;
466         }
467         /* We MUST free d before any further use of rsa->d */
468         BN_free(d);
469     }
470 
471     if (blinding)
472         if (!rsa_blinding_invert(blinding, ret, unblind, ctx))
473             goto err;
474 
475     j = BN_bn2binpad(ret, buf, num);
476     if (j < 0)
477         goto err;
478 
479     switch (padding) {
480     case RSA_PKCS1_PADDING:
481         r = RSA_padding_check_PKCS1_type_2(to, num, buf, j, num);
482         break;
483     case RSA_PKCS1_OAEP_PADDING:
484         r = RSA_padding_check_PKCS1_OAEP(to, num, buf, j, num, NULL, 0);
485         break;
486     case RSA_NO_PADDING:
487         memcpy(to, buf, (r = j));
488         break;
489     default:
490         ERR_raise(ERR_LIB_RSA, RSA_R_UNKNOWN_PADDING_TYPE);
491         goto err;
492     }
493 #ifndef FIPS_MODULE
494     /*
495      * This trick doesn't work in the FIPS provider because libcrypto manages
496      * the error stack. Instead we opt not to put an error on the stack at all
497      * in case of padding failure in the FIPS provider.
498      */
499     ERR_raise(ERR_LIB_RSA, RSA_R_PADDING_CHECK_FAILED);
500     err_clear_last_constant_time(1 & ~constant_time_msb(r));
501 #endif
502 
503  err:
504     BN_CTX_end(ctx);
505     BN_CTX_free(ctx);
506     OPENSSL_clear_free(buf, num);
507     return r;
508 }
509 
510 /* signature verification */
rsa_ossl_public_decrypt(int flen,const unsigned char * from,unsigned char * to,RSA * rsa,int padding)511 static int rsa_ossl_public_decrypt(int flen, const unsigned char *from,
512                                   unsigned char *to, RSA *rsa, int padding)
513 {
514     BIGNUM *f, *ret;
515     int i, num = 0, r = -1;
516     unsigned char *buf = NULL;
517     BN_CTX *ctx = NULL;
518 
519     if (BN_num_bits(rsa->n) > OPENSSL_RSA_MAX_MODULUS_BITS) {
520         ERR_raise(ERR_LIB_RSA, RSA_R_MODULUS_TOO_LARGE);
521         return -1;
522     }
523 
524     if (BN_ucmp(rsa->n, rsa->e) <= 0) {
525         ERR_raise(ERR_LIB_RSA, RSA_R_BAD_E_VALUE);
526         return -1;
527     }
528 
529     /* for large moduli, enforce exponent limit */
530     if (BN_num_bits(rsa->n) > OPENSSL_RSA_SMALL_MODULUS_BITS) {
531         if (BN_num_bits(rsa->e) > OPENSSL_RSA_MAX_PUBEXP_BITS) {
532             ERR_raise(ERR_LIB_RSA, RSA_R_BAD_E_VALUE);
533             return -1;
534         }
535     }
536 
537     if ((ctx = BN_CTX_new_ex(rsa->libctx)) == NULL)
538         goto err;
539     BN_CTX_start(ctx);
540     f = BN_CTX_get(ctx);
541     ret = BN_CTX_get(ctx);
542     num = BN_num_bytes(rsa->n);
543     buf = OPENSSL_malloc(num);
544     if (ret == NULL || buf == NULL) {
545         ERR_raise(ERR_LIB_RSA, ERR_R_MALLOC_FAILURE);
546         goto err;
547     }
548 
549     /*
550      * This check was for equality but PGP does evil things and chops off the
551      * top '0' bytes
552      */
553     if (flen > num) {
554         ERR_raise(ERR_LIB_RSA, RSA_R_DATA_GREATER_THAN_MOD_LEN);
555         goto err;
556     }
557 
558     if (BN_bin2bn(from, flen, f) == NULL)
559         goto err;
560 
561     if (BN_ucmp(f, rsa->n) >= 0) {
562         ERR_raise(ERR_LIB_RSA, RSA_R_DATA_TOO_LARGE_FOR_MODULUS);
563         goto err;
564     }
565 
566     if (rsa->flags & RSA_FLAG_CACHE_PUBLIC)
567         if (!BN_MONT_CTX_set_locked(&rsa->_method_mod_n, rsa->lock,
568                                     rsa->n, ctx))
569             goto err;
570 
571     if (!rsa->meth->bn_mod_exp(ret, f, rsa->e, rsa->n, ctx,
572                                rsa->_method_mod_n))
573         goto err;
574 
575     if ((padding == RSA_X931_PADDING) && ((bn_get_words(ret)[0] & 0xf) != 12))
576         if (!BN_sub(ret, rsa->n, ret))
577             goto err;
578 
579     i = BN_bn2binpad(ret, buf, num);
580     if (i < 0)
581         goto err;
582 
583     switch (padding) {
584     case RSA_PKCS1_PADDING:
585         r = RSA_padding_check_PKCS1_type_1(to, num, buf, i, num);
586         break;
587     case RSA_X931_PADDING:
588         r = RSA_padding_check_X931(to, num, buf, i, num);
589         break;
590     case RSA_NO_PADDING:
591         memcpy(to, buf, (r = i));
592         break;
593     default:
594         ERR_raise(ERR_LIB_RSA, RSA_R_UNKNOWN_PADDING_TYPE);
595         goto err;
596     }
597     if (r < 0)
598         ERR_raise(ERR_LIB_RSA, RSA_R_PADDING_CHECK_FAILED);
599 
600  err:
601     BN_CTX_end(ctx);
602     BN_CTX_free(ctx);
603     OPENSSL_clear_free(buf, num);
604     return r;
605 }
606 
rsa_ossl_mod_exp(BIGNUM * r0,const BIGNUM * I,RSA * rsa,BN_CTX * ctx)607 static int rsa_ossl_mod_exp(BIGNUM *r0, const BIGNUM *I, RSA *rsa, BN_CTX *ctx)
608 {
609     BIGNUM *r1, *m1, *vrfy;
610     int ret = 0, smooth = 0;
611 #ifndef FIPS_MODULE
612     BIGNUM *r2, *m[RSA_MAX_PRIME_NUM - 2];
613     int i, ex_primes = 0;
614     RSA_PRIME_INFO *pinfo;
615 #endif
616 
617     BN_CTX_start(ctx);
618 
619     r1 = BN_CTX_get(ctx);
620 #ifndef FIPS_MODULE
621     r2 = BN_CTX_get(ctx);
622 #endif
623     m1 = BN_CTX_get(ctx);
624     vrfy = BN_CTX_get(ctx);
625     if (vrfy == NULL)
626         goto err;
627 
628 #ifndef FIPS_MODULE
629     if (rsa->version == RSA_ASN1_VERSION_MULTI
630         && ((ex_primes = sk_RSA_PRIME_INFO_num(rsa->prime_infos)) <= 0
631              || ex_primes > RSA_MAX_PRIME_NUM - 2))
632         goto err;
633 #endif
634 
635     if (rsa->flags & RSA_FLAG_CACHE_PRIVATE) {
636         BIGNUM *factor = BN_new();
637 
638         if (factor == NULL)
639             goto err;
640 
641         /*
642          * Make sure BN_mod_inverse in Montgomery initialization uses the
643          * BN_FLG_CONSTTIME flag
644          */
645         if (!(BN_with_flags(factor, rsa->p, BN_FLG_CONSTTIME),
646               BN_MONT_CTX_set_locked(&rsa->_method_mod_p, rsa->lock,
647                                      factor, ctx))
648             || !(BN_with_flags(factor, rsa->q, BN_FLG_CONSTTIME),
649                  BN_MONT_CTX_set_locked(&rsa->_method_mod_q, rsa->lock,
650                                         factor, ctx))) {
651             BN_free(factor);
652             goto err;
653         }
654 #ifndef FIPS_MODULE
655         for (i = 0; i < ex_primes; i++) {
656             pinfo = sk_RSA_PRIME_INFO_value(rsa->prime_infos, i);
657             BN_with_flags(factor, pinfo->r, BN_FLG_CONSTTIME);
658             if (!BN_MONT_CTX_set_locked(&pinfo->m, rsa->lock, factor, ctx)) {
659                 BN_free(factor);
660                 goto err;
661             }
662         }
663 #endif
664         /*
665          * We MUST free |factor| before any further use of the prime factors
666          */
667         BN_free(factor);
668 
669         smooth = (rsa->meth->bn_mod_exp == BN_mod_exp_mont)
670 #ifndef FIPS_MODULE
671                  && (ex_primes == 0)
672 #endif
673                  && (BN_num_bits(rsa->q) == BN_num_bits(rsa->p));
674     }
675 
676     if (rsa->flags & RSA_FLAG_CACHE_PUBLIC)
677         if (!BN_MONT_CTX_set_locked(&rsa->_method_mod_n, rsa->lock,
678                                     rsa->n, ctx))
679             goto err;
680 
681     if (smooth) {
682         /*
683          * Conversion from Montgomery domain, a.k.a. Montgomery reduction,
684          * accepts values in [0-m*2^w) range. w is m's bit width rounded up
685          * to limb width. So that at the very least if |I| is fully reduced,
686          * i.e. less than p*q, we can count on from-to round to perform
687          * below modulo operations on |I|. Unlike BN_mod it's constant time.
688          */
689         if (/* m1 = I moq q */
690             !bn_from_mont_fixed_top(m1, I, rsa->_method_mod_q, ctx)
691             || !bn_to_mont_fixed_top(m1, m1, rsa->_method_mod_q, ctx)
692             /* r1 = I mod p */
693             || !bn_from_mont_fixed_top(r1, I, rsa->_method_mod_p, ctx)
694             || !bn_to_mont_fixed_top(r1, r1, rsa->_method_mod_p, ctx)
695             /*
696              * Use parallel exponentiations optimization if possible,
697              * otherwise fallback to two sequential exponentiations:
698              *    m1 = m1^dmq1 mod q
699              *    r1 = r1^dmp1 mod p
700              */
701             || !BN_mod_exp_mont_consttime_x2(m1, m1, rsa->dmq1, rsa->q,
702                                              rsa->_method_mod_q,
703                                              r1, r1, rsa->dmp1, rsa->p,
704                                              rsa->_method_mod_p,
705                                              ctx)
706             /* r1 = (r1 - m1) mod p */
707             /*
708              * bn_mod_sub_fixed_top is not regular modular subtraction,
709              * it can tolerate subtrahend to be larger than modulus, but
710              * not bit-wise wider. This makes up for uncommon q>p case,
711              * when |m1| can be larger than |rsa->p|.
712              */
713             || !bn_mod_sub_fixed_top(r1, r1, m1, rsa->p)
714 
715             /* r1 = r1 * iqmp mod p */
716             || !bn_to_mont_fixed_top(r1, r1, rsa->_method_mod_p, ctx)
717             || !bn_mul_mont_fixed_top(r1, r1, rsa->iqmp, rsa->_method_mod_p,
718                                       ctx)
719             /* r0 = r1 * q + m1 */
720             || !bn_mul_fixed_top(r0, r1, rsa->q, ctx)
721             || !bn_mod_add_fixed_top(r0, r0, m1, rsa->n))
722             goto err;
723 
724         goto tail;
725     }
726 
727     /* compute I mod q */
728     {
729         BIGNUM *c = BN_new();
730         if (c == NULL)
731             goto err;
732         BN_with_flags(c, I, BN_FLG_CONSTTIME);
733 
734         if (!BN_mod(r1, c, rsa->q, ctx)) {
735             BN_free(c);
736             goto err;
737         }
738 
739         {
740             BIGNUM *dmq1 = BN_new();
741             if (dmq1 == NULL) {
742                 BN_free(c);
743                 goto err;
744             }
745             BN_with_flags(dmq1, rsa->dmq1, BN_FLG_CONSTTIME);
746 
747             /* compute r1^dmq1 mod q */
748             if (!rsa->meth->bn_mod_exp(m1, r1, dmq1, rsa->q, ctx,
749                                        rsa->_method_mod_q)) {
750                 BN_free(c);
751                 BN_free(dmq1);
752                 goto err;
753             }
754             /* We MUST free dmq1 before any further use of rsa->dmq1 */
755             BN_free(dmq1);
756         }
757 
758         /* compute I mod p */
759         if (!BN_mod(r1, c, rsa->p, ctx)) {
760             BN_free(c);
761             goto err;
762         }
763         /* We MUST free c before any further use of I */
764         BN_free(c);
765     }
766 
767     {
768         BIGNUM *dmp1 = BN_new();
769         if (dmp1 == NULL)
770             goto err;
771         BN_with_flags(dmp1, rsa->dmp1, BN_FLG_CONSTTIME);
772 
773         /* compute r1^dmp1 mod p */
774         if (!rsa->meth->bn_mod_exp(r0, r1, dmp1, rsa->p, ctx,
775                                    rsa->_method_mod_p)) {
776             BN_free(dmp1);
777             goto err;
778         }
779         /* We MUST free dmp1 before any further use of rsa->dmp1 */
780         BN_free(dmp1);
781     }
782 
783 #ifndef FIPS_MODULE
784     if (ex_primes > 0) {
785         BIGNUM *di = BN_new(), *cc = BN_new();
786 
787         if (cc == NULL || di == NULL) {
788             BN_free(cc);
789             BN_free(di);
790             goto err;
791         }
792 
793         for (i = 0; i < ex_primes; i++) {
794             /* prepare m_i */
795             if ((m[i] = BN_CTX_get(ctx)) == NULL) {
796                 BN_free(cc);
797                 BN_free(di);
798                 goto err;
799             }
800 
801             pinfo = sk_RSA_PRIME_INFO_value(rsa->prime_infos, i);
802 
803             /* prepare c and d_i */
804             BN_with_flags(cc, I, BN_FLG_CONSTTIME);
805             BN_with_flags(di, pinfo->d, BN_FLG_CONSTTIME);
806 
807             if (!BN_mod(r1, cc, pinfo->r, ctx)) {
808                 BN_free(cc);
809                 BN_free(di);
810                 goto err;
811             }
812             /* compute r1 ^ d_i mod r_i */
813             if (!rsa->meth->bn_mod_exp(m[i], r1, di, pinfo->r, ctx, pinfo->m)) {
814                 BN_free(cc);
815                 BN_free(di);
816                 goto err;
817             }
818         }
819 
820         BN_free(cc);
821         BN_free(di);
822     }
823 #endif
824 
825     if (!BN_sub(r0, r0, m1))
826         goto err;
827     /*
828      * This will help stop the size of r0 increasing, which does affect the
829      * multiply if it optimised for a power of 2 size
830      */
831     if (BN_is_negative(r0))
832         if (!BN_add(r0, r0, rsa->p))
833             goto err;
834 
835     if (!BN_mul(r1, r0, rsa->iqmp, ctx))
836         goto err;
837 
838     {
839         BIGNUM *pr1 = BN_new();
840         if (pr1 == NULL)
841             goto err;
842         BN_with_flags(pr1, r1, BN_FLG_CONSTTIME);
843 
844         if (!BN_mod(r0, pr1, rsa->p, ctx)) {
845             BN_free(pr1);
846             goto err;
847         }
848         /* We MUST free pr1 before any further use of r1 */
849         BN_free(pr1);
850     }
851 
852     /*
853      * If p < q it is occasionally possible for the correction of adding 'p'
854      * if r0 is negative above to leave the result still negative. This can
855      * break the private key operations: the following second correction
856      * should *always* correct this rare occurrence. This will *never* happen
857      * with OpenSSL generated keys because they ensure p > q [steve]
858      */
859     if (BN_is_negative(r0))
860         if (!BN_add(r0, r0, rsa->p))
861             goto err;
862     if (!BN_mul(r1, r0, rsa->q, ctx))
863         goto err;
864     if (!BN_add(r0, r1, m1))
865         goto err;
866 
867 #ifndef FIPS_MODULE
868     /* add m_i to m in multi-prime case */
869     if (ex_primes > 0) {
870         BIGNUM *pr2 = BN_new();
871 
872         if (pr2 == NULL)
873             goto err;
874 
875         for (i = 0; i < ex_primes; i++) {
876             pinfo = sk_RSA_PRIME_INFO_value(rsa->prime_infos, i);
877             if (!BN_sub(r1, m[i], r0)) {
878                 BN_free(pr2);
879                 goto err;
880             }
881 
882             if (!BN_mul(r2, r1, pinfo->t, ctx)) {
883                 BN_free(pr2);
884                 goto err;
885             }
886 
887             BN_with_flags(pr2, r2, BN_FLG_CONSTTIME);
888 
889             if (!BN_mod(r1, pr2, pinfo->r, ctx)) {
890                 BN_free(pr2);
891                 goto err;
892             }
893 
894             if (BN_is_negative(r1))
895                 if (!BN_add(r1, r1, pinfo->r)) {
896                     BN_free(pr2);
897                     goto err;
898                 }
899             if (!BN_mul(r1, r1, pinfo->pp, ctx)) {
900                 BN_free(pr2);
901                 goto err;
902             }
903             if (!BN_add(r0, r0, r1)) {
904                 BN_free(pr2);
905                 goto err;
906             }
907         }
908         BN_free(pr2);
909     }
910 #endif
911 
912  tail:
913     if (rsa->e && rsa->n) {
914         if (rsa->meth->bn_mod_exp == BN_mod_exp_mont) {
915             if (!BN_mod_exp_mont(vrfy, r0, rsa->e, rsa->n, ctx,
916                                  rsa->_method_mod_n))
917                 goto err;
918         } else {
919             bn_correct_top(r0);
920             if (!rsa->meth->bn_mod_exp(vrfy, r0, rsa->e, rsa->n, ctx,
921                                        rsa->_method_mod_n))
922                 goto err;
923         }
924         /*
925          * If 'I' was greater than (or equal to) rsa->n, the operation will
926          * be equivalent to using 'I mod n'. However, the result of the
927          * verify will *always* be less than 'n' so we don't check for
928          * absolute equality, just congruency.
929          */
930         if (!BN_sub(vrfy, vrfy, I))
931             goto err;
932         if (BN_is_zero(vrfy)) {
933             bn_correct_top(r0);
934             ret = 1;
935             goto err;   /* not actually error */
936         }
937         if (!BN_mod(vrfy, vrfy, rsa->n, ctx))
938             goto err;
939         if (BN_is_negative(vrfy))
940             if (!BN_add(vrfy, vrfy, rsa->n))
941                 goto err;
942         if (!BN_is_zero(vrfy)) {
943             /*
944              * 'I' and 'vrfy' aren't congruent mod n. Don't leak
945              * miscalculated CRT output, just do a raw (slower) mod_exp and
946              * return that instead.
947              */
948 
949             BIGNUM *d = BN_new();
950             if (d == NULL)
951                 goto err;
952             BN_with_flags(d, rsa->d, BN_FLG_CONSTTIME);
953 
954             if (!rsa->meth->bn_mod_exp(r0, I, d, rsa->n, ctx,
955                                        rsa->_method_mod_n)) {
956                 BN_free(d);
957                 goto err;
958             }
959             /* We MUST free d before any further use of rsa->d */
960             BN_free(d);
961         }
962     }
963     /*
964      * It's unfortunate that we have to bn_correct_top(r0). What hopefully
965      * saves the day is that correction is highly unlike, and private key
966      * operations are customarily performed on blinded message. Which means
967      * that attacker won't observe correlation with chosen plaintext.
968      * Secondly, remaining code would still handle it in same computational
969      * time and even conceal memory access pattern around corrected top.
970      */
971     bn_correct_top(r0);
972     ret = 1;
973  err:
974     BN_CTX_end(ctx);
975     return ret;
976 }
977 
rsa_ossl_init(RSA * rsa)978 static int rsa_ossl_init(RSA *rsa)
979 {
980     rsa->flags |= RSA_FLAG_CACHE_PUBLIC | RSA_FLAG_CACHE_PRIVATE;
981     return 1;
982 }
983 
rsa_ossl_finish(RSA * rsa)984 static int rsa_ossl_finish(RSA *rsa)
985 {
986 #ifndef FIPS_MODULE
987     int i;
988     RSA_PRIME_INFO *pinfo;
989 
990     for (i = 0; i < sk_RSA_PRIME_INFO_num(rsa->prime_infos); i++) {
991         pinfo = sk_RSA_PRIME_INFO_value(rsa->prime_infos, i);
992         BN_MONT_CTX_free(pinfo->m);
993     }
994 #endif
995 
996     BN_MONT_CTX_free(rsa->_method_mod_n);
997     BN_MONT_CTX_free(rsa->_method_mod_p);
998     BN_MONT_CTX_free(rsa->_method_mod_q);
999     return 1;
1000 }
1001