1# Copyright (C) 2006-2007  Jeff Forcier <jeff@bitprophet.org>
2#
3# This file is part of ssh.
4#
5# 'ssh' is free software; you can redistribute it and/or modify it under the
6# terms of the GNU Lesser General Public License as published by the Free
7# Software Foundation; either version 2.1 of the License, or (at your option)
8# any later version.
9#
10# 'ssh' is distrubuted in the hope that it will be useful, but WITHOUT ANY
11# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR
12# A PARTICULAR PURPOSE.  See the GNU Lesser General Public License for more
13# details.
14#
15# You should have received a copy of the GNU Lesser General Public License
16# along with 'ssh'; if not, write to the Free Software Foundation, Inc.,
17# 51 Franklin Street, Suite 500, Boston, MA  02110-1335  USA.
18
19"""
20L{SSHClient}.
21"""
22
23from binascii import hexlify
24import getpass
25import os
26import socket
27import warnings
28
29from ssh.agent import Agent
30from ssh.common import *
31from ssh.dsskey import DSSKey
32from ssh.hostkeys import HostKeys
33from ssh.resource import ResourceManager
34from ssh.rsakey import RSAKey
35from ssh.ssh_exception import SSHException, BadHostKeyException
36from ssh.transport import Transport
37from ssh.util import retry_on_signal
38
39
40SSH_PORT = 22
41
42class MissingHostKeyPolicy (object):
43    """
44    Interface for defining the policy that L{SSHClient} should use when the
45    SSH server's hostname is not in either the system host keys or the
46    application's keys.  Pre-made classes implement policies for automatically
47    adding the key to the application's L{HostKeys} object (L{AutoAddPolicy}),
48    and for automatically rejecting the key (L{RejectPolicy}).
49
50    This function may be used to ask the user to verify the key, for example.
51    """
52
53    def missing_host_key(self, client, hostname, key):
54        """
55        Called when an L{SSHClient} receives a server key for a server that
56        isn't in either the system or local L{HostKeys} object.  To accept
57        the key, simply return.  To reject, raised an exception (which will
58        be passed to the calling application).
59        """
60        pass
61
62
63class AutoAddPolicy (MissingHostKeyPolicy):
64    """
65    Policy for automatically adding the hostname and new host key to the
66    local L{HostKeys} object, and saving it.  This is used by L{SSHClient}.
67    """
68
69    def missing_host_key(self, client, hostname, key):
70        client._host_keys.add(hostname, key.get_name(), key)
71        if client._host_keys_filename is not None:
72            client.save_host_keys(client._host_keys_filename)
73        client._log(DEBUG, 'Adding %s host key for %s: %s' %
74                    (key.get_name(), hostname, hexlify(key.get_fingerprint())))
75
76
77class RejectPolicy (MissingHostKeyPolicy):
78    """
79    Policy for automatically rejecting the unknown hostname & key.  This is
80    used by L{SSHClient}.
81    """
82
83    def missing_host_key(self, client, hostname, key):
84        client._log(DEBUG, 'Rejecting %s host key for %s: %s' %
85                    (key.get_name(), hostname, hexlify(key.get_fingerprint())))
86        raise SSHException('Server %r not found in known_hosts' % hostname)
87
88
89class WarningPolicy (MissingHostKeyPolicy):
90    """
91    Policy for logging a python-style warning for an unknown host key, but
92    accepting it. This is used by L{SSHClient}.
93    """
94    def missing_host_key(self, client, hostname, key):
95        warnings.warn('Unknown %s host key for %s: %s' %
96                      (key.get_name(), hostname, hexlify(key.get_fingerprint())))
97
98
99class SSHClient (object):
100    """
101    A high-level representation of a session with an SSH server.  This class
102    wraps L{Transport}, L{Channel}, and L{SFTPClient} to take care of most
103    aspects of authenticating and opening channels.  A typical use case is::
104
105        client = SSHClient()
106        client.load_system_host_keys()
107        client.connect('ssh.example.com')
108        stdin, stdout, stderr = client.exec_command('ls -l')
109
110    You may pass in explicit overrides for authentication and server host key
111    checking.  The default mechanism is to try to use local key files or an
112    SSH agent (if one is running).
113
114    @since: 1.6
115    """
116
117    def __init__(self):
118        """
119        Create a new SSHClient.
120        """
121        self._system_host_keys = HostKeys()
122        self._host_keys = HostKeys()
123        self._host_keys_filename = None
124        self._log_channel = None
125        self._policy = RejectPolicy()
126        self._transport = None
127        self._agent = None
128
129    def load_system_host_keys(self, filename=None):
130        """
131        Load host keys from a system (read-only) file.  Host keys read with
132        this method will not be saved back by L{save_host_keys}.
133
134        This method can be called multiple times.  Each new set of host keys
135        will be merged with the existing set (new replacing old if there are
136        conflicts).
137
138        If C{filename} is left as C{None}, an attempt will be made to read
139        keys from the user's local "known hosts" file, as used by OpenSSH,
140        and no exception will be raised if the file can't be read.  This is
141        probably only useful on posix.
142
143        @param filename: the filename to read, or C{None}
144        @type filename: str
145
146        @raise IOError: if a filename was provided and the file could not be
147            read
148        """
149        if filename is None:
150            # try the user's .ssh key file, and mask exceptions
151            filename = os.path.expanduser('~/.ssh/known_hosts')
152            try:
153                self._system_host_keys.load(filename)
154            except IOError:
155                pass
156            return
157        self._system_host_keys.load(filename)
158
159    def load_host_keys(self, filename):
160        """
161        Load host keys from a local host-key file.  Host keys read with this
162        method will be checked I{after} keys loaded via L{load_system_host_keys},
163        but will be saved back by L{save_host_keys} (so they can be modified).
164        The missing host key policy L{AutoAddPolicy} adds keys to this set and
165        saves them, when connecting to a previously-unknown server.
166
167        This method can be called multiple times.  Each new set of host keys
168        will be merged with the existing set (new replacing old if there are
169        conflicts).  When automatically saving, the last hostname is used.
170
171        @param filename: the filename to read
172        @type filename: str
173
174        @raise IOError: if the filename could not be read
175        """
176        self._host_keys_filename = filename
177        self._host_keys.load(filename)
178
179    def save_host_keys(self, filename):
180        """
181        Save the host keys back to a file.  Only the host keys loaded with
182        L{load_host_keys} (plus any added directly) will be saved -- not any
183        host keys loaded with L{load_system_host_keys}.
184
185        @param filename: the filename to save to
186        @type filename: str
187
188        @raise IOError: if the file could not be written
189        """
190        f = open(filename, 'w')
191        f.write('# SSH host keys collected by ssh\n')
192        for hostname, keys in self._host_keys.iteritems():
193            for keytype, key in keys.iteritems():
194                f.write('%s %s %s\n' % (hostname, keytype, key.get_base64()))
195        f.close()
196
197    def get_host_keys(self):
198        """
199        Get the local L{HostKeys} object.  This can be used to examine the
200        local host keys or change them.
201
202        @return: the local host keys
203        @rtype: L{HostKeys}
204        """
205        return self._host_keys
206
207    def set_log_channel(self, name):
208        """
209        Set the channel for logging.  The default is C{"ssh.transport"}
210        but it can be set to anything you want.
211
212        @param name: new channel name for logging
213        @type name: str
214        """
215        self._log_channel = name
216
217    def set_missing_host_key_policy(self, policy):
218        """
219        Set the policy to use when connecting to a server that doesn't have a
220        host key in either the system or local L{HostKeys} objects.  The
221        default policy is to reject all unknown servers (using L{RejectPolicy}).
222        You may substitute L{AutoAddPolicy} or write your own policy class.
223
224        @param policy: the policy to use when receiving a host key from a
225            previously-unknown server
226        @type policy: L{MissingHostKeyPolicy}
227        """
228        self._policy = policy
229
230    def connect(self, hostname, port=SSH_PORT, username=None, password=None, pkey=None,
231                key_filename=None, timeout=None, allow_agent=True, look_for_keys=True,
232                compress=False):
233        """
234        Connect to an SSH server and authenticate to it.  The server's host key
235        is checked against the system host keys (see L{load_system_host_keys})
236        and any local host keys (L{load_host_keys}).  If the server's hostname
237        is not found in either set of host keys, the missing host key policy
238        is used (see L{set_missing_host_key_policy}).  The default policy is
239        to reject the key and raise an L{SSHException}.
240
241        Authentication is attempted in the following order of priority:
242
243            - The C{pkey} or C{key_filename} passed in (if any)
244            - Any key we can find through an SSH agent
245            - Any "id_rsa" or "id_dsa" key discoverable in C{~/.ssh/}
246            - Plain username/password auth, if a password was given
247
248        If a private key requires a password to unlock it, and a password is
249        passed in, that password will be used to attempt to unlock the key.
250
251        @param hostname: the server to connect to
252        @type hostname: str
253        @param port: the server port to connect to
254        @type port: int
255        @param username: the username to authenticate as (defaults to the
256            current local username)
257        @type username: str
258        @param password: a password to use for authentication or for unlocking
259            a private key
260        @type password: str
261        @param pkey: an optional private key to use for authentication
262        @type pkey: L{PKey}
263        @param key_filename: the filename, or list of filenames, of optional
264            private key(s) to try for authentication
265        @type key_filename: str or list(str)
266        @param timeout: an optional timeout (in seconds) for the TCP connect
267        @type timeout: float
268        @param allow_agent: set to False to disable connecting to the SSH agent
269        @type allow_agent: bool
270        @param look_for_keys: set to False to disable searching for discoverable
271            private key files in C{~/.ssh/}
272        @type look_for_keys: bool
273        @param compress: set to True to turn on compression
274        @type compress: bool
275
276        @raise BadHostKeyException: if the server's host key could not be
277            verified
278        @raise AuthenticationException: if authentication failed
279        @raise SSHException: if there was any other error connecting or
280            establishing an SSH session
281        @raise socket.error: if a socket error occurred while connecting
282        """
283        for (family, socktype, proto, canonname, sockaddr) in socket.getaddrinfo(hostname, port, socket.AF_UNSPEC, socket.SOCK_STREAM):
284            if socktype == socket.SOCK_STREAM:
285                af = family
286                addr = sockaddr
287                break
288        else:
289            # some OS like AIX don't indicate SOCK_STREAM support, so just guess. :(
290            af, _, _, _, addr = socket.getaddrinfo(hostname, port, socket.AF_UNSPEC, socket.SOCK_STREAM)
291        sock = socket.socket(af, socket.SOCK_STREAM)
292        if timeout is not None:
293            try:
294                sock.settimeout(timeout)
295            except:
296                pass
297        retry_on_signal(lambda: sock.connect(addr))
298        t = self._transport = Transport(sock)
299        t.use_compression(compress=compress)
300        if self._log_channel is not None:
301            t.set_log_channel(self._log_channel)
302        t.start_client()
303        ResourceManager.register(self, t)
304
305        server_key = t.get_remote_server_key()
306        keytype = server_key.get_name()
307
308        if port == SSH_PORT:
309            server_hostkey_name = hostname
310        else:
311            server_hostkey_name = "[%s]:%d" % (hostname, port)
312        our_server_key = self._system_host_keys.get(server_hostkey_name, {}).get(keytype, None)
313        if our_server_key is None:
314            our_server_key = self._host_keys.get(server_hostkey_name, {}).get(keytype, None)
315        if our_server_key is None:
316            # will raise exception if the key is rejected; let that fall out
317            self._policy.missing_host_key(self, server_hostkey_name, server_key)
318            # if the callback returns, assume the key is ok
319            our_server_key = server_key
320
321        if server_key != our_server_key:
322            raise BadHostKeyException(hostname, server_key, our_server_key)
323
324        if username is None:
325            username = getpass.getuser()
326
327        if key_filename is None:
328            key_filenames = []
329        elif isinstance(key_filename, (str, unicode)):
330            key_filenames = [ key_filename ]
331        else:
332            key_filenames = key_filename
333        self._auth(username, password, pkey, key_filenames, allow_agent, look_for_keys)
334
335    def close(self):
336        """
337        Close this SSHClient and its underlying L{Transport}.
338        """
339        if self._transport is None:
340            return
341        self._transport.close()
342        self._transport = None
343
344        if self._agent != None:
345            self._agent.close()
346            self._agent = None
347
348    def exec_command(self, command, bufsize=-1):
349        """
350        Execute a command on the SSH server.  A new L{Channel} is opened and
351        the requested command is executed.  The command's input and output
352        streams are returned as python C{file}-like objects representing
353        stdin, stdout, and stderr.
354
355        @param command: the command to execute
356        @type command: str
357        @param bufsize: interpreted the same way as by the built-in C{file()} function in python
358        @type bufsize: int
359        @return: the stdin, stdout, and stderr of the executing command
360        @rtype: tuple(L{ChannelFile}, L{ChannelFile}, L{ChannelFile})
361
362        @raise SSHException: if the server fails to execute the command
363        """
364        chan = self._transport.open_session()
365        chan.exec_command(command)
366        stdin = chan.makefile('wb', bufsize)
367        stdout = chan.makefile('rb', bufsize)
368        stderr = chan.makefile_stderr('rb', bufsize)
369        return stdin, stdout, stderr
370
371    def invoke_shell(self, term='vt100', width=80, height=24):
372        """
373        Start an interactive shell session on the SSH server.  A new L{Channel}
374        is opened and connected to a pseudo-terminal using the requested
375        terminal type and size.
376
377        @param term: the terminal type to emulate (for example, C{"vt100"})
378        @type term: str
379        @param width: the width (in characters) of the terminal window
380        @type width: int
381        @param height: the height (in characters) of the terminal window
382        @type height: int
383        @return: a new channel connected to the remote shell
384        @rtype: L{Channel}
385
386        @raise SSHException: if the server fails to invoke a shell
387        """
388        chan = self._transport.open_session()
389        chan.get_pty(term, width, height)
390        chan.invoke_shell()
391        return chan
392
393    def open_sftp(self):
394        """
395        Open an SFTP session on the SSH server.
396
397        @return: a new SFTP session object
398        @rtype: L{SFTPClient}
399        """
400        return self._transport.open_sftp_client()
401
402    def get_transport(self):
403        """
404        Return the underlying L{Transport} object for this SSH connection.
405        This can be used to perform lower-level tasks, like opening specific
406        kinds of channels.
407
408        @return: the Transport for this connection
409        @rtype: L{Transport}
410        """
411        return self._transport
412
413    def _auth(self, username, password, pkey, key_filenames, allow_agent, look_for_keys):
414        """
415        Try, in order:
416
417            - The key passed in, if one was passed in.
418            - Any key we can find through an SSH agent (if allowed).
419            - Any "id_rsa" or "id_dsa" key discoverable in ~/.ssh/ (if allowed).
420            - Plain username/password auth, if a password was given.
421
422        (The password might be needed to unlock a private key.)
423
424        The password is required for two-factor authentication.
425        """
426        saved_exception = None
427        two_factor = False
428        allowed_types = []
429
430        if pkey is not None:
431            try:
432                self._log(DEBUG, 'Trying SSH key %s' % hexlify(pkey.get_fingerprint()))
433                allowed_types = self._transport.auth_publickey(username, pkey)
434                two_factor = (allowed_types == ['password'])
435                if not two_factor:
436                    return
437            except SSHException, e:
438                saved_exception = e
439
440        if not two_factor:
441            for key_filename in key_filenames:
442                for pkey_class in (RSAKey, DSSKey):
443                    try:
444                        key = pkey_class.from_private_key_file(key_filename, password)
445                        self._log(DEBUG, 'Trying key %s from %s' % (hexlify(key.get_fingerprint()), key_filename))
446                        self._transport.auth_publickey(username, key)
447                        two_factor = (allowed_types == ['password'])
448                        if not two_factor:
449                            return
450                        break
451                    except SSHException, e:
452                        saved_exception = e
453
454        if not two_factor and allow_agent:
455            if self._agent == None:
456                self._agent = Agent()
457
458            for key in self._agent.get_keys():
459                try:
460                    self._log(DEBUG, 'Trying SSH agent key %s' % hexlify(key.get_fingerprint()))
461                    # for 2-factor auth a successfully auth'd key will result in ['password']
462                    allowed_types = self._transport.auth_publickey(username, key)
463                    two_factor = (allowed_types == ['password'])
464                    if not two_factor:
465                        return
466                    break
467                except SSHException, e:
468                    saved_exception = e
469
470        if not two_factor:
471            keyfiles = []
472            rsa_key = os.path.expanduser('~/.ssh/id_rsa')
473            dsa_key = os.path.expanduser('~/.ssh/id_dsa')
474            if os.path.isfile(rsa_key):
475                keyfiles.append((RSAKey, rsa_key))
476            if os.path.isfile(dsa_key):
477                keyfiles.append((DSSKey, dsa_key))
478            # look in ~/ssh/ for windows users:
479            rsa_key = os.path.expanduser('~/ssh/id_rsa')
480            dsa_key = os.path.expanduser('~/ssh/id_dsa')
481            if os.path.isfile(rsa_key):
482                keyfiles.append((RSAKey, rsa_key))
483            if os.path.isfile(dsa_key):
484                keyfiles.append((DSSKey, dsa_key))
485
486            if not look_for_keys:
487                keyfiles = []
488
489            for pkey_class, filename in keyfiles:
490                try:
491                    key = pkey_class.from_private_key_file(filename, password)
492                    self._log(DEBUG, 'Trying discovered key %s in %s' % (hexlify(key.get_fingerprint()), filename))
493                    # for 2-factor auth a successfully auth'd key will result in ['password']
494                    allowed_types = self._transport.auth_publickey(username, key)
495                    two_factor = (allowed_types == ['password'])
496                    if not two_factor:
497                        return
498                    break
499                except SSHException, e:
500                    saved_exception = e
501                except IOError, e:
502                    saved_exception = e
503
504        if password is not None:
505            try:
506                self._transport.auth_password(username, password)
507                return
508            except SSHException, e:
509                saved_exception = e
510        elif two_factor:
511            raise SSHException('Two-factor authentication requires a password')
512
513        # if we got an auth-failed exception earlier, re-raise it
514        if saved_exception is not None:
515            raise saved_exception
516        raise SSHException('No authentication methods available')
517
518    def _log(self, level, msg):
519        self._transport._log(level, msg)
520
521