1 /*
2  * RSA implementation for PuTTY.
3  */
4 
5 #include <stdio.h>
6 #include <stdlib.h>
7 #include <string.h>
8 #include <assert.h>
9 
10 #include "ssh.h"
11 #include "mpint.h"
12 #include "misc.h"
13 
BinarySource_get_rsa_ssh1_pub(BinarySource * src,RSAKey * rsa,RsaSsh1Order order)14 void BinarySource_get_rsa_ssh1_pub(
15     BinarySource *src, RSAKey *rsa, RsaSsh1Order order)
16 {
17     unsigned bits;
18     mp_int *e, *m;
19 
20     bits = get_uint32(src);
21     if (order == RSA_SSH1_EXPONENT_FIRST) {
22         e = get_mp_ssh1(src);
23         m = get_mp_ssh1(src);
24     } else {
25         m = get_mp_ssh1(src);
26         e = get_mp_ssh1(src);
27     }
28 
29     if (rsa) {
30         rsa->bits = bits;
31         rsa->exponent = e;
32         rsa->modulus = m;
33         rsa->bytes = (mp_get_nbits(m) + 7) / 8;
34     } else {
35         mp_free(e);
36         mp_free(m);
37     }
38 }
39 
BinarySource_get_rsa_ssh1_priv(BinarySource * src,RSAKey * rsa)40 void BinarySource_get_rsa_ssh1_priv(
41     BinarySource *src, RSAKey *rsa)
42 {
43     rsa->private_exponent = get_mp_ssh1(src);
44 }
45 
rsa_components(RSAKey * rsa)46 key_components *rsa_components(RSAKey *rsa)
47 {
48     key_components *kc = key_components_new();
49     key_components_add_text(kc, "key_type", "RSA");
50     key_components_add_mp(kc, "public_modulus", rsa->modulus);
51     key_components_add_mp(kc, "public_exponent", rsa->exponent);
52     if (rsa->private_exponent) {
53         key_components_add_mp(kc, "private_exponent", rsa->private_exponent);
54         key_components_add_mp(kc, "private_p", rsa->p);
55         key_components_add_mp(kc, "private_q", rsa->q);
56         key_components_add_mp(kc, "private_inverse_q_mod_p", rsa->iqmp);
57     }
58     return kc;
59 }
60 
BinarySource_get_rsa_ssh1_priv_agent(BinarySource * src)61 RSAKey *BinarySource_get_rsa_ssh1_priv_agent(BinarySource *src)
62 {
63     RSAKey *rsa = snew(RSAKey);
64     memset(rsa, 0, sizeof(RSAKey));
65 
66     get_rsa_ssh1_pub(src, rsa, RSA_SSH1_MODULUS_FIRST);
67     get_rsa_ssh1_priv(src, rsa);
68 
69     /* SSH-1 names p and q the other way round, i.e. we have the
70      * inverse of p mod q and not of q mod p. We swap the names,
71      * because our internal RSA wants iqmp. */
72     rsa->iqmp = get_mp_ssh1(src);
73     rsa->q = get_mp_ssh1(src);
74     rsa->p = get_mp_ssh1(src);
75 
76     return rsa;
77 }
78 
rsa_ssh1_encrypt(unsigned char * data,int length,RSAKey * key)79 bool rsa_ssh1_encrypt(unsigned char *data, int length, RSAKey *key)
80 {
81     mp_int *b1, *b2;
82     int i;
83     unsigned char *p;
84 
85     if (key->bytes < length + 4)
86         return false;                  /* RSA key too short! */
87 
88     memmove(data + key->bytes - length, data, length);
89     data[0] = 0;
90     data[1] = 2;
91 
92     size_t npad = key->bytes - length - 3;
93     /*
94      * Generate a sequence of nonzero padding bytes. We do this in a
95      * reasonably uniform way and without having to loop round
96      * retrying the random number generation, by first generating an
97      * integer in [0,2^n) for an appropriately large n; then we
98      * repeatedly multiply by 255 to give an integer in [0,255*2^n),
99      * extract the top 8 bits to give an integer in [0,255), and mask
100      * those bits off before multiplying up again for the next digit.
101      * This gives us a sequence of numbers in [0,255), and of course
102      * adding 1 to each of them gives numbers in [1,256) as we wanted.
103      *
104      * (You could imagine this being a sort of fixed-point operation:
105      * given a uniformly random binary _fraction_, multiplying it by k
106      * and subtracting off the integer part will yield you a sequence
107      * of integers each in [0,k). I'm just doing that scaled up by a
108      * power of 2 to avoid the fractions.)
109      */
110     size_t random_bits = (npad + 16) * 8;
111     mp_int *randval = mp_new(random_bits + 8);
112     mp_int *tmp = mp_random_bits(random_bits);
113     mp_copy_into(randval, tmp);
114     mp_free(tmp);
115     for (i = 2; i < key->bytes - length - 1; i++) {
116         mp_mul_integer_into(randval, randval, 255);
117         uint8_t byte = mp_get_byte(randval, random_bits / 8);
118         assert(byte != 255);
119         data[i] = byte + 1;
120         mp_reduce_mod_2to(randval, random_bits);
121     }
122     mp_free(randval);
123     data[key->bytes - length - 1] = 0;
124 
125     b1 = mp_from_bytes_be(make_ptrlen(data, key->bytes));
126 
127     b2 = mp_modpow(b1, key->exponent, key->modulus);
128 
129     p = data;
130     for (i = key->bytes; i--;) {
131         *p++ = mp_get_byte(b2, i);
132     }
133 
134     mp_free(b1);
135     mp_free(b2);
136 
137     return true;
138 }
139 
140 /*
141  * Compute (base ^ exp) % mod, provided mod == p * q, with p,q
142  * distinct primes, and iqmp is the multiplicative inverse of q mod p.
143  * Uses Chinese Remainder Theorem to speed computation up over the
144  * obvious implementation of a single big modpow.
145  */
crt_modpow(mp_int * base,mp_int * exp,mp_int * mod,mp_int * p,mp_int * q,mp_int * iqmp)146 static mp_int *crt_modpow(mp_int *base, mp_int *exp, mp_int *mod,
147                           mp_int *p, mp_int *q, mp_int *iqmp)
148 {
149     mp_int *pm1, *qm1, *pexp, *qexp, *presult, *qresult;
150     mp_int *diff, *multiplier, *ret0, *ret;
151 
152     /*
153      * Reduce the exponent mod phi(p) and phi(q), to save time when
154      * exponentiating mod p and mod q respectively. Of course, since p
155      * and q are prime, phi(p) == p-1 and similarly for q.
156      */
157     pm1 = mp_copy(p);
158     mp_sub_integer_into(pm1, pm1, 1);
159     qm1 = mp_copy(q);
160     mp_sub_integer_into(qm1, qm1, 1);
161     pexp = mp_mod(exp, pm1);
162     qexp = mp_mod(exp, qm1);
163 
164     /*
165      * Do the two modpows.
166      */
167     mp_int *base_mod_p = mp_mod(base, p);
168     presult = mp_modpow(base_mod_p, pexp, p);
169     mp_free(base_mod_p);
170     mp_int *base_mod_q = mp_mod(base, q);
171     qresult = mp_modpow(base_mod_q, qexp, q);
172     mp_free(base_mod_q);
173 
174     /*
175      * Recombine the results. We want a value which is congruent to
176      * qresult mod q, and to presult mod p.
177      *
178      * We know that iqmp * q is congruent to 1 * mod p (by definition
179      * of iqmp) and to 0 mod q (obviously). So we start with qresult
180      * (which is congruent to qresult mod both primes), and add on
181      * (presult-qresult) * (iqmp * q) which adjusts it to be congruent
182      * to presult mod p without affecting its value mod q.
183      *
184      * (If presult-qresult < 0, we add p to it to keep it positive.)
185      */
186     unsigned presult_too_small = mp_cmp_hs(qresult, presult);
187     mp_cond_add_into(presult, presult, p, presult_too_small);
188 
189     diff = mp_sub(presult, qresult);
190     multiplier = mp_mul(iqmp, q);
191     ret0 = mp_mul(multiplier, diff);
192     mp_add_into(ret0, ret0, qresult);
193 
194     /*
195      * Finally, reduce the result mod n.
196      */
197     ret = mp_mod(ret0, mod);
198 
199     /*
200      * Free all the intermediate results before returning.
201      */
202     mp_free(pm1);
203     mp_free(qm1);
204     mp_free(pexp);
205     mp_free(qexp);
206     mp_free(presult);
207     mp_free(qresult);
208     mp_free(diff);
209     mp_free(multiplier);
210     mp_free(ret0);
211 
212     return ret;
213 }
214 
215 /*
216  * Wrapper on crt_modpow that looks up all the right values from an
217  * RSAKey.
218  */
rsa_privkey_op(mp_int * input,RSAKey * key)219 static mp_int *rsa_privkey_op(mp_int *input, RSAKey *key)
220 {
221     return crt_modpow(input, key->private_exponent,
222                       key->modulus, key->p, key->q, key->iqmp);
223 }
224 
rsa_ssh1_decrypt(mp_int * input,RSAKey * key)225 mp_int *rsa_ssh1_decrypt(mp_int *input, RSAKey *key)
226 {
227     return rsa_privkey_op(input, key);
228 }
229 
rsa_ssh1_decrypt_pkcs1(mp_int * input,RSAKey * key,strbuf * outbuf)230 bool rsa_ssh1_decrypt_pkcs1(mp_int *input, RSAKey *key,
231                             strbuf *outbuf)
232 {
233     strbuf *data = strbuf_new_nm();
234     bool success = false;
235     BinarySource src[1];
236 
237     {
238         mp_int *b = rsa_ssh1_decrypt(input, key);
239         for (size_t i = (mp_get_nbits(key->modulus) + 7) / 8; i-- > 0 ;) {
240             put_byte(data, mp_get_byte(b, i));
241         }
242         mp_free(b);
243     }
244 
245     BinarySource_BARE_INIT(src, data->u, data->len);
246 
247     /* Check PKCS#1 formatting prefix */
248     if (get_byte(src) != 0) goto out;
249     if (get_byte(src) != 2) goto out;
250     while (1) {
251         unsigned char byte = get_byte(src);
252         if (get_err(src)) goto out;
253         if (byte == 0)
254             break;
255     }
256 
257     /* Everything else is the payload */
258     success = true;
259     put_data(outbuf, get_ptr(src), get_avail(src));
260 
261   out:
262     strbuf_free(data);
263     return success;
264 }
265 
append_hex_to_strbuf(strbuf * sb,mp_int * x)266 static void append_hex_to_strbuf(strbuf *sb, mp_int *x)
267 {
268     if (sb->len > 0)
269         put_byte(sb, ',');
270     put_data(sb, "0x", 2);
271     char *hex = mp_get_hex(x);
272     size_t hexlen = strlen(hex);
273     put_data(sb, hex, hexlen);
274     smemclr(hex, hexlen);
275     sfree(hex);
276 }
277 
rsastr_fmt(RSAKey * key)278 char *rsastr_fmt(RSAKey *key)
279 {
280     strbuf *sb = strbuf_new();
281 
282     append_hex_to_strbuf(sb, key->exponent);
283     append_hex_to_strbuf(sb, key->modulus);
284 
285     return strbuf_to_str(sb);
286 }
287 
288 /*
289  * Generate a fingerprint string for the key. Compatible with the
290  * OpenSSH fingerprint code.
291  */
rsa_ssh1_fingerprint(RSAKey * key)292 char *rsa_ssh1_fingerprint(RSAKey *key)
293 {
294     unsigned char digest[16];
295     strbuf *out;
296     int i;
297 
298     /*
299      * The hash preimage for SSH-1 key fingerprinting consists of the
300      * modulus and exponent _without_ any preceding length field -
301      * just the minimum number of bytes to represent each integer,
302      * stored big-endian, concatenated with no marker at the division
303      * between them.
304      */
305 
306     ssh_hash *hash = ssh_hash_new(&ssh_md5);
307     for (size_t i = (mp_get_nbits(key->modulus) + 7) / 8; i-- > 0 ;)
308         put_byte(hash, mp_get_byte(key->modulus, i));
309     for (size_t i = (mp_get_nbits(key->exponent) + 7) / 8; i-- > 0 ;)
310         put_byte(hash, mp_get_byte(key->exponent, i));
311     ssh_hash_final(hash, digest);
312 
313     out = strbuf_new();
314     strbuf_catf(out, "%"SIZEu" ", mp_get_nbits(key->modulus));
315     for (i = 0; i < 16; i++)
316         strbuf_catf(out, "%s%02x", i ? ":" : "", digest[i]);
317     if (key->comment)
318         strbuf_catf(out, " %s", key->comment);
319     return strbuf_to_str(out);
320 }
321 
322 /*
323  * Wrap the output of rsa_ssh1_fingerprint up into the same kind of
324  * structure that comes from ssh2_all_fingerprints.
325  */
rsa_ssh1_fake_all_fingerprints(RSAKey * key)326 char **rsa_ssh1_fake_all_fingerprints(RSAKey *key)
327 {
328     char **ret = snewn(SSH_N_FPTYPES, char *);
329     for (unsigned i = 0; i < SSH_N_FPTYPES; i++)
330         ret[i] = NULL;
331     ret[SSH_FPTYPE_MD5] = rsa_ssh1_fingerprint(key);
332     return ret;
333 }
334 
335 /*
336  * Verify that the public data in an RSA key matches the private
337  * data. We also check the private data itself: we ensure that p >
338  * q and that iqmp really is the inverse of q mod p.
339  */
rsa_verify(RSAKey * key)340 bool rsa_verify(RSAKey *key)
341 {
342     mp_int *n, *ed, *pm1, *qm1;
343     unsigned ok = 1;
344 
345     /* Preliminary checks: p,q can't be 0 or 1. (Of course no other
346      * very small value is any good either, but these are the values
347      * we _must_ check for to avoid assertion failures further down
348      * this function.) */
349     if (!(mp_hs_integer(key->p, 2) & mp_hs_integer(key->q, 2)))
350         return false;
351 
352     /* n must equal pq. */
353     n = mp_mul(key->p, key->q);
354     ok &= mp_cmp_eq(n, key->modulus);
355     mp_free(n);
356 
357     /* e * d must be congruent to 1, modulo (p-1) and modulo (q-1). */
358     pm1 = mp_copy(key->p);
359     mp_sub_integer_into(pm1, pm1, 1);
360     ed = mp_modmul(key->exponent, key->private_exponent, pm1);
361     mp_free(pm1);
362     ok &= mp_eq_integer(ed, 1);
363     mp_free(ed);
364 
365     qm1 = mp_copy(key->q);
366     mp_sub_integer_into(qm1, qm1, 1);
367     ed = mp_modmul(key->exponent, key->private_exponent, qm1);
368     mp_free(qm1);
369     ok &= mp_eq_integer(ed, 1);
370     mp_free(ed);
371 
372     /*
373      * Ensure p > q.
374      *
375      * I have seen key blobs in the wild which were generated with
376      * p < q, so instead of rejecting the key in this case we
377      * should instead flip them round into the canonical order of
378      * p > q. This also involves regenerating iqmp.
379      */
380     mp_int *p_new = mp_max(key->p, key->q);
381     mp_int *q_new = mp_min(key->p, key->q);
382     mp_free(key->p);
383     mp_free(key->q);
384     mp_free(key->iqmp);
385     key->p = p_new;
386     key->q = q_new;
387     key->iqmp = mp_invert(key->q, key->p);
388 
389     return ok;
390 }
391 
rsa_ssh1_public_blob(BinarySink * bs,RSAKey * key,RsaSsh1Order order)392 void rsa_ssh1_public_blob(BinarySink *bs, RSAKey *key,
393                           RsaSsh1Order order)
394 {
395     put_uint32(bs, mp_get_nbits(key->modulus));
396     if (order == RSA_SSH1_EXPONENT_FIRST) {
397         put_mp_ssh1(bs, key->exponent);
398         put_mp_ssh1(bs, key->modulus);
399     } else {
400         put_mp_ssh1(bs, key->modulus);
401         put_mp_ssh1(bs, key->exponent);
402     }
403 }
404 
rsa_ssh1_private_blob_agent(BinarySink * bs,RSAKey * key)405 void rsa_ssh1_private_blob_agent(BinarySink *bs, RSAKey *key)
406 {
407     rsa_ssh1_public_blob(bs, key, RSA_SSH1_MODULUS_FIRST);
408     put_mp_ssh1(bs, key->private_exponent);
409     put_mp_ssh1(bs, key->iqmp);
410     put_mp_ssh1(bs, key->q);
411     put_mp_ssh1(bs, key->p);
412 }
413 
414 /* Given an SSH-1 public key blob, determine its length. */
rsa_ssh1_public_blob_len(ptrlen data)415 int rsa_ssh1_public_blob_len(ptrlen data)
416 {
417     BinarySource src[1];
418 
419     BinarySource_BARE_INIT_PL(src, data);
420 
421     /* Expect a length word, then exponent and modulus. (It doesn't
422      * even matter which order.) */
423     get_uint32(src);
424     mp_free(get_mp_ssh1(src));
425     mp_free(get_mp_ssh1(src));
426 
427     if (get_err(src))
428         return -1;
429 
430     /* Return the number of bytes consumed. */
431     return src->pos;
432 }
433 
freersapriv(RSAKey * key)434 void freersapriv(RSAKey *key)
435 {
436     if (key->private_exponent) {
437         mp_free(key->private_exponent);
438         key->private_exponent = NULL;
439     }
440     if (key->p) {
441         mp_free(key->p);
442         key->p = NULL;
443     }
444     if (key->q) {
445         mp_free(key->q);
446         key->q = NULL;
447     }
448     if (key->iqmp) {
449         mp_free(key->iqmp);
450         key->iqmp = NULL;
451     }
452 }
453 
freersakey(RSAKey * key)454 void freersakey(RSAKey *key)
455 {
456     freersapriv(key);
457     if (key->modulus) {
458         mp_free(key->modulus);
459         key->modulus = NULL;
460     }
461     if (key->exponent) {
462         mp_free(key->exponent);
463         key->exponent = NULL;
464     }
465     if (key->comment) {
466         sfree(key->comment);
467         key->comment = NULL;
468     }
469 }
470 
471 /* ----------------------------------------------------------------------
472  * Implementation of the ssh-rsa signing key type family.
473  */
474 
475 struct ssh2_rsa_extra {
476     unsigned signflags;
477 };
478 
479 static void rsa2_freekey(ssh_key *key);   /* forward reference */
480 
rsa2_new_pub(const ssh_keyalg * self,ptrlen data)481 static ssh_key *rsa2_new_pub(const ssh_keyalg *self, ptrlen data)
482 {
483     BinarySource src[1];
484     RSAKey *rsa;
485 
486     BinarySource_BARE_INIT_PL(src, data);
487     if (!ptrlen_eq_string(get_string(src), "ssh-rsa"))
488         return NULL;
489 
490     rsa = snew(RSAKey);
491     rsa->sshk.vt = self;
492     rsa->exponent = get_mp_ssh2(src);
493     rsa->modulus = get_mp_ssh2(src);
494     rsa->private_exponent = NULL;
495     rsa->p = rsa->q = rsa->iqmp = NULL;
496     rsa->comment = NULL;
497 
498     if (get_err(src)) {
499         rsa2_freekey(&rsa->sshk);
500         return NULL;
501     }
502 
503     return &rsa->sshk;
504 }
505 
rsa2_freekey(ssh_key * key)506 static void rsa2_freekey(ssh_key *key)
507 {
508     RSAKey *rsa = container_of(key, RSAKey, sshk);
509     freersakey(rsa);
510     sfree(rsa);
511 }
512 
rsa2_cache_str(ssh_key * key)513 static char *rsa2_cache_str(ssh_key *key)
514 {
515     RSAKey *rsa = container_of(key, RSAKey, sshk);
516     return rsastr_fmt(rsa);
517 }
518 
rsa2_components(ssh_key * key)519 static key_components *rsa2_components(ssh_key *key)
520 {
521     RSAKey *rsa = container_of(key, RSAKey, sshk);
522     return rsa_components(rsa);
523 }
524 
rsa2_public_blob(ssh_key * key,BinarySink * bs)525 static void rsa2_public_blob(ssh_key *key, BinarySink *bs)
526 {
527     RSAKey *rsa = container_of(key, RSAKey, sshk);
528 
529     put_stringz(bs, "ssh-rsa");
530     put_mp_ssh2(bs, rsa->exponent);
531     put_mp_ssh2(bs, rsa->modulus);
532 }
533 
rsa2_private_blob(ssh_key * key,BinarySink * bs)534 static void rsa2_private_blob(ssh_key *key, BinarySink *bs)
535 {
536     RSAKey *rsa = container_of(key, RSAKey, sshk);
537 
538     put_mp_ssh2(bs, rsa->private_exponent);
539     put_mp_ssh2(bs, rsa->p);
540     put_mp_ssh2(bs, rsa->q);
541     put_mp_ssh2(bs, rsa->iqmp);
542 }
543 
rsa2_new_priv(const ssh_keyalg * self,ptrlen pub,ptrlen priv)544 static ssh_key *rsa2_new_priv(const ssh_keyalg *self,
545                                ptrlen pub, ptrlen priv)
546 {
547     BinarySource src[1];
548     ssh_key *sshk;
549     RSAKey *rsa;
550 
551     sshk = rsa2_new_pub(self, pub);
552     if (!sshk)
553         return NULL;
554 
555     rsa = container_of(sshk, RSAKey, sshk);
556     BinarySource_BARE_INIT_PL(src, priv);
557     rsa->private_exponent = get_mp_ssh2(src);
558     rsa->p = get_mp_ssh2(src);
559     rsa->q = get_mp_ssh2(src);
560     rsa->iqmp = get_mp_ssh2(src);
561 
562     if (get_err(src) || !rsa_verify(rsa)) {
563         rsa2_freekey(&rsa->sshk);
564         return NULL;
565     }
566 
567     return &rsa->sshk;
568 }
569 
rsa2_new_priv_openssh(const ssh_keyalg * self,BinarySource * src)570 static ssh_key *rsa2_new_priv_openssh(const ssh_keyalg *self,
571                                       BinarySource *src)
572 {
573     RSAKey *rsa;
574 
575     rsa = snew(RSAKey);
576     rsa->sshk.vt = &ssh_rsa;
577     rsa->comment = NULL;
578 
579     rsa->modulus = get_mp_ssh2(src);
580     rsa->exponent = get_mp_ssh2(src);
581     rsa->private_exponent = get_mp_ssh2(src);
582     rsa->iqmp = get_mp_ssh2(src);
583     rsa->p = get_mp_ssh2(src);
584     rsa->q = get_mp_ssh2(src);
585 
586     if (get_err(src) || !rsa_verify(rsa)) {
587         rsa2_freekey(&rsa->sshk);
588         return NULL;
589     }
590 
591     return &rsa->sshk;
592 }
593 
rsa2_openssh_blob(ssh_key * key,BinarySink * bs)594 static void rsa2_openssh_blob(ssh_key *key, BinarySink *bs)
595 {
596     RSAKey *rsa = container_of(key, RSAKey, sshk);
597 
598     put_mp_ssh2(bs, rsa->modulus);
599     put_mp_ssh2(bs, rsa->exponent);
600     put_mp_ssh2(bs, rsa->private_exponent);
601     put_mp_ssh2(bs, rsa->iqmp);
602     put_mp_ssh2(bs, rsa->p);
603     put_mp_ssh2(bs, rsa->q);
604 }
605 
rsa2_pubkey_bits(const ssh_keyalg * self,ptrlen pub)606 static int rsa2_pubkey_bits(const ssh_keyalg *self, ptrlen pub)
607 {
608     ssh_key *sshk;
609     RSAKey *rsa;
610     int ret;
611 
612     sshk = rsa2_new_pub(self, pub);
613     if (!sshk)
614         return -1;
615 
616     rsa = container_of(sshk, RSAKey, sshk);
617     ret = mp_get_nbits(rsa->modulus);
618     rsa2_freekey(&rsa->sshk);
619 
620     return ret;
621 }
622 
rsa2_hash_alg_for_flags(unsigned flags,const char ** protocol_id_out)623 static inline const ssh_hashalg *rsa2_hash_alg_for_flags(
624     unsigned flags, const char **protocol_id_out)
625 {
626     const ssh_hashalg *halg;
627     const char *protocol_id;
628 
629     if (flags & SSH_AGENT_RSA_SHA2_256) {
630         halg = &ssh_sha256;
631         protocol_id = "rsa-sha2-256";
632     } else if (flags & SSH_AGENT_RSA_SHA2_512) {
633         halg = &ssh_sha512;
634         protocol_id = "rsa-sha2-512";
635     } else {
636         halg = &ssh_sha1;
637         protocol_id = "ssh-rsa";
638     }
639 
640     if (protocol_id_out)
641         *protocol_id_out = protocol_id;
642 
643     return halg;
644 }
645 
rsa_pkcs1_prefix_for_hash(const ssh_hashalg * halg)646 static inline ptrlen rsa_pkcs1_prefix_for_hash(const ssh_hashalg *halg)
647 {
648     if (halg == &ssh_sha1) {
649         /*
650          * This is the magic ASN.1/DER prefix that goes in the decoded
651          * signature, between the string of FFs and the actual SHA-1
652          * hash value. The meaning of it is:
653          *
654          * 00 -- this marks the end of the FFs; not part of the ASN.1
655          * bit itself
656          *
657          * 30 21 -- a constructed SEQUENCE of length 0x21
658          *    30 09 -- a constructed sub-SEQUENCE of length 9
659          *       06 05 -- an object identifier, length 5
660          *          2B 0E 03 02 1A -- object id { 1 3 14 3 2 26 }
661          *                            (the 1,3 comes from 0x2B = 43 = 40*1+3)
662          *       05 00 -- NULL
663          *    04 14 -- a primitive OCTET STRING of length 0x14
664          *       [0x14 bytes of hash data follows]
665          *
666          * The object id in the middle there is listed as `id-sha1' in
667          * ftp://ftp.rsasecurity.com/pub/pkcs/pkcs-1/pkcs-1v2-1d2.asn
668          * (the ASN module for PKCS #1) and its expanded form is as
669          * follows:
670          *
671          * id-sha1                OBJECT IDENTIFIER ::= {
672          *    iso(1) identified-organization(3) oiw(14) secsig(3)
673          *    algorithms(2) 26 }
674          */
675         static const unsigned char sha1_asn1_prefix[] = {
676             0x00, 0x30, 0x21, 0x30, 0x09, 0x06, 0x05, 0x2B,
677             0x0E, 0x03, 0x02, 0x1A, 0x05, 0x00, 0x04, 0x14,
678         };
679         return PTRLEN_FROM_CONST_BYTES(sha1_asn1_prefix);
680     }
681 
682     if (halg == &ssh_sha256) {
683         /*
684          * A similar piece of ASN.1 used for signatures using SHA-256,
685          * in the same format but differing only in various length
686          * fields and OID.
687          */
688         static const unsigned char sha256_asn1_prefix[] = {
689             0x00, 0x30, 0x31, 0x30, 0x0d, 0x06, 0x09, 0x60,
690             0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x01,
691             0x05, 0x00, 0x04, 0x20,
692         };
693         return PTRLEN_FROM_CONST_BYTES(sha256_asn1_prefix);
694     }
695 
696     if (halg == &ssh_sha512) {
697         /*
698          * And one more for SHA-512.
699          */
700         static const unsigned char sha512_asn1_prefix[] = {
701             0x00, 0x30, 0x51, 0x30, 0x0d, 0x06, 0x09, 0x60,
702             0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x03,
703             0x05, 0x00, 0x04, 0x40,
704         };
705         return PTRLEN_FROM_CONST_BYTES(sha512_asn1_prefix);
706     }
707 
708     unreachable("bad hash algorithm for RSA PKCS#1");
709 }
710 
rsa_pkcs1_length_of_fixed_parts(const ssh_hashalg * halg)711 static inline size_t rsa_pkcs1_length_of_fixed_parts(const ssh_hashalg *halg)
712 {
713     ptrlen asn1_prefix = rsa_pkcs1_prefix_for_hash(halg);
714     return halg->hlen + asn1_prefix.len + 2;
715 }
716 
rsa_pkcs1_signature_string(size_t nbytes,const ssh_hashalg * halg,ptrlen data)717 static unsigned char *rsa_pkcs1_signature_string(
718     size_t nbytes, const ssh_hashalg *halg, ptrlen data)
719 {
720     size_t fixed_parts = rsa_pkcs1_length_of_fixed_parts(halg);
721     assert(nbytes >= fixed_parts);
722     size_t padding = nbytes - fixed_parts;
723 
724     ptrlen asn1_prefix = rsa_pkcs1_prefix_for_hash(halg);
725 
726     unsigned char *bytes = snewn(nbytes, unsigned char);
727 
728     bytes[0] = 0;
729     bytes[1] = 1;
730 
731     memset(bytes + 2, 0xFF, padding);
732 
733     memcpy(bytes + 2 + padding, asn1_prefix.ptr, asn1_prefix.len);
734 
735     ssh_hash *h = ssh_hash_new(halg);
736     put_datapl(h, data);
737     ssh_hash_final(h, bytes + 2 + padding + asn1_prefix.len);
738 
739     return bytes;
740 }
741 
rsa2_verify(ssh_key * key,ptrlen sig,ptrlen data)742 static bool rsa2_verify(ssh_key *key, ptrlen sig, ptrlen data)
743 {
744     RSAKey *rsa = container_of(key, RSAKey, sshk);
745     BinarySource src[1];
746     ptrlen type, in_pl;
747     mp_int *in, *out;
748 
749     const struct ssh2_rsa_extra *extra =
750         (const struct ssh2_rsa_extra *)key->vt->extra;
751 
752     const ssh_hashalg *halg = rsa2_hash_alg_for_flags(extra->signflags, NULL);
753 
754     /* Start by making sure the key is even long enough to encode a
755      * signature. If not, everything fails to verify. */
756     size_t nbytes = (mp_get_nbits(rsa->modulus) + 7) / 8;
757     if (nbytes < rsa_pkcs1_length_of_fixed_parts(halg))
758         return false;
759 
760     BinarySource_BARE_INIT_PL(src, sig);
761     type = get_string(src);
762     /*
763      * RFC 4253 section 6.6: the signature integer in an ssh-rsa
764      * signature is 'without lengths or padding'. That is, we _don't_
765      * expect the usual leading zero byte if the topmost bit of the
766      * first byte is set. (However, because of the possibility of
767      * BUG_SSH2_RSA_PADDING at the other end, we tolerate it if it's
768      * there.) So we can't use get_mp_ssh2, which enforces that
769      * leading-byte scheme; instead we use get_string and
770      * mp_from_bytes_be, which will tolerate anything.
771      */
772     in_pl = get_string(src);
773     if (get_err(src) || !ptrlen_eq_string(type, key->vt->ssh_id))
774         return false;
775 
776     in = mp_from_bytes_be(in_pl);
777     out = mp_modpow(in, rsa->exponent, rsa->modulus);
778     mp_free(in);
779 
780     unsigned diff = 0;
781 
782     unsigned char *bytes = rsa_pkcs1_signature_string(nbytes, halg, data);
783     for (size_t i = 0; i < nbytes; i++)
784         diff |= bytes[nbytes-1 - i] ^ mp_get_byte(out, i);
785     smemclr(bytes, nbytes);
786     sfree(bytes);
787     mp_free(out);
788 
789     return diff == 0;
790 }
791 
rsa2_sign(ssh_key * key,ptrlen data,unsigned flags,BinarySink * bs)792 static void rsa2_sign(ssh_key *key, ptrlen data,
793                       unsigned flags, BinarySink *bs)
794 {
795     RSAKey *rsa = container_of(key, RSAKey, sshk);
796     unsigned char *bytes;
797     size_t nbytes;
798     mp_int *in, *out;
799     const ssh_hashalg *halg;
800     const char *sign_alg_name;
801 
802     const struct ssh2_rsa_extra *extra =
803         (const struct ssh2_rsa_extra *)key->vt->extra;
804     flags |= extra->signflags;
805 
806     halg = rsa2_hash_alg_for_flags(flags, &sign_alg_name);
807 
808     nbytes = (mp_get_nbits(rsa->modulus) + 7) / 8;
809 
810     bytes = rsa_pkcs1_signature_string(nbytes, halg, data);
811     in = mp_from_bytes_be(make_ptrlen(bytes, nbytes));
812     smemclr(bytes, nbytes);
813     sfree(bytes);
814 
815     out = rsa_privkey_op(in, rsa);
816     mp_free(in);
817 
818     put_stringz(bs, sign_alg_name);
819     nbytes = (mp_get_nbits(out) + 7) / 8;
820     put_uint32(bs, nbytes);
821     for (size_t i = 0; i < nbytes; i++)
822         put_byte(bs, mp_get_byte(out, nbytes - 1 - i));
823 
824     mp_free(out);
825 }
826 
rsa2_invalid(ssh_key * key,unsigned flags)827 static char *rsa2_invalid(ssh_key *key, unsigned flags)
828 {
829     RSAKey *rsa = container_of(key, RSAKey, sshk);
830     size_t bits = mp_get_nbits(rsa->modulus), nbytes = (bits + 7) / 8;
831     const char *sign_alg_name;
832     const ssh_hashalg *halg = rsa2_hash_alg_for_flags(flags, &sign_alg_name);
833     if (nbytes < rsa_pkcs1_length_of_fixed_parts(halg)) {
834         return dupprintf(
835             "%"SIZEu"-bit RSA key is too short to generate %s signatures",
836             bits, sign_alg_name);
837     }
838 
839     return NULL;
840 }
841 
842 static const struct ssh2_rsa_extra
843     rsa_extra = { 0 },
844     rsa_sha256_extra = { SSH_AGENT_RSA_SHA2_256 },
845     rsa_sha512_extra = { SSH_AGENT_RSA_SHA2_512 };
846 
847 #define COMMON_KEYALG_FIELDS                    \
848     .new_pub = rsa2_new_pub,                    \
849     .new_priv = rsa2_new_priv,                  \
850     .new_priv_openssh = rsa2_new_priv_openssh,  \
851     .freekey = rsa2_freekey,                    \
852     .invalid = rsa2_invalid,                    \
853     .sign = rsa2_sign,                          \
854     .verify = rsa2_verify,                      \
855     .public_blob = rsa2_public_blob,            \
856     .private_blob = rsa2_private_blob,          \
857     .openssh_blob = rsa2_openssh_blob,          \
858     .cache_str = rsa2_cache_str,                \
859     .components = rsa2_components,              \
860     .pubkey_bits = rsa2_pubkey_bits,            \
861     .cache_id = "rsa2"
862 
863 const ssh_keyalg ssh_rsa = {
864     COMMON_KEYALG_FIELDS,
865     .ssh_id = "ssh-rsa",
866     .supported_flags = SSH_AGENT_RSA_SHA2_256 | SSH_AGENT_RSA_SHA2_512,
867     .extra = &rsa_extra,
868 };
869 
870 const ssh_keyalg ssh_rsa_sha256 = {
871     COMMON_KEYALG_FIELDS,
872     .ssh_id = "rsa-sha2-256",
873     .supported_flags = 0,
874     .extra = &rsa_sha256_extra,
875 };
876 
877 const ssh_keyalg ssh_rsa_sha512 = {
878     COMMON_KEYALG_FIELDS,
879     .ssh_id = "rsa-sha2-512",
880     .supported_flags = 0,
881     .extra = &rsa_sha512_extra,
882 };
883 
ssh_rsakex_newkey(ptrlen data)884 RSAKey *ssh_rsakex_newkey(ptrlen data)
885 {
886     ssh_key *sshk = rsa2_new_pub(&ssh_rsa, data);
887     if (!sshk)
888         return NULL;
889     return container_of(sshk, RSAKey, sshk);
890 }
891 
ssh_rsakex_freekey(RSAKey * key)892 void ssh_rsakex_freekey(RSAKey *key)
893 {
894     rsa2_freekey(&key->sshk);
895 }
896 
ssh_rsakex_klen(RSAKey * rsa)897 int ssh_rsakex_klen(RSAKey *rsa)
898 {
899     return mp_get_nbits(rsa->modulus);
900 }
901 
oaep_mask(const ssh_hashalg * h,void * seed,int seedlen,void * vdata,int datalen)902 static void oaep_mask(const ssh_hashalg *h, void *seed, int seedlen,
903                       void *vdata, int datalen)
904 {
905     unsigned char *data = (unsigned char *)vdata;
906     unsigned count = 0;
907 
908     ssh_hash *s = ssh_hash_new(h);
909 
910     while (datalen > 0) {
911         int i, max = (datalen > h->hlen ? h->hlen : datalen);
912         unsigned char hash[MAX_HASH_LEN];
913 
914         ssh_hash_reset(s);
915         assert(h->hlen <= MAX_HASH_LEN);
916         put_data(s, seed, seedlen);
917         put_uint32(s, count);
918         ssh_hash_digest(s, hash);
919         count++;
920 
921         for (i = 0; i < max; i++)
922             data[i] ^= hash[i];
923 
924         data += max;
925         datalen -= max;
926     }
927 
928     ssh_hash_free(s);
929 }
930 
ssh_rsakex_encrypt(RSAKey * rsa,const ssh_hashalg * h,ptrlen in)931 strbuf *ssh_rsakex_encrypt(RSAKey *rsa, const ssh_hashalg *h, ptrlen in)
932 {
933     mp_int *b1, *b2;
934     int k, i;
935     char *p;
936     const int HLEN = h->hlen;
937 
938     /*
939      * Here we encrypt using RSAES-OAEP. Essentially this means:
940      *
941      *  - we have a SHA-based `mask generation function' which
942      *    creates a pseudo-random stream of mask data
943      *    deterministically from an input chunk of data.
944      *
945      *  - we have a random chunk of data called a seed.
946      *
947      *  - we use the seed to generate a mask which we XOR with our
948      *    plaintext.
949      *
950      *  - then we use _the masked plaintext_ to generate a mask
951      *    which we XOR with the seed.
952      *
953      *  - then we concatenate the masked seed and the masked
954      *    plaintext, and RSA-encrypt that lot.
955      *
956      * The result is that the data input to the encryption function
957      * is random-looking and (hopefully) contains no exploitable
958      * structure such as PKCS1-v1_5 does.
959      *
960      * For a precise specification, see RFC 3447, section 7.1.1.
961      * Some of the variable names below are derived from that, so
962      * it'd probably help to read it anyway.
963      */
964 
965     /* k denotes the length in octets of the RSA modulus. */
966     k = (7 + mp_get_nbits(rsa->modulus)) / 8;
967 
968     /* The length of the input data must be at most k - 2hLen - 2. */
969     assert(in.len > 0 && in.len <= k - 2*HLEN - 2);
970 
971     /* The length of the output data wants to be precisely k. */
972     strbuf *toret = strbuf_new_nm();
973     int outlen = k;
974     unsigned char *out = strbuf_append(toret, outlen);
975 
976     /*
977      * Now perform EME-OAEP encoding. First set up all the unmasked
978      * output data.
979      */
980     /* Leading byte zero. */
981     out[0] = 0;
982     /* At position 1, the seed: HLEN bytes of random data. */
983     random_read(out + 1, HLEN);
984     /* At position 1+HLEN, the data block DB, consisting of: */
985     /* The hash of the label (we only support an empty label here) */
986     hash_simple(h, PTRLEN_LITERAL(""), out + HLEN + 1);
987     /* A bunch of zero octets */
988     memset(out + 2*HLEN + 1, 0, outlen - (2*HLEN + 1));
989     /* A single 1 octet, followed by the input message data. */
990     out[outlen - in.len - 1] = 1;
991     memcpy(out + outlen - in.len, in.ptr, in.len);
992 
993     /*
994      * Now use the seed data to mask the block DB.
995      */
996     oaep_mask(h, out+1, HLEN, out+HLEN+1, outlen-HLEN-1);
997 
998     /*
999      * And now use the masked DB to mask the seed itself.
1000      */
1001     oaep_mask(h, out+HLEN+1, outlen-HLEN-1, out+1, HLEN);
1002 
1003     /*
1004      * Now `out' contains precisely the data we want to
1005      * RSA-encrypt.
1006      */
1007     b1 = mp_from_bytes_be(make_ptrlen(out, outlen));
1008     b2 = mp_modpow(b1, rsa->exponent, rsa->modulus);
1009     p = (char *)out;
1010     for (i = outlen; i--;) {
1011         *p++ = mp_get_byte(b2, i);
1012     }
1013     mp_free(b1);
1014     mp_free(b2);
1015 
1016     /*
1017      * And we're done.
1018      */
1019     return toret;
1020 }
1021 
ssh_rsakex_decrypt(RSAKey * rsa,const ssh_hashalg * h,ptrlen ciphertext)1022 mp_int *ssh_rsakex_decrypt(
1023     RSAKey *rsa, const ssh_hashalg *h, ptrlen ciphertext)
1024 {
1025     mp_int *b1, *b2;
1026     int outlen, i;
1027     unsigned char *out;
1028     unsigned char labelhash[64];
1029     BinarySource src[1];
1030     const int HLEN = h->hlen;
1031 
1032     /*
1033      * Decryption side of the RSA key exchange operation.
1034      */
1035 
1036     /* The length of the encrypted data should be exactly the length
1037      * in octets of the RSA modulus.. */
1038     outlen = (7 + mp_get_nbits(rsa->modulus)) / 8;
1039     if (ciphertext.len != outlen)
1040         return NULL;
1041 
1042     /* Do the RSA decryption, and extract the result into a byte array. */
1043     b1 = mp_from_bytes_be(ciphertext);
1044     b2 = rsa_privkey_op(b1, rsa);
1045     out = snewn(outlen, unsigned char);
1046     for (i = 0; i < outlen; i++)
1047         out[i] = mp_get_byte(b2, outlen-1-i);
1048     mp_free(b1);
1049     mp_free(b2);
1050 
1051     /* Do the OAEP masking operations, in the reverse order from encryption */
1052     oaep_mask(h, out+HLEN+1, outlen-HLEN-1, out+1, HLEN);
1053     oaep_mask(h, out+1, HLEN, out+HLEN+1, outlen-HLEN-1);
1054 
1055     /* Check the leading byte is zero. */
1056     if (out[0] != 0) {
1057         sfree(out);
1058         return NULL;
1059     }
1060     /* Check the label hash at position 1+HLEN */
1061     assert(HLEN <= lenof(labelhash));
1062     hash_simple(h, PTRLEN_LITERAL(""), labelhash);
1063     if (memcmp(out + HLEN + 1, labelhash, HLEN)) {
1064         sfree(out);
1065         return NULL;
1066     }
1067     /* Expect zero bytes followed by a 1 byte */
1068     for (i = 1 + 2 * HLEN; i < outlen; i++) {
1069         if (out[i] == 1) {
1070             i++;  /* skip over the 1 byte */
1071             break;
1072         } else if (out[i] != 0) {
1073             sfree(out);
1074             return NULL;
1075         }
1076     }
1077     /* And what's left is the input message data, which should be
1078      * encoded as an ordinary SSH-2 mpint. */
1079     BinarySource_BARE_INIT(src, out + i, outlen - i);
1080     b1 = get_mp_ssh2(src);
1081     sfree(out);
1082     if (get_err(src) || get_avail(src) != 0) {
1083         mp_free(b1);
1084         return NULL;
1085     }
1086 
1087     /* Success! */
1088     return b1;
1089 }
1090 
1091 static const struct ssh_rsa_kex_extra ssh_rsa_kex_extra_sha1 = { 1024 };
1092 static const struct ssh_rsa_kex_extra ssh_rsa_kex_extra_sha256 = { 2048 };
1093 
1094 static const ssh_kex ssh_rsa_kex_sha1 = {
1095     "rsa1024-sha1", NULL, KEXTYPE_RSA,
1096     &ssh_sha1, &ssh_rsa_kex_extra_sha1,
1097 };
1098 
1099 static const ssh_kex ssh_rsa_kex_sha256 = {
1100     "rsa2048-sha256", NULL, KEXTYPE_RSA,
1101     &ssh_sha256, &ssh_rsa_kex_extra_sha256,
1102 };
1103 
1104 static const ssh_kex *const rsa_kex_list[] = {
1105     &ssh_rsa_kex_sha256,
1106     &ssh_rsa_kex_sha1
1107 };
1108 
1109 const ssh_kexes ssh_rsa_kex = { lenof(rsa_kex_list), rsa_kex_list };
1110