1// Copyright 2019+ Klaus Post. All rights reserved. 2// License information can be found in the LICENSE file. 3// Based on work by Yann Collet, released under BSD License. 4 5package zstd 6 7import ( 8 "crypto/rand" 9 "fmt" 10 "io" 11 rdebug "runtime/debug" 12 "sync" 13 14 "github.com/klauspost/compress/zstd/internal/xxhash" 15) 16 17// Encoder provides encoding to Zstandard. 18// An Encoder can be used for either compressing a stream via the 19// io.WriteCloser interface supported by the Encoder or as multiple independent 20// tasks via the EncodeAll function. 21// Smaller encodes are encouraged to use the EncodeAll function. 22// Use NewWriter to create a new instance. 23type Encoder struct { 24 o encoderOptions 25 encoders chan encoder 26 state encoderState 27 init sync.Once 28} 29 30type encoder interface { 31 Encode(blk *blockEnc, src []byte) 32 EncodeNoHist(blk *blockEnc, src []byte) 33 Block() *blockEnc 34 CRC() *xxhash.Digest 35 AppendCRC([]byte) []byte 36 WindowSize(size int) int32 37 UseBlock(*blockEnc) 38 Reset(d *dict, singleBlock bool) 39} 40 41type encoderState struct { 42 w io.Writer 43 filling []byte 44 current []byte 45 previous []byte 46 encoder encoder 47 writing *blockEnc 48 err error 49 writeErr error 50 nWritten int64 51 headerWritten bool 52 eofWritten bool 53 fullFrameWritten bool 54 55 // This waitgroup indicates an encode is running. 56 wg sync.WaitGroup 57 // This waitgroup indicates we have a block encoding/writing. 58 wWg sync.WaitGroup 59} 60 61// NewWriter will create a new Zstandard encoder. 62// If the encoder will be used for encoding blocks a nil writer can be used. 63func NewWriter(w io.Writer, opts ...EOption) (*Encoder, error) { 64 initPredefined() 65 var e Encoder 66 e.o.setDefault() 67 for _, o := range opts { 68 err := o(&e.o) 69 if err != nil { 70 return nil, err 71 } 72 } 73 if w != nil { 74 e.Reset(w) 75 } 76 return &e, nil 77} 78 79func (e *Encoder) initialize() { 80 if e.o.concurrent == 0 { 81 e.o.setDefault() 82 } 83 e.encoders = make(chan encoder, e.o.concurrent) 84 for i := 0; i < e.o.concurrent; i++ { 85 enc := e.o.encoder() 86 e.encoders <- enc 87 } 88} 89 90// Reset will re-initialize the writer and new writes will encode to the supplied writer 91// as a new, independent stream. 92func (e *Encoder) Reset(w io.Writer) { 93 s := &e.state 94 s.wg.Wait() 95 s.wWg.Wait() 96 if cap(s.filling) == 0 { 97 s.filling = make([]byte, 0, e.o.blockSize) 98 } 99 if cap(s.current) == 0 { 100 s.current = make([]byte, 0, e.o.blockSize) 101 } 102 if cap(s.previous) == 0 { 103 s.previous = make([]byte, 0, e.o.blockSize) 104 } 105 if s.encoder == nil { 106 s.encoder = e.o.encoder() 107 } 108 if s.writing == nil { 109 s.writing = &blockEnc{lowMem: e.o.lowMem} 110 s.writing.init() 111 } 112 s.writing.initNewEncode() 113 s.filling = s.filling[:0] 114 s.current = s.current[:0] 115 s.previous = s.previous[:0] 116 s.encoder.Reset(e.o.dict, false) 117 s.headerWritten = false 118 s.eofWritten = false 119 s.fullFrameWritten = false 120 s.w = w 121 s.err = nil 122 s.nWritten = 0 123 s.writeErr = nil 124} 125 126// Write data to the encoder. 127// Input data will be buffered and as the buffer fills up 128// content will be compressed and written to the output. 129// When done writing, use Close to flush the remaining output 130// and write CRC if requested. 131func (e *Encoder) Write(p []byte) (n int, err error) { 132 s := &e.state 133 for len(p) > 0 { 134 if len(p)+len(s.filling) < e.o.blockSize { 135 if e.o.crc { 136 _, _ = s.encoder.CRC().Write(p) 137 } 138 s.filling = append(s.filling, p...) 139 return n + len(p), nil 140 } 141 add := p 142 if len(p)+len(s.filling) > e.o.blockSize { 143 add = add[:e.o.blockSize-len(s.filling)] 144 } 145 if e.o.crc { 146 _, _ = s.encoder.CRC().Write(add) 147 } 148 s.filling = append(s.filling, add...) 149 p = p[len(add):] 150 n += len(add) 151 if len(s.filling) < e.o.blockSize { 152 return n, nil 153 } 154 err := e.nextBlock(false) 155 if err != nil { 156 return n, err 157 } 158 if debugAsserts && len(s.filling) > 0 { 159 panic(len(s.filling)) 160 } 161 } 162 return n, nil 163} 164 165// nextBlock will synchronize and start compressing input in e.state.filling. 166// If an error has occurred during encoding it will be returned. 167func (e *Encoder) nextBlock(final bool) error { 168 s := &e.state 169 // Wait for current block. 170 s.wg.Wait() 171 if s.err != nil { 172 return s.err 173 } 174 if len(s.filling) > e.o.blockSize { 175 return fmt.Errorf("block > maxStoreBlockSize") 176 } 177 if !s.headerWritten { 178 // If we have a single block encode, do a sync compression. 179 if final && len(s.filling) == 0 && !e.o.fullZero { 180 s.headerWritten = true 181 s.fullFrameWritten = true 182 s.eofWritten = true 183 return nil 184 } 185 if final && len(s.filling) > 0 { 186 s.current = e.EncodeAll(s.filling, s.current[:0]) 187 var n2 int 188 n2, s.err = s.w.Write(s.current) 189 if s.err != nil { 190 return s.err 191 } 192 s.nWritten += int64(n2) 193 s.current = s.current[:0] 194 s.filling = s.filling[:0] 195 s.headerWritten = true 196 s.fullFrameWritten = true 197 s.eofWritten = true 198 return nil 199 } 200 201 var tmp [maxHeaderSize]byte 202 fh := frameHeader{ 203 ContentSize: 0, 204 WindowSize: uint32(s.encoder.WindowSize(0)), 205 SingleSegment: false, 206 Checksum: e.o.crc, 207 DictID: e.o.dict.ID(), 208 } 209 210 dst, err := fh.appendTo(tmp[:0]) 211 if err != nil { 212 return err 213 } 214 s.headerWritten = true 215 s.wWg.Wait() 216 var n2 int 217 n2, s.err = s.w.Write(dst) 218 if s.err != nil { 219 return s.err 220 } 221 s.nWritten += int64(n2) 222 } 223 if s.eofWritten { 224 // Ensure we only write it once. 225 final = false 226 } 227 228 if len(s.filling) == 0 { 229 // Final block, but no data. 230 if final { 231 enc := s.encoder 232 blk := enc.Block() 233 blk.reset(nil) 234 blk.last = true 235 blk.encodeRaw(nil) 236 s.wWg.Wait() 237 _, s.err = s.w.Write(blk.output) 238 s.nWritten += int64(len(blk.output)) 239 s.eofWritten = true 240 } 241 return s.err 242 } 243 244 // Move blocks forward. 245 s.filling, s.current, s.previous = s.previous[:0], s.filling, s.current 246 s.wg.Add(1) 247 go func(src []byte) { 248 if debugEncoder { 249 println("Adding block,", len(src), "bytes, final:", final) 250 } 251 defer func() { 252 if r := recover(); r != nil { 253 s.err = fmt.Errorf("panic while encoding: %v", r) 254 rdebug.PrintStack() 255 } 256 s.wg.Done() 257 }() 258 enc := s.encoder 259 blk := enc.Block() 260 enc.Encode(blk, src) 261 blk.last = final 262 if final { 263 s.eofWritten = true 264 } 265 // Wait for pending writes. 266 s.wWg.Wait() 267 if s.writeErr != nil { 268 s.err = s.writeErr 269 return 270 } 271 // Transfer encoders from previous write block. 272 blk.swapEncoders(s.writing) 273 // Transfer recent offsets to next. 274 enc.UseBlock(s.writing) 275 s.writing = blk 276 s.wWg.Add(1) 277 go func() { 278 defer func() { 279 if r := recover(); r != nil { 280 s.writeErr = fmt.Errorf("panic while encoding/writing: %v", r) 281 rdebug.PrintStack() 282 } 283 s.wWg.Done() 284 }() 285 err := errIncompressible 286 // If we got the exact same number of literals as input, 287 // assume the literals cannot be compressed. 288 if len(src) != len(blk.literals) || len(src) != e.o.blockSize { 289 err = blk.encode(src, e.o.noEntropy, !e.o.allLitEntropy) 290 } 291 switch err { 292 case errIncompressible: 293 if debugEncoder { 294 println("Storing incompressible block as raw") 295 } 296 blk.encodeRaw(src) 297 // In fast mode, we do not transfer offsets, so we don't have to deal with changing the. 298 case nil: 299 default: 300 s.writeErr = err 301 return 302 } 303 _, s.writeErr = s.w.Write(blk.output) 304 s.nWritten += int64(len(blk.output)) 305 }() 306 }(s.current) 307 return nil 308} 309 310// ReadFrom reads data from r until EOF or error. 311// The return value n is the number of bytes read. 312// Any error except io.EOF encountered during the read is also returned. 313// 314// The Copy function uses ReaderFrom if available. 315func (e *Encoder) ReadFrom(r io.Reader) (n int64, err error) { 316 if debugEncoder { 317 println("Using ReadFrom") 318 } 319 320 // Flush any current writes. 321 if len(e.state.filling) > 0 { 322 if err := e.nextBlock(false); err != nil { 323 return 0, err 324 } 325 } 326 e.state.filling = e.state.filling[:e.o.blockSize] 327 src := e.state.filling 328 for { 329 n2, err := r.Read(src) 330 if e.o.crc { 331 _, _ = e.state.encoder.CRC().Write(src[:n2]) 332 } 333 // src is now the unfilled part... 334 src = src[n2:] 335 n += int64(n2) 336 switch err { 337 case io.EOF: 338 e.state.filling = e.state.filling[:len(e.state.filling)-len(src)] 339 if debugEncoder { 340 println("ReadFrom: got EOF final block:", len(e.state.filling)) 341 } 342 return n, nil 343 case nil: 344 default: 345 if debugEncoder { 346 println("ReadFrom: got error:", err) 347 } 348 e.state.err = err 349 return n, err 350 } 351 if len(src) > 0 { 352 if debugEncoder { 353 println("ReadFrom: got space left in source:", len(src)) 354 } 355 continue 356 } 357 err = e.nextBlock(false) 358 if err != nil { 359 return n, err 360 } 361 e.state.filling = e.state.filling[:e.o.blockSize] 362 src = e.state.filling 363 } 364} 365 366// Flush will send the currently written data to output 367// and block until everything has been written. 368// This should only be used on rare occasions where pushing the currently queued data is critical. 369func (e *Encoder) Flush() error { 370 s := &e.state 371 if len(s.filling) > 0 { 372 err := e.nextBlock(false) 373 if err != nil { 374 return err 375 } 376 } 377 s.wg.Wait() 378 s.wWg.Wait() 379 if s.err != nil { 380 return s.err 381 } 382 return s.writeErr 383} 384 385// Close will flush the final output and close the stream. 386// The function will block until everything has been written. 387// The Encoder can still be re-used after calling this. 388func (e *Encoder) Close() error { 389 s := &e.state 390 if s.encoder == nil { 391 return nil 392 } 393 err := e.nextBlock(true) 394 if err != nil { 395 return err 396 } 397 if e.state.fullFrameWritten { 398 return s.err 399 } 400 s.wg.Wait() 401 s.wWg.Wait() 402 403 if s.err != nil { 404 return s.err 405 } 406 if s.writeErr != nil { 407 return s.writeErr 408 } 409 410 // Write CRC 411 if e.o.crc && s.err == nil { 412 // heap alloc. 413 var tmp [4]byte 414 _, s.err = s.w.Write(s.encoder.AppendCRC(tmp[:0])) 415 s.nWritten += 4 416 } 417 418 // Add padding with content from crypto/rand.Reader 419 if s.err == nil && e.o.pad > 0 { 420 add := calcSkippableFrame(s.nWritten, int64(e.o.pad)) 421 frame, err := skippableFrame(s.filling[:0], add, rand.Reader) 422 if err != nil { 423 return err 424 } 425 _, s.err = s.w.Write(frame) 426 } 427 return s.err 428} 429 430// EncodeAll will encode all input in src and append it to dst. 431// This function can be called concurrently, but each call will only run on a single goroutine. 432// If empty input is given, nothing is returned, unless WithZeroFrames is specified. 433// Encoded blocks can be concatenated and the result will be the combined input stream. 434// Data compressed with EncodeAll can be decoded with the Decoder, 435// using either a stream or DecodeAll. 436func (e *Encoder) EncodeAll(src, dst []byte) []byte { 437 if len(src) == 0 { 438 if e.o.fullZero { 439 // Add frame header. 440 fh := frameHeader{ 441 ContentSize: 0, 442 WindowSize: MinWindowSize, 443 SingleSegment: true, 444 // Adding a checksum would be a waste of space. 445 Checksum: false, 446 DictID: 0, 447 } 448 dst, _ = fh.appendTo(dst) 449 450 // Write raw block as last one only. 451 var blk blockHeader 452 blk.setSize(0) 453 blk.setType(blockTypeRaw) 454 blk.setLast(true) 455 dst = blk.appendTo(dst) 456 } 457 return dst 458 } 459 e.init.Do(e.initialize) 460 enc := <-e.encoders 461 defer func() { 462 // Release encoder reference to last block. 463 // If a non-single block is needed the encoder will reset again. 464 e.encoders <- enc 465 }() 466 // Use single segments when above minimum window and below 1MB. 467 single := len(src) < 1<<20 && len(src) > MinWindowSize 468 if e.o.single != nil { 469 single = *e.o.single 470 } 471 fh := frameHeader{ 472 ContentSize: uint64(len(src)), 473 WindowSize: uint32(enc.WindowSize(len(src))), 474 SingleSegment: single, 475 Checksum: e.o.crc, 476 DictID: e.o.dict.ID(), 477 } 478 479 // If less than 1MB, allocate a buffer up front. 480 if len(dst) == 0 && cap(dst) == 0 && len(src) < 1<<20 && !e.o.lowMem { 481 dst = make([]byte, 0, len(src)) 482 } 483 dst, err := fh.appendTo(dst) 484 if err != nil { 485 panic(err) 486 } 487 488 // If we can do everything in one block, prefer that. 489 if len(src) <= maxCompressedBlockSize { 490 enc.Reset(e.o.dict, true) 491 // Slightly faster with no history and everything in one block. 492 if e.o.crc { 493 _, _ = enc.CRC().Write(src) 494 } 495 blk := enc.Block() 496 blk.last = true 497 if e.o.dict == nil { 498 enc.EncodeNoHist(blk, src) 499 } else { 500 enc.Encode(blk, src) 501 } 502 503 // If we got the exact same number of literals as input, 504 // assume the literals cannot be compressed. 505 err := errIncompressible 506 oldout := blk.output 507 if len(blk.literals) != len(src) || len(src) != e.o.blockSize { 508 // Output directly to dst 509 blk.output = dst 510 err = blk.encode(src, e.o.noEntropy, !e.o.allLitEntropy) 511 } 512 513 switch err { 514 case errIncompressible: 515 if debugEncoder { 516 println("Storing incompressible block as raw") 517 } 518 dst = blk.encodeRawTo(dst, src) 519 case nil: 520 dst = blk.output 521 default: 522 panic(err) 523 } 524 blk.output = oldout 525 } else { 526 enc.Reset(e.o.dict, false) 527 blk := enc.Block() 528 for len(src) > 0 { 529 todo := src 530 if len(todo) > e.o.blockSize { 531 todo = todo[:e.o.blockSize] 532 } 533 src = src[len(todo):] 534 if e.o.crc { 535 _, _ = enc.CRC().Write(todo) 536 } 537 blk.pushOffsets() 538 enc.Encode(blk, todo) 539 if len(src) == 0 { 540 blk.last = true 541 } 542 err := errIncompressible 543 // If we got the exact same number of literals as input, 544 // assume the literals cannot be compressed. 545 if len(blk.literals) != len(todo) || len(todo) != e.o.blockSize { 546 err = blk.encode(todo, e.o.noEntropy, !e.o.allLitEntropy) 547 } 548 549 switch err { 550 case errIncompressible: 551 if debugEncoder { 552 println("Storing incompressible block as raw") 553 } 554 dst = blk.encodeRawTo(dst, todo) 555 blk.popOffsets() 556 case nil: 557 dst = append(dst, blk.output...) 558 default: 559 panic(err) 560 } 561 blk.reset(nil) 562 } 563 } 564 if e.o.crc { 565 dst = enc.AppendCRC(dst) 566 } 567 // Add padding with content from crypto/rand.Reader 568 if e.o.pad > 0 { 569 add := calcSkippableFrame(int64(len(dst)), int64(e.o.pad)) 570 dst, err = skippableFrame(dst, add, rand.Reader) 571 if err != nil { 572 panic(err) 573 } 574 } 575 return dst 576} 577