1// Copyright 2019+ Klaus Post. All rights reserved. 2// License information can be found in the LICENSE file. 3// Based on work by Yann Collet, released under BSD License. 4 5package zstd 6 7import ( 8 "bytes" 9 "encoding/hex" 10 "errors" 11 "hash" 12 "io" 13 "sync" 14 15 "github.com/klauspost/compress/zstd/internal/xxhash" 16) 17 18type frameDec struct { 19 o decoderOptions 20 crc hash.Hash64 21 offset int64 22 23 WindowSize uint64 24 25 // maxWindowSize is the maximum windows size to support. 26 // should never be bigger than max-int. 27 maxWindowSize uint64 28 29 // In order queue of blocks being decoded. 30 decoding chan *blockDec 31 32 // Frame history passed between blocks 33 history history 34 35 rawInput byteBuffer 36 37 // Byte buffer that can be reused for small input blocks. 38 bBuf byteBuf 39 40 FrameContentSize uint64 41 frameDone sync.WaitGroup 42 43 DictionaryID *uint32 44 HasCheckSum bool 45 SingleSegment bool 46 47 // asyncRunning indicates whether the async routine processes input on 'decoding'. 48 asyncRunningMu sync.Mutex 49 asyncRunning bool 50} 51 52const ( 53 // The minimum Window_Size is 1 KB. 54 MinWindowSize = 1 << 10 55 MaxWindowSize = 1 << 29 56) 57 58var ( 59 frameMagic = []byte{0x28, 0xb5, 0x2f, 0xfd} 60 skippableFrameMagic = []byte{0x2a, 0x4d, 0x18} 61) 62 63func newFrameDec(o decoderOptions) *frameDec { 64 d := frameDec{ 65 o: o, 66 maxWindowSize: MaxWindowSize, 67 } 68 if d.maxWindowSize > o.maxDecodedSize { 69 d.maxWindowSize = o.maxDecodedSize 70 } 71 return &d 72} 73 74// reset will read the frame header and prepare for block decoding. 75// If nothing can be read from the input, io.EOF will be returned. 76// Any other error indicated that the stream contained data, but 77// there was a problem. 78func (d *frameDec) reset(br byteBuffer) error { 79 d.HasCheckSum = false 80 d.WindowSize = 0 81 var b []byte 82 for { 83 b = br.readSmall(4) 84 if b == nil { 85 return io.EOF 86 } 87 if !bytes.Equal(b[1:4], skippableFrameMagic) || b[0]&0xf0 != 0x50 { 88 if debug { 89 println("Not skippable", hex.EncodeToString(b), hex.EncodeToString(skippableFrameMagic)) 90 } 91 // Break if not skippable frame. 92 break 93 } 94 // Read size to skip 95 b = br.readSmall(4) 96 if b == nil { 97 println("Reading Frame Size EOF") 98 return io.ErrUnexpectedEOF 99 } 100 n := uint32(b[0]) | (uint32(b[1]) << 8) | (uint32(b[2]) << 16) | (uint32(b[3]) << 24) 101 println("Skipping frame with", n, "bytes.") 102 err := br.skipN(int(n)) 103 if err != nil { 104 if debug { 105 println("Reading discarded frame", err) 106 } 107 return err 108 } 109 } 110 if !bytes.Equal(b, frameMagic) { 111 println("Got magic numbers: ", b, "want:", frameMagic) 112 return ErrMagicMismatch 113 } 114 115 // Read Frame_Header_Descriptor 116 fhd, err := br.readByte() 117 if err != nil { 118 println("Reading Frame_Header_Descriptor", err) 119 return err 120 } 121 d.SingleSegment = fhd&(1<<5) != 0 122 123 if fhd&(1<<3) != 0 { 124 return errors.New("Reserved bit set on frame header") 125 } 126 127 // Read Window_Descriptor 128 // https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#window_descriptor 129 d.WindowSize = 0 130 if !d.SingleSegment { 131 wd, err := br.readByte() 132 if err != nil { 133 println("Reading Window_Descriptor", err) 134 return err 135 } 136 printf("raw: %x, mantissa: %d, exponent: %d\n", wd, wd&7, wd>>3) 137 windowLog := 10 + (wd >> 3) 138 windowBase := uint64(1) << windowLog 139 windowAdd := (windowBase / 8) * uint64(wd&0x7) 140 d.WindowSize = windowBase + windowAdd 141 } 142 143 // Read Dictionary_ID 144 // https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#dictionary_id 145 d.DictionaryID = nil 146 if size := fhd & 3; size != 0 { 147 if size == 3 { 148 size = 4 149 } 150 b = br.readSmall(int(size)) 151 if b == nil { 152 if debug { 153 println("Reading Dictionary_ID", io.ErrUnexpectedEOF) 154 } 155 return io.ErrUnexpectedEOF 156 } 157 var id uint32 158 switch size { 159 case 1: 160 id = uint32(b[0]) 161 case 2: 162 id = uint32(b[0]) | (uint32(b[1]) << 8) 163 case 4: 164 id = uint32(b[0]) | (uint32(b[1]) << 8) | (uint32(b[2]) << 16) | (uint32(b[3]) << 24) 165 } 166 if debug { 167 println("Dict size", size, "ID:", id) 168 } 169 if id > 0 { 170 // ID 0 means "sorry, no dictionary anyway". 171 // https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#dictionary-format 172 d.DictionaryID = &id 173 } 174 } 175 176 // Read Frame_Content_Size 177 // https://github.com/facebook/zstd/blob/dev/doc/zstd_compression_format.md#frame_content_size 178 var fcsSize int 179 v := fhd >> 6 180 switch v { 181 case 0: 182 if d.SingleSegment { 183 fcsSize = 1 184 } 185 default: 186 fcsSize = 1 << v 187 } 188 d.FrameContentSize = 0 189 if fcsSize > 0 { 190 b := br.readSmall(fcsSize) 191 if b == nil { 192 println("Reading Frame content", io.ErrUnexpectedEOF) 193 return io.ErrUnexpectedEOF 194 } 195 switch fcsSize { 196 case 1: 197 d.FrameContentSize = uint64(b[0]) 198 case 2: 199 // When FCS_Field_Size is 2, the offset of 256 is added. 200 d.FrameContentSize = uint64(b[0]) | (uint64(b[1]) << 8) + 256 201 case 4: 202 d.FrameContentSize = uint64(b[0]) | (uint64(b[1]) << 8) | (uint64(b[2]) << 16) | (uint64(b[3]) << 24) 203 case 8: 204 d1 := uint32(b[0]) | (uint32(b[1]) << 8) | (uint32(b[2]) << 16) | (uint32(b[3]) << 24) 205 d2 := uint32(b[4]) | (uint32(b[5]) << 8) | (uint32(b[6]) << 16) | (uint32(b[7]) << 24) 206 d.FrameContentSize = uint64(d1) | (uint64(d2) << 32) 207 } 208 if debug { 209 println("field size bits:", v, "fcsSize:", fcsSize, "FrameContentSize:", d.FrameContentSize, hex.EncodeToString(b[:fcsSize]), "singleseg:", d.SingleSegment, "window:", d.WindowSize) 210 } 211 } 212 // Move this to shared. 213 d.HasCheckSum = fhd&(1<<2) != 0 214 if d.HasCheckSum { 215 if d.crc == nil { 216 d.crc = xxhash.New() 217 } 218 d.crc.Reset() 219 } 220 221 if d.WindowSize == 0 && d.SingleSegment { 222 // We may not need window in this case. 223 d.WindowSize = d.FrameContentSize 224 if d.WindowSize < MinWindowSize { 225 d.WindowSize = MinWindowSize 226 } 227 } 228 229 if d.WindowSize > d.maxWindowSize { 230 printf("window size %d > max %d\n", d.WindowSize, d.maxWindowSize) 231 return ErrWindowSizeExceeded 232 } 233 // The minimum Window_Size is 1 KB. 234 if d.WindowSize < MinWindowSize { 235 println("got window size: ", d.WindowSize) 236 return ErrWindowSizeTooSmall 237 } 238 d.history.windowSize = int(d.WindowSize) 239 if d.o.lowMem && d.history.windowSize < maxBlockSize { 240 d.history.maxSize = d.history.windowSize * 2 241 } else { 242 d.history.maxSize = d.history.windowSize + maxBlockSize 243 } 244 // history contains input - maybe we do something 245 d.rawInput = br 246 return nil 247} 248 249// next will start decoding the next block from stream. 250func (d *frameDec) next(block *blockDec) error { 251 if debug { 252 printf("decoding new block %p:%p", block, block.data) 253 } 254 err := block.reset(d.rawInput, d.WindowSize) 255 if err != nil { 256 println("block error:", err) 257 // Signal the frame decoder we have a problem. 258 d.sendErr(block, err) 259 return err 260 } 261 block.input <- struct{}{} 262 if debug { 263 println("next block:", block) 264 } 265 d.asyncRunningMu.Lock() 266 defer d.asyncRunningMu.Unlock() 267 if !d.asyncRunning { 268 return nil 269 } 270 if block.Last { 271 // We indicate the frame is done by sending io.EOF 272 d.decoding <- block 273 return io.EOF 274 } 275 d.decoding <- block 276 return nil 277} 278 279// sendEOF will queue an error block on the frame. 280// This will cause the frame decoder to return when it encounters the block. 281// Returns true if the decoder was added. 282func (d *frameDec) sendErr(block *blockDec, err error) bool { 283 d.asyncRunningMu.Lock() 284 defer d.asyncRunningMu.Unlock() 285 if !d.asyncRunning { 286 return false 287 } 288 289 println("sending error", err.Error()) 290 block.sendErr(err) 291 d.decoding <- block 292 return true 293} 294 295// checkCRC will check the checksum if the frame has one. 296// Will return ErrCRCMismatch if crc check failed, otherwise nil. 297func (d *frameDec) checkCRC() error { 298 if !d.HasCheckSum { 299 return nil 300 } 301 var tmp [4]byte 302 got := d.crc.Sum64() 303 // Flip to match file order. 304 tmp[0] = byte(got >> 0) 305 tmp[1] = byte(got >> 8) 306 tmp[2] = byte(got >> 16) 307 tmp[3] = byte(got >> 24) 308 309 // We can overwrite upper tmp now 310 want := d.rawInput.readSmall(4) 311 if want == nil { 312 println("CRC missing?") 313 return io.ErrUnexpectedEOF 314 } 315 316 if !bytes.Equal(tmp[:], want) { 317 if debug { 318 println("CRC Check Failed:", tmp[:], "!=", want) 319 } 320 return ErrCRCMismatch 321 } 322 if debug { 323 println("CRC ok", tmp[:]) 324 } 325 return nil 326} 327 328func (d *frameDec) initAsync() { 329 if !d.o.lowMem && !d.SingleSegment { 330 // set max extra size history to 10MB. 331 d.history.maxSize = d.history.windowSize + maxBlockSize*5 332 } 333 // re-alloc if more than one extra block size. 334 if d.o.lowMem && cap(d.history.b) > d.history.maxSize+maxBlockSize { 335 d.history.b = make([]byte, 0, d.history.maxSize) 336 } 337 if cap(d.history.b) < d.history.maxSize { 338 d.history.b = make([]byte, 0, d.history.maxSize) 339 } 340 if cap(d.decoding) < d.o.concurrent { 341 d.decoding = make(chan *blockDec, d.o.concurrent) 342 } 343 if debug { 344 h := d.history 345 printf("history init. len: %d, cap: %d", len(h.b), cap(h.b)) 346 } 347 d.asyncRunningMu.Lock() 348 d.asyncRunning = true 349 d.asyncRunningMu.Unlock() 350} 351 352// startDecoder will start decoding blocks and write them to the writer. 353// The decoder will stop as soon as an error occurs or at end of frame. 354// When the frame has finished decoding the *bufio.Reader 355// containing the remaining input will be sent on frameDec.frameDone. 356func (d *frameDec) startDecoder(output chan decodeOutput) { 357 written := int64(0) 358 359 defer func() { 360 d.asyncRunningMu.Lock() 361 d.asyncRunning = false 362 d.asyncRunningMu.Unlock() 363 364 // Drain the currently decoding. 365 d.history.error = true 366 flushdone: 367 for { 368 select { 369 case b := <-d.decoding: 370 b.history <- &d.history 371 output <- <-b.result 372 default: 373 break flushdone 374 } 375 } 376 println("frame decoder done, signalling done") 377 d.frameDone.Done() 378 }() 379 // Get decoder for first block. 380 block := <-d.decoding 381 block.history <- &d.history 382 for { 383 var next *blockDec 384 // Get result 385 r := <-block.result 386 if r.err != nil { 387 println("Result contained error", r.err) 388 output <- r 389 return 390 } 391 if debug { 392 println("got result, from ", d.offset, "to", d.offset+int64(len(r.b))) 393 d.offset += int64(len(r.b)) 394 } 395 if !block.Last { 396 // Send history to next block 397 select { 398 case next = <-d.decoding: 399 if debug { 400 println("Sending ", len(d.history.b), "bytes as history") 401 } 402 next.history <- &d.history 403 default: 404 // Wait until we have sent the block, so 405 // other decoders can potentially get the decoder. 406 next = nil 407 } 408 } 409 410 // Add checksum, async to decoding. 411 if d.HasCheckSum { 412 n, err := d.crc.Write(r.b) 413 if err != nil { 414 r.err = err 415 if n != len(r.b) { 416 r.err = io.ErrShortWrite 417 } 418 output <- r 419 return 420 } 421 } 422 written += int64(len(r.b)) 423 if d.SingleSegment && uint64(written) > d.FrameContentSize { 424 println("runDecoder: single segment and", uint64(written), ">", d.FrameContentSize) 425 r.err = ErrFrameSizeExceeded 426 output <- r 427 return 428 } 429 if block.Last { 430 r.err = d.checkCRC() 431 output <- r 432 return 433 } 434 output <- r 435 if next == nil { 436 // There was no decoder available, we wait for one now that we have sent to the writer. 437 if debug { 438 println("Sending ", len(d.history.b), " bytes as history") 439 } 440 next = <-d.decoding 441 next.history <- &d.history 442 } 443 block = next 444 } 445} 446 447// runDecoder will create a sync decoder that will decode a block of data. 448func (d *frameDec) runDecoder(dst []byte, dec *blockDec) ([]byte, error) { 449 saved := d.history.b 450 451 // We use the history for output to avoid copying it. 452 d.history.b = dst 453 // Store input length, so we only check new data. 454 crcStart := len(dst) 455 var err error 456 for { 457 err = dec.reset(d.rawInput, d.WindowSize) 458 if err != nil { 459 break 460 } 461 if debug { 462 println("next block:", dec) 463 } 464 err = dec.decodeBuf(&d.history) 465 if err != nil || dec.Last { 466 break 467 } 468 if uint64(len(d.history.b)) > d.o.maxDecodedSize { 469 err = ErrDecoderSizeExceeded 470 break 471 } 472 if d.SingleSegment && uint64(len(d.history.b)) > d.o.maxDecodedSize { 473 println("runDecoder: single segment and", uint64(len(d.history.b)), ">", d.o.maxDecodedSize) 474 err = ErrFrameSizeExceeded 475 break 476 } 477 } 478 dst = d.history.b 479 if err == nil { 480 if d.HasCheckSum { 481 var n int 482 n, err = d.crc.Write(dst[crcStart:]) 483 if err == nil { 484 if n != len(dst)-crcStart { 485 err = io.ErrShortWrite 486 } else { 487 err = d.checkCRC() 488 } 489 } 490 } 491 } 492 d.history.b = saved 493 return dst, err 494} 495