1# Copyright 2012, Google Inc. 2# All rights reserved. 3# 4# Redistribution and use in source and binary forms, with or without 5# modification, are permitted provided that the following conditions are 6# met: 7# 8# * Redistributions of source code must retain the above copyright 9# notice, this list of conditions and the following disclaimer. 10# * Redistributions in binary form must reproduce the above 11# copyright notice, this list of conditions and the following disclaimer 12# in the documentation and/or other materials provided with the 13# distribution. 14# * Neither the name of Google Inc. nor the names of its 15# contributors may be used to endorse or promote products derived from 16# this software without specific prior written permission. 17# 18# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 21# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 22# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 23# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 24# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 25# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 26# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29"""This file provides the opening handshake processor for the WebSocket 30protocol (RFC 6455). 31 32Specification: 33http://tools.ietf.org/html/rfc6455 34""" 35 36from __future__ import absolute_import 37import base64 38import logging 39import os 40import re 41from hashlib import sha1 42 43from mod_pywebsocket import common 44from mod_pywebsocket.extensions import get_extension_processor 45from mod_pywebsocket.handshake._base import check_request_line 46from mod_pywebsocket.handshake._base import format_header 47from mod_pywebsocket.handshake._base import get_mandatory_header 48from mod_pywebsocket.handshake._base import HandshakeException 49from mod_pywebsocket.handshake._base import parse_token_list 50from mod_pywebsocket.handshake._base import validate_mandatory_header 51from mod_pywebsocket.handshake._base import validate_subprotocol 52from mod_pywebsocket.handshake._base import VersionException 53from mod_pywebsocket.stream import Stream 54from mod_pywebsocket.stream import StreamOptions 55from mod_pywebsocket import util 56from six.moves import map 57from six.moves import range 58 59# Used to validate the value in the Sec-WebSocket-Key header strictly. RFC 4648 60# disallows non-zero padding, so the character right before == must be any of 61# A, Q, g and w. 62_SEC_WEBSOCKET_KEY_REGEX = re.compile('^[+/0-9A-Za-z]{21}[AQgw]==$') 63 64# Defining aliases for values used frequently. 65_VERSION_LATEST = common.VERSION_HYBI_LATEST 66_VERSION_LATEST_STRING = str(_VERSION_LATEST) 67_SUPPORTED_VERSIONS = [ 68 _VERSION_LATEST, 69] 70 71 72def compute_accept(key): 73 """Computes value for the Sec-WebSocket-Accept header from value of the 74 Sec-WebSocket-Key header. 75 """ 76 77 accept_binary = sha1(key + common.WEBSOCKET_ACCEPT_UUID).digest() 78 accept = base64.b64encode(accept_binary) 79 80 return accept 81 82 83def compute_accept_from_unicode(unicode_key): 84 """A wrapper function for compute_accept which takes a unicode string as an 85 argument, and encodes it to byte string. It then passes it on to 86 compute_accept. 87 """ 88 89 key = unicode_key.encode('UTF-8') 90 return compute_accept(key) 91 92 93class Handshaker(object): 94 """Opening handshake processor for the WebSocket protocol (RFC 6455).""" 95 def __init__(self, request, dispatcher): 96 """Construct an instance. 97 98 Args: 99 request: mod_python request. 100 dispatcher: Dispatcher (dispatch.Dispatcher). 101 102 Handshaker will add attributes such as ws_resource during handshake. 103 """ 104 105 self._logger = util.get_class_logger(self) 106 107 self._request = request 108 self._dispatcher = dispatcher 109 110 def _validate_connection_header(self): 111 connection = get_mandatory_header(self._request, 112 common.CONNECTION_HEADER) 113 114 try: 115 connection_tokens = parse_token_list(connection) 116 except HandshakeException as e: 117 raise HandshakeException('Failed to parse %s: %s' % 118 (common.CONNECTION_HEADER, e)) 119 120 connection_is_valid = False 121 for token in connection_tokens: 122 if token.lower() == common.UPGRADE_CONNECTION_TYPE.lower(): 123 connection_is_valid = True 124 break 125 if not connection_is_valid: 126 raise HandshakeException( 127 '%s header doesn\'t contain "%s"' % 128 (common.CONNECTION_HEADER, common.UPGRADE_CONNECTION_TYPE)) 129 130 def do_handshake(self): 131 self._request.ws_close_code = None 132 self._request.ws_close_reason = None 133 134 # Parsing. 135 136 check_request_line(self._request) 137 138 validate_mandatory_header(self._request, common.UPGRADE_HEADER, 139 common.WEBSOCKET_UPGRADE_TYPE) 140 141 self._validate_connection_header() 142 143 self._request.ws_resource = self._request.uri 144 145 unused_host = get_mandatory_header(self._request, common.HOST_HEADER) 146 147 self._request.ws_version = self._check_version() 148 149 try: 150 self._get_origin() 151 self._set_protocol() 152 self._parse_extensions() 153 154 # Key validation, response generation. 155 156 key = self._get_key() 157 accept = compute_accept(key) 158 self._logger.debug('%s: %r (%s)', 159 common.SEC_WEBSOCKET_ACCEPT_HEADER, accept, 160 util.hexify(base64.b64decode(accept))) 161 162 self._logger.debug('Protocol version is RFC 6455') 163 164 # Setup extension processors. 165 166 processors = [] 167 if self._request.ws_requested_extensions is not None: 168 for extension_request in self._request.ws_requested_extensions: 169 processor = get_extension_processor(extension_request) 170 # Unknown extension requests are just ignored. 171 if processor is not None: 172 processors.append(processor) 173 self._request.ws_extension_processors = processors 174 175 # List of extra headers. The extra handshake handler may add header 176 # data as name/value pairs to this list and pywebsocket appends 177 # them to the WebSocket handshake. 178 self._request.extra_headers = [] 179 180 # Extra handshake handler may modify/remove processors. 181 self._dispatcher.do_extra_handshake(self._request) 182 processors = [ 183 processor 184 for processor in self._request.ws_extension_processors 185 if processor is not None 186 ] 187 188 # Ask each processor if there are extensions on the request which 189 # cannot co-exist. When processor decided other processors cannot 190 # co-exist with it, the processor marks them (or itself) as 191 # "inactive". The first extension processor has the right to 192 # make the final call. 193 for processor in reversed(processors): 194 if processor.is_active(): 195 processor.check_consistency_with_other_processors( 196 processors) 197 processors = [ 198 processor for processor in processors if processor.is_active() 199 ] 200 201 accepted_extensions = [] 202 203 stream_options = StreamOptions() 204 205 for index, processor in enumerate(processors): 206 if not processor.is_active(): 207 continue 208 209 extension_response = processor.get_extension_response() 210 if extension_response is None: 211 # Rejected. 212 continue 213 214 accepted_extensions.append(extension_response) 215 216 processor.setup_stream_options(stream_options) 217 218 # Inactivate all of the following compression extensions. 219 for j in range(index + 1, len(processors)): 220 processors[j].set_active(False) 221 222 if len(accepted_extensions) > 0: 223 self._request.ws_extensions = accepted_extensions 224 self._logger.debug( 225 'Extensions accepted: %r', 226 list( 227 map(common.ExtensionParameter.name, 228 accepted_extensions))) 229 else: 230 self._request.ws_extensions = None 231 232 self._request.ws_stream = self._create_stream(stream_options) 233 234 if self._request.ws_requested_protocols is not None: 235 if self._request.ws_protocol is None: 236 raise HandshakeException( 237 'do_extra_handshake must choose one subprotocol from ' 238 'ws_requested_protocols and set it to ws_protocol') 239 validate_subprotocol(self._request.ws_protocol) 240 241 self._logger.debug('Subprotocol accepted: %r', 242 self._request.ws_protocol) 243 else: 244 if self._request.ws_protocol is not None: 245 raise HandshakeException( 246 'ws_protocol must be None when the client didn\'t ' 247 'request any subprotocol') 248 249 self._send_handshake(accept) 250 except HandshakeException as e: 251 if not e.status: 252 # Fallback to 400 bad request by default. 253 e.status = common.HTTP_STATUS_BAD_REQUEST 254 raise e 255 256 def _get_origin(self): 257 origin_header = common.ORIGIN_HEADER 258 origin = self._request.headers_in.get(origin_header) 259 if origin is None: 260 self._logger.debug('Client request does not have origin header') 261 self._request.ws_origin = origin 262 263 def _check_version(self): 264 version = get_mandatory_header(self._request, 265 common.SEC_WEBSOCKET_VERSION_HEADER) 266 if version == _VERSION_LATEST_STRING: 267 return _VERSION_LATEST 268 269 if version.find(',') >= 0: 270 raise HandshakeException( 271 'Multiple versions (%r) are not allowed for header %s' % 272 (version, common.SEC_WEBSOCKET_VERSION_HEADER), 273 status=common.HTTP_STATUS_BAD_REQUEST) 274 raise VersionException('Unsupported version %r for header %s' % 275 (version, common.SEC_WEBSOCKET_VERSION_HEADER), 276 supported_versions=', '.join( 277 map(str, _SUPPORTED_VERSIONS))) 278 279 def _set_protocol(self): 280 self._request.ws_protocol = None 281 # MOZILLA 282 self._request.sts = None 283 # /MOZILLA 284 285 protocol_header = self._request.headers_in.get( 286 common.SEC_WEBSOCKET_PROTOCOL_HEADER) 287 288 if protocol_header is None: 289 self._request.ws_requested_protocols = None 290 return 291 292 self._request.ws_requested_protocols = parse_token_list( 293 protocol_header) 294 self._logger.debug('Subprotocols requested: %r', 295 self._request.ws_requested_protocols) 296 297 def _parse_extensions(self): 298 extensions_header = self._request.headers_in.get( 299 common.SEC_WEBSOCKET_EXTENSIONS_HEADER) 300 if not extensions_header: 301 self._request.ws_requested_extensions = None 302 return 303 304 try: 305 self._request.ws_requested_extensions = common.parse_extensions( 306 extensions_header) 307 except common.ExtensionParsingException as e: 308 raise HandshakeException( 309 'Failed to parse Sec-WebSocket-Extensions header: %r' % e) 310 311 self._logger.debug( 312 'Extensions requested: %r', 313 list( 314 map(common.ExtensionParameter.name, 315 self._request.ws_requested_extensions))) 316 317 def _validate_key(self, key): 318 if key.find(',') >= 0: 319 raise HandshakeException('Request has multiple %s header lines or ' 320 'contains illegal character \',\': %r' % 321 (common.SEC_WEBSOCKET_KEY_HEADER, key)) 322 323 # Validate 324 key_is_valid = False 325 try: 326 # Validate key by quick regex match before parsing by base64 327 # module. Because base64 module skips invalid characters, we have 328 # to do this in advance to make this server strictly reject illegal 329 # keys. 330 if _SEC_WEBSOCKET_KEY_REGEX.match(key): 331 decoded_key = base64.b64decode(key) 332 if len(decoded_key) == 16: 333 key_is_valid = True 334 except TypeError as e: 335 pass 336 337 if not key_is_valid: 338 raise HandshakeException('Illegal value for header %s: %r' % 339 (common.SEC_WEBSOCKET_KEY_HEADER, key)) 340 341 return decoded_key 342 343 def _get_key(self): 344 key = get_mandatory_header(self._request, 345 common.SEC_WEBSOCKET_KEY_HEADER) 346 347 decoded_key = self._validate_key(key) 348 349 self._logger.debug('%s: %r (%s)', common.SEC_WEBSOCKET_KEY_HEADER, key, 350 util.hexify(decoded_key)) 351 352 return key.encode('UTF-8') 353 354 def _create_stream(self, stream_options): 355 return Stream(self._request, stream_options) 356 357 def _create_handshake_response(self, accept): 358 response = [] 359 360 response.append(u'HTTP/1.1 101 Switching Protocols\r\n') 361 362 # WebSocket headers 363 response.append( 364 format_header(common.UPGRADE_HEADER, 365 common.WEBSOCKET_UPGRADE_TYPE)) 366 response.append( 367 format_header(common.CONNECTION_HEADER, 368 common.UPGRADE_CONNECTION_TYPE)) 369 response.append( 370 format_header(common.SEC_WEBSOCKET_ACCEPT_HEADER, 371 accept.decode('UTF-8'))) 372 if self._request.ws_protocol is not None: 373 response.append( 374 format_header(common.SEC_WEBSOCKET_PROTOCOL_HEADER, 375 self._request.ws_protocol)) 376 if (self._request.ws_extensions is not None 377 and len(self._request.ws_extensions) != 0): 378 response.append( 379 format_header( 380 common.SEC_WEBSOCKET_EXTENSIONS_HEADER, 381 common.format_extensions(self._request.ws_extensions))) 382 # MOZILLA 383 if self._request.sts is not None: 384 response.append(format_header("Strict-Transport-Security", 385 self._request.sts)) 386 # /MOZILLA 387 388 # Headers not specific for WebSocket 389 for name, value in self._request.extra_headers: 390 response.append(format_header(name, value)) 391 392 response.append(u'\r\n') 393 394 return u''.join(response) 395 396 def _send_handshake(self, accept): 397 raw_response = self._create_handshake_response(accept) 398 self._request.connection.write(raw_response.encode('UTF-8')) 399 self._logger.debug('Sent server\'s opening handshake: %r', 400 raw_response) 401 402 403# vi:sts=4 sw=4 et 404