1import asyncio
2import ssl
3
4import async_timeout
5
6from ...._errors import ProxyConnectionError, ProxyTimeoutError
7from ...._proto.http_async import HttpProto
8from ...._proto.socks4_async import Socks4Proto
9from ...._proto.socks5_async import Socks5Proto
10
11from .._resolver import Resolver
12from ._stream import AsyncioSocketStream
13from ._connect import connect_tcp
14
15DEFAULT_TIMEOUT = 60
16
17
18class AsyncioProxy:
19    def __init__(
20        self,
21        proxy_host: str,
22        proxy_port: int,
23        proxy_ssl: ssl.SSLContext = None,
24        loop: asyncio.AbstractEventLoop = None,
25    ):
26
27        if loop is None:
28            loop = asyncio.get_event_loop()
29
30        self._loop = loop
31
32        self._proxy_host = proxy_host
33        self._proxy_port = proxy_port
34        self._proxy_ssl = proxy_ssl
35
36        self._dest_host = None
37        self._dest_port = None
38        self._dest_ssl = None
39        self._timeout = None
40
41        self._stream = None
42        self._resolver = Resolver(loop=loop)
43
44    async def connect(
45        self,
46        dest_host: str,
47        dest_port: int,
48        dest_ssl: ssl.SSLContext = None,
49        timeout: float = None,
50        _stream: AsyncioSocketStream = None,
51    ) -> AsyncioSocketStream:
52
53        if timeout is None:
54            timeout = DEFAULT_TIMEOUT
55
56        self._dest_host = dest_host
57        self._dest_port = dest_port
58        self._dest_ssl = dest_ssl
59        self._timeout = timeout
60
61        try:
62            return await self._connect(_stream)
63        except asyncio.TimeoutError as e:
64            raise ProxyTimeoutError('Proxy connection timed out: {}'.format(self._timeout)) from e
65
66    async def _connect(self, _stream: AsyncioSocketStream) -> AsyncioSocketStream:
67        async with async_timeout.timeout(self._timeout):
68            try:
69                if _stream is None:
70                    reader, writer = await connect_tcp(
71                        host=self._proxy_host,
72                        port=self._proxy_port,
73                    )
74                    self._stream = AsyncioSocketStream(
75                        loop=self._loop,
76                        reader=reader,
77                        writer=writer,
78                    )
79                else:
80                    self._stream = _stream
81
82                if self._proxy_ssl is not None:  # pragma: no cover
83                    self._stream = await self._stream.start_tls(
84                        hostname=self._proxy_host,
85                        ssl_context=self._proxy_ssl,
86                    )
87            except OSError as e:
88                await self._close()
89                msg = 'Could not connect to proxy {}:{} [{}]'.format(
90                    self._proxy_host,
91                    self._proxy_port,
92                    e.strerror,
93                )
94                raise ProxyConnectionError(e.errno, msg) from e
95            except (asyncio.CancelledError, Exception):
96                await self._close()
97                raise
98
99            try:
100                await self._negotiate()
101
102                if self._dest_ssl is not None:
103                    self._stream = await self._stream.start_tls(
104                        hostname=self._dest_host,
105                        ssl_context=self._dest_ssl,
106                    )
107            except (asyncio.CancelledError, Exception):
108                await self._close()
109                raise
110
111            return self._stream
112
113    async def _negotiate(self):
114        raise NotImplementedError()
115
116    async def _close(self):
117        if self._stream is not None:
118            await self._stream.close()
119
120    @property
121    def proxy_host(self):
122        return self._proxy_host
123
124    @property
125    def proxy_port(self):
126        return self._proxy_port
127
128
129class Socks5Proxy(AsyncioProxy):
130    def __init__(
131        self,
132        proxy_host,
133        proxy_port,
134        username=None,
135        password=None,
136        rdns=None,
137        proxy_ssl=None,
138        loop: asyncio.AbstractEventLoop = None,
139    ):
140        super().__init__(
141            proxy_host=proxy_host,
142            proxy_port=proxy_port,
143            proxy_ssl=proxy_ssl,
144            loop=loop,
145        )
146        self._username = username
147        self._password = password
148        self._rdns = rdns
149
150    async def _negotiate(self):
151        proto = Socks5Proto(
152            stream=self._stream,
153            resolver=self._resolver,
154            dest_host=self._dest_host,
155            dest_port=self._dest_port,
156            username=self._username,
157            password=self._password,
158            rdns=self._rdns,
159        )
160        await proto.negotiate()
161
162
163class Socks4Proxy(AsyncioProxy):
164    def __init__(
165        self,
166        proxy_host,
167        proxy_port,
168        user_id=None,
169        rdns=None,
170        proxy_ssl=None,
171        loop: asyncio.AbstractEventLoop = None,
172    ):
173        super().__init__(
174            proxy_host=proxy_host,
175            proxy_port=proxy_port,
176            proxy_ssl=proxy_ssl,
177            loop=loop,
178        )
179        self._user_id = user_id
180        self._rdns = rdns
181
182    async def _negotiate(self):
183        proto = Socks4Proto(
184            stream=self._stream,
185            resolver=self._resolver,
186            dest_host=self._dest_host,
187            dest_port=self._dest_port,
188            user_id=self._user_id,
189            rdns=self._rdns,
190        )
191        await proto.negotiate()
192
193
194class HttpProxy(AsyncioProxy):
195    def __init__(
196        self,
197        proxy_host,
198        proxy_port,
199        username=None,
200        password=None,
201        proxy_ssl=None,
202        loop: asyncio.AbstractEventLoop = None,
203    ):
204        super().__init__(
205            proxy_host=proxy_host,
206            proxy_port=proxy_port,
207            proxy_ssl=proxy_ssl,
208            loop=loop,
209        )
210        self._username = username
211        self._password = password
212
213    async def _negotiate(self):
214        proto = HttpProto(
215            stream=self._stream,  # noqa
216            dest_host=self._dest_host,
217            dest_port=self._dest_port,
218            username=self._username,
219            password=self._password,
220        )
221        await proto.negotiate()
222