1# Copyright (C) 2006-2007  Jeff Forcier <jeff@bitprophet.org>
2#
3# This file is part of ssh.
4#
5# 'ssh' is free software; you can redistribute it and/or modify it under the
6# terms of the GNU Lesser General Public License as published by the Free
7# Software Foundation; either version 2.1 of the License, or (at your option)
8# any later version.
9#
10# 'ssh' is distrubuted in the hope that it will be useful, but WITHOUT ANY
11# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
12# A PARTICULAR PURPOSE.  See the GNU Lesser General Public License for more
13# details.
14#
15# You should have received a copy of the GNU Lesser General Public License
16# along with 'ssh'; if not, write to the Free Software Foundation, Inc.,
17# 51 Franklin Street, Suite 500, Boston, MA  02110-1335  USA.
18
19"""
20L{HostKeys}
21"""
22
23import base64
24import binascii
25from Crypto.Hash import SHA, HMAC
26import UserDict
27
28from ssh.common import *
29from ssh.dsskey import DSSKey
30from ssh.rsakey import RSAKey
31
32
33class InvalidHostKey(Exception):
34
35    def __init__(self, line, exc):
36        self.line = line
37        self.exc = exc
38        self.args = (line, exc)
39
40
41class HostKeyEntry:
42    """
43    Representation of a line in an OpenSSH-style "known hosts" file.
44    """
45
46    def __init__(self, hostnames=None, key=None):
47        self.valid = (hostnames is not None) and (key is not None)
48        self.hostnames = hostnames
49        self.key = key
50
51    def from_line(cls, line):
52        """
53        Parses the given line of text to find the names for the host,
54        the type of key, and the key data. The line is expected to be in the
55        format used by the openssh known_hosts file.
56
57        Lines are expected to not have leading or trailing whitespace.
58        We don't bother to check for comments or empty lines.  All of
59        that should be taken care of before sending the line to us.
60
61        @param line: a line from an OpenSSH known_hosts file
62        @type line: str
63        """
64        fields = line.split(' ')
65        if len(fields) < 3:
66            # Bad number of fields
67            return None
68        fields = fields[:3]
69
70        names, keytype, key = fields
71        names = names.split(',')
72
73        # Decide what kind of key we're looking at and create an object
74        # to hold it accordingly.
75        try:
76            if keytype == 'ssh-rsa':
77                key = RSAKey(data=base64.decodestring(key))
78            elif keytype == 'ssh-dss':
79                key = DSSKey(data=base64.decodestring(key))
80            else:
81                return None
82        except binascii.Error, e:
83            raise InvalidHostKey(line, e)
84
85        return cls(names, key)
86    from_line = classmethod(from_line)
87
88    def to_line(self):
89        """
90        Returns a string in OpenSSH known_hosts file format, or None if
91        the object is not in a valid state.  A trailing newline is
92        included.
93        """
94        if self.valid:
95            return '%s %s %s\n' % (','.join(self.hostnames), self.key.get_name(),
96                   self.key.get_base64())
97        return None
98
99    def __repr__(self):
100        return '<HostKeyEntry %r: %r>' % (self.hostnames, self.key)
101
102
103class HostKeys (UserDict.DictMixin):
104    """
105    Representation of an openssh-style "known hosts" file.  Host keys can be
106    read from one or more files, and then individual hosts can be looked up to
107    verify server keys during SSH negotiation.
108
109    A HostKeys object can be treated like a dict; any dict lookup is equivalent
110    to calling L{lookup}.
111
112    @since: 1.5.3
113    """
114
115    def __init__(self, filename=None):
116        """
117        Create a new HostKeys object, optionally loading keys from an openssh
118        style host-key file.
119
120        @param filename: filename to load host keys from, or C{None}
121        @type filename: str
122        """
123        # emulate a dict of { hostname: { keytype: PKey } }
124        self._entries = []
125        if filename is not None:
126            self.load(filename)
127
128    def add(self, hostname, keytype, key):
129        """
130        Add a host key entry to the table.  Any existing entry for a
131        C{(hostname, keytype)} pair will be replaced.
132
133        @param hostname: the hostname (or IP) to add
134        @type hostname: str
135        @param keytype: key type (C{"ssh-rsa"} or C{"ssh-dss"})
136        @type keytype: str
137        @param key: the key to add
138        @type key: L{PKey}
139        """
140        for e in self._entries:
141            if (hostname in e.hostnames) and (e.key.get_name() == keytype):
142                e.key = key
143                return
144        self._entries.append(HostKeyEntry([hostname], key))
145
146    def load(self, filename):
147        """
148        Read a file of known SSH host keys, in the format used by openssh.
149        This type of file unfortunately doesn't exist on Windows, but on
150        posix, it will usually be stored in
151        C{os.path.expanduser("~/.ssh/known_hosts")}.
152
153        If this method is called multiple times, the host keys are merged,
154        not cleared.  So multiple calls to C{load} will just call L{add},
155        replacing any existing entries and adding new ones.
156
157        @param filename: name of the file to read host keys from
158        @type filename: str
159
160        @raise IOError: if there was an error reading the file
161        """
162        f = open(filename, 'r')
163        for line in f:
164            line = line.strip()
165            if (len(line) == 0) or (line[0] == '#'):
166                continue
167            e = HostKeyEntry.from_line(line)
168            if e is not None:
169                self._entries.append(e)
170        f.close()
171
172    def save(self, filename):
173        """
174        Save host keys into a file, in the format used by openssh.  The order of
175        keys in the file will be preserved when possible (if these keys were
176        loaded from a file originally).  The single exception is that combined
177        lines will be split into individual key lines, which is arguably a bug.
178
179        @param filename: name of the file to write
180        @type filename: str
181
182        @raise IOError: if there was an error writing the file
183
184        @since: 1.6.1
185        """
186        f = open(filename, 'w')
187        for e in self._entries:
188            line = e.to_line()
189            if line:
190                f.write(line)
191        f.close()
192
193    def lookup(self, hostname):
194        """
195        Find a hostkey entry for a given hostname or IP.  If no entry is found,
196        C{None} is returned.  Otherwise a dictionary of keytype to key is
197        returned.  The keytype will be either C{"ssh-rsa"} or C{"ssh-dss"}.
198
199        @param hostname: the hostname (or IP) to lookup
200        @type hostname: str
201        @return: keys associated with this host (or C{None})
202        @rtype: dict(str, L{PKey})
203        """
204        class SubDict (UserDict.DictMixin):
205            def __init__(self, hostname, entries, hostkeys):
206                self._hostname = hostname
207                self._entries = entries
208                self._hostkeys = hostkeys
209
210            def __getitem__(self, key):
211                for e in self._entries:
212                    if e.key.get_name() == key:
213                        return e.key
214                raise KeyError(key)
215
216            def __setitem__(self, key, val):
217                for e in self._entries:
218                    if e.key is None:
219                        continue
220                    if e.key.get_name() == key:
221                        # replace
222                        e.key = val
223                        break
224                else:
225                    # add a new one
226                    e = HostKeyEntry([hostname], val)
227                    self._entries.append(e)
228                    self._hostkeys._entries.append(e)
229
230            def keys(self):
231                return [e.key.get_name() for e in self._entries if e.key is not None]
232
233        entries = []
234        for e in self._entries:
235            for h in e.hostnames:
236                if (h.startswith('|1|') and (self.hash_host(hostname, h) == h)) or (h == hostname):
237                    entries.append(e)
238        if len(entries) == 0:
239            return None
240        return SubDict(hostname, entries, self)
241
242    def check(self, hostname, key):
243        """
244        Return True if the given key is associated with the given hostname
245        in this dictionary.
246
247        @param hostname: hostname (or IP) of the SSH server
248        @type hostname: str
249        @param key: the key to check
250        @type key: L{PKey}
251        @return: C{True} if the key is associated with the hostname; C{False}
252            if not
253        @rtype: bool
254        """
255        k = self.lookup(hostname)
256        if k is None:
257            return False
258        host_key = k.get(key.get_name(), None)
259        if host_key is None:
260            return False
261        return str(host_key) == str(key)
262
263    def clear(self):
264        """
265        Remove all host keys from the dictionary.
266        """
267        self._entries = []
268
269    def __getitem__(self, key):
270        ret = self.lookup(key)
271        if ret is None:
272            raise KeyError(key)
273        return ret
274
275    def __setitem__(self, hostname, entry):
276        # don't use this please.
277        if len(entry) == 0:
278            self._entries.append(HostKeyEntry([hostname], None))
279            return
280        for key_type in entry.keys():
281            found = False
282            for e in self._entries:
283                if (hostname in e.hostnames) and (e.key.get_name() == key_type):
284                    # replace
285                    e.key = entry[key_type]
286                    found = True
287            if not found:
288                self._entries.append(HostKeyEntry([hostname], entry[key_type]))
289
290    def keys(self):
291        # python 2.4 sets would be nice here.
292        ret = []
293        for e in self._entries:
294            for h in e.hostnames:
295                if h not in ret:
296                    ret.append(h)
297        return ret
298
299    def values(self):
300        ret = []
301        for k in self.keys():
302            ret.append(self.lookup(k))
303        return ret
304
305    def hash_host(hostname, salt=None):
306        """
307        Return a "hashed" form of the hostname, as used by openssh when storing
308        hashed hostnames in the known_hosts file.
309
310        @param hostname: the hostname to hash
311        @type hostname: str
312        @param salt: optional salt to use when hashing (must be 20 bytes long)
313        @type salt: str
314        @return: the hashed hostname
315        @rtype: str
316        """
317        if salt is None:
318            salt = rng.read(SHA.digest_size)
319        else:
320            if salt.startswith('|1|'):
321                salt = salt.split('|')[2]
322            salt = base64.decodestring(salt)
323        assert len(salt) == SHA.digest_size
324        hmac = HMAC.HMAC(salt, hostname, SHA).digest()
325        hostkey = '|1|%s|%s' % (base64.encodestring(salt), base64.encodestring(hmac))
326        return hostkey.replace('\n', '')
327    hash_host = staticmethod(hash_host)
328
329