1 use std::cmp; 2 /// Length-limited Huffman Codes 3 /// 4 use std::io; 5 6 use bit; 7 8 const MAX_BITWIDTH: u8 = 15; 9 10 #[derive(Debug, Clone, PartialEq, Eq)] 11 pub struct Code { 12 pub width: u8, 13 pub bits: u16, 14 } 15 impl Code { new(width: u8, bits: u16) -> Self16 pub fn new(width: u8, bits: u16) -> Self { 17 debug_assert!(width <= MAX_BITWIDTH); 18 Code { width, bits } 19 } inverse_endian(&self) -> Self20 fn inverse_endian(&self) -> Self { 21 let mut f = self.bits; 22 let mut t = 0; 23 for _ in 0..self.width { 24 t <<= 1; 25 t |= f & 1; 26 f >>= 1; 27 } 28 Code::new(self.width, t) 29 } 30 } 31 32 pub trait Builder: Sized { 33 type Instance; set_mapping(&mut self, symbol: u16, code: Code) -> io::Result<()>34 fn set_mapping(&mut self, symbol: u16, code: Code) -> io::Result<()>; finish(self) -> Self::Instance35 fn finish(self) -> Self::Instance; restore_canonical_huffman_codes(mut self, bitwidthes: &[u8]) -> io::Result<Self::Instance>36 fn restore_canonical_huffman_codes(mut self, bitwidthes: &[u8]) -> io::Result<Self::Instance> { 37 debug_assert!(!bitwidthes.is_empty()); 38 39 let mut symbols = bitwidthes 40 .iter() 41 .enumerate() 42 .filter(|&(_, &code_bitwidth)| code_bitwidth > 0) 43 .map(|(symbol, &code_bitwidth)| (symbol as u16, code_bitwidth)) 44 .collect::<Vec<_>>(); 45 symbols.sort_by_key(|x| x.1); 46 47 let mut code = 0; 48 let mut prev_width = 0; 49 for (symbol, bitwidth) in symbols { 50 code <<= bitwidth - prev_width; 51 self.set_mapping(symbol, Code::new(bitwidth, code))?; 52 code += 1; 53 prev_width = bitwidth; 54 } 55 Ok(self.finish()) 56 } 57 } 58 59 pub struct DecoderBuilder { 60 table: Vec<u16>, 61 eob_symbol: Option<u16>, 62 eob_bitwidth: u8, 63 max_bitwidth: u8, 64 } 65 impl DecoderBuilder { new(max_bitwidth: u8, eob_symbol: Option<u16>) -> Self66 pub fn new(max_bitwidth: u8, eob_symbol: Option<u16>) -> Self { 67 debug_assert!(max_bitwidth <= MAX_BITWIDTH); 68 DecoderBuilder { 69 table: vec![u16::from(MAX_BITWIDTH) + 1; 1 << max_bitwidth], 70 eob_symbol, 71 eob_bitwidth: max_bitwidth, 72 max_bitwidth, 73 } 74 } from_bitwidthes(bitwidthes: &[u8], eob_symbol: Option<u16>) -> io::Result<Decoder>75 pub fn from_bitwidthes(bitwidthes: &[u8], eob_symbol: Option<u16>) -> io::Result<Decoder> { 76 let builder = Self::new(bitwidthes.iter().cloned().max().unwrap_or(0), eob_symbol); 77 builder.restore_canonical_huffman_codes(bitwidthes) 78 } 79 } 80 impl Builder for DecoderBuilder { 81 type Instance = Decoder; set_mapping(&mut self, symbol: u16, code: Code) -> io::Result<()>82 fn set_mapping(&mut self, symbol: u16, code: Code) -> io::Result<()> { 83 debug_assert!(code.width <= self.max_bitwidth); 84 if Some(symbol) == self.eob_symbol { 85 self.eob_bitwidth = code.width; 86 } 87 88 // `bitwidth` encoded `to` value 89 let value = (symbol << 5) | u16::from(code.width); 90 91 // Sets the mapping to all possible indices 92 let code_be = code.inverse_endian(); 93 for padding in 0..(1 << (self.max_bitwidth - code.width)) { 94 let i = ((padding << code.width) | code_be.bits) as usize; 95 if self.table[i] != u16::from(MAX_BITWIDTH) + 1 { 96 let message = format!( 97 "Bit region conflict: i={}, old_value={}, new_value={}, symbol={}, code={:?}", 98 i, self.table[i], value, symbol, code 99 ); 100 return Err(io::Error::new(io::ErrorKind::InvalidData, message)); 101 } 102 self.table[i] = value; 103 } 104 Ok(()) 105 } finish(self) -> Self::Instance106 fn finish(self) -> Self::Instance { 107 Decoder { 108 table: self.table, 109 eob_bitwidth: self.eob_bitwidth, 110 max_bitwidth: self.max_bitwidth, 111 } 112 } 113 } 114 115 #[derive(Debug)] 116 pub struct Decoder { 117 table: Vec<u16>, 118 eob_bitwidth: u8, 119 max_bitwidth: u8, 120 } 121 impl Decoder { 122 #[inline(always)] decode<R>(&self, reader: &mut bit::BitReader<R>) -> io::Result<u16> where R: io::Read,123 pub fn decode<R>(&self, reader: &mut bit::BitReader<R>) -> io::Result<u16> 124 where 125 R: io::Read, 126 { 127 let v = self.decode_unchecked(reader); 128 reader.check_last_error()?; 129 Ok(v) 130 } 131 132 #[inline(always)] decode_unchecked<R>(&self, reader: &mut bit::BitReader<R>) -> u16 where R: io::Read,133 pub fn decode_unchecked<R>(&self, reader: &mut bit::BitReader<R>) -> u16 134 where 135 R: io::Read, 136 { 137 let code = reader.peek_bits_unchecked(self.eob_bitwidth); 138 let mut value = self.table[code as usize]; 139 let mut bitwidth = (value & 0b1_1111) as u8; 140 if bitwidth > self.eob_bitwidth { 141 let code = reader.peek_bits_unchecked(self.max_bitwidth); 142 value = self.table[code as usize]; 143 bitwidth = (value & 0b1_1111) as u8; 144 if bitwidth > self.max_bitwidth { 145 reader.set_last_error(invalid_data_error!("Invalid huffman coded stream")); 146 } 147 } 148 reader.skip_bits(bitwidth as u8); 149 value >> 5 150 } 151 } 152 153 #[derive(Debug)] 154 pub struct EncoderBuilder { 155 table: Vec<Code>, 156 } 157 impl EncoderBuilder { new(symbol_count: usize) -> Self158 pub fn new(symbol_count: usize) -> Self { 159 EncoderBuilder { 160 table: vec![Code::new(0, 0); symbol_count], 161 } 162 } from_bitwidthes(bitwidthes: &[u8]) -> io::Result<Encoder>163 pub fn from_bitwidthes(bitwidthes: &[u8]) -> io::Result<Encoder> { 164 let symbol_count = bitwidthes 165 .iter() 166 .enumerate() 167 .filter(|e| *e.1 > 0) 168 .last() 169 .map_or(0, |e| e.0) 170 + 1; 171 let builder = Self::new(symbol_count); 172 builder.restore_canonical_huffman_codes(bitwidthes) 173 } from_frequencies(symbol_frequencies: &[usize], max_bitwidth: u8) -> io::Result<Encoder>174 pub fn from_frequencies(symbol_frequencies: &[usize], max_bitwidth: u8) -> io::Result<Encoder> { 175 let max_bitwidth = cmp::min( 176 max_bitwidth, 177 ordinary_huffman_codes::calc_optimal_max_bitwidth(symbol_frequencies), 178 ); 179 let code_bitwidthes = length_limited_huffman_codes::calc(max_bitwidth, symbol_frequencies); 180 Self::from_bitwidthes(&code_bitwidthes) 181 } 182 } 183 impl Builder for EncoderBuilder { 184 type Instance = Encoder; set_mapping(&mut self, symbol: u16, code: Code) -> io::Result<()>185 fn set_mapping(&mut self, symbol: u16, code: Code) -> io::Result<()> { 186 debug_assert_eq!(self.table[symbol as usize], Code::new(0, 0)); 187 self.table[symbol as usize] = code.inverse_endian(); 188 Ok(()) 189 } finish(self) -> Self::Instance190 fn finish(self) -> Self::Instance { 191 Encoder { table: self.table } 192 } 193 } 194 195 #[derive(Debug, Clone)] 196 pub struct Encoder { 197 table: Vec<Code>, 198 } 199 impl Encoder { 200 #[inline(always)] encode<W>(&self, writer: &mut bit::BitWriter<W>, symbol: u16) -> io::Result<()> where W: io::Write,201 pub fn encode<W>(&self, writer: &mut bit::BitWriter<W>, symbol: u16) -> io::Result<()> 202 where 203 W: io::Write, 204 { 205 let code = self.lookup(symbol); 206 debug_assert_ne!(code, Code::new(0, 0)); 207 writer.write_bits(code.width, code.bits) 208 } 209 #[inline(always)] lookup(&self, symbol: u16) -> Code210 pub fn lookup(&self, symbol: u16) -> Code { 211 debug_assert!( 212 symbol < self.table.len() as u16, 213 "symbol:{}, table:{}", 214 symbol, 215 self.table.len() 216 ); 217 self.table[symbol as usize].clone() 218 } used_max_symbol(&self) -> Option<u16>219 pub fn used_max_symbol(&self) -> Option<u16> { 220 self.table 221 .iter() 222 .rev() 223 .position(|x| x.width > 0) 224 .map(|trailing_zeros| (self.table.len() - 1 - trailing_zeros) as u16) 225 } 226 } 227 228 #[allow(dead_code)] 229 mod ordinary_huffman_codes { 230 use std::cmp; 231 use std::collections::BinaryHeap; 232 calc_optimal_max_bitwidth(frequencies: &[usize]) -> u8233 pub fn calc_optimal_max_bitwidth(frequencies: &[usize]) -> u8 { 234 let mut heap = BinaryHeap::new(); 235 for &freq in frequencies.iter().filter(|&&f| f > 0) { 236 let weight = -(freq as isize); 237 heap.push((weight, 0 as u8)); 238 } 239 while heap.len() > 1 { 240 let (weight1, width1) = heap.pop().unwrap(); 241 let (weight2, width2) = heap.pop().unwrap(); 242 heap.push((weight1 + weight2, 1 + cmp::max(width1, width2))); 243 } 244 let max_bitwidth = heap.pop().map_or(0, |x| x.1); 245 cmp::max(1, max_bitwidth) 246 } 247 } 248 mod length_limited_huffman_codes { 249 use std::mem; 250 251 #[derive(Debug, Clone)] 252 struct Node { 253 symbols: Vec<u16>, 254 weight: usize, 255 } 256 impl Node { empty() -> Self257 pub fn empty() -> Self { 258 Node { 259 symbols: vec![], 260 weight: 0, 261 } 262 } single(symbol: u16, weight: usize) -> Self263 pub fn single(symbol: u16, weight: usize) -> Self { 264 Node { 265 symbols: vec![symbol], 266 weight, 267 } 268 } merge(&mut self, other: Self)269 pub fn merge(&mut self, other: Self) { 270 self.weight += other.weight; 271 self.symbols.extend(other.symbols); 272 } 273 } 274 275 /// Reference: [A Fast Algorithm for Optimal Length-Limited Huffman Codes][LenLimHuff.pdf] 276 /// 277 /// [LenLimHuff.pdf]: https://www.ics.uci.edu/~dan/pubs/LenLimHuff.pdf calc(max_bitwidth: u8, frequencies: &[usize]) -> Vec<u8>278 pub fn calc(max_bitwidth: u8, frequencies: &[usize]) -> Vec<u8> { 279 // NOTE: unoptimized implementation 280 let mut source = frequencies 281 .iter() 282 .enumerate() 283 .filter(|&(_, &f)| f > 0) 284 .map(|(symbol, &weight)| Node::single(symbol as u16, weight)) 285 .collect::<Vec<_>>(); 286 source.sort_by_key(|o| o.weight); 287 288 let weighted = 289 (0..max_bitwidth - 1).fold(source.clone(), |w, _| merge(package(w), source.clone())); 290 291 let mut code_bitwidthes = vec![0; frequencies.len()]; 292 for symbol in package(weighted) 293 .into_iter() 294 .flat_map(|n| n.symbols.into_iter()) 295 { 296 code_bitwidthes[symbol as usize] += 1; 297 } 298 code_bitwidthes 299 } merge(x: Vec<Node>, y: Vec<Node>) -> Vec<Node>300 fn merge(x: Vec<Node>, y: Vec<Node>) -> Vec<Node> { 301 let mut z = Vec::with_capacity(x.len() + y.len()); 302 let mut x = x.into_iter().peekable(); 303 let mut y = y.into_iter().peekable(); 304 loop { 305 let x_weight = x.peek().map(|s| s.weight); 306 let y_weight = y.peek().map(|s| s.weight); 307 if x_weight.is_none() { 308 z.extend(y); 309 break; 310 } else if y_weight.is_none() { 311 z.extend(x); 312 break; 313 } else if x_weight < y_weight { 314 z.push(x.next().unwrap()); 315 } else { 316 z.push(y.next().unwrap()); 317 } 318 } 319 z 320 } package(mut nodes: Vec<Node>) -> Vec<Node>321 fn package(mut nodes: Vec<Node>) -> Vec<Node> { 322 if nodes.len() >= 2 { 323 let new_len = nodes.len() / 2; 324 325 for i in 0..new_len { 326 nodes[i] = mem::replace(&mut nodes[i * 2], Node::empty()); 327 let other = mem::replace(&mut nodes[i * 2 + 1], Node::empty()); 328 nodes[i].merge(other); 329 } 330 nodes.truncate(new_len); 331 } 332 nodes 333 } 334 } 335 336 #[cfg(test)] 337 mod test { 338 #[test] it_works()339 fn it_works() {} 340 } 341