1# Copyright (C) 2011  Jeff Forcier <jeff@bitprophet.org>
2#
3# This file is part of ssh.
4#
5# 'ssh' is free software; you can redistribute it and/or modify it under the
6# terms of the GNU Lesser General Public License as published by the Free
7# Software Foundation; either version 2.1 of the License, or (at your option)
8# any later version.
9#
10# 'ssh' is distrubuted in the hope that it will be useful, but WITHOUT ANY
11# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
12# A PARTICULAR PURPOSE.  See the GNU Lesser General Public License for more
13# details.
14#
15# You should have received a copy of the GNU Lesser General Public License
16# along with 'ssh'; if not, write to the Free Software Foundation, Inc.,
17# 51 Franklin Street, Suite 500, Boston, MA  02110-1335  USA.
18
19"""
20L{RSAKey}
21"""
22
23from Crypto.PublicKey import RSA
24from Crypto.Hash import SHA, MD5
25from Crypto.Cipher import DES3
26
27from ssh.common import *
28from ssh import util
29from ssh.message import Message
30from ssh.ber import BER, BERException
31from ssh.pkey import PKey
32from ssh.ssh_exception import SSHException
33
34
35class RSAKey (PKey):
36    """
37    Representation of an RSA key which can be used to sign and verify SSH2
38    data.
39    """
40
41    def __init__(self, msg=None, data=None, filename=None, password=None, vals=None, file_obj=None):
42        self.n = None
43        self.e = None
44        self.d = None
45        self.p = None
46        self.q = None
47        if file_obj is not None:
48            self._from_private_key(file_obj, password)
49            return
50        if filename is not None:
51            self._from_private_key_file(filename, password)
52            return
53        if (msg is None) and (data is not None):
54            msg = Message(data)
55        if vals is not None:
56            self.e, self.n = vals
57        else:
58            if msg is None:
59                raise SSHException('Key object may not be empty')
60            if msg.get_string() != 'ssh-rsa':
61                raise SSHException('Invalid key')
62            self.e = msg.get_mpint()
63            self.n = msg.get_mpint()
64        self.size = util.bit_length(self.n)
65
66    def __str__(self):
67        m = Message()
68        m.add_string('ssh-rsa')
69        m.add_mpint(self.e)
70        m.add_mpint(self.n)
71        return str(m)
72
73    def __hash__(self):
74        h = hash(self.get_name())
75        h = h * 37 + hash(self.e)
76        h = h * 37 + hash(self.n)
77        return hash(h)
78
79    def get_name(self):
80        return 'ssh-rsa'
81
82    def get_bits(self):
83        return self.size
84
85    def can_sign(self):
86        return self.d is not None
87
88    def sign_ssh_data(self, rpool, data):
89        digest = SHA.new(data).digest()
90        rsa = RSA.construct((long(self.n), long(self.e), long(self.d)))
91        sig = util.deflate_long(rsa.sign(self._pkcs1imify(digest), '')[0], 0)
92        m = Message()
93        m.add_string('ssh-rsa')
94        m.add_string(sig)
95        return m
96
97    def verify_ssh_sig(self, data, msg):
98        if msg.get_string() != 'ssh-rsa':
99            return False
100        sig = util.inflate_long(msg.get_string(), True)
101        # verify the signature by SHA'ing the data and encrypting it using the
102        # public key.  some wackiness ensues where we "pkcs1imify" the 20-byte
103        # hash into a string as long as the RSA key.
104        hash_obj = util.inflate_long(self._pkcs1imify(SHA.new(data).digest()), True)
105        rsa = RSA.construct((long(self.n), long(self.e)))
106        return rsa.verify(hash_obj, (sig,))
107
108    def _encode_key(self):
109        if (self.p is None) or (self.q is None):
110            raise SSHException('Not enough key info to write private key file')
111        keylist = [ 0, self.n, self.e, self.d, self.p, self.q,
112                    self.d % (self.p - 1), self.d % (self.q - 1),
113                    util.mod_inverse(self.q, self.p) ]
114        try:
115            b = BER()
116            b.encode(keylist)
117        except BERException:
118            raise SSHException('Unable to create ber encoding of key')
119        return str(b)
120
121    def write_private_key_file(self, filename, password=None):
122        self._write_private_key_file('RSA', filename, self._encode_key(), password)
123
124    def write_private_key(self, file_obj, password=None):
125        self._write_private_key('RSA', file_obj, self._encode_key(), password)
126
127    def generate(bits, progress_func=None):
128        """
129        Generate a new private RSA key.  This factory function can be used to
130        generate a new host key or authentication key.
131
132        @param bits: number of bits the generated key should be.
133        @type bits: int
134        @param progress_func: an optional function to call at key points in
135            key generation (used by C{pyCrypto.PublicKey}).
136        @type progress_func: function
137        @return: new private key
138        @rtype: L{RSAKey}
139        """
140        rsa = RSA.generate(bits, rng.read, progress_func)
141        key = RSAKey(vals=(rsa.e, rsa.n))
142        key.d = rsa.d
143        key.p = rsa.p
144        key.q = rsa.q
145        return key
146    generate = staticmethod(generate)
147
148
149    ###  internals...
150
151
152    def _pkcs1imify(self, data):
153        """
154        turn a 20-byte SHA1 hash into a blob of data as large as the key's N,
155        using PKCS1's \"emsa-pkcs1-v1_5\" encoding.  totally bizarre.
156        """
157        SHA1_DIGESTINFO = '\x30\x21\x30\x09\x06\x05\x2b\x0e\x03\x02\x1a\x05\x00\x04\x14'
158        size = len(util.deflate_long(self.n, 0))
159        filler = '\xff' * (size - len(SHA1_DIGESTINFO) - len(data) - 3)
160        return '\x00\x01' + filler + '\x00' + SHA1_DIGESTINFO + data
161
162    def _from_private_key_file(self, filename, password):
163        data = self._read_private_key_file('RSA', filename, password)
164        self._decode_key(data)
165
166    def _from_private_key(self, file_obj, password):
167        data = self._read_private_key('RSA', file_obj, password)
168        self._decode_key(data)
169
170    def _decode_key(self, data):
171        # private key file contains:
172        # RSAPrivateKey = { version = 0, n, e, d, p, q, d mod p-1, d mod q-1, q**-1 mod p }
173        try:
174            keylist = BER(data).decode()
175        except BERException:
176            raise SSHException('Unable to parse key file')
177        if (type(keylist) is not list) or (len(keylist) < 4) or (keylist[0] != 0):
178            raise SSHException('Not a valid RSA private key file (bad ber encoding)')
179        self.n = keylist[1]
180        self.e = keylist[2]
181        self.d = keylist[3]
182        # not really needed
183        self.p = keylist[4]
184        self.q = keylist[5]
185        self.size = util.bit_length(self.n)
186