1# -*- coding: utf-8 -*-
2"""
3hpack/hpack
4~~~~~~~~~~~
5
6Implements the HPACK header compression algorithm as detailed by the IETF.
7"""
8import logging
9
10from .table import HeaderTable, table_entry_size
11from .compat import to_byte, to_bytes
12from .exceptions import (
13    HPACKDecodingError, OversizedHeaderListError, InvalidTableSizeError
14)
15from .huffman import HuffmanEncoder
16from .huffman_constants import (
17    REQUEST_CODES, REQUEST_CODES_LENGTH
18)
19from .huffman_table import decode_huffman
20from .struct import HeaderTuple, NeverIndexedHeaderTuple
21
22log = logging.getLogger(__name__)
23
24INDEX_NONE = b'\x00'
25INDEX_NEVER = b'\x10'
26INDEX_INCREMENTAL = b'\x40'
27
28# Precompute 2^i for 1-8 for use in prefix calcs.
29# Zero index is not used but there to save a subtraction
30# as prefix numbers are not zero indexed.
31_PREFIX_BIT_MAX_NUMBERS = [(2 ** i) - 1 for i in range(9)]
32
33try:  # pragma: no cover
34    basestring = basestring
35except NameError:  # pragma: no cover
36    basestring = (str, bytes)
37
38
39# We default the maximum header list we're willing to accept to 64kB. That's a
40# lot of headers, but if applications want to raise it they can do.
41DEFAULT_MAX_HEADER_LIST_SIZE = 2 ** 16
42
43
44def _unicode_if_needed(header, raw):
45    """
46    Provides a header as a unicode string if raw is False, otherwise returns
47    it as a bytestring.
48    """
49    name = to_bytes(header[0])
50    value = to_bytes(header[1])
51    if not raw:
52        name = name.decode('utf-8')
53        value = value.decode('utf-8')
54    return header.__class__(name, value)
55
56
57def encode_integer(integer, prefix_bits):
58    """
59    This encodes an integer according to the wacky integer encoding rules
60    defined in the HPACK spec.
61    """
62    log.debug("Encoding %d with %d bits", integer, prefix_bits)
63
64    if integer < 0:
65        raise ValueError(
66            "Can only encode positive integers, got %s" % integer
67        )
68
69    if prefix_bits < 1 or prefix_bits > 8:
70        raise ValueError(
71            "Prefix bits must be between 1 and 8, got %s" % prefix_bits
72        )
73
74    max_number = _PREFIX_BIT_MAX_NUMBERS[prefix_bits]
75
76    if integer < max_number:
77        return bytearray([integer])  # Seriously?
78    else:
79        elements = [max_number]
80        integer -= max_number
81
82        while integer >= 128:
83            elements.append((integer & 127) + 128)
84            integer >>= 7
85
86        elements.append(integer)
87
88        return bytearray(elements)
89
90
91def decode_integer(data, prefix_bits):
92    """
93    This decodes an integer according to the wacky integer encoding rules
94    defined in the HPACK spec. Returns a tuple of the decoded integer and the
95    number of bytes that were consumed from ``data`` in order to get that
96    integer.
97    """
98    if prefix_bits < 1 or prefix_bits > 8:
99        raise ValueError(
100            "Prefix bits must be between 1 and 8, got %s" % prefix_bits
101        )
102
103    max_number = _PREFIX_BIT_MAX_NUMBERS[prefix_bits]
104    index = 1
105    shift = 0
106    mask = (0xFF >> (8 - prefix_bits))
107
108    try:
109        number = to_byte(data[0]) & mask
110        if number == max_number:
111            while True:
112                next_byte = to_byte(data[index])
113                index += 1
114
115                if next_byte >= 128:
116                    number += (next_byte - 128) << shift
117                else:
118                    number += next_byte << shift
119                    break
120                shift += 7
121
122    except IndexError:
123        raise HPACKDecodingError(
124            "Unable to decode HPACK integer representation from %r" % data
125        )
126
127    log.debug("Decoded %d, consumed %d bytes", number, index)
128
129    return number, index
130
131
132def _dict_to_iterable(header_dict):
133    """
134    This converts a dictionary to an iterable of two-tuples. This is a
135    HPACK-specific function becuase it pulls "special-headers" out first and
136    then emits them.
137    """
138    assert isinstance(header_dict, dict)
139    keys = sorted(
140        header_dict.keys(),
141        key=lambda k: not _to_bytes(k).startswith(b':')
142    )
143    for key in keys:
144        yield key, header_dict[key]
145
146
147def _to_bytes(string):
148    """
149    Convert string to bytes.
150    """
151    if not isinstance(string, basestring):  # pragma: no cover
152        string = str(string)
153
154    return string if isinstance(string, bytes) else string.encode('utf-8')
155
156
157class Encoder(object):
158    """
159    An HPACK encoder object. This object takes HTTP headers and emits encoded
160    HTTP/2 header blocks.
161    """
162
163    def __init__(self):
164        self.header_table = HeaderTable()
165        self.huffman_coder = HuffmanEncoder(
166            REQUEST_CODES, REQUEST_CODES_LENGTH
167        )
168        self.table_size_changes = []
169
170    @property
171    def header_table_size(self):
172        """
173        Controls the size of the HPACK header table.
174        """
175        return self.header_table.maxsize
176
177    @header_table_size.setter
178    def header_table_size(self, value):
179        self.header_table.maxsize = value
180        if self.header_table.resized:
181            self.table_size_changes.append(value)
182
183    def encode(self, headers, huffman=True):
184        """
185        Takes a set of headers and encodes them into a HPACK-encoded header
186        block.
187
188        :param headers: The headers to encode. Must be either an iterable of
189                        tuples, an iterable of :class:`HeaderTuple
190                        <hpack.struct.HeaderTuple>`, or a ``dict``.
191
192                        If an iterable of tuples, the tuples may be either
193                        two-tuples or three-tuples. If they are two-tuples, the
194                        tuples must be of the format ``(name, value)``. If they
195                        are three-tuples, they must be of the format
196                        ``(name, value, sensitive)``, where ``sensitive`` is a
197                        boolean value indicating whether the header should be
198                        added to header tables anywhere. If not present,
199                        ``sensitive`` defaults to ``False``.
200
201                        If an iterable of :class:`HeaderTuple
202                        <hpack.struct.HeaderTuple>`, the tuples must always be
203                        two-tuples. Instead of using ``sensitive`` as a third
204                        tuple entry, use :class:`NeverIndexedHeaderTuple
205                        <hpack.struct.NeverIndexedHeaderTuple>` to request that
206                        the field never be indexed.
207
208                        .. warning:: HTTP/2 requires that all special headers
209                            (headers whose names begin with ``:`` characters)
210                            appear at the *start* of the header block. While
211                            this method will ensure that happens for ``dict``
212                            subclasses, callers using any other iterable of
213                            tuples **must** ensure they place their special
214                            headers at the start of the iterable.
215
216                            For efficiency reasons users should prefer to use
217                            iterables of two-tuples: fixing the ordering of
218                            dictionary headers is an expensive operation that
219                            should be avoided if possible.
220
221        :param huffman: (optional) Whether to Huffman-encode any header sent as
222                        a literal value. Except for use when debugging, it is
223                        recommended that this be left enabled.
224
225        :returns: A bytestring containing the HPACK-encoded header block.
226        """
227        # Transforming the headers into a header block is a procedure that can
228        # be modeled as a chain or pipe. First, the headers are encoded. This
229        # encoding can be done a number of ways. If the header name-value pair
230        # are already in the header table we can represent them using the
231        # indexed representation: the same is true if they are in the static
232        # table. Otherwise, a literal representation will be used.
233        header_block = []
234
235        # Turn the headers into a list of tuples if possible. This is the
236        # natural way to interact with them in HPACK. Because dictionaries are
237        # un-ordered, we need to make sure we grab the "special" headers first.
238        if isinstance(headers, dict):
239            headers = _dict_to_iterable(headers)
240
241        # Before we begin, if the header table size has been changed we need
242        # to signal all changes since last emission appropriately.
243        if self.header_table.resized:
244            header_block.append(self._encode_table_size_change())
245            self.header_table.resized = False
246
247        # Add each header to the header block
248        for header in headers:
249            sensitive = False
250            if isinstance(header, HeaderTuple):
251                sensitive = not header.indexable
252            elif len(header) > 2:
253                sensitive = header[2]
254
255            header = (_to_bytes(header[0]), _to_bytes(header[1]))
256            header_block.append(self.add(header, sensitive, huffman))
257
258        header_block = b''.join(header_block)
259
260        log.debug("Encoded header block to %s", header_block)
261
262        return header_block
263
264    def add(self, to_add, sensitive, huffman=False):
265        """
266        This function takes a header key-value tuple and serializes it.
267        """
268        log.debug("Adding %s to the header table", to_add)
269
270        name, value = to_add
271
272        # Set our indexing mode
273        indexbit = INDEX_INCREMENTAL if not sensitive else INDEX_NEVER
274
275        # Search for a matching header in the header table.
276        match = self.header_table.search(name, value)
277
278        if match is None:
279            # Not in the header table. Encode using the literal syntax,
280            # and add it to the header table.
281            encoded = self._encode_literal(name, value, indexbit, huffman)
282            if not sensitive:
283                self.header_table.add(name, value)
284            return encoded
285
286        # The header is in the table, break out the values. If we matched
287        # perfectly, we can use the indexed representation: otherwise we
288        # can use the indexed literal.
289        index, name, perfect = match
290
291        if perfect:
292            # Indexed representation.
293            encoded = self._encode_indexed(index)
294        else:
295            # Indexed literal. We are going to add header to the
296            # header table unconditionally. It is a future todo to
297            # filter out headers which are known to be ineffective for
298            # indexing since they just take space in the table and
299            # pushed out other valuable headers.
300            encoded = self._encode_indexed_literal(
301                index, value, indexbit, huffman
302            )
303            if not sensitive:
304                self.header_table.add(name, value)
305
306        return encoded
307
308    def _encode_indexed(self, index):
309        """
310        Encodes a header using the indexed representation.
311        """
312        field = encode_integer(index, 7)
313        field[0] |= 0x80  # we set the top bit
314        return bytes(field)
315
316    def _encode_literal(self, name, value, indexbit, huffman=False):
317        """
318        Encodes a header with a literal name and literal value. If ``indexing``
319        is True, the header will be added to the header table: otherwise it
320        will not.
321        """
322        if huffman:
323            name = self.huffman_coder.encode(name)
324            value = self.huffman_coder.encode(value)
325
326        name_len = encode_integer(len(name), 7)
327        value_len = encode_integer(len(value), 7)
328
329        if huffman:
330            name_len[0] |= 0x80
331            value_len[0] |= 0x80
332
333        return b''.join(
334            [indexbit, bytes(name_len), name, bytes(value_len), value]
335        )
336
337    def _encode_indexed_literal(self, index, value, indexbit, huffman=False):
338        """
339        Encodes a header with an indexed name and a literal value and performs
340        incremental indexing.
341        """
342        if indexbit != INDEX_INCREMENTAL:
343            prefix = encode_integer(index, 4)
344        else:
345            prefix = encode_integer(index, 6)
346
347        prefix[0] |= ord(indexbit)
348
349        if huffman:
350            value = self.huffman_coder.encode(value)
351
352        value_len = encode_integer(len(value), 7)
353
354        if huffman:
355            value_len[0] |= 0x80
356
357        return b''.join([bytes(prefix), bytes(value_len), value])
358
359    def _encode_table_size_change(self):
360        """
361        Produces the encoded form of all header table size change context
362        updates.
363        """
364        block = b''
365        for size_bytes in self.table_size_changes:
366            size_bytes = encode_integer(size_bytes, 5)
367            size_bytes[0] |= 0x20
368            block += bytes(size_bytes)
369        self.table_size_changes = []
370        return block
371
372
373class Decoder(object):
374    """
375    An HPACK decoder object.
376
377    .. versionchanged:: 2.3.0
378       Added ``max_header_list_size`` argument.
379
380    :param max_header_list_size: The maximum decompressed size we will allow
381        for any single header block. This is a protection against DoS attacks
382        that attempt to force the application to expand a relatively small
383        amount of data into a really large header list, allowing enormous
384        amounts of memory to be allocated.
385
386        If this amount of data is exceeded, a `OversizedHeaderListError
387        <hpack.OversizedHeaderListError>` exception will be raised. At this
388        point the connection should be shut down, as the HPACK state will no
389        longer be useable.
390
391        Defaults to 64kB.
392    :type max_header_list_size: ``int``
393    """
394    def __init__(self, max_header_list_size=DEFAULT_MAX_HEADER_LIST_SIZE):
395        self.header_table = HeaderTable()
396
397        #: The maximum decompressed size we will allow for any single header
398        #: block. This is a protection against DoS attacks that attempt to
399        #: force the application to expand a relatively small amount of data
400        #: into a really large header list, allowing enormous amounts of memory
401        #: to be allocated.
402        #:
403        #: If this amount of data is exceeded, a `OversizedHeaderListError
404        #: <hpack.OversizedHeaderListError>` exception will be raised. At this
405        #: point the connection should be shut down, as the HPACK state will no
406        #: longer be usable.
407        #:
408        #: Defaults to 64kB.
409        #:
410        #: .. versionadded:: 2.3.0
411        self.max_header_list_size = max_header_list_size
412
413        #: Maximum allowed header table size.
414        #:
415        #: A HTTP/2 implementation should set this to the most recent value of
416        #: SETTINGS_HEADER_TABLE_SIZE that it sent *and has received an ACK
417        #: for*. Once this setting is set, the actual header table size will be
418        #: checked at the end of each decoding run and whenever it is changed,
419        #: to confirm that it fits in this size.
420        self.max_allowed_table_size = self.header_table.maxsize
421
422    @property
423    def header_table_size(self):
424        """
425        Controls the size of the HPACK header table.
426        """
427        return self.header_table.maxsize
428
429    @header_table_size.setter
430    def header_table_size(self, value):
431        self.header_table.maxsize = value
432
433    def decode(self, data, raw=False):
434        """
435        Takes an HPACK-encoded header block and decodes it into a header set.
436
437        :param data: A bytestring representing a complete HPACK-encoded header
438                     block.
439        :param raw: (optional) Whether to return the headers as tuples of raw
440                    byte strings or to decode them as UTF-8 before returning
441                    them. The default value is False, which returns tuples of
442                    Unicode strings
443        :returns: A list of two-tuples of ``(name, value)`` representing the
444                  HPACK-encoded headers, in the order they were decoded.
445        :raises HPACKDecodingError: If an error is encountered while decoding
446                                    the header block.
447        """
448        log.debug("Decoding %s", data)
449
450        data_mem = memoryview(data)
451        headers = []
452        data_len = len(data)
453        inflated_size = 0
454        current_index = 0
455
456        while current_index < data_len:
457            # Work out what kind of header we're decoding.
458            # If the high bit is 1, it's an indexed field.
459            current = to_byte(data[current_index])
460            indexed = True if current & 0x80 else False
461
462            # Otherwise, if the second-highest bit is 1 it's a field that does
463            # alter the header table.
464            literal_index = True if current & 0x40 else False
465
466            # Otherwise, if the third-highest bit is 1 it's an encoding context
467            # update.
468            encoding_update = True if current & 0x20 else False
469
470            if indexed:
471                header, consumed = self._decode_indexed(
472                    data_mem[current_index:]
473                )
474            elif literal_index:
475                # It's a literal header that does affect the header table.
476                header, consumed = self._decode_literal_index(
477                    data_mem[current_index:]
478                )
479            elif encoding_update:
480                # It's an update to the encoding context. These are forbidden
481                # in a header block after any actual header.
482                if headers:
483                    raise HPACKDecodingError(
484                        "Table size update not at the start of the block"
485                    )
486                consumed = self._update_encoding_context(
487                    data_mem[current_index:]
488                )
489                header = None
490            else:
491                # It's a literal header that does not affect the header table.
492                header, consumed = self._decode_literal_no_index(
493                    data_mem[current_index:]
494                )
495
496            if header:
497                headers.append(header)
498                inflated_size += table_entry_size(*header)
499
500                if inflated_size > self.max_header_list_size:
501                    raise OversizedHeaderListError(
502                        "A header list larger than %d has been received" %
503                        self.max_header_list_size
504                    )
505
506            current_index += consumed
507
508        # Confirm that the table size is lower than the maximum. We do this
509        # here to ensure that we catch when the max has been *shrunk* and the
510        # remote peer hasn't actually done that.
511        self._assert_valid_table_size()
512
513        try:
514            return [_unicode_if_needed(h, raw) for h in headers]
515        except UnicodeDecodeError:
516            raise HPACKDecodingError("Unable to decode headers as UTF-8.")
517
518    def _assert_valid_table_size(self):
519        """
520        Check that the table size set by the encoder is lower than the maximum
521        we expect to have.
522        """
523        if self.header_table_size > self.max_allowed_table_size:
524            raise InvalidTableSizeError(
525                "Encoder did not shrink table size to within the max"
526            )
527
528    def _update_encoding_context(self, data):
529        """
530        Handles a byte that updates the encoding context.
531        """
532        # We've been asked to resize the header table.
533        new_size, consumed = decode_integer(data, 5)
534        if new_size > self.max_allowed_table_size:
535            raise InvalidTableSizeError(
536                "Encoder exceeded max allowable table size"
537            )
538        self.header_table_size = new_size
539        return consumed
540
541    def _decode_indexed(self, data):
542        """
543        Decodes a header represented using the indexed representation.
544        """
545        index, consumed = decode_integer(data, 7)
546        header = HeaderTuple(*self.header_table.get_by_index(index))
547        log.debug("Decoded %s, consumed %d", header, consumed)
548        return header, consumed
549
550    def _decode_literal_no_index(self, data):
551        return self._decode_literal(data, False)
552
553    def _decode_literal_index(self, data):
554        return self._decode_literal(data, True)
555
556    def _decode_literal(self, data, should_index):
557        """
558        Decodes a header represented with a literal.
559        """
560        total_consumed = 0
561
562        # When should_index is true, if the low six bits of the first byte are
563        # nonzero, the header name is indexed.
564        # When should_index is false, if the low four bits of the first byte
565        # are nonzero the header name is indexed.
566        if should_index:
567            indexed_name = to_byte(data[0]) & 0x3F
568            name_len = 6
569            not_indexable = False
570        else:
571            high_byte = to_byte(data[0])
572            indexed_name = high_byte & 0x0F
573            name_len = 4
574            not_indexable = high_byte & 0x10
575
576        if indexed_name:
577            # Indexed header name.
578            index, consumed = decode_integer(data, name_len)
579            name = self.header_table.get_by_index(index)[0]
580
581            total_consumed = consumed
582            length = 0
583        else:
584            # Literal header name. The first byte was consumed, so we need to
585            # move forward.
586            data = data[1:]
587
588            length, consumed = decode_integer(data, 7)
589            name = data[consumed:consumed + length]
590            if len(name) != length:
591                raise HPACKDecodingError("Truncated header block")
592
593            if to_byte(data[0]) & 0x80:
594                name = decode_huffman(name)
595            total_consumed = consumed + length + 1  # Since we moved forward 1.
596
597        data = data[consumed + length:]
598
599        # The header value is definitely length-based.
600        length, consumed = decode_integer(data, 7)
601        value = data[consumed:consumed + length]
602        if len(value) != length:
603            raise HPACKDecodingError("Truncated header block")
604
605        if to_byte(data[0]) & 0x80:
606            value = decode_huffman(value)
607
608        # Updated the total consumed length.
609        total_consumed += length + consumed
610
611        # If we have been told never to index the header field, encode that in
612        # the tuple we use.
613        if not_indexable:
614            header = NeverIndexedHeaderTuple(name, value)
615        else:
616            header = HeaderTuple(name, value)
617
618        # If we've been asked to index this, add it to the header table.
619        if should_index:
620            self.header_table.add(name, value)
621
622        log.debug(
623            "Decoded %s, total consumed %d bytes, indexed %s",
624            header,
625            total_consumed,
626            should_index
627        )
628
629        return header, total_consumed
630