1# Copyright 2012, Google Inc.
2# All rights reserved.
3#
4# Redistribution and use in source and binary forms, with or without
5# modification, are permitted provided that the following conditions are
6# met:
7#
8#     * Redistributions of source code must retain the above copyright
9# notice, this list of conditions and the following disclaimer.
10#     * Redistributions in binary form must reproduce the above
11# copyright notice, this list of conditions and the following disclaimer
12# in the documentation and/or other materials provided with the
13# distribution.
14#     * Neither the name of Google Inc. nor the names of its
15# contributors may be used to endorse or promote products derived from
16# this software without specific prior written permission.
17#
18# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
19# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
20# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
21# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
22# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
23# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
24# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
25# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
26# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
27# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29
30from __future__ import absolute_import
31from mod_pywebsocket import common
32from mod_pywebsocket import util
33from mod_pywebsocket.http_header_util import quote_if_necessary
34
35# The list of available server side extension processor classes.
36_available_processors = {}
37
38
39class ExtensionProcessorInterface(object):
40    def __init__(self, request):
41        self._logger = util.get_class_logger(self)
42
43        self._request = request
44        self._active = True
45
46    def request(self):
47        return self._request
48
49    def name(self):
50        return None
51
52    def check_consistency_with_other_processors(self, processors):
53        pass
54
55    def set_active(self, active):
56        self._active = active
57
58    def is_active(self):
59        return self._active
60
61    def _get_extension_response_internal(self):
62        return None
63
64    def get_extension_response(self):
65        if not self._active:
66            self._logger.debug('Extension %s is deactivated', self.name())
67            return None
68
69        response = self._get_extension_response_internal()
70        if response is None:
71            self._active = False
72        return response
73
74    def _setup_stream_options_internal(self, stream_options):
75        pass
76
77    def setup_stream_options(self, stream_options):
78        if self._active:
79            self._setup_stream_options_internal(stream_options)
80
81
82def _log_outgoing_compression_ratio(logger, original_bytes, filtered_bytes,
83                                    average_ratio):
84    # Print inf when ratio is not available.
85    ratio = float('inf')
86    if original_bytes != 0:
87        ratio = float(filtered_bytes) / original_bytes
88
89    logger.debug('Outgoing compression ratio: %f (average: %f)' %
90                 (ratio, average_ratio))
91
92
93def _log_incoming_compression_ratio(logger, received_bytes, filtered_bytes,
94                                    average_ratio):
95    # Print inf when ratio is not available.
96    ratio = float('inf')
97    if filtered_bytes != 0:
98        ratio = float(received_bytes) / filtered_bytes
99
100    logger.debug('Incoming compression ratio: %f (average: %f)' %
101                 (ratio, average_ratio))
102
103
104def _parse_window_bits(bits):
105    """Return parsed integer value iff the given string conforms to the
106    grammar of the window bits extension parameters.
107    """
108
109    if bits is None:
110        raise ValueError('Value is required')
111
112    # For non integer values such as "10.0", ValueError will be raised.
113    int_bits = int(bits)
114
115    # First condition is to drop leading zero case e.g. "08".
116    if bits != str(int_bits) or int_bits < 8 or int_bits > 15:
117        raise ValueError('Invalid value: %r' % bits)
118
119    return int_bits
120
121
122class _AverageRatioCalculator(object):
123    """Stores total bytes of original and result data, and calculates average
124    result / original ratio.
125    """
126    def __init__(self):
127        self._total_original_bytes = 0
128        self._total_result_bytes = 0
129
130    def add_original_bytes(self, value):
131        self._total_original_bytes += value
132
133    def add_result_bytes(self, value):
134        self._total_result_bytes += value
135
136    def get_average_ratio(self):
137        if self._total_original_bytes != 0:
138            return (float(self._total_result_bytes) /
139                    self._total_original_bytes)
140        else:
141            return float('inf')
142
143
144class PerMessageDeflateExtensionProcessor(ExtensionProcessorInterface):
145    """permessage-deflate extension processor.
146
147    Specification:
148    http://tools.ietf.org/html/draft-ietf-hybi-permessage-compression-08
149    """
150
151    _SERVER_MAX_WINDOW_BITS_PARAM = 'server_max_window_bits'
152    _SERVER_NO_CONTEXT_TAKEOVER_PARAM = 'server_no_context_takeover'
153    _CLIENT_MAX_WINDOW_BITS_PARAM = 'client_max_window_bits'
154    _CLIENT_NO_CONTEXT_TAKEOVER_PARAM = 'client_no_context_takeover'
155
156    def __init__(self, request):
157        """Construct PerMessageDeflateExtensionProcessor."""
158
159        ExtensionProcessorInterface.__init__(self, request)
160        self._logger = util.get_class_logger(self)
161
162        self._preferred_client_max_window_bits = None
163        self._client_no_context_takeover = False
164
165    def name(self):
166        # This method returns "deflate" (not "permessage-deflate") for
167        # compatibility.
168        return 'deflate'
169
170    def _get_extension_response_internal(self):
171        for name in self._request.get_parameter_names():
172            if name not in [
173                    self._SERVER_MAX_WINDOW_BITS_PARAM,
174                    self._SERVER_NO_CONTEXT_TAKEOVER_PARAM,
175                    self._CLIENT_MAX_WINDOW_BITS_PARAM
176            ]:
177                self._logger.debug('Unknown parameter: %r', name)
178                return None
179
180        server_max_window_bits = None
181        if self._request.has_parameter(self._SERVER_MAX_WINDOW_BITS_PARAM):
182            server_max_window_bits = self._request.get_parameter_value(
183                self._SERVER_MAX_WINDOW_BITS_PARAM)
184            try:
185                server_max_window_bits = _parse_window_bits(
186                    server_max_window_bits)
187            except ValueError as e:
188                self._logger.debug('Bad %s parameter: %r',
189                                   self._SERVER_MAX_WINDOW_BITS_PARAM, e)
190                return None
191
192        server_no_context_takeover = self._request.has_parameter(
193            self._SERVER_NO_CONTEXT_TAKEOVER_PARAM)
194        if (server_no_context_takeover and self._request.get_parameter_value(
195                self._SERVER_NO_CONTEXT_TAKEOVER_PARAM) is not None):
196            self._logger.debug('%s parameter must not have a value: %r',
197                               self._SERVER_NO_CONTEXT_TAKEOVER_PARAM,
198                               server_no_context_takeover)
199            return None
200
201        # client_max_window_bits from a client indicates whether the client can
202        # accept client_max_window_bits from a server or not.
203        client_client_max_window_bits = self._request.has_parameter(
204            self._CLIENT_MAX_WINDOW_BITS_PARAM)
205        if (client_client_max_window_bits
206                and self._request.get_parameter_value(
207                    self._CLIENT_MAX_WINDOW_BITS_PARAM) is not None):
208            self._logger.debug(
209                '%s parameter must not have a value in a '
210                'client\'s opening handshake: %r',
211                self._CLIENT_MAX_WINDOW_BITS_PARAM,
212                client_client_max_window_bits)
213            return None
214
215        self._rfc1979_deflater = util._RFC1979Deflater(
216            server_max_window_bits, server_no_context_takeover)
217
218        # Note that we prepare for incoming messages compressed with window
219        # bits upto 15 regardless of the client_max_window_bits value to be
220        # sent to the client.
221        self._rfc1979_inflater = util._RFC1979Inflater()
222
223        self._framer = _PerMessageDeflateFramer(server_max_window_bits,
224                                                server_no_context_takeover)
225        self._framer.set_bfinal(False)
226        self._framer.set_compress_outgoing_enabled(True)
227
228        response = common.ExtensionParameter(self._request.name())
229
230        if server_max_window_bits is not None:
231            response.add_parameter(self._SERVER_MAX_WINDOW_BITS_PARAM,
232                                   str(server_max_window_bits))
233
234        if server_no_context_takeover:
235            response.add_parameter(self._SERVER_NO_CONTEXT_TAKEOVER_PARAM,
236                                   None)
237
238        if self._preferred_client_max_window_bits is not None:
239            if not client_client_max_window_bits:
240                self._logger.debug(
241                    'Processor is configured to use %s but '
242                    'the client cannot accept it',
243                    self._CLIENT_MAX_WINDOW_BITS_PARAM)
244                return None
245            response.add_parameter(self._CLIENT_MAX_WINDOW_BITS_PARAM,
246                                   str(self._preferred_client_max_window_bits))
247
248        if self._client_no_context_takeover:
249            response.add_parameter(self._CLIENT_NO_CONTEXT_TAKEOVER_PARAM,
250                                   None)
251
252        self._logger.debug('Enable %s extension ('
253                           'request: server_max_window_bits=%s; '
254                           'server_no_context_takeover=%r, '
255                           'response: client_max_window_bits=%s; '
256                           'client_no_context_takeover=%r)' %
257                           (self._request.name(), server_max_window_bits,
258                            server_no_context_takeover,
259                            self._preferred_client_max_window_bits,
260                            self._client_no_context_takeover))
261
262        return response
263
264    def _setup_stream_options_internal(self, stream_options):
265        self._framer.setup_stream_options(stream_options)
266
267    def set_client_max_window_bits(self, value):
268        """If this option is specified, this class adds the
269        client_max_window_bits extension parameter to the handshake response,
270        but doesn't reduce the LZ77 sliding window size of its inflater.
271        I.e., you can use this for testing client implementation but cannot
272        reduce memory usage of this class.
273
274        If this method has been called with True and an offer without the
275        client_max_window_bits extension parameter is received,
276        - (When processing the permessage-deflate extension) this processor
277          declines the request.
278        - (When processing the permessage-compress extension) this processor
279          accepts the request.
280        """
281
282        self._preferred_client_max_window_bits = value
283
284    def set_client_no_context_takeover(self, value):
285        """If this option is specified, this class adds the
286        client_no_context_takeover extension parameter to the handshake
287        response, but doesn't reset inflater for each message. I.e., you can
288        use this for testing client implementation but cannot reduce memory
289        usage of this class.
290        """
291
292        self._client_no_context_takeover = value
293
294    def set_bfinal(self, value):
295        self._framer.set_bfinal(value)
296
297    def enable_outgoing_compression(self):
298        self._framer.set_compress_outgoing_enabled(True)
299
300    def disable_outgoing_compression(self):
301        self._framer.set_compress_outgoing_enabled(False)
302
303
304class _PerMessageDeflateFramer(object):
305    """A framer for extensions with per-message DEFLATE feature."""
306    def __init__(self, deflate_max_window_bits, deflate_no_context_takeover):
307        self._logger = util.get_class_logger(self)
308
309        self._rfc1979_deflater = util._RFC1979Deflater(
310            deflate_max_window_bits, deflate_no_context_takeover)
311
312        self._rfc1979_inflater = util._RFC1979Inflater()
313
314        self._bfinal = False
315
316        self._compress_outgoing_enabled = False
317
318        # True if a message is fragmented and compression is ongoing.
319        self._compress_ongoing = False
320
321        # Calculates
322        #     (Total outgoing bytes supplied to this filter) /
323        #     (Total bytes sent to the network after applying this filter)
324        self._outgoing_average_ratio_calculator = _AverageRatioCalculator()
325
326        # Calculates
327        #     (Total bytes received from the network) /
328        #     (Total incoming bytes obtained after applying this filter)
329        self._incoming_average_ratio_calculator = _AverageRatioCalculator()
330
331    def set_bfinal(self, value):
332        self._bfinal = value
333
334    def set_compress_outgoing_enabled(self, value):
335        self._compress_outgoing_enabled = value
336
337    def _process_incoming_message(self, message, decompress):
338        if not decompress:
339            return message
340
341        received_payload_size = len(message)
342        self._incoming_average_ratio_calculator.add_result_bytes(
343            received_payload_size)
344
345        message = self._rfc1979_inflater.filter(message)
346
347        filtered_payload_size = len(message)
348        self._incoming_average_ratio_calculator.add_original_bytes(
349            filtered_payload_size)
350
351        _log_incoming_compression_ratio(
352            self._logger, received_payload_size, filtered_payload_size,
353            self._incoming_average_ratio_calculator.get_average_ratio())
354
355        return message
356
357    def _process_outgoing_message(self, message, end, binary):
358        if not binary:
359            message = message.encode('utf-8')
360
361        if not self._compress_outgoing_enabled:
362            return message
363
364        original_payload_size = len(message)
365        self._outgoing_average_ratio_calculator.add_original_bytes(
366            original_payload_size)
367
368        message = self._rfc1979_deflater.filter(message,
369                                                end=end,
370                                                bfinal=self._bfinal)
371
372        filtered_payload_size = len(message)
373        self._outgoing_average_ratio_calculator.add_result_bytes(
374            filtered_payload_size)
375
376        _log_outgoing_compression_ratio(
377            self._logger, original_payload_size, filtered_payload_size,
378            self._outgoing_average_ratio_calculator.get_average_ratio())
379
380        if not self._compress_ongoing:
381            self._outgoing_frame_filter.set_compression_bit()
382        self._compress_ongoing = not end
383        return message
384
385    def _process_incoming_frame(self, frame):
386        if frame.rsv1 == 1 and not common.is_control_opcode(frame.opcode):
387            self._incoming_message_filter.decompress_next_message()
388            frame.rsv1 = 0
389
390    def _process_outgoing_frame(self, frame, compression_bit):
391        if (not compression_bit or common.is_control_opcode(frame.opcode)):
392            return
393
394        frame.rsv1 = 1
395
396    def setup_stream_options(self, stream_options):
397        """Creates filters and sets them to the StreamOptions."""
398        class _OutgoingMessageFilter(object):
399            def __init__(self, parent):
400                self._parent = parent
401
402            def filter(self, message, end=True, binary=False):
403                return self._parent._process_outgoing_message(
404                    message, end, binary)
405
406        class _IncomingMessageFilter(object):
407            def __init__(self, parent):
408                self._parent = parent
409                self._decompress_next_message = False
410
411            def decompress_next_message(self):
412                self._decompress_next_message = True
413
414            def filter(self, message):
415                message = self._parent._process_incoming_message(
416                    message, self._decompress_next_message)
417                self._decompress_next_message = False
418                return message
419
420        self._outgoing_message_filter = _OutgoingMessageFilter(self)
421        self._incoming_message_filter = _IncomingMessageFilter(self)
422        stream_options.outgoing_message_filters.append(
423            self._outgoing_message_filter)
424        stream_options.incoming_message_filters.append(
425            self._incoming_message_filter)
426
427        class _OutgoingFrameFilter(object):
428            def __init__(self, parent):
429                self._parent = parent
430                self._set_compression_bit = False
431
432            def set_compression_bit(self):
433                self._set_compression_bit = True
434
435            def filter(self, frame):
436                self._parent._process_outgoing_frame(frame,
437                                                     self._set_compression_bit)
438                self._set_compression_bit = False
439
440        class _IncomingFrameFilter(object):
441            def __init__(self, parent):
442                self._parent = parent
443
444            def filter(self, frame):
445                self._parent._process_incoming_frame(frame)
446
447        self._outgoing_frame_filter = _OutgoingFrameFilter(self)
448        self._incoming_frame_filter = _IncomingFrameFilter(self)
449        stream_options.outgoing_frame_filters.append(
450            self._outgoing_frame_filter)
451        stream_options.incoming_frame_filters.append(
452            self._incoming_frame_filter)
453
454        stream_options.encode_text_message_to_utf8 = False
455
456
457_available_processors[common.PERMESSAGE_DEFLATE_EXTENSION] = (
458    PerMessageDeflateExtensionProcessor)
459
460
461def get_extension_processor(extension_request):
462    """Given an ExtensionParameter representing an extension offer received
463    from a client, configures and returns an instance of the corresponding
464    extension processor class.
465    """
466
467    processor_class = _available_processors.get(extension_request.name())
468    if processor_class is None:
469        return None
470    return processor_class(extension_request)
471
472
473# vi:sts=4 sw=4 et
474