1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3"""
4*sshtunnel* - Initiate SSH tunnels via a remote gateway.
5
6``sshtunnel`` works by opening a port forwarding SSH connection in the
7background, using threads.
8
9The connection(s) are closed when explicitly calling the
10:meth:`SSHTunnelForwarder.stop` method or using it as a context.
11
12"""
13
14import os
15import sys
16import socket
17import getpass
18import logging
19import argparse
20import warnings
21import threading
22from select import select
23from binascii import hexlify
24
25import paramiko
26
27if sys.version_info[0] < 3:  # pragma: no cover
28    import Queue as queue
29    import SocketServer as socketserver
30    string_types = basestring,  # noqa
31    input_ = raw_input  # noqa
32else:
33    import queue
34    import socketserver
35    string_types = str
36    input_ = input
37
38
39__version__ = '0.1.5'
40__author__ = 'pahaz'
41
42
43DEFAULT_LOGLEVEL = logging.ERROR  #: default level if no logger passed (ERROR)
44TUNNEL_TIMEOUT = 1.0  #: Timeout (seconds) for tunnel connection
45_DAEMON = False  #: Use daemon threads in connections
46TRACE_LEVEL = 1
47_CONNECTION_COUNTER = 1
48_LOCK = threading.Lock()
49#: Timeout (seconds) for the connection to the SSH gateway, ``None`` to disable
50SSH_TIMEOUT = None
51DEPRECATIONS = {
52    'ssh_address': 'ssh_address_or_host',
53    'ssh_host': 'ssh_address_or_host',
54    'ssh_private_key': 'ssh_pkey',
55    'raise_exception_if_any_forwarder_have_a_problem': 'mute_exceptions'
56}
57
58logging.addLevelName(TRACE_LEVEL, 'TRACE')
59
60if os.name == 'posix':
61    DEFAULT_SSH_DIRECTORY = '~/.ssh'
62    UnixStreamServer = socketserver.UnixStreamServer
63else:
64    DEFAULT_SSH_DIRECTORY = '~/ssh'
65    UnixStreamServer = socketserver.TCPServer
66
67#: Path of optional ssh configuration file
68SSH_CONFIG_FILE = os.path.join(DEFAULT_SSH_DIRECTORY, 'config')
69
70########################
71#                      #
72#       Utils          #
73#                      #
74########################
75
76
77def check_host(host):
78    assert isinstance(host, string_types), 'IP is not a string ({0})'.format(
79        type(host).__name__
80    )
81
82
83def check_port(port):
84    assert isinstance(port, int), 'PORT is not a number'
85    assert port >= 0, 'PORT < 0 ({0})'.format(port)
86
87
88def check_address(address):
89    """
90    Check if the format of the address is correct
91
92    Arguments:
93        address (tuple):
94            (``str``, ``int``) representing an IP address and port,
95            respectively
96
97            .. note::
98                alternatively a local ``address`` can be a ``str`` when working
99                with UNIX domain sockets, if supported by the platform
100    Raises:
101        ValueError:
102            raised when address has an incorrect format
103
104    Example:
105        >>> check_address(('127.0.0.1', 22))
106    """
107    if isinstance(address, tuple):
108        check_host(address[0])
109        check_port(address[1])
110    elif isinstance(address, string_types):
111        if os.name != 'posix':
112            raise ValueError('Platform does not support UNIX domain sockets')
113        if not (os.path.exists(address) or
114                os.access(os.path.dirname(address), os.W_OK)):
115            raise ValueError('ADDRESS not a valid socket domain socket ({0})'
116                             .format(address))
117    else:
118        raise ValueError('ADDRESS is not a tuple, string, or character buffer '
119                         '({0})'.format(type(address).__name__))
120
121
122def check_addresses(address_list, is_remote=False):
123    """
124    Check if the format of the addresses is correct
125
126    Arguments:
127        address_list (list[tuple]):
128            Sequence of (``str``, ``int``) pairs, each representing an IP
129            address and port respectively
130
131            .. note::
132                when supported by the platform, one or more of the elements in
133                the list can be of type ``str``, representing a valid UNIX
134                domain socket
135
136        is_remote (boolean):
137            Whether or not the address list
138    Raises:
139        AssertionError:
140            raised when ``address_list`` contains an invalid element
141        ValueError:
142            raised when any address in the list has an incorrect format
143
144    Example:
145
146        >>> check_addresses([('127.0.0.1', 22), ('127.0.0.1', 2222)])
147    """
148    assert all(isinstance(x, (tuple, string_types)) for x in address_list)
149    if (is_remote and any(isinstance(x, string_types) for x in address_list)):
150        raise AssertionError('UNIX domain sockets not allowed for remote'
151                             'addresses')
152
153    for address in address_list:
154        check_address(address)
155
156
157def create_logger(logger=None,
158                  loglevel=None,
159                  capture_warnings=True,
160                  add_paramiko_handler=True):
161    """
162    Attach or create a new logger and add a console handler if not present
163
164    Arguments:
165
166        logger (Optional[logging.Logger]):
167            :class:`logging.Logger` instance; a new one is created if this
168            argument is empty
169
170        loglevel (Optional[str or int]):
171            :class:`logging.Logger`'s level, either as a string (i.e.
172            ``ERROR``) or in numeric format (10 == ``DEBUG``)
173
174            .. note:: a value of 1 == ``TRACE`` enables Tracing mode
175
176        capture_warnings (boolean):
177            Enable/disable capturing the events logged by the warnings module
178            into ``logger``'s handlers
179
180            Default: True
181
182            .. note:: ignored in python 2.6
183
184        add_paramiko_handler (boolean):
185            Whether or not add a console handler for ``paramiko.transport``'s
186            logger if no handler present
187
188            Default: True
189    Return:
190        :class:`logging.Logger`
191    """
192    logger = logger or logging.getLogger(
193        '{0}.SSHTunnelForwarder'.format(__name__)
194    )
195    if not any(isinstance(x, logging.Handler) for x in logger.handlers):
196        logger.setLevel(loglevel or DEFAULT_LOGLEVEL)
197        console_handler = logging.StreamHandler()
198        _add_handler(logger,
199                     handler=console_handler,
200                     loglevel=loglevel or DEFAULT_LOGLEVEL)
201    if loglevel:  # override if loglevel was set
202        logger.setLevel(loglevel)
203        for handler in logger.handlers:
204            handler.setLevel(loglevel)
205
206    if add_paramiko_handler:
207        _check_paramiko_handlers(logger=logger)
208
209    if capture_warnings and sys.version_info >= (2, 7):
210        logging.captureWarnings(True)
211        pywarnings = logging.getLogger('py.warnings')
212        pywarnings.handlers.extend(logger.handlers)
213    return logger
214
215
216def _add_handler(logger, handler=None, loglevel=None):
217    """
218    Add a handler to an existing logging.Logger object
219    """
220    handler.setLevel(loglevel or DEFAULT_LOGLEVEL)
221    if handler.level <= logging.DEBUG:
222        _fmt = '%(asctime)s| %(levelname)-4.3s|%(threadName)10.9s/' \
223               '%(lineno)04d@%(module)-10.9s| %(message)s'
224        handler.setFormatter(logging.Formatter(_fmt))
225    else:
226        handler.setFormatter(logging.Formatter(
227            '%(asctime)s| %(levelname)-8s| %(message)s'
228        ))
229    logger.addHandler(handler)
230
231
232def _check_paramiko_handlers(logger=None):
233    """
234    Add a console handler for paramiko.transport's logger if not present
235    """
236    paramiko_logger = logging.getLogger('paramiko.transport')
237    if not paramiko_logger.handlers:
238        if logger:
239            paramiko_logger.handlers = logger.handlers
240        else:
241            console_handler = logging.StreamHandler()
242            console_handler.setFormatter(
243                logging.Formatter('%(asctime)s | %(levelname)-8s| PARAMIKO: '
244                                  '%(lineno)03d@%(module)-10s| %(message)s')
245            )
246            paramiko_logger.addHandler(console_handler)
247
248
249def address_to_str(address):
250    if isinstance(address, tuple):
251        return '{0[0]}:{0[1]}'.format(address)
252    return str(address)
253
254
255def get_connection_id():
256    global _CONNECTION_COUNTER
257    with _LOCK:
258        uid = _CONNECTION_COUNTER
259        _CONNECTION_COUNTER += 1
260    return uid
261
262
263def _remove_none_values(dictionary):
264    """ Remove dictionary keys whose value is None """
265    return list(map(dictionary.pop,
266                    [i for i in dictionary if dictionary[i] is None]))
267
268########################
269#                      #
270#       Errors         #
271#                      #
272########################
273
274
275class BaseSSHTunnelForwarderError(Exception):
276    """ Exception raised by :class:`SSHTunnelForwarder` errors """
277
278    def __init__(self, *args, **kwargs):
279        self.value = kwargs.pop('value', args[0] if args else '')
280
281    def __str__(self):
282        return self.value
283
284
285class HandlerSSHTunnelForwarderError(BaseSSHTunnelForwarderError):
286    """ Exception for Tunnel forwarder errors """
287    pass
288
289
290########################
291#                      #
292#       Handlers       #
293#                      #
294########################
295
296
297class _ForwardHandler(socketserver.BaseRequestHandler):
298    """ Base handler for tunnel connections """
299    remote_address = None
300    ssh_transport = None
301    logger = None
302    info = None
303
304    def _redirect(self, chan):
305        while chan.active:
306            rqst, _, _ = select([self.request, chan], [], [], 5)
307            if self.request in rqst:
308                data = self.request.recv(1024)
309                if not data:
310                    break
311                self.logger.log(TRACE_LEVEL,
312                                '>>> OUT {0} send to {1}: {2} >>>'.format(
313                                    self.info,
314                                    self.remote_address,
315                                    hexlify(data)
316                                ))
317                chan.sendall(data)
318            if chan in rqst:  # else
319                if not chan.recv_ready():
320                    break
321                data = chan.recv(1024)
322                self.logger.log(
323                    TRACE_LEVEL,
324                    '<<< IN {0} recv: {1} <<<'.format(self.info, hexlify(data))
325                )
326                self.request.sendall(data)
327
328    def handle(self):
329        uid = get_connection_id()
330        self.info = '#{0} <-- {1}'.format(uid, self.client_address or
331                                          self.server.local_address)
332        src_address = self.request.getpeername()
333        if not isinstance(src_address, tuple):
334            src_address = ('dummy', 12345)
335        try:
336            chan = self.ssh_transport.open_channel(
337                kind='direct-tcpip',
338                dest_addr=self.remote_address,
339                src_addr=src_address,
340                timeout=TUNNEL_TIMEOUT
341            )
342        except paramiko.SSHException:
343            chan = None
344        if chan is None:
345            msg = '{0} to {1} was rejected by the SSH server'.format(
346                self.info,
347                self.remote_address
348            )
349            self.logger.log(TRACE_LEVEL, msg)
350            raise HandlerSSHTunnelForwarderError(msg)
351
352        self.logger.log(TRACE_LEVEL, '{0} connected'.format(self.info))
353        try:
354            self._redirect(chan)
355        except socket.error:
356            # Sometimes a RST is sent and a socket error is raised, treat this
357            # exception. It was seen that a 3way FIN is processed later on, so
358            # no need to make an ordered close of the connection here or raise
359            # the exception beyond this point...
360            self.logger.log(TRACE_LEVEL, '{0} sending RST'.format(self.info))
361        except Exception as e:
362            self.logger.log(TRACE_LEVEL,
363                            '{0} error: {1}'.format(self.info, repr(e)))
364        finally:
365            chan.close()
366            self.request.close()
367            self.logger.log(TRACE_LEVEL,
368                            '{0} connection closed.'.format(self.info))
369
370
371class _ForwardServer(socketserver.TCPServer):  # Not Threading
372    """
373    Non-threading version of the forward server
374    """
375    allow_reuse_address = True  # faster rebinding
376
377    def __init__(self, *args, **kwargs):
378        self.logger = create_logger(kwargs.pop('logger', None))
379        self.tunnel_ok = queue.Queue()
380        socketserver.TCPServer.__init__(self, *args, **kwargs)
381
382    def handle_error(self, request, client_address):
383        (exc_class, exc, tb) = sys.exc_info()
384        self.logger.error('Could not establish connection from {0} to remote '
385                          'side of the tunnel'.format(request.getsockname()))
386        self.tunnel_ok.put(False)
387
388    @property
389    def local_address(self):
390        return self.server_address
391
392    @property
393    def local_host(self):
394        return self.server_address[0]
395
396    @property
397    def local_port(self):
398        return self.server_address[1]
399
400    @property
401    def remote_address(self):
402        return self.RequestHandlerClass.remote_address
403
404    @property
405    def remote_host(self):
406        return self.RequestHandlerClass.remote_address[0]
407
408    @property
409    def remote_port(self):
410        return self.RequestHandlerClass.remote_address[1]
411
412
413class _ThreadingForwardServer(socketserver.ThreadingMixIn, _ForwardServer):
414    """
415    Allow concurrent connections to each tunnel
416    """
417    # If True, cleanly stop threads created by ThreadingMixIn when quitting
418    daemon_threads = _DAEMON
419
420
421class _UnixStreamForwardServer(UnixStreamServer):
422    """
423    Serve over UNIX domain sockets (does not work on Windows)
424    """
425
426    def __init__(self, *args, **kwargs):
427        self.logger = create_logger(kwargs.pop('logger', None))
428        self.tunnel_ok = queue.Queue()
429        UnixStreamServer.__init__(self, *args, **kwargs)
430
431    @property
432    def local_address(self):
433        return self.server_address
434
435    @property
436    def local_host(self):
437        return None
438
439    @property
440    def local_port(self):
441        return None
442
443    @property
444    def remote_address(self):
445        return self.RequestHandlerClass.remote_address
446
447    @property
448    def remote_host(self):
449        return self.RequestHandlerClass.remote_address[0]
450
451    @property
452    def remote_port(self):
453        return self.RequestHandlerClass.remote_address[1]
454
455
456class _ThreadingUnixStreamForwardServer(socketserver.ThreadingMixIn,
457                                        _UnixStreamForwardServer):
458    """
459    Allow concurrent connections to each tunnel
460    """
461    # If True, cleanly stop threads created by ThreadingMixIn when quitting
462    daemon_threads = _DAEMON
463
464
465class SSHTunnelForwarder(object):
466    """
467    **SSH tunnel class**
468
469        - Initialize a SSH tunnel to a remote host according to the input
470          arguments
471
472        - Optionally:
473            + Read an SSH configuration file (typically ``~/.ssh/config``)
474            + Load keys from a running SSH agent (i.e. Pageant, GNOME Keyring)
475
476    Raises:
477
478        :class:`.BaseSSHTunnelForwarderError`:
479            raised by SSHTunnelForwarder class methods
480
481        :class:`.HandlerSSHTunnelForwarderError`:
482            raised by tunnel forwarder threads
483
484            .. note::
485                    Attributes ``mute_exceptions`` and
486                    ``raise_exception_if_any_forwarder_have_a_problem``
487                    (deprecated) may be used to silence most exceptions raised
488                    from this class
489
490    Keyword Arguments:
491
492        ssh_address_or_host (tuple or str):
493            IP or hostname of ``REMOTE GATEWAY``. It may be a two-element
494            tuple (``str``, ``int``) representing IP and port respectively,
495            or a ``str`` representing the IP address only
496
497            .. versionadded:: 0.0.4
498
499        ssh_config_file (str):
500            SSH configuration file that will be read. If explicitly set to
501            ``None``, parsing of this configuration is omitted
502
503            Default: :const:`SSH_CONFIG_FILE`
504
505            .. versionadded:: 0.0.4
506
507        ssh_host_key (str):
508            Representation of a line in an OpenSSH-style "known hosts"
509            file.
510
511            ``REMOTE GATEWAY``'s key fingerprint will be compared to this
512            host key in order to prevent against SSH server spoofing.
513            Important when using passwords in order not to accidentally
514            do a login attempt to a wrong (perhaps an attacker's) machine
515
516        ssh_username (str):
517            Username to authenticate as in ``REMOTE SERVER``
518
519            Default: current local user name
520
521        ssh_password (str):
522            Text representing the password used to connect to ``REMOTE
523            SERVER`` or for unlocking a private key.
524
525            .. note::
526                Avoid coding secret password directly in the code, since this
527                may be visible and make your service vulnerable to attacks
528
529        ssh_port (int):
530            Optional port number of the SSH service on ``REMOTE GATEWAY``,
531            when `ssh_address_or_host`` is a ``str`` representing the
532            IP part of ``REMOTE GATEWAY``'s address
533
534            Default: 22
535
536        ssh_pkey (str or paramiko.PKey):
537            **Private** key file name (``str``) to obtain the public key
538            from or a **public** key (:class:`paramiko.pkey.PKey`)
539
540        ssh_private_key_password (str):
541            Password for an encrypted ``ssh_pkey``
542
543            .. note::
544                Avoid coding secret password directly in the code, since this
545                may be visible and make your service vulnerable to attacks
546
547        ssh_proxy (socket-like object or tuple):
548            Proxy where all SSH traffic will be passed through.
549            It might be for example a :class:`paramiko.proxy.ProxyCommand`
550            instance.
551            See either the :class:`paramiko.transport.Transport`'s sock
552            parameter documentation or ``ProxyCommand`` in ``ssh_config(5)``
553            for more information.
554
555            It is also possible to specify the proxy address as a tuple of
556            type (``str``, ``int``) representing proxy's IP and port
557
558            .. note::
559                Ignored if ``ssh_proxy_enabled`` is False
560
561            .. versionadded:: 0.0.5
562
563        ssh_proxy_enabled (boolean):
564            Enable/disable SSH proxy. If True and user's
565            ``ssh_config_file`` contains a ``ProxyCommand`` directive
566            that matches the specified ``ssh_address_or_host``,
567            a :class:`paramiko.proxy.ProxyCommand` object will be created where
568            all SSH traffic will be passed through
569
570            Default: ``True``
571
572            .. versionadded:: 0.0.4
573
574        local_bind_address (tuple):
575            Local tuple in the format (``str``, ``int``) representing the
576            IP and port of the local side of the tunnel. Both elements in
577            the tuple are optional so both ``('', 8000)`` and
578            ``('10.0.0.1', )`` are valid values
579
580            Default: ``('0.0.0.0', RANDOM_PORT)``
581
582            .. versionchanged:: 0.0.8
583                Added the ability to use a UNIX domain socket as local bind
584                address
585
586        local_bind_addresses (list[tuple]):
587            In case more than one tunnel is established at once, a list
588            of tuples (in the same format as ``local_bind_address``)
589            can be specified, such as [(ip1, port_1), (ip_2, port2), ...]
590
591            Default: ``[local_bind_address]``
592
593            .. versionadded:: 0.0.4
594
595        remote_bind_address (tuple):
596            Remote tuple in the format (``str``, ``int``) representing the
597            IP and port of the remote side of the tunnel.
598
599        remote_bind_addresses (list[tuple]):
600            In case more than one tunnel is established at once, a list
601            of tuples (in the same format as ``remote_bind_address``)
602            can be specified, such as [(ip1, port_1), (ip_2, port2), ...]
603
604            Default: ``[remote_bind_address]``
605
606            .. versionadded:: 0.0.4
607
608        allow_agent (boolean):
609            Enable/disable load of keys from an SSH agent
610
611            Default: ``True``
612
613            .. versionadded:: 0.0.8
614
615        host_pkey_directories (list):
616            Look for pkeys in folders on this list, for example ['~/.ssh'].
617
618            Default: ``None`` (disabled)
619
620            .. versionadded:: 0.1.4
621
622        compression (boolean):
623            Turn on/off transport compression. By default compression is
624            disabled since it may negatively affect interactive sessions
625
626            Default: ``False``
627
628            .. versionadded:: 0.0.8
629
630        logger (logging.Logger):
631            logging instance for sshtunnel and paramiko
632
633            Default: :class:`logging.Logger` instance with a single
634            :class:`logging.StreamHandler` handler and
635            :const:`DEFAULT_LOGLEVEL` level
636
637            .. versionadded:: 0.0.3
638
639        mute_exceptions (boolean):
640            Allow silencing :class:`BaseSSHTunnelForwarderError` or
641            :class:`HandlerSSHTunnelForwarderError` exceptions when enabled
642
643            Default: ``False``
644
645            .. versionadded:: 0.0.8
646
647        set_keepalive (float):
648            Interval in seconds defining the period in which, if no data
649            was sent over the connection, a *'keepalive'* packet will be
650            sent (and ignored by the remote host). This can be useful to
651            keep connections alive over a NAT
652
653            Default: 0.0 (no keepalive packets are sent)
654
655            .. versionadded:: 0.0.7
656
657        threaded (boolean):
658            Allow concurrent connections over a single tunnel
659
660            Default: ``True``
661
662            .. versionadded:: 0.0.3
663
664        ssh_address (str):
665            Superseded by ``ssh_address_or_host``, tuple of type (str, int)
666            representing the IP and port of ``REMOTE SERVER``
667
668            .. deprecated:: 0.0.4
669
670        ssh_host (str):
671            Superseded by ``ssh_address_or_host``, tuple of type
672            (str, int) representing the IP and port of ``REMOTE SERVER``
673
674            .. deprecated:: 0.0.4
675
676        ssh_private_key (str or paramiko.PKey):
677            Superseded by ``ssh_pkey``, which can represent either a
678            **private** key file name (``str``) or a **public** key
679            (:class:`paramiko.pkey.PKey`)
680
681            .. deprecated:: 0.0.8
682
683        raise_exception_if_any_forwarder_have_a_problem (boolean):
684            Allow silencing :class:`BaseSSHTunnelForwarderError` or
685            :class:`HandlerSSHTunnelForwarderError` exceptions when set to
686            False
687
688            Default: ``True``
689
690            .. versionadded:: 0.0.4
691
692            .. deprecated:: 0.0.8 (use ``mute_exceptions`` instead)
693
694    Attributes:
695
696        tunnel_is_up (dict):
697            Describe whether or not the other side of the tunnel was reported
698            to be up (and we must close it) or not (skip shutting down that
699            tunnel)
700
701            .. note::
702                This attribute should not be modified
703
704            .. note::
705                When :attr:`.skip_tunnel_checkup` is disabled or the local bind
706                is a UNIX socket, the value will always be ``True``
707
708            **Example**::
709
710                {('127.0.0.1', 55550): True,   # this tunnel is up
711                 ('127.0.0.1', 55551): False}  # this one isn't
712
713            where 55550 and 55551 are the local bind ports
714
715        skip_tunnel_checkup (boolean):
716            Disable tunnel checkup (default for backwards compatibility).
717
718            .. versionadded:: 0.1.0
719
720    """
721    skip_tunnel_checkup = True
722    daemon_forward_servers = _DAEMON  #: flag tunnel threads in daemon mode
723    daemon_transport = _DAEMON  #: flag SSH transport thread in daemon mode
724
725    def local_is_up(self, target):
726        """
727        Check if a tunnel is up (remote target's host is reachable on TCP
728        target's port)
729
730        Arguments:
731            target (tuple):
732                tuple of type (``str``, ``int``) indicating the listen IP
733                address and port
734        Return:
735            boolean
736
737        .. deprecated:: 0.1.0
738            Replaced by :meth:`.check_tunnels()` and :attr:`.tunnel_is_up`
739        """
740        try:
741            check_address(target)
742        except ValueError:
743            self.logger.warning('Target must be a tuple (IP, port), where IP '
744                                'is a string (i.e. "192.168.0.1") and port is '
745                                'an integer (i.e. 40000). Alternatively '
746                                'target can be a valid UNIX domain socket.')
747            return False
748
749        if self.skip_tunnel_checkup:  # force tunnel check at this point
750            self.skip_tunnel_checkup = False
751            self.check_tunnels()
752            self.skip_tunnel_checkup = True  # roll it back
753        return self.tunnel_is_up.get(target, True)
754
755    def _make_ssh_forward_handler_class(self, remote_address_):
756        """
757        Make SSH Handler class
758        """
759        class Handler(_ForwardHandler):
760            remote_address = remote_address_
761            ssh_transport = self._transport
762            logger = self.logger
763        return Handler
764
765    def _make_ssh_forward_server_class(self, remote_address_):
766        return _ThreadingForwardServer if self._threaded else _ForwardServer
767
768    def _make_unix_ssh_forward_server_class(self, remote_address_):
769        return _ThreadingUnixStreamForwardServer if \
770            self._threaded else _UnixStreamForwardServer
771
772    def _make_ssh_forward_server(self, remote_address, local_bind_address):
773        """
774        Make SSH forward proxy Server class
775        """
776        _Handler = self._make_ssh_forward_handler_class(remote_address)
777        try:
778            if isinstance(local_bind_address, string_types):
779                forward_maker_class = self._make_unix_ssh_forward_server_class
780            else:
781                forward_maker_class = self._make_ssh_forward_server_class
782            _Server = forward_maker_class(remote_address)
783            ssh_forward_server = _Server(
784                local_bind_address,
785                _Handler,
786                logger=self.logger,
787            )
788
789            if ssh_forward_server:
790                ssh_forward_server.daemon_threads = self.daemon_forward_servers
791                self._server_list.append(ssh_forward_server)
792                self.tunnel_is_up[ssh_forward_server.server_address] = False
793            else:
794                self._raise(
795                    BaseSSHTunnelForwarderError,
796                    'Problem setting up ssh {0} <> {1} forwarder. You can '
797                    'suppress this exception by using the `mute_exceptions`'
798                    'argument'.format(address_to_str(local_bind_address),
799                                      address_to_str(remote_address))
800                )
801        except IOError:
802            self._raise(
803                BaseSSHTunnelForwarderError,
804                "Couldn't open tunnel {0} <> {1} might be in use or "
805                "destination not reachable".format(
806                    address_to_str(local_bind_address),
807                    address_to_str(remote_address)
808                )
809            )
810
811    def __init__(
812            self,
813            ssh_address_or_host=None,
814            ssh_config_file=SSH_CONFIG_FILE,
815            ssh_host_key=None,
816            ssh_password=None,
817            ssh_pkey=None,
818            ssh_private_key_password=None,
819            ssh_proxy=None,
820            ssh_proxy_enabled=True,
821            ssh_username=None,
822            local_bind_address=None,
823            local_bind_addresses=None,
824            logger=None,
825            mute_exceptions=False,
826            remote_bind_address=None,
827            remote_bind_addresses=None,
828            set_keepalive=0.0,
829            threaded=True,  # old version False
830            compression=None,
831            allow_agent=True,  # look for keys from an SSH agent
832            host_pkey_directories=None,  # look for keys in ~/.ssh
833            *args,
834            **kwargs  # for backwards compatibility
835    ):
836        self.logger = logger or create_logger()
837
838        # Ensure paramiko.transport has a console handler
839        _check_paramiko_handlers(logger=logger)
840
841        self.ssh_host_key = ssh_host_key
842        self.set_keepalive = set_keepalive
843        self._server_list = []  # reset server list
844        self.tunnel_is_up = {}  # handle tunnel status
845        self._threaded = threaded
846        self.is_alive = False
847        # Check if deprecated arguments ssh_address or ssh_host were used
848        for deprecated_argument in ['ssh_address', 'ssh_host']:
849            ssh_address_or_host = self._process_deprecated(ssh_address_or_host,
850                                                           deprecated_argument,
851                                                           kwargs)
852        # other deprecated arguments
853        ssh_pkey = self._process_deprecated(ssh_pkey,
854                                            'ssh_private_key',
855                                            kwargs)
856
857        self._raise_fwd_exc = self._process_deprecated(
858            None,
859            'raise_exception_if_any_forwarder_have_a_problem',
860            kwargs) or not mute_exceptions
861
862        if isinstance(ssh_address_or_host, tuple):
863            check_address(ssh_address_or_host)
864            (ssh_host, ssh_port) = ssh_address_or_host
865        else:
866            ssh_host = ssh_address_or_host
867            ssh_port = kwargs.pop('ssh_port', None)
868
869        if kwargs:
870            raise ValueError('Unknown arguments: {0}'.format(kwargs))
871
872        # remote binds
873        self._remote_binds = self._get_binds(remote_bind_address,
874                                             remote_bind_addresses,
875                                             is_remote=True)
876        # local binds
877        self._local_binds = self._get_binds(local_bind_address,
878                                            local_bind_addresses)
879        self._local_binds = self._consolidate_binds(self._local_binds,
880                                                    self._remote_binds)
881
882        (self.ssh_host,
883         self.ssh_username,
884         ssh_pkey,  # still needs to go through _consolidate_auth
885         self.ssh_port,
886         self.ssh_proxy,
887         self.compression) = self._read_ssh_config(
888             ssh_host,
889             ssh_config_file,
890             ssh_username,
891             ssh_pkey,
892             ssh_port,
893             ssh_proxy if ssh_proxy_enabled else None,
894             compression,
895             self.logger
896        )
897
898        (self.ssh_password, self.ssh_pkeys) = self._consolidate_auth(
899            ssh_password=ssh_password,
900            ssh_pkey=ssh_pkey,
901            ssh_pkey_password=ssh_private_key_password,
902            allow_agent=allow_agent,
903            host_pkey_directories=host_pkey_directories,
904            logger=self.logger
905        )
906
907        check_host(self.ssh_host)
908        check_port(self.ssh_port)
909
910        self.logger.info("Connecting to gateway: {0}:{1} as user '{2}'"
911                         .format(self.ssh_host,
912                                 self.ssh_port,
913                                 self.ssh_username))
914
915        self.logger.debug('Concurrent connections allowed: {0}'
916                          .format(self._threaded))
917
918    @staticmethod
919    def _read_ssh_config(ssh_host,
920                         ssh_config_file,
921                         ssh_username=None,
922                         ssh_pkey=None,
923                         ssh_port=None,
924                         ssh_proxy=None,
925                         compression=None,
926                         logger=None):
927        """
928        Read ssh_config_file and tries to look for user (ssh_username),
929        identityfile (ssh_pkey), port (ssh_port) and proxycommand
930        (ssh_proxy) entries for ssh_host
931        """
932        ssh_config = paramiko.SSHConfig()
933        if not ssh_config_file:  # handle case where it's an empty string
934            ssh_config_file = None
935
936        # Try to read SSH_CONFIG_FILE
937        try:
938            # open the ssh config file
939            with open(os.path.expanduser(ssh_config_file), 'r') as f:
940                ssh_config.parse(f)
941            # looks for information for the destination system
942            hostname_info = ssh_config.lookup(ssh_host)
943            # gather settings for user, port and identity file
944            # last resort: use the 'login name' of the user
945            ssh_username = (
946                ssh_username or
947                hostname_info.get('user')
948            )
949            ssh_pkey = (
950                ssh_pkey or
951                hostname_info.get('identityfile', [None])[0]
952            )
953            ssh_host = hostname_info.get('hostname')
954            ssh_port = ssh_port or hostname_info.get('port')
955
956            proxycommand = hostname_info.get('proxycommand')
957            ssh_proxy = ssh_proxy or (paramiko.ProxyCommand(proxycommand) if
958                                      proxycommand else None)
959            if compression is None:
960                compression = hostname_info.get('compression', '')
961                compression = True if compression.upper() == 'YES' else False
962        except IOError:
963            if logger:
964                logger.warning(
965                    'Could not read SSH configuration file: {0}'
966                    .format(ssh_config_file)
967                )
968        except (AttributeError, TypeError):  # ssh_config_file is None
969            if logger:
970                logger.info('Skipping loading of ssh configuration file')
971        finally:
972            return (ssh_host,
973                    ssh_username or getpass.getuser(),
974                    ssh_pkey,
975                    int(ssh_port) if ssh_port else 22,  # fallback value
976                    ssh_proxy,
977                    compression)
978
979    @staticmethod
980    def get_agent_keys(logger=None):
981        """ Load public keys from any available SSH agent
982
983        Arguments:
984            logger (Optional[logging.Logger])
985
986        Return:
987            list
988        """
989        paramiko_agent = paramiko.Agent()
990        agent_keys = paramiko_agent.get_keys()
991        if logger:
992            logger.info('{0} keys loaded from agent'.format(len(agent_keys)))
993        return list(agent_keys)
994
995    @staticmethod
996    def get_keys(logger=None, host_pkey_directories=None, allow_agent=False):
997        """
998        Load public keys from any available SSH agent or local
999        .ssh directory.
1000
1001        Arguments:
1002            logger (Optional[logging.Logger])
1003
1004            host_pkey_directories (Optional[list[str]]):
1005                List of local directories where host SSH pkeys in the format
1006                "id_*" are searched. For example, ['~/.ssh']
1007
1008                .. versionadded:: 0.1.0
1009
1010            allow_agent (Optional[boolean]):
1011                Whether or not load keys from agent
1012
1013                Default: False
1014
1015        Return:
1016            list
1017        """
1018        keys = SSHTunnelForwarder.get_agent_keys(logger=logger) \
1019            if allow_agent else []
1020
1021        if host_pkey_directories is not None:
1022            paramiko_key_types = {'rsa': paramiko.RSAKey,
1023                                  'dsa': paramiko.DSSKey,
1024                                  'ecdsa': paramiko.ECDSAKey,
1025                                  'ed25519': paramiko.Ed25519Key}
1026            for directory in host_pkey_directories or [DEFAULT_SSH_DIRECTORY]:
1027                for keytype in paramiko_key_types.keys():
1028                    ssh_pkey_expanded = os.path.expanduser(
1029                        os.path.join(directory, 'id_{}'.format(keytype))
1030                    )
1031                    if os.path.isfile(ssh_pkey_expanded):
1032                        ssh_pkey = SSHTunnelForwarder.read_private_key_file(
1033                            pkey_file=ssh_pkey_expanded,
1034                            logger=logger,
1035                            key_type=paramiko_key_types[keytype]
1036                        )
1037                        if ssh_pkey:
1038                            keys.append(ssh_pkey)
1039        if logger:
1040            logger.info('{0} keys loaded from host directory'.format(
1041                len(keys))
1042            )
1043
1044        return keys
1045
1046    @staticmethod
1047    def _consolidate_binds(local_binds, remote_binds):
1048        """
1049        Fill local_binds with defaults when no value/s were specified,
1050        leaving paramiko to decide in which local port the tunnel will be open
1051        """
1052        count = len(remote_binds) - len(local_binds)
1053        if count < 0:
1054            raise ValueError('Too many local bind addresses '
1055                             '(local_bind_addresses > remote_bind_addresses)')
1056        local_binds.extend([('0.0.0.0', 0) for x in range(count)])
1057        return local_binds
1058
1059    @staticmethod
1060    def _consolidate_auth(ssh_password=None,
1061                          ssh_pkey=None,
1062                          ssh_pkey_password=None,
1063                          allow_agent=True,
1064                          host_pkey_directories=None,
1065                          logger=None):
1066        """
1067        Get sure authentication information is in place.
1068        ``ssh_pkey`` may be of classes:
1069            - ``str`` - in this case it represents a private key file; public
1070            key will be obtained from it
1071            - ``paramiko.Pkey`` - it will be transparently added to loaded keys
1072
1073        """
1074        ssh_loaded_pkeys = SSHTunnelForwarder.get_keys(
1075            logger=logger,
1076            host_pkey_directories=host_pkey_directories,
1077            allow_agent=allow_agent
1078        )
1079
1080        if isinstance(ssh_pkey, string_types):
1081            ssh_pkey_expanded = os.path.expanduser(ssh_pkey)
1082            if os.path.exists(ssh_pkey_expanded):
1083                ssh_pkey = SSHTunnelForwarder.read_private_key_file(
1084                    pkey_file=ssh_pkey_expanded,
1085                    pkey_password=ssh_pkey_password or ssh_password,
1086                    logger=logger
1087                )
1088            elif logger:
1089                logger.warning('Private key file not found: {0}'
1090                               .format(ssh_pkey))
1091        if isinstance(ssh_pkey, paramiko.pkey.PKey):
1092            ssh_loaded_pkeys.insert(0, ssh_pkey)
1093
1094        if not ssh_password and not ssh_loaded_pkeys:
1095            raise ValueError('No password or public key available!')
1096        return (ssh_password, ssh_loaded_pkeys)
1097
1098    def _raise(self, exception=BaseSSHTunnelForwarderError, reason=None):
1099        if self._raise_fwd_exc:
1100            raise exception(reason)
1101        else:
1102            self.logger.error(repr(exception(reason)))
1103
1104    def _get_transport(self):
1105        """ Return the SSH transport to the remote gateway """
1106        if self.ssh_proxy:
1107            if isinstance(self.ssh_proxy, paramiko.proxy.ProxyCommand):
1108                proxy_repr = repr(self.ssh_proxy.cmd[1])
1109            else:
1110                proxy_repr = repr(self.ssh_proxy)
1111            self.logger.debug('Connecting via proxy: {0}'.format(proxy_repr))
1112            _socket = self.ssh_proxy
1113        else:
1114            _socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
1115        if isinstance(_socket, socket.socket):
1116            _socket.settimeout(SSH_TIMEOUT)
1117            _socket.connect((self.ssh_host, self.ssh_port))
1118        transport = paramiko.Transport(_socket)
1119        transport.set_keepalive(self.set_keepalive)
1120        transport.use_compression(compress=self.compression)
1121        transport.daemon = self.daemon_transport
1122
1123        return transport
1124
1125    def _create_tunnels(self):
1126        """
1127        Create SSH tunnels on top of a transport to the remote gateway
1128        """
1129        if not self.is_active:
1130            try:
1131                self._connect_to_gateway()
1132            except socket.gaierror:  # raised by paramiko.Transport
1133                msg = 'Could not resolve IP address for {0}, aborting!' \
1134                    .format(self.ssh_host)
1135                self.logger.error(msg)
1136                return
1137            except (paramiko.SSHException, socket.error) as e:
1138                template = 'Could not connect to gateway {0}:{1} : {2}'
1139                msg = template.format(self.ssh_host, self.ssh_port, e.args[0])
1140                self.logger.error(msg)
1141                return
1142        for (rem, loc) in zip(self._remote_binds, self._local_binds):
1143            try:
1144                self._make_ssh_forward_server(rem, loc)
1145            except BaseSSHTunnelForwarderError as e:
1146                msg = 'Problem setting SSH Forwarder up: {0}'.format(e.value)
1147                self.logger.error(msg)
1148
1149    @staticmethod
1150    def _get_binds(bind_address, bind_addresses, is_remote=False):
1151        addr_kind = 'remote' if is_remote else 'local'
1152
1153        if not bind_address and not bind_addresses:
1154            if is_remote:
1155                raise ValueError("No {0} bind addresses specified. Use "
1156                                 "'{0}_bind_address' or '{0}_bind_addresses'"
1157                                 " argument".format(addr_kind))
1158            else:
1159                return []
1160        elif bind_address and bind_addresses:
1161            raise ValueError("You can't use both '{0}_bind_address' and "
1162                             "'{0}_bind_addresses' arguments. Use one of "
1163                             "them.".format(addr_kind))
1164        if bind_address:
1165            bind_addresses = [bind_address]
1166        if not is_remote:
1167            # Add random port if missing in local bind
1168            for (i, local_bind) in enumerate(bind_addresses):
1169                if isinstance(local_bind, tuple) and len(local_bind) == 1:
1170                    bind_addresses[i] = (local_bind[0], 0)
1171        check_addresses(bind_addresses, is_remote)
1172        return bind_addresses
1173
1174    @staticmethod
1175    def _process_deprecated(attrib, deprecated_attrib, kwargs):
1176        """
1177        Processes optional deprecate arguments
1178        """
1179        if deprecated_attrib not in DEPRECATIONS:
1180            raise ValueError('{0} not included in deprecations list'
1181                             .format(deprecated_attrib))
1182        if deprecated_attrib in kwargs:
1183            warnings.warn("'{0}' is DEPRECATED use '{1}' instead"
1184                          .format(deprecated_attrib,
1185                                  DEPRECATIONS[deprecated_attrib]),
1186                          DeprecationWarning)
1187            if attrib:
1188                raise ValueError("You can't use both '{0}' and '{1}'. "
1189                                 "Please only use one of them"
1190                                 .format(deprecated_attrib,
1191                                         DEPRECATIONS[deprecated_attrib]))
1192            else:
1193                return kwargs.pop(deprecated_attrib)
1194        return attrib
1195
1196    @staticmethod
1197    def read_private_key_file(pkey_file,
1198                              pkey_password=None,
1199                              key_type=None,
1200                              logger=None):
1201        """
1202        Get SSH Public key from a private key file, given an optional password
1203
1204        Arguments:
1205            pkey_file (str):
1206                File containing a private key (RSA, DSS or ECDSA)
1207        Keyword Arguments:
1208            pkey_password (Optional[str]):
1209                Password to decrypt the private key
1210            logger (Optional[logging.Logger])
1211        Return:
1212            paramiko.Pkey
1213        """
1214        ssh_pkey = None
1215        for pkey_class in (key_type,) if key_type else (
1216            paramiko.RSAKey,
1217            paramiko.DSSKey,
1218            paramiko.ECDSAKey,
1219            paramiko.Ed25519Key
1220        ):
1221            try:
1222                ssh_pkey = pkey_class.from_private_key_file(
1223                    pkey_file,
1224                    password=pkey_password
1225                )
1226                if logger:
1227                    logger.debug('Private key file ({0}, {1}) successfully '
1228                                 'loaded'.format(pkey_file, pkey_class))
1229                break
1230            except paramiko.PasswordRequiredException:
1231                if logger:
1232                    logger.error('Password is required for key {0}'
1233                                 .format(pkey_file))
1234                break
1235            except paramiko.SSHException:
1236                if logger:
1237                    logger.debug('Private key file ({0}) could not be loaded '
1238                                 'as type {1} or bad password'
1239                                 .format(pkey_file, pkey_class))
1240        return ssh_pkey
1241
1242    def _check_tunnel(self, _srv):
1243        """ Check if tunnel is already established """
1244        if self.skip_tunnel_checkup:
1245            self.tunnel_is_up[_srv.local_address] = True
1246            return
1247        self.logger.info('Checking tunnel to: {0}'.format(_srv.remote_address))
1248        if isinstance(_srv.local_address, string_types):  # UNIX stream
1249            s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
1250        else:
1251            s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
1252        s.settimeout(TUNNEL_TIMEOUT)
1253        try:
1254            # Windows raises WinError 10049 if trying to connect to 0.0.0.0
1255            connect_to = ('127.0.0.1', _srv.local_port) \
1256                if _srv.local_host == '0.0.0.0' else _srv.local_address
1257            s.connect(connect_to)
1258            self.tunnel_is_up[_srv.local_address] = _srv.tunnel_ok.get(
1259                timeout=TUNNEL_TIMEOUT * 1.1
1260            )
1261            self.logger.debug(
1262                'Tunnel to {0} is DOWN'.format(_srv.remote_address)
1263            )
1264        except socket.error:
1265            self.logger.debug(
1266                'Tunnel to {0} is DOWN'.format(_srv.remote_address)
1267            )
1268            self.tunnel_is_up[_srv.local_address] = False
1269
1270        except queue.Empty:
1271            self.logger.debug(
1272                'Tunnel to {0} is UP'.format(_srv.remote_address)
1273            )
1274            self.tunnel_is_up[_srv.local_address] = True
1275        finally:
1276            s.close()
1277
1278    def check_tunnels(self):
1279        """
1280        Check that if all tunnels are established and populates
1281        :attr:`.tunnel_is_up`
1282        """
1283        for _srv in self._server_list:
1284            self._check_tunnel(_srv)
1285
1286    def start(self):
1287        """ Start the SSH tunnels """
1288        if self.is_alive:
1289            self.logger.warning('Already started!')
1290            return
1291        self._create_tunnels()
1292        if not self.is_active:
1293            self._raise(BaseSSHTunnelForwarderError,
1294                        reason='Could not establish session to SSH gateway')
1295        for _srv in self._server_list:
1296            thread = threading.Thread(
1297                target=self._serve_forever_wrapper,
1298                args=(_srv, ),
1299                name='Srv-{0}'.format(address_to_str(_srv.local_port))
1300            )
1301            thread.daemon = self.daemon_forward_servers
1302            thread.start()
1303            self._check_tunnel(_srv)
1304        self.is_alive = any(self.tunnel_is_up.values())
1305        if not self.is_alive:
1306            self._raise(HandlerSSHTunnelForwarderError,
1307                        'An error occurred while opening tunnels.')
1308
1309    def stop(self):
1310        """
1311        Shut the tunnel down.
1312
1313        .. note:: This **had** to be handled with care before ``0.1.0``:
1314
1315            - if a port redirection is opened
1316            - the destination is not reachable
1317            - we attempt a connection to that tunnel (``SYN`` is sent and
1318              acknowledged, then a ``FIN`` packet is sent and never
1319              acknowledged... weird)
1320            - we try to shutdown: it will not succeed until ``FIN_WAIT_2`` and
1321              ``CLOSE_WAIT`` time out.
1322
1323        .. note::
1324            Handle these scenarios with :attr:`.tunnel_is_up`: if False, server
1325            ``shutdown()`` will be skipped on that tunnel
1326        """
1327        self.logger.info('Closing all open connections...')
1328        opened_address_text = ', '.join(
1329            (address_to_str(k.local_address) for k in self._server_list)
1330        ) or 'None'
1331        self.logger.debug('Listening tunnels: ' + opened_address_text)
1332        self._stop_transport()
1333        self._server_list = []  # reset server list
1334        self.tunnel_is_up = {}  # reset tunnel status
1335
1336    def close(self):
1337        """ Stop the an active tunnel, alias to :meth:`.stop` """
1338        self.stop()
1339
1340    def restart(self):
1341        """ Restart connection to the gateway and tunnels """
1342        self.stop()
1343        self.start()
1344
1345    def _connect_to_gateway(self):
1346        """
1347        Open connection to SSH gateway
1348         - First try with all keys loaded from an SSH agent (if allowed)
1349         - Then with those passed directly or read from ~/.ssh/config
1350         - As last resort, try with a provided password
1351        """
1352        for key in self.ssh_pkeys:
1353            self.logger.debug('Trying to log in with key: {0}'
1354                              .format(hexlify(key.get_fingerprint())))
1355            try:
1356                self._transport = self._get_transport()
1357                self._transport.connect(hostkey=self.ssh_host_key,
1358                                        username=self.ssh_username,
1359                                        pkey=key)
1360                if self._transport.is_alive:
1361                    return
1362            except paramiko.AuthenticationException:
1363                self.logger.debug('Authentication error')
1364                self._stop_transport()
1365
1366        if self.ssh_password:  # avoid conflict using both pass and pkey
1367            self.logger.debug('Trying to log in with password: {0}'
1368                              .format('*' * len(self.ssh_password)))
1369            try:
1370                self._transport = self._get_transport()
1371                self._transport.connect(hostkey=self.ssh_host_key,
1372                                        username=self.ssh_username,
1373                                        password=self.ssh_password)
1374                if self._transport.is_alive:
1375                    return
1376            except paramiko.AuthenticationException:
1377                self.logger.debug('Authentication error')
1378                self._stop_transport()
1379
1380        self.logger.error('Could not open connection to gateway')
1381
1382    def _serve_forever_wrapper(self, _srv, poll_interval=0.1):
1383        """
1384        Wrapper for the server created for a SSH forward
1385        """
1386        self.logger.info('Opening tunnel: {0} <> {1}'.format(
1387            address_to_str(_srv.local_address),
1388            address_to_str(_srv.remote_address))
1389        )
1390        _srv.serve_forever(poll_interval)  # blocks until finished
1391
1392        self.logger.info('Tunnel: {0} <> {1} released'.format(
1393            address_to_str(_srv.local_address),
1394            address_to_str(_srv.remote_address))
1395        )
1396
1397    def _stop_transport(self):
1398        """ Close the underlying transport when nothing more is needed """
1399        try:
1400            self._check_is_started()
1401        except (BaseSSHTunnelForwarderError,
1402                HandlerSSHTunnelForwarderError) as e:
1403            self.logger.warning(e)
1404        for _srv in self._server_list:
1405            tunnel = _srv.local_address
1406            if self.tunnel_is_up[tunnel]:
1407                self.logger.info('Shutting down tunnel {0}'.format(tunnel))
1408                _srv.shutdown()
1409            _srv.server_close()
1410            # clean up the UNIX domain socket if we're using one
1411            if isinstance(_srv, _UnixStreamForwardServer):
1412                try:
1413                    os.unlink(_srv.local_address)
1414                except Exception as e:
1415                    self.logger.error('Unable to unlink socket {0}: {1}'
1416                                      .format(self.local_address, repr(e)))
1417        self.is_alive = False
1418        if self.is_active:
1419            self._transport.close()
1420            self._transport.stop_thread()
1421        self.logger.debug('Transport is closed')
1422
1423    @property
1424    def local_bind_port(self):
1425        # BACKWARDS COMPATIBILITY
1426        self._check_is_started()
1427        if len(self._server_list) != 1:
1428            raise BaseSSHTunnelForwarderError(
1429                'Use .local_bind_ports property for more than one tunnel'
1430            )
1431        return self.local_bind_ports[0]
1432
1433    @property
1434    def local_bind_host(self):
1435        # BACKWARDS COMPATIBILITY
1436        self._check_is_started()
1437        if len(self._server_list) != 1:
1438            raise BaseSSHTunnelForwarderError(
1439                'Use .local_bind_hosts property for more than one tunnel'
1440            )
1441        return self.local_bind_hosts[0]
1442
1443    @property
1444    def local_bind_address(self):
1445        # BACKWARDS COMPATIBILITY
1446        self._check_is_started()
1447        if len(self._server_list) != 1:
1448            raise BaseSSHTunnelForwarderError(
1449                'Use .local_bind_addresses property for more than one tunnel'
1450            )
1451        return self.local_bind_addresses[0]
1452
1453    @property
1454    def local_bind_ports(self):
1455        """
1456        Return a list containing the ports of local side of the TCP tunnels
1457        """
1458        self._check_is_started()
1459        return [_server.local_port for _server in self._server_list if
1460                _server.local_port is not None]
1461
1462    @property
1463    def local_bind_hosts(self):
1464        """
1465        Return a list containing the IP addresses listening for the tunnels
1466        """
1467        self._check_is_started()
1468        return [_server.local_host for _server in self._server_list if
1469                _server.local_host is not None]
1470
1471    @property
1472    def local_bind_addresses(self):
1473        """
1474        Return a list of (IP, port) pairs for the local side of the tunnels
1475        """
1476        self._check_is_started()
1477        return [_server.local_address for _server in self._server_list]
1478
1479    @property
1480    def tunnel_bindings(self):
1481        """
1482        Return a dictionary containing the active local<>remote tunnel_bindings
1483        """
1484        return dict((_server.remote_address, _server.local_address) for
1485                    _server in self._server_list if
1486                    self.tunnel_is_up[_server.local_address])
1487
1488    @property
1489    def is_active(self):
1490        """ Return True if the underlying SSH transport is up """
1491        if (
1492            '_transport' in self.__dict__ and
1493            self._transport.is_active()
1494        ):
1495            return True
1496        return False
1497
1498    def _check_is_started(self):
1499        if not self.is_active:  # underlying transport not alive
1500            msg = 'Server is not started. Please .start() first!'
1501            raise BaseSSHTunnelForwarderError(msg)
1502        if not self.is_alive:
1503            msg = 'Tunnels are not started. Please .start() first!'
1504            raise HandlerSSHTunnelForwarderError(msg)
1505
1506    def __str__(self):
1507        credentials = {
1508            'password': self.ssh_password,
1509            'pkeys': [(key.get_name(), hexlify(key.get_fingerprint()))
1510                      for key in self.ssh_pkeys]
1511            if any(self.ssh_pkeys) else None
1512        }
1513        _remove_none_values(credentials)
1514        template = os.linesep.join(['{0} object',
1515                                    'ssh gateway: {1}:{2}',
1516                                    'proxy: {3}',
1517                                    'username: {4}',
1518                                    'authentication: {5}',
1519                                    'hostkey: {6}',
1520                                    'status: {7}started',
1521                                    'keepalive messages: {8}',
1522                                    'tunnel connection check: {9}',
1523                                    'concurrent connections: {10}allowed',
1524                                    'compression: {11}requested',
1525                                    'logging level: {12}',
1526                                    'local binds: {13}',
1527                                    'remote binds: {14}'])
1528        return (template.format(
1529            self.__class__,
1530            self.ssh_host,
1531            self.ssh_port,
1532            self.ssh_proxy.cmd[1] if self.ssh_proxy else 'no',
1533            self.ssh_username,
1534            credentials,
1535            self.ssh_host_key if self.ssh_host_key else'not checked',
1536            '' if self.is_alive else 'not ',
1537            'disabled' if not self.set_keepalive else
1538            'every {0} sec'.format(self.set_keepalive),
1539            'disabled' if self.skip_tunnel_checkup else 'enabled',
1540            '' if self._threaded else 'not ',
1541            '' if self.compression else 'not ',
1542            logging.getLevelName(self.logger.level),
1543            self._local_binds,
1544            self._remote_binds,
1545        ))
1546
1547    def __repr__(self):
1548        return self.__str__()
1549
1550    def __enter__(self):
1551        try:
1552            self.start()
1553            return self
1554        except KeyboardInterrupt:
1555            self.__exit__()
1556
1557    def __exit__(self, *args):
1558        self._stop_transport()
1559
1560
1561def open_tunnel(*args, **kwargs):
1562    """
1563    Open an SSH Tunnel, wrapper for :class:`SSHTunnelForwarder`
1564
1565    Arguments:
1566        destination (Optional[tuple]):
1567            SSH server's IP address and port in the format
1568            (``ssh_address``, ``ssh_port``)
1569
1570    Keyword Arguments:
1571        debug_level (Optional[int or str]):
1572            log level for :class:`logging.Logger` instance, i.e. ``DEBUG``
1573
1574        skip_tunnel_checkup (boolean):
1575            Enable/disable the local side check and populate
1576            :attr:`~SSHTunnelForwarder.tunnel_is_up`
1577
1578            Default: True
1579
1580            .. versionadded:: 0.1.0
1581
1582        block_on_close (boolean):
1583            Wait until all connections are done during close by changing the
1584            value of :attr:`~SSHTunnelForwarder.block_on_close`
1585
1586            Default: True
1587
1588    .. note::
1589        A value of ``debug_level`` set to 1 == ``TRACE`` enables tracing mode
1590    .. note::
1591        See :class:`SSHTunnelForwarder` for keyword arguments
1592
1593    **Example**::
1594
1595        from sshtunnel import open_tunnel
1596
1597        with open_tunnel(SERVER,
1598                         ssh_username=SSH_USER,
1599                         ssh_port=22,
1600                         ssh_password=SSH_PASSWORD,
1601                         remote_bind_address=(REMOTE_HOST, REMOTE_PORT),
1602                         local_bind_address=('', LOCAL_PORT)) as server:
1603            def do_something(port):
1604                pass
1605
1606            print("LOCAL PORTS:", server.local_bind_port)
1607
1608            do_something(server.local_bind_port)
1609    """
1610    # Attach a console handler to the logger or create one if not passed
1611    kwargs['logger'] = create_logger(logger=kwargs.get('logger', None),
1612                                     loglevel=kwargs.pop('debug_level', None))
1613
1614    ssh_address_or_host = kwargs.pop('ssh_address_or_host', None)
1615    # Check if deprecated arguments ssh_address or ssh_host were used
1616    for deprecated_argument in ['ssh_address', 'ssh_host']:
1617        ssh_address_or_host = SSHTunnelForwarder._process_deprecated(
1618            ssh_address_or_host,
1619            deprecated_argument,
1620            kwargs
1621        )
1622
1623    ssh_port = kwargs.pop('ssh_port', None)
1624    skip_tunnel_checkup = kwargs.pop('skip_tunnel_checkup', True)
1625    block_on_close = kwargs.pop('block_on_close', _DAEMON)
1626    if not args:
1627        if isinstance(ssh_address_or_host, tuple):
1628            args = (ssh_address_or_host, )
1629        else:
1630            args = ((ssh_address_or_host, ssh_port), )
1631    forwarder = SSHTunnelForwarder(*args, **kwargs)
1632    forwarder.skip_tunnel_checkup = skip_tunnel_checkup
1633    forwarder.daemon_forward_servers = not block_on_close
1634    forwarder.daemon_transport = not block_on_close
1635    return forwarder
1636
1637
1638def _bindlist(input_str):
1639    """ Define type of data expected for remote and local bind address lists
1640        Returns a tuple (ip_address, port) whose elements are (str, int)
1641    """
1642    try:
1643        ip_port = input_str.split(':')
1644        if len(ip_port) == 1:
1645            _ip = ip_port[0]
1646            _port = None
1647        else:
1648            (_ip, _port) = ip_port
1649        if not _ip and not _port:
1650            raise AssertionError
1651        elif not _port:
1652            _port = '22'  # default port if not given
1653        return _ip, int(_port)
1654    except ValueError:
1655        raise argparse.ArgumentTypeError(
1656            'Address tuple must be of type IP_ADDRESS:PORT'
1657        )
1658    except AssertionError:
1659        raise argparse.ArgumentTypeError("Both IP:PORT can't be missing!")
1660
1661
1662def _parse_arguments(args=None):
1663    """
1664    Parse arguments directly passed from CLI
1665    """
1666    parser = argparse.ArgumentParser(
1667        description='Pure python ssh tunnel utils\n'
1668                    'Version {0}'.format(__version__),
1669        formatter_class=argparse.RawTextHelpFormatter
1670    )
1671
1672    parser.add_argument(
1673        'ssh_address',
1674        type=str,
1675        help='SSH server IP address (GW for SSH tunnels)\n'
1676             'set with "-- ssh_address" if immediately after '
1677             '-R or -L'
1678    )
1679
1680    parser.add_argument(
1681        '-U', '--username',
1682        type=str,
1683        dest='ssh_username',
1684        help='SSH server account username'
1685    )
1686
1687    parser.add_argument(
1688        '-p', '--server_port',
1689        type=int,
1690        dest='ssh_port',
1691        default=22,
1692        help='SSH server TCP port (default: 22)'
1693    )
1694
1695    parser.add_argument(
1696        '-P', '--password',
1697        type=str,
1698        dest='ssh_password',
1699        help='SSH server account password'
1700    )
1701
1702    parser.add_argument(
1703        '-R', '--remote_bind_address',
1704        type=_bindlist,
1705        nargs='+',
1706        default=[],
1707        metavar='IP:PORT',
1708        required=True,
1709        dest='remote_bind_addresses',
1710        help='Remote bind address sequence: '
1711             'ip_1:port_1 ip_2:port_2 ... ip_n:port_n\n'
1712             'Equivalent to ssh -Lxxxx:IP_ADDRESS:PORT\n'
1713             'If port is omitted, defaults to 22.\n'
1714             'Example: -R 10.10.10.10: 10.10.10.10:5900'
1715    )
1716
1717    parser.add_argument(
1718        '-L', '--local_bind_address',
1719        type=_bindlist,
1720        nargs='*',
1721        dest='local_bind_addresses',
1722        metavar='IP:PORT',
1723        help='Local bind address sequence: '
1724             'ip_1:port_1 ip_2:port_2 ... ip_n:port_n\n'
1725             'Elements may also be valid UNIX socket domains: \n'
1726             '/tmp/foo.sock /tmp/bar.sock ... /tmp/baz.sock\n'
1727             'Equivalent to ssh -LPORT:xxxxxxxxx:xxxx, '
1728             'being the local IP address optional.\n'
1729             'By default it will listen in all interfaces '
1730             '(0.0.0.0) and choose a random port.\n'
1731             'Example: -L :40000'
1732    )
1733
1734    parser.add_argument(
1735        '-k', '--ssh_host_key',
1736        type=str,
1737        help="Gateway's host key"
1738    )
1739
1740    parser.add_argument(
1741        '-K', '--private_key_file',
1742        dest='ssh_private_key',
1743        metavar='KEY_FILE',
1744        type=str,
1745        help='RSA/DSS/ECDSA private key file'
1746    )
1747
1748    parser.add_argument(
1749        '-S', '--private_key_password',
1750        dest='ssh_private_key_password',
1751        metavar='KEY_PASSWORD',
1752        type=str,
1753        help='RSA/DSS/ECDSA private key password'
1754    )
1755
1756    parser.add_argument(
1757        '-t', '--threaded',
1758        action='store_true',
1759        help='Allow concurrent connections to each tunnel'
1760    )
1761
1762    parser.add_argument(
1763        '-v', '--verbose',
1764        action='count',
1765        default=0,
1766        help='Increase output verbosity (default: {0})'.format(
1767            logging.getLevelName(DEFAULT_LOGLEVEL)
1768        )
1769    )
1770
1771    parser.add_argument(
1772        '-V', '--version',
1773        action='version',
1774        version='%(prog)s {version}'.format(version=__version__),
1775        help='Show version number and quit'
1776    )
1777
1778    parser.add_argument(
1779        '-x', '--proxy',
1780        type=_bindlist,
1781        dest='ssh_proxy',
1782        metavar='IP:PORT',
1783        help='IP and port of SSH proxy to destination'
1784    )
1785
1786    parser.add_argument(
1787        '-c', '--config',
1788        type=str,
1789        default=SSH_CONFIG_FILE,
1790        dest='ssh_config_file',
1791        help='SSH configuration file, defaults to {0}'.format(SSH_CONFIG_FILE)
1792    )
1793
1794    parser.add_argument(
1795        '-z', '--compress',
1796        action='store_true',
1797        dest='compression',
1798        help='Request server for compression over SSH transport'
1799    )
1800
1801    parser.add_argument(
1802        '-n', '--noagent',
1803        action='store_false',
1804        dest='allow_agent',
1805        help='Disable looking for keys from an SSH agent'
1806    )
1807
1808    parser.add_argument(
1809        '-d', '--host_pkey_directories',
1810        nargs='*',
1811        dest='host_pkey_directories',
1812        metavar='FOLDER',
1813        help='List of directories where SSH pkeys (in the format `id_*`) '
1814             'may be found'
1815    )
1816    return vars(parser.parse_args(args))
1817
1818
1819def _cli_main(args=None):
1820    """ Pass input arguments to open_tunnel
1821
1822        Mandatory: ssh_address, -R (remote bind address list)
1823
1824        Optional:
1825        -U (username) we may gather it from SSH_CONFIG_FILE or current username
1826        -p (server_port), defaults to 22
1827        -P (password)
1828        -L (local_bind_address), default to 0.0.0.0:22
1829        -k (ssh_host_key)
1830        -K (private_key_file), may be gathered from SSH_CONFIG_FILE
1831        -S (private_key_password)
1832        -t (threaded), allow concurrent connections over tunnels
1833        -v (verbose), up to 3 (-vvv) to raise loglevel from ERROR to DEBUG
1834        -V (version)
1835        -x (proxy), ProxyCommand's IP:PORT, may be gathered from config file
1836        -c (ssh_config), ssh configuration file (defaults to SSH_CONFIG_FILE)
1837        -z (compress)
1838        -n (noagent), disable looking for keys from an Agent
1839        -d (host_pkey_directories), look for keys on these folders
1840    """
1841    arguments = _parse_arguments(args)
1842    # Remove all "None" input values
1843    _remove_none_values(arguments)
1844    verbosity = min(arguments.pop('verbose'), 4)
1845    levels = [logging.ERROR,
1846              logging.WARNING,
1847              logging.INFO,
1848              logging.DEBUG,
1849              TRACE_LEVEL]
1850    arguments.setdefault('debug_level', levels[verbosity])
1851    with open_tunnel(**arguments) as tunnel:
1852        if tunnel.is_alive:
1853            input_('''
1854
1855            Press <Ctrl-C> or <Enter> to stop!
1856
1857            ''')
1858
1859
1860if __name__ == '__main__':  # pragma: no cover
1861    _cli_main()
1862