1 /* This Source Code Form is subject to the terms of the Mozilla Public
2  * License, v. 2.0. If a copy of the MPL was not distributed with this
3  * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
4 
5 /*
6  * RSA PKCS#1 v2.1 (RFC 3447) operations
7  */
8 
9 #ifdef FREEBL_NO_DEPEND
10 #include "stubs.h"
11 #endif
12 
13 #include "secerr.h"
14 
15 #include "blapi.h"
16 #include "secitem.h"
17 #include "blapii.h"
18 
19 #define RSA_BLOCK_MIN_PAD_LEN 8
20 #define RSA_BLOCK_FIRST_OCTET 0x00
21 #define RSA_BLOCK_PRIVATE_PAD_OCTET 0xff
22 #define RSA_BLOCK_AFTER_PAD_OCTET 0x00
23 
24 /*
25  * RSA block types
26  *
27  * The values of RSA_BlockPrivate and RSA_BlockPublic are fixed.
28  * The value of RSA_BlockRaw isn't fixed by definition, but we are keeping
29  * the value that NSS has been using in the past.
30  */
31 typedef enum {
32     RSA_BlockPrivate = 1, /* pad for a private-key operation */
33     RSA_BlockPublic = 2,  /* pad for a public-key operation */
34     RSA_BlockRaw = 4      /* simply justify the block appropriately */
35 } RSA_BlockType;
36 
37 /* Needed for RSA-PSS functions */
38 static const unsigned char eightZeros[] = { 0, 0, 0, 0, 0, 0, 0, 0 };
39 
40 /* Constant time comparison of a single byte.
41  * Returns 1 iff a == b, otherwise returns 0.
42  * Note: For ranges of bytes, use constantTimeCompare.
43  */
44 static unsigned char
constantTimeEQ8(unsigned char a,unsigned char b)45 constantTimeEQ8(unsigned char a, unsigned char b)
46 {
47     unsigned char c = ~((a - b) | (b - a));
48     c >>= 7;
49     return c;
50 }
51 
52 /* Constant time comparison of a range of bytes.
53  * Returns 1 iff len bytes of a are identical to len bytes of b, otherwise
54  * returns 0.
55  */
56 static unsigned char
constantTimeCompare(const unsigned char * a,const unsigned char * b,unsigned int len)57 constantTimeCompare(const unsigned char *a,
58                     const unsigned char *b,
59                     unsigned int len)
60 {
61     unsigned char tmp = 0;
62     unsigned int i;
63     for (i = 0; i < len; ++i, ++a, ++b)
64         tmp |= *a ^ *b;
65     return constantTimeEQ8(0x00, tmp);
66 }
67 
68 /* Constant time conditional.
69  * Returns a if c is 1, or b if c is 0. The result is undefined if c is
70  * not 0 or 1.
71  */
72 static unsigned int
constantTimeCondition(unsigned int c,unsigned int a,unsigned int b)73 constantTimeCondition(unsigned int c,
74                       unsigned int a,
75                       unsigned int b)
76 {
77     return (~(c - 1) & a) | ((c - 1) & b);
78 }
79 
80 static unsigned int
rsa_modulusLen(SECItem * modulus)81 rsa_modulusLen(SECItem *modulus)
82 {
83     unsigned char byteZero = modulus->data[0];
84     unsigned int modLen = modulus->len - !byteZero;
85     return modLen;
86 }
87 
88 /*
89  * Format one block of data for public/private key encryption using
90  * the rules defined in PKCS #1.
91  */
92 static unsigned char *
rsa_FormatOneBlock(unsigned modulusLen,RSA_BlockType blockType,SECItem * data)93 rsa_FormatOneBlock(unsigned modulusLen,
94                    RSA_BlockType blockType,
95                    SECItem *data)
96 {
97     unsigned char *block;
98     unsigned char *bp;
99     int padLen;
100     int i, j;
101     SECStatus rv;
102 
103     block = (unsigned char *)PORT_Alloc(modulusLen);
104     if (block == NULL)
105         return NULL;
106 
107     bp = block;
108 
109     /*
110      * All RSA blocks start with two octets:
111      *  0x00 || BlockType
112      */
113     *bp++ = RSA_BLOCK_FIRST_OCTET;
114     *bp++ = (unsigned char)blockType;
115 
116     switch (blockType) {
117 
118         /*
119        * Blocks intended for private-key operation.
120        */
121         case RSA_BlockPrivate: /* preferred method */
122             /*
123          * 0x00 || BT || Pad || 0x00 || ActualData
124          *   1      1   padLen    1      data->len
125          * Pad is either all 0x00 or all 0xff bytes, depending on blockType.
126          */
127             padLen = modulusLen - data->len - 3;
128             PORT_Assert(padLen >= RSA_BLOCK_MIN_PAD_LEN);
129             if (padLen < RSA_BLOCK_MIN_PAD_LEN) {
130                 PORT_Free(block);
131                 return NULL;
132             }
133             PORT_Memset(bp, RSA_BLOCK_PRIVATE_PAD_OCTET, padLen);
134             bp += padLen;
135             *bp++ = RSA_BLOCK_AFTER_PAD_OCTET;
136             PORT_Memcpy(bp, data->data, data->len);
137             break;
138 
139         /*
140          * Blocks intended for public-key operation.
141          */
142         case RSA_BlockPublic:
143             /*
144              * 0x00 || BT || Pad || 0x00 || ActualData
145              *   1      1   padLen    1      data->len
146              * Pad is all non-zero random bytes.
147              *
148              * Build the block left to right.
149              * Fill the entire block from Pad to the end with random bytes.
150              * Use the bytes after Pad as a supply of extra random bytes from
151              * which to find replacements for the zero bytes in Pad.
152              * If we need more than that, refill the bytes after Pad with
153              * new random bytes as necessary.
154              */
155             padLen = modulusLen - (data->len + 3);
156             PORT_Assert(padLen >= RSA_BLOCK_MIN_PAD_LEN);
157             if (padLen < RSA_BLOCK_MIN_PAD_LEN) {
158                 PORT_Free(block);
159                 return NULL;
160             }
161             j = modulusLen - 2;
162             rv = RNG_GenerateGlobalRandomBytes(bp, j);
163             if (rv == SECSuccess) {
164                 for (i = 0; i < padLen;) {
165                     unsigned char repl;
166                     /* Pad with non-zero random data. */
167                     if (bp[i] != RSA_BLOCK_AFTER_PAD_OCTET) {
168                         ++i;
169                         continue;
170                     }
171                     if (j <= padLen) {
172                         rv = RNG_GenerateGlobalRandomBytes(bp + padLen,
173                                                            modulusLen - (2 + padLen));
174                         if (rv != SECSuccess)
175                             break;
176                         j = modulusLen - 2;
177                     }
178                     do {
179                         repl = bp[--j];
180                     } while (repl == RSA_BLOCK_AFTER_PAD_OCTET && j > padLen);
181                     if (repl != RSA_BLOCK_AFTER_PAD_OCTET) {
182                         bp[i++] = repl;
183                     }
184                 }
185             }
186             if (rv != SECSuccess) {
187                 PORT_Free(block);
188                 PORT_SetError(SEC_ERROR_LIBRARY_FAILURE);
189                 return NULL;
190             }
191             bp += padLen;
192             *bp++ = RSA_BLOCK_AFTER_PAD_OCTET;
193             PORT_Memcpy(bp, data->data, data->len);
194             break;
195 
196         default:
197             PORT_Assert(0);
198             PORT_Free(block);
199             return NULL;
200     }
201 
202     return block;
203 }
204 
205 static SECStatus
rsa_FormatBlock(SECItem * result,unsigned modulusLen,RSA_BlockType blockType,SECItem * data)206 rsa_FormatBlock(SECItem *result,
207                 unsigned modulusLen,
208                 RSA_BlockType blockType,
209                 SECItem *data)
210 {
211     switch (blockType) {
212         case RSA_BlockPrivate:
213         case RSA_BlockPublic:
214             /*
215              * 0x00 || BT || Pad || 0x00 || ActualData
216              *
217              * The "3" below is the first octet + the second octet + the 0x00
218              * octet that always comes just before the ActualData.
219              */
220             PORT_Assert(data->len <= (modulusLen - (3 + RSA_BLOCK_MIN_PAD_LEN)));
221 
222             result->data = rsa_FormatOneBlock(modulusLen, blockType, data);
223             if (result->data == NULL) {
224                 result->len = 0;
225                 return SECFailure;
226             }
227             result->len = modulusLen;
228 
229             break;
230 
231         case RSA_BlockRaw:
232             /*
233              * Pad || ActualData
234              * Pad is zeros. The application is responsible for recovering
235              * the actual data.
236              */
237             if (data->len > modulusLen) {
238                 return SECFailure;
239             }
240             result->data = (unsigned char *)PORT_ZAlloc(modulusLen);
241             result->len = modulusLen;
242             PORT_Memcpy(result->data + (modulusLen - data->len),
243                         data->data, data->len);
244             break;
245 
246         default:
247             PORT_Assert(0);
248             result->data = NULL;
249             result->len = 0;
250             return SECFailure;
251     }
252 
253     return SECSuccess;
254 }
255 
256 /*
257  * Mask generation function MGF1 as defined in PKCS #1 v2.1 / RFC 3447.
258  */
259 static SECStatus
MGF1(HASH_HashType hashAlg,unsigned char * mask,unsigned int maskLen,const unsigned char * mgfSeed,unsigned int mgfSeedLen)260 MGF1(HASH_HashType hashAlg,
261      unsigned char *mask,
262      unsigned int maskLen,
263      const unsigned char *mgfSeed,
264      unsigned int mgfSeedLen)
265 {
266     unsigned int digestLen;
267     PRUint32 counter;
268     PRUint32 rounds;
269     unsigned char *tempHash;
270     unsigned char *temp;
271     const SECHashObject *hash;
272     void *hashContext;
273     unsigned char C[4];
274 
275     hash = HASH_GetRawHashObject(hashAlg);
276     if (hash == NULL)
277         return SECFailure;
278 
279     hashContext = (*hash->create)();
280     rounds = (maskLen + hash->length - 1) / hash->length;
281     for (counter = 0; counter < rounds; counter++) {
282         C[0] = (unsigned char)((counter >> 24) & 0xff);
283         C[1] = (unsigned char)((counter >> 16) & 0xff);
284         C[2] = (unsigned char)((counter >> 8) & 0xff);
285         C[3] = (unsigned char)(counter & 0xff);
286 
287         /* This could be optimized when the clone functions in
288          * rawhash.c are implemented. */
289         (*hash->begin)(hashContext);
290         (*hash->update)(hashContext, mgfSeed, mgfSeedLen);
291         (*hash->update)(hashContext, C, sizeof C);
292 
293         tempHash = mask + counter * hash->length;
294         if (counter != (rounds - 1)) {
295             (*hash->end)(hashContext, tempHash, &digestLen, hash->length);
296         } else { /* we're in the last round and need to cut the hash */
297             temp = (unsigned char *)PORT_Alloc(hash->length);
298             (*hash->end)(hashContext, temp, &digestLen, hash->length);
299             PORT_Memcpy(tempHash, temp, maskLen - counter * hash->length);
300             PORT_Free(temp);
301         }
302     }
303     (*hash->destroy)(hashContext, PR_TRUE);
304 
305     return SECSuccess;
306 }
307 
308 /* XXX Doesn't set error code */
309 SECStatus
RSA_SignRaw(RSAPrivateKey * key,unsigned char * output,unsigned int * outputLen,unsigned int maxOutputLen,const unsigned char * data,unsigned int dataLen)310 RSA_SignRaw(RSAPrivateKey *key,
311             unsigned char *output,
312             unsigned int *outputLen,
313             unsigned int maxOutputLen,
314             const unsigned char *data,
315             unsigned int dataLen)
316 {
317     SECStatus rv = SECSuccess;
318     unsigned int modulusLen = rsa_modulusLen(&key->modulus);
319     SECItem formatted;
320     SECItem unformatted;
321 
322     if (maxOutputLen < modulusLen)
323         return SECFailure;
324 
325     unformatted.len = dataLen;
326     unformatted.data = (unsigned char *)data;
327     formatted.data = NULL;
328     rv = rsa_FormatBlock(&formatted, modulusLen, RSA_BlockRaw, &unformatted);
329     if (rv != SECSuccess)
330         goto done;
331 
332     rv = RSA_PrivateKeyOpDoubleChecked(key, output, formatted.data);
333     *outputLen = modulusLen;
334 
335 done:
336     if (formatted.data != NULL)
337         PORT_ZFree(formatted.data, modulusLen);
338     return rv;
339 }
340 
341 /* XXX Doesn't set error code */
342 SECStatus
RSA_CheckSignRaw(RSAPublicKey * key,const unsigned char * sig,unsigned int sigLen,const unsigned char * hash,unsigned int hashLen)343 RSA_CheckSignRaw(RSAPublicKey *key,
344                  const unsigned char *sig,
345                  unsigned int sigLen,
346                  const unsigned char *hash,
347                  unsigned int hashLen)
348 {
349     SECStatus rv;
350     unsigned int modulusLen = rsa_modulusLen(&key->modulus);
351     unsigned char *buffer;
352 
353     if (sigLen != modulusLen)
354         goto failure;
355     if (hashLen > modulusLen)
356         goto failure;
357 
358     buffer = (unsigned char *)PORT_Alloc(modulusLen + 1);
359     if (!buffer)
360         goto failure;
361 
362     rv = RSA_PublicKeyOp(key, buffer, sig);
363     if (rv != SECSuccess)
364         goto loser;
365 
366     /*
367      * make sure we get the same results
368      */
369     /* XXX(rsleevi): Constant time */
370     /* NOTE: should we verify the leading zeros? */
371     if (PORT_Memcmp(buffer + (modulusLen - hashLen), hash, hashLen) != 0)
372         goto loser;
373 
374     PORT_Free(buffer);
375     return SECSuccess;
376 
377 loser:
378     PORT_Free(buffer);
379 failure:
380     return SECFailure;
381 }
382 
383 /* XXX Doesn't set error code */
384 SECStatus
RSA_CheckSignRecoverRaw(RSAPublicKey * key,unsigned char * data,unsigned int * dataLen,unsigned int maxDataLen,const unsigned char * sig,unsigned int sigLen)385 RSA_CheckSignRecoverRaw(RSAPublicKey *key,
386                         unsigned char *data,
387                         unsigned int *dataLen,
388                         unsigned int maxDataLen,
389                         const unsigned char *sig,
390                         unsigned int sigLen)
391 {
392     SECStatus rv;
393     unsigned int modulusLen = rsa_modulusLen(&key->modulus);
394 
395     if (sigLen != modulusLen)
396         goto failure;
397     if (maxDataLen < modulusLen)
398         goto failure;
399 
400     rv = RSA_PublicKeyOp(key, data, sig);
401     if (rv != SECSuccess)
402         goto failure;
403 
404     *dataLen = modulusLen;
405     return SECSuccess;
406 
407 failure:
408     return SECFailure;
409 }
410 
411 /* XXX Doesn't set error code */
412 SECStatus
RSA_EncryptRaw(RSAPublicKey * key,unsigned char * output,unsigned int * outputLen,unsigned int maxOutputLen,const unsigned char * input,unsigned int inputLen)413 RSA_EncryptRaw(RSAPublicKey *key,
414                unsigned char *output,
415                unsigned int *outputLen,
416                unsigned int maxOutputLen,
417                const unsigned char *input,
418                unsigned int inputLen)
419 {
420     SECStatus rv;
421     unsigned int modulusLen = rsa_modulusLen(&key->modulus);
422     SECItem formatted;
423     SECItem unformatted;
424 
425     formatted.data = NULL;
426     if (maxOutputLen < modulusLen)
427         goto failure;
428 
429     unformatted.len = inputLen;
430     unformatted.data = (unsigned char *)input;
431     formatted.data = NULL;
432     rv = rsa_FormatBlock(&formatted, modulusLen, RSA_BlockRaw, &unformatted);
433     if (rv != SECSuccess)
434         goto failure;
435 
436     rv = RSA_PublicKeyOp(key, output, formatted.data);
437     if (rv != SECSuccess)
438         goto failure;
439 
440     PORT_ZFree(formatted.data, modulusLen);
441     *outputLen = modulusLen;
442     return SECSuccess;
443 
444 failure:
445     if (formatted.data != NULL)
446         PORT_ZFree(formatted.data, modulusLen);
447     return SECFailure;
448 }
449 
450 /* XXX Doesn't set error code */
451 SECStatus
RSA_DecryptRaw(RSAPrivateKey * key,unsigned char * output,unsigned int * outputLen,unsigned int maxOutputLen,const unsigned char * input,unsigned int inputLen)452 RSA_DecryptRaw(RSAPrivateKey *key,
453                unsigned char *output,
454                unsigned int *outputLen,
455                unsigned int maxOutputLen,
456                const unsigned char *input,
457                unsigned int inputLen)
458 {
459     SECStatus rv;
460     unsigned int modulusLen = rsa_modulusLen(&key->modulus);
461 
462     if (modulusLen > maxOutputLen)
463         goto failure;
464     if (inputLen != modulusLen)
465         goto failure;
466 
467     rv = RSA_PrivateKeyOp(key, output, input);
468     if (rv != SECSuccess)
469         goto failure;
470 
471     *outputLen = modulusLen;
472     return SECSuccess;
473 
474 failure:
475     return SECFailure;
476 }
477 
478 /*
479  * Decodes an EME-OAEP encoded block, validating the encoding in constant
480  * time.
481  * Described in RFC 3447, section 7.1.2.
482  * input contains the encoded block, after decryption.
483  * label is the optional value L that was associated with the message.
484  * On success, the original message and message length will be stored in
485  * output and outputLen.
486  */
487 static SECStatus
eme_oaep_decode(unsigned char * output,unsigned int * outputLen,unsigned int maxOutputLen,const unsigned char * input,unsigned int inputLen,HASH_HashType hashAlg,HASH_HashType maskHashAlg,const unsigned char * label,unsigned int labelLen)488 eme_oaep_decode(unsigned char *output,
489                 unsigned int *outputLen,
490                 unsigned int maxOutputLen,
491                 const unsigned char *input,
492                 unsigned int inputLen,
493                 HASH_HashType hashAlg,
494                 HASH_HashType maskHashAlg,
495                 const unsigned char *label,
496                 unsigned int labelLen)
497 {
498     const SECHashObject *hash;
499     void *hashContext;
500     SECStatus rv = SECFailure;
501     unsigned char labelHash[HASH_LENGTH_MAX];
502     unsigned int i;
503     unsigned int maskLen;
504     unsigned int paddingOffset;
505     unsigned char *mask = NULL;
506     unsigned char *tmpOutput = NULL;
507     unsigned char isGood;
508     unsigned char foundPaddingEnd;
509 
510     hash = HASH_GetRawHashObject(hashAlg);
511 
512     /* 1.c */
513     if (inputLen < (hash->length * 2) + 2) {
514         PORT_SetError(SEC_ERROR_INPUT_LEN);
515         return SECFailure;
516     }
517 
518     /* Step 3.a - Generate lHash */
519     hashContext = (*hash->create)();
520     if (hashContext == NULL) {
521         PORT_SetError(SEC_ERROR_NO_MEMORY);
522         return SECFailure;
523     }
524     (*hash->begin)(hashContext);
525     if (labelLen > 0)
526         (*hash->update)(hashContext, label, labelLen);
527     (*hash->end)(hashContext, labelHash, &i, sizeof(labelHash));
528     (*hash->destroy)(hashContext, PR_TRUE);
529 
530     tmpOutput = (unsigned char *)PORT_Alloc(inputLen);
531     if (tmpOutput == NULL) {
532         PORT_SetError(SEC_ERROR_NO_MEMORY);
533         goto done;
534     }
535 
536     maskLen = inputLen - hash->length - 1;
537     mask = (unsigned char *)PORT_Alloc(maskLen);
538     if (mask == NULL) {
539         PORT_SetError(SEC_ERROR_NO_MEMORY);
540         goto done;
541     }
542 
543     PORT_Memcpy(tmpOutput, input, inputLen);
544 
545     /* 3.c - Generate seedMask */
546     MGF1(maskHashAlg, mask, hash->length, &tmpOutput[1 + hash->length],
547          inputLen - hash->length - 1);
548     /* 3.d - Unmask seed */
549     for (i = 0; i < hash->length; ++i)
550         tmpOutput[1 + i] ^= mask[i];
551 
552     /* 3.e - Generate dbMask */
553     MGF1(maskHashAlg, mask, maskLen, &tmpOutput[1], hash->length);
554     /* 3.f - Unmask DB */
555     for (i = 0; i < maskLen; ++i)
556         tmpOutput[1 + hash->length + i] ^= mask[i];
557 
558     /* 3.g - Compare Y, lHash, and PS in constant time
559      * Warning: This code is timing dependent and must not disclose which of
560      * these were invalid.
561      */
562     paddingOffset = 0;
563     isGood = 1;
564     foundPaddingEnd = 0;
565 
566     /* Compare Y */
567     isGood &= constantTimeEQ8(0x00, tmpOutput[0]);
568 
569     /* Compare lHash and lHash' */
570     isGood &= constantTimeCompare(&labelHash[0],
571                                   &tmpOutput[1 + hash->length],
572                                   hash->length);
573 
574     /* Compare that the padding is zero or more zero octets, followed by a
575      * 0x01 octet */
576     for (i = 1 + (hash->length * 2); i < inputLen; ++i) {
577         unsigned char isZero = constantTimeEQ8(0x00, tmpOutput[i]);
578         unsigned char isOne = constantTimeEQ8(0x01, tmpOutput[i]);
579         /* non-constant time equivalent:
580          * if (tmpOutput[i] == 0x01 && !foundPaddingEnd)
581          *     paddingOffset = i;
582          */
583         paddingOffset = constantTimeCondition(isOne & ~foundPaddingEnd, i,
584                                               paddingOffset);
585         /* non-constant time equivalent:
586          * if (tmpOutput[i] == 0x01)
587          *    foundPaddingEnd = true;
588          *
589          * Note: This may yield false positives, as it will be set whenever
590          * a 0x01 byte is encountered. If there was bad padding (eg:
591          * 0x03 0x02 0x01), foundPaddingEnd will still be set to true, and
592          * paddingOffset will still be set to 2.
593          */
594         foundPaddingEnd = constantTimeCondition(isOne, 1, foundPaddingEnd);
595         /* non-constant time equivalent:
596          * if (tmpOutput[i] != 0x00 && tmpOutput[i] != 0x01 &&
597          *     !foundPaddingEnd) {
598          *    isGood = false;
599          * }
600          *
601          * Note: This may yield false positives, as a message (and padding)
602          * that is entirely zeros will result in isGood still being true. Thus
603          * it's necessary to check foundPaddingEnd is positive below.
604          */
605         isGood = constantTimeCondition(~foundPaddingEnd & ~isZero, 0, isGood);
606     }
607 
608     /* While both isGood and foundPaddingEnd may have false positives, they
609      * cannot BOTH have false positives. If both are not true, then an invalid
610      * message was received. Note, this comparison must still be done in constant
611      * time so as not to leak either condition.
612      */
613     if (!(isGood & foundPaddingEnd)) {
614         PORT_SetError(SEC_ERROR_BAD_DATA);
615         goto done;
616     }
617 
618     /* End timing dependent code */
619 
620     ++paddingOffset; /* Skip the 0x01 following the end of PS */
621 
622     *outputLen = inputLen - paddingOffset;
623     if (*outputLen > maxOutputLen) {
624         PORT_SetError(SEC_ERROR_OUTPUT_LEN);
625         goto done;
626     }
627 
628     if (*outputLen)
629         PORT_Memcpy(output, &tmpOutput[paddingOffset], *outputLen);
630     rv = SECSuccess;
631 
632 done:
633     if (mask)
634         PORT_ZFree(mask, maskLen);
635     if (tmpOutput)
636         PORT_ZFree(tmpOutput, inputLen);
637     return rv;
638 }
639 
640 /*
641  * Generate an EME-OAEP encoded block for encryption
642  * Described in RFC 3447, section 7.1.1
643  * We use input instead of M for the message to be encrypted
644  * label is the optional value L to be associated with the message.
645  */
646 static SECStatus
eme_oaep_encode(unsigned char * em,unsigned int emLen,const unsigned char * input,unsigned int inputLen,HASH_HashType hashAlg,HASH_HashType maskHashAlg,const unsigned char * label,unsigned int labelLen,const unsigned char * seed,unsigned int seedLen)647 eme_oaep_encode(unsigned char *em,
648                 unsigned int emLen,
649                 const unsigned char *input,
650                 unsigned int inputLen,
651                 HASH_HashType hashAlg,
652                 HASH_HashType maskHashAlg,
653                 const unsigned char *label,
654                 unsigned int labelLen,
655                 const unsigned char *seed,
656                 unsigned int seedLen)
657 {
658     const SECHashObject *hash;
659     void *hashContext;
660     SECStatus rv;
661     unsigned char *mask;
662     unsigned int reservedLen;
663     unsigned int dbMaskLen;
664     unsigned int i;
665 
666     hash = HASH_GetRawHashObject(hashAlg);
667     PORT_Assert(seed == NULL || seedLen == hash->length);
668 
669     /* Step 1.b */
670     reservedLen = (2 * hash->length) + 2;
671     if (emLen < reservedLen || inputLen > (emLen - reservedLen)) {
672         PORT_SetError(SEC_ERROR_INPUT_LEN);
673         return SECFailure;
674     }
675 
676     /*
677      * From RFC 3447, Section 7.1
678      *                      +----------+---------+-------+
679      *                 DB = |  lHash   |    PS   |   M   |
680      *                      +----------+---------+-------+
681      *                                     |
682      *           +----------+              V
683      *           |   seed   |--> MGF ---> xor
684      *           +----------+              |
685      *                 |                   |
686      *        +--+     V                   |
687      *        |00|    xor <----- MGF <-----|
688      *        +--+     |                   |
689      *          |      |                   |
690      *          V      V                   V
691      *        +--+----------+----------------------------+
692      *  EM =  |00|maskedSeed|          maskedDB          |
693      *        +--+----------+----------------------------+
694      *
695      * We use mask to hold the result of the MGF functions, and all other
696      * values are generated in their final resting place.
697      */
698     *em = 0x00;
699 
700     /* Step 2.a - Generate lHash */
701     hashContext = (*hash->create)();
702     if (hashContext == NULL) {
703         PORT_SetError(SEC_ERROR_NO_MEMORY);
704         return SECFailure;
705     }
706     (*hash->begin)(hashContext);
707     if (labelLen > 0)
708         (*hash->update)(hashContext, label, labelLen);
709     (*hash->end)(hashContext, &em[1 + hash->length], &i, hash->length);
710     (*hash->destroy)(hashContext, PR_TRUE);
711 
712     /* Step 2.b - Generate PS */
713     if (emLen - reservedLen - inputLen > 0) {
714         PORT_Memset(em + 1 + (hash->length * 2), 0x00,
715                     emLen - reservedLen - inputLen);
716     }
717 
718     /* Step 2.c. - Generate DB
719      * DB = lHash || PS || 0x01 || M
720      * Note that PS and lHash have already been placed into em at their
721      * appropriate offsets. This just copies M into place
722      */
723     em[emLen - inputLen - 1] = 0x01;
724     if (inputLen)
725         PORT_Memcpy(em + emLen - inputLen, input, inputLen);
726 
727     if (seed == NULL) {
728         /* Step 2.d - Generate seed */
729         rv = RNG_GenerateGlobalRandomBytes(em + 1, hash->length);
730         if (rv != SECSuccess) {
731             return rv;
732         }
733     } else {
734         /* For Known Answer Tests, copy the supplied seed. */
735         PORT_Memcpy(em + 1, seed, seedLen);
736     }
737 
738     /* Step 2.e - Generate dbMask*/
739     dbMaskLen = emLen - hash->length - 1;
740     mask = (unsigned char *)PORT_Alloc(dbMaskLen);
741     if (mask == NULL) {
742         PORT_SetError(SEC_ERROR_NO_MEMORY);
743         return SECFailure;
744     }
745     MGF1(maskHashAlg, mask, dbMaskLen, em + 1, hash->length);
746     /* Step 2.f - Compute maskedDB*/
747     for (i = 0; i < dbMaskLen; ++i)
748         em[1 + hash->length + i] ^= mask[i];
749 
750     /* Step 2.g - Generate seedMask */
751     MGF1(maskHashAlg, mask, hash->length, &em[1 + hash->length], dbMaskLen);
752     /* Step 2.h - Compute maskedSeed */
753     for (i = 0; i < hash->length; ++i)
754         em[1 + i] ^= mask[i];
755 
756     PORT_ZFree(mask, dbMaskLen);
757     return SECSuccess;
758 }
759 
760 SECStatus
RSA_EncryptOAEP(RSAPublicKey * key,HASH_HashType hashAlg,HASH_HashType maskHashAlg,const unsigned char * label,unsigned int labelLen,const unsigned char * seed,unsigned int seedLen,unsigned char * output,unsigned int * outputLen,unsigned int maxOutputLen,const unsigned char * input,unsigned int inputLen)761 RSA_EncryptOAEP(RSAPublicKey *key,
762                 HASH_HashType hashAlg,
763                 HASH_HashType maskHashAlg,
764                 const unsigned char *label,
765                 unsigned int labelLen,
766                 const unsigned char *seed,
767                 unsigned int seedLen,
768                 unsigned char *output,
769                 unsigned int *outputLen,
770                 unsigned int maxOutputLen,
771                 const unsigned char *input,
772                 unsigned int inputLen)
773 {
774     SECStatus rv = SECFailure;
775     unsigned int modulusLen = rsa_modulusLen(&key->modulus);
776     unsigned char *oaepEncoded = NULL;
777 
778     if (maxOutputLen < modulusLen) {
779         PORT_SetError(SEC_ERROR_OUTPUT_LEN);
780         return SECFailure;
781     }
782 
783     if ((hashAlg == HASH_AlgNULL) || (maskHashAlg == HASH_AlgNULL)) {
784         PORT_SetError(SEC_ERROR_INVALID_ALGORITHM);
785         return SECFailure;
786     }
787 
788     if ((labelLen == 0 && label != NULL) ||
789         (labelLen > 0 && label == NULL)) {
790         PORT_SetError(SEC_ERROR_INVALID_ALGORITHM);
791         return SECFailure;
792     }
793 
794     oaepEncoded = (unsigned char *)PORT_Alloc(modulusLen);
795     if (oaepEncoded == NULL) {
796         PORT_SetError(SEC_ERROR_NO_MEMORY);
797         return SECFailure;
798     }
799     rv = eme_oaep_encode(oaepEncoded, modulusLen, input, inputLen,
800                          hashAlg, maskHashAlg, label, labelLen, seed, seedLen);
801     if (rv != SECSuccess)
802         goto done;
803 
804     rv = RSA_PublicKeyOp(key, output, oaepEncoded);
805     if (rv != SECSuccess)
806         goto done;
807     *outputLen = modulusLen;
808 
809 done:
810     PORT_Free(oaepEncoded);
811     return rv;
812 }
813 
814 SECStatus
RSA_DecryptOAEP(RSAPrivateKey * key,HASH_HashType hashAlg,HASH_HashType maskHashAlg,const unsigned char * label,unsigned int labelLen,unsigned char * output,unsigned int * outputLen,unsigned int maxOutputLen,const unsigned char * input,unsigned int inputLen)815 RSA_DecryptOAEP(RSAPrivateKey *key,
816                 HASH_HashType hashAlg,
817                 HASH_HashType maskHashAlg,
818                 const unsigned char *label,
819                 unsigned int labelLen,
820                 unsigned char *output,
821                 unsigned int *outputLen,
822                 unsigned int maxOutputLen,
823                 const unsigned char *input,
824                 unsigned int inputLen)
825 {
826     SECStatus rv = SECFailure;
827     unsigned int modulusLen = rsa_modulusLen(&key->modulus);
828     unsigned char *oaepEncoded = NULL;
829 
830     if ((hashAlg == HASH_AlgNULL) || (maskHashAlg == HASH_AlgNULL)) {
831         PORT_SetError(SEC_ERROR_INVALID_ALGORITHM);
832         return SECFailure;
833     }
834 
835     if (inputLen != modulusLen) {
836         PORT_SetError(SEC_ERROR_INPUT_LEN);
837         return SECFailure;
838     }
839 
840     if ((labelLen == 0 && label != NULL) ||
841         (labelLen > 0 && label == NULL)) {
842         PORT_SetError(SEC_ERROR_INVALID_ALGORITHM);
843         return SECFailure;
844     }
845 
846     oaepEncoded = (unsigned char *)PORT_Alloc(modulusLen);
847     if (oaepEncoded == NULL) {
848         PORT_SetError(SEC_ERROR_NO_MEMORY);
849         return SECFailure;
850     }
851 
852     rv = RSA_PrivateKeyOpDoubleChecked(key, oaepEncoded, input);
853     if (rv != SECSuccess) {
854         goto done;
855     }
856     rv = eme_oaep_decode(output, outputLen, maxOutputLen, oaepEncoded,
857                          modulusLen, hashAlg, maskHashAlg, label,
858                          labelLen);
859 
860 done:
861     if (oaepEncoded)
862         PORT_ZFree(oaepEncoded, modulusLen);
863     return rv;
864 }
865 
866 /* XXX Doesn't set error code */
867 SECStatus
RSA_EncryptBlock(RSAPublicKey * key,unsigned char * output,unsigned int * outputLen,unsigned int maxOutputLen,const unsigned char * input,unsigned int inputLen)868 RSA_EncryptBlock(RSAPublicKey *key,
869                  unsigned char *output,
870                  unsigned int *outputLen,
871                  unsigned int maxOutputLen,
872                  const unsigned char *input,
873                  unsigned int inputLen)
874 {
875     SECStatus rv;
876     unsigned int modulusLen = rsa_modulusLen(&key->modulus);
877     SECItem formatted;
878     SECItem unformatted;
879 
880     formatted.data = NULL;
881     if (maxOutputLen < modulusLen)
882         goto failure;
883 
884     unformatted.len = inputLen;
885     unformatted.data = (unsigned char *)input;
886     formatted.data = NULL;
887     rv = rsa_FormatBlock(&formatted, modulusLen, RSA_BlockPublic,
888                          &unformatted);
889     if (rv != SECSuccess)
890         goto failure;
891 
892     rv = RSA_PublicKeyOp(key, output, formatted.data);
893     if (rv != SECSuccess)
894         goto failure;
895 
896     PORT_ZFree(formatted.data, modulusLen);
897     *outputLen = modulusLen;
898     return SECSuccess;
899 
900 failure:
901     if (formatted.data != NULL)
902         PORT_ZFree(formatted.data, modulusLen);
903     return SECFailure;
904 }
905 
906 /* XXX Doesn't set error code */
907 SECStatus
RSA_DecryptBlock(RSAPrivateKey * key,unsigned char * output,unsigned int * outputLen,unsigned int maxOutputLen,const unsigned char * input,unsigned int inputLen)908 RSA_DecryptBlock(RSAPrivateKey *key,
909                  unsigned char *output,
910                  unsigned int *outputLen,
911                  unsigned int maxOutputLen,
912                  const unsigned char *input,
913                  unsigned int inputLen)
914 {
915     SECStatus rv;
916     unsigned int modulusLen = rsa_modulusLen(&key->modulus);
917     unsigned int i;
918     unsigned char *buffer;
919 
920     if (inputLen != modulusLen)
921         goto failure;
922 
923     buffer = (unsigned char *)PORT_Alloc(modulusLen + 1);
924     if (!buffer)
925         goto failure;
926 
927     rv = RSA_PrivateKeyOp(key, buffer, input);
928     if (rv != SECSuccess)
929         goto loser;
930 
931     /* XXX(rsleevi): Constant time */
932     if (buffer[0] != RSA_BLOCK_FIRST_OCTET ||
933         buffer[1] != (unsigned char)RSA_BlockPublic) {
934         goto loser;
935     }
936     *outputLen = 0;
937     for (i = 2; i < modulusLen; i++) {
938         if (buffer[i] == RSA_BLOCK_AFTER_PAD_OCTET) {
939             *outputLen = modulusLen - i - 1;
940             break;
941         }
942     }
943     if (*outputLen == 0)
944         goto loser;
945     if (*outputLen > maxOutputLen)
946         goto loser;
947 
948     PORT_Memcpy(output, buffer + modulusLen - *outputLen, *outputLen);
949 
950     PORT_Free(buffer);
951     return SECSuccess;
952 
953 loser:
954     PORT_Free(buffer);
955 failure:
956     return SECFailure;
957 }
958 
959 /*
960  * Encode a RSA-PSS signature.
961  * Described in RFC 3447, section 9.1.1.
962  * We use mHash instead of M as input.
963  * emBits from the RFC is just modBits - 1, see section 8.1.1.
964  * We only support MGF1 as the MGF.
965  *
966  * NOTE: this code assumes modBits is a multiple of 8.
967  */
968 static SECStatus
emsa_pss_encode(unsigned char * em,unsigned int emLen,const unsigned char * mHash,HASH_HashType hashAlg,HASH_HashType maskHashAlg,const unsigned char * salt,unsigned int saltLen)969 emsa_pss_encode(unsigned char *em,
970                 unsigned int emLen,
971                 const unsigned char *mHash,
972                 HASH_HashType hashAlg,
973                 HASH_HashType maskHashAlg,
974                 const unsigned char *salt,
975                 unsigned int saltLen)
976 {
977     const SECHashObject *hash;
978     void *hash_context;
979     unsigned char *dbMask;
980     unsigned int dbMaskLen;
981     unsigned int i;
982     SECStatus rv;
983 
984     hash = HASH_GetRawHashObject(hashAlg);
985     dbMaskLen = emLen - hash->length - 1;
986 
987     /* Step 3 */
988     if (emLen < hash->length + saltLen + 2) {
989         PORT_SetError(SEC_ERROR_OUTPUT_LEN);
990         return SECFailure;
991     }
992 
993     /* Step 4 */
994     if (salt == NULL) {
995         rv = RNG_GenerateGlobalRandomBytes(&em[dbMaskLen - saltLen], saltLen);
996         if (rv != SECSuccess) {
997             return rv;
998         }
999     } else {
1000         PORT_Memcpy(&em[dbMaskLen - saltLen], salt, saltLen);
1001     }
1002 
1003     /* Step 5 + 6 */
1004     /* Compute H and store it at its final location &em[dbMaskLen]. */
1005     hash_context = (*hash->create)();
1006     if (hash_context == NULL) {
1007         PORT_SetError(SEC_ERROR_NO_MEMORY);
1008         return SECFailure;
1009     }
1010     (*hash->begin)(hash_context);
1011     (*hash->update)(hash_context, eightZeros, 8);
1012     (*hash->update)(hash_context, mHash, hash->length);
1013     (*hash->update)(hash_context, &em[dbMaskLen - saltLen], saltLen);
1014     (*hash->end)(hash_context, &em[dbMaskLen], &i, hash->length);
1015     (*hash->destroy)(hash_context, PR_TRUE);
1016 
1017     /* Step 7 + 8 */
1018     PORT_Memset(em, 0, dbMaskLen - saltLen - 1);
1019     em[dbMaskLen - saltLen - 1] = 0x01;
1020 
1021     /* Step 9 */
1022     dbMask = (unsigned char *)PORT_Alloc(dbMaskLen);
1023     if (dbMask == NULL) {
1024         PORT_SetError(SEC_ERROR_NO_MEMORY);
1025         return SECFailure;
1026     }
1027     MGF1(maskHashAlg, dbMask, dbMaskLen, &em[dbMaskLen], hash->length);
1028 
1029     /* Step 10 */
1030     for (i = 0; i < dbMaskLen; i++)
1031         em[i] ^= dbMask[i];
1032     PORT_Free(dbMask);
1033 
1034     /* Step 11 */
1035     em[0] &= 0x7f;
1036 
1037     /* Step 12 */
1038     em[emLen - 1] = 0xbc;
1039 
1040     return SECSuccess;
1041 }
1042 
1043 /*
1044  * Verify a RSA-PSS signature.
1045  * Described in RFC 3447, section 9.1.2.
1046  * We use mHash instead of M as input.
1047  * emBits from the RFC is just modBits - 1, see section 8.1.2.
1048  * We only support MGF1 as the MGF.
1049  *
1050  * NOTE: this code assumes modBits is a multiple of 8.
1051  */
1052 static SECStatus
emsa_pss_verify(const unsigned char * mHash,const unsigned char * em,unsigned int emLen,HASH_HashType hashAlg,HASH_HashType maskHashAlg,unsigned int saltLen)1053 emsa_pss_verify(const unsigned char *mHash,
1054                 const unsigned char *em,
1055                 unsigned int emLen,
1056                 HASH_HashType hashAlg,
1057                 HASH_HashType maskHashAlg,
1058                 unsigned int saltLen)
1059 {
1060     const SECHashObject *hash;
1061     void *hash_context;
1062     unsigned char *db;
1063     unsigned char *H_; /* H' from the RFC */
1064     unsigned int i;
1065     unsigned int dbMaskLen;
1066     SECStatus rv;
1067 
1068     hash = HASH_GetRawHashObject(hashAlg);
1069     dbMaskLen = emLen - hash->length - 1;
1070 
1071     /* Step 3 + 4 + 6 */
1072     if ((emLen < (hash->length + saltLen + 2)) ||
1073         (em[emLen - 1] != 0xbc) ||
1074         ((em[0] & 0x80) != 0)) {
1075         PORT_SetError(SEC_ERROR_BAD_SIGNATURE);
1076         return SECFailure;
1077     }
1078 
1079     /* Step 7 */
1080     db = (unsigned char *)PORT_Alloc(dbMaskLen);
1081     if (db == NULL) {
1082         PORT_SetError(SEC_ERROR_NO_MEMORY);
1083         return SECFailure;
1084     }
1085     /* &em[dbMaskLen] points to H, used as mgfSeed */
1086     MGF1(maskHashAlg, db, dbMaskLen, &em[dbMaskLen], hash->length);
1087 
1088     /* Step 8 */
1089     for (i = 0; i < dbMaskLen; i++) {
1090         db[i] ^= em[i];
1091     }
1092 
1093     /* Step 9 */
1094     db[0] &= 0x7f;
1095 
1096     /* Step 10 */
1097     for (i = 0; i < (dbMaskLen - saltLen - 1); i++) {
1098         if (db[i] != 0) {
1099             PORT_Free(db);
1100             PORT_SetError(SEC_ERROR_BAD_SIGNATURE);
1101             return SECFailure;
1102         }
1103     }
1104     if (db[dbMaskLen - saltLen - 1] != 0x01) {
1105         PORT_Free(db);
1106         PORT_SetError(SEC_ERROR_BAD_SIGNATURE);
1107         return SECFailure;
1108     }
1109 
1110     /* Step 12 + 13 */
1111     H_ = (unsigned char *)PORT_Alloc(hash->length);
1112     if (H_ == NULL) {
1113         PORT_Free(db);
1114         PORT_SetError(SEC_ERROR_NO_MEMORY);
1115         return SECFailure;
1116     }
1117     hash_context = (*hash->create)();
1118     if (hash_context == NULL) {
1119         PORT_Free(db);
1120         PORT_Free(H_);
1121         PORT_SetError(SEC_ERROR_NO_MEMORY);
1122         return SECFailure;
1123     }
1124     (*hash->begin)(hash_context);
1125     (*hash->update)(hash_context, eightZeros, 8);
1126     (*hash->update)(hash_context, mHash, hash->length);
1127     (*hash->update)(hash_context, &db[dbMaskLen - saltLen], saltLen);
1128     (*hash->end)(hash_context, H_, &i, hash->length);
1129     (*hash->destroy)(hash_context, PR_TRUE);
1130 
1131     PORT_Free(db);
1132 
1133     /* Step 14 */
1134     if (PORT_Memcmp(H_, &em[dbMaskLen], hash->length) != 0) {
1135         PORT_SetError(SEC_ERROR_BAD_SIGNATURE);
1136         rv = SECFailure;
1137     } else {
1138         rv = SECSuccess;
1139     }
1140 
1141     PORT_Free(H_);
1142     return rv;
1143 }
1144 
1145 SECStatus
RSA_SignPSS(RSAPrivateKey * key,HASH_HashType hashAlg,HASH_HashType maskHashAlg,const unsigned char * salt,unsigned int saltLength,unsigned char * output,unsigned int * outputLen,unsigned int maxOutputLen,const unsigned char * input,unsigned int inputLen)1146 RSA_SignPSS(RSAPrivateKey *key,
1147             HASH_HashType hashAlg,
1148             HASH_HashType maskHashAlg,
1149             const unsigned char *salt,
1150             unsigned int saltLength,
1151             unsigned char *output,
1152             unsigned int *outputLen,
1153             unsigned int maxOutputLen,
1154             const unsigned char *input,
1155             unsigned int inputLen)
1156 {
1157     SECStatus rv = SECSuccess;
1158     unsigned int modulusLen = rsa_modulusLen(&key->modulus);
1159     unsigned char *pssEncoded = NULL;
1160 
1161     if (maxOutputLen < modulusLen) {
1162         PORT_SetError(SEC_ERROR_OUTPUT_LEN);
1163         return SECFailure;
1164     }
1165 
1166     if ((hashAlg == HASH_AlgNULL) || (maskHashAlg == HASH_AlgNULL)) {
1167         PORT_SetError(SEC_ERROR_INVALID_ALGORITHM);
1168         return SECFailure;
1169     }
1170 
1171     pssEncoded = (unsigned char *)PORT_Alloc(modulusLen);
1172     if (pssEncoded == NULL) {
1173         PORT_SetError(SEC_ERROR_NO_MEMORY);
1174         return SECFailure;
1175     }
1176     rv = emsa_pss_encode(pssEncoded, modulusLen, input, hashAlg,
1177                          maskHashAlg, salt, saltLength);
1178     if (rv != SECSuccess)
1179         goto done;
1180 
1181     rv = RSA_PrivateKeyOpDoubleChecked(key, output, pssEncoded);
1182     *outputLen = modulusLen;
1183 
1184 done:
1185     PORT_Free(pssEncoded);
1186     return rv;
1187 }
1188 
1189 SECStatus
RSA_CheckSignPSS(RSAPublicKey * key,HASH_HashType hashAlg,HASH_HashType maskHashAlg,unsigned int saltLength,const unsigned char * sig,unsigned int sigLen,const unsigned char * hash,unsigned int hashLen)1190 RSA_CheckSignPSS(RSAPublicKey *key,
1191                  HASH_HashType hashAlg,
1192                  HASH_HashType maskHashAlg,
1193                  unsigned int saltLength,
1194                  const unsigned char *sig,
1195                  unsigned int sigLen,
1196                  const unsigned char *hash,
1197                  unsigned int hashLen)
1198 {
1199     SECStatus rv;
1200     unsigned int modulusLen = rsa_modulusLen(&key->modulus);
1201     unsigned char *buffer;
1202 
1203     if (sigLen != modulusLen) {
1204         PORT_SetError(SEC_ERROR_BAD_SIGNATURE);
1205         return SECFailure;
1206     }
1207 
1208     if ((hashAlg == HASH_AlgNULL) || (maskHashAlg == HASH_AlgNULL)) {
1209         PORT_SetError(SEC_ERROR_INVALID_ALGORITHM);
1210         return SECFailure;
1211     }
1212 
1213     buffer = (unsigned char *)PORT_Alloc(modulusLen);
1214     if (!buffer) {
1215         PORT_SetError(SEC_ERROR_NO_MEMORY);
1216         return SECFailure;
1217     }
1218 
1219     rv = RSA_PublicKeyOp(key, buffer, sig);
1220     if (rv != SECSuccess) {
1221         PORT_Free(buffer);
1222         PORT_SetError(SEC_ERROR_BAD_SIGNATURE);
1223         return SECFailure;
1224     }
1225 
1226     rv = emsa_pss_verify(hash, buffer, modulusLen, hashAlg,
1227                          maskHashAlg, saltLength);
1228     PORT_Free(buffer);
1229 
1230     return rv;
1231 }
1232 
1233 /* XXX Doesn't set error code */
1234 SECStatus
RSA_Sign(RSAPrivateKey * key,unsigned char * output,unsigned int * outputLen,unsigned int maxOutputLen,const unsigned char * input,unsigned int inputLen)1235 RSA_Sign(RSAPrivateKey *key,
1236          unsigned char *output,
1237          unsigned int *outputLen,
1238          unsigned int maxOutputLen,
1239          const unsigned char *input,
1240          unsigned int inputLen)
1241 {
1242     SECStatus rv = SECSuccess;
1243     unsigned int modulusLen = rsa_modulusLen(&key->modulus);
1244     SECItem formatted;
1245     SECItem unformatted;
1246 
1247     if (maxOutputLen < modulusLen)
1248         return SECFailure;
1249 
1250     unformatted.len = inputLen;
1251     unformatted.data = (unsigned char *)input;
1252     formatted.data = NULL;
1253     rv = rsa_FormatBlock(&formatted, modulusLen, RSA_BlockPrivate,
1254                          &unformatted);
1255     if (rv != SECSuccess)
1256         goto done;
1257 
1258     rv = RSA_PrivateKeyOpDoubleChecked(key, output, formatted.data);
1259     *outputLen = modulusLen;
1260 
1261     goto done;
1262 
1263 done:
1264     if (formatted.data != NULL)
1265         PORT_ZFree(formatted.data, modulusLen);
1266     return rv;
1267 }
1268 
1269 /* XXX Doesn't set error code */
1270 SECStatus
RSA_CheckSign(RSAPublicKey * key,const unsigned char * sig,unsigned int sigLen,const unsigned char * data,unsigned int dataLen)1271 RSA_CheckSign(RSAPublicKey *key,
1272               const unsigned char *sig,
1273               unsigned int sigLen,
1274               const unsigned char *data,
1275               unsigned int dataLen)
1276 {
1277     SECStatus rv;
1278     unsigned int modulusLen = rsa_modulusLen(&key->modulus);
1279     unsigned int i;
1280     unsigned char *buffer;
1281 
1282     if (sigLen != modulusLen)
1283         goto failure;
1284     /*
1285      * 0x00 || BT || Pad || 0x00 || ActualData
1286      *
1287      * The "3" below is the first octet + the second octet + the 0x00
1288      * octet that always comes just before the ActualData.
1289      */
1290     if (dataLen > modulusLen - (3 + RSA_BLOCK_MIN_PAD_LEN))
1291         goto failure;
1292 
1293     buffer = (unsigned char *)PORT_Alloc(modulusLen + 1);
1294     if (!buffer)
1295         goto failure;
1296 
1297     rv = RSA_PublicKeyOp(key, buffer, sig);
1298     if (rv != SECSuccess)
1299         goto loser;
1300 
1301     /*
1302      * check the padding that was used
1303      */
1304     if (buffer[0] != RSA_BLOCK_FIRST_OCTET ||
1305         buffer[1] != (unsigned char)RSA_BlockPrivate) {
1306         goto loser;
1307     }
1308     for (i = 2; i < modulusLen - dataLen - 1; i++) {
1309         if (buffer[i] != RSA_BLOCK_PRIVATE_PAD_OCTET)
1310             goto loser;
1311     }
1312     if (buffer[i] != RSA_BLOCK_AFTER_PAD_OCTET)
1313         goto loser;
1314 
1315     /*
1316      * make sure we get the same results
1317      */
1318     if (PORT_Memcmp(buffer + modulusLen - dataLen, data, dataLen) != 0)
1319         goto loser;
1320 
1321     PORT_Free(buffer);
1322     return SECSuccess;
1323 
1324 loser:
1325     PORT_Free(buffer);
1326 failure:
1327     return SECFailure;
1328 }
1329 
1330 /* XXX Doesn't set error code */
1331 SECStatus
RSA_CheckSignRecover(RSAPublicKey * key,unsigned char * output,unsigned int * outputLen,unsigned int maxOutputLen,const unsigned char * sig,unsigned int sigLen)1332 RSA_CheckSignRecover(RSAPublicKey *key,
1333                      unsigned char *output,
1334                      unsigned int *outputLen,
1335                      unsigned int maxOutputLen,
1336                      const unsigned char *sig,
1337                      unsigned int sigLen)
1338 {
1339     SECStatus rv;
1340     unsigned int modulusLen = rsa_modulusLen(&key->modulus);
1341     unsigned int i;
1342     unsigned char *buffer;
1343 
1344     if (sigLen != modulusLen)
1345         goto failure;
1346 
1347     buffer = (unsigned char *)PORT_Alloc(modulusLen + 1);
1348     if (!buffer)
1349         goto failure;
1350 
1351     rv = RSA_PublicKeyOp(key, buffer, sig);
1352     if (rv != SECSuccess)
1353         goto loser;
1354     *outputLen = 0;
1355 
1356     /*
1357      * check the padding that was used
1358      */
1359     if (buffer[0] != RSA_BLOCK_FIRST_OCTET ||
1360         buffer[1] != (unsigned char)RSA_BlockPrivate) {
1361         goto loser;
1362     }
1363     for (i = 2; i < modulusLen; i++) {
1364         if (buffer[i] == RSA_BLOCK_AFTER_PAD_OCTET) {
1365             *outputLen = modulusLen - i - 1;
1366             break;
1367         }
1368         if (buffer[i] != RSA_BLOCK_PRIVATE_PAD_OCTET)
1369             goto loser;
1370     }
1371     if (*outputLen == 0)
1372         goto loser;
1373     if (*outputLen > maxOutputLen)
1374         goto loser;
1375 
1376     PORT_Memcpy(output, buffer + modulusLen - *outputLen, *outputLen);
1377 
1378     PORT_Free(buffer);
1379     return SECSuccess;
1380 
1381 loser:
1382     PORT_Free(buffer);
1383 failure:
1384     return SECFailure;
1385 }
1386