1# -*- coding: utf-8 -*- 2""" 3hpack/huffman_decoder 4~~~~~~~~~~~~~~~~~~~~~ 5 6An implementation of a bitwise prefix tree specially built for decoding 7Huffman-coded content where we already know the Huffman table. 8""" 9from .compat import to_byte, decode_hex 10from .exceptions import HPACKDecodingError 11 12def _pad_binary(bin_str, req_len=8): 13 """ 14 Given a binary string (returned by bin()), pad it to a full byte length. 15 """ 16 bin_str = bin_str[2:] # Strip the 0b prefix 17 return max(0, req_len - len(bin_str)) * '0' + bin_str 18 19def _hex_to_bin_str(hex_string): 20 """ 21 Given a Python bytestring, returns a string representing those bytes in 22 unicode form. 23 """ 24 unpadded_bin_string_list = (bin(to_byte(c)) for c in hex_string) 25 padded_bin_string_list = map(_pad_binary, unpadded_bin_string_list) 26 bitwise_message = "".join(padded_bin_string_list) 27 return bitwise_message 28 29 30class HuffmanDecoder(object): 31 """ 32 Decodes a Huffman-coded bytestream according to the Huffman table laid out 33 in the HPACK specification. 34 """ 35 class _Node(object): 36 def __init__(self, data): 37 self.data = data 38 self.mapping = {} 39 40 def __init__(self, huffman_code_list, huffman_code_list_lengths): 41 self.root = self._Node(None) 42 for index, (huffman_code, code_length) in enumerate(zip(huffman_code_list, huffman_code_list_lengths)): 43 self._insert(huffman_code, code_length, index) 44 45 def _insert(self, hex_number, hex_length, letter): 46 """ 47 Inserts a Huffman code point into the tree. 48 """ 49 hex_number = _pad_binary(bin(hex_number), hex_length) 50 cur_node = self.root 51 for digit in hex_number: 52 if digit not in cur_node.mapping: 53 cur_node.mapping[digit] = self._Node(None) 54 cur_node = cur_node.mapping[digit] 55 cur_node.data = letter 56 57 def decode(self, encoded_string): 58 """ 59 Decode the given Huffman coded string. 60 """ 61 number = _hex_to_bin_str(encoded_string) 62 cur_node = self.root 63 decoded_message = bytearray() 64 65 try: 66 for digit in number: 67 cur_node = cur_node.mapping[digit] 68 if cur_node.data is not None: 69 # If we get EOS, everything else is padding. 70 if cur_node.data == 256: 71 break 72 73 decoded_message.append(cur_node.data) 74 cur_node = self.root 75 except KeyError: 76 # We have a Huffman-coded string that doesn't match our trie. This 77 # is pretty bad: raise a useful exception. 78 raise HPACKDecodingError("Invalid Huffman-coded string received.") 79 return bytes(decoded_message) 80 81 82class HuffmanEncoder(object): 83 """ 84 Encodes a string according to the Huffman encoding table defined in the 85 HPACK specification. 86 """ 87 def __init__(self, huffman_code_list, huffman_code_list_lengths): 88 self.huffman_code_list = huffman_code_list 89 self.huffman_code_list_lengths = huffman_code_list_lengths 90 91 def encode(self, bytes_to_encode): 92 """ 93 Given a string of bytes, encodes them according to the HPACK Huffman 94 specification. 95 """ 96 # If handed the empty string, just immediately return. 97 if not bytes_to_encode: 98 return b'' 99 100 final_num = 0 101 final_int_len = 0 102 103 # Turn each byte into its huffman code. These codes aren't necessarily 104 # octet aligned, so keep track of how far through an octet we are. To 105 # handle this cleanly, just use a single giant integer. 106 for char in bytes_to_encode: 107 byte = to_byte(char) 108 bin_int_len = self.huffman_code_list_lengths[byte] 109 bin_int = self.huffman_code_list[byte] & (2 ** (bin_int_len + 1) - 1) 110 final_num <<= bin_int_len 111 final_num |= bin_int 112 final_int_len += bin_int_len 113 114 # Pad out to an octet with ones. 115 bits_to_be_padded = (8 - (final_int_len % 8)) % 8 116 final_num <<= bits_to_be_padded 117 final_num |= (1 << (bits_to_be_padded)) - 1 118 119 # Convert the number to hex and strip off the leading '0x' and the 120 # trailing 'L', if present. 121 final_num = hex(final_num)[2:].rstrip('L') 122 123 # If this is odd, prepend a zero. 124 final_num = '0' + final_num if len(final_num) % 2 != 0 else final_num 125 126 # This number should have twice as many digits as bytes. If not, we're 127 # missing some leading zeroes. Work out how many bytes we want and how 128 # many digits we have, then add the missing zero digits to the front. 129 total_bytes = (final_int_len + bits_to_be_padded) // 8 130 expected_digits = total_bytes * 2 131 132 if len(final_num) != expected_digits: 133 missing_digits = expected_digits - len(final_num) 134 final_num = ('0' * missing_digits) + final_num 135 136 return decode_hex(final_num) 137