1# -*- coding: utf-8 -*-
2import base64
3from hashlib import sha1
4from email.parser import BytesHeaderParser
5import io
6
7import asyncio
8
9from ws4py import WS_KEY, WS_VERSION
10from ws4py.exc import HandshakeError
11from ws4py.websocket import WebSocket
12
13LF = b'\n'
14CRLF = b'\r\n'
15SPACE = b' '
16EMPTY = b''
17
18__all__ = ['WebSocketProtocol']
19
20class WebSocketProtocol(asyncio.StreamReaderProtocol):
21    def __init__(self, handler_cls):
22        asyncio.StreamReaderProtocol.__init__(self, asyncio.StreamReader(),
23                                              self._pseudo_connected)
24        self.ws = handler_cls(self)
25
26    def _pseudo_connected(self, reader, writer):
27        pass
28
29    def connection_made(self, transport):
30        """
31        A peer is now connected and we receive an instance
32        of the underlying :class:`asyncio.Transport`.
33
34        We :class:`asyncio.StreamReader` is created
35        and the transport is associated before the
36        initial HTTP handshake is undertaken.
37        """
38        #self.transport = transport
39        #self.stream = asyncio.StreamReader()
40        #self.stream.set_transport(transport)
41        asyncio.StreamReaderProtocol.connection_made(self, transport)
42        # Let make it concurrent for others to tag along
43        f = asyncio.async(self.handle_initial_handshake())
44        f.add_done_callback(self.terminated)
45
46    @property
47    def writer(self):
48        return self._stream_writer
49
50    @property
51    def reader(self):
52        return self._stream_reader
53
54    def terminated(self, f):
55        if f.done() and not f.cancelled():
56            ex = f.exception()
57            if ex:
58                response = [b'HTTP/1.0 400 Bad Request']
59                response.append(b'Content-Length: 0')
60                response.append(b'Connection: close')
61                response.append(b'')
62                response.append(b'')
63                self.writer.write(CRLF.join(response))
64                self.ws.close_connection()
65
66    def close(self):
67        """
68        Initiate the websocket closing handshake
69        which will eventuall lead to the underlying
70        transport.
71        """
72        self.ws.close()
73
74    def timeout(self):
75        self.ws.close_connection()
76        if self.ws.started:
77            self.ws.closed(1002, "Peer connection timed-out")
78
79    def connection_lost(self, exc):
80        """
81        The peer connection is now, the closing
82        handshake won't work so let's not even try.
83        However let's make the websocket handler
84        be aware of it by calling its `closed`
85        method.
86        """
87        if exc is not None:
88            self.ws.close_connection()
89            if self.ws.started:
90                self.ws.closed(1002, "Peer connection was lost")
91
92    @asyncio.coroutine
93    def handle_initial_handshake(self):
94        """
95        Performs the HTTP handshake described in :rfc:`6455`. Note that
96        this implementation is really basic and it is strongly advised
97        against using it in production. It would probably break for
98        most clients. If you want a better support for HTTP, please
99        use a more reliable HTTP server implemented using asyncio.
100        """
101        request_line = yield from self.next_line()
102        method, uri, req_protocol = request_line.strip().split(SPACE, 2)
103
104        # GET required
105        if method.upper() != b'GET':
106            raise HandshakeError('HTTP method must be a GET')
107
108        headers = yield from self.read_headers()
109        if req_protocol == b'HTTP/1.1' and 'Host' not in headers:
110            raise ValueError("Missing host header")
111
112        for key, expected_value in [('Upgrade', 'websocket'),
113                                     ('Connection', 'upgrade')]:
114            actual_value = headers.get(key, '').lower()
115            if not actual_value:
116                raise HandshakeError('Header %s is not defined' % str(key))
117            if expected_value not in actual_value:
118                raise HandshakeError('Illegal value for header %s: %s' %
119                                     (key, actual_value))
120
121        response_headers = {}
122
123        ws_version = WS_VERSION
124        version = headers.get('Sec-WebSocket-Version')
125        supported_versions = ', '.join([str(v) for v in ws_version])
126        version_is_valid = False
127        if version:
128            try: version = int(version)
129            except: pass
130            else: version_is_valid = version in ws_version
131
132        if not version_is_valid:
133            response_headers['Sec-WebSocket-Version'] = supported_versions
134            raise HandshakeError('Unhandled or missing WebSocket version')
135
136        key = headers.get('Sec-WebSocket-Key')
137        if key:
138            ws_key = base64.b64decode(key.encode('utf-8'))
139            if len(ws_key) != 16:
140                raise HandshakeError("WebSocket key's length is invalid")
141
142        protocols = []
143        ws_protocols = []
144        subprotocols = headers.get('Sec-WebSocket-Protocol')
145        if subprotocols:
146            for s in subprotocols.split(','):
147                s = s.strip()
148                if s in protocols:
149                    ws_protocols.append(s)
150
151        exts = []
152        ws_extensions = []
153        extensions = headers.get('Sec-WebSocket-Extensions')
154        if extensions:
155            for ext in extensions.split(','):
156                ext = ext.strip()
157                if ext in exts:
158                    ws_extensions.append(ext)
159
160        response = [req_protocol + b' 101 Switching Protocols']
161        response.append(b'Upgrade: websocket')
162        response.append(b'Content-Type: text/plain')
163        response.append(b'Content-Length: 0')
164        response.append(b'Connection: Upgrade')
165        response.append(b'Sec-WebSocket-Version:' + bytes(str(version), 'utf-8'))
166        response.append(b'Sec-WebSocket-Accept:' + base64.b64encode(sha1(key.encode('utf-8') + WS_KEY).digest()))
167        if ws_protocols:
168            response.append(b'Sec-WebSocket-Protocol:' + b', '.join(ws_protocols))
169        if ws_extensions:
170            response.append(b'Sec-WebSocket-Extensions:' + b','.join(ws_extensions))
171        response.append(b'')
172        response.append(b'')
173        self.writer.write(CRLF.join(response))
174        yield from self.handle_websocket()
175
176    @asyncio.coroutine
177    def handle_websocket(self):
178        """
179        Starts the websocket process until the
180        exchange is completed and terminated.
181        """
182        yield from self.ws.run()
183
184    @asyncio.coroutine
185    def read_headers(self):
186        """
187        Read all HTTP headers from the HTTP request
188        and returns a dictionary of them.
189        """
190        headers = b''
191        while True:
192            line = yield from self.next_line()
193            headers += line
194            if line == CRLF:
195                break
196        return BytesHeaderParser().parsebytes(headers)
197
198    @asyncio.coroutine
199    def next_line(self):
200        """
201        Reads data until \r\n is met and then return all read
202        bytes.
203        """
204        line = yield from self.reader.readline()
205        if not line.endswith(CRLF):
206            raise ValueError("Missing mandatory trailing CRLF")
207        return line
208
209if __name__ == '__main__':
210    from ws4py.async_websocket import EchoWebSocket
211
212    loop = asyncio.get_event_loop()
213
214    def start_server():
215        proto_factory = lambda: WebSocketProtocol(EchoWebSocket)
216        return loop.create_server(proto_factory, '', 9007)
217
218    s = loop.run_until_complete(start_server())
219    print('serving on', s.sockets[0].getsockname())
220    loop.run_forever()
221