1import six
2import sys
3import time
4
5from . import exceptions
6from . import packet
7from . import payload
8
9
10class Socket(object):
11    """An Engine.IO socket."""
12    upgrade_protocols = ['websocket']
13
14    def __init__(self, server, sid):
15        self.server = server
16        self.sid = sid
17        self.queue = self.server.create_queue()
18        self.last_ping = time.time()
19        self.connected = False
20        self.upgrading = False
21        self.upgraded = False
22        self.closing = False
23        self.closed = False
24        self.session = {}
25
26    def poll(self):
27        """Wait for packets to send to the client."""
28        queue_empty = self.server.get_queue_empty_exception()
29        try:
30            packets = [self.queue.get(timeout=self.server.ping_timeout)]
31            self.queue.task_done()
32        except queue_empty:
33            raise exceptions.QueueEmpty()
34        if packets == [None]:
35            return []
36        while True:
37            try:
38                pkt = self.queue.get(block=False)
39                self.queue.task_done()
40                if pkt is None:
41                    self.queue.put(None)
42                    break
43                packets.append(pkt)
44            except queue_empty:
45                break
46        return packets
47
48    def receive(self, pkt):
49        """Receive packet from the client."""
50        packet_name = packet.packet_names[pkt.packet_type] \
51            if pkt.packet_type < len(packet.packet_names) else 'UNKNOWN'
52        self.server.logger.info('%s: Received packet %s data %s',
53                                self.sid, packet_name,
54                                pkt.data if not isinstance(pkt.data, bytes)
55                                else '<binary>')
56        if pkt.packet_type == packet.PING:
57            self.last_ping = time.time()
58            self.send(packet.Packet(packet.PONG, pkt.data))
59        elif pkt.packet_type == packet.MESSAGE:
60            self.server._trigger_event('message', self.sid, pkt.data,
61                                       run_async=self.server.async_handlers)
62        elif pkt.packet_type == packet.UPGRADE:
63            self.send(packet.Packet(packet.NOOP))
64        elif pkt.packet_type == packet.CLOSE:
65            self.close(wait=False, abort=True)
66        else:
67            raise exceptions.UnknownPacketError()
68
69    def check_ping_timeout(self):
70        """Make sure the client is still sending pings.
71
72        This helps detect disconnections for long-polling clients.
73        """
74        if self.closed:
75            raise exceptions.SocketIsClosedError()
76        if time.time() - self.last_ping > self.server.ping_interval + \
77                self.server.ping_interval_grace_period:
78            self.server.logger.info('%s: Client is gone, closing socket',
79                                    self.sid)
80            # Passing abort=False here will cause close() to write a
81            # CLOSE packet. This has the effect of updating half-open sockets
82            # to their correct state of disconnected
83            self.close(wait=False, abort=False)
84            return False
85        return True
86
87    def send(self, pkt):
88        """Send a packet to the client."""
89        if not self.check_ping_timeout():
90            return
91        else:
92            self.queue.put(pkt)
93        self.server.logger.info('%s: Sending packet %s data %s',
94                                self.sid, packet.packet_names[pkt.packet_type],
95                                pkt.data if not isinstance(pkt.data, bytes)
96                                else '<binary>')
97
98    def handle_get_request(self, environ, start_response):
99        """Handle a long-polling GET request from the client."""
100        connections = [
101            s.strip()
102            for s in environ.get('HTTP_CONNECTION', '').lower().split(',')]
103        transport = environ.get('HTTP_UPGRADE', '').lower()
104        if 'upgrade' in connections and transport in self.upgrade_protocols:
105            self.server.logger.info('%s: Received request to upgrade to %s',
106                                    self.sid, transport)
107            return getattr(self, '_upgrade_' + transport)(environ,
108                                                          start_response)
109        if self.upgrading or self.upgraded:
110            # we are upgrading to WebSocket, do not return any more packets
111            # through the polling endpoint
112            return [packet.Packet(packet.NOOP)]
113        try:
114            packets = self.poll()
115        except exceptions.QueueEmpty:
116            exc = sys.exc_info()
117            self.close(wait=False)
118            six.reraise(*exc)
119        return packets
120
121    def handle_post_request(self, environ):
122        """Handle a long-polling POST request from the client."""
123        length = int(environ.get('CONTENT_LENGTH', '0'))
124        if length > self.server.max_http_buffer_size:
125            raise exceptions.ContentTooLongError()
126        else:
127            body = environ['wsgi.input'].read(length)
128            p = payload.Payload(encoded_payload=body)
129            for pkt in p.packets:
130                self.receive(pkt)
131
132    def close(self, wait=True, abort=False):
133        """Close the socket connection."""
134        if not self.closed and not self.closing:
135            self.closing = True
136            self.server._trigger_event('disconnect', self.sid, run_async=False)
137            if not abort:
138                self.send(packet.Packet(packet.CLOSE))
139            self.closed = True
140            self.queue.put(None)
141            if wait:
142                self.queue.join()
143
144    def _upgrade_websocket(self, environ, start_response):
145        """Upgrade the connection from polling to websocket."""
146        if self.upgraded:
147            raise IOError('Socket has been upgraded already')
148        if self.server._async['websocket'] is None:
149            # the selected async mode does not support websocket
150            return self.server._bad_request()
151        ws = self.server._async['websocket'](self._websocket_handler)
152        return ws(environ, start_response)
153
154    def _websocket_handler(self, ws):
155        """Engine.IO handler for websocket transport."""
156        # try to set a socket timeout matching the configured ping interval
157        for attr in ['_sock', 'socket']:  # pragma: no cover
158            if hasattr(ws, attr) and hasattr(getattr(ws, attr), 'settimeout'):
159                getattr(ws, attr).settimeout(self.server.ping_timeout)
160
161        if self.connected:
162            # the socket was already connected, so this is an upgrade
163            self.upgrading = True  # hold packet sends during the upgrade
164
165            pkt = ws.wait()
166            decoded_pkt = packet.Packet(encoded_packet=pkt)
167            if decoded_pkt.packet_type != packet.PING or \
168                    decoded_pkt.data != 'probe':
169                self.server.logger.info(
170                    '%s: Failed websocket upgrade, no PING packet', self.sid)
171                self.upgrading = False
172                return []
173            ws.send(packet.Packet(
174                packet.PONG,
175                data=six.text_type('probe')).encode(always_bytes=False))
176            self.queue.put(packet.Packet(packet.NOOP))  # end poll
177
178            pkt = ws.wait()
179            decoded_pkt = packet.Packet(encoded_packet=pkt)
180            if decoded_pkt.packet_type != packet.UPGRADE:
181                self.upgraded = False
182                self.server.logger.info(
183                    ('%s: Failed websocket upgrade, expected UPGRADE packet, '
184                     'received %s instead.'),
185                    self.sid, pkt)
186                self.upgrading = False
187                return []
188            self.upgraded = True
189            self.upgrading = False
190        else:
191            self.connected = True
192            self.upgraded = True
193
194        # start separate writer thread
195        def writer():
196            while True:
197                packets = None
198                try:
199                    packets = self.poll()
200                except exceptions.QueueEmpty:
201                    break
202                if not packets:
203                    # empty packet list returned -> connection closed
204                    break
205                try:
206                    for pkt in packets:
207                        ws.send(pkt.encode(always_bytes=False))
208                except:
209                    break
210        writer_task = self.server.start_background_task(writer)
211
212        self.server.logger.info(
213            '%s: Upgrade to websocket successful', self.sid)
214
215        while True:
216            p = None
217            try:
218                p = ws.wait()
219            except Exception as e:
220                # if the socket is already closed, we can assume this is a
221                # downstream error of that
222                if not self.closed:  # pragma: no cover
223                    self.server.logger.info(
224                        '%s: Unexpected error "%s", closing connection',
225                        self.sid, str(e))
226                break
227            if p is None:
228                # connection closed by client
229                break
230            if isinstance(p, six.text_type):  # pragma: no cover
231                p = p.encode('utf-8')
232            pkt = packet.Packet(encoded_packet=p)
233            try:
234                self.receive(pkt)
235            except exceptions.UnknownPacketError:  # pragma: no cover
236                pass
237            except exceptions.SocketIsClosedError:  # pragma: no cover
238                self.server.logger.info('Receive error -- socket is closed')
239                break
240            except:  # pragma: no cover
241                # if we get an unexpected exception we log the error and exit
242                # the connection properly
243                self.server.logger.exception('Unknown receive error')
244                break
245
246        self.queue.put(None)  # unlock the writer task so that it can exit
247        writer_task.join()
248        self.close(wait=False, abort=True)
249
250        return []
251