1# Copyright 2012, Google Inc.
2# All rights reserved.
3#
4# Redistribution and use in source and binary forms, with or without
5# modification, are permitted provided that the following conditions are
6# met:
7#
8#     * Redistributions of source code must retain the above copyright
9# notice, this list of conditions and the following disclaimer.
10#     * Redistributions in binary form must reproduce the above
11# copyright notice, this list of conditions and the following disclaimer
12# in the documentation and/or other materials provided with the
13# distribution.
14#     * Neither the name of Google Inc. nor the names of its
15# contributors may be used to endorse or promote products derived from
16# this software without specific prior written permission.
17#
18# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
19# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
20# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
21# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
22# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
23# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
24# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
25# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
26# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
27# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
30from __future__ import absolute_import
31from mod_pywebsocket import common
32from mod_pywebsocket import util
33from mod_pywebsocket.http_header_util import quote_if_necessary
34
35# The list of available server side extension processor classes.
36_available_processors = {}
37
38
39class ExtensionProcessorInterface(object):
40    def __init__(self, request):
41        self._logger = util.get_class_logger(self)
42
43        self._request = request
44        self._active = True
45
46    def request(self):
47        return self._request
48
49    def name(self):
50        return None
51
52    def check_consistency_with_other_processors(self, processors):
53        pass
54
55    def set_active(self, active):
56        self._active = active
57
58    def is_active(self):
59        return self._active
60
61    def _get_extension_response_internal(self):
62        return None
63
64    def get_extension_response(self):
65        if not self._active:
66            self._logger.debug('Extension %s is deactivated', self.name())
67            return None
68
69        response = self._get_extension_response_internal()
70        if response is None:
71            self._active = False
72        return response
73
74    def _setup_stream_options_internal(self, stream_options):
75        pass
76
77    def setup_stream_options(self, stream_options):
78        if self._active:
79            self._setup_stream_options_internal(stream_options)
80
81
82def _log_outgoing_compression_ratio(logger, original_bytes, filtered_bytes,
83                                    average_ratio):
84    # Print inf when ratio is not available.
85    ratio = float('inf')
86    if original_bytes != 0:
87        ratio = float(filtered_bytes) / original_bytes
88
89    logger.debug('Outgoing compression ratio: %f (average: %f)' %
90                 (ratio, average_ratio))
91
92
93def _log_incoming_compression_ratio(logger, received_bytes, filtered_bytes,
94                                    average_ratio):
95    # Print inf when ratio is not available.
96    ratio = float('inf')
97    if filtered_bytes != 0:
98        ratio = float(received_bytes) / filtered_bytes
99
100    logger.debug('Incoming compression ratio: %f (average: %f)' %
101                 (ratio, average_ratio))
102
103
104def _parse_window_bits(bits):
105    """Return parsed integer value iff the given string conforms to the
106    grammar of the window bits extension parameters.
107    """
108
109    if bits is None:
110        raise ValueError('Value is required')
111
112    # For non integer values such as "10.0", ValueError will be raised.
113    int_bits = int(bits)
114
115    # First condition is to drop leading zero case e.g. "08".
116    if bits != str(int_bits) or int_bits < 8 or int_bits > 15:
117        raise ValueError('Invalid value: %r' % bits)
118
119    return int_bits
120
121
122class _AverageRatioCalculator(object):
123    """Stores total bytes of original and result data, and calculates average
124    result / original ratio.
125    """
126    def __init__(self):
127        self._total_original_bytes = 0
128        self._total_result_bytes = 0
129
130    def add_original_bytes(self, value):
131        self._total_original_bytes += value
132
133    def add_result_bytes(self, value):
134        self._total_result_bytes += value
135
136    def get_average_ratio(self):
137        if self._total_original_bytes != 0:
138            return (float(self._total_result_bytes) /
139                    self._total_original_bytes)
140        else:
141            return float('inf')
142
143
144class PerMessageDeflateExtensionProcessor(ExtensionProcessorInterface):
145    """permessage-deflate extension processor.
146
147    Specification:
148    http://tools.ietf.org/html/draft-ietf-hybi-permessage-compression-08
149    """
150
151    _SERVER_MAX_WINDOW_BITS_PARAM = 'server_max_window_bits'
152    _SERVER_NO_CONTEXT_TAKEOVER_PARAM = 'server_no_context_takeover'
153    _CLIENT_MAX_WINDOW_BITS_PARAM = 'client_max_window_bits'
154    _CLIENT_NO_CONTEXT_TAKEOVER_PARAM = 'client_no_context_takeover'
155
156    def __init__(self, request):
157        """Construct PerMessageDeflateExtensionProcessor."""
158
159        ExtensionProcessorInterface.__init__(self, request)
160        self._logger = util.get_class_logger(self)
161
162        self._preferred_client_max_window_bits = None
163        self._client_no_context_takeover = False
164
165    def name(self):
166        # This method returns "deflate" (not "permessage-deflate") for
167        # compatibility.
168        return 'deflate'
169
170    def _get_extension_response_internal(self):
171        for name in self._request.get_parameter_names():
172            if name not in [
173                    self._SERVER_MAX_WINDOW_BITS_PARAM,
174                    self._SERVER_NO_CONTEXT_TAKEOVER_PARAM,
175                    self._CLIENT_MAX_WINDOW_BITS_PARAM
176            ]:
177                self._logger.debug('Unknown parameter: %r', name)
178                return None
179
180        server_max_window_bits = None
181        if self._request.has_parameter(self._SERVER_MAX_WINDOW_BITS_PARAM):
182            server_max_window_bits = self._request.get_parameter_value(
183                self._SERVER_MAX_WINDOW_BITS_PARAM)
184            try:
185                server_max_window_bits = _parse_window_bits(
186                    server_max_window_bits)
187            except ValueError as e:
188                self._logger.debug('Bad %s parameter: %r',
189                                   self._SERVER_MAX_WINDOW_BITS_PARAM, e)
190                return None
191
192        server_no_context_takeover = self._request.has_parameter(
193            self._SERVER_NO_CONTEXT_TAKEOVER_PARAM)
194        if (server_no_context_takeover and self._request.get_parameter_value(
195                self._SERVER_NO_CONTEXT_TAKEOVER_PARAM) is not None):
196            self._logger.debug('%s parameter must not have a value: %r',
197                               self._SERVER_NO_CONTEXT_TAKEOVER_PARAM,
198                               server_no_context_takeover)
199            return None
200
201        # client_max_window_bits from a client indicates whether the client can
202        # accept client_max_window_bits from a server or not.
203        client_client_max_window_bits = self._request.has_parameter(
204            self._CLIENT_MAX_WINDOW_BITS_PARAM)
205        if (client_client_max_window_bits
206                and self._request.get_parameter_value(
207                    self._CLIENT_MAX_WINDOW_BITS_PARAM) is not None):
208            self._logger.debug(
209                '%s parameter must not have a value in a '
210                'client\'s opening handshake: %r',
211                self._CLIENT_MAX_WINDOW_BITS_PARAM,
212                client_client_max_window_bits)
213            return None
214
215        self._rfc1979_deflater = util._RFC1979Deflater(
216            server_max_window_bits, server_no_context_takeover)
217
218        # Note that we prepare for incoming messages compressed with window
219        # bits upto 15 regardless of the client_max_window_bits value to be
220        # sent to the client.
221        self._rfc1979_inflater = util._RFC1979Inflater()
222
223        self._framer = _PerMessageDeflateFramer(server_max_window_bits,
224                                                server_no_context_takeover)
225        self._framer.set_bfinal(False)
226        self._framer.set_compress_outgoing_enabled(True)
227
228        response = common.ExtensionParameter(self._request.name())
229
230        if server_max_window_bits is not None:
231            response.add_parameter(self._SERVER_MAX_WINDOW_BITS_PARAM,
232                                   str(server_max_window_bits))
233
234        if server_no_context_takeover:
235            response.add_parameter(self._SERVER_NO_CONTEXT_TAKEOVER_PARAM,
236                                   None)
237
238        if self._preferred_client_max_window_bits is not None:
239            if not client_client_max_window_bits:
240                self._logger.debug(
241                    'Processor is configured to use %s but '
242                    'the client cannot accept it',
243                    self._CLIENT_MAX_WINDOW_BITS_PARAM)
244                return None
245            response.add_parameter(self._CLIENT_MAX_WINDOW_BITS_PARAM,
246                                   str(self._preferred_client_max_window_bits))
247
248        if self._client_no_context_takeover:
249            response.add_parameter(self._CLIENT_NO_CONTEXT_TAKEOVER_PARAM,
250                                   None)
251
252        self._logger.debug('Enable %s extension ('
253                           'request: server_max_window_bits=%s; '
254                           'server_no_context_takeover=%r, '
255                           'response: client_max_window_bits=%s; '
256                           'client_no_context_takeover=%r)' %
257                           (self._request.name(), server_max_window_bits,
258                            server_no_context_takeover,
259                            self._preferred_client_max_window_bits,
260                            self._client_no_context_takeover))
261
262        return response
263
264    def _setup_stream_options_internal(self, stream_options):
265        self._framer.setup_stream_options(stream_options)
266
267    def set_client_max_window_bits(self, value):
268        """If this option is specified, this class adds the
269        client_max_window_bits extension parameter to the handshake response,
270        but doesn't reduce the LZ77 sliding window size of its inflater.
271        I.e., you can use this for testing client implementation but cannot
272        reduce memory usage of this class.
273
274        If this method has been called with True and an offer without the
275        client_max_window_bits extension parameter is received,
276
277        - (When processing the permessage-deflate extension) this processor
278          declines the request.
279        - (When processing the permessage-compress extension) this processor
280          accepts the request.
281        """
282
283        self._preferred_client_max_window_bits = value
284
285    def set_client_no_context_takeover(self, value):
286        """If this option is specified, this class adds the
287        client_no_context_takeover extension parameter to the handshake
288        response, but doesn't reset inflater for each message. I.e., you can
289        use this for testing client implementation but cannot reduce memory
290        usage of this class.
291        """
292
293        self._client_no_context_takeover = value
294
295    def set_bfinal(self, value):
296        self._framer.set_bfinal(value)
297
298    def enable_outgoing_compression(self):
299        self._framer.set_compress_outgoing_enabled(True)
300
301    def disable_outgoing_compression(self):
302        self._framer.set_compress_outgoing_enabled(False)
303
304
305class _PerMessageDeflateFramer(object):
306    """A framer for extensions with per-message DEFLATE feature."""
307    def __init__(self, deflate_max_window_bits, deflate_no_context_takeover):
308        self._logger = util.get_class_logger(self)
309
310        self._rfc1979_deflater = util._RFC1979Deflater(
311            deflate_max_window_bits, deflate_no_context_takeover)
312
313        self._rfc1979_inflater = util._RFC1979Inflater()
314
315        self._bfinal = False
316
317        self._compress_outgoing_enabled = False
318
319        # True if a message is fragmented and compression is ongoing.
320        self._compress_ongoing = False
321
322        # Calculates
323        #     (Total outgoing bytes supplied to this filter) /
324        #     (Total bytes sent to the network after applying this filter)
325        self._outgoing_average_ratio_calculator = _AverageRatioCalculator()
326
327        # Calculates
328        #     (Total bytes received from the network) /
329        #     (Total incoming bytes obtained after applying this filter)
330        self._incoming_average_ratio_calculator = _AverageRatioCalculator()
331
332    def set_bfinal(self, value):
333        self._bfinal = value
334
335    def set_compress_outgoing_enabled(self, value):
336        self._compress_outgoing_enabled = value
337
338    def _process_incoming_message(self, message, decompress):
339        if not decompress:
340            return message
341
342        received_payload_size = len(message)
343        self._incoming_average_ratio_calculator.add_result_bytes(
344            received_payload_size)
345
346        message = self._rfc1979_inflater.filter(message)
347
348        filtered_payload_size = len(message)
349        self._incoming_average_ratio_calculator.add_original_bytes(
350            filtered_payload_size)
351
352        _log_incoming_compression_ratio(
353            self._logger, received_payload_size, filtered_payload_size,
354            self._incoming_average_ratio_calculator.get_average_ratio())
355
356        return message
357
358    def _process_outgoing_message(self, message, end, binary):
359        if not binary:
360            message = message.encode('utf-8')
361
362        if not self._compress_outgoing_enabled:
363            return message
364
365        original_payload_size = len(message)
366        self._outgoing_average_ratio_calculator.add_original_bytes(
367            original_payload_size)
368
369        message = self._rfc1979_deflater.filter(message,
370                                                end=end,
371                                                bfinal=self._bfinal)
372
373        filtered_payload_size = len(message)
374        self._outgoing_average_ratio_calculator.add_result_bytes(
375            filtered_payload_size)
376
377        _log_outgoing_compression_ratio(
378            self._logger, original_payload_size, filtered_payload_size,
379            self._outgoing_average_ratio_calculator.get_average_ratio())
380
381        if not self._compress_ongoing:
382            self._outgoing_frame_filter.set_compression_bit()
383        self._compress_ongoing = not end
384        return message
385
386    def _process_incoming_frame(self, frame):
387        if frame.rsv1 == 1 and not common.is_control_opcode(frame.opcode):
388            self._incoming_message_filter.decompress_next_message()
389            frame.rsv1 = 0
390
391    def _process_outgoing_frame(self, frame, compression_bit):
392        if (not compression_bit or common.is_control_opcode(frame.opcode)):
393            return
394
395        frame.rsv1 = 1
396
397    def setup_stream_options(self, stream_options):
398        """Creates filters and sets them to the StreamOptions."""
399        class _OutgoingMessageFilter(object):
400            def __init__(self, parent):
401                self._parent = parent
402
403            def filter(self, message, end=True, binary=False):
404                return self._parent._process_outgoing_message(
405                    message, end, binary)
406
407        class _IncomingMessageFilter(object):
408            def __init__(self, parent):
409                self._parent = parent
410                self._decompress_next_message = False
411
412            def decompress_next_message(self):
413                self._decompress_next_message = True
414
415            def filter(self, message):
416                message = self._parent._process_incoming_message(
417                    message, self._decompress_next_message)
418                self._decompress_next_message = False
419                return message
420
421        self._outgoing_message_filter = _OutgoingMessageFilter(self)
422        self._incoming_message_filter = _IncomingMessageFilter(self)
423        stream_options.outgoing_message_filters.append(
424            self._outgoing_message_filter)
425        stream_options.incoming_message_filters.append(
426            self._incoming_message_filter)
427
428        class _OutgoingFrameFilter(object):
429            def __init__(self, parent):
430                self._parent = parent
431                self._set_compression_bit = False
432
433            def set_compression_bit(self):
434                self._set_compression_bit = True
435
436            def filter(self, frame):
437                self._parent._process_outgoing_frame(frame,
438                                                     self._set_compression_bit)
439                self._set_compression_bit = False
440
441        class _IncomingFrameFilter(object):
442            def __init__(self, parent):
443                self._parent = parent
444
445            def filter(self, frame):
446                self._parent._process_incoming_frame(frame)
447
448        self._outgoing_frame_filter = _OutgoingFrameFilter(self)
449        self._incoming_frame_filter = _IncomingFrameFilter(self)
450        stream_options.outgoing_frame_filters.append(
451            self._outgoing_frame_filter)
452        stream_options.incoming_frame_filters.append(
453            self._incoming_frame_filter)
454
455        stream_options.encode_text_message_to_utf8 = False
456
457
458_available_processors[common.PERMESSAGE_DEFLATE_EXTENSION] = (
459    PerMessageDeflateExtensionProcessor)
460
461
462def get_extension_processor(extension_request):
463    """Given an ExtensionParameter representing an extension offer received
464    from a client, configures and returns an instance of the corresponding
465    extension processor class.
466    """
467
468    processor_class = _available_processors.get(extension_request.name())
469    if processor_class is None:
470        return None
471    return processor_class(extension_request)
472
473
474# vi:sts=4 sw=4 et
475