1 use crate::bit;
2 use crate::huffman;
3 use crate::huffman::Builder;
4 use crate::lz77;
5 use std::cmp;
6 use std::io;
7 use std::iter;
8 use std::ops::Range;
9 
10 const FIXED_LITERAL_OR_LENGTH_CODE_TABLE: [(u8, Range<u16>, u16); 4] = [
11     (8, 000..144, 0b0_0011_0000),
12     (9, 144..256, 0b1_1001_0000),
13     (7, 256..280, 0b0_0000_0000),
14     (8, 280..288, 0b0_1100_0000),
15 ];
16 
17 const BITWIDTH_CODE_ORDER: [usize; 19] = [
18     16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15,
19 ];
20 
21 const END_OF_BLOCK: u16 = 256;
22 
23 const LENGTH_TABLE: [(u16, u8); 29] = [
24     (3, 0),
25     (4, 0),
26     (5, 0),
27     (6, 0),
28     (7, 0),
29     (8, 0),
30     (9, 0),
31     (10, 0),
32     (11, 1),
33     (13, 1),
34     (15, 1),
35     (17, 1),
36     (19, 2),
37     (23, 2),
38     (27, 2),
39     (31, 2),
40     (35, 3),
41     (43, 3),
42     (51, 3),
43     (59, 3),
44     (67, 4),
45     (83, 4),
46     (99, 4),
47     (115, 4),
48     (131, 5),
49     (163, 5),
50     (195, 5),
51     (227, 5),
52     (258, 0),
53 ];
54 
55 const MAX_DISTANCE_CODE_COUNT: usize = 30;
56 
57 const DISTANCE_TABLE: [(u16, u8); 30] = [
58     (1, 0),
59     (2, 0),
60     (3, 0),
61     (4, 0),
62     (5, 1),
63     (7, 1),
64     (9, 2),
65     (13, 2),
66     (17, 3),
67     (25, 3),
68     (33, 4),
69     (49, 4),
70     (65, 5),
71     (97, 5),
72     (129, 6),
73     (193, 6),
74     (257, 7),
75     (385, 7),
76     (513, 8),
77     (769, 8),
78     (1025, 9),
79     (1537, 9),
80     (2049, 10),
81     (3073, 10),
82     (4097, 11),
83     (6145, 11),
84     (8193, 12),
85     (12_289, 12),
86     (16_385, 13),
87     (24_577, 13),
88 ];
89 
90 #[derive(Debug, PartialEq, Eq)]
91 pub enum Symbol {
92     EndOfBlock,
93     Code(lz77::Code),
94 }
95 impl Symbol {
code(&self) -> u1696     pub fn code(&self) -> u16 {
97         match *self {
98             Symbol::Code(lz77::Code::Literal(b)) => u16::from(b),
99             Symbol::EndOfBlock => 256,
100             Symbol::Code(lz77::Code::Pointer { length, .. }) => match length {
101                 3..=10 => 257 + length - 3,
102                 11..=18 => 265 + (length - 11) / 2,
103                 19..=34 => 269 + (length - 19) / 4,
104                 35..=66 => 273 + (length - 35) / 8,
105                 67..=130 => 277 + (length - 67) / 16,
106                 131..=257 => 281 + (length - 131) / 32,
107                 258 => 285,
108                 _ => unreachable!(),
109             },
110         }
111     }
extra_lengh(&self) -> Option<(u8, u16)>112     pub fn extra_lengh(&self) -> Option<(u8, u16)> {
113         if let Symbol::Code(lz77::Code::Pointer { length, .. }) = *self {
114             match length {
115                 3..=10 | 258 => None,
116                 11..=18 => Some((1, (length - 11) % 2)),
117                 19..=34 => Some((2, (length - 19) % 4)),
118                 35..=66 => Some((3, (length - 35) % 8)),
119                 67..=130 => Some((4, (length - 67) % 16)),
120                 131..=257 => Some((5, (length - 131) % 32)),
121                 _ => unreachable!(),
122             }
123         } else {
124             None
125         }
126     }
distance(&self) -> Option<(u8, u8, u16)>127     pub fn distance(&self) -> Option<(u8, u8, u16)> {
128         if let Symbol::Code(lz77::Code::Pointer {
129             backward_distance: distance,
130             ..
131         }) = *self
132         {
133             if distance <= 4 {
134                 Some((distance as u8 - 1, 0, 0))
135             } else {
136                 let mut extra_bits = 1;
137                 let mut code = 4;
138                 let mut base = 4;
139                 while base * 2 < distance {
140                     extra_bits += 1;
141                     code += 2;
142                     base *= 2;
143                 }
144                 let half = base / 2;
145                 let delta = distance - base - 1;
146                 if distance <= base + half {
147                     Some((code, extra_bits, delta % half))
148                 } else {
149                     Some((code + 1, extra_bits, delta % half))
150                 }
151             }
152         } else {
153             None
154         }
155     }
156 }
157 impl From<lz77::Code> for Symbol {
from(code: lz77::Code) -> Self158     fn from(code: lz77::Code) -> Self {
159         Symbol::Code(code)
160     }
161 }
162 
163 #[derive(Debug)]
164 pub struct Encoder {
165     literal: huffman::Encoder,
166     distance: huffman::Encoder,
167 }
168 impl Encoder {
encode<W>(&self, writer: &mut bit::BitWriter<W>, symbol: &Symbol) -> io::Result<()> where W: io::Write,169     pub fn encode<W>(&self, writer: &mut bit::BitWriter<W>, symbol: &Symbol) -> io::Result<()>
170     where
171         W: io::Write,
172     {
173         self.literal.encode(writer, symbol.code())?;
174         if let Some((bits, extra)) = symbol.extra_lengh() {
175             writer.write_bits(bits, extra)?;
176         }
177         if let Some((code, bits, extra)) = symbol.distance() {
178             self.distance.encode(writer, u16::from(code))?;
179             if bits > 0 {
180                 writer.write_bits(bits, extra)?;
181             }
182         }
183         Ok(())
184     }
185 }
186 
187 #[derive(Debug)]
188 pub struct Decoder {
189     literal: huffman::Decoder,
190     distance: huffman::Decoder,
191 }
192 impl Decoder {
193     #[inline(always)]
decode_unchecked<R>(&self, reader: &mut bit::BitReader<R>) -> Symbol where R: io::Read,194     pub fn decode_unchecked<R>(&self, reader: &mut bit::BitReader<R>) -> Symbol
195     where
196         R: io::Read,
197     {
198         let mut symbol = self.decode_literal_or_length(reader);
199         if let Symbol::Code(lz77::Code::Pointer {
200             ref mut backward_distance,
201             ..
202         }) = symbol
203         {
204             *backward_distance = self.decode_distance(reader);
205         }
206         symbol
207     }
208     #[inline(always)]
decode_literal_or_length<R>(&self, reader: &mut bit::BitReader<R>) -> Symbol where R: io::Read,209     fn decode_literal_or_length<R>(&self, reader: &mut bit::BitReader<R>) -> Symbol
210     where
211         R: io::Read,
212     {
213         let decoded = self.literal.decode_unchecked(reader);
214         match decoded {
215             0..=255 => Symbol::Code(lz77::Code::Literal(decoded as u8)),
216             256 => Symbol::EndOfBlock,
217             286 | 287 => {
218                 let message = format!("The value {} must not occur in compressed data", decoded);
219                 reader.set_last_error(io::Error::new(io::ErrorKind::InvalidData, message));
220                 Symbol::EndOfBlock // dummy value
221             }
222             length_code => {
223                 let (base, extra_bits) = LENGTH_TABLE[length_code as usize - 257];
224                 let extra = reader.read_bits_unchecked(extra_bits);
225                 Symbol::Code(lz77::Code::Pointer {
226                     length: base + extra,
227                     backward_distance: 0,
228                 })
229             }
230         }
231     }
232     #[inline(always)]
decode_distance<R>(&self, reader: &mut bit::BitReader<R>) -> u16 where R: io::Read,233     fn decode_distance<R>(&self, reader: &mut bit::BitReader<R>) -> u16
234     where
235         R: io::Read,
236     {
237         let decoded = self.distance.decode_unchecked(reader) as usize;
238         let (base, extra_bits) = DISTANCE_TABLE[decoded];
239         let extra = reader.read_bits_unchecked(extra_bits);
240         base + extra
241     }
242 }
243 
244 pub trait HuffmanCodec {
build(&self, symbols: &[Symbol]) -> io::Result<Encoder>245     fn build(&self, symbols: &[Symbol]) -> io::Result<Encoder>;
save<W>(&self, writer: &mut bit::BitWriter<W>, codec: &Encoder) -> io::Result<()> where W: io::Write246     fn save<W>(&self, writer: &mut bit::BitWriter<W>, codec: &Encoder) -> io::Result<()>
247     where
248         W: io::Write;
load<R>(&self, reader: &mut bit::BitReader<R>) -> io::Result<Decoder> where R: io::Read249     fn load<R>(&self, reader: &mut bit::BitReader<R>) -> io::Result<Decoder>
250     where
251         R: io::Read;
252 }
253 
254 #[derive(Debug)]
255 pub struct FixedHuffmanCodec;
256 impl HuffmanCodec for FixedHuffmanCodec {
257     #[allow(unused_variables)]
build(&self, symbols: &[Symbol]) -> io::Result<Encoder>258     fn build(&self, symbols: &[Symbol]) -> io::Result<Encoder> {
259         let mut literal_builder = huffman::EncoderBuilder::new(288);
260         for &(bitwidth, ref symbols, code_base) in &FIXED_LITERAL_OR_LENGTH_CODE_TABLE {
261             for (code, symbol) in symbols
262                 .clone()
263                 .enumerate()
264                 .map(|(i, s)| (code_base + i as u16, s))
265             {
266                 literal_builder.set_mapping(symbol, huffman::Code::new(bitwidth, code))?;
267             }
268         }
269 
270         let mut distance_builder = huffman::EncoderBuilder::new(30);
271         for i in 0..30 {
272             distance_builder.set_mapping(i, huffman::Code::new(5, i))?;
273         }
274 
275         Ok(Encoder {
276             literal: literal_builder.finish(),
277             distance: distance_builder.finish(),
278         })
279     }
280     #[allow(unused_variables)]
save<W>(&self, writer: &mut bit::BitWriter<W>, codec: &Encoder) -> io::Result<()> where W: io::Write,281     fn save<W>(&self, writer: &mut bit::BitWriter<W>, codec: &Encoder) -> io::Result<()>
282     where
283         W: io::Write,
284     {
285         Ok(())
286     }
287     #[allow(unused_variables)]
load<R>(&self, reader: &mut bit::BitReader<R>) -> io::Result<Decoder> where R: io::Read,288     fn load<R>(&self, reader: &mut bit::BitReader<R>) -> io::Result<Decoder>
289     where
290         R: io::Read,
291     {
292         let mut literal_builder = huffman::DecoderBuilder::new(9, None, Some(END_OF_BLOCK));
293         for &(bitwidth, ref symbols, code_base) in &FIXED_LITERAL_OR_LENGTH_CODE_TABLE {
294             for (code, symbol) in symbols
295                 .clone()
296                 .enumerate()
297                 .map(|(i, s)| (code_base + i as u16, s))
298             {
299                 literal_builder.set_mapping(symbol, huffman::Code::new(bitwidth, code))?;
300             }
301         }
302 
303         let mut distance_builder =
304             huffman::DecoderBuilder::new(5, literal_builder.safely_peek_bitwidth(), None);
305         for i in 0..30 {
306             distance_builder.set_mapping(i, huffman::Code::new(5, i))?;
307         }
308 
309         Ok(Decoder {
310             literal: literal_builder.finish(),
311             distance: distance_builder.finish(),
312         })
313     }
314 }
315 
316 #[derive(Debug)]
317 pub struct DynamicHuffmanCodec;
318 impl HuffmanCodec for DynamicHuffmanCodec {
build(&self, symbols: &[Symbol]) -> io::Result<Encoder>319     fn build(&self, symbols: &[Symbol]) -> io::Result<Encoder> {
320         let mut literal_counts = [0; 286];
321         let mut distance_counts = [0; 30];
322         let mut empty_distance_table = true;
323         for s in symbols {
324             literal_counts[s.code() as usize] += 1;
325             if let Some((d, _, _)) = s.distance() {
326                 empty_distance_table = false;
327                 distance_counts[d as usize] += 1;
328             }
329         }
330         if empty_distance_table {
331             // Sets a dummy value because an empty distance table causes decoding error on Windows.
332             //
333             // See https://github.com/sile/libflate/issues/23 for more details.
334             distance_counts[0] = 1;
335         }
336         Ok(Encoder {
337             literal: huffman::EncoderBuilder::from_frequencies(&literal_counts, 15)?,
338             distance: huffman::EncoderBuilder::from_frequencies(&distance_counts, 15)?,
339         })
340     }
save<W>(&self, writer: &mut bit::BitWriter<W>, codec: &Encoder) -> io::Result<()> where W: io::Write,341     fn save<W>(&self, writer: &mut bit::BitWriter<W>, codec: &Encoder) -> io::Result<()>
342     where
343         W: io::Write,
344     {
345         let literal_code_count = cmp::max(257, codec.literal.used_max_symbol().unwrap_or(0) + 1);
346         let distance_code_count = cmp::max(1, codec.distance.used_max_symbol().unwrap_or(0) + 1);
347         let codes = build_bitwidth_codes(codec, literal_code_count, distance_code_count);
348 
349         let mut code_counts = [0; 19];
350         for x in &codes {
351             code_counts[x.0 as usize] += 1;
352         }
353         let bitwidth_encoder = huffman::EncoderBuilder::from_frequencies(&code_counts, 7)?;
354 
355         let bitwidth_code_count = cmp::max(
356             4,
357             BITWIDTH_CODE_ORDER
358                 .iter()
359                 .rev()
360                 .position(|&i| code_counts[i] != 0 && bitwidth_encoder.lookup(i as u16).width > 0)
361                 .map_or(0, |trailing_zeros| 19 - trailing_zeros),
362         ) as u16;
363         writer.write_bits(5, literal_code_count - 257)?;
364         writer.write_bits(5, distance_code_count - 1)?;
365         writer.write_bits(4, bitwidth_code_count - 4)?;
366         for &i in BITWIDTH_CODE_ORDER
367             .iter()
368             .take(bitwidth_code_count as usize)
369         {
370             let width = if code_counts[i] == 0 {
371                 0
372             } else {
373                 u16::from(bitwidth_encoder.lookup(i as u16).width)
374             };
375             writer.write_bits(3, width)?;
376         }
377         for &(code, bits, extra) in &codes {
378             bitwidth_encoder.encode(writer, u16::from(code))?;
379             if bits > 0 {
380                 writer.write_bits(bits, u16::from(extra))?;
381             }
382         }
383         Ok(())
384     }
load<R>(&self, reader: &mut bit::BitReader<R>) -> io::Result<Decoder> where R: io::Read,385     fn load<R>(&self, reader: &mut bit::BitReader<R>) -> io::Result<Decoder>
386     where
387         R: io::Read,
388     {
389         let literal_code_count = reader.read_bits(5)? + 257;
390         let distance_code_count = reader.read_bits(5)? + 1;
391         let bitwidth_code_count = reader.read_bits(4)? + 4;
392 
393         if distance_code_count as usize > MAX_DISTANCE_CODE_COUNT {
394             let message = format!(
395                 "The value of HDIST is too big: max={}, actual={}",
396                 MAX_DISTANCE_CODE_COUNT, distance_code_count
397             );
398             return Err(io::Error::new(io::ErrorKind::InvalidData, message));
399         }
400 
401         let mut bitwidth_code_bitwidthes = [0; 19];
402         for &i in BITWIDTH_CODE_ORDER
403             .iter()
404             .take(bitwidth_code_count as usize)
405         {
406             bitwidth_code_bitwidthes[i] = reader.read_bits(3)? as u8;
407         }
408         let bitwidth_decoder =
409             huffman::DecoderBuilder::from_bitwidthes(&bitwidth_code_bitwidthes, Some(1), None)?;
410 
411         let mut literal_code_bitwidthes = Vec::with_capacity(literal_code_count as usize);
412         while literal_code_bitwidthes.len() < literal_code_count as usize {
413             let c = bitwidth_decoder.decode(reader)?;
414             let last = literal_code_bitwidthes.last().cloned();
415             literal_code_bitwidthes.extend(load_bitwidthes(reader, c, last)?);
416         }
417 
418         let mut distance_code_bitwidthes = literal_code_bitwidthes
419             .drain(literal_code_count as usize..)
420             .collect::<Vec<_>>();
421         while distance_code_bitwidthes.len() < distance_code_count as usize {
422             let c = bitwidth_decoder.decode(reader)?;
423             let last = distance_code_bitwidthes
424                 .last()
425                 .cloned()
426                 .or_else(|| literal_code_bitwidthes.last().cloned());
427             distance_code_bitwidthes.extend(load_bitwidthes(reader, c, last)?);
428         }
429         if distance_code_bitwidthes.len() > distance_code_count as usize {
430             let message = format!(
431                 "The length of `distance_code_bitwidthes` is too large: actual={}, expected={}",
432                 distance_code_bitwidthes.len(),
433                 distance_code_count
434             );
435             return Err(io::Error::new(io::ErrorKind::InvalidData, message));
436         }
437 
438         let literal = huffman::DecoderBuilder::from_bitwidthes(
439             &literal_code_bitwidthes,
440             None,
441             Some(END_OF_BLOCK),
442         )?;
443         let distance = huffman::DecoderBuilder::from_bitwidthes(
444             &distance_code_bitwidthes,
445             Some(literal.safely_peek_bitwidth()),
446             None,
447         )?;
448         Ok(Decoder { literal, distance })
449     }
450 }
451 
load_bitwidthes<R>( reader: &mut bit::BitReader<R>, code: u16, last: Option<u8>, ) -> io::Result<Box<dyn Iterator<Item = u8>>> where R: io::Read,452 fn load_bitwidthes<R>(
453     reader: &mut bit::BitReader<R>,
454     code: u16,
455     last: Option<u8>,
456 ) -> io::Result<Box<dyn Iterator<Item = u8>>>
457 where
458     R: io::Read,
459 {
460     Ok(match code {
461         0..=15 => Box::new(iter::once(code as u8)),
462         16 => {
463             let count = reader.read_bits(2)? + 3;
464             let last = last.ok_or_else(|| invalid_data_error!("No preceding value"))?;
465             Box::new(iter::repeat(last).take(count as usize))
466         }
467         17 => {
468             let zeros = reader.read_bits(3)? + 3;
469             Box::new(iter::repeat(0).take(zeros as usize))
470         }
471         18 => {
472             let zeros = reader.read_bits(7)? + 11;
473             Box::new(iter::repeat(0).take(zeros as usize))
474         }
475         _ => unreachable!(),
476     })
477 }
478 
build_bitwidth_codes( codec: &Encoder, literal_code_count: u16, distance_code_count: u16, ) -> Vec<(u8, u8, u8)>479 fn build_bitwidth_codes(
480     codec: &Encoder,
481     literal_code_count: u16,
482     distance_code_count: u16,
483 ) -> Vec<(u8, u8, u8)> {
484     struct RunLength {
485         value: u8,
486         count: usize,
487     }
488 
489     let mut run_lens: Vec<RunLength> = Vec::new();
490     for &(e, size) in &[
491         (&codec.literal, literal_code_count),
492         (&codec.distance, distance_code_count),
493     ] {
494         for (i, c) in (0..size).map(|x| e.lookup(x as u16).width).enumerate() {
495             if i > 0 && run_lens.last().map_or(false, |s| s.value == c) {
496                 run_lens.last_mut().unwrap().count += 1;
497             } else {
498                 run_lens.push(RunLength { value: c, count: 1 })
499             }
500         }
501     }
502 
503     let mut codes: Vec<(u8, u8, u8)> = Vec::new();
504     for r in run_lens {
505         if r.value == 0 {
506             let mut c = r.count;
507             while c >= 11 {
508                 let n = cmp::min(138, c) as u8;
509                 codes.push((18, 7, n - 11));
510                 c -= n as usize;
511             }
512             if c >= 3 {
513                 codes.push((17, 3, c as u8 - 3));
514                 c = 0;
515             }
516             for _ in 0..c {
517                 codes.push((0, 0, 0));
518             }
519         } else {
520             codes.push((r.value, 0, 0));
521             let mut c = r.count - 1;
522             while c >= 3 {
523                 let n = cmp::min(6, c) as u8;
524                 codes.push((16, 2, n - 3));
525                 c -= n as usize;
526             }
527             for _ in 0..c {
528                 codes.push((r.value, 0, 0));
529             }
530         }
531     }
532     codes
533 }
534