1# -*- coding: utf-8 -*- 2import base64 3from hashlib import sha1 4from email.parser import BytesHeaderParser 5import io 6 7import asyncio 8 9from ws4py import WS_KEY, WS_VERSION 10from ws4py.exc import HandshakeError 11from ws4py.websocket import WebSocket 12 13LF = b'\n' 14CRLF = b'\r\n' 15SPACE = b' ' 16EMPTY = b'' 17 18__all__ = ['WebSocketProtocol'] 19 20class WebSocketProtocol(asyncio.StreamReaderProtocol): 21 def __init__(self, handler_cls): 22 asyncio.StreamReaderProtocol.__init__(self, asyncio.StreamReader(), 23 self._pseudo_connected) 24 self.ws = handler_cls(self) 25 26 def _pseudo_connected(self, reader, writer): 27 pass 28 29 def connection_made(self, transport): 30 """ 31 A peer is now connected and we receive an instance 32 of the underlying :class:`asyncio.Transport`. 33 34 We :class:`asyncio.StreamReader` is created 35 and the transport is associated before the 36 initial HTTP handshake is undertaken. 37 """ 38 #self.transport = transport 39 #self.stream = asyncio.StreamReader() 40 #self.stream.set_transport(transport) 41 asyncio.StreamReaderProtocol.connection_made(self, transport) 42 # Let make it concurrent for others to tag along 43 f = asyncio.async(self.handle_initial_handshake()) 44 f.add_done_callback(self.terminated) 45 46 @property 47 def writer(self): 48 return self._stream_writer 49 50 @property 51 def reader(self): 52 return self._stream_reader 53 54 def terminated(self, f): 55 if f.done() and not f.cancelled(): 56 ex = f.exception() 57 if ex: 58 response = [b'HTTP/1.0 400 Bad Request'] 59 response.append(b'Content-Length: 0') 60 response.append(b'Connection: close') 61 response.append(b'') 62 response.append(b'') 63 self.writer.write(CRLF.join(response)) 64 self.ws.close_connection() 65 66 def close(self): 67 """ 68 Initiate the websocket closing handshake 69 which will eventuall lead to the underlying 70 transport. 71 """ 72 self.ws.close() 73 74 def timeout(self): 75 self.ws.close_connection() 76 if self.ws.started: 77 self.ws.closed(1002, "Peer connection timed-out") 78 79 def connection_lost(self, exc): 80 """ 81 The peer connection is now, the closing 82 handshake won't work so let's not even try. 83 However let's make the websocket handler 84 be aware of it by calling its `closed` 85 method. 86 """ 87 if exc is not None: 88 self.ws.close_connection() 89 if self.ws.started: 90 self.ws.closed(1002, "Peer connection was lost") 91 92 @asyncio.coroutine 93 def handle_initial_handshake(self): 94 """ 95 Performs the HTTP handshake described in :rfc:`6455`. Note that 96 this implementation is really basic and it is strongly advised 97 against using it in production. It would probably break for 98 most clients. If you want a better support for HTTP, please 99 use a more reliable HTTP server implemented using asyncio. 100 """ 101 request_line = yield from self.next_line() 102 method, uri, req_protocol = request_line.strip().split(SPACE, 2) 103 104 # GET required 105 if method.upper() != b'GET': 106 raise HandshakeError('HTTP method must be a GET') 107 108 headers = yield from self.read_headers() 109 if req_protocol == b'HTTP/1.1' and 'Host' not in headers: 110 raise ValueError("Missing host header") 111 112 for key, expected_value in [('Upgrade', 'websocket'), 113 ('Connection', 'upgrade')]: 114 actual_value = headers.get(key, '').lower() 115 if not actual_value: 116 raise HandshakeError('Header %s is not defined' % str(key)) 117 if expected_value not in actual_value: 118 raise HandshakeError('Illegal value for header %s: %s' % 119 (key, actual_value)) 120 121 response_headers = {} 122 123 ws_version = WS_VERSION 124 version = headers.get('Sec-WebSocket-Version') 125 supported_versions = ', '.join([str(v) for v in ws_version]) 126 version_is_valid = False 127 if version: 128 try: version = int(version) 129 except: pass 130 else: version_is_valid = version in ws_version 131 132 if not version_is_valid: 133 response_headers['Sec-WebSocket-Version'] = supported_versions 134 raise HandshakeError('Unhandled or missing WebSocket version') 135 136 key = headers.get('Sec-WebSocket-Key') 137 if key: 138 ws_key = base64.b64decode(key.encode('utf-8')) 139 if len(ws_key) != 16: 140 raise HandshakeError("WebSocket key's length is invalid") 141 142 protocols = [] 143 ws_protocols = [] 144 subprotocols = headers.get('Sec-WebSocket-Protocol') 145 if subprotocols: 146 for s in subprotocols.split(','): 147 s = s.strip() 148 if s in protocols: 149 ws_protocols.append(s) 150 151 exts = [] 152 ws_extensions = [] 153 extensions = headers.get('Sec-WebSocket-Extensions') 154 if extensions: 155 for ext in extensions.split(','): 156 ext = ext.strip() 157 if ext in exts: 158 ws_extensions.append(ext) 159 160 response = [req_protocol + b' 101 Switching Protocols'] 161 response.append(b'Upgrade: websocket') 162 response.append(b'Content-Type: text/plain') 163 response.append(b'Content-Length: 0') 164 response.append(b'Connection: Upgrade') 165 response.append(b'Sec-WebSocket-Version:' + bytes(str(version), 'utf-8')) 166 response.append(b'Sec-WebSocket-Accept:' + base64.b64encode(sha1(key.encode('utf-8') + WS_KEY).digest())) 167 if ws_protocols: 168 response.append(b'Sec-WebSocket-Protocol:' + b', '.join(ws_protocols)) 169 if ws_extensions: 170 response.append(b'Sec-WebSocket-Extensions:' + b','.join(ws_extensions)) 171 response.append(b'') 172 response.append(b'') 173 self.writer.write(CRLF.join(response)) 174 yield from self.handle_websocket() 175 176 @asyncio.coroutine 177 def handle_websocket(self): 178 """ 179 Starts the websocket process until the 180 exchange is completed and terminated. 181 """ 182 yield from self.ws.run() 183 184 @asyncio.coroutine 185 def read_headers(self): 186 """ 187 Read all HTTP headers from the HTTP request 188 and returns a dictionary of them. 189 """ 190 headers = b'' 191 while True: 192 line = yield from self.next_line() 193 headers += line 194 if line == CRLF: 195 break 196 return BytesHeaderParser().parsebytes(headers) 197 198 @asyncio.coroutine 199 def next_line(self): 200 """ 201 Reads data until \r\n is met and then return all read 202 bytes. 203 """ 204 line = yield from self.reader.readline() 205 if not line.endswith(CRLF): 206 raise ValueError("Missing mandatory trailing CRLF") 207 return line 208 209if __name__ == '__main__': 210 from ws4py.async_websocket import EchoWebSocket 211 212 loop = asyncio.get_event_loop() 213 214 def start_server(): 215 proto_factory = lambda: WebSocketProtocol(EchoWebSocket) 216 return loop.create_server(proto_factory, '', 9007) 217 218 s = loop.run_until_complete(start_server()) 219 print('serving on', s.sockets[0].getsockname()) 220 loop.run_forever() 221