1import paramiko
2import queue
3import urllib.parse
4import requests.adapters
5import logging
6import os
7import signal
8import socket
9import subprocess
10
11from docker.transport.basehttpadapter import BaseHTTPAdapter
12from .. import constants
13
14import http.client as httplib
15
16try:
17    import requests.packages.urllib3 as urllib3
18except ImportError:
19    import urllib3
20
21RecentlyUsedContainer = urllib3._collections.RecentlyUsedContainer
22
23
24class SSHSocket(socket.socket):
25    def __init__(self, host):
26        super().__init__(
27            socket.AF_INET, socket.SOCK_STREAM)
28        self.host = host
29        self.port = None
30        self.user = None
31        if ':' in self.host:
32            self.host, self.port = self.host.split(':')
33        if '@' in self.host:
34            self.user, self.host = self.host.split('@')
35
36        self.proc = None
37
38    def connect(self, **kwargs):
39        args = ['ssh']
40        if self.user:
41            args = args + ['-l', self.user]
42
43        if self.port:
44            args = args + ['-p', self.port]
45
46        args = args + ['--', self.host, 'docker system dial-stdio']
47
48        preexec_func = None
49        if not constants.IS_WINDOWS_PLATFORM:
50            def f():
51                signal.signal(signal.SIGINT, signal.SIG_IGN)
52            preexec_func = f
53
54        env = dict(os.environ)
55
56        # drop LD_LIBRARY_PATH and SSL_CERT_FILE
57        env.pop('LD_LIBRARY_PATH', None)
58        env.pop('SSL_CERT_FILE', None)
59
60        self.proc = subprocess.Popen(
61            ' '.join(args),
62            env=env,
63            shell=True,
64            stdout=subprocess.PIPE,
65            stdin=subprocess.PIPE,
66            preexec_fn=None if constants.IS_WINDOWS_PLATFORM else preexec_func)
67
68    def _write(self, data):
69        if not self.proc or self.proc.stdin.closed:
70            raise Exception('SSH subprocess not initiated.'
71                            'connect() must be called first.')
72        written = self.proc.stdin.write(data)
73        self.proc.stdin.flush()
74        return written
75
76    def sendall(self, data):
77        self._write(data)
78
79    def send(self, data):
80        return self._write(data)
81
82    def recv(self, n):
83        if not self.proc:
84            raise Exception('SSH subprocess not initiated.'
85                            'connect() must be called first.')
86        return self.proc.stdout.read(n)
87
88    def makefile(self, mode):
89        if not self.proc:
90            self.connect()
91        self.proc.stdout.channel = self
92
93        return self.proc.stdout
94
95    def close(self):
96        if not self.proc or self.proc.stdin.closed:
97            return
98        self.proc.stdin.write(b'\n\n')
99        self.proc.stdin.flush()
100        self.proc.terminate()
101
102
103class SSHConnection(httplib.HTTPConnection):
104    def __init__(self, ssh_transport=None, timeout=60, host=None):
105        super().__init__(
106            'localhost', timeout=timeout
107        )
108        self.ssh_transport = ssh_transport
109        self.timeout = timeout
110        self.ssh_host = host
111
112    def connect(self):
113        if self.ssh_transport:
114            sock = self.ssh_transport.open_session()
115            sock.settimeout(self.timeout)
116            sock.exec_command('docker system dial-stdio')
117        else:
118            sock = SSHSocket(self.ssh_host)
119            sock.settimeout(self.timeout)
120            sock.connect()
121
122        self.sock = sock
123
124
125class SSHConnectionPool(urllib3.connectionpool.HTTPConnectionPool):
126    scheme = 'ssh'
127
128    def __init__(self, ssh_client=None, timeout=60, maxsize=10, host=None):
129        super().__init__(
130            'localhost', timeout=timeout, maxsize=maxsize
131        )
132        self.ssh_transport = None
133        self.timeout = timeout
134        if ssh_client:
135            self.ssh_transport = ssh_client.get_transport()
136        self.ssh_host = host
137
138    def _new_conn(self):
139        return SSHConnection(self.ssh_transport, self.timeout, self.ssh_host)
140
141    # When re-using connections, urllib3 calls fileno() on our
142    # SSH channel instance, quickly overloading our fd limit. To avoid this,
143    # we override _get_conn
144    def _get_conn(self, timeout):
145        conn = None
146        try:
147            conn = self.pool.get(block=self.block, timeout=timeout)
148
149        except AttributeError:  # self.pool is None
150            raise urllib3.exceptions.ClosedPoolError(self, "Pool is closed.")
151
152        except queue.Empty:
153            if self.block:
154                raise urllib3.exceptions.EmptyPoolError(
155                    self,
156                    "Pool reached maximum size and no more "
157                    "connections are allowed."
158                )
159            pass  # Oh well, we'll create a new connection then
160
161        return conn or self._new_conn()
162
163
164class SSHHTTPAdapter(BaseHTTPAdapter):
165
166    __attrs__ = requests.adapters.HTTPAdapter.__attrs__ + [
167        'pools', 'timeout', 'ssh_client', 'ssh_params', 'max_pool_size'
168    ]
169
170    def __init__(self, base_url, timeout=60,
171                 pool_connections=constants.DEFAULT_NUM_POOLS,
172                 max_pool_size=constants.DEFAULT_MAX_POOL_SIZE,
173                 shell_out=False):
174        self.ssh_client = None
175        if not shell_out:
176            self._create_paramiko_client(base_url)
177            self._connect()
178
179        self.ssh_host = base_url
180        if base_url.startswith('ssh://'):
181            self.ssh_host = base_url[len('ssh://'):]
182
183        self.timeout = timeout
184        self.max_pool_size = max_pool_size
185        self.pools = RecentlyUsedContainer(
186            pool_connections, dispose_func=lambda p: p.close()
187        )
188        super().__init__()
189
190    def _create_paramiko_client(self, base_url):
191        logging.getLogger("paramiko").setLevel(logging.WARNING)
192        self.ssh_client = paramiko.SSHClient()
193        base_url = urllib.parse.urlparse(base_url)
194        self.ssh_params = {
195            "hostname": base_url.hostname,
196            "port": base_url.port,
197            "username": base_url.username
198            }
199        ssh_config_file = os.path.expanduser("~/.ssh/config")
200        if os.path.exists(ssh_config_file):
201            conf = paramiko.SSHConfig()
202            with open(ssh_config_file) as f:
203                conf.parse(f)
204            host_config = conf.lookup(base_url.hostname)
205            if 'proxycommand' in host_config:
206                self.ssh_params["sock"] = paramiko.ProxyCommand(
207                    self.ssh_conf['proxycommand']
208                )
209            if 'hostname' in host_config:
210                self.ssh_params['hostname'] = host_config['hostname']
211            if base_url.port is None and 'port' in host_config:
212                self.ssh_params['port'] = host_config['port']
213            if base_url.username is None and 'user' in host_config:
214                self.ssh_params['username'] = host_config['user']
215            if 'identityfile' in host_config:
216                self.ssh_params['key_filename'] = host_config['identityfile']
217
218        self.ssh_client.load_system_host_keys()
219        self.ssh_client.set_missing_host_key_policy(paramiko.WarningPolicy())
220
221    def _connect(self):
222        if self.ssh_client:
223            self.ssh_client.connect(**self.ssh_params)
224
225    def get_connection(self, url, proxies=None):
226        if not self.ssh_client:
227            return SSHConnectionPool(
228                ssh_client=self.ssh_client,
229                timeout=self.timeout,
230                maxsize=self.max_pool_size,
231                host=self.ssh_host
232            )
233        with self.pools.lock:
234            pool = self.pools.get(url)
235            if pool:
236                return pool
237
238            # Connection is closed try a reconnect
239            if self.ssh_client and not self.ssh_client.get_transport():
240                self._connect()
241
242            pool = SSHConnectionPool(
243                ssh_client=self.ssh_client,
244                timeout=self.timeout,
245                maxsize=self.max_pool_size,
246                host=self.ssh_host
247            )
248            self.pools[url] = pool
249
250        return pool
251
252    def close(self):
253        super().close()
254        if self.ssh_client:
255            self.ssh_client.close()
256