1 use std::io; 2 3 #[derive(Debug)] 4 pub struct BitWriter<W> { 5 inner: W, 6 buf: u32, 7 end: u8, 8 } 9 impl<W> BitWriter<W> 10 where 11 W: io::Write, 12 { new(inner: W) -> Self13 pub fn new(inner: W) -> Self { 14 BitWriter { 15 inner, 16 buf: 0, 17 end: 0, 18 } 19 } 20 #[inline(always)] write_bit(&mut self, bit: bool) -> io::Result<()>21 pub fn write_bit(&mut self, bit: bool) -> io::Result<()> { 22 self.write_bits(1, bit as u16) 23 } 24 #[inline(always)] write_bits(&mut self, bitwidth: u8, bits: u16) -> io::Result<()>25 pub fn write_bits(&mut self, bitwidth: u8, bits: u16) -> io::Result<()> { 26 debug_assert!(bitwidth < 16); 27 debug_assert!(self.end + bitwidth <= 32); 28 self.buf |= u32::from(bits) << self.end; 29 self.end += bitwidth; 30 self.flush_if_needed() 31 } flush(&mut self) -> io::Result<()>32 pub fn flush(&mut self) -> io::Result<()> { 33 while self.end > 0 { 34 self.inner.write_all(&[self.buf as u8])?; 35 self.buf >>= 8; 36 self.end = self.end.saturating_sub(8); 37 } 38 self.inner.flush()?; 39 Ok(()) 40 } 41 #[inline(always)] flush_if_needed(&mut self) -> io::Result<()>42 fn flush_if_needed(&mut self) -> io::Result<()> { 43 if self.end >= 16 { 44 self.inner.write_all(&(self.buf as u16).to_le_bytes())?; 45 self.end -= 16; 46 self.buf >>= 16; 47 } 48 Ok(()) 49 } 50 } 51 impl<W> BitWriter<W> { as_inner_ref(&self) -> &W52 pub fn as_inner_ref(&self) -> &W { 53 &self.inner 54 } as_inner_mut(&mut self) -> &mut W55 pub fn as_inner_mut(&mut self) -> &mut W { 56 &mut self.inner 57 } into_inner(self) -> W58 pub fn into_inner(self) -> W { 59 self.inner 60 } 61 } 62 63 #[derive(Debug)] 64 pub struct BitReader<R> { 65 inner: R, 66 last_read: u32, 67 offset: u8, 68 last_error: Option<io::Error>, 69 } 70 impl<R> BitReader<R> 71 where 72 R: io::Read, 73 { new(inner: R) -> Self74 pub fn new(inner: R) -> Self { 75 BitReader { 76 inner, 77 last_read: 0, 78 offset: 32, 79 last_error: None, 80 } 81 } 82 #[inline(always)] set_last_error(&mut self, e: io::Error)83 pub fn set_last_error(&mut self, e: io::Error) { 84 self.last_error = Some(e); 85 } 86 #[inline(always)] check_last_error(&mut self) -> io::Result<()>87 pub fn check_last_error(&mut self) -> io::Result<()> { 88 if let Some(e) = self.last_error.take() { 89 Err(e) 90 } else { 91 Ok(()) 92 } 93 } 94 #[inline(always)] read_bit(&mut self) -> io::Result<bool>95 pub fn read_bit(&mut self) -> io::Result<bool> { 96 self.read_bits(1).map(|b| b != 0) 97 } 98 #[inline(always)] read_bits(&mut self, bitwidth: u8) -> io::Result<u16>99 pub fn read_bits(&mut self, bitwidth: u8) -> io::Result<u16> { 100 let v = self.read_bits_unchecked(bitwidth); 101 self.check_last_error().map(|_| v) 102 } 103 #[inline(always)] read_bits_unchecked(&mut self, bitwidth: u8) -> u16104 pub fn read_bits_unchecked(&mut self, bitwidth: u8) -> u16 { 105 let bits = self.peek_bits_unchecked(bitwidth); 106 self.skip_bits(bitwidth); 107 bits 108 } 109 #[inline(always)] peek_bits_unchecked(&mut self, bitwidth: u8) -> u16110 pub fn peek_bits_unchecked(&mut self, bitwidth: u8) -> u16 { 111 debug_assert!(bitwidth <= 16); 112 while 32 < self.offset + bitwidth { 113 if self.last_error.is_some() { 114 return 0; 115 } 116 if let Err(e) = self.fill_next_u8() { 117 self.last_error = Some(e); 118 return 0; 119 } 120 } 121 debug_assert!(self.offset < 32 || bitwidth == 0); 122 let bits = self.last_read.wrapping_shr(u32::from(self.offset)) as u16; 123 bits & ((1 << bitwidth) - 1) 124 } 125 #[inline(always)] skip_bits(&mut self, bitwidth: u8)126 pub fn skip_bits(&mut self, bitwidth: u8) { 127 debug_assert!(self.last_error.is_some() || 32 - self.offset >= bitwidth); 128 self.offset += bitwidth; 129 } 130 #[inline(always)] fill_next_u8(&mut self) -> io::Result<()>131 fn fill_next_u8(&mut self) -> io::Result<()> { 132 self.offset -= 8; 133 self.last_read >>= 8; 134 135 let mut buf = [0; 1]; 136 self.inner.read_exact(&mut buf)?; 137 let next = u32::from(buf[0]); 138 self.last_read |= next << (32 - 8); 139 Ok(()) 140 } 141 #[inline] state(&self) -> BitReaderState142 pub(crate) fn state(&self) -> BitReaderState { 143 BitReaderState { 144 last_read: self.last_read, 145 offset: self.offset, 146 } 147 } 148 #[inline] restore_state(&mut self, state: BitReaderState)149 pub(crate) fn restore_state(&mut self, state: BitReaderState) { 150 self.last_read = state.last_read; 151 self.offset = state.offset; 152 } 153 } 154 impl<R> BitReader<R> { reset(&mut self)155 pub fn reset(&mut self) { 156 self.offset = 32; 157 } as_inner_ref(&self) -> &R158 pub fn as_inner_ref(&self) -> &R { 159 &self.inner 160 } as_inner_mut(&mut self) -> &mut R161 pub fn as_inner_mut(&mut self) -> &mut R { 162 &mut self.inner 163 } into_inner(self) -> R164 pub fn into_inner(self) -> R { 165 self.inner 166 } 167 } 168 169 #[derive(Debug, Clone, Copy)] 170 pub(crate) struct BitReaderState { 171 last_read: u32, 172 offset: u8, 173 } 174 175 #[cfg(test)] 176 mod test { 177 use super::*; 178 use std::io; 179 180 #[test] writer_works()181 fn writer_works() { 182 let mut writer = BitWriter::new(Vec::new()); 183 writer.write_bit(true).unwrap(); 184 writer.write_bits(3, 0b010).unwrap(); 185 writer.write_bits(11, 0b10101011010).unwrap(); 186 writer.flush().unwrap(); 187 writer.write_bit(true).unwrap(); 188 writer.flush().unwrap(); 189 190 let buf = writer.into_inner(); 191 assert_eq!(buf, [0b10100101, 0b01010101, 0b00000001]); 192 } 193 194 #[test] reader_works()195 fn reader_works() { 196 let buf = [0b10100101, 0b11010101]; 197 let mut reader = BitReader::new(&buf[..]); 198 assert_eq!(reader.read_bit().unwrap(), true); 199 assert_eq!(reader.read_bit().unwrap(), false); 200 assert_eq!(reader.read_bits(8).unwrap(), 0b01101001); 201 assert_eq!(reader.peek_bits_unchecked(3), 0b101); 202 assert_eq!(reader.peek_bits_unchecked(3), 0b101); 203 reader.skip_bits(1); 204 assert_eq!(reader.peek_bits_unchecked(3), 0b010); 205 assert_eq!( 206 reader.read_bits(8).map_err(|e| e.kind()), 207 Err(io::ErrorKind::UnexpectedEof) 208 ); 209 } 210 } 211