1 #include "os.h"
2 #include <mp.h>
3 #include <libsec.h>
4 
5 RSApriv*
rsafill(mpint * n,mpint * e,mpint * d,mpint * p,mpint * q)6 rsafill(mpint *n, mpint *e, mpint *d, mpint *p, mpint *q)
7 {
8 	mpint *c2, *kq, *kp, *x;
9 	RSApriv *rsa;
10 
11 	// make sure we're not being hoodwinked
12 	if(!probably_prime(p, 10) || !probably_prime(q, 10)){
13 		werrstr("rsafill: p or q not prime");
14 		return nil;
15 	}
16 	x = mpnew(0);
17 	mpmul(p, q, x);
18 	if(mpcmp(n, x) != 0){
19 		werrstr("rsafill: n != p*q");
20 		mpfree(x);
21 		return nil;
22 	}
23 	c2 = mpnew(0);
24 	mpsub(p, mpone, c2);
25 	mpsub(q, mpone, x);
26 	mpmul(c2, x, x);
27 	mpmul(e, d, c2);
28 	mpmod(c2, x, x);
29 	if(mpcmp(x, mpone) != 0){
30 		werrstr("rsafill: e*d != 1 mod (p-1)*(q-1)");
31 		mpfree(x);
32 		mpfree(c2);
33 		return nil;
34 	}
35 
36 	// compute chinese remainder coefficient
37 	mpinvert(p, q, c2);
38 
39 	// for crt a**k mod p == (a**(k mod p-1)) mod p
40 	kq = mpnew(0);
41 	kp = mpnew(0);
42 	mpsub(p, mpone, x);
43 	mpmod(d, x, kp);
44 	mpsub(q, mpone, x);
45 	mpmod(d, x, kq);
46 
47 	rsa = rsaprivalloc();
48 	rsa->pub.ek = mpcopy(e);
49 	rsa->pub.n = mpcopy(n);
50 	rsa->dk = mpcopy(d);
51 	rsa->kp = kp;
52 	rsa->kq = kq;
53 	rsa->p = mpcopy(p);
54 	rsa->q = mpcopy(q);
55 	rsa->c2 = c2;
56 
57 	mpfree(x);
58 
59 	return rsa;
60 }
61 
62