1 use std::io;
2 
3 #[derive(Debug)]
4 pub struct BitWriter<W> {
5     inner: W,
6     buf: u32,
7     end: u8,
8 }
9 impl<W> BitWriter<W>
10 where
11     W: io::Write,
12 {
new(inner: W) -> Self13     pub fn new(inner: W) -> Self {
14         BitWriter {
15             inner,
16             buf: 0,
17             end: 0,
18         }
19     }
20     #[inline(always)]
write_bit(&mut self, bit: bool) -> io::Result<()>21     pub fn write_bit(&mut self, bit: bool) -> io::Result<()> {
22         self.write_bits(1, bit as u16)
23     }
24     #[inline(always)]
write_bits(&mut self, bitwidth: u8, bits: u16) -> io::Result<()>25     pub fn write_bits(&mut self, bitwidth: u8, bits: u16) -> io::Result<()> {
26         debug_assert!(bitwidth < 16);
27         debug_assert!(self.end + bitwidth <= 32);
28         self.buf |= u32::from(bits) << self.end;
29         self.end += bitwidth;
30         self.flush_if_needed()
31     }
flush(&mut self) -> io::Result<()>32     pub fn flush(&mut self) -> io::Result<()> {
33         while self.end > 0 {
34             self.inner.write_all(&[self.buf as u8])?;
35             self.buf >>= 8;
36             self.end = self.end.saturating_sub(8);
37         }
38         self.inner.flush()?;
39         Ok(())
40     }
41     #[inline(always)]
flush_if_needed(&mut self) -> io::Result<()>42     fn flush_if_needed(&mut self) -> io::Result<()> {
43         if self.end >= 16 {
44             self.inner.write_all(&(self.buf as u16).to_le_bytes())?;
45             self.end -= 16;
46             self.buf >>= 16;
47         }
48         Ok(())
49     }
50 }
51 impl<W> BitWriter<W> {
as_inner_ref(&self) -> &W52     pub fn as_inner_ref(&self) -> &W {
53         &self.inner
54     }
as_inner_mut(&mut self) -> &mut W55     pub fn as_inner_mut(&mut self) -> &mut W {
56         &mut self.inner
57     }
into_inner(self) -> W58     pub fn into_inner(self) -> W {
59         self.inner
60     }
61 }
62 
63 #[derive(Debug)]
64 pub struct BitReader<R> {
65     inner: R,
66     last_read: u32,
67     offset: u8,
68     last_error: Option<io::Error>,
69 }
70 impl<R> BitReader<R>
71 where
72     R: io::Read,
73 {
new(inner: R) -> Self74     pub fn new(inner: R) -> Self {
75         BitReader {
76             inner,
77             last_read: 0,
78             offset: 32,
79             last_error: None,
80         }
81     }
82     #[inline(always)]
set_last_error(&mut self, e: io::Error)83     pub fn set_last_error(&mut self, e: io::Error) {
84         self.last_error = Some(e);
85     }
86     #[inline(always)]
check_last_error(&mut self) -> io::Result<()>87     pub fn check_last_error(&mut self) -> io::Result<()> {
88         if let Some(e) = self.last_error.take() {
89             Err(e)
90         } else {
91             Ok(())
92         }
93     }
94     #[inline(always)]
read_bit(&mut self) -> io::Result<bool>95     pub fn read_bit(&mut self) -> io::Result<bool> {
96         self.read_bits(1).map(|b| b != 0)
97     }
98     #[inline(always)]
read_bits(&mut self, bitwidth: u8) -> io::Result<u16>99     pub fn read_bits(&mut self, bitwidth: u8) -> io::Result<u16> {
100         let v = self.read_bits_unchecked(bitwidth);
101         self.check_last_error().map(|_| v)
102     }
103     #[inline(always)]
read_bits_unchecked(&mut self, bitwidth: u8) -> u16104     pub fn read_bits_unchecked(&mut self, bitwidth: u8) -> u16 {
105         let bits = self.peek_bits_unchecked(bitwidth);
106         self.skip_bits(bitwidth);
107         bits
108     }
109     #[inline(always)]
peek_bits_unchecked(&mut self, bitwidth: u8) -> u16110     pub fn peek_bits_unchecked(&mut self, bitwidth: u8) -> u16 {
111         debug_assert!(bitwidth <= 16);
112         while 32 < self.offset + bitwidth {
113             if self.last_error.is_some() {
114                 return 0;
115             }
116             if let Err(e) = self.fill_next_u8() {
117                 self.last_error = Some(e);
118                 return 0;
119             }
120         }
121         debug_assert!(self.offset < 32 || bitwidth == 0);
122         let bits = self.last_read.wrapping_shr(u32::from(self.offset)) as u16;
123         bits & ((1 << bitwidth) - 1)
124     }
125     #[inline(always)]
skip_bits(&mut self, bitwidth: u8)126     pub fn skip_bits(&mut self, bitwidth: u8) {
127         debug_assert!(self.last_error.is_some() || 32 - self.offset >= bitwidth);
128         self.offset += bitwidth;
129     }
130     #[inline(always)]
fill_next_u8(&mut self) -> io::Result<()>131     fn fill_next_u8(&mut self) -> io::Result<()> {
132         self.offset -= 8;
133         self.last_read >>= 8;
134 
135         let mut buf = [0; 1];
136         self.inner.read_exact(&mut buf)?;
137         let next = u32::from(buf[0]);
138         self.last_read |= next << (32 - 8);
139         Ok(())
140     }
141     #[inline]
state(&self) -> BitReaderState142     pub(crate) fn state(&self) -> BitReaderState {
143         BitReaderState {
144             last_read: self.last_read,
145             offset: self.offset,
146         }
147     }
148     #[inline]
restore_state(&mut self, state: BitReaderState)149     pub(crate) fn restore_state(&mut self, state: BitReaderState) {
150         self.last_read = state.last_read;
151         self.offset = state.offset;
152     }
153 }
154 impl<R> BitReader<R> {
reset(&mut self)155     pub fn reset(&mut self) {
156         self.offset = 32;
157     }
as_inner_ref(&self) -> &R158     pub fn as_inner_ref(&self) -> &R {
159         &self.inner
160     }
as_inner_mut(&mut self) -> &mut R161     pub fn as_inner_mut(&mut self) -> &mut R {
162         &mut self.inner
163     }
into_inner(self) -> R164     pub fn into_inner(self) -> R {
165         self.inner
166     }
167 }
168 
169 #[derive(Debug, Clone, Copy)]
170 pub(crate) struct BitReaderState {
171     last_read: u32,
172     offset: u8,
173 }
174 
175 #[cfg(test)]
176 mod test {
177     use super::*;
178     use std::io;
179 
180     #[test]
writer_works()181     fn writer_works() {
182         let mut writer = BitWriter::new(Vec::new());
183         writer.write_bit(true).unwrap();
184         writer.write_bits(3, 0b010).unwrap();
185         writer.write_bits(11, 0b10101011010).unwrap();
186         writer.flush().unwrap();
187         writer.write_bit(true).unwrap();
188         writer.flush().unwrap();
189 
190         let buf = writer.into_inner();
191         assert_eq!(buf, [0b10100101, 0b01010101, 0b00000001]);
192     }
193 
194     #[test]
reader_works()195     fn reader_works() {
196         let buf = [0b10100101, 0b11010101];
197         let mut reader = BitReader::new(&buf[..]);
198         assert_eq!(reader.read_bit().unwrap(), true);
199         assert_eq!(reader.read_bit().unwrap(), false);
200         assert_eq!(reader.read_bits(8).unwrap(), 0b01101001);
201         assert_eq!(reader.peek_bits_unchecked(3), 0b101);
202         assert_eq!(reader.peek_bits_unchecked(3), 0b101);
203         reader.skip_bits(1);
204         assert_eq!(reader.peek_bits_unchecked(3), 0b010);
205         assert_eq!(
206             reader.read_bits(8).map_err(|e| e.kind()),
207             Err(io::ErrorKind::UnexpectedEof)
208         );
209     }
210 }
211