1# -*- coding: utf-8 -*-
2"""
3hpack/huffman_decoder
4~~~~~~~~~~~~~~~~~~~~~
5
6An implementation of a bitwise prefix tree specially built for decoding
7Huffman-coded content where we already know the Huffman table.
8"""
9from .compat import to_byte, decode_hex
10from .exceptions import HPACKDecodingError
11
12def _pad_binary(bin_str, req_len=8):
13    """
14    Given a binary string (returned by bin()), pad it to a full byte length.
15    """
16    bin_str = bin_str[2:]  # Strip the 0b prefix
17    return max(0, req_len - len(bin_str)) * '0' + bin_str
18
19def _hex_to_bin_str(hex_string):
20    """
21    Given a Python bytestring, returns a string representing those bytes in
22    unicode form.
23    """
24    unpadded_bin_string_list = (bin(to_byte(c)) for c in hex_string)
25    padded_bin_string_list = map(_pad_binary, unpadded_bin_string_list)
26    bitwise_message = "".join(padded_bin_string_list)
27    return bitwise_message
28
29
30class HuffmanDecoder(object):
31    """
32    Decodes a Huffman-coded bytestream according to the Huffman table laid out
33    in the HPACK specification.
34    """
35    class _Node(object):
36        def __init__(self, data):
37            self.data = data
38            self.mapping = {}
39
40    def __init__(self, huffman_code_list, huffman_code_list_lengths):
41        self.root = self._Node(None)
42        for index, (huffman_code, code_length) in enumerate(zip(huffman_code_list, huffman_code_list_lengths)):
43            self._insert(huffman_code, code_length, index)
44
45    def _insert(self, hex_number, hex_length, letter):
46        """
47        Inserts a Huffman code point into the tree.
48        """
49        hex_number = _pad_binary(bin(hex_number), hex_length)
50        cur_node = self.root
51        for digit in hex_number:
52            if digit not in cur_node.mapping:
53                cur_node.mapping[digit] = self._Node(None)
54            cur_node = cur_node.mapping[digit]
55        cur_node.data = letter
56
57    def decode(self, encoded_string):
58        """
59        Decode the given Huffman coded string.
60        """
61        number = _hex_to_bin_str(encoded_string)
62        cur_node = self.root
63        decoded_message = bytearray()
64
65        try:
66            for digit in number:
67                cur_node = cur_node.mapping[digit]
68                if cur_node.data is not None:
69                    # If we get EOS, everything else is padding.
70                    if cur_node.data == 256:
71                        break
72
73                    decoded_message.append(cur_node.data)
74                    cur_node = self.root
75        except KeyError:
76            # We have a Huffman-coded string that doesn't match our trie. This
77            # is pretty bad: raise a useful exception.
78            raise HPACKDecodingError("Invalid Huffman-coded string received.")
79        return bytes(decoded_message)
80
81
82class HuffmanEncoder(object):
83    """
84    Encodes a string according to the Huffman encoding table defined in the
85    HPACK specification.
86    """
87    def __init__(self, huffman_code_list, huffman_code_list_lengths):
88        self.huffman_code_list = huffman_code_list
89        self.huffman_code_list_lengths = huffman_code_list_lengths
90
91    def encode(self, bytes_to_encode):
92        """
93        Given a string of bytes, encodes them according to the HPACK Huffman
94        specification.
95        """
96        # If handed the empty string, just immediately return.
97        if not bytes_to_encode:
98            return b''
99
100        final_num = 0
101        final_int_len = 0
102
103        # Turn each byte into its huffman code. These codes aren't necessarily
104        # octet aligned, so keep track of how far through an octet we are. To
105        # handle this cleanly, just use a single giant integer.
106        for char in bytes_to_encode:
107            byte = to_byte(char)
108            bin_int_len = self.huffman_code_list_lengths[byte]
109            bin_int = self.huffman_code_list[byte] & (2 ** (bin_int_len + 1) - 1)
110            final_num <<= bin_int_len
111            final_num |= bin_int
112            final_int_len += bin_int_len
113
114        # Pad out to an octet with ones.
115        bits_to_be_padded = (8 - (final_int_len % 8)) % 8
116        final_num <<= bits_to_be_padded
117        final_num |= (1 << (bits_to_be_padded)) - 1
118
119        # Convert the number to hex and strip off the leading '0x' and the
120        # trailing 'L', if present.
121        final_num = hex(final_num)[2:].rstrip('L')
122
123        # If this is odd, prepend a zero.
124        final_num = '0' + final_num if len(final_num) % 2 != 0 else final_num
125
126        # This number should have twice as many digits as bytes. If not, we're
127        # missing some leading zeroes. Work out how many bytes we want and how
128        # many digits we have, then add the missing zero digits to the front.
129        total_bytes = (final_int_len + bits_to_be_padded) // 8
130        expected_digits = total_bytes * 2
131
132        if len(final_num) != expected_digits:
133            missing_digits = expected_digits - len(final_num)
134            final_num = ('0' * missing_digits) + final_num
135
136        return decode_hex(final_num)
137