xref: /openbsd/lib/libcrypto/rsa/rsa_ameth.c (revision 510d2225)
1 /* $OpenBSD: rsa_ameth.c,v 1.51 2023/11/09 08:29:53 tb Exp $ */
2 /* Written by Dr Stephen N Henson (steve@openssl.org) for the OpenSSL
3  * project 2006.
4  */
5 /* ====================================================================
6  * Copyright (c) 2006 The OpenSSL Project.  All rights reserved.
7  *
8  * Redistribution and use in source and binary forms, with or without
9  * modification, are permitted provided that the following conditions
10  * are met:
11  *
12  * 1. Redistributions of source code must retain the above copyright
13  *    notice, this list of conditions and the following disclaimer.
14  *
15  * 2. Redistributions in binary form must reproduce the above copyright
16  *    notice, this list of conditions and the following disclaimer in
17  *    the documentation and/or other materials provided with the
18  *    distribution.
19  *
20  * 3. All advertising materials mentioning features or use of this
21  *    software must display the following acknowledgment:
22  *    "This product includes software developed by the OpenSSL Project
23  *    for use in the OpenSSL Toolkit. (http://www.OpenSSL.org/)"
24  *
25  * 4. The names "OpenSSL Toolkit" and "OpenSSL Project" must not be used to
26  *    endorse or promote products derived from this software without
27  *    prior written permission. For written permission, please contact
28  *    licensing@OpenSSL.org.
29  *
30  * 5. Products derived from this software may not be called "OpenSSL"
31  *    nor may "OpenSSL" appear in their names without prior written
32  *    permission of the OpenSSL Project.
33  *
34  * 6. Redistributions of any form whatsoever must retain the following
35  *    acknowledgment:
36  *    "This product includes software developed by the OpenSSL Project
37  *    for use in the OpenSSL Toolkit (http://www.OpenSSL.org/)"
38  *
39  * THIS SOFTWARE IS PROVIDED BY THE OpenSSL PROJECT ``AS IS'' AND ANY
40  * EXPRESSED OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
41  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
42  * PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL THE OpenSSL PROJECT OR
43  * ITS CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
44  * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
45  * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
46  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
47  * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
48  * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
49  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED
50  * OF THE POSSIBILITY OF SUCH DAMAGE.
51  * ====================================================================
52  *
53  * This product includes cryptographic software written by Eric Young
54  * (eay@cryptsoft.com).  This product includes software written by Tim
55  * Hudson (tjh@cryptsoft.com).
56  *
57  */
58 
59 #include <stdio.h>
60 
61 #include <openssl/opensslconf.h>
62 
63 #include <openssl/asn1t.h>
64 #include <openssl/bn.h>
65 #include <openssl/cms.h>
66 #include <openssl/err.h>
67 #include <openssl/rsa.h>
68 #include <openssl/x509.h>
69 
70 #include "asn1_local.h"
71 #include "bn_local.h"
72 #include "cryptlib.h"
73 #include "evp_local.h"
74 #include "rsa_local.h"
75 #include "x509_local.h"
76 
77 #ifndef OPENSSL_NO_CMS
78 static int rsa_cms_sign(CMS_SignerInfo *si);
79 static int rsa_cms_verify(CMS_SignerInfo *si);
80 static int rsa_cms_decrypt(CMS_RecipientInfo *ri);
81 static int rsa_cms_encrypt(CMS_RecipientInfo *ri);
82 #endif
83 
84 static RSA_PSS_PARAMS *rsa_pss_decode(const X509_ALGOR *alg);
85 
86 static int rsa_alg_set_pkcs1_padding(X509_ALGOR *alg);
87 
88 /* Set any parameters associated with pkey */
89 static int
90 rsa_param_encode(const EVP_PKEY *pkey, ASN1_STRING **pstr, int *pstrtype)
91 {
92 	const RSA *rsa = pkey->pkey.rsa;
93 
94 	*pstr = NULL;
95 
96 	/* If RSA it's just NULL type */
97 	if (pkey->ameth->pkey_id != EVP_PKEY_RSA_PSS) {
98 		*pstrtype = V_ASN1_NULL;
99 		return 1;
100 	}
101 
102 	/* If no PSS parameters we omit parameters entirely */
103 	if (rsa->pss == NULL) {
104 		*pstrtype = V_ASN1_UNDEF;
105 		return 1;
106 	}
107 
108 	/* Encode PSS parameters */
109 	if (ASN1_item_pack(rsa->pss, &RSA_PSS_PARAMS_it, pstr) == NULL)
110 		return 0;
111 
112 	*pstrtype = V_ASN1_SEQUENCE;
113 	return 1;
114 }
115 
116 /* Decode any parameters and set them in RSA structure */
117 static int
118 rsa_param_decode(RSA *rsa, const X509_ALGOR *alg)
119 {
120 	const ASN1_OBJECT *algoid;
121 	const void *algp;
122 	int algptype;
123 
124 	X509_ALGOR_get0(&algoid, &algptype, &algp, alg);
125 	if (OBJ_obj2nid(algoid) != EVP_PKEY_RSA_PSS)
126 		return 1;
127 	if (algptype == V_ASN1_UNDEF)
128 		return 1;
129 	if (algptype != V_ASN1_SEQUENCE) {
130 		RSAerror(RSA_R_INVALID_PSS_PARAMETERS);
131 		return 0;
132 	}
133 	rsa->pss = rsa_pss_decode(alg);
134 	if (rsa->pss == NULL)
135 		return 0;
136 	return 1;
137 }
138 
139 static int
140 rsa_pub_encode(X509_PUBKEY *pk, const EVP_PKEY *pkey)
141 {
142 	ASN1_STRING *str = NULL;
143 	int strtype;
144 	unsigned char *penc = NULL;
145 	int penclen = 0;
146 	ASN1_OBJECT *aobj;
147 
148 	if (!rsa_param_encode(pkey, &str, &strtype))
149 		goto err;
150 	if ((penclen = i2d_RSAPublicKey(pkey->pkey.rsa, &penc)) <= 0) {
151 		penclen = 0;
152 		goto err;
153 	}
154 	if ((aobj = OBJ_nid2obj(pkey->ameth->pkey_id)) == NULL)
155 		goto err;
156 	if (!X509_PUBKEY_set0_param(pk, aobj, strtype, str, penc, penclen))
157 		goto err;
158 
159 	return 1;
160 
161  err:
162 	ASN1_STRING_free(str);
163 	freezero(penc, penclen);
164 
165 	return 0;
166 }
167 
168 static int
169 rsa_pub_decode(EVP_PKEY *pkey, X509_PUBKEY *pubkey)
170 {
171 	const unsigned char *p;
172 	int pklen;
173 	X509_ALGOR *alg;
174 	RSA *rsa = NULL;
175 
176 	if (!X509_PUBKEY_get0_param(NULL, &p, &pklen, &alg, pubkey))
177 		return 0;
178 	if ((rsa = d2i_RSAPublicKey(NULL, &p, pklen)) == NULL) {
179 		RSAerror(ERR_R_RSA_LIB);
180 		return 0;
181 	}
182 	if (!rsa_param_decode(rsa, alg)) {
183 		RSA_free(rsa);
184 		return 0;
185 	}
186 	if (!EVP_PKEY_assign(pkey, pkey->ameth->pkey_id, rsa)) {
187 		RSA_free(rsa);
188 		return 0;
189 	}
190 	return 1;
191 }
192 
193 static int
194 rsa_pub_cmp(const EVP_PKEY *a, const EVP_PKEY *b)
195 {
196 	if (BN_cmp(b->pkey.rsa->n, a->pkey.rsa->n) != 0 ||
197 	    BN_cmp(b->pkey.rsa->e, a->pkey.rsa->e) != 0)
198 		return 0;
199 
200 	return 1;
201 }
202 
203 static int
204 old_rsa_priv_decode(EVP_PKEY *pkey, const unsigned char **pder, int derlen)
205 {
206 	RSA *rsa;
207 
208 	if ((rsa = d2i_RSAPrivateKey(NULL, pder, derlen)) == NULL) {
209 		RSAerror(ERR_R_RSA_LIB);
210 		return 0;
211 	}
212 	EVP_PKEY_assign(pkey, pkey->ameth->pkey_id, rsa);
213 	return 1;
214 }
215 
216 static int
217 old_rsa_priv_encode(const EVP_PKEY *pkey, unsigned char **pder)
218 {
219 	return i2d_RSAPrivateKey(pkey->pkey.rsa, pder);
220 }
221 
222 static int
223 rsa_priv_encode(PKCS8_PRIV_KEY_INFO *p8, const EVP_PKEY *pkey)
224 {
225 	ASN1_STRING *str = NULL;
226 	ASN1_OBJECT *aobj;
227 	int strtype;
228 	unsigned char *rk = NULL;
229 	int rklen = 0;
230 
231 	if (!rsa_param_encode(pkey, &str, &strtype))
232 		goto err;
233 	if ((rklen = i2d_RSAPrivateKey(pkey->pkey.rsa, &rk)) <= 0) {
234 		RSAerror(ERR_R_MALLOC_FAILURE);
235 		rklen = 0;
236 		goto err;
237 	}
238 	if ((aobj = OBJ_nid2obj(pkey->ameth->pkey_id)) == NULL)
239 		goto err;
240 	if (!PKCS8_pkey_set0(p8, aobj, 0, strtype, str, rk, rklen)) {
241 		RSAerror(ERR_R_MALLOC_FAILURE);
242 		goto err;
243 	}
244 
245 	return 1;
246 
247  err:
248 	ASN1_STRING_free(str);
249 	freezero(rk, rklen);
250 
251 	return 0;
252 }
253 
254 static int
255 rsa_priv_decode(EVP_PKEY *pkey, const PKCS8_PRIV_KEY_INFO *p8)
256 {
257 	const unsigned char *p;
258 	RSA *rsa;
259 	int pklen;
260 	const X509_ALGOR *alg;
261 
262 	if (!PKCS8_pkey_get0(NULL, &p, &pklen, &alg, p8))
263 		return 0;
264 	rsa = d2i_RSAPrivateKey(NULL, &p, pklen);
265 	if (rsa == NULL) {
266 		RSAerror(ERR_R_RSA_LIB);
267 		return 0;
268 	}
269 	if (!rsa_param_decode(rsa, alg)) {
270 		RSA_free(rsa);
271 		return 0;
272 	}
273 	EVP_PKEY_assign(pkey, pkey->ameth->pkey_id, rsa);
274 
275 	return 1;
276 }
277 
278 static int
279 rsa_size(const EVP_PKEY *pkey)
280 {
281 	return RSA_size(pkey->pkey.rsa);
282 }
283 
284 static int
285 rsa_bits(const EVP_PKEY *pkey)
286 {
287 	return BN_num_bits(pkey->pkey.rsa->n);
288 }
289 
290 static int
291 rsa_security_bits(const EVP_PKEY *pkey)
292 {
293 	return RSA_security_bits(pkey->pkey.rsa);
294 }
295 
296 static void
297 rsa_free(EVP_PKEY *pkey)
298 {
299 	RSA_free(pkey->pkey.rsa);
300 }
301 
302 static X509_ALGOR *
303 rsa_mgf1_decode(X509_ALGOR *alg)
304 {
305 	if (OBJ_obj2nid(alg->algorithm) != NID_mgf1)
306 		return NULL;
307 
308 	return ASN1_TYPE_unpack_sequence(&X509_ALGOR_it, alg->parameter);
309 }
310 
311 static RSA_PSS_PARAMS *
312 rsa_pss_decode(const X509_ALGOR *alg)
313 {
314 	RSA_PSS_PARAMS *pss;
315 
316 	pss = ASN1_TYPE_unpack_sequence(&RSA_PSS_PARAMS_it, alg->parameter);
317 	if (pss == NULL)
318 		return NULL;
319 
320 	if (pss->maskGenAlgorithm != NULL) {
321 		pss->maskHash = rsa_mgf1_decode(pss->maskGenAlgorithm);
322 		if (pss->maskHash == NULL) {
323 			RSA_PSS_PARAMS_free(pss);
324 			return NULL;
325 		}
326 	}
327 
328 	return pss;
329 }
330 
331 static int
332 rsa_pss_param_print(BIO *bp, int pss_key, RSA_PSS_PARAMS *pss, int indent)
333 {
334 	int rv = 0;
335 	X509_ALGOR *maskHash = NULL;
336 
337 	if (!BIO_indent(bp, indent, 128))
338 		goto err;
339 	if (pss_key) {
340 		if (pss == NULL) {
341 			if (BIO_puts(bp, "No PSS parameter restrictions\n") <= 0)
342 				return 0;
343 			return 1;
344 		} else {
345 			if (BIO_puts(bp, "PSS parameter restrictions:") <= 0)
346 				return 0;
347 		}
348 	} else if (pss == NULL) {
349 		if (BIO_puts(bp,"(INVALID PSS PARAMETERS)\n") <= 0)
350 			return 0;
351 		return 1;
352 	}
353 	if (BIO_puts(bp, "\n") <= 0)
354 		goto err;
355 	if (pss_key)
356 		indent += 2;
357 	if (!BIO_indent(bp, indent, 128))
358 		goto err;
359 	if (BIO_puts(bp, "Hash Algorithm: ") <= 0)
360 		goto err;
361 
362 	if (pss->hashAlgorithm) {
363 		if (i2a_ASN1_OBJECT(bp, pss->hashAlgorithm->algorithm) <= 0)
364 			goto err;
365 	} else if (BIO_puts(bp, "sha1 (default)") <= 0) {
366 		goto err;
367 	}
368 
369 	if (BIO_puts(bp, "\n") <= 0)
370 		goto err;
371 
372 	if (!BIO_indent(bp, indent, 128))
373 		goto err;
374 
375 	if (BIO_puts(bp, "Mask Algorithm: ") <= 0)
376 		goto err;
377 	if (pss->maskGenAlgorithm) {
378 		if (i2a_ASN1_OBJECT(bp, pss->maskGenAlgorithm->algorithm) <= 0)
379 			goto err;
380 		if (BIO_puts(bp, " with ") <= 0)
381 			goto err;
382 		maskHash = rsa_mgf1_decode(pss->maskGenAlgorithm);
383 		if (maskHash != NULL) {
384 			if (i2a_ASN1_OBJECT(bp, maskHash->algorithm) <= 0)
385 				goto err;
386 		} else if (BIO_puts(bp, "INVALID") <= 0) {
387 			goto err;
388 		}
389 	} else if (BIO_puts(bp, "mgf1 with sha1 (default)") <= 0) {
390 		goto err;
391 	}
392 	BIO_puts(bp, "\n");
393 
394 	if (!BIO_indent(bp, indent, 128))
395 		goto err;
396 	if (BIO_printf(bp, "%s Salt Length: 0x", pss_key ? "Minimum" : "") <= 0)
397 		goto err;
398 	if (pss->saltLength) {
399 		if (i2a_ASN1_INTEGER(bp, pss->saltLength) <= 0)
400 			goto err;
401 	} else if (BIO_puts(bp, "14 (default)") <= 0) {
402 		goto err;
403 	}
404 	BIO_puts(bp, "\n");
405 
406 	if (!BIO_indent(bp, indent, 128))
407 		goto err;
408 	if (BIO_puts(bp, "Trailer Field: 0x") <= 0)
409 		goto err;
410 	if (pss->trailerField) {
411 		if (i2a_ASN1_INTEGER(bp, pss->trailerField) <= 0)
412 			goto err;
413 	} else if (BIO_puts(bp, "BC (default)") <= 0) {
414 		goto err;
415 	}
416 	BIO_puts(bp, "\n");
417 
418 	rv = 1;
419 
420  err:
421 	X509_ALGOR_free(maskHash);
422 	return rv;
423 
424 }
425 
426 static int
427 pkey_rsa_print(BIO *bp, const EVP_PKEY *pkey, int off, int priv)
428 {
429 	const RSA *x = pkey->pkey.rsa;
430 	char *str;
431 	const char *s;
432 	int ret = 0, mod_len = 0;
433 
434 	if (x->n != NULL)
435 		mod_len = BN_num_bits(x->n);
436 
437 	if (!BIO_indent(bp, off, 128))
438 		goto err;
439 
440 	if (BIO_printf(bp, "%s ", pkey_is_pss(pkey) ?  "RSA-PSS" : "RSA") <= 0)
441 		goto err;
442 
443 	if (priv && x->d != NULL) {
444 		if (BIO_printf(bp, "Private-Key: (%d bit)\n", mod_len) <= 0)
445 			goto err;
446 		str = "modulus:";
447 		s = "publicExponent:";
448 	} else {
449 		if (BIO_printf(bp, "Public-Key: (%d bit)\n", mod_len) <= 0)
450 			goto err;
451 		str = "Modulus:";
452 		s = "Exponent:";
453 	}
454 	if (!bn_printf(bp, x->n, off, "%s", str))
455 		goto err;
456 	if (!bn_printf(bp, x->e, off, "%s", s))
457 		goto err;
458 	if (priv) {
459 		if (!bn_printf(bp, x->d, off, "privateExponent:"))
460 			goto err;
461 		if (!bn_printf(bp, x->p, off, "prime1:"))
462 			goto err;
463 		if (!bn_printf(bp, x->q, off, "prime2:"))
464 			goto err;
465 		if (!bn_printf(bp, x->dmp1, off, "exponent1:"))
466 			goto err;
467 		if (!bn_printf(bp, x->dmq1, off, "exponent2:"))
468 			goto err;
469 		if (!bn_printf(bp, x->iqmp, off, "coefficient:"))
470 			goto err;
471 	}
472 	if (pkey_is_pss(pkey) && !rsa_pss_param_print(bp, 1, x->pss, off))
473 		goto err;
474 	ret = 1;
475  err:
476 	return ret;
477 }
478 
479 static int
480 rsa_pub_print(BIO *bp, const EVP_PKEY *pkey, int indent, ASN1_PCTX *ctx)
481 {
482 	return pkey_rsa_print(bp, pkey, indent, 0);
483 }
484 
485 static int
486 rsa_priv_print(BIO *bp, const EVP_PKEY *pkey, int indent, ASN1_PCTX *ctx)
487 {
488 	return pkey_rsa_print(bp, pkey, indent, 1);
489 }
490 
491 static int
492 rsa_sig_print(BIO *bp, const X509_ALGOR *sigalg, const ASN1_STRING *sig,
493     int indent, ASN1_PCTX *pctx)
494 {
495 	if (OBJ_obj2nid(sigalg->algorithm) == EVP_PKEY_RSA_PSS) {
496 		int rv;
497 		RSA_PSS_PARAMS *pss = rsa_pss_decode(sigalg);
498 
499 		rv = rsa_pss_param_print(bp, 0, pss, indent);
500 		RSA_PSS_PARAMS_free(pss);
501 		if (!rv)
502 			return 0;
503 	} else if (!sig && BIO_puts(bp, "\n") <= 0) {
504 		return 0;
505 	}
506 	if (sig)
507 		return X509_signature_dump(bp, sig, indent);
508 	return 1;
509 }
510 
511 static int
512 rsa_pkey_ctrl(EVP_PKEY *pkey, int op, long arg1, void *arg2)
513 {
514 	X509_ALGOR *alg = NULL;
515 	const EVP_MD *md;
516 	const EVP_MD *mgf1md;
517 	int min_saltlen;
518 
519 	switch (op) {
520 	case ASN1_PKEY_CTRL_PKCS7_SIGN:
521 		if (arg1 == 0)
522 			PKCS7_SIGNER_INFO_get0_algs(arg2, NULL, NULL, &alg);
523 		break;
524 
525 	case ASN1_PKEY_CTRL_PKCS7_ENCRYPT:
526 		if (pkey_is_pss(pkey))
527 			return -2;
528 		if (arg1 == 0)
529 			PKCS7_RECIP_INFO_get0_alg(arg2, &alg);
530 		break;
531 #ifndef OPENSSL_NO_CMS
532 	case ASN1_PKEY_CTRL_CMS_SIGN:
533 		if (arg1 == 0)
534 			return rsa_cms_sign(arg2);
535 		else if (arg1 == 1)
536 			return rsa_cms_verify(arg2);
537 		break;
538 
539 	case ASN1_PKEY_CTRL_CMS_ENVELOPE:
540 		if (pkey_is_pss(pkey))
541 			return -2;
542 		if (arg1 == 0)
543 			return rsa_cms_encrypt(arg2);
544 		else if (arg1 == 1)
545 			return rsa_cms_decrypt(arg2);
546 		break;
547 
548 	case ASN1_PKEY_CTRL_CMS_RI_TYPE:
549 		if (pkey_is_pss(pkey))
550 			return -2;
551 		*(int *)arg2 = CMS_RECIPINFO_TRANS;
552 		return 1;
553 #endif
554 
555 	case ASN1_PKEY_CTRL_DEFAULT_MD_NID:
556 		if (pkey->pkey.rsa->pss != NULL) {
557 			if (!rsa_pss_get_param(pkey->pkey.rsa->pss, &md, &mgf1md,
558 			    &min_saltlen)) {
559 				RSAerror(ERR_R_INTERNAL_ERROR);
560 				return 0;
561 			}
562 			*(int *)arg2 = EVP_MD_type(md);
563 			/* Return of 2 indicates this MD is mandatory */
564 			return 2;
565 		}
566 		*(int *)arg2 = NID_sha256;
567 		return 1;
568 
569 	default:
570 		return -2;
571 	}
572 
573 	if (alg != NULL)
574 		return rsa_alg_set_pkcs1_padding(alg);
575 
576 	return 1;
577 }
578 
579 static int
580 rsa_md_to_algor(const EVP_MD *md, X509_ALGOR **out_alg)
581 {
582 	X509_ALGOR *alg = NULL;
583 	int ret = 0;
584 
585 	X509_ALGOR_free(*out_alg);
586 	*out_alg = NULL;
587 
588 	/* RFC 8017 - default hash is SHA-1 and hence omitted. */
589 	if (md == NULL || EVP_MD_type(md) == NID_sha1)
590 		goto done;
591 
592 	if ((alg = X509_ALGOR_new()) == NULL)
593 		goto err;
594 	if (!X509_ALGOR_set_evp_md(alg, md))
595 		goto err;
596 
597  done:
598 	*out_alg = alg;
599 	alg = NULL;
600 
601 	ret = 1;
602 
603  err:
604 	X509_ALGOR_free(alg);
605 
606 	return ret;
607 }
608 
609 /*
610  * RFC 8017, A.2.1 and A.2.3 - encode maskGenAlgorithm for RSAES-OAEP
611  * and RSASSA-PSS. The default is mgfSHA1 and hence omitted.
612  */
613 static int
614 rsa_mgf1md_to_maskGenAlgorithm(const EVP_MD *mgf1md, X509_ALGOR **out_alg)
615 {
616 	X509_ALGOR *alg = NULL;
617 	X509_ALGOR *inner_alg = NULL;
618 	ASN1_STRING *astr = NULL;
619 	int ret = 0;
620 
621 	X509_ALGOR_free(*out_alg);
622 	*out_alg = NULL;
623 
624 	if (mgf1md == NULL || EVP_MD_type(mgf1md) == NID_sha1)
625 		goto done;
626 
627 	if ((inner_alg = X509_ALGOR_new()) == NULL)
628 		goto err;
629 	if (!X509_ALGOR_set_evp_md(inner_alg, mgf1md))
630 		goto err;
631 	if ((astr = ASN1_item_pack(inner_alg, &X509_ALGOR_it, NULL)) == NULL)
632 		goto err;
633 
634 	if ((alg = X509_ALGOR_new()) == NULL)
635 		goto err;
636 	if (!X509_ALGOR_set0_by_nid(alg, NID_mgf1, V_ASN1_SEQUENCE, astr))
637 		goto err;
638 	astr = NULL;
639 
640  done:
641 	*out_alg = alg;
642 	alg = NULL;
643 
644 	ret = 1;
645 
646  err:
647 	X509_ALGOR_free(alg);
648 	X509_ALGOR_free(inner_alg);
649 	ASN1_STRING_free(astr);
650 
651 	return ret;
652 }
653 
654 /* Convert algorithm ID to EVP_MD, defaults to SHA1. */
655 static const EVP_MD *
656 rsa_algor_to_md(X509_ALGOR *alg)
657 {
658 	const EVP_MD *md;
659 
660 	if (!alg)
661 		return EVP_sha1();
662 	md = EVP_get_digestbyobj(alg->algorithm);
663 	if (md == NULL)
664 		RSAerror(RSA_R_UNKNOWN_DIGEST);
665 	return md;
666 }
667 
668 /*
669  * Convert EVP_PKEY_CTX in PSS mode into corresponding algorithm parameter,
670  * suitable for setting an AlgorithmIdentifier.
671  */
672 static RSA_PSS_PARAMS *
673 rsa_ctx_to_pss(EVP_PKEY_CTX *pkey_ctx)
674 {
675 	const EVP_MD *sigmd, *mgf1md;
676 	EVP_PKEY *pk = EVP_PKEY_CTX_get0_pkey(pkey_ctx);
677 	int saltlen;
678 
679 	if (EVP_PKEY_CTX_get_signature_md(pkey_ctx, &sigmd) <= 0)
680 		return NULL;
681 	if (EVP_PKEY_CTX_get_rsa_mgf1_md(pkey_ctx, &mgf1md) <= 0)
682 		return NULL;
683 	if (!EVP_PKEY_CTX_get_rsa_pss_saltlen(pkey_ctx, &saltlen))
684 		return NULL;
685 	if (saltlen == -1) {
686 		saltlen = EVP_MD_size(sigmd);
687 	} else if (saltlen == -2 || saltlen == -3) {
688 		saltlen = EVP_PKEY_size(pk) - EVP_MD_size(sigmd) - 2;
689 		if ((EVP_PKEY_bits(pk) & 0x7) == 1)
690 			saltlen--;
691 		if (saltlen < 0)
692 			return NULL;
693 	}
694 
695 	return rsa_pss_params_create(sigmd, mgf1md, saltlen);
696 }
697 
698 RSA_PSS_PARAMS *
699 rsa_pss_params_create(const EVP_MD *sigmd, const EVP_MD *mgf1md, int saltlen)
700 {
701 	RSA_PSS_PARAMS *pss = NULL;
702 
703 	if (mgf1md == NULL)
704 		mgf1md = sigmd;
705 
706 	if ((pss = RSA_PSS_PARAMS_new()) == NULL)
707 		goto err;
708 
709 	if (!rsa_md_to_algor(sigmd, &pss->hashAlgorithm))
710 		goto err;
711 	if (!rsa_mgf1md_to_maskGenAlgorithm(mgf1md, &pss->maskGenAlgorithm))
712 		goto err;
713 
714 	/* Translate mgf1md to X509_ALGOR in decoded form for internal use. */
715 	if (!rsa_md_to_algor(mgf1md, &pss->maskHash))
716 		goto err;
717 
718 	/* RFC 8017, A.2.3 - default saltLength is SHA_DIGEST_LENGTH. */
719 	if (saltlen != SHA_DIGEST_LENGTH) {
720 		if ((pss->saltLength = ASN1_INTEGER_new()) == NULL)
721 			goto err;
722 		if (!ASN1_INTEGER_set(pss->saltLength, saltlen))
723 			goto err;
724 	}
725 
726 	return pss;
727 
728  err:
729 	RSA_PSS_PARAMS_free(pss);
730 
731 	return NULL;
732 }
733 
734 /*
735  * From PSS AlgorithmIdentifier set public key parameters. If pkey isn't NULL
736  * then the EVP_MD_CTX is setup and initialised. If it is NULL parameters are
737  * passed to pkey_ctx instead.
738  */
739 
740 static int
741 rsa_pss_to_ctx(EVP_MD_CTX *ctx, EVP_PKEY_CTX *pkey_ctx,
742     X509_ALGOR *sigalg, EVP_PKEY *pkey)
743 {
744 	int rv = -1;
745 	int saltlen;
746 	const EVP_MD *mgf1md = NULL, *md = NULL;
747 	RSA_PSS_PARAMS *pss;
748 
749 	/* Sanity check: make sure it is PSS */
750 	if (OBJ_obj2nid(sigalg->algorithm) != EVP_PKEY_RSA_PSS) {
751 		RSAerror(RSA_R_UNSUPPORTED_SIGNATURE_TYPE);
752 		return -1;
753 	}
754 	/* Decode PSS parameters */
755 	pss = rsa_pss_decode(sigalg);
756 
757 	if (!rsa_pss_get_param(pss, &md, &mgf1md, &saltlen)) {
758 		RSAerror(RSA_R_INVALID_PSS_PARAMETERS);
759 		goto err;
760 	}
761 
762 	/* We have all parameters now set up context */
763 	if (pkey) {
764 		if (!EVP_DigestVerifyInit(ctx, &pkey_ctx, md, NULL, pkey))
765 			goto err;
766 	} else {
767 		const EVP_MD *checkmd;
768 		if (EVP_PKEY_CTX_get_signature_md(pkey_ctx, &checkmd) <= 0)
769 			goto err;
770 		if (EVP_MD_type(md) != EVP_MD_type(checkmd)) {
771 			RSAerror(RSA_R_DIGEST_DOES_NOT_MATCH);
772 			goto err;
773 		}
774 	}
775 
776 	if (EVP_PKEY_CTX_set_rsa_padding(pkey_ctx, RSA_PKCS1_PSS_PADDING) <= 0)
777 		goto err;
778 
779 	if (EVP_PKEY_CTX_set_rsa_pss_saltlen(pkey_ctx, saltlen) <= 0)
780 		goto err;
781 
782 	if (EVP_PKEY_CTX_set_rsa_mgf1_md(pkey_ctx, mgf1md) <= 0)
783 		goto err;
784 	/* Carry on */
785 	rv = 1;
786 
787  err:
788 	RSA_PSS_PARAMS_free(pss);
789 	return rv;
790 }
791 
792 int
793 rsa_pss_get_param(const RSA_PSS_PARAMS *pss, const EVP_MD **pmd,
794     const EVP_MD **pmgf1md, int *psaltlen)
795 {
796 	if (pss == NULL)
797 		return 0;
798 	*pmd = rsa_algor_to_md(pss->hashAlgorithm);
799 	if (*pmd == NULL)
800 		return 0;
801 	*pmgf1md = rsa_algor_to_md(pss->maskHash);
802 	if (*pmgf1md == NULL)
803 		return 0;
804 	if (pss->saltLength) {
805 		*psaltlen = ASN1_INTEGER_get(pss->saltLength);
806 		if (*psaltlen < 0) {
807 			RSAerror(RSA_R_INVALID_SALT_LENGTH);
808 			return 0;
809 		}
810 	} else {
811 		*psaltlen = 20;
812 	}
813 
814 	/*
815 	 * low-level routines support only trailer field 0xbc (value 1) and
816 	 * PKCS#1 says we should reject any other value anyway.
817 	 */
818 	if (pss->trailerField && ASN1_INTEGER_get(pss->trailerField) != 1) {
819 		RSAerror(RSA_R_INVALID_TRAILER);
820 		return 0;
821 	}
822 
823 	return 1;
824 }
825 
826 #ifndef OPENSSL_NO_CMS
827 static int
828 rsa_cms_verify(CMS_SignerInfo *si)
829 {
830 	int nid, nid2;
831 	X509_ALGOR *alg;
832 	EVP_PKEY_CTX *pkey_ctx = CMS_SignerInfo_get0_pkey_ctx(si);
833 
834 	CMS_SignerInfo_get0_algs(si, NULL, NULL, NULL, &alg);
835 	nid = OBJ_obj2nid(alg->algorithm);
836 	if (nid == EVP_PKEY_RSA_PSS)
837 		return rsa_pss_to_ctx(NULL, pkey_ctx, alg, NULL);
838 	/* Only PSS allowed for PSS keys */
839 	if (pkey_ctx_is_pss(pkey_ctx)) {
840 		RSAerror(RSA_R_ILLEGAL_OR_UNSUPPORTED_PADDING_MODE);
841 		return 0;
842 	}
843 	if (nid == NID_rsaEncryption)
844 		return 1;
845 	/* Workaround for some implementation that use a signature OID */
846 	if (OBJ_find_sigid_algs(nid, NULL, &nid2)) {
847 		if (nid2 == NID_rsaEncryption)
848 			return 1;
849 	}
850 	return 0;
851 }
852 #endif
853 
854 /*
855  * Customised RSA item verification routine. This is called when a signature
856  * is encountered requiring special handling. We currently only handle PSS.
857  */
858 static int
859 rsa_item_verify(EVP_MD_CTX *ctx, const ASN1_ITEM *it, void *asn,
860     X509_ALGOR *sigalg, ASN1_BIT_STRING *sig, EVP_PKEY *pkey)
861 {
862 	/* Sanity check: make sure it is PSS */
863 	if (OBJ_obj2nid(sigalg->algorithm) != EVP_PKEY_RSA_PSS) {
864 		RSAerror(RSA_R_UNSUPPORTED_SIGNATURE_TYPE);
865 		return -1;
866 	}
867 	if (rsa_pss_to_ctx(ctx, NULL, sigalg, pkey) > 0) {
868 		/* Carry on */
869 		return 2;
870 	}
871 	return -1;
872 }
873 
874 static int
875 rsa_alg_set_pkcs1_padding(X509_ALGOR *alg)
876 {
877 	return X509_ALGOR_set0_by_nid(alg, NID_rsaEncryption, V_ASN1_NULL, NULL);
878 }
879 
880 static int
881 rsa_alg_set_pss_padding(X509_ALGOR *alg, EVP_PKEY_CTX *pkey_ctx)
882 {
883 	RSA_PSS_PARAMS *pss = NULL;
884 	ASN1_STRING *astr = NULL;
885 	int ret = 0;
886 
887 	if (pkey_ctx == NULL)
888 		goto err;
889 
890 	if ((pss = rsa_ctx_to_pss(pkey_ctx)) == NULL)
891 		goto err;
892 	if ((astr = ASN1_item_pack(pss, &RSA_PSS_PARAMS_it, NULL)) == NULL)
893 		goto err;
894 	if (!X509_ALGOR_set0_by_nid(alg, EVP_PKEY_RSA_PSS, V_ASN1_SEQUENCE, astr))
895 		goto err;
896 	astr = NULL;
897 
898 	ret = 1;
899 
900  err:
901 	ASN1_STRING_free(astr);
902 	RSA_PSS_PARAMS_free(pss);
903 
904 	return ret;
905 }
906 
907 #ifndef OPENSSL_NO_CMS
908 static int
909 rsa_alg_set_oaep_padding(X509_ALGOR *alg, EVP_PKEY_CTX *pkey_ctx)
910 {
911 	const EVP_MD *md, *mgf1md;
912 	RSA_OAEP_PARAMS *oaep = NULL;
913 	ASN1_STRING *astr = NULL;
914 	ASN1_OCTET_STRING *ostr = NULL;
915 	unsigned char *label;
916 	int labellen;
917 	int ret = 0;
918 
919 	if (EVP_PKEY_CTX_get_rsa_oaep_md(pkey_ctx, &md) <= 0)
920 		goto err;
921 	if (EVP_PKEY_CTX_get_rsa_mgf1_md(pkey_ctx, &mgf1md) <= 0)
922 		goto err;
923 	labellen = EVP_PKEY_CTX_get0_rsa_oaep_label(pkey_ctx, &label);
924 	if (labellen < 0)
925 		goto err;
926 
927 	if ((oaep = RSA_OAEP_PARAMS_new()) == NULL)
928 		goto err;
929 
930 	if (!rsa_md_to_algor(md, &oaep->hashFunc))
931 		goto err;
932 	if (!rsa_mgf1md_to_maskGenAlgorithm(mgf1md, &oaep->maskGenFunc))
933 		goto err;
934 
935 	/* XXX - why do we not set oaep->maskHash here? */
936 
937 	if (labellen > 0) {
938 		if ((oaep->pSourceFunc = X509_ALGOR_new()) == NULL)
939 			goto err;
940 		if ((ostr = ASN1_OCTET_STRING_new()) == NULL)
941 			goto err;
942 		if (!ASN1_OCTET_STRING_set(ostr, label, labellen))
943 			goto err;
944 		if (!X509_ALGOR_set0_by_nid(oaep->pSourceFunc, NID_pSpecified,
945 		    V_ASN1_OCTET_STRING, ostr))
946 			goto err;
947 		ostr = NULL;
948 	}
949 
950 	if ((astr = ASN1_item_pack(oaep, &RSA_OAEP_PARAMS_it, NULL)) == NULL)
951 		goto err;
952 	if (!X509_ALGOR_set0_by_nid(alg, NID_rsaesOaep, V_ASN1_SEQUENCE, astr))
953 		goto err;
954 	astr = NULL;
955 
956 	ret = 1;
957 
958  err:
959 	RSA_OAEP_PARAMS_free(oaep);
960 	ASN1_STRING_free(astr);
961 	ASN1_OCTET_STRING_free(ostr);
962 
963 	return ret;
964 }
965 
966 static int
967 rsa_cms_sign(CMS_SignerInfo *si)
968 {
969 	EVP_PKEY_CTX *pkey_ctx;
970 	X509_ALGOR *alg;
971 	int pad_mode = RSA_PKCS1_PADDING;
972 
973 	if ((pkey_ctx = CMS_SignerInfo_get0_pkey_ctx(si)) != NULL) {
974 		if (EVP_PKEY_CTX_get_rsa_padding(pkey_ctx, &pad_mode) <= 0)
975 			return 0;
976 	}
977 
978 	CMS_SignerInfo_get0_algs(si, NULL, NULL, NULL, &alg);
979 	if (pad_mode == RSA_PKCS1_PADDING)
980 		return rsa_alg_set_pkcs1_padding(alg);
981 	if (pad_mode == RSA_PKCS1_PSS_PADDING)
982 		return rsa_alg_set_pss_padding(alg, pkey_ctx);
983 
984 	return 0;
985 }
986 #endif
987 
988 static int
989 rsa_item_sign(EVP_MD_CTX *ctx, const ASN1_ITEM *it, void *asn,
990     X509_ALGOR *alg1, X509_ALGOR *alg2, ASN1_BIT_STRING *sig)
991 {
992 	EVP_PKEY_CTX *pkey_ctx = ctx->pctx;
993 	int pad_mode;
994 
995 	if (EVP_PKEY_CTX_get_rsa_padding(pkey_ctx, &pad_mode) <= 0)
996 		return 0;
997 	if (pad_mode == RSA_PKCS1_PADDING)
998 		return 2;
999 	if (pad_mode == RSA_PKCS1_PSS_PADDING) {
1000 		if (!rsa_alg_set_pss_padding(alg1, pkey_ctx))
1001 			return 0;
1002 		if (alg2 != NULL) {
1003 			if (!rsa_alg_set_pss_padding(alg2, pkey_ctx))
1004 				return 0;
1005 		}
1006 		return 3;
1007 	}
1008 	return 2;
1009 }
1010 
1011 static int
1012 rsa_pkey_check(const EVP_PKEY *pkey)
1013 {
1014 	return RSA_check_key(pkey->pkey.rsa);
1015 }
1016 
1017 #ifndef OPENSSL_NO_CMS
1018 static RSA_OAEP_PARAMS *
1019 rsa_oaep_decode(const X509_ALGOR *alg)
1020 {
1021 	RSA_OAEP_PARAMS *oaep;
1022 
1023 	oaep = ASN1_TYPE_unpack_sequence(&RSA_OAEP_PARAMS_it, alg->parameter);
1024 	if (oaep == NULL)
1025 		return NULL;
1026 
1027 	if (oaep->maskGenFunc != NULL) {
1028 		oaep->maskHash = rsa_mgf1_decode(oaep->maskGenFunc);
1029 		if (oaep->maskHash == NULL) {
1030 			RSA_OAEP_PARAMS_free(oaep);
1031 			return NULL;
1032 		}
1033 	}
1034 	return oaep;
1035 }
1036 
1037 static int
1038 rsa_cms_decrypt(CMS_RecipientInfo *ri)
1039 {
1040 	EVP_PKEY_CTX *pkctx;
1041 	X509_ALGOR *cmsalg;
1042 	int nid;
1043 	int rv = -1;
1044 	unsigned char *label = NULL;
1045 	int labellen = 0;
1046 	const EVP_MD *mgf1md = NULL, *md = NULL;
1047 	RSA_OAEP_PARAMS *oaep;
1048 
1049 	pkctx = CMS_RecipientInfo_get0_pkey_ctx(ri);
1050 	if (pkctx == NULL)
1051 		return 0;
1052 	if (!CMS_RecipientInfo_ktri_get0_algs(ri, NULL, NULL, &cmsalg))
1053 		return -1;
1054 	nid = OBJ_obj2nid(cmsalg->algorithm);
1055 	if (nid == NID_rsaEncryption)
1056 		return 1;
1057 	if (nid != NID_rsaesOaep) {
1058 		RSAerror(RSA_R_UNSUPPORTED_ENCRYPTION_TYPE);
1059 		return -1;
1060 	}
1061 	/* Decode OAEP parameters */
1062 	oaep = rsa_oaep_decode(cmsalg);
1063 
1064 	if (oaep == NULL) {
1065 		RSAerror(RSA_R_INVALID_OAEP_PARAMETERS);
1066 		goto err;
1067 	}
1068 
1069 	mgf1md = rsa_algor_to_md(oaep->maskHash);
1070 	if (mgf1md == NULL)
1071 		goto err;
1072 	md = rsa_algor_to_md(oaep->hashFunc);
1073 	if (md == NULL)
1074 		goto err;
1075 
1076 	if (oaep->pSourceFunc != NULL) {
1077 		X509_ALGOR *plab = oaep->pSourceFunc;
1078 
1079 		if (OBJ_obj2nid(plab->algorithm) != NID_pSpecified) {
1080 			RSAerror(RSA_R_UNSUPPORTED_LABEL_SOURCE);
1081 			goto err;
1082 		}
1083 		if (plab->parameter->type != V_ASN1_OCTET_STRING) {
1084 			RSAerror(RSA_R_INVALID_LABEL);
1085 			goto err;
1086 		}
1087 
1088 		label = plab->parameter->value.octet_string->data;
1089 
1090 		/* Stop label being freed when OAEP parameters are freed */
1091 		/* XXX - this leaks label on error... */
1092 		plab->parameter->value.octet_string->data = NULL;
1093 		labellen = plab->parameter->value.octet_string->length;
1094 	}
1095 
1096 	if (EVP_PKEY_CTX_set_rsa_padding(pkctx, RSA_PKCS1_OAEP_PADDING) <= 0)
1097 		goto err;
1098 	if (EVP_PKEY_CTX_set_rsa_oaep_md(pkctx, md) <= 0)
1099 		goto err;
1100 	if (EVP_PKEY_CTX_set_rsa_mgf1_md(pkctx, mgf1md) <= 0)
1101 		goto err;
1102 	if (EVP_PKEY_CTX_set0_rsa_oaep_label(pkctx, label, labellen) <= 0)
1103 		goto err;
1104 
1105 	rv = 1;
1106 
1107  err:
1108 	RSA_OAEP_PARAMS_free(oaep);
1109 	return rv;
1110 }
1111 
1112 static int
1113 rsa_cms_encrypt(CMS_RecipientInfo *ri)
1114 {
1115 	X509_ALGOR *alg;
1116 	EVP_PKEY_CTX *pkey_ctx;
1117 	int pad_mode = RSA_PKCS1_PADDING;
1118 
1119 	if ((pkey_ctx = CMS_RecipientInfo_get0_pkey_ctx(ri)) != NULL) {
1120 		if (EVP_PKEY_CTX_get_rsa_padding(pkey_ctx, &pad_mode) <= 0)
1121 			return 0;
1122 	}
1123 
1124 	if (!CMS_RecipientInfo_ktri_get0_algs(ri, NULL, NULL, &alg))
1125 		return 0;
1126 	if (pad_mode == RSA_PKCS1_PADDING)
1127 		return rsa_alg_set_pkcs1_padding(alg);
1128 	if (pad_mode == RSA_PKCS1_OAEP_PADDING)
1129 		return rsa_alg_set_oaep_padding(alg, pkey_ctx);
1130 
1131 	return 0;
1132 }
1133 #endif
1134 
1135 const EVP_PKEY_ASN1_METHOD rsa_asn1_meths[] = {
1136 	{
1137 		.pkey_id = EVP_PKEY_RSA,
1138 		.pkey_base_id = EVP_PKEY_RSA,
1139 		.pkey_flags = ASN1_PKEY_SIGPARAM_NULL,
1140 
1141 		.pem_str = "RSA",
1142 		.info = "OpenSSL RSA method",
1143 
1144 		.pub_decode = rsa_pub_decode,
1145 		.pub_encode = rsa_pub_encode,
1146 		.pub_cmp = rsa_pub_cmp,
1147 		.pub_print = rsa_pub_print,
1148 
1149 		.priv_decode = rsa_priv_decode,
1150 		.priv_encode = rsa_priv_encode,
1151 		.priv_print = rsa_priv_print,
1152 
1153 		.pkey_size = rsa_size,
1154 		.pkey_bits = rsa_bits,
1155 		.pkey_security_bits = rsa_security_bits,
1156 
1157 		.sig_print = rsa_sig_print,
1158 
1159 		.pkey_free = rsa_free,
1160 		.pkey_ctrl = rsa_pkey_ctrl,
1161 		.old_priv_decode = old_rsa_priv_decode,
1162 		.old_priv_encode = old_rsa_priv_encode,
1163 		.item_verify = rsa_item_verify,
1164 		.item_sign = rsa_item_sign,
1165 
1166 		.pkey_check = rsa_pkey_check,
1167 	},
1168 
1169 	{
1170 		.pkey_id = EVP_PKEY_RSA2,
1171 		.pkey_base_id = EVP_PKEY_RSA,
1172 		.pkey_flags = ASN1_PKEY_ALIAS,
1173 
1174 		.pkey_check = rsa_pkey_check,
1175 	},
1176 };
1177 
1178 const EVP_PKEY_ASN1_METHOD rsa_pss_asn1_meth = {
1179 	.pkey_id = EVP_PKEY_RSA_PSS,
1180 	.pkey_base_id = EVP_PKEY_RSA_PSS,
1181 	.pkey_flags = ASN1_PKEY_SIGPARAM_NULL,
1182 
1183 	.pem_str = "RSA-PSS",
1184 	.info = "OpenSSL RSA-PSS method",
1185 
1186 	.pub_decode = rsa_pub_decode,
1187 	.pub_encode = rsa_pub_encode,
1188 	.pub_cmp = rsa_pub_cmp,
1189 	.pub_print = rsa_pub_print,
1190 
1191 	.priv_decode = rsa_priv_decode,
1192 	.priv_encode = rsa_priv_encode,
1193 	.priv_print = rsa_priv_print,
1194 
1195 	.pkey_size = rsa_size,
1196 	.pkey_bits = rsa_bits,
1197 	.pkey_security_bits = rsa_security_bits,
1198 
1199 	.sig_print = rsa_sig_print,
1200 
1201 	.pkey_free = rsa_free,
1202 	.pkey_ctrl = rsa_pkey_ctrl,
1203 	.item_verify = rsa_item_verify,
1204 	.item_sign = rsa_item_sign
1205 };
1206