1# Copyright (c) 2018, Neil Booth
2#
3# All rights reserved.
4#
5# The MIT License (MIT)
6#
7# Permission is hereby granted, free of charge, to any person obtaining
8# a copy of this software and associated documentation files (the
9# "Software"), to deal in the Software without restriction, including
10# without limitation the rights to use, copy, modify, merge, publish,
11# distribute, sublicense, and/or sell copies of the Software, and to
12# permit persons to whom the Software is furnished to do so, subject to
13# the following conditions:
14#
15# The above copyright notice and this permission notice shall be
16# included in all copies or substantial portions of the Software.
17#
18# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
19# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
20# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
21# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
22# LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
23# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
24# WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
25
26'''SOCKS proxying.'''
27
28import asyncio
29import collections
30from ipaddress import IPv4Address, IPv6Address
31import secrets
32import socket
33import struct
34from functools import partial
35
36from aiorpcx.util import NetAddress
37
38
39__all__ = ('SOCKSUserAuth', 'SOCKSRandomAuth', 'SOCKS4', 'SOCKS4a', 'SOCKS5', 'SOCKSProxy',
40           'SOCKSError', 'SOCKSProtocolError', 'SOCKSFailure')
41
42
43SOCKSUserAuth = collections.namedtuple("SOCKSUserAuth", "username password")
44
45
46# Random authentication is useful when used with Tor for stream isolation.
47class SOCKSRandomAuth(SOCKSUserAuth):
48    def __getattribute__(self, key):
49        return secrets.token_hex(32)
50
51
52SOCKSRandomAuth.__new__.__defaults__ = (None, None)
53
54
55class SOCKSError(Exception):
56    '''Base class for SOCKS exceptions.  Each raised exception will be
57    an instance of a derived class.'''
58
59
60class SOCKSProtocolError(SOCKSError):
61    '''Raised when the proxy does not follow the SOCKS protocol'''
62
63
64class SOCKSFailure(SOCKSError):
65    '''Raised when the proxy refuses or fails to make a connection'''
66
67
68class NeedData(Exception):
69    pass
70
71
72class SOCKSBase:
73    '''Stateful as written so good for a single connection only.'''
74
75    @classmethod
76    def name(cls):
77        return cls.__name__
78
79    def __init__(self):
80        self._buffer = bytes()
81        self._state = self._start
82
83    def _read(self, size):
84        if len(self._buffer) < size:
85            raise NeedData(size - len(self._buffer))
86        result = self._buffer[:size]
87        self._buffer = self._buffer[size:]
88        return result
89
90    def receive_data(self, data):
91        self._buffer += data
92
93    def next_message(self):
94        return self._state()
95
96
97class SOCKS4(SOCKSBase):
98    '''SOCKS4 protocol wrapper.'''
99
100    # See http://ftp.icm.edu.pl/packages/socks/socks4/SOCKS4.protocol
101    REPLY_CODES = {
102        90: 'request granted',
103        91: 'request rejected or failed',
104        92: ('request rejected because SOCKS server cannot connect '
105             'to identd on the client'),
106        93: ('request rejected because the client program and identd '
107             'report different user-ids')
108    }
109
110    def __init__(self, remote_address, auth):
111        super().__init__()
112        self._remote_host = remote_address.host
113        self._remote_port = remote_address.port
114        self._auth = auth
115        self._check_remote_host()
116
117    def _check_remote_host(self):
118        if not isinstance(self._remote_host, IPv4Address):
119            raise SOCKSProtocolError(f'SOCKS4 requires an IPv4 address: {self._remote_host}')
120
121    def _start(self):
122        self._state = self._first_response
123
124        if isinstance(self._remote_host, IPv4Address):
125            # SOCKS4
126            dst_ip_packed = self._remote_host.packed
127            host_bytes = b''
128        else:
129            # SOCKS4a
130            dst_ip_packed = b'\0\0\0\1'
131            host_bytes = self._remote_host.encode() + b'\0'
132
133        if isinstance(self._auth, SOCKSUserAuth):
134            user_id = self._auth.username.encode()
135        else:
136            user_id = b''
137
138        # Send TCP/IP stream CONNECT request
139        return b''.join([b'\4\1', struct.pack('>H', self._remote_port),
140                         dst_ip_packed, user_id, b'\0', host_bytes])
141
142    def _first_response(self):
143        # Wait for 8-byte response
144        data = self._read(8)
145        if data[0] != 0:
146            raise SOCKSProtocolError(f'invalid {self.name()} proxy '
147                                     f'response: {data}')
148        reply_code = data[1]
149        if reply_code != 90:
150            msg = self.REPLY_CODES.get(
151                reply_code, f'unknown {self.name()} reply code {reply_code}')
152            raise SOCKSFailure(f'{self.name()} proxy request failed: {msg}')
153
154        # Other fields ignored
155        return None
156
157
158class SOCKS4a(SOCKS4):
159
160    def _check_remote_host(self):
161        if not isinstance(self._remote_host, (str, IPv4Address)):
162            raise SOCKSProtocolError(
163                f'SOCKS4a requires an IPv4 address or host name: {self._remote_host}')
164
165
166class SOCKS5(SOCKSBase):
167    '''SOCKS protocol wrapper.'''
168
169    # See https://tools.ietf.org/html/rfc1928
170    ERROR_CODES = {
171        1: 'general SOCKS server failure',
172        2: 'connection not allowed by ruleset',
173        3: 'network unreachable',
174        4: 'host unreachable',
175        5: 'connection refused',
176        6: 'TTL expired',
177        7: 'command not supported',
178        8: 'address type not supported',
179    }
180
181    def __init__(self, remote_address, auth):
182        super().__init__()
183        self._dst_bytes = SOCKS5._destination_bytes(remote_address.host, remote_address.port)
184        self._auth_bytes, self._auth_methods = SOCKS5._authentication(auth)
185
186    @staticmethod
187    def _destination_bytes(host, port):
188        if isinstance(host, IPv4Address):
189            addr_bytes = b'\1' + host.packed
190        elif isinstance(host, IPv6Address):
191            addr_bytes = b'\4' + host.packed
192        else:
193            assert isinstance(host, str)
194            host = host.encode()
195            assert len(host) <= 255
196            addr_bytes = b'\3' + bytes([len(host)]) + host
197        return addr_bytes + struct.pack('>H', port)
198
199    @staticmethod
200    def _authentication(auth):
201        if isinstance(auth, SOCKSUserAuth):
202            user_bytes = auth.username.encode()
203            if not 0 < len(user_bytes) < 256:
204                raise SOCKSProtocolError(f'username {auth.username} has '
205                                         f'invalid length {len(user_bytes)}')
206            pwd_bytes = auth.password.encode()
207            if not 0 < len(pwd_bytes) < 256:
208                raise SOCKSProtocolError(f'password has invalid length '
209                                         f'{len(pwd_bytes)}')
210            return b''.join([bytes([1, len(user_bytes)]), user_bytes,
211                             bytes([len(pwd_bytes)]), pwd_bytes]), [0, 2]
212        return b'', [0]
213
214    def _start(self):
215        self._state = self._first_response
216        return (b'\5' + bytes([len(self._auth_methods)])
217                + bytes(m for m in self._auth_methods))
218
219    def _first_response(self):
220        # Wait for 2-byte response
221        data = self._read(2)
222        if data[0] != 5:
223            raise SOCKSProtocolError(f'invalid SOCKS5 proxy response: {data}')
224        if data[1] not in self._auth_methods:
225            raise SOCKSFailure('SOCKS5 proxy rejected authentication methods')
226
227        # Authenticate if user-password authentication
228        if data[1] == 2:
229            self._state = self._auth_response
230            return self._auth_bytes
231        return self._request_connection()
232
233    def _auth_response(self):
234        data = self._read(2)
235        if data[0] != 1:
236            raise SOCKSProtocolError(f'invalid SOCKS5 proxy auth '
237                                     f'response: {data}')
238        if data[1] != 0:
239            raise SOCKSFailure(f'SOCKS5 proxy auth failure code: '
240                               f'{data[1]}')
241
242        return self._request_connection()
243
244    def _request_connection(self):
245        # Send connection request
246        self._state = self._connect_response
247        return b'\5\1\0' + self._dst_bytes
248
249    def _connect_response(self):
250        data = self._read(5)
251        if data[0] != 5 or data[2] != 0 or data[3] not in (1, 3, 4):
252            raise SOCKSProtocolError(f'invalid SOCKS5 proxy response: {data}')
253        if data[1] != 0:
254            raise SOCKSFailure(self.ERROR_CODES.get(
255                data[1], f'unknown SOCKS5 error code: {data[1]}'))
256
257        if data[3] == 1:
258            addr_len = 3   # IPv4
259        elif data[3] == 3:
260            addr_len = data[4]  # Hostname
261        else:
262            addr_len = 15  # IPv6
263
264        self._state = partial(self._connect_response_rest, addr_len)
265        return self.next_message()
266
267    def _connect_response_rest(self, addr_len):
268        self._read(addr_len + 2)
269        return None
270
271
272class SOCKSProxy:
273
274    def __init__(self, address, protocol, auth):
275        '''A SOCKS proxy at a NetAddress following a SOCKS protocol.
276
277        auth is an authentication method to use when connecting, or None.
278        '''
279        if not isinstance(address, NetAddress):
280            address = NetAddress.from_string(address)
281        self.address = address
282        self.protocol = protocol
283        self.auth = auth
284        # Set on each successful connection via the proxy to the
285        # result of socket.getpeername()
286        self.peername = None
287
288    def __str__(self):
289        auth = 'username' if self.auth else 'none'
290        return f'{self.protocol.name()} proxy at {self.address}, auth: {auth}'
291
292    async def _handshake(self, client, sock, loop):
293        while True:
294            count = 0
295            try:
296                message = client.next_message()
297            except NeedData as e:
298                count = e.args[0]
299            else:
300                if message is None:
301                    return
302                await loop.sock_sendall(sock, message)
303
304            if count:
305                data = await loop.sock_recv(sock, count)
306                if not data:
307                    raise SOCKSProtocolError("EOF received")
308                client.receive_data(data)
309
310    async def _connect_one(self, remote_address):
311        '''Connect to the proxy and perform a handshake requesting a connection.
312
313        Return the open socket on success, or the exception on failure.
314        '''
315        loop = asyncio.get_event_loop()
316
317        for info in await loop.getaddrinfo(str(self.address.host), self.address.port,
318                                           type=socket.SOCK_STREAM):
319            # This object has state so is only good for one connection
320            client = self.protocol(remote_address, self.auth)
321            sock = socket.socket(family=info[0])
322            try:
323                # A non-blocking socket is required by loop socket methods
324                sock.setblocking(False)
325                await loop.sock_connect(sock, info[4])
326                await self._handshake(client, sock, loop)
327                self.peername = sock.getpeername()
328                return sock
329            except (OSError, SOCKSError) as e:
330                exception = e
331                # Don't close the socket because of an asyncio bug
332                # see https://github.com/kyuupichan/aiorpcX/issues/8
333        return exception
334
335    async def _connect(self, remote_addresses):
336        '''Connect to the proxy and perform a handshake requesting a connection to each address in
337        addresses.
338
339        Return an (open_socket, remote_address) pair on success.
340        '''
341        assert remote_addresses
342
343        exceptions = []
344        for remote_address in remote_addresses:
345            sock = await self._connect_one(remote_address)
346            if isinstance(sock, socket.socket):
347                return sock, remote_address
348            exceptions.append(sock)
349
350        strings = set(f'{exc!r}' for exc in exceptions)
351        raise (exceptions[0] if len(strings) == 1 else
352               OSError(f'multiple exceptions: {", ".join(strings)}'))
353
354    async def _detect_proxy(self):
355        '''Return True if it appears we can connect to a SOCKS proxy,
356        otherwise False.
357        '''
358        if self.protocol is SOCKS4a:
359            remote_address = NetAddress('www.apple.com', 80)
360        else:
361            remote_address = NetAddress('8.8.8.8', 53)
362
363        sock = await self._connect_one(remote_address)
364        if isinstance(sock, socket.socket):
365            sock.close()
366            return True
367
368        # SOCKSFailure indicates something failed, but that we are likely talking to a
369        # proxy
370        return isinstance(sock, SOCKSFailure)
371
372    @classmethod
373    async def auto_detect_at_address(cls, address, auth):
374        '''Try to detect a SOCKS proxy at address using the authentication method (or None).
375        SOCKS5, SOCKS4a and SOCKS are tried in order.  If a SOCKS proxy is detected a
376        SOCKSProxy object is returned.
377
378        Returning a SOCKSProxy does not mean it is functioning - for example, it may have
379        no network connectivity.
380
381        If no proxy is detected return None.
382        '''
383        for protocol in (SOCKS5, SOCKS4a, SOCKS4):
384            proxy = cls(address, protocol, auth)
385            if await proxy._detect_proxy():
386                return proxy
387        return None
388
389    @classmethod
390    async def auto_detect_at_host(cls, host, ports, auth):
391        '''Try to detect a SOCKS proxy on a host on one of the ports.
392
393        Calls auto_detect_address for the ports in order.  Returning a SOCKSProxy does not
394        mean it is functioning - for example, it may have no network connectivity.
395
396        If no proxy is detected return None.
397        '''
398        for port in ports:
399            proxy = await cls.auto_detect_at_address(NetAddress(host, port), auth)
400            if proxy:
401                return proxy
402
403        return None
404
405    async def create_connection(self, protocol_factory, host, port, *,
406                                resolve=False, ssl=None,
407                                family=0, proto=0, flags=0):
408        '''Set up a connection to (host, port) through the proxy.
409
410        If resolve is True then host is resolved locally with
411        getaddrinfo using family, proto and flags, otherwise the proxy
412        is asked to resolve host.
413
414        The function signature is similar to loop.create_connection()
415        with the same result.  The attribute _address is set on the
416        protocol to the address of the successful remote connection.
417        Additionally raises SOCKSError if something goes wrong with
418        the proxy handshake.
419        '''
420        loop = asyncio.get_event_loop()
421        if resolve:
422            remote_addresses = [NetAddress(info[4][0], info[4][1]) for info in
423                                await loop.getaddrinfo(host, port, family=family, proto=proto,
424                                                       type=socket.SOCK_STREAM, flags=flags)]
425        else:
426            remote_addresses = [NetAddress(host, port)]
427
428        sock, remote_address = await self._connect(remote_addresses)
429
430        def set_address():
431            protocol = protocol_factory()
432            protocol._proxy = self
433            protocol._remote_address = remote_address
434            return protocol
435
436        return await loop.create_connection(set_address, sock=sock, ssl=ssl,
437                                            server_hostname=host if ssl else None)
438