1 /*
2  * RSA key generation.
3  */
4 
5 #include <assert.h>
6 
7 #include "ssh.h"
8 #include "sshkeygen.h"
9 #include "mpint.h"
10 
11 #define RSA_EXPONENT 65537
12 
13 #define NFIRSTBITS 13
14 static void invent_firstbits(unsigned *one, unsigned *two,
15                              unsigned min_separation);
16 
17 typedef struct RSAPrimeDetails RSAPrimeDetails;
18 struct RSAPrimeDetails {
19     bool strong;
20     int bits, bitsm1m1, bitsm1, bitsp1;
21     unsigned firstbits;
22     ProgressPhase phase_main, phase_m1m1, phase_m1, phase_p1;
23 };
24 
25 #define STRONG_MARGIN (20 + NFIRSTBITS)
26 
setup_rsa_prime(int bits,bool strong,PrimeGenerationContext * pgc,ProgressReceiver * prog)27 static RSAPrimeDetails setup_rsa_prime(
28     int bits, bool strong, PrimeGenerationContext *pgc, ProgressReceiver *prog)
29 {
30     RSAPrimeDetails pd;
31     pd.bits = bits;
32     if (strong) {
33         pd.bitsm1 = (bits - STRONG_MARGIN) / 2;
34         pd.bitsp1 = (bits - STRONG_MARGIN) - pd.bitsm1;
35         pd.bitsm1m1 = (pd.bitsm1 - STRONG_MARGIN) / 2;
36         if (pd.bitsm1m1 < STRONG_MARGIN) {
37             /* Absurdly small prime, but we should at least not crash. */
38             strong = false;
39         }
40     }
41     pd.strong = strong;
42 
43     if (pd.strong) {
44         pd.phase_m1m1 = primegen_add_progress_phase(pgc, prog, pd.bitsm1m1);
45         pd.phase_m1 = primegen_add_progress_phase(pgc, prog, pd.bitsm1);
46         pd.phase_p1 = primegen_add_progress_phase(pgc, prog, pd.bitsp1);
47     }
48     pd.phase_main = primegen_add_progress_phase(pgc, prog, pd.bits);
49 
50     return pd;
51 }
52 
generate_rsa_prime(RSAPrimeDetails pd,PrimeGenerationContext * pgc,ProgressReceiver * prog)53 static mp_int *generate_rsa_prime(
54     RSAPrimeDetails pd, PrimeGenerationContext *pgc, ProgressReceiver *prog)
55 {
56     mp_int *m1m1 = NULL, *m1 = NULL, *p1 = NULL, *p = NULL;
57     PrimeCandidateSource *pcs;
58 
59     if (pd.strong) {
60         progress_start_phase(prog, pd.phase_m1m1);
61         pcs = pcs_new_with_firstbits(pd.bitsm1m1, pd.firstbits, NFIRSTBITS);
62         m1m1 = primegen_generate(pgc, pcs, prog);
63         progress_report_phase_complete(prog);
64 
65         progress_start_phase(prog, pd.phase_m1);
66         pcs = pcs_new_with_firstbits(pd.bitsm1, pd.firstbits, NFIRSTBITS);
67         pcs_require_residue_1_mod_prime(pcs, m1m1);
68         m1 = primegen_generate(pgc, pcs, prog);
69         progress_report_phase_complete(prog);
70 
71         progress_start_phase(prog, pd.phase_p1);
72         pcs = pcs_new_with_firstbits(pd.bitsp1, pd.firstbits, NFIRSTBITS);
73         p1 = primegen_generate(pgc, pcs, prog);
74         progress_report_phase_complete(prog);
75     }
76 
77     progress_start_phase(prog, pd.phase_main);
78     pcs = pcs_new_with_firstbits(pd.bits, pd.firstbits, NFIRSTBITS);
79     pcs_avoid_residue_small(pcs, RSA_EXPONENT, 1);
80     if (pd.strong) {
81         pcs_require_residue_1_mod_prime(pcs, m1);
82         mp_int *p1_minus_1 = mp_copy(p1);
83         mp_sub_integer_into(p1_minus_1, p1, 1);
84         pcs_require_residue(pcs, p1, p1_minus_1);
85         mp_free(p1_minus_1);
86     }
87     p = primegen_generate(pgc, pcs, prog);
88     progress_report_phase_complete(prog);
89 
90     if (m1m1)
91         mp_free(m1m1);
92     if (m1)
93         mp_free(m1);
94     if (p1)
95         mp_free(p1);
96 
97     return p;
98 }
99 
rsa_generate(RSAKey * key,int bits,bool strong,PrimeGenerationContext * pgc,ProgressReceiver * prog)100 int rsa_generate(RSAKey *key, int bits, bool strong,
101                  PrimeGenerationContext *pgc, ProgressReceiver *prog)
102 {
103     key->sshk.vt = &ssh_rsa;
104 
105     /*
106      * We don't generate e; we just use a standard one always.
107      */
108     mp_int *exponent = mp_from_integer(RSA_EXPONENT);
109 
110     /*
111      * Generate p and q: primes with combined length `bits', not
112      * congruent to 1 modulo e. (Strictly speaking, we wanted (p-1)
113      * and e to be coprime, and (q-1) and e to be coprime, but in
114      * general that's slightly more fiddly to arrange. By choosing
115      * a prime e, we can simplify the criterion.)
116      *
117      * We give a min_separation of 2 to invent_firstbits(), ensuring
118      * that the two primes won't be very close to each other. (The
119      * chance of them being _dangerously_ close is negligible - even
120      * more so than an attacker guessing a whole 256-bit session key -
121      * but it doesn't cost much to make sure.)
122      */
123     int qbits = bits / 2;
124     int pbits = bits - qbits;
125     assert(pbits >= qbits);
126 
127     RSAPrimeDetails pd = setup_rsa_prime(pbits, strong, pgc, prog);
128     RSAPrimeDetails qd = setup_rsa_prime(qbits, strong, pgc, prog);
129     progress_ready(prog);
130 
131     invent_firstbits(&pd.firstbits, &qd.firstbits, 2);
132 
133     mp_int *p = generate_rsa_prime(pd, pgc, prog);
134     mp_int *q = generate_rsa_prime(qd, pgc, prog);
135 
136     /*
137      * Ensure p > q, by swapping them if not.
138      *
139      * We only need to do this if the two primes were generated with
140      * the same number of bits (i.e. if the requested key size is
141      * even) - otherwise it's already guaranteed!
142      */
143     if (pbits == qbits) {
144         mp_cond_swap(p, q, mp_cmp_hs(q, p));
145     } else {
146         assert(mp_cmp_hs(p, q));
147     }
148 
149     /*
150      * Now we have p, q and e. All we need to do now is work out
151      * the other helpful quantities: n=pq, d=e^-1 mod (p-1)(q-1),
152      * and (q^-1 mod p).
153      */
154     mp_int *modulus = mp_mul(p, q);
155     mp_int *pm1 = mp_copy(p);
156     mp_sub_integer_into(pm1, pm1, 1);
157     mp_int *qm1 = mp_copy(q);
158     mp_sub_integer_into(qm1, qm1, 1);
159     mp_int *phi_n = mp_mul(pm1, qm1);
160     mp_free(pm1);
161     mp_free(qm1);
162     mp_int *private_exponent = mp_invert(exponent, phi_n);
163     mp_free(phi_n);
164     mp_int *iqmp = mp_invert(q, p);
165 
166     /*
167      * Populate the returned structure.
168      */
169     key->modulus = modulus;
170     key->exponent = exponent;
171     key->private_exponent = private_exponent;
172     key->p = p;
173     key->q = q;
174     key->iqmp = iqmp;
175 
176     key->bits = mp_get_nbits(modulus);
177     key->bytes = (key->bits + 7) / 8;
178 
179     return 1;
180 }
181 
182 /*
183  * Invent a pair of values suitable for use as the 'firstbits' values
184  * for the two RSA primes, such that their product is at least 2, and
185  * such that their difference is also at least min_separation.
186  *
187  * This is used for generating RSA keys which have exactly the
188  * specified number of bits rather than one fewer - if you generate an
189  * a-bit and a b-bit number completely at random and multiply them
190  * together, you could end up with either an (ab-1)-bit number or an
191  * (ab)-bit number. The former happens log(2)*2-1 of the time (about
192  * 39%) and, though actually harmless, every time it occurs it has a
193  * non-zero probability of sparking a user email along the lines of
194  * 'Hey, I asked PuTTYgen for a 2048-bit key and I only got 2047 bits!
195  * Bug!'
196  */
firstbits_b_min(unsigned a,unsigned lo,unsigned hi,unsigned min_separation)197 static inline unsigned firstbits_b_min(
198     unsigned a, unsigned lo, unsigned hi, unsigned min_separation)
199 {
200     /* To get a large enough product, b must be at least this much */
201     unsigned b_min = (2*lo*lo + a - 1) / a;
202     /* Now enforce a<b, optionally with minimum separation */
203     if (b_min < a + min_separation)
204         b_min = a + min_separation;
205     /* And cap at the upper limit */
206     if (b_min > hi)
207         b_min = hi;
208     return b_min;
209 }
210 
invent_firstbits(unsigned * one,unsigned * two,unsigned min_separation)211 static void invent_firstbits(unsigned *one, unsigned *two,
212                              unsigned min_separation)
213 {
214     /*
215      * We'll pick 12 initial bits (number selected at random) for each
216      * prime, not counting the leading 1. So we want to return two
217      * values in the range [2^12,2^13) whose product is at least 2^25.
218      *
219      * Strategy: count up all the viable pairs, then select a random
220      * number in that range and use it to pick a pair.
221      *
222      * To keep things simple, we'll ensure a < b, and randomly swap
223      * them at the end.
224      */
225     const unsigned lo = 1<<12, hi = 1<<13, minproduct = 2*lo*lo;
226     unsigned a, b;
227 
228     /*
229      * Count up the number of prefixes of b that would be valid for
230      * each prefix of a.
231      */
232     mp_int *total = mp_new(32);
233     for (a = lo; a < hi; a++) {
234         unsigned b_min = firstbits_b_min(a, lo, hi, min_separation);
235         mp_add_integer_into(total, total, hi - b_min);
236     }
237 
238     /*
239      * Make up a random number in the range [0,2*total).
240      */
241     mp_int *mlo = mp_from_integer(0), *mhi = mp_new(32);
242     mp_lshift_fixed_into(mhi, total, 1);
243     mp_int *randval = mp_random_in_range(mlo, mhi);
244     mp_free(mlo);
245     mp_free(mhi);
246 
247     /*
248      * Use the low bit of randval as our swap indicator, leaving the
249      * rest of it in the range [0,total).
250      */
251     unsigned swap = mp_get_bit(randval, 0);
252     mp_rshift_fixed_into(randval, randval, 1);
253 
254     /*
255      * Now do the same counting loop again to make the actual choice.
256      */
257     a = b = 0;
258     for (unsigned a_candidate = lo; a_candidate < hi; a_candidate++) {
259         unsigned b_min = firstbits_b_min(a_candidate, lo, hi, min_separation);
260         unsigned limit = hi - b_min;
261 
262         unsigned b_candidate = b_min + mp_get_integer(randval);
263         unsigned use_it = 1 ^ mp_hs_integer(randval, limit);
264         a ^= (a ^ a_candidate) & -use_it;
265         b ^= (b ^ b_candidate) & -use_it;
266 
267         mp_sub_integer_into(randval, randval, limit);
268     }
269 
270     mp_free(randval);
271     mp_free(total);
272 
273     /*
274      * Check everything came out right.
275      */
276     assert(lo <= a);
277     assert(a < hi);
278     assert(lo <= b);
279     assert(b < hi);
280     assert(a * b >= minproduct);
281     assert(b >= a + min_separation);
282 
283     /*
284      * Last-minute optional swap of a and b.
285      */
286     unsigned diff = (a ^ b) & (-swap);
287     a ^= diff;
288     b ^= diff;
289 
290     *one = a;
291     *two = b;
292 }
293