1# Copyright 2012, Google Inc. 2# All rights reserved. 3# 4# Redistribution and use in source and binary forms, with or without 5# modification, are permitted provided that the following conditions are 6# met: 7# 8# * Redistributions of source code must retain the above copyright 9# notice, this list of conditions and the following disclaimer. 10# * Redistributions in binary form must reproduce the above 11# copyright notice, this list of conditions and the following disclaimer 12# in the documentation and/or other materials provided with the 13# distribution. 14# * Neither the name of Google Inc. nor the names of its 15# contributors may be used to endorse or promote products derived from 16# this software without specific prior written permission. 17# 18# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 21# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 22# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 23# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 24# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 25# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 26# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 30from __future__ import absolute_import 31from mod_pywebsocket import common 32from mod_pywebsocket import util 33from mod_pywebsocket.http_header_util import quote_if_necessary 34 35# The list of available server side extension processor classes. 36_available_processors = {} 37 38 39class ExtensionProcessorInterface(object): 40 def __init__(self, request): 41 self._logger = util.get_class_logger(self) 42 43 self._request = request 44 self._active = True 45 46 def request(self): 47 return self._request 48 49 def name(self): 50 return None 51 52 def check_consistency_with_other_processors(self, processors): 53 pass 54 55 def set_active(self, active): 56 self._active = active 57 58 def is_active(self): 59 return self._active 60 61 def _get_extension_response_internal(self): 62 return None 63 64 def get_extension_response(self): 65 if not self._active: 66 self._logger.debug('Extension %s is deactivated', self.name()) 67 return None 68 69 response = self._get_extension_response_internal() 70 if response is None: 71 self._active = False 72 return response 73 74 def _setup_stream_options_internal(self, stream_options): 75 pass 76 77 def setup_stream_options(self, stream_options): 78 if self._active: 79 self._setup_stream_options_internal(stream_options) 80 81 82def _log_outgoing_compression_ratio(logger, original_bytes, filtered_bytes, 83 average_ratio): 84 # Print inf when ratio is not available. 85 ratio = float('inf') 86 if original_bytes != 0: 87 ratio = float(filtered_bytes) / original_bytes 88 89 logger.debug('Outgoing compression ratio: %f (average: %f)' % 90 (ratio, average_ratio)) 91 92 93def _log_incoming_compression_ratio(logger, received_bytes, filtered_bytes, 94 average_ratio): 95 # Print inf when ratio is not available. 96 ratio = float('inf') 97 if filtered_bytes != 0: 98 ratio = float(received_bytes) / filtered_bytes 99 100 logger.debug('Incoming compression ratio: %f (average: %f)' % 101 (ratio, average_ratio)) 102 103 104def _parse_window_bits(bits): 105 """Return parsed integer value iff the given string conforms to the 106 grammar of the window bits extension parameters. 107 """ 108 109 if bits is None: 110 raise ValueError('Value is required') 111 112 # For non integer values such as "10.0", ValueError will be raised. 113 int_bits = int(bits) 114 115 # First condition is to drop leading zero case e.g. "08". 116 if bits != str(int_bits) or int_bits < 8 or int_bits > 15: 117 raise ValueError('Invalid value: %r' % bits) 118 119 return int_bits 120 121 122class _AverageRatioCalculator(object): 123 """Stores total bytes of original and result data, and calculates average 124 result / original ratio. 125 """ 126 def __init__(self): 127 self._total_original_bytes = 0 128 self._total_result_bytes = 0 129 130 def add_original_bytes(self, value): 131 self._total_original_bytes += value 132 133 def add_result_bytes(self, value): 134 self._total_result_bytes += value 135 136 def get_average_ratio(self): 137 if self._total_original_bytes != 0: 138 return (float(self._total_result_bytes) / 139 self._total_original_bytes) 140 else: 141 return float('inf') 142 143 144class PerMessageDeflateExtensionProcessor(ExtensionProcessorInterface): 145 """permessage-deflate extension processor. 146 147 Specification: 148 http://tools.ietf.org/html/draft-ietf-hybi-permessage-compression-08 149 """ 150 151 _SERVER_MAX_WINDOW_BITS_PARAM = 'server_max_window_bits' 152 _SERVER_NO_CONTEXT_TAKEOVER_PARAM = 'server_no_context_takeover' 153 _CLIENT_MAX_WINDOW_BITS_PARAM = 'client_max_window_bits' 154 _CLIENT_NO_CONTEXT_TAKEOVER_PARAM = 'client_no_context_takeover' 155 156 def __init__(self, request): 157 """Construct PerMessageDeflateExtensionProcessor.""" 158 159 ExtensionProcessorInterface.__init__(self, request) 160 self._logger = util.get_class_logger(self) 161 162 self._preferred_client_max_window_bits = None 163 self._client_no_context_takeover = False 164 165 def name(self): 166 # This method returns "deflate" (not "permessage-deflate") for 167 # compatibility. 168 return 'deflate' 169 170 def _get_extension_response_internal(self): 171 for name in self._request.get_parameter_names(): 172 if name not in [ 173 self._SERVER_MAX_WINDOW_BITS_PARAM, 174 self._SERVER_NO_CONTEXT_TAKEOVER_PARAM, 175 self._CLIENT_MAX_WINDOW_BITS_PARAM 176 ]: 177 self._logger.debug('Unknown parameter: %r', name) 178 return None 179 180 server_max_window_bits = None 181 if self._request.has_parameter(self._SERVER_MAX_WINDOW_BITS_PARAM): 182 server_max_window_bits = self._request.get_parameter_value( 183 self._SERVER_MAX_WINDOW_BITS_PARAM) 184 try: 185 server_max_window_bits = _parse_window_bits( 186 server_max_window_bits) 187 except ValueError as e: 188 self._logger.debug('Bad %s parameter: %r', 189 self._SERVER_MAX_WINDOW_BITS_PARAM, e) 190 return None 191 192 server_no_context_takeover = self._request.has_parameter( 193 self._SERVER_NO_CONTEXT_TAKEOVER_PARAM) 194 if (server_no_context_takeover and self._request.get_parameter_value( 195 self._SERVER_NO_CONTEXT_TAKEOVER_PARAM) is not None): 196 self._logger.debug('%s parameter must not have a value: %r', 197 self._SERVER_NO_CONTEXT_TAKEOVER_PARAM, 198 server_no_context_takeover) 199 return None 200 201 # client_max_window_bits from a client indicates whether the client can 202 # accept client_max_window_bits from a server or not. 203 client_client_max_window_bits = self._request.has_parameter( 204 self._CLIENT_MAX_WINDOW_BITS_PARAM) 205 if (client_client_max_window_bits 206 and self._request.get_parameter_value( 207 self._CLIENT_MAX_WINDOW_BITS_PARAM) is not None): 208 self._logger.debug( 209 '%s parameter must not have a value in a ' 210 'client\'s opening handshake: %r', 211 self._CLIENT_MAX_WINDOW_BITS_PARAM, 212 client_client_max_window_bits) 213 return None 214 215 self._rfc1979_deflater = util._RFC1979Deflater( 216 server_max_window_bits, server_no_context_takeover) 217 218 # Note that we prepare for incoming messages compressed with window 219 # bits upto 15 regardless of the client_max_window_bits value to be 220 # sent to the client. 221 self._rfc1979_inflater = util._RFC1979Inflater() 222 223 self._framer = _PerMessageDeflateFramer(server_max_window_bits, 224 server_no_context_takeover) 225 self._framer.set_bfinal(False) 226 self._framer.set_compress_outgoing_enabled(True) 227 228 response = common.ExtensionParameter(self._request.name()) 229 230 if server_max_window_bits is not None: 231 response.add_parameter(self._SERVER_MAX_WINDOW_BITS_PARAM, 232 str(server_max_window_bits)) 233 234 if server_no_context_takeover: 235 response.add_parameter(self._SERVER_NO_CONTEXT_TAKEOVER_PARAM, 236 None) 237 238 if self._preferred_client_max_window_bits is not None: 239 if not client_client_max_window_bits: 240 self._logger.debug( 241 'Processor is configured to use %s but ' 242 'the client cannot accept it', 243 self._CLIENT_MAX_WINDOW_BITS_PARAM) 244 return None 245 response.add_parameter(self._CLIENT_MAX_WINDOW_BITS_PARAM, 246 str(self._preferred_client_max_window_bits)) 247 248 if self._client_no_context_takeover: 249 response.add_parameter(self._CLIENT_NO_CONTEXT_TAKEOVER_PARAM, 250 None) 251 252 self._logger.debug('Enable %s extension (' 253 'request: server_max_window_bits=%s; ' 254 'server_no_context_takeover=%r, ' 255 'response: client_max_window_bits=%s; ' 256 'client_no_context_takeover=%r)' % 257 (self._request.name(), server_max_window_bits, 258 server_no_context_takeover, 259 self._preferred_client_max_window_bits, 260 self._client_no_context_takeover)) 261 262 return response 263 264 def _setup_stream_options_internal(self, stream_options): 265 self._framer.setup_stream_options(stream_options) 266 267 def set_client_max_window_bits(self, value): 268 """If this option is specified, this class adds the 269 client_max_window_bits extension parameter to the handshake response, 270 but doesn't reduce the LZ77 sliding window size of its inflater. 271 I.e., you can use this for testing client implementation but cannot 272 reduce memory usage of this class. 273 274 If this method has been called with True and an offer without the 275 client_max_window_bits extension parameter is received, 276 - (When processing the permessage-deflate extension) this processor 277 declines the request. 278 - (When processing the permessage-compress extension) this processor 279 accepts the request. 280 """ 281 282 self._preferred_client_max_window_bits = value 283 284 def set_client_no_context_takeover(self, value): 285 """If this option is specified, this class adds the 286 client_no_context_takeover extension parameter to the handshake 287 response, but doesn't reset inflater for each message. I.e., you can 288 use this for testing client implementation but cannot reduce memory 289 usage of this class. 290 """ 291 292 self._client_no_context_takeover = value 293 294 def set_bfinal(self, value): 295 self._framer.set_bfinal(value) 296 297 def enable_outgoing_compression(self): 298 self._framer.set_compress_outgoing_enabled(True) 299 300 def disable_outgoing_compression(self): 301 self._framer.set_compress_outgoing_enabled(False) 302 303 304class _PerMessageDeflateFramer(object): 305 """A framer for extensions with per-message DEFLATE feature.""" 306 def __init__(self, deflate_max_window_bits, deflate_no_context_takeover): 307 self._logger = util.get_class_logger(self) 308 309 self._rfc1979_deflater = util._RFC1979Deflater( 310 deflate_max_window_bits, deflate_no_context_takeover) 311 312 self._rfc1979_inflater = util._RFC1979Inflater() 313 314 self._bfinal = False 315 316 self._compress_outgoing_enabled = False 317 318 # True if a message is fragmented and compression is ongoing. 319 self._compress_ongoing = False 320 321 # Calculates 322 # (Total outgoing bytes supplied to this filter) / 323 # (Total bytes sent to the network after applying this filter) 324 self._outgoing_average_ratio_calculator = _AverageRatioCalculator() 325 326 # Calculates 327 # (Total bytes received from the network) / 328 # (Total incoming bytes obtained after applying this filter) 329 self._incoming_average_ratio_calculator = _AverageRatioCalculator() 330 331 def set_bfinal(self, value): 332 self._bfinal = value 333 334 def set_compress_outgoing_enabled(self, value): 335 self._compress_outgoing_enabled = value 336 337 def _process_incoming_message(self, message, decompress): 338 if not decompress: 339 return message 340 341 received_payload_size = len(message) 342 self._incoming_average_ratio_calculator.add_result_bytes( 343 received_payload_size) 344 345 message = self._rfc1979_inflater.filter(message) 346 347 filtered_payload_size = len(message) 348 self._incoming_average_ratio_calculator.add_original_bytes( 349 filtered_payload_size) 350 351 _log_incoming_compression_ratio( 352 self._logger, received_payload_size, filtered_payload_size, 353 self._incoming_average_ratio_calculator.get_average_ratio()) 354 355 return message 356 357 def _process_outgoing_message(self, message, end, binary): 358 if not binary: 359 message = message.encode('utf-8') 360 361 if not self._compress_outgoing_enabled: 362 return message 363 364 original_payload_size = len(message) 365 self._outgoing_average_ratio_calculator.add_original_bytes( 366 original_payload_size) 367 368 message = self._rfc1979_deflater.filter(message, 369 end=end, 370 bfinal=self._bfinal) 371 372 filtered_payload_size = len(message) 373 self._outgoing_average_ratio_calculator.add_result_bytes( 374 filtered_payload_size) 375 376 _log_outgoing_compression_ratio( 377 self._logger, original_payload_size, filtered_payload_size, 378 self._outgoing_average_ratio_calculator.get_average_ratio()) 379 380 if not self._compress_ongoing: 381 self._outgoing_frame_filter.set_compression_bit() 382 self._compress_ongoing = not end 383 return message 384 385 def _process_incoming_frame(self, frame): 386 if frame.rsv1 == 1 and not common.is_control_opcode(frame.opcode): 387 self._incoming_message_filter.decompress_next_message() 388 frame.rsv1 = 0 389 390 def _process_outgoing_frame(self, frame, compression_bit): 391 if (not compression_bit or common.is_control_opcode(frame.opcode)): 392 return 393 394 frame.rsv1 = 1 395 396 def setup_stream_options(self, stream_options): 397 """Creates filters and sets them to the StreamOptions.""" 398 class _OutgoingMessageFilter(object): 399 def __init__(self, parent): 400 self._parent = parent 401 402 def filter(self, message, end=True, binary=False): 403 return self._parent._process_outgoing_message( 404 message, end, binary) 405 406 class _IncomingMessageFilter(object): 407 def __init__(self, parent): 408 self._parent = parent 409 self._decompress_next_message = False 410 411 def decompress_next_message(self): 412 self._decompress_next_message = True 413 414 def filter(self, message): 415 message = self._parent._process_incoming_message( 416 message, self._decompress_next_message) 417 self._decompress_next_message = False 418 return message 419 420 self._outgoing_message_filter = _OutgoingMessageFilter(self) 421 self._incoming_message_filter = _IncomingMessageFilter(self) 422 stream_options.outgoing_message_filters.append( 423 self._outgoing_message_filter) 424 stream_options.incoming_message_filters.append( 425 self._incoming_message_filter) 426 427 class _OutgoingFrameFilter(object): 428 def __init__(self, parent): 429 self._parent = parent 430 self._set_compression_bit = False 431 432 def set_compression_bit(self): 433 self._set_compression_bit = True 434 435 def filter(self, frame): 436 self._parent._process_outgoing_frame(frame, 437 self._set_compression_bit) 438 self._set_compression_bit = False 439 440 class _IncomingFrameFilter(object): 441 def __init__(self, parent): 442 self._parent = parent 443 444 def filter(self, frame): 445 self._parent._process_incoming_frame(frame) 446 447 self._outgoing_frame_filter = _OutgoingFrameFilter(self) 448 self._incoming_frame_filter = _IncomingFrameFilter(self) 449 stream_options.outgoing_frame_filters.append( 450 self._outgoing_frame_filter) 451 stream_options.incoming_frame_filters.append( 452 self._incoming_frame_filter) 453 454 stream_options.encode_text_message_to_utf8 = False 455 456 457_available_processors[common.PERMESSAGE_DEFLATE_EXTENSION] = ( 458 PerMessageDeflateExtensionProcessor) 459 460 461def get_extension_processor(extension_request): 462 """Given an ExtensionParameter representing an extension offer received 463 from a client, configures and returns an instance of the corresponding 464 extension processor class. 465 """ 466 467 processor_class = _available_processors.get(extension_request.name()) 468 if processor_class is None: 469 return None 470 return processor_class(extension_request) 471 472 473# vi:sts=4 sw=4 et 474