1// Copyright 2015, Joe Tsai. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE.md file. 4 5package brotli 6 7import ( 8 "bufio" 9 "io" 10 11 "github.com/dsnet/compress/internal/errors" 12) 13 14// The bitReader preserves the property that it will never read more bytes than 15// is necessary. However, this feature dramatically hurts performance because 16// every byte needs to be obtained through a ReadByte method call. 17// Furthermore, the decoding of variable length codes in ReadSymbol, often 18// requires multiple passes before it knows the exact bit-length of the code. 19// 20// Thus, to improve performance, if the underlying byteReader is a bufio.Reader, 21// then the bitReader will use the Peek and Discard methods to fill the internal 22// bit buffer with as many bits as possible, allowing the TryReadBits and 23// TryReadSymbol methods to often succeed on the first try. 24 25type byteReader interface { 26 io.Reader 27 io.ByteReader 28} 29 30type bitReader struct { 31 rd byteReader 32 bufBits uint64 // Buffer to hold some bits 33 numBits uint // Number of valid bits in bufBits 34 offset int64 // Number of bytes read from the underlying io.Reader 35 36 // These fields are only used if rd is a bufio.Reader. 37 bufRd *bufio.Reader 38 bufPeek []byte // Buffer for the Peek data 39 discardBits int // Number of bits to discard from bufio.Reader 40 fedBits uint // Number of bits fed in last call to FeedBits 41 42 // Local copy of decoders to reduce memory allocations. 43 prefix prefixDecoder 44} 45 46func (br *bitReader) Init(r io.Reader) { 47 *br = bitReader{prefix: br.prefix} 48 if rr, ok := r.(byteReader); ok { 49 br.rd = rr 50 } else { 51 br.rd = bufio.NewReader(r) 52 } 53 if brd, ok := br.rd.(*bufio.Reader); ok { 54 br.bufRd = brd 55 } 56} 57 58// FlushOffset updates the read offset of the underlying byteReader. 59// If the byteReader is a bufio.Reader, then this calls Discard to update the 60// read offset. 61func (br *bitReader) FlushOffset() int64 { 62 if br.bufRd == nil { 63 return br.offset 64 } 65 66 // Update the number of total bits to discard. 67 br.discardBits += int(br.fedBits - br.numBits) 68 br.fedBits = br.numBits 69 70 // Discard some bytes to update read offset. 71 nd := (br.discardBits + 7) / 8 // Round up to nearest byte 72 nd, _ = br.bufRd.Discard(nd) 73 br.discardBits -= nd * 8 // -7..0 74 br.offset += int64(nd) 75 76 // These are invalid after Discard. 77 br.bufPeek = nil 78 return br.offset 79} 80 81// FeedBits ensures that at least nb bits exist in the bit buffer. 82// If the underlying byteReader is a bufio.Reader, then this will fill the 83// bit buffer with as many bits as possible, relying on Peek and Discard to 84// properly advance the read offset. Otherwise, it will use ReadByte to fill the 85// buffer with just the right number of bits. 86func (br *bitReader) FeedBits(nb uint) { 87 if br.bufRd != nil { 88 br.discardBits += int(br.fedBits - br.numBits) 89 for { 90 if len(br.bufPeek) == 0 { 91 br.fedBits = br.numBits // Don't discard bits just added 92 br.FlushOffset() 93 94 var err error 95 cntPeek := 8 // Minimum Peek amount to make progress 96 if br.bufRd.Buffered() > cntPeek { 97 cntPeek = br.bufRd.Buffered() 98 } 99 br.bufPeek, err = br.bufRd.Peek(cntPeek) 100 br.bufPeek = br.bufPeek[int(br.numBits/8):] // Skip buffered bits 101 if len(br.bufPeek) == 0 { 102 if br.numBits >= nb { 103 break 104 } 105 if err == io.EOF { 106 err = io.ErrUnexpectedEOF 107 } 108 errors.Panic(err) 109 } 110 } 111 cnt := int(64-br.numBits) / 8 112 if cnt > len(br.bufPeek) { 113 cnt = len(br.bufPeek) 114 } 115 for _, c := range br.bufPeek[:cnt] { 116 br.bufBits |= uint64(c) << br.numBits 117 br.numBits += 8 118 } 119 br.bufPeek = br.bufPeek[cnt:] 120 if br.numBits > 56 { 121 break 122 } 123 } 124 br.fedBits = br.numBits 125 } else { 126 for br.numBits < nb { 127 c, err := br.rd.ReadByte() 128 if err != nil { 129 if err == io.EOF { 130 err = io.ErrUnexpectedEOF 131 } 132 errors.Panic(err) 133 } 134 br.bufBits |= uint64(c) << br.numBits 135 br.numBits += 8 136 br.offset++ 137 } 138 } 139} 140 141// Read reads up to len(buf) bytes into buf. 142func (br *bitReader) Read(buf []byte) (cnt int, err error) { 143 if br.numBits%8 != 0 { 144 return 0, errorf(errors.Invalid, "non-aligned bit buffer") 145 } 146 if br.numBits > 0 { 147 for cnt = 0; len(buf) > cnt && br.numBits > 0; cnt++ { 148 buf[cnt] = byte(br.bufBits) 149 br.bufBits >>= 8 150 br.numBits -= 8 151 } 152 } else { 153 br.FlushOffset() 154 cnt, err = br.rd.Read(buf) 155 br.offset += int64(cnt) 156 } 157 return cnt, err 158} 159 160// TryReadBits attempts to read nb bits using the contents of the bit buffer 161// alone. It returns the value and whether it succeeded. 162// 163// This method is designed to be inlined for performance reasons. 164func (br *bitReader) TryReadBits(nb uint) (uint, bool) { 165 if br.numBits < nb { 166 return 0, false 167 } 168 val := uint(br.bufBits & uint64(1<<nb-1)) 169 br.bufBits >>= nb 170 br.numBits -= nb 171 return val, true 172} 173 174// ReadBits reads nb bits in LSB order from the underlying reader. 175func (br *bitReader) ReadBits(nb uint) uint { 176 br.FeedBits(nb) 177 val := uint(br.bufBits & uint64(1<<nb-1)) 178 br.bufBits >>= nb 179 br.numBits -= nb 180 return val 181} 182 183// ReadPads reads 0-7 bits from the bit buffer to achieve byte-alignment. 184func (br *bitReader) ReadPads() uint { 185 nb := br.numBits % 8 186 val := uint(br.bufBits & uint64(1<<nb-1)) 187 br.bufBits >>= nb 188 br.numBits -= nb 189 return val 190} 191 192// TryReadSymbol attempts to decode the next symbol using the contents of the 193// bit buffer alone. It returns the decoded symbol and whether it succeeded. 194// 195// This method is designed to be inlined for performance reasons. 196func (br *bitReader) TryReadSymbol(pd *prefixDecoder) (uint, bool) { 197 if br.numBits < uint(pd.minBits) || len(pd.chunks) == 0 { 198 return 0, false 199 } 200 chunk := pd.chunks[uint32(br.bufBits)&pd.chunkMask] 201 nb := uint(chunk & prefixCountMask) 202 if nb > br.numBits || nb > uint(pd.chunkBits) { 203 return 0, false 204 } 205 br.bufBits >>= nb 206 br.numBits -= nb 207 return uint(chunk >> prefixCountBits), true 208} 209 210// ReadSymbol reads the next prefix symbol using the provided prefixDecoder. 211func (br *bitReader) ReadSymbol(pd *prefixDecoder) uint { 212 if len(pd.chunks) == 0 { 213 errors.Panic(errInvalid) // Decode with empty tree 214 } 215 216 nb := uint(pd.minBits) 217 for { 218 br.FeedBits(nb) 219 chunk := pd.chunks[uint32(br.bufBits)&pd.chunkMask] 220 nb = uint(chunk & prefixCountMask) 221 if nb > uint(pd.chunkBits) { 222 linkIdx := chunk >> prefixCountBits 223 chunk = pd.links[linkIdx][uint32(br.bufBits>>pd.chunkBits)&pd.linkMask] 224 nb = uint(chunk & prefixCountMask) 225 } 226 if nb <= br.numBits { 227 br.bufBits >>= nb 228 br.numBits -= nb 229 return uint(chunk >> prefixCountBits) 230 } 231 } 232} 233 234// ReadOffset reads an offset value using the provided rangesCodes indexed by 235// the given symbol. 236func (br *bitReader) ReadOffset(sym uint, rcs []rangeCode) uint { 237 rc := rcs[sym] 238 return uint(rc.base) + br.ReadBits(uint(rc.bits)) 239} 240 241// ReadPrefixCode reads the prefix definition from the stream and initializes 242// the provided prefixDecoder. The value maxSyms is the alphabet size of the 243// prefix code being generated. The actual number of representable symbols 244// will be between 1 and maxSyms, inclusively. 245func (br *bitReader) ReadPrefixCode(pd *prefixDecoder, maxSyms uint) { 246 hskip := br.ReadBits(2) 247 if hskip == 1 { 248 br.readSimplePrefixCode(pd, maxSyms) 249 } else { 250 br.readComplexPrefixCode(pd, maxSyms, hskip) 251 } 252} 253 254// readSimplePrefixCode reads the prefix code according to RFC section 3.4. 255func (br *bitReader) readSimplePrefixCode(pd *prefixDecoder, maxSyms uint) { 256 var codes [4]prefixCode 257 nsym := int(br.ReadBits(2)) + 1 258 clen := neededBits(uint32(maxSyms)) 259 for i := 0; i < nsym; i++ { 260 codes[i].sym = uint32(br.ReadBits(clen)) 261 } 262 263 copyLens := func(lens []uint) { 264 for i := 0; i < nsym; i++ { 265 codes[i].len = uint32(lens[i]) 266 } 267 } 268 compareSwap := func(i, j int) { 269 if codes[i].sym > codes[j].sym { 270 codes[i], codes[j] = codes[j], codes[i] 271 } 272 } 273 274 switch nsym { 275 case 1: 276 copyLens(simpleLens1[:]) 277 case 2: 278 copyLens(simpleLens2[:]) 279 compareSwap(0, 1) 280 case 3: 281 copyLens(simpleLens3[:]) 282 compareSwap(0, 1) 283 compareSwap(0, 2) 284 compareSwap(1, 2) 285 case 4: 286 if tsel := br.ReadBits(1) == 1; !tsel { 287 copyLens(simpleLens4a[:]) 288 } else { 289 copyLens(simpleLens4b[:]) 290 } 291 compareSwap(0, 1) 292 compareSwap(2, 3) 293 compareSwap(0, 2) 294 compareSwap(1, 3) 295 compareSwap(1, 2) 296 } 297 if uint(codes[nsym-1].sym) >= maxSyms { 298 errors.Panic(errCorrupted) // Symbol goes beyond range of alphabet 299 } 300 pd.Init(codes[:nsym], true) // Must have 1..4 symbols 301} 302 303// readComplexPrefixCode reads the prefix code according to RFC section 3.5. 304func (br *bitReader) readComplexPrefixCode(pd *prefixDecoder, maxSyms, hskip uint) { 305 // Read the code-lengths prefix table. 306 var codeCLensArr [len(complexLens)]prefixCode // Sorted, but may have holes 307 sum := 32 308 for _, sym := range complexLens[hskip:] { 309 clen := br.ReadSymbol(&decCLens) 310 if clen > 0 { 311 codeCLensArr[sym] = prefixCode{sym: uint32(sym), len: uint32(clen)} 312 if sum -= 32 >> clen; sum <= 0 { 313 break 314 } 315 } 316 } 317 codeCLens := codeCLensArr[:0] // Compact the array to have no holes 318 for _, c := range codeCLensArr { 319 if c.len > 0 { 320 codeCLens = append(codeCLens, c) 321 } 322 } 323 if len(codeCLens) < 1 { 324 errors.Panic(errCorrupted) 325 } 326 br.prefix.Init(codeCLens, true) // Must have 1..len(complexLens) symbols 327 328 // Use code-lengths table to decode rest of prefix table. 329 var codesArr [maxNumAlphabetSyms]prefixCode 330 var sym, repSymLast, repCntLast, clenLast uint = 0, 0, 0, 8 331 codes := codesArr[:0] 332 for sym, sum = 0, 32768; sym < maxSyms && sum > 0; { 333 clen := br.ReadSymbol(&br.prefix) 334 if clen < 16 { 335 // Literal bit-length symbol used. 336 if clen > 0 { 337 codes = append(codes, prefixCode{sym: uint32(sym), len: uint32(clen)}) 338 clenLast = clen 339 sum -= 32768 >> clen 340 } 341 repSymLast = 0 // Reset last repeater symbol 342 sym++ 343 } else { 344 // Repeater symbol used. 345 // 16: Repeat previous non-zero code-length 346 // 17: Repeat code length of zero 347 348 repSym := clen // Rename clen for better clarity 349 if repSym != repSymLast { 350 repCntLast = 0 351 repSymLast = repSym 352 } 353 354 nb := repSym - 14 // 2..3 bits 355 rep := br.ReadBits(nb) + 3 // 3..6 or 3..10 356 if repCntLast > 0 { 357 rep += (repCntLast - 2) << nb // Modify previous repeat count 358 } 359 repDiff := rep - repCntLast // Always positive 360 repCntLast = rep 361 362 if repSym == 16 { 363 clen := clenLast 364 for symEnd := sym + repDiff; sym < symEnd; sym++ { 365 codes = append(codes, prefixCode{sym: uint32(sym), len: uint32(clen)}) 366 } 367 sum -= int(repDiff) * (32768 >> clen) 368 } else { 369 sym += repDiff 370 } 371 } 372 } 373 if len(codes) < 2 || sym > maxSyms { 374 errors.Panic(errCorrupted) 375 } 376 pd.Init(codes, true) // Must have 2..maxSyms symbols 377} 378