1import sys
2import time
3
4from . import exceptions
5from . import packet
6from . import payload
7
8
9class Socket(object):
10    """An Engine.IO socket."""
11    upgrade_protocols = ['websocket']
12
13    def __init__(self, server, sid):
14        self.server = server
15        self.sid = sid
16        self.queue = self.server.create_queue()
17        self.last_ping = None
18        self.connected = False
19        self.upgrading = False
20        self.upgraded = False
21        self.closing = False
22        self.closed = False
23        self.session = {}
24
25    def poll(self):
26        """Wait for packets to send to the client."""
27        queue_empty = self.server.get_queue_empty_exception()
28        try:
29            packets = [self.queue.get(
30                timeout=self.server.ping_interval + self.server.ping_timeout)]
31            self.queue.task_done()
32        except queue_empty:
33            raise exceptions.QueueEmpty()
34        if packets == [None]:
35            return []
36        while True:
37            try:
38                pkt = self.queue.get(block=False)
39                self.queue.task_done()
40                if pkt is None:
41                    self.queue.put(None)
42                    break
43                packets.append(pkt)
44            except queue_empty:
45                break
46        return packets
47
48    def receive(self, pkt):
49        """Receive packet from the client."""
50        packet_name = packet.packet_names[pkt.packet_type] \
51            if pkt.packet_type < len(packet.packet_names) else 'UNKNOWN'
52        self.server.logger.info('%s: Received packet %s data %s',
53                                self.sid, packet_name,
54                                pkt.data if not isinstance(pkt.data, bytes)
55                                else '<binary>')
56        if pkt.packet_type == packet.PONG:
57            self.schedule_ping()
58        elif pkt.packet_type == packet.MESSAGE:
59            self.server._trigger_event('message', self.sid, pkt.data,
60                                       run_async=self.server.async_handlers)
61        elif pkt.packet_type == packet.UPGRADE:
62            self.send(packet.Packet(packet.NOOP))
63        elif pkt.packet_type == packet.CLOSE:
64            self.close(wait=False, abort=True)
65        else:
66            raise exceptions.UnknownPacketError()
67
68    def check_ping_timeout(self):
69        """Make sure the client is still responding to pings."""
70        if self.closed:
71            raise exceptions.SocketIsClosedError()
72        if self.last_ping and \
73                time.time() - self.last_ping > self.server.ping_timeout:
74            self.server.logger.info('%s: Client is gone, closing socket',
75                                    self.sid)
76            # Passing abort=False here will cause close() to write a
77            # CLOSE packet. This has the effect of updating half-open sockets
78            # to their correct state of disconnected
79            self.close(wait=False, abort=False)
80            return False
81        return True
82
83    def send(self, pkt):
84        """Send a packet to the client."""
85        if not self.check_ping_timeout():
86            return
87        else:
88            self.queue.put(pkt)
89        self.server.logger.info('%s: Sending packet %s data %s',
90                                self.sid, packet.packet_names[pkt.packet_type],
91                                pkt.data if not isinstance(pkt.data, bytes)
92                                else '<binary>')
93
94    def handle_get_request(self, environ, start_response):
95        """Handle a long-polling GET request from the client."""
96        connections = [
97            s.strip()
98            for s in environ.get('HTTP_CONNECTION', '').lower().split(',')]
99        transport = environ.get('HTTP_UPGRADE', '').lower()
100        if 'upgrade' in connections and transport in self.upgrade_protocols:
101            self.server.logger.info('%s: Received request to upgrade to %s',
102                                    self.sid, transport)
103            return getattr(self, '_upgrade_' + transport)(environ,
104                                                          start_response)
105        if self.upgrading or self.upgraded:
106            # we are upgrading to WebSocket, do not return any more packets
107            # through the polling endpoint
108            return [packet.Packet(packet.NOOP)]
109        try:
110            packets = self.poll()
111        except exceptions.QueueEmpty:
112            exc = sys.exc_info()
113            self.close(wait=False)
114            raise exc[1].with_traceback(exc[2])
115        return packets
116
117    def handle_post_request(self, environ):
118        """Handle a long-polling POST request from the client."""
119        length = int(environ.get('CONTENT_LENGTH', '0'))
120        if length > self.server.max_http_buffer_size:
121            raise exceptions.ContentTooLongError()
122        else:
123            body = environ['wsgi.input'].read(length).decode('utf-8')
124            p = payload.Payload(encoded_payload=body)
125            for pkt in p.packets:
126                self.receive(pkt)
127
128    def close(self, wait=True, abort=False):
129        """Close the socket connection."""
130        if not self.closed and not self.closing:
131            self.closing = True
132            self.server._trigger_event('disconnect', self.sid, run_async=False)
133            if not abort:
134                self.send(packet.Packet(packet.CLOSE))
135            self.closed = True
136            self.queue.put(None)
137            if wait:
138                self.queue.join()
139
140    def schedule_ping(self):
141        def send_ping():
142            self.last_ping = None
143            self.server.sleep(self.server.ping_interval)
144            if not self.closing and not self.closed:
145                self.last_ping = time.time()
146                self.send(packet.Packet(packet.PING))
147
148        self.server.start_background_task(send_ping)
149
150    def _upgrade_websocket(self, environ, start_response):
151        """Upgrade the connection from polling to websocket."""
152        if self.upgraded:
153            raise IOError('Socket has been upgraded already')
154        if self.server._async['websocket'] is None:
155            # the selected async mode does not support websocket
156            return self.server._bad_request()
157        ws = self.server._async['websocket'](self._websocket_handler)
158        return ws(environ, start_response)
159
160    def _websocket_handler(self, ws):
161        """Engine.IO handler for websocket transport."""
162        def websocket_wait():
163            data = ws.wait()
164            if data and len(data) > self.server.max_http_buffer_size:
165                raise ValueError('packet is too large')
166            return data
167
168        # try to set a socket timeout matching the configured ping interval
169        # and timeout
170        for attr in ['_sock', 'socket']:  # pragma: no cover
171            if hasattr(ws, attr) and hasattr(getattr(ws, attr), 'settimeout'):
172                getattr(ws, attr).settimeout(
173                    self.server.ping_interval + self.server.ping_timeout)
174
175        if self.connected:
176            # the socket was already connected, so this is an upgrade
177            self.upgrading = True  # hold packet sends during the upgrade
178
179            pkt = websocket_wait()
180            decoded_pkt = packet.Packet(encoded_packet=pkt)
181            if decoded_pkt.packet_type != packet.PING or \
182                    decoded_pkt.data != 'probe':
183                self.server.logger.info(
184                    '%s: Failed websocket upgrade, no PING packet', self.sid)
185                self.upgrading = False
186                return []
187            ws.send(packet.Packet(packet.PONG, data='probe').encode())
188            self.queue.put(packet.Packet(packet.NOOP))  # end poll
189
190            pkt = websocket_wait()
191            decoded_pkt = packet.Packet(encoded_packet=pkt)
192            if decoded_pkt.packet_type != packet.UPGRADE:
193                self.upgraded = False
194                self.server.logger.info(
195                    ('%s: Failed websocket upgrade, expected UPGRADE packet, '
196                     'received %s instead.'),
197                    self.sid, pkt)
198                self.upgrading = False
199                return []
200            self.upgraded = True
201            self.upgrading = False
202        else:
203            self.connected = True
204            self.upgraded = True
205
206        # start separate writer thread
207        def writer():
208            while True:
209                packets = None
210                try:
211                    packets = self.poll()
212                except exceptions.QueueEmpty:
213                    break
214                if not packets:
215                    # empty packet list returned -> connection closed
216                    break
217                try:
218                    for pkt in packets:
219                        ws.send(pkt.encode())
220                except:
221                    break
222        writer_task = self.server.start_background_task(writer)
223
224        self.server.logger.info(
225            '%s: Upgrade to websocket successful', self.sid)
226
227        while True:
228            p = None
229            try:
230                p = websocket_wait()
231            except Exception as e:
232                # if the socket is already closed, we can assume this is a
233                # downstream error of that
234                if not self.closed:  # pragma: no cover
235                    self.server.logger.info(
236                        '%s: Unexpected error "%s", closing connection',
237                        self.sid, str(e))
238                break
239            if p is None:
240                # connection closed by client
241                break
242            pkt = packet.Packet(encoded_packet=p)
243            try:
244                self.receive(pkt)
245            except exceptions.UnknownPacketError:  # pragma: no cover
246                pass
247            except exceptions.SocketIsClosedError:  # pragma: no cover
248                self.server.logger.info('Receive error -- socket is closed')
249                break
250            except:  # pragma: no cover
251                # if we get an unexpected exception we log the error and exit
252                # the connection properly
253                self.server.logger.exception('Unknown receive error')
254                break
255
256        self.queue.put(None)  # unlock the writer task so that it can exit
257        writer_task.join()
258        self.close(wait=False, abort=True)
259
260        return []
261