1 use std::cmp;
2 /// Length-limited Huffman Codes
3 ///
4 use std::io;
5 
6 use bit;
7 
8 const MAX_BITWIDTH: u8 = 15;
9 
10 #[derive(Debug, Clone, PartialEq, Eq)]
11 pub struct Code {
12     pub width: u8,
13     pub bits: u16,
14 }
15 impl Code {
new(width: u8, bits: u16) -> Self16     pub fn new(width: u8, bits: u16) -> Self {
17         debug_assert!(width <= MAX_BITWIDTH);
18         Code { width, bits }
19     }
inverse_endian(&self) -> Self20     fn inverse_endian(&self) -> Self {
21         let mut f = self.bits;
22         let mut t = 0;
23         for _ in 0..self.width {
24             t <<= 1;
25             t |= f & 1;
26             f >>= 1;
27         }
28         Code::new(self.width, t)
29     }
30 }
31 
32 pub trait Builder: Sized {
33     type Instance;
set_mapping(&mut self, symbol: u16, code: Code) -> io::Result<()>34     fn set_mapping(&mut self, symbol: u16, code: Code) -> io::Result<()>;
finish(self) -> Self::Instance35     fn finish(self) -> Self::Instance;
restore_canonical_huffman_codes(mut self, bitwidthes: &[u8]) -> io::Result<Self::Instance>36     fn restore_canonical_huffman_codes(mut self, bitwidthes: &[u8]) -> io::Result<Self::Instance> {
37         debug_assert!(!bitwidthes.is_empty());
38 
39         let mut symbols = bitwidthes
40             .iter()
41             .enumerate()
42             .filter(|&(_, &code_bitwidth)| code_bitwidth > 0)
43             .map(|(symbol, &code_bitwidth)| (symbol as u16, code_bitwidth))
44             .collect::<Vec<_>>();
45         symbols.sort_by_key(|x| x.1);
46 
47         let mut code = 0;
48         let mut prev_width = 0;
49         for (symbol, bitwidth) in symbols {
50             code <<= bitwidth - prev_width;
51             self.set_mapping(symbol, Code::new(bitwidth, code))?;
52             code += 1;
53             prev_width = bitwidth;
54         }
55         Ok(self.finish())
56     }
57 }
58 
59 pub struct DecoderBuilder {
60     table: Vec<u16>,
61     eob_symbol: Option<u16>,
62     eob_bitwidth: u8,
63     max_bitwidth: u8,
64 }
65 impl DecoderBuilder {
new(max_bitwidth: u8, eob_symbol: Option<u16>) -> Self66     pub fn new(max_bitwidth: u8, eob_symbol: Option<u16>) -> Self {
67         debug_assert!(max_bitwidth <= MAX_BITWIDTH);
68         DecoderBuilder {
69             table: vec![u16::from(MAX_BITWIDTH) + 1; 1 << max_bitwidth],
70             eob_symbol,
71             eob_bitwidth: max_bitwidth,
72             max_bitwidth,
73         }
74     }
from_bitwidthes(bitwidthes: &[u8], eob_symbol: Option<u16>) -> io::Result<Decoder>75     pub fn from_bitwidthes(bitwidthes: &[u8], eob_symbol: Option<u16>) -> io::Result<Decoder> {
76         let builder = Self::new(bitwidthes.iter().cloned().max().unwrap_or(0), eob_symbol);
77         builder.restore_canonical_huffman_codes(bitwidthes)
78     }
79 }
80 impl Builder for DecoderBuilder {
81     type Instance = Decoder;
set_mapping(&mut self, symbol: u16, code: Code) -> io::Result<()>82     fn set_mapping(&mut self, symbol: u16, code: Code) -> io::Result<()> {
83         debug_assert!(code.width <= self.max_bitwidth);
84         if Some(symbol) == self.eob_symbol {
85             self.eob_bitwidth = code.width;
86         }
87 
88         // `bitwidth` encoded `to` value
89         let value = (symbol << 5) | u16::from(code.width);
90 
91         // Sets the mapping to all possible indices
92         let code_be = code.inverse_endian();
93         for padding in 0..(1 << (self.max_bitwidth - code.width)) {
94             let i = ((padding << code.width) | code_be.bits) as usize;
95             if self.table[i] != u16::from(MAX_BITWIDTH) + 1 {
96                 let message = format!(
97                     "Bit region conflict: i={}, old_value={}, new_value={}, symbol={}, code={:?}",
98                     i, self.table[i], value, symbol, code
99                 );
100                 return Err(io::Error::new(io::ErrorKind::InvalidData, message));
101             }
102             self.table[i] = value;
103         }
104         Ok(())
105     }
finish(self) -> Self::Instance106     fn finish(self) -> Self::Instance {
107         Decoder {
108             table: self.table,
109             eob_bitwidth: self.eob_bitwidth,
110             max_bitwidth: self.max_bitwidth,
111         }
112     }
113 }
114 
115 #[derive(Debug)]
116 pub struct Decoder {
117     table: Vec<u16>,
118     eob_bitwidth: u8,
119     max_bitwidth: u8,
120 }
121 impl Decoder {
122     #[inline(always)]
decode<R>(&self, reader: &mut bit::BitReader<R>) -> io::Result<u16> where R: io::Read,123     pub fn decode<R>(&self, reader: &mut bit::BitReader<R>) -> io::Result<u16>
124     where
125         R: io::Read,
126     {
127         let v = self.decode_unchecked(reader);
128         reader.check_last_error()?;
129         Ok(v)
130     }
131 
132     #[inline(always)]
decode_unchecked<R>(&self, reader: &mut bit::BitReader<R>) -> u16 where R: io::Read,133     pub fn decode_unchecked<R>(&self, reader: &mut bit::BitReader<R>) -> u16
134     where
135         R: io::Read,
136     {
137         let code = reader.peek_bits_unchecked(self.eob_bitwidth);
138         let mut value = self.table[code as usize];
139         let mut bitwidth = (value & 0b1_1111) as u8;
140         if bitwidth > self.eob_bitwidth {
141             let code = reader.peek_bits_unchecked(self.max_bitwidth);
142             value = self.table[code as usize];
143             bitwidth = (value & 0b1_1111) as u8;
144             if bitwidth > self.max_bitwidth {
145                 reader.set_last_error(invalid_data_error!("Invalid huffman coded stream"));
146             }
147         }
148         reader.skip_bits(bitwidth as u8);
149         value >> 5
150     }
151 }
152 
153 #[derive(Debug)]
154 pub struct EncoderBuilder {
155     table: Vec<Code>,
156 }
157 impl EncoderBuilder {
new(symbol_count: usize) -> Self158     pub fn new(symbol_count: usize) -> Self {
159         EncoderBuilder {
160             table: vec![Code::new(0, 0); symbol_count],
161         }
162     }
from_bitwidthes(bitwidthes: &[u8]) -> io::Result<Encoder>163     pub fn from_bitwidthes(bitwidthes: &[u8]) -> io::Result<Encoder> {
164         let symbol_count = bitwidthes
165             .iter()
166             .enumerate()
167             .filter(|e| *e.1 > 0)
168             .last()
169             .map_or(0, |e| e.0)
170             + 1;
171         let builder = Self::new(symbol_count);
172         builder.restore_canonical_huffman_codes(bitwidthes)
173     }
from_frequencies(symbol_frequencies: &[usize], max_bitwidth: u8) -> io::Result<Encoder>174     pub fn from_frequencies(symbol_frequencies: &[usize], max_bitwidth: u8) -> io::Result<Encoder> {
175         let max_bitwidth = cmp::min(
176             max_bitwidth,
177             ordinary_huffman_codes::calc_optimal_max_bitwidth(symbol_frequencies),
178         );
179         let code_bitwidthes = length_limited_huffman_codes::calc(max_bitwidth, symbol_frequencies);
180         Self::from_bitwidthes(&code_bitwidthes)
181     }
182 }
183 impl Builder for EncoderBuilder {
184     type Instance = Encoder;
set_mapping(&mut self, symbol: u16, code: Code) -> io::Result<()>185     fn set_mapping(&mut self, symbol: u16, code: Code) -> io::Result<()> {
186         debug_assert_eq!(self.table[symbol as usize], Code::new(0, 0));
187         self.table[symbol as usize] = code.inverse_endian();
188         Ok(())
189     }
finish(self) -> Self::Instance190     fn finish(self) -> Self::Instance {
191         Encoder { table: self.table }
192     }
193 }
194 
195 #[derive(Debug, Clone)]
196 pub struct Encoder {
197     table: Vec<Code>,
198 }
199 impl Encoder {
200     #[inline(always)]
encode<W>(&self, writer: &mut bit::BitWriter<W>, symbol: u16) -> io::Result<()> where W: io::Write,201     pub fn encode<W>(&self, writer: &mut bit::BitWriter<W>, symbol: u16) -> io::Result<()>
202     where
203         W: io::Write,
204     {
205         let code = self.lookup(symbol);
206         debug_assert_ne!(code, Code::new(0, 0));
207         writer.write_bits(code.width, code.bits)
208     }
209     #[inline(always)]
lookup(&self, symbol: u16) -> Code210     pub fn lookup(&self, symbol: u16) -> Code {
211         debug_assert!(
212             symbol < self.table.len() as u16,
213             "symbol:{}, table:{}",
214             symbol,
215             self.table.len()
216         );
217         self.table[symbol as usize].clone()
218     }
used_max_symbol(&self) -> Option<u16>219     pub fn used_max_symbol(&self) -> Option<u16> {
220         self.table
221             .iter()
222             .rev()
223             .position(|x| x.width > 0)
224             .map(|trailing_zeros| (self.table.len() - 1 - trailing_zeros) as u16)
225     }
226 }
227 
228 #[allow(dead_code)]
229 mod ordinary_huffman_codes {
230     use std::cmp;
231     use std::collections::BinaryHeap;
232 
calc_optimal_max_bitwidth(frequencies: &[usize]) -> u8233     pub fn calc_optimal_max_bitwidth(frequencies: &[usize]) -> u8 {
234         let mut heap = BinaryHeap::new();
235         for &freq in frequencies.iter().filter(|&&f| f > 0) {
236             let weight = -(freq as isize);
237             heap.push((weight, 0 as u8));
238         }
239         while heap.len() > 1 {
240             let (weight1, width1) = heap.pop().unwrap();
241             let (weight2, width2) = heap.pop().unwrap();
242             heap.push((weight1 + weight2, 1 + cmp::max(width1, width2)));
243         }
244         let max_bitwidth = heap.pop().map_or(0, |x| x.1);
245         cmp::max(1, max_bitwidth)
246     }
247 }
248 mod length_limited_huffman_codes {
249     use std::mem;
250 
251     #[derive(Debug, Clone)]
252     struct Node {
253         symbols: Vec<u16>,
254         weight: usize,
255     }
256     impl Node {
empty() -> Self257         pub fn empty() -> Self {
258             Node {
259                 symbols: vec![],
260                 weight: 0,
261             }
262         }
single(symbol: u16, weight: usize) -> Self263         pub fn single(symbol: u16, weight: usize) -> Self {
264             Node {
265                 symbols: vec![symbol],
266                 weight,
267             }
268         }
merge(&mut self, other: Self)269         pub fn merge(&mut self, other: Self) {
270             self.weight += other.weight;
271             self.symbols.extend(other.symbols);
272         }
273     }
274 
275     /// Reference: [A Fast Algorithm for Optimal Length-Limited Huffman Codes][LenLimHuff.pdf]
276     ///
277     /// [LenLimHuff.pdf]: https://www.ics.uci.edu/~dan/pubs/LenLimHuff.pdf
calc(max_bitwidth: u8, frequencies: &[usize]) -> Vec<u8>278     pub fn calc(max_bitwidth: u8, frequencies: &[usize]) -> Vec<u8> {
279         // NOTE: unoptimized implementation
280         let mut source = frequencies
281             .iter()
282             .enumerate()
283             .filter(|&(_, &f)| f > 0)
284             .map(|(symbol, &weight)| Node::single(symbol as u16, weight))
285             .collect::<Vec<_>>();
286         source.sort_by_key(|o| o.weight);
287 
288         let weighted =
289             (0..max_bitwidth - 1).fold(source.clone(), |w, _| merge(package(w), source.clone()));
290 
291         let mut code_bitwidthes = vec![0; frequencies.len()];
292         for symbol in package(weighted)
293             .into_iter()
294             .flat_map(|n| n.symbols.into_iter())
295         {
296             code_bitwidthes[symbol as usize] += 1;
297         }
298         code_bitwidthes
299     }
merge(x: Vec<Node>, y: Vec<Node>) -> Vec<Node>300     fn merge(x: Vec<Node>, y: Vec<Node>) -> Vec<Node> {
301         let mut z = Vec::with_capacity(x.len() + y.len());
302         let mut x = x.into_iter().peekable();
303         let mut y = y.into_iter().peekable();
304         loop {
305             let x_weight = x.peek().map(|s| s.weight);
306             let y_weight = y.peek().map(|s| s.weight);
307             if x_weight.is_none() {
308                 z.extend(y);
309                 break;
310             } else if y_weight.is_none() {
311                 z.extend(x);
312                 break;
313             } else if x_weight < y_weight {
314                 z.push(x.next().unwrap());
315             } else {
316                 z.push(y.next().unwrap());
317             }
318         }
319         z
320     }
package(mut nodes: Vec<Node>) -> Vec<Node>321     fn package(mut nodes: Vec<Node>) -> Vec<Node> {
322         if nodes.len() >= 2 {
323             let new_len = nodes.len() / 2;
324 
325             for i in 0..new_len {
326                 nodes[i] = mem::replace(&mut nodes[i * 2], Node::empty());
327                 let other = mem::replace(&mut nodes[i * 2 + 1], Node::empty());
328                 nodes[i].merge(other);
329             }
330             nodes.truncate(new_len);
331         }
332         nodes
333     }
334 }
335 
336 #[cfg(test)]
337 mod test {
338     #[test]
it_works()339     fn it_works() {}
340 }
341