1// Copyright 2019+ Klaus Post. All rights reserved. 2// License information can be found in the LICENSE file. 3// Based on work by Yann Collet, released under BSD License. 4 5package zstd 6 7import ( 8 "encoding/binary" 9 "errors" 10 "io" 11 "math/bits" 12) 13 14// bitReader reads a bitstream in reverse. 15// The last set bit indicates the start of the stream and is used 16// for aligning the input. 17type bitReader struct { 18 in []byte 19 off uint // next byte to read is at in[off - 1] 20 value uint64 // Maybe use [16]byte, but shifting is awkward. 21 bitsRead uint8 22} 23 24// init initializes and resets the bit reader. 25func (b *bitReader) init(in []byte) error { 26 if len(in) < 1 { 27 return errors.New("corrupt stream: too short") 28 } 29 b.in = in 30 b.off = uint(len(in)) 31 // The highest bit of the last byte indicates where to start 32 v := in[len(in)-1] 33 if v == 0 { 34 return errors.New("corrupt stream, did not find end of stream") 35 } 36 b.bitsRead = 64 37 b.value = 0 38 if len(in) >= 8 { 39 b.fillFastStart() 40 } else { 41 b.fill() 42 b.fill() 43 } 44 b.bitsRead += 8 - uint8(highBits(uint32(v))) 45 return nil 46} 47 48// getBits will return n bits. n can be 0. 49func (b *bitReader) getBits(n uint8) int { 50 if n == 0 /*|| b.bitsRead >= 64 */ { 51 return 0 52 } 53 return b.getBitsFast(n) 54} 55 56// getBitsFast requires that at least one bit is requested every time. 57// There are no checks if the buffer is filled. 58func (b *bitReader) getBitsFast(n uint8) int { 59 const regMask = 64 - 1 60 v := uint32((b.value << (b.bitsRead & regMask)) >> ((regMask + 1 - n) & regMask)) 61 b.bitsRead += n 62 return int(v) 63} 64 65// fillFast() will make sure at least 32 bits are available. 66// There must be at least 4 bytes available. 67func (b *bitReader) fillFast() { 68 if b.bitsRead < 32 { 69 return 70 } 71 // 2 bounds checks. 72 v := b.in[b.off-4:] 73 v = v[:4] 74 low := (uint32(v[0])) | (uint32(v[1]) << 8) | (uint32(v[2]) << 16) | (uint32(v[3]) << 24) 75 b.value = (b.value << 32) | uint64(low) 76 b.bitsRead -= 32 77 b.off -= 4 78} 79 80// fillFastStart() assumes the bitreader is empty and there is at least 8 bytes to read. 81func (b *bitReader) fillFastStart() { 82 // Do single re-slice to avoid bounds checks. 83 b.value = binary.LittleEndian.Uint64(b.in[b.off-8:]) 84 b.bitsRead = 0 85 b.off -= 8 86} 87 88// fill() will make sure at least 32 bits are available. 89func (b *bitReader) fill() { 90 if b.bitsRead < 32 { 91 return 92 } 93 if b.off >= 4 { 94 v := b.in[b.off-4:] 95 v = v[:4] 96 low := (uint32(v[0])) | (uint32(v[1]) << 8) | (uint32(v[2]) << 16) | (uint32(v[3]) << 24) 97 b.value = (b.value << 32) | uint64(low) 98 b.bitsRead -= 32 99 b.off -= 4 100 return 101 } 102 for b.off > 0 { 103 b.value = (b.value << 8) | uint64(b.in[b.off-1]) 104 b.bitsRead -= 8 105 b.off-- 106 } 107} 108 109// finished returns true if all bits have been read from the bit stream. 110func (b *bitReader) finished() bool { 111 return b.off == 0 && b.bitsRead >= 64 112} 113 114// overread returns true if more bits have been requested than is on the stream. 115func (b *bitReader) overread() bool { 116 return b.bitsRead > 64 117} 118 119// remain returns the number of bits remaining. 120func (b *bitReader) remain() uint { 121 return b.off*8 + 64 - uint(b.bitsRead) 122} 123 124// close the bitstream and returns an error if out-of-buffer reads occurred. 125func (b *bitReader) close() error { 126 // Release reference. 127 b.in = nil 128 if b.bitsRead > 64 { 129 return io.ErrUnexpectedEOF 130 } 131 return nil 132} 133 134func highBits(val uint32) (n uint32) { 135 return uint32(bits.Len32(val) - 1) 136} 137