1# Copyright (C) 2006-2007 Jeff Forcier <jeff@bitprophet.org> 2# 3# This file is part of ssh. 4# 5# 'ssh' is free software; you can redistribute it and/or modify it under the 6# terms of the GNU Lesser General Public License as published by the Free 7# Software Foundation; either version 2.1 of the License, or (at your option) 8# any later version. 9# 10# 'ssh' is distrubuted in the hope that it will be useful, but WITHOUT ANY 11# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR 12# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more 13# details. 14# 15# You should have received a copy of the GNU Lesser General Public License 16# along with 'ssh'; if not, write to the Free Software Foundation, Inc., 17# 51 Franklin Street, Suite 500, Boston, MA 02110-1335 USA. 18 19""" 20L{SSHClient}. 21""" 22 23from binascii import hexlify 24import getpass 25import os 26import socket 27import warnings 28 29from ssh.agent import Agent 30from ssh.common import * 31from ssh.dsskey import DSSKey 32from ssh.hostkeys import HostKeys 33from ssh.resource import ResourceManager 34from ssh.rsakey import RSAKey 35from ssh.ssh_exception import SSHException, BadHostKeyException 36from ssh.transport import Transport 37from ssh.util import retry_on_signal 38 39 40SSH_PORT = 22 41 42class MissingHostKeyPolicy (object): 43 """ 44 Interface for defining the policy that L{SSHClient} should use when the 45 SSH server's hostname is not in either the system host keys or the 46 application's keys. Pre-made classes implement policies for automatically 47 adding the key to the application's L{HostKeys} object (L{AutoAddPolicy}), 48 and for automatically rejecting the key (L{RejectPolicy}). 49 50 This function may be used to ask the user to verify the key, for example. 51 """ 52 53 def missing_host_key(self, client, hostname, key): 54 """ 55 Called when an L{SSHClient} receives a server key for a server that 56 isn't in either the system or local L{HostKeys} object. To accept 57 the key, simply return. To reject, raised an exception (which will 58 be passed to the calling application). 59 """ 60 pass 61 62 63class AutoAddPolicy (MissingHostKeyPolicy): 64 """ 65 Policy for automatically adding the hostname and new host key to the 66 local L{HostKeys} object, and saving it. This is used by L{SSHClient}. 67 """ 68 69 def missing_host_key(self, client, hostname, key): 70 client._host_keys.add(hostname, key.get_name(), key) 71 if client._host_keys_filename is not None: 72 client.save_host_keys(client._host_keys_filename) 73 client._log(DEBUG, 'Adding %s host key for %s: %s' % 74 (key.get_name(), hostname, hexlify(key.get_fingerprint()))) 75 76 77class RejectPolicy (MissingHostKeyPolicy): 78 """ 79 Policy for automatically rejecting the unknown hostname & key. This is 80 used by L{SSHClient}. 81 """ 82 83 def missing_host_key(self, client, hostname, key): 84 client._log(DEBUG, 'Rejecting %s host key for %s: %s' % 85 (key.get_name(), hostname, hexlify(key.get_fingerprint()))) 86 raise SSHException('Server %r not found in known_hosts' % hostname) 87 88 89class WarningPolicy (MissingHostKeyPolicy): 90 """ 91 Policy for logging a python-style warning for an unknown host key, but 92 accepting it. This is used by L{SSHClient}. 93 """ 94 def missing_host_key(self, client, hostname, key): 95 warnings.warn('Unknown %s host key for %s: %s' % 96 (key.get_name(), hostname, hexlify(key.get_fingerprint()))) 97 98 99class SSHClient (object): 100 """ 101 A high-level representation of a session with an SSH server. This class 102 wraps L{Transport}, L{Channel}, and L{SFTPClient} to take care of most 103 aspects of authenticating and opening channels. A typical use case is:: 104 105 client = SSHClient() 106 client.load_system_host_keys() 107 client.connect('ssh.example.com') 108 stdin, stdout, stderr = client.exec_command('ls -l') 109 110 You may pass in explicit overrides for authentication and server host key 111 checking. The default mechanism is to try to use local key files or an 112 SSH agent (if one is running). 113 114 @since: 1.6 115 """ 116 117 def __init__(self): 118 """ 119 Create a new SSHClient. 120 """ 121 self._system_host_keys = HostKeys() 122 self._host_keys = HostKeys() 123 self._host_keys_filename = None 124 self._log_channel = None 125 self._policy = RejectPolicy() 126 self._transport = None 127 self._agent = None 128 129 def load_system_host_keys(self, filename=None): 130 """ 131 Load host keys from a system (read-only) file. Host keys read with 132 this method will not be saved back by L{save_host_keys}. 133 134 This method can be called multiple times. Each new set of host keys 135 will be merged with the existing set (new replacing old if there are 136 conflicts). 137 138 If C{filename} is left as C{None}, an attempt will be made to read 139 keys from the user's local "known hosts" file, as used by OpenSSH, 140 and no exception will be raised if the file can't be read. This is 141 probably only useful on posix. 142 143 @param filename: the filename to read, or C{None} 144 @type filename: str 145 146 @raise IOError: if a filename was provided and the file could not be 147 read 148 """ 149 if filename is None: 150 # try the user's .ssh key file, and mask exceptions 151 filename = os.path.expanduser('~/.ssh/known_hosts') 152 try: 153 self._system_host_keys.load(filename) 154 except IOError: 155 pass 156 return 157 self._system_host_keys.load(filename) 158 159 def load_host_keys(self, filename): 160 """ 161 Load host keys from a local host-key file. Host keys read with this 162 method will be checked I{after} keys loaded via L{load_system_host_keys}, 163 but will be saved back by L{save_host_keys} (so they can be modified). 164 The missing host key policy L{AutoAddPolicy} adds keys to this set and 165 saves them, when connecting to a previously-unknown server. 166 167 This method can be called multiple times. Each new set of host keys 168 will be merged with the existing set (new replacing old if there are 169 conflicts). When automatically saving, the last hostname is used. 170 171 @param filename: the filename to read 172 @type filename: str 173 174 @raise IOError: if the filename could not be read 175 """ 176 self._host_keys_filename = filename 177 self._host_keys.load(filename) 178 179 def save_host_keys(self, filename): 180 """ 181 Save the host keys back to a file. Only the host keys loaded with 182 L{load_host_keys} (plus any added directly) will be saved -- not any 183 host keys loaded with L{load_system_host_keys}. 184 185 @param filename: the filename to save to 186 @type filename: str 187 188 @raise IOError: if the file could not be written 189 """ 190 f = open(filename, 'w') 191 f.write('# SSH host keys collected by ssh\n') 192 for hostname, keys in self._host_keys.iteritems(): 193 for keytype, key in keys.iteritems(): 194 f.write('%s %s %s\n' % (hostname, keytype, key.get_base64())) 195 f.close() 196 197 def get_host_keys(self): 198 """ 199 Get the local L{HostKeys} object. This can be used to examine the 200 local host keys or change them. 201 202 @return: the local host keys 203 @rtype: L{HostKeys} 204 """ 205 return self._host_keys 206 207 def set_log_channel(self, name): 208 """ 209 Set the channel for logging. The default is C{"ssh.transport"} 210 but it can be set to anything you want. 211 212 @param name: new channel name for logging 213 @type name: str 214 """ 215 self._log_channel = name 216 217 def set_missing_host_key_policy(self, policy): 218 """ 219 Set the policy to use when connecting to a server that doesn't have a 220 host key in either the system or local L{HostKeys} objects. The 221 default policy is to reject all unknown servers (using L{RejectPolicy}). 222 You may substitute L{AutoAddPolicy} or write your own policy class. 223 224 @param policy: the policy to use when receiving a host key from a 225 previously-unknown server 226 @type policy: L{MissingHostKeyPolicy} 227 """ 228 self._policy = policy 229 230 def connect(self, hostname, port=SSH_PORT, username=None, password=None, pkey=None, 231 key_filename=None, timeout=None, allow_agent=True, look_for_keys=True, 232 compress=False): 233 """ 234 Connect to an SSH server and authenticate to it. The server's host key 235 is checked against the system host keys (see L{load_system_host_keys}) 236 and any local host keys (L{load_host_keys}). If the server's hostname 237 is not found in either set of host keys, the missing host key policy 238 is used (see L{set_missing_host_key_policy}). The default policy is 239 to reject the key and raise an L{SSHException}. 240 241 Authentication is attempted in the following order of priority: 242 243 - The C{pkey} or C{key_filename} passed in (if any) 244 - Any key we can find through an SSH agent 245 - Any "id_rsa" or "id_dsa" key discoverable in C{~/.ssh/} 246 - Plain username/password auth, if a password was given 247 248 If a private key requires a password to unlock it, and a password is 249 passed in, that password will be used to attempt to unlock the key. 250 251 @param hostname: the server to connect to 252 @type hostname: str 253 @param port: the server port to connect to 254 @type port: int 255 @param username: the username to authenticate as (defaults to the 256 current local username) 257 @type username: str 258 @param password: a password to use for authentication or for unlocking 259 a private key 260 @type password: str 261 @param pkey: an optional private key to use for authentication 262 @type pkey: L{PKey} 263 @param key_filename: the filename, or list of filenames, of optional 264 private key(s) to try for authentication 265 @type key_filename: str or list(str) 266 @param timeout: an optional timeout (in seconds) for the TCP connect 267 @type timeout: float 268 @param allow_agent: set to False to disable connecting to the SSH agent 269 @type allow_agent: bool 270 @param look_for_keys: set to False to disable searching for discoverable 271 private key files in C{~/.ssh/} 272 @type look_for_keys: bool 273 @param compress: set to True to turn on compression 274 @type compress: bool 275 276 @raise BadHostKeyException: if the server's host key could not be 277 verified 278 @raise AuthenticationException: if authentication failed 279 @raise SSHException: if there was any other error connecting or 280 establishing an SSH session 281 @raise socket.error: if a socket error occurred while connecting 282 """ 283 for (family, socktype, proto, canonname, sockaddr) in socket.getaddrinfo(hostname, port, socket.AF_UNSPEC, socket.SOCK_STREAM): 284 if socktype == socket.SOCK_STREAM: 285 af = family 286 addr = sockaddr 287 break 288 else: 289 # some OS like AIX don't indicate SOCK_STREAM support, so just guess. :( 290 af, _, _, _, addr = socket.getaddrinfo(hostname, port, socket.AF_UNSPEC, socket.SOCK_STREAM) 291 sock = socket.socket(af, socket.SOCK_STREAM) 292 if timeout is not None: 293 try: 294 sock.settimeout(timeout) 295 except: 296 pass 297 retry_on_signal(lambda: sock.connect(addr)) 298 t = self._transport = Transport(sock) 299 t.use_compression(compress=compress) 300 if self._log_channel is not None: 301 t.set_log_channel(self._log_channel) 302 t.start_client() 303 ResourceManager.register(self, t) 304 305 server_key = t.get_remote_server_key() 306 keytype = server_key.get_name() 307 308 if port == SSH_PORT: 309 server_hostkey_name = hostname 310 else: 311 server_hostkey_name = "[%s]:%d" % (hostname, port) 312 our_server_key = self._system_host_keys.get(server_hostkey_name, {}).get(keytype, None) 313 if our_server_key is None: 314 our_server_key = self._host_keys.get(server_hostkey_name, {}).get(keytype, None) 315 if our_server_key is None: 316 # will raise exception if the key is rejected; let that fall out 317 self._policy.missing_host_key(self, server_hostkey_name, server_key) 318 # if the callback returns, assume the key is ok 319 our_server_key = server_key 320 321 if server_key != our_server_key: 322 raise BadHostKeyException(hostname, server_key, our_server_key) 323 324 if username is None: 325 username = getpass.getuser() 326 327 if key_filename is None: 328 key_filenames = [] 329 elif isinstance(key_filename, (str, unicode)): 330 key_filenames = [ key_filename ] 331 else: 332 key_filenames = key_filename 333 self._auth(username, password, pkey, key_filenames, allow_agent, look_for_keys) 334 335 def close(self): 336 """ 337 Close this SSHClient and its underlying L{Transport}. 338 """ 339 if self._transport is None: 340 return 341 self._transport.close() 342 self._transport = None 343 344 if self._agent != None: 345 self._agent.close() 346 self._agent = None 347 348 def exec_command(self, command, bufsize=-1): 349 """ 350 Execute a command on the SSH server. A new L{Channel} is opened and 351 the requested command is executed. The command's input and output 352 streams are returned as python C{file}-like objects representing 353 stdin, stdout, and stderr. 354 355 @param command: the command to execute 356 @type command: str 357 @param bufsize: interpreted the same way as by the built-in C{file()} function in python 358 @type bufsize: int 359 @return: the stdin, stdout, and stderr of the executing command 360 @rtype: tuple(L{ChannelFile}, L{ChannelFile}, L{ChannelFile}) 361 362 @raise SSHException: if the server fails to execute the command 363 """ 364 chan = self._transport.open_session() 365 chan.exec_command(command) 366 stdin = chan.makefile('wb', bufsize) 367 stdout = chan.makefile('rb', bufsize) 368 stderr = chan.makefile_stderr('rb', bufsize) 369 return stdin, stdout, stderr 370 371 def invoke_shell(self, term='vt100', width=80, height=24): 372 """ 373 Start an interactive shell session on the SSH server. A new L{Channel} 374 is opened and connected to a pseudo-terminal using the requested 375 terminal type and size. 376 377 @param term: the terminal type to emulate (for example, C{"vt100"}) 378 @type term: str 379 @param width: the width (in characters) of the terminal window 380 @type width: int 381 @param height: the height (in characters) of the terminal window 382 @type height: int 383 @return: a new channel connected to the remote shell 384 @rtype: L{Channel} 385 386 @raise SSHException: if the server fails to invoke a shell 387 """ 388 chan = self._transport.open_session() 389 chan.get_pty(term, width, height) 390 chan.invoke_shell() 391 return chan 392 393 def open_sftp(self): 394 """ 395 Open an SFTP session on the SSH server. 396 397 @return: a new SFTP session object 398 @rtype: L{SFTPClient} 399 """ 400 return self._transport.open_sftp_client() 401 402 def get_transport(self): 403 """ 404 Return the underlying L{Transport} object for this SSH connection. 405 This can be used to perform lower-level tasks, like opening specific 406 kinds of channels. 407 408 @return: the Transport for this connection 409 @rtype: L{Transport} 410 """ 411 return self._transport 412 413 def _auth(self, username, password, pkey, key_filenames, allow_agent, look_for_keys): 414 """ 415 Try, in order: 416 417 - The key passed in, if one was passed in. 418 - Any key we can find through an SSH agent (if allowed). 419 - Any "id_rsa" or "id_dsa" key discoverable in ~/.ssh/ (if allowed). 420 - Plain username/password auth, if a password was given. 421 422 (The password might be needed to unlock a private key.) 423 424 The password is required for two-factor authentication. 425 """ 426 saved_exception = None 427 two_factor = False 428 allowed_types = [] 429 430 if pkey is not None: 431 try: 432 self._log(DEBUG, 'Trying SSH key %s' % hexlify(pkey.get_fingerprint())) 433 allowed_types = self._transport.auth_publickey(username, pkey) 434 two_factor = (allowed_types == ['password']) 435 if not two_factor: 436 return 437 except SSHException, e: 438 saved_exception = e 439 440 if not two_factor: 441 for key_filename in key_filenames: 442 for pkey_class in (RSAKey, DSSKey): 443 try: 444 key = pkey_class.from_private_key_file(key_filename, password) 445 self._log(DEBUG, 'Trying key %s from %s' % (hexlify(key.get_fingerprint()), key_filename)) 446 self._transport.auth_publickey(username, key) 447 two_factor = (allowed_types == ['password']) 448 if not two_factor: 449 return 450 break 451 except SSHException, e: 452 saved_exception = e 453 454 if not two_factor and allow_agent: 455 if self._agent == None: 456 self._agent = Agent() 457 458 for key in self._agent.get_keys(): 459 try: 460 self._log(DEBUG, 'Trying SSH agent key %s' % hexlify(key.get_fingerprint())) 461 # for 2-factor auth a successfully auth'd key will result in ['password'] 462 allowed_types = self._transport.auth_publickey(username, key) 463 two_factor = (allowed_types == ['password']) 464 if not two_factor: 465 return 466 break 467 except SSHException, e: 468 saved_exception = e 469 470 if not two_factor: 471 keyfiles = [] 472 rsa_key = os.path.expanduser('~/.ssh/id_rsa') 473 dsa_key = os.path.expanduser('~/.ssh/id_dsa') 474 if os.path.isfile(rsa_key): 475 keyfiles.append((RSAKey, rsa_key)) 476 if os.path.isfile(dsa_key): 477 keyfiles.append((DSSKey, dsa_key)) 478 # look in ~/ssh/ for windows users: 479 rsa_key = os.path.expanduser('~/ssh/id_rsa') 480 dsa_key = os.path.expanduser('~/ssh/id_dsa') 481 if os.path.isfile(rsa_key): 482 keyfiles.append((RSAKey, rsa_key)) 483 if os.path.isfile(dsa_key): 484 keyfiles.append((DSSKey, dsa_key)) 485 486 if not look_for_keys: 487 keyfiles = [] 488 489 for pkey_class, filename in keyfiles: 490 try: 491 key = pkey_class.from_private_key_file(filename, password) 492 self._log(DEBUG, 'Trying discovered key %s in %s' % (hexlify(key.get_fingerprint()), filename)) 493 # for 2-factor auth a successfully auth'd key will result in ['password'] 494 allowed_types = self._transport.auth_publickey(username, key) 495 two_factor = (allowed_types == ['password']) 496 if not two_factor: 497 return 498 break 499 except SSHException, e: 500 saved_exception = e 501 except IOError, e: 502 saved_exception = e 503 504 if password is not None: 505 try: 506 self._transport.auth_password(username, password) 507 return 508 except SSHException, e: 509 saved_exception = e 510 elif two_factor: 511 raise SSHException('Two-factor authentication requires a password') 512 513 # if we got an auth-failed exception earlier, re-raise it 514 if saved_exception is not None: 515 raise saved_exception 516 raise SSHException('No authentication methods available') 517 518 def _log(self, level, msg): 519 self._transport._log(level, msg) 520 521