1import requests
2import six
3import socket
4import ssl
5import sys
6import threading
7import time
8import websocket
9from six import string_types
10
11from .exceptions import ConnectionError, TimeoutError
12from .parsers import (
13    encode_engineIO_content, decode_engineIO_content,
14    format_packet_text, parse_packet_text)
15from .symmetries import format_query, memoryview, parse_url
16
17
18if not hasattr(websocket, 'create_connection'):
19    sys.exit("""\
20An incompatible websocket library is conflicting with the one we need.
21You can remove the incompatible library and install the correct one
22by running the following commands:
23
24yes | pip uninstall websocket websocket-client
25pip install -U websocket-client""")
26
27
28ENGINEIO_PROTOCOL = 3
29TRANSPORTS = 'xhr-polling', 'websocket'
30
31
32class AbstractTransport(object):
33
34    def __init__(self, http_session, is_secure, url, engineIO_session=None):
35        self.http_session = http_session
36        self.is_secure = is_secure
37        self.url = url
38        self.engineIO_session = engineIO_session
39
40    def recv_packet(self):
41        pass
42
43    def send_packet(self, engineIO_packet_type, engineIO_packet_data=''):
44        pass
45
46    def set_timeout(self, seconds=None):
47        pass
48
49
50class XHR_PollingTransport(AbstractTransport):
51
52    def __init__(self, http_session, is_secure, url, engineIO_session=None):
53        super(XHR_PollingTransport, self).__init__(
54            http_session, is_secure, url, engineIO_session)
55        self._params = {
56            'EIO': ENGINEIO_PROTOCOL, 'transport': 'polling'}
57        if engineIO_session:
58            self._request_index = 1
59            self._kw_get = dict(
60                timeout=engineIO_session.ping_timeout)
61            self._kw_post = dict(
62                timeout=engineIO_session.ping_timeout,
63                headers={'content-type': 'application/octet-stream'})
64            self._params['sid'] = engineIO_session.id
65        else:
66            self._request_index = 0
67            self._kw_get = {}
68            self._kw_post = {}
69        http_scheme = 'https' if is_secure else 'http'
70        self._http_url = '%s://%s/' % (http_scheme, url)
71        self._request_index_lock = threading.Lock()
72        self._send_packet_lock = threading.Lock()
73
74    def recv_packet(self):
75        params = dict(self._params)
76        params['t'] = self._get_timestamp()
77        response = get_response(
78            self.http_session.get,
79            self._http_url,
80            params=params,
81            **self._kw_get)
82        for engineIO_packet in decode_engineIO_content(response.content):
83            engineIO_packet_type, engineIO_packet_data = engineIO_packet
84            yield engineIO_packet_type, engineIO_packet_data
85
86    def send_packet(self, engineIO_packet_type, engineIO_packet_data=''):
87        with self._send_packet_lock:
88            params = dict(self._params)
89            params['t'] = self._get_timestamp()
90            data = encode_engineIO_content([
91                (engineIO_packet_type, engineIO_packet_data),
92            ])
93            get_response(
94                self.http_session.post,
95                self._http_url,
96                params=params,
97                data=memoryview(data),
98                **self._kw_post)
99
100    def _get_timestamp(self):
101        with self._request_index_lock:
102            timestamp = '%s-%s' % (
103                int(time.time() * 1000), self._request_index)
104            self._request_index += 1
105        return timestamp
106
107
108class WebsocketTransport(AbstractTransport):
109
110    def __init__(self, http_session, is_secure, url, engineIO_session=None):
111        super(WebsocketTransport, self).__init__(
112            http_session, is_secure, url, engineIO_session)
113        params = dict(http_session.params, **{
114            'EIO': ENGINEIO_PROTOCOL, 'transport': 'websocket'})
115        request = http_session.prepare_request(requests.Request('GET', url))
116        kw = {'header': ['%s: %s' % x for x in request.headers.items()]}
117        if engineIO_session:
118            params['sid'] = engineIO_session.id
119            kw['timeout'] = self._timeout = engineIO_session.ping_timeout
120        ws_url = '%s://%s/?%s' % (
121            'wss' if is_secure else 'ws', url, format_query(params))
122        http_scheme = 'https' if is_secure else 'http'
123        if http_scheme in http_session.proxies:  # Use the correct proxy
124            proxy_url_pack = parse_url(http_session.proxies[http_scheme])
125            kw['http_proxy_host'] = proxy_url_pack.hostname
126            kw['http_proxy_port'] = proxy_url_pack.port
127            if proxy_url_pack.username:
128                kw['http_proxy_auth'] = (
129                    proxy_url_pack.username, proxy_url_pack.password)
130        if http_session.verify:
131            if http_session.cert:  # Specify certificate path on disk
132                if isinstance(http_session.cert, string_types):
133                    kw['ca_certs'] = http_session.cert
134                else:
135                    kw['ca_certs'] = http_session.cert[0]
136        else:  # Do not verify the SSL certificate
137            kw['sslopt'] = {'cert_reqs': ssl.CERT_NONE}
138        try:
139            self._connection = websocket.create_connection(ws_url, **kw)
140        except Exception as e:
141            raise ConnectionError(e)
142
143    def recv_packet(self):
144        try:
145            packet_text = self._connection.recv()
146        except websocket.WebSocketTimeoutException as e:
147            raise TimeoutError('recv timed out (%s)' % e)
148        except websocket.SSLError as e:
149            raise ConnectionError('recv disconnected by SSL (%s)' % e)
150        except websocket.WebSocketConnectionClosedException as e:
151            raise ConnectionError('recv disconnected (%s)' % e)
152        except socket.error as e:
153            raise ConnectionError('recv disconnected (%s)' % e)
154        engineIO_packet_type, engineIO_packet_data = parse_packet_text(
155            six.b(packet_text))
156        yield engineIO_packet_type, engineIO_packet_data
157
158    def send_packet(self, engineIO_packet_type, engineIO_packet_data=''):
159        packet = format_packet_text(engineIO_packet_type, engineIO_packet_data)
160        try:
161            self._connection.send(packet)
162        except websocket.WebSocketTimeoutException as e:
163            raise TimeoutError('send timed out (%s)' % e)
164        except (
165            TypeError,
166            socket.error,
167            websocket.WebSocketConnectionClosedException,
168        ) as e:
169            raise ConnectionError('send disconnected (%s)' % e)
170
171    def set_timeout(self, seconds=None):
172        self._connection.settimeout(seconds or self._timeout)
173
174
175def get_response(request, *args, **kw):
176    try:
177        response = request(*args, stream=True, **kw)
178    except requests.exceptions.Timeout as e:
179        raise TimeoutError(e)
180    except requests.exceptions.ConnectionError as e:
181        raise ConnectionError(e)
182    except requests.exceptions.SSLError as e:
183        raise ConnectionError('could not negotiate SSL (%s)' % e)
184    status_code = response.status_code
185    if 200 != status_code:
186        raise ConnectionError('unexpected status code (%s %s)' % (
187            status_code, response.text))
188    return response
189
190
191def prepare_http_session(kw):
192    http_session = requests.Session()
193    http_session.headers.update(kw.get('headers', {}))
194    http_session.auth = kw.get('auth')
195    http_session.proxies.update(kw.get('proxies', {}))
196    http_session.hooks.update(kw.get('hooks', {}))
197    http_session.params.update(kw.get('params', {}))
198    http_session.verify = kw.get('verify', True)
199    http_session.cert = _get_cert(kw)
200    http_session.cookies.update(kw.get('cookies', {}))
201    return http_session
202
203
204def _get_cert(kw):
205    # Reduce (None, None) to None
206    cert = kw.get('cert')
207    if hasattr(cert, '__iter__') and cert[0] is None:
208        cert = None
209    return cert
210