1 /* 2 * Copyright 2005-2021 The OpenSSL Project Authors. All Rights Reserved. 3 * 4 * Licensed under the Apache License 2.0 (the "License"). You may not use 5 * this file except in compliance with the License. You can obtain a copy 6 * in the file LICENSE in the source distribution or at 7 * https://www.openssl.org/source/license.html 8 */ 9 10 /* 11 * RSA low level APIs are deprecated for public use, but still ok for 12 * internal use. 13 */ 14 #include "internal/deprecated.h" 15 16 #include <stdio.h> 17 #include "internal/cryptlib.h" 18 #include <openssl/bn.h> 19 #include <openssl/rsa.h> 20 #include <openssl/evp.h> 21 #include <openssl/rand.h> 22 #include <openssl/sha.h> 23 #include "rsa_local.h" 24 25 static const unsigned char zeroes[] = { 0, 0, 0, 0, 0, 0, 0, 0 }; 26 27 #if defined(_MSC_VER) && defined(_ARM_) 28 # pragma optimize("g", off) 29 #endif 30 31 int RSA_verify_PKCS1_PSS(RSA *rsa, const unsigned char *mHash, 32 const EVP_MD *Hash, const unsigned char *EM, 33 int sLen) 34 { 35 return RSA_verify_PKCS1_PSS_mgf1(rsa, mHash, Hash, NULL, EM, sLen); 36 } 37 38 int RSA_verify_PKCS1_PSS_mgf1(RSA *rsa, const unsigned char *mHash, 39 const EVP_MD *Hash, const EVP_MD *mgf1Hash, 40 const unsigned char *EM, int sLen) 41 { 42 int i; 43 int ret = 0; 44 int hLen, maskedDBLen, MSBits, emLen; 45 const unsigned char *H; 46 unsigned char *DB = NULL; 47 EVP_MD_CTX *ctx = EVP_MD_CTX_new(); 48 unsigned char H_[EVP_MAX_MD_SIZE]; 49 50 if (ctx == NULL) 51 goto err; 52 53 if (mgf1Hash == NULL) 54 mgf1Hash = Hash; 55 56 hLen = EVP_MD_get_size(Hash); 57 if (hLen < 0) 58 goto err; 59 /*- 60 * Negative sLen has special meanings: 61 * -1 sLen == hLen 62 * -2 salt length is autorecovered from signature 63 * -3 salt length is maximized 64 * -N reserved 65 */ 66 if (sLen == RSA_PSS_SALTLEN_DIGEST) { 67 sLen = hLen; 68 } else if (sLen < RSA_PSS_SALTLEN_MAX) { 69 ERR_raise(ERR_LIB_RSA, RSA_R_SLEN_CHECK_FAILED); 70 goto err; 71 } 72 73 MSBits = (BN_num_bits(rsa->n) - 1) & 0x7; 74 emLen = RSA_size(rsa); 75 if (EM[0] & (0xFF << MSBits)) { 76 ERR_raise(ERR_LIB_RSA, RSA_R_FIRST_OCTET_INVALID); 77 goto err; 78 } 79 if (MSBits == 0) { 80 EM++; 81 emLen--; 82 } 83 if (emLen < hLen + 2) { 84 ERR_raise(ERR_LIB_RSA, RSA_R_DATA_TOO_LARGE); 85 goto err; 86 } 87 if (sLen == RSA_PSS_SALTLEN_MAX) { 88 sLen = emLen - hLen - 2; 89 } else if (sLen > emLen - hLen - 2) { /* sLen can be small negative */ 90 ERR_raise(ERR_LIB_RSA, RSA_R_DATA_TOO_LARGE); 91 goto err; 92 } 93 if (EM[emLen - 1] != 0xbc) { 94 ERR_raise(ERR_LIB_RSA, RSA_R_LAST_OCTET_INVALID); 95 goto err; 96 } 97 maskedDBLen = emLen - hLen - 1; 98 H = EM + maskedDBLen; 99 DB = OPENSSL_malloc(maskedDBLen); 100 if (DB == NULL) { 101 ERR_raise(ERR_LIB_RSA, ERR_R_MALLOC_FAILURE); 102 goto err; 103 } 104 if (PKCS1_MGF1(DB, maskedDBLen, H, hLen, mgf1Hash) < 0) 105 goto err; 106 for (i = 0; i < maskedDBLen; i++) 107 DB[i] ^= EM[i]; 108 if (MSBits) 109 DB[0] &= 0xFF >> (8 - MSBits); 110 for (i = 0; DB[i] == 0 && i < (maskedDBLen - 1); i++) ; 111 if (DB[i++] != 0x1) { 112 ERR_raise(ERR_LIB_RSA, RSA_R_SLEN_RECOVERY_FAILED); 113 goto err; 114 } 115 if (sLen != RSA_PSS_SALTLEN_AUTO && (maskedDBLen - i) != sLen) { 116 ERR_raise_data(ERR_LIB_RSA, RSA_R_SLEN_CHECK_FAILED, 117 "expected: %d retrieved: %d", sLen, 118 maskedDBLen - i); 119 goto err; 120 } 121 if (!EVP_DigestInit_ex(ctx, Hash, NULL) 122 || !EVP_DigestUpdate(ctx, zeroes, sizeof(zeroes)) 123 || !EVP_DigestUpdate(ctx, mHash, hLen)) 124 goto err; 125 if (maskedDBLen - i) { 126 if (!EVP_DigestUpdate(ctx, DB + i, maskedDBLen - i)) 127 goto err; 128 } 129 if (!EVP_DigestFinal_ex(ctx, H_, NULL)) 130 goto err; 131 if (memcmp(H_, H, hLen)) { 132 ERR_raise(ERR_LIB_RSA, RSA_R_BAD_SIGNATURE); 133 ret = 0; 134 } else { 135 ret = 1; 136 } 137 138 err: 139 OPENSSL_free(DB); 140 EVP_MD_CTX_free(ctx); 141 142 return ret; 143 144 } 145 146 int RSA_padding_add_PKCS1_PSS(RSA *rsa, unsigned char *EM, 147 const unsigned char *mHash, 148 const EVP_MD *Hash, int sLen) 149 { 150 return RSA_padding_add_PKCS1_PSS_mgf1(rsa, EM, mHash, Hash, NULL, sLen); 151 } 152 153 int RSA_padding_add_PKCS1_PSS_mgf1(RSA *rsa, unsigned char *EM, 154 const unsigned char *mHash, 155 const EVP_MD *Hash, const EVP_MD *mgf1Hash, 156 int sLen) 157 { 158 int i; 159 int ret = 0; 160 int hLen, maskedDBLen, MSBits, emLen; 161 unsigned char *H, *salt = NULL, *p; 162 EVP_MD_CTX *ctx = NULL; 163 164 if (mgf1Hash == NULL) 165 mgf1Hash = Hash; 166 167 hLen = EVP_MD_get_size(Hash); 168 if (hLen < 0) 169 goto err; 170 /*- 171 * Negative sLen has special meanings: 172 * -1 sLen == hLen 173 * -2 salt length is maximized 174 * -3 same as above (on signing) 175 * -N reserved 176 */ 177 if (sLen == RSA_PSS_SALTLEN_DIGEST) { 178 sLen = hLen; 179 } else if (sLen == RSA_PSS_SALTLEN_MAX_SIGN) { 180 sLen = RSA_PSS_SALTLEN_MAX; 181 } else if (sLen < RSA_PSS_SALTLEN_MAX) { 182 ERR_raise(ERR_LIB_RSA, RSA_R_SLEN_CHECK_FAILED); 183 goto err; 184 } 185 186 MSBits = (BN_num_bits(rsa->n) - 1) & 0x7; 187 emLen = RSA_size(rsa); 188 if (MSBits == 0) { 189 *EM++ = 0; 190 emLen--; 191 } 192 if (emLen < hLen + 2) { 193 ERR_raise(ERR_LIB_RSA, RSA_R_DATA_TOO_LARGE_FOR_KEY_SIZE); 194 goto err; 195 } 196 if (sLen == RSA_PSS_SALTLEN_MAX) { 197 sLen = emLen - hLen - 2; 198 } else if (sLen > emLen - hLen - 2) { 199 ERR_raise(ERR_LIB_RSA, RSA_R_DATA_TOO_LARGE_FOR_KEY_SIZE); 200 goto err; 201 } 202 if (sLen > 0) { 203 salt = OPENSSL_malloc(sLen); 204 if (salt == NULL) { 205 ERR_raise(ERR_LIB_RSA, ERR_R_MALLOC_FAILURE); 206 goto err; 207 } 208 if (RAND_bytes_ex(rsa->libctx, salt, sLen, 0) <= 0) 209 goto err; 210 } 211 maskedDBLen = emLen - hLen - 1; 212 H = EM + maskedDBLen; 213 ctx = EVP_MD_CTX_new(); 214 if (ctx == NULL) 215 goto err; 216 if (!EVP_DigestInit_ex(ctx, Hash, NULL) 217 || !EVP_DigestUpdate(ctx, zeroes, sizeof(zeroes)) 218 || !EVP_DigestUpdate(ctx, mHash, hLen)) 219 goto err; 220 if (sLen && !EVP_DigestUpdate(ctx, salt, sLen)) 221 goto err; 222 if (!EVP_DigestFinal_ex(ctx, H, NULL)) 223 goto err; 224 225 /* Generate dbMask in place then perform XOR on it */ 226 if (PKCS1_MGF1(EM, maskedDBLen, H, hLen, mgf1Hash)) 227 goto err; 228 229 p = EM; 230 231 /* 232 * Initial PS XORs with all zeroes which is a NOP so just update pointer. 233 * Note from a test above this value is guaranteed to be non-negative. 234 */ 235 p += emLen - sLen - hLen - 2; 236 *p++ ^= 0x1; 237 if (sLen > 0) { 238 for (i = 0; i < sLen; i++) 239 *p++ ^= salt[i]; 240 } 241 if (MSBits) 242 EM[0] &= 0xFF >> (8 - MSBits); 243 244 /* H is already in place so just set final 0xbc */ 245 246 EM[emLen - 1] = 0xbc; 247 248 ret = 1; 249 250 err: 251 EVP_MD_CTX_free(ctx); 252 OPENSSL_clear_free(salt, (size_t)sLen); /* salt != NULL implies sLen > 0 */ 253 254 return ret; 255 256 } 257 258 /* 259 * The defaults for PSS restrictions are defined in RFC 8017, A.2.3 RSASSA-PSS 260 * (https://tools.ietf.org/html/rfc8017#appendix-A.2.3): 261 * 262 * If the default values of the hashAlgorithm, maskGenAlgorithm, and 263 * trailerField fields of RSASSA-PSS-params are used, then the algorithm 264 * identifier will have the following value: 265 * 266 * rSASSA-PSS-Default-Identifier RSASSA-AlgorithmIdentifier ::= { 267 * algorithm id-RSASSA-PSS, 268 * parameters RSASSA-PSS-params : { 269 * hashAlgorithm sha1, 270 * maskGenAlgorithm mgf1SHA1, 271 * saltLength 20, 272 * trailerField trailerFieldBC 273 * } 274 * } 275 * 276 * RSASSA-AlgorithmIdentifier ::= AlgorithmIdentifier { 277 * {PKCS1Algorithms} 278 * } 279 */ 280 static const RSA_PSS_PARAMS_30 default_RSASSA_PSS_params = { 281 NID_sha1, /* default hashAlgorithm */ 282 { 283 NID_mgf1, /* default maskGenAlgorithm */ 284 NID_sha1 /* default MGF1 hash */ 285 }, 286 20, /* default saltLength */ 287 1 /* default trailerField (0xBC) */ 288 }; 289 290 int ossl_rsa_pss_params_30_set_defaults(RSA_PSS_PARAMS_30 *rsa_pss_params) 291 { 292 if (rsa_pss_params == NULL) 293 return 0; 294 *rsa_pss_params = default_RSASSA_PSS_params; 295 return 1; 296 } 297 298 int ossl_rsa_pss_params_30_is_unrestricted(const RSA_PSS_PARAMS_30 *rsa_pss_params) 299 { 300 static RSA_PSS_PARAMS_30 pss_params_cmp = { 0, }; 301 302 return rsa_pss_params == NULL 303 || memcmp(rsa_pss_params, &pss_params_cmp, 304 sizeof(*rsa_pss_params)) == 0; 305 } 306 307 int ossl_rsa_pss_params_30_copy(RSA_PSS_PARAMS_30 *to, 308 const RSA_PSS_PARAMS_30 *from) 309 { 310 memcpy(to, from, sizeof(*to)); 311 return 1; 312 } 313 314 int ossl_rsa_pss_params_30_set_hashalg(RSA_PSS_PARAMS_30 *rsa_pss_params, 315 int hashalg_nid) 316 { 317 if (rsa_pss_params == NULL) 318 return 0; 319 rsa_pss_params->hash_algorithm_nid = hashalg_nid; 320 return 1; 321 } 322 323 int ossl_rsa_pss_params_30_set_maskgenalg(RSA_PSS_PARAMS_30 *rsa_pss_params, 324 int maskgenalg_nid) 325 { 326 if (rsa_pss_params == NULL) 327 return 0; 328 rsa_pss_params->mask_gen.algorithm_nid = maskgenalg_nid; 329 return 1; 330 } 331 332 int ossl_rsa_pss_params_30_set_maskgenhashalg(RSA_PSS_PARAMS_30 *rsa_pss_params, 333 int maskgenhashalg_nid) 334 { 335 if (rsa_pss_params == NULL) 336 return 0; 337 rsa_pss_params->mask_gen.hash_algorithm_nid = maskgenhashalg_nid; 338 return 1; 339 } 340 341 int ossl_rsa_pss_params_30_set_saltlen(RSA_PSS_PARAMS_30 *rsa_pss_params, 342 int saltlen) 343 { 344 if (rsa_pss_params == NULL) 345 return 0; 346 rsa_pss_params->salt_len = saltlen; 347 return 1; 348 } 349 350 int ossl_rsa_pss_params_30_set_trailerfield(RSA_PSS_PARAMS_30 *rsa_pss_params, 351 int trailerfield) 352 { 353 if (rsa_pss_params == NULL) 354 return 0; 355 rsa_pss_params->trailer_field = trailerfield; 356 return 1; 357 } 358 359 int ossl_rsa_pss_params_30_hashalg(const RSA_PSS_PARAMS_30 *rsa_pss_params) 360 { 361 if (rsa_pss_params == NULL) 362 return default_RSASSA_PSS_params.hash_algorithm_nid; 363 return rsa_pss_params->hash_algorithm_nid; 364 } 365 366 int ossl_rsa_pss_params_30_maskgenalg(const RSA_PSS_PARAMS_30 *rsa_pss_params) 367 { 368 if (rsa_pss_params == NULL) 369 return default_RSASSA_PSS_params.mask_gen.algorithm_nid; 370 return rsa_pss_params->mask_gen.algorithm_nid; 371 } 372 373 int ossl_rsa_pss_params_30_maskgenhashalg(const RSA_PSS_PARAMS_30 *rsa_pss_params) 374 { 375 if (rsa_pss_params == NULL) 376 return default_RSASSA_PSS_params.hash_algorithm_nid; 377 return rsa_pss_params->mask_gen.hash_algorithm_nid; 378 } 379 380 int ossl_rsa_pss_params_30_saltlen(const RSA_PSS_PARAMS_30 *rsa_pss_params) 381 { 382 if (rsa_pss_params == NULL) 383 return default_RSASSA_PSS_params.salt_len; 384 return rsa_pss_params->salt_len; 385 } 386 387 int ossl_rsa_pss_params_30_trailerfield(const RSA_PSS_PARAMS_30 *rsa_pss_params) 388 { 389 if (rsa_pss_params == NULL) 390 return default_RSASSA_PSS_params.trailer_field; 391 return rsa_pss_params->trailer_field; 392 } 393 394 #if defined(_MSC_VER) 395 # pragma optimize("",on) 396 #endif 397