1 /*	$NetBSD: rsa-ltm.c,v 1.1.1.1 2011/04/13 18:14:51 elric Exp $	*/
2 
3 /*
4  * Copyright (c) 2006 - 2007, 2010 Kungliga Tekniska Högskolan
5  * (Royal Institute of Technology, Stockholm, Sweden).
6  * All rights reserved.
7  *
8  * Redistribution and use in source and binary forms, with or without
9  * modification, are permitted provided that the following conditions
10  * are met:
11  *
12  * 1. Redistributions of source code must retain the above copyright
13  *    notice, this list of conditions and the following disclaimer.
14  *
15  * 2. Redistributions in binary form must reproduce the above copyright
16  *    notice, this list of conditions and the following disclaimer in the
17  *    documentation and/or other materials provided with the distribution.
18  *
19  * 3. Neither the name of the Institute nor the names of its contributors
20  *    may be used to endorse or promote products derived from this software
21  *    without specific prior written permission.
22  *
23  * THIS SOFTWARE IS PROVIDED BY THE INSTITUTE AND CONTRIBUTORS ``AS IS'' AND
24  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
25  * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
26  * ARE DISCLAIMED.  IN NO EVENT SHALL THE INSTITUTE OR CONTRIBUTORS BE LIABLE
27  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
28  * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
29  * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
30  * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
31  * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
32  * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
33  * SUCH DAMAGE.
34  */
35 
36 #include <config.h>
37 
38 #include <stdio.h>
39 #include <stdlib.h>
40 #include <krb5/krb5-types.h>
41 #include <assert.h>
42 
43 #include <rsa.h>
44 
45 #include <krb5/roken.h>
46 
47 #include "tommath.h"
48 
49 static int
50 random_num(mp_int *num, size_t len)
51 {
52     unsigned char *p;
53 
54     len = (len + 7) / 8;
55     p = malloc(len);
56     if (p == NULL)
57 	return 1;
58     if (RAND_bytes(p, len) != 1) {
59 	free(p);
60 	return 1;
61     }
62     mp_read_unsigned_bin(num, p, len);
63     free(p);
64     return 0;
65 }
66 
67 static void
68 BN2mpz(mp_int *s, const BIGNUM *bn)
69 {
70     size_t len;
71     void *p;
72 
73     len = BN_num_bytes(bn);
74     p = malloc(len);
75     BN_bn2bin(bn, p);
76     mp_read_unsigned_bin(s, p, len);
77     free(p);
78 }
79 
80 static void
81 setup_blind(mp_int *n, mp_int *b, mp_int *bi)
82 {
83     random_num(b, mp_count_bits(n));
84     mp_mod(b, n, b);
85     mp_invmod(b, n, bi);
86 }
87 
88 static void
89 blind(mp_int *in, mp_int *b, mp_int *e, mp_int *n)
90 {
91     mp_int t1;
92     mp_init(&t1);
93     /* in' = (in * b^e) mod n */
94     mp_exptmod(b, e, n, &t1);
95     mp_mul(&t1, in, in);
96     mp_mod(in, n, in);
97     mp_clear(&t1);
98 }
99 
100 static void
101 unblind(mp_int *out, mp_int *bi, mp_int *n)
102 {
103     /* out' = (out * 1/b) mod n */
104     mp_mul(out, bi, out);
105     mp_mod(out, n, out);
106 }
107 
108 static int
109 ltm_rsa_private_calculate(mp_int * in, mp_int * p,  mp_int * q,
110 			  mp_int * dmp1, mp_int * dmq1, mp_int * iqmp,
111 			  mp_int * out)
112 {
113     mp_int vp, vq, u;
114 
115     mp_init_multi(&vp, &vq, &u, NULL);
116 
117     /* vq = c ^ (d mod (q - 1)) mod q */
118     /* vp = c ^ (d mod (p - 1)) mod p */
119     mp_mod(in, p, &u);
120     mp_exptmod(&u, dmp1, p, &vp);
121     mp_mod(in, q, &u);
122     mp_exptmod(&u, dmq1, q, &vq);
123 
124     /* C2 = 1/q mod p  (iqmp) */
125     /* u = (vp - vq)C2 mod p. */
126     mp_sub(&vp, &vq, &u);
127     if (mp_isneg(&u))
128 	mp_add(&u, p, &u);
129     mp_mul(&u, iqmp, &u);
130     mp_mod(&u, p, &u);
131 
132     /* c ^ d mod n = vq + u q */
133     mp_mul(&u, q, &u);
134     mp_add(&u, &vq, out);
135 
136     mp_clear_multi(&vp, &vq, &u, NULL);
137 
138     return 0;
139 }
140 
141 /*
142  *
143  */
144 
145 static int
146 ltm_rsa_public_encrypt(int flen, const unsigned char* from,
147 			unsigned char* to, RSA* rsa, int padding)
148 {
149     unsigned char *p, *p0;
150     int res;
151     size_t size, padlen;
152     mp_int enc, dec, n, e;
153 
154     if (padding != RSA_PKCS1_PADDING)
155 	return -1;
156 
157     mp_init_multi(&n, &e, &enc, &dec, NULL);
158 
159     size = RSA_size(rsa);
160 
161     if (size < RSA_PKCS1_PADDING_SIZE || size - RSA_PKCS1_PADDING_SIZE < flen) {
162 	mp_clear_multi(&n, &e, &enc, &dec);
163 	return -2;
164     }
165 
166     BN2mpz(&n, rsa->n);
167     BN2mpz(&e, rsa->e);
168 
169     p = p0 = malloc(size - 1);
170     if (p0 == NULL) {
171 	mp_clear_multi(&e, &n, &enc, &dec, NULL);
172 	return -3;
173     }
174 
175     padlen = size - flen - 3;
176 
177     *p++ = 2;
178     if (RAND_bytes(p, padlen) != 1) {
179 	mp_clear_multi(&e, &n, &enc, &dec, NULL);
180 	free(p0);
181 	return -4;
182     }
183     while(padlen) {
184 	if (*p == 0)
185 	    *p = 1;
186 	padlen--;
187 	p++;
188     }
189     *p++ = 0;
190     memcpy(p, from, flen);
191     p += flen;
192     assert((p - p0) == size - 1);
193 
194     mp_read_unsigned_bin(&dec, p0, size - 1);
195     free(p0);
196 
197     res = mp_exptmod(&dec, &e, &n, &enc);
198 
199     mp_clear_multi(&dec, &e, &n, NULL);
200 
201     if (res != 0) {
202 	mp_clear(&enc);
203 	return -4;
204     }
205 
206     {
207 	size_t ssize;
208 	ssize = mp_unsigned_bin_size(&enc);
209 	assert(size >= ssize);
210 	mp_to_unsigned_bin(&enc, to);
211 	size = ssize;
212     }
213     mp_clear(&enc);
214 
215     return size;
216 }
217 
218 static int
219 ltm_rsa_public_decrypt(int flen, const unsigned char* from,
220 		       unsigned char* to, RSA* rsa, int padding)
221 {
222     unsigned char *p;
223     int res;
224     size_t size;
225     mp_int s, us, n, e;
226 
227     if (padding != RSA_PKCS1_PADDING)
228 	return -1;
229 
230     if (flen > RSA_size(rsa))
231 	return -2;
232 
233     mp_init_multi(&e, &n, &s, &us, NULL);
234 
235     BN2mpz(&n, rsa->n);
236     BN2mpz(&e, rsa->e);
237 
238 #if 0
239     /* Check that the exponent is larger then 3 */
240     if (mp_int_compare_value(&e, 3) <= 0) {
241 	mp_clear_multi(&e, &n, &s, &us, NULL);
242 	return -3;
243     }
244 #endif
245 
246     mp_read_unsigned_bin(&s, rk_UNCONST(from), flen);
247 
248     if (mp_cmp(&s, &n) >= 0) {
249 	mp_clear_multi(&e, &n, &s, &us, NULL);
250 	return -4;
251     }
252 
253     res = mp_exptmod(&s, &e, &n, &us);
254 
255     mp_clear_multi(&e, &n, &s, NULL);
256 
257     if (res != 0) {
258 	mp_clear(&us);
259 	return -5;
260     }
261     p = to;
262 
263 
264     size = mp_unsigned_bin_size(&us);
265     assert(size <= RSA_size(rsa));
266     mp_to_unsigned_bin(&us, p);
267 
268     mp_clear(&us);
269 
270     /* head zero was skipped by mp_to_unsigned_bin */
271     if (*p == 0)
272 	return -6;
273     if (*p != 1)
274 	return -7;
275     size--; p++;
276     while (size && *p == 0xff) {
277 	size--; p++;
278     }
279     if (size == 0 || *p != 0)
280 	return -8;
281     size--; p++;
282 
283     memmove(to, p, size);
284 
285     return size;
286 }
287 
288 static int
289 ltm_rsa_private_encrypt(int flen, const unsigned char* from,
290 			unsigned char* to, RSA* rsa, int padding)
291 {
292     unsigned char *p, *p0;
293     int res;
294     int size;
295     mp_int in, out, n, e;
296     mp_int bi, b;
297     int blinding = (rsa->flags & RSA_FLAG_NO_BLINDING) == 0;
298     int do_unblind = 0;
299 
300     if (padding != RSA_PKCS1_PADDING)
301 	return -1;
302 
303     mp_init_multi(&e, &n, &in, &out, &b, &bi, NULL);
304 
305     size = RSA_size(rsa);
306 
307     if (size < RSA_PKCS1_PADDING_SIZE || size - RSA_PKCS1_PADDING_SIZE < flen)
308 	return -2;
309 
310     p0 = p = malloc(size);
311     *p++ = 0;
312     *p++ = 1;
313     memset(p, 0xff, size - flen - 3);
314     p += size - flen - 3;
315     *p++ = 0;
316     memcpy(p, from, flen);
317     p += flen;
318     assert((p - p0) == size);
319 
320     BN2mpz(&n, rsa->n);
321     BN2mpz(&e, rsa->e);
322 
323     mp_read_unsigned_bin(&in, p0, size);
324     free(p0);
325 
326     if(mp_isneg(&in) || mp_cmp(&in, &n) >= 0) {
327 	size = -3;
328 	goto out;
329     }
330 
331     if (blinding) {
332 	setup_blind(&n, &b, &bi);
333 	blind(&in, &b, &e, &n);
334 	do_unblind = 1;
335     }
336 
337     if (rsa->p && rsa->q && rsa->dmp1 && rsa->dmq1 && rsa->iqmp) {
338 	mp_int p, q, dmp1, dmq1, iqmp;
339 
340 	mp_init_multi(&p, &q, &dmp1, &dmq1, &iqmp, NULL);
341 
342 	BN2mpz(&p, rsa->p);
343 	BN2mpz(&q, rsa->q);
344 	BN2mpz(&dmp1, rsa->dmp1);
345 	BN2mpz(&dmq1, rsa->dmq1);
346 	BN2mpz(&iqmp, rsa->iqmp);
347 
348 	res = ltm_rsa_private_calculate(&in, &p, &q, &dmp1, &dmq1, &iqmp, &out);
349 
350 	mp_clear_multi(&p, &q, &dmp1, &dmq1, &iqmp, NULL);
351 
352 	if (res != 0) {
353 	    size = -4;
354 	    goto out;
355 	}
356     } else {
357 	mp_int d;
358 
359 	BN2mpz(&d, rsa->d);
360 	res = mp_exptmod(&in, &d, &n, &out);
361 	mp_clear(&d);
362 	if (res != 0) {
363 	    size = -5;
364 	    goto out;
365 	}
366     }
367 
368     if (do_unblind)
369 	unblind(&out, &bi, &n);
370 
371     if (size > 0) {
372 	size_t ssize;
373 	ssize = mp_unsigned_bin_size(&out);
374 	assert(size >= ssize);
375 	mp_to_unsigned_bin(&out, to);
376 	size = ssize;
377     }
378 
379  out:
380     mp_clear_multi(&e, &n, &in, &out, &b, &bi, NULL);
381 
382     return size;
383 }
384 
385 static int
386 ltm_rsa_private_decrypt(int flen, const unsigned char* from,
387 			unsigned char* to, RSA* rsa, int padding)
388 {
389     unsigned char *ptr;
390     int res, size;
391     mp_int in, out, n, e, b, bi;
392     int blinding = (rsa->flags & RSA_FLAG_NO_BLINDING) == 0;
393     int do_unblind = 0;
394 
395     if (padding != RSA_PKCS1_PADDING)
396 	return -1;
397 
398     size = RSA_size(rsa);
399     if (flen > size)
400 	return -2;
401 
402     mp_init_multi(&in, &n, &e, &out, &b, &bi, NULL);
403 
404     BN2mpz(&n, rsa->n);
405     BN2mpz(&e, rsa->e);
406 
407     mp_read_unsigned_bin(&in, rk_UNCONST(from), flen);
408 
409     if(mp_isneg(&in) || mp_cmp(&in, &n) >= 0) {
410 	size = -2;
411 	goto out;
412     }
413 
414     if (blinding) {
415 	setup_blind(&n, &b, &bi);
416 	blind(&in, &b, &e, &n);
417 	do_unblind = 1;
418     }
419 
420     if (rsa->p && rsa->q && rsa->dmp1 && rsa->dmq1 && rsa->iqmp) {
421 	mp_int p, q, dmp1, dmq1, iqmp;
422 
423 	mp_init_multi(&p, &q, &dmp1, &dmq1, &iqmp, NULL);
424 
425 	BN2mpz(&p, rsa->p);
426 	BN2mpz(&q, rsa->q);
427 	BN2mpz(&dmp1, rsa->dmp1);
428 	BN2mpz(&dmq1, rsa->dmq1);
429 	BN2mpz(&iqmp, rsa->iqmp);
430 
431 	res = ltm_rsa_private_calculate(&in, &p, &q, &dmp1, &dmq1, &iqmp, &out);
432 
433 	mp_clear_multi(&p, &q, &dmp1, &dmq1, &iqmp, NULL);
434 
435 	if (res != 0) {
436 	    size = -3;
437 	    goto out;
438 	}
439 
440     } else {
441 	mp_int d;
442 
443 	if(mp_isneg(&in) || mp_cmp(&in, &n) >= 0)
444 	    return -4;
445 
446 	BN2mpz(&d, rsa->d);
447 	res = mp_exptmod(&in, &d, &n, &out);
448 	mp_clear(&d);
449 	if (res != 0) {
450 	    size = -5;
451 	    goto out;
452 	}
453     }
454 
455     if (do_unblind)
456 	unblind(&out, &bi, &n);
457 
458     ptr = to;
459     {
460 	size_t ssize;
461 	ssize = mp_unsigned_bin_size(&out);
462 	assert(size >= ssize);
463 	mp_to_unsigned_bin(&out, ptr);
464 	size = ssize;
465     }
466 
467     /* head zero was skipped by mp_int_to_unsigned */
468     if (*ptr != 2) {
469 	size = -6;
470 	goto out;
471     }
472     size--; ptr++;
473     while (size && *ptr != 0) {
474 	size--; ptr++;
475     }
476     if (size == 0)
477 	return -7;
478     size--; ptr++;
479 
480     memmove(to, ptr, size);
481 
482  out:
483     mp_clear_multi(&e, &n, &in, &out, &b, &bi, NULL);
484 
485     return size;
486 }
487 
488 static BIGNUM *
489 mpz2BN(mp_int *s)
490 {
491     size_t size;
492     BIGNUM *bn;
493     void *p;
494 
495     size = mp_unsigned_bin_size(s);
496     p = malloc(size);
497     if (p == NULL && size != 0)
498 	return NULL;
499 
500     mp_to_unsigned_bin(s, p);
501 
502     bn = BN_bin2bn(p, size, NULL);
503     free(p);
504     return bn;
505 }
506 
507 #define CHECK(f, v) if ((f) != (v)) { goto out; }
508 
509 static int
510 ltm_rsa_generate_key(RSA *rsa, int bits, BIGNUM *e, BN_GENCB *cb)
511 {
512     mp_int el, p, q, n, d, dmp1, dmq1, iqmp, t1, t2, t3;
513     int counter, ret, bitsp;
514 
515     if (bits < 789)
516 	return -1;
517 
518     bitsp = (bits + 1) / 2;
519 
520     ret = -1;
521 
522     mp_init_multi(&el, &p, &q, &n, &d,
523 		  &dmp1, &dmq1, &iqmp,
524 		  &t1, &t2, &t3, NULL);
525 
526     BN2mpz(&el, e);
527 
528     /* generate p and q so that p != q and bits(pq) ~ bits */
529     counter = 0;
530     do {
531 	BN_GENCB_call(cb, 2, counter++);
532 	CHECK(random_num(&p, bitsp), 0);
533 	CHECK(mp_find_prime(&p), MP_YES);
534 
535 	mp_sub_d(&p, 1, &t1);
536 	mp_gcd(&t1, &el, &t2);
537     } while(mp_cmp_d(&t2, 1) != 0);
538 
539     BN_GENCB_call(cb, 3, 0);
540 
541     counter = 0;
542     do {
543 	BN_GENCB_call(cb, 2, counter++);
544 	CHECK(random_num(&q, bits - bitsp), 0);
545 	CHECK(mp_find_prime(&q), MP_YES);
546 
547 	if (mp_cmp(&p, &q) == 0) /* don't let p and q be the same */
548 	    continue;
549 
550 	mp_sub_d(&q, 1, &t1);
551 	mp_gcd(&t1, &el, &t2);
552     } while(mp_cmp_d(&t2, 1) != 0);
553 
554     /* make p > q */
555     if (mp_cmp(&p, &q) < 0) {
556 	mp_int c;
557 	c = p;
558 	p = q;
559 	q = c;
560     }
561 
562     BN_GENCB_call(cb, 3, 1);
563 
564     /* calculate n,  		n = p * q */
565     mp_mul(&p, &q, &n);
566 
567     /* calculate d, 		d = 1/e mod (p - 1)(q - 1) */
568     mp_sub_d(&p, 1, &t1);
569     mp_sub_d(&q, 1, &t2);
570     mp_mul(&t1, &t2, &t3);
571     mp_invmod(&el, &t3, &d);
572 
573     /* calculate dmp1		dmp1 = d mod (p-1) */
574     mp_mod(&d, &t1, &dmp1);
575     /* calculate dmq1		dmq1 = d mod (q-1) */
576     mp_mod(&d, &t2, &dmq1);
577     /* calculate iqmp 		iqmp = 1/q mod p */
578     mp_invmod(&q, &p, &iqmp);
579 
580     /* fill in RSA key */
581 
582     rsa->e = mpz2BN(&el);
583     rsa->p = mpz2BN(&p);
584     rsa->q = mpz2BN(&q);
585     rsa->n = mpz2BN(&n);
586     rsa->d = mpz2BN(&d);
587     rsa->dmp1 = mpz2BN(&dmp1);
588     rsa->dmq1 = mpz2BN(&dmq1);
589     rsa->iqmp = mpz2BN(&iqmp);
590 
591     ret = 1;
592 
593 out:
594     mp_clear_multi(&el, &p, &q, &n, &d,
595 		   &dmp1, &dmq1, &iqmp,
596 		   &t1, &t2, &t3, NULL);
597 
598     return ret;
599 }
600 
601 static int
602 ltm_rsa_init(RSA *rsa)
603 {
604     return 1;
605 }
606 
607 static int
608 ltm_rsa_finish(RSA *rsa)
609 {
610     return 1;
611 }
612 
613 const RSA_METHOD hc_rsa_ltm_method = {
614     "hcrypto ltm RSA",
615     ltm_rsa_public_encrypt,
616     ltm_rsa_public_decrypt,
617     ltm_rsa_private_encrypt,
618     ltm_rsa_private_decrypt,
619     NULL,
620     NULL,
621     ltm_rsa_init,
622     ltm_rsa_finish,
623     0,
624     NULL,
625     NULL,
626     NULL,
627     ltm_rsa_generate_key
628 };
629 
630 const RSA_METHOD *
631 RSA_ltm_method(void)
632 {
633     return &hc_rsa_ltm_method;
634 }
635