1# Copyright 2012, Google Inc. 2# All rights reserved. 3# 4# Redistribution and use in source and binary forms, with or without 5# modification, are permitted provided that the following conditions are 6# met: 7# 8# * Redistributions of source code must retain the above copyright 9# notice, this list of conditions and the following disclaimer. 10# * Redistributions in binary form must reproduce the above 11# copyright notice, this list of conditions and the following disclaimer 12# in the documentation and/or other materials provided with the 13# distribution. 14# * Neither the name of Google Inc. nor the names of its 15# contributors may be used to endorse or promote products derived from 16# this software without specific prior written permission. 17# 18# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 19# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 20# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 21# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT 22# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, 23# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT 24# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 25# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 26# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 28# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 30from __future__ import absolute_import 31from mod_pywebsocket import common 32from mod_pywebsocket import util 33from mod_pywebsocket.http_header_util import quote_if_necessary 34 35# The list of available server side extension processor classes. 36_available_processors = {} 37 38 39class ExtensionProcessorInterface(object): 40 def __init__(self, request): 41 self._logger = util.get_class_logger(self) 42 43 self._request = request 44 self._active = True 45 46 def request(self): 47 return self._request 48 49 def name(self): 50 return None 51 52 def check_consistency_with_other_processors(self, processors): 53 pass 54 55 def set_active(self, active): 56 self._active = active 57 58 def is_active(self): 59 return self._active 60 61 def _get_extension_response_internal(self): 62 return None 63 64 def get_extension_response(self): 65 if not self._active: 66 self._logger.debug('Extension %s is deactivated', self.name()) 67 return None 68 69 response = self._get_extension_response_internal() 70 if response is None: 71 self._active = False 72 return response 73 74 def _setup_stream_options_internal(self, stream_options): 75 pass 76 77 def setup_stream_options(self, stream_options): 78 if self._active: 79 self._setup_stream_options_internal(stream_options) 80 81 82def _log_outgoing_compression_ratio(logger, original_bytes, filtered_bytes, 83 average_ratio): 84 # Print inf when ratio is not available. 85 ratio = float('inf') 86 if original_bytes != 0: 87 ratio = float(filtered_bytes) / original_bytes 88 89 logger.debug('Outgoing compression ratio: %f (average: %f)' % 90 (ratio, average_ratio)) 91 92 93def _log_incoming_compression_ratio(logger, received_bytes, filtered_bytes, 94 average_ratio): 95 # Print inf when ratio is not available. 96 ratio = float('inf') 97 if filtered_bytes != 0: 98 ratio = float(received_bytes) / filtered_bytes 99 100 logger.debug('Incoming compression ratio: %f (average: %f)' % 101 (ratio, average_ratio)) 102 103 104def _parse_window_bits(bits): 105 """Return parsed integer value iff the given string conforms to the 106 grammar of the window bits extension parameters. 107 """ 108 109 if bits is None: 110 raise ValueError('Value is required') 111 112 # For non integer values such as "10.0", ValueError will be raised. 113 int_bits = int(bits) 114 115 # First condition is to drop leading zero case e.g. "08". 116 if bits != str(int_bits) or int_bits < 8 or int_bits > 15: 117 raise ValueError('Invalid value: %r' % bits) 118 119 return int_bits 120 121 122class _AverageRatioCalculator(object): 123 """Stores total bytes of original and result data, and calculates average 124 result / original ratio. 125 """ 126 def __init__(self): 127 self._total_original_bytes = 0 128 self._total_result_bytes = 0 129 130 def add_original_bytes(self, value): 131 self._total_original_bytes += value 132 133 def add_result_bytes(self, value): 134 self._total_result_bytes += value 135 136 def get_average_ratio(self): 137 if self._total_original_bytes != 0: 138 return (float(self._total_result_bytes) / 139 self._total_original_bytes) 140 else: 141 return float('inf') 142 143 144class PerMessageDeflateExtensionProcessor(ExtensionProcessorInterface): 145 """permessage-deflate extension processor. 146 147 Specification: 148 http://tools.ietf.org/html/draft-ietf-hybi-permessage-compression-08 149 """ 150 151 _SERVER_MAX_WINDOW_BITS_PARAM = 'server_max_window_bits' 152 _SERVER_NO_CONTEXT_TAKEOVER_PARAM = 'server_no_context_takeover' 153 _CLIENT_MAX_WINDOW_BITS_PARAM = 'client_max_window_bits' 154 _CLIENT_NO_CONTEXT_TAKEOVER_PARAM = 'client_no_context_takeover' 155 156 def __init__(self, request): 157 """Construct PerMessageDeflateExtensionProcessor.""" 158 159 ExtensionProcessorInterface.__init__(self, request) 160 self._logger = util.get_class_logger(self) 161 162 self._preferred_client_max_window_bits = None 163 self._client_no_context_takeover = False 164 165 def name(self): 166 # This method returns "deflate" (not "permessage-deflate") for 167 # compatibility. 168 return 'deflate' 169 170 def _get_extension_response_internal(self): 171 for name in self._request.get_parameter_names(): 172 if name not in [ 173 self._SERVER_MAX_WINDOW_BITS_PARAM, 174 self._SERVER_NO_CONTEXT_TAKEOVER_PARAM, 175 self._CLIENT_MAX_WINDOW_BITS_PARAM 176 ]: 177 self._logger.debug('Unknown parameter: %r', name) 178 return None 179 180 server_max_window_bits = None 181 if self._request.has_parameter(self._SERVER_MAX_WINDOW_BITS_PARAM): 182 server_max_window_bits = self._request.get_parameter_value( 183 self._SERVER_MAX_WINDOW_BITS_PARAM) 184 try: 185 server_max_window_bits = _parse_window_bits( 186 server_max_window_bits) 187 except ValueError as e: 188 self._logger.debug('Bad %s parameter: %r', 189 self._SERVER_MAX_WINDOW_BITS_PARAM, e) 190 return None 191 192 server_no_context_takeover = self._request.has_parameter( 193 self._SERVER_NO_CONTEXT_TAKEOVER_PARAM) 194 if (server_no_context_takeover and self._request.get_parameter_value( 195 self._SERVER_NO_CONTEXT_TAKEOVER_PARAM) is not None): 196 self._logger.debug('%s parameter must not have a value: %r', 197 self._SERVER_NO_CONTEXT_TAKEOVER_PARAM, 198 server_no_context_takeover) 199 return None 200 201 # client_max_window_bits from a client indicates whether the client can 202 # accept client_max_window_bits from a server or not. 203 client_client_max_window_bits = self._request.has_parameter( 204 self._CLIENT_MAX_WINDOW_BITS_PARAM) 205 if (client_client_max_window_bits 206 and self._request.get_parameter_value( 207 self._CLIENT_MAX_WINDOW_BITS_PARAM) is not None): 208 self._logger.debug( 209 '%s parameter must not have a value in a ' 210 'client\'s opening handshake: %r', 211 self._CLIENT_MAX_WINDOW_BITS_PARAM, 212 client_client_max_window_bits) 213 return None 214 215 self._rfc1979_deflater = util._RFC1979Deflater( 216 server_max_window_bits, server_no_context_takeover) 217 218 # Note that we prepare for incoming messages compressed with window 219 # bits upto 15 regardless of the client_max_window_bits value to be 220 # sent to the client. 221 self._rfc1979_inflater = util._RFC1979Inflater() 222 223 self._framer = _PerMessageDeflateFramer(server_max_window_bits, 224 server_no_context_takeover) 225 self._framer.set_bfinal(False) 226 self._framer.set_compress_outgoing_enabled(True) 227 228 response = common.ExtensionParameter(self._request.name()) 229 230 if server_max_window_bits is not None: 231 response.add_parameter(self._SERVER_MAX_WINDOW_BITS_PARAM, 232 str(server_max_window_bits)) 233 234 if server_no_context_takeover: 235 response.add_parameter(self._SERVER_NO_CONTEXT_TAKEOVER_PARAM, 236 None) 237 238 if self._preferred_client_max_window_bits is not None: 239 if not client_client_max_window_bits: 240 self._logger.debug( 241 'Processor is configured to use %s but ' 242 'the client cannot accept it', 243 self._CLIENT_MAX_WINDOW_BITS_PARAM) 244 return None 245 response.add_parameter(self._CLIENT_MAX_WINDOW_BITS_PARAM, 246 str(self._preferred_client_max_window_bits)) 247 248 if self._client_no_context_takeover: 249 response.add_parameter(self._CLIENT_NO_CONTEXT_TAKEOVER_PARAM, 250 None) 251 252 self._logger.debug('Enable %s extension (' 253 'request: server_max_window_bits=%s; ' 254 'server_no_context_takeover=%r, ' 255 'response: client_max_window_bits=%s; ' 256 'client_no_context_takeover=%r)' % 257 (self._request.name(), server_max_window_bits, 258 server_no_context_takeover, 259 self._preferred_client_max_window_bits, 260 self._client_no_context_takeover)) 261 262 return response 263 264 def _setup_stream_options_internal(self, stream_options): 265 self._framer.setup_stream_options(stream_options) 266 267 def set_client_max_window_bits(self, value): 268 """If this option is specified, this class adds the 269 client_max_window_bits extension parameter to the handshake response, 270 but doesn't reduce the LZ77 sliding window size of its inflater. 271 I.e., you can use this for testing client implementation but cannot 272 reduce memory usage of this class. 273 274 If this method has been called with True and an offer without the 275 client_max_window_bits extension parameter is received, 276 277 - (When processing the permessage-deflate extension) this processor 278 declines the request. 279 - (When processing the permessage-compress extension) this processor 280 accepts the request. 281 """ 282 283 self._preferred_client_max_window_bits = value 284 285 def set_client_no_context_takeover(self, value): 286 """If this option is specified, this class adds the 287 client_no_context_takeover extension parameter to the handshake 288 response, but doesn't reset inflater for each message. I.e., you can 289 use this for testing client implementation but cannot reduce memory 290 usage of this class. 291 """ 292 293 self._client_no_context_takeover = value 294 295 def set_bfinal(self, value): 296 self._framer.set_bfinal(value) 297 298 def enable_outgoing_compression(self): 299 self._framer.set_compress_outgoing_enabled(True) 300 301 def disable_outgoing_compression(self): 302 self._framer.set_compress_outgoing_enabled(False) 303 304 305class _PerMessageDeflateFramer(object): 306 """A framer for extensions with per-message DEFLATE feature.""" 307 def __init__(self, deflate_max_window_bits, deflate_no_context_takeover): 308 self._logger = util.get_class_logger(self) 309 310 self._rfc1979_deflater = util._RFC1979Deflater( 311 deflate_max_window_bits, deflate_no_context_takeover) 312 313 self._rfc1979_inflater = util._RFC1979Inflater() 314 315 self._bfinal = False 316 317 self._compress_outgoing_enabled = False 318 319 # True if a message is fragmented and compression is ongoing. 320 self._compress_ongoing = False 321 322 # Calculates 323 # (Total outgoing bytes supplied to this filter) / 324 # (Total bytes sent to the network after applying this filter) 325 self._outgoing_average_ratio_calculator = _AverageRatioCalculator() 326 327 # Calculates 328 # (Total bytes received from the network) / 329 # (Total incoming bytes obtained after applying this filter) 330 self._incoming_average_ratio_calculator = _AverageRatioCalculator() 331 332 def set_bfinal(self, value): 333 self._bfinal = value 334 335 def set_compress_outgoing_enabled(self, value): 336 self._compress_outgoing_enabled = value 337 338 def _process_incoming_message(self, message, decompress): 339 if not decompress: 340 return message 341 342 received_payload_size = len(message) 343 self._incoming_average_ratio_calculator.add_result_bytes( 344 received_payload_size) 345 346 message = self._rfc1979_inflater.filter(message) 347 348 filtered_payload_size = len(message) 349 self._incoming_average_ratio_calculator.add_original_bytes( 350 filtered_payload_size) 351 352 _log_incoming_compression_ratio( 353 self._logger, received_payload_size, filtered_payload_size, 354 self._incoming_average_ratio_calculator.get_average_ratio()) 355 356 return message 357 358 def _process_outgoing_message(self, message, end, binary): 359 if not binary: 360 message = message.encode('utf-8') 361 362 if not self._compress_outgoing_enabled: 363 return message 364 365 original_payload_size = len(message) 366 self._outgoing_average_ratio_calculator.add_original_bytes( 367 original_payload_size) 368 369 message = self._rfc1979_deflater.filter(message, 370 end=end, 371 bfinal=self._bfinal) 372 373 filtered_payload_size = len(message) 374 self._outgoing_average_ratio_calculator.add_result_bytes( 375 filtered_payload_size) 376 377 _log_outgoing_compression_ratio( 378 self._logger, original_payload_size, filtered_payload_size, 379 self._outgoing_average_ratio_calculator.get_average_ratio()) 380 381 if not self._compress_ongoing: 382 self._outgoing_frame_filter.set_compression_bit() 383 self._compress_ongoing = not end 384 return message 385 386 def _process_incoming_frame(self, frame): 387 if frame.rsv1 == 1 and not common.is_control_opcode(frame.opcode): 388 self._incoming_message_filter.decompress_next_message() 389 frame.rsv1 = 0 390 391 def _process_outgoing_frame(self, frame, compression_bit): 392 if (not compression_bit or common.is_control_opcode(frame.opcode)): 393 return 394 395 frame.rsv1 = 1 396 397 def setup_stream_options(self, stream_options): 398 """Creates filters and sets them to the StreamOptions.""" 399 class _OutgoingMessageFilter(object): 400 def __init__(self, parent): 401 self._parent = parent 402 403 def filter(self, message, end=True, binary=False): 404 return self._parent._process_outgoing_message( 405 message, end, binary) 406 407 class _IncomingMessageFilter(object): 408 def __init__(self, parent): 409 self._parent = parent 410 self._decompress_next_message = False 411 412 def decompress_next_message(self): 413 self._decompress_next_message = True 414 415 def filter(self, message): 416 message = self._parent._process_incoming_message( 417 message, self._decompress_next_message) 418 self._decompress_next_message = False 419 return message 420 421 self._outgoing_message_filter = _OutgoingMessageFilter(self) 422 self._incoming_message_filter = _IncomingMessageFilter(self) 423 stream_options.outgoing_message_filters.append( 424 self._outgoing_message_filter) 425 stream_options.incoming_message_filters.append( 426 self._incoming_message_filter) 427 428 class _OutgoingFrameFilter(object): 429 def __init__(self, parent): 430 self._parent = parent 431 self._set_compression_bit = False 432 433 def set_compression_bit(self): 434 self._set_compression_bit = True 435 436 def filter(self, frame): 437 self._parent._process_outgoing_frame(frame, 438 self._set_compression_bit) 439 self._set_compression_bit = False 440 441 class _IncomingFrameFilter(object): 442 def __init__(self, parent): 443 self._parent = parent 444 445 def filter(self, frame): 446 self._parent._process_incoming_frame(frame) 447 448 self._outgoing_frame_filter = _OutgoingFrameFilter(self) 449 self._incoming_frame_filter = _IncomingFrameFilter(self) 450 stream_options.outgoing_frame_filters.append( 451 self._outgoing_frame_filter) 452 stream_options.incoming_frame_filters.append( 453 self._incoming_frame_filter) 454 455 stream_options.encode_text_message_to_utf8 = False 456 457 458_available_processors[common.PERMESSAGE_DEFLATE_EXTENSION] = ( 459 PerMessageDeflateExtensionProcessor) 460 461 462def get_extension_processor(extension_request): 463 """Given an ExtensionParameter representing an extension offer received 464 from a client, configures and returns an instance of the corresponding 465 extension processor class. 466 """ 467 468 processor_class = _available_processors.get(extension_request.name()) 469 if processor_class is None: 470 return None 471 return processor_class(extension_request) 472 473 474# vi:sts=4 sw=4 et 475