1# Copyright (C) 2006-2007 Jeff Forcier <jeff@bitprophet.org> 2# 3# This file is part of ssh. 4# 5# 'ssh' is free software; you can redistribute it and/or modify it under the 6# terms of the GNU Lesser General Public License as published by the Free 7# Software Foundation; either version 2.1 of the License, or (at your option) 8# any later version. 9# 10# 'ssh' is distrubuted in the hope that it will be useful, but WITHOUT ANY 11# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR 12# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more 13# details. 14# 15# You should have received a copy of the GNU Lesser General Public License 16# along with 'ssh'; if not, write to the Free Software Foundation, Inc., 17# 51 Franklin Street, Suite 500, Boston, MA 02110-1335 USA. 18 19""" 20L{HostKeys} 21""" 22 23import base64 24import binascii 25from Crypto.Hash import SHA, HMAC 26import UserDict 27 28from ssh.common import * 29from ssh.dsskey import DSSKey 30from ssh.rsakey import RSAKey 31 32 33class InvalidHostKey(Exception): 34 35 def __init__(self, line, exc): 36 self.line = line 37 self.exc = exc 38 self.args = (line, exc) 39 40 41class HostKeyEntry: 42 """ 43 Representation of a line in an OpenSSH-style "known hosts" file. 44 """ 45 46 def __init__(self, hostnames=None, key=None): 47 self.valid = (hostnames is not None) and (key is not None) 48 self.hostnames = hostnames 49 self.key = key 50 51 def from_line(cls, line): 52 """ 53 Parses the given line of text to find the names for the host, 54 the type of key, and the key data. The line is expected to be in the 55 format used by the openssh known_hosts file. 56 57 Lines are expected to not have leading or trailing whitespace. 58 We don't bother to check for comments or empty lines. All of 59 that should be taken care of before sending the line to us. 60 61 @param line: a line from an OpenSSH known_hosts file 62 @type line: str 63 """ 64 fields = line.split(' ') 65 if len(fields) < 3: 66 # Bad number of fields 67 return None 68 fields = fields[:3] 69 70 names, keytype, key = fields 71 names = names.split(',') 72 73 # Decide what kind of key we're looking at and create an object 74 # to hold it accordingly. 75 try: 76 if keytype == 'ssh-rsa': 77 key = RSAKey(data=base64.decodestring(key)) 78 elif keytype == 'ssh-dss': 79 key = DSSKey(data=base64.decodestring(key)) 80 else: 81 return None 82 except binascii.Error, e: 83 raise InvalidHostKey(line, e) 84 85 return cls(names, key) 86 from_line = classmethod(from_line) 87 88 def to_line(self): 89 """ 90 Returns a string in OpenSSH known_hosts file format, or None if 91 the object is not in a valid state. A trailing newline is 92 included. 93 """ 94 if self.valid: 95 return '%s %s %s\n' % (','.join(self.hostnames), self.key.get_name(), 96 self.key.get_base64()) 97 return None 98 99 def __repr__(self): 100 return '<HostKeyEntry %r: %r>' % (self.hostnames, self.key) 101 102 103class HostKeys (UserDict.DictMixin): 104 """ 105 Representation of an openssh-style "known hosts" file. Host keys can be 106 read from one or more files, and then individual hosts can be looked up to 107 verify server keys during SSH negotiation. 108 109 A HostKeys object can be treated like a dict; any dict lookup is equivalent 110 to calling L{lookup}. 111 112 @since: 1.5.3 113 """ 114 115 def __init__(self, filename=None): 116 """ 117 Create a new HostKeys object, optionally loading keys from an openssh 118 style host-key file. 119 120 @param filename: filename to load host keys from, or C{None} 121 @type filename: str 122 """ 123 # emulate a dict of { hostname: { keytype: PKey } } 124 self._entries = [] 125 if filename is not None: 126 self.load(filename) 127 128 def add(self, hostname, keytype, key): 129 """ 130 Add a host key entry to the table. Any existing entry for a 131 C{(hostname, keytype)} pair will be replaced. 132 133 @param hostname: the hostname (or IP) to add 134 @type hostname: str 135 @param keytype: key type (C{"ssh-rsa"} or C{"ssh-dss"}) 136 @type keytype: str 137 @param key: the key to add 138 @type key: L{PKey} 139 """ 140 for e in self._entries: 141 if (hostname in e.hostnames) and (e.key.get_name() == keytype): 142 e.key = key 143 return 144 self._entries.append(HostKeyEntry([hostname], key)) 145 146 def load(self, filename): 147 """ 148 Read a file of known SSH host keys, in the format used by openssh. 149 This type of file unfortunately doesn't exist on Windows, but on 150 posix, it will usually be stored in 151 C{os.path.expanduser("~/.ssh/known_hosts")}. 152 153 If this method is called multiple times, the host keys are merged, 154 not cleared. So multiple calls to C{load} will just call L{add}, 155 replacing any existing entries and adding new ones. 156 157 @param filename: name of the file to read host keys from 158 @type filename: str 159 160 @raise IOError: if there was an error reading the file 161 """ 162 f = open(filename, 'r') 163 for line in f: 164 line = line.strip() 165 if (len(line) == 0) or (line[0] == '#'): 166 continue 167 e = HostKeyEntry.from_line(line) 168 if e is not None: 169 self._entries.append(e) 170 f.close() 171 172 def save(self, filename): 173 """ 174 Save host keys into a file, in the format used by openssh. The order of 175 keys in the file will be preserved when possible (if these keys were 176 loaded from a file originally). The single exception is that combined 177 lines will be split into individual key lines, which is arguably a bug. 178 179 @param filename: name of the file to write 180 @type filename: str 181 182 @raise IOError: if there was an error writing the file 183 184 @since: 1.6.1 185 """ 186 f = open(filename, 'w') 187 for e in self._entries: 188 line = e.to_line() 189 if line: 190 f.write(line) 191 f.close() 192 193 def lookup(self, hostname): 194 """ 195 Find a hostkey entry for a given hostname or IP. If no entry is found, 196 C{None} is returned. Otherwise a dictionary of keytype to key is 197 returned. The keytype will be either C{"ssh-rsa"} or C{"ssh-dss"}. 198 199 @param hostname: the hostname (or IP) to lookup 200 @type hostname: str 201 @return: keys associated with this host (or C{None}) 202 @rtype: dict(str, L{PKey}) 203 """ 204 class SubDict (UserDict.DictMixin): 205 def __init__(self, hostname, entries, hostkeys): 206 self._hostname = hostname 207 self._entries = entries 208 self._hostkeys = hostkeys 209 210 def __getitem__(self, key): 211 for e in self._entries: 212 if e.key.get_name() == key: 213 return e.key 214 raise KeyError(key) 215 216 def __setitem__(self, key, val): 217 for e in self._entries: 218 if e.key is None: 219 continue 220 if e.key.get_name() == key: 221 # replace 222 e.key = val 223 break 224 else: 225 # add a new one 226 e = HostKeyEntry([hostname], val) 227 self._entries.append(e) 228 self._hostkeys._entries.append(e) 229 230 def keys(self): 231 return [e.key.get_name() for e in self._entries if e.key is not None] 232 233 entries = [] 234 for e in self._entries: 235 for h in e.hostnames: 236 if (h.startswith('|1|') and (self.hash_host(hostname, h) == h)) or (h == hostname): 237 entries.append(e) 238 if len(entries) == 0: 239 return None 240 return SubDict(hostname, entries, self) 241 242 def check(self, hostname, key): 243 """ 244 Return True if the given key is associated with the given hostname 245 in this dictionary. 246 247 @param hostname: hostname (or IP) of the SSH server 248 @type hostname: str 249 @param key: the key to check 250 @type key: L{PKey} 251 @return: C{True} if the key is associated with the hostname; C{False} 252 if not 253 @rtype: bool 254 """ 255 k = self.lookup(hostname) 256 if k is None: 257 return False 258 host_key = k.get(key.get_name(), None) 259 if host_key is None: 260 return False 261 return str(host_key) == str(key) 262 263 def clear(self): 264 """ 265 Remove all host keys from the dictionary. 266 """ 267 self._entries = [] 268 269 def __getitem__(self, key): 270 ret = self.lookup(key) 271 if ret is None: 272 raise KeyError(key) 273 return ret 274 275 def __setitem__(self, hostname, entry): 276 # don't use this please. 277 if len(entry) == 0: 278 self._entries.append(HostKeyEntry([hostname], None)) 279 return 280 for key_type in entry.keys(): 281 found = False 282 for e in self._entries: 283 if (hostname in e.hostnames) and (e.key.get_name() == key_type): 284 # replace 285 e.key = entry[key_type] 286 found = True 287 if not found: 288 self._entries.append(HostKeyEntry([hostname], entry[key_type])) 289 290 def keys(self): 291 # python 2.4 sets would be nice here. 292 ret = [] 293 for e in self._entries: 294 for h in e.hostnames: 295 if h not in ret: 296 ret.append(h) 297 return ret 298 299 def values(self): 300 ret = [] 301 for k in self.keys(): 302 ret.append(self.lookup(k)) 303 return ret 304 305 def hash_host(hostname, salt=None): 306 """ 307 Return a "hashed" form of the hostname, as used by openssh when storing 308 hashed hostnames in the known_hosts file. 309 310 @param hostname: the hostname to hash 311 @type hostname: str 312 @param salt: optional salt to use when hashing (must be 20 bytes long) 313 @type salt: str 314 @return: the hashed hostname 315 @rtype: str 316 """ 317 if salt is None: 318 salt = rng.read(SHA.digest_size) 319 else: 320 if salt.startswith('|1|'): 321 salt = salt.split('|')[2] 322 salt = base64.decodestring(salt) 323 assert len(salt) == SHA.digest_size 324 hmac = HMAC.HMAC(salt, hostname, SHA).digest() 325 hostkey = '|1|%s|%s' % (base64.encodestring(salt), base64.encodestring(hmac)) 326 return hostkey.replace('\n', '') 327 hash_host = staticmethod(hash_host) 328 329