1// Copyright 2019+ Klaus Post. All rights reserved. 2// License information can be found in the LICENSE file. 3// Based on work by Yann Collet, released under BSD License. 4 5package zstd 6 7import ( 8 "crypto/rand" 9 "fmt" 10 "io" 11 rdebug "runtime/debug" 12 "sync" 13 14 "github.com/klauspost/compress/zstd/internal/xxhash" 15) 16 17// Encoder provides encoding to Zstandard. 18// An Encoder can be used for either compressing a stream via the 19// io.WriteCloser interface supported by the Encoder or as multiple independent 20// tasks via the EncodeAll function. 21// Smaller encodes are encouraged to use the EncodeAll function. 22// Use NewWriter to create a new instance. 23type Encoder struct { 24 o encoderOptions 25 encoders chan encoder 26 state encoderState 27 init sync.Once 28} 29 30type encoder interface { 31 Encode(blk *blockEnc, src []byte) 32 EncodeNoHist(blk *blockEnc, src []byte) 33 Block() *blockEnc 34 CRC() *xxhash.Digest 35 AppendCRC([]byte) []byte 36 WindowSize(size int64) int32 37 UseBlock(*blockEnc) 38 Reset(d *dict, singleBlock bool) 39} 40 41type encoderState struct { 42 w io.Writer 43 filling []byte 44 current []byte 45 previous []byte 46 encoder encoder 47 writing *blockEnc 48 err error 49 writeErr error 50 nWritten int64 51 nInput int64 52 frameContentSize int64 53 headerWritten bool 54 eofWritten bool 55 fullFrameWritten bool 56 57 // This waitgroup indicates an encode is running. 58 wg sync.WaitGroup 59 // This waitgroup indicates we have a block encoding/writing. 60 wWg sync.WaitGroup 61} 62 63// NewWriter will create a new Zstandard encoder. 64// If the encoder will be used for encoding blocks a nil writer can be used. 65func NewWriter(w io.Writer, opts ...EOption) (*Encoder, error) { 66 initPredefined() 67 var e Encoder 68 e.o.setDefault() 69 for _, o := range opts { 70 err := o(&e.o) 71 if err != nil { 72 return nil, err 73 } 74 } 75 if w != nil { 76 e.Reset(w) 77 } 78 return &e, nil 79} 80 81func (e *Encoder) initialize() { 82 if e.o.concurrent == 0 { 83 e.o.setDefault() 84 } 85 e.encoders = make(chan encoder, e.o.concurrent) 86 for i := 0; i < e.o.concurrent; i++ { 87 enc := e.o.encoder() 88 e.encoders <- enc 89 } 90} 91 92// Reset will re-initialize the writer and new writes will encode to the supplied writer 93// as a new, independent stream. 94func (e *Encoder) Reset(w io.Writer) { 95 s := &e.state 96 s.wg.Wait() 97 s.wWg.Wait() 98 if cap(s.filling) == 0 { 99 s.filling = make([]byte, 0, e.o.blockSize) 100 } 101 if cap(s.current) == 0 { 102 s.current = make([]byte, 0, e.o.blockSize) 103 } 104 if cap(s.previous) == 0 { 105 s.previous = make([]byte, 0, e.o.blockSize) 106 } 107 if s.encoder == nil { 108 s.encoder = e.o.encoder() 109 } 110 if s.writing == nil { 111 s.writing = &blockEnc{lowMem: e.o.lowMem} 112 s.writing.init() 113 } 114 s.writing.initNewEncode() 115 s.filling = s.filling[:0] 116 s.current = s.current[:0] 117 s.previous = s.previous[:0] 118 s.encoder.Reset(e.o.dict, false) 119 s.headerWritten = false 120 s.eofWritten = false 121 s.fullFrameWritten = false 122 s.w = w 123 s.err = nil 124 s.nWritten = 0 125 s.nInput = 0 126 s.writeErr = nil 127 s.frameContentSize = 0 128} 129 130// ResetContentSize will reset and set a content size for the next stream. 131// If the bytes written does not match the size given an error will be returned 132// when calling Close(). 133// This is removed when Reset is called. 134// Sizes <= 0 results in no content size set. 135func (e *Encoder) ResetContentSize(w io.Writer, size int64) { 136 e.Reset(w) 137 if size >= 0 { 138 e.state.frameContentSize = size 139 } 140} 141 142// Write data to the encoder. 143// Input data will be buffered and as the buffer fills up 144// content will be compressed and written to the output. 145// When done writing, use Close to flush the remaining output 146// and write CRC if requested. 147func (e *Encoder) Write(p []byte) (n int, err error) { 148 s := &e.state 149 for len(p) > 0 { 150 if len(p)+len(s.filling) < e.o.blockSize { 151 if e.o.crc { 152 _, _ = s.encoder.CRC().Write(p) 153 } 154 s.filling = append(s.filling, p...) 155 return n + len(p), nil 156 } 157 add := p 158 if len(p)+len(s.filling) > e.o.blockSize { 159 add = add[:e.o.blockSize-len(s.filling)] 160 } 161 if e.o.crc { 162 _, _ = s.encoder.CRC().Write(add) 163 } 164 s.filling = append(s.filling, add...) 165 p = p[len(add):] 166 n += len(add) 167 if len(s.filling) < e.o.blockSize { 168 return n, nil 169 } 170 err := e.nextBlock(false) 171 if err != nil { 172 return n, err 173 } 174 if debugAsserts && len(s.filling) > 0 { 175 panic(len(s.filling)) 176 } 177 } 178 return n, nil 179} 180 181// nextBlock will synchronize and start compressing input in e.state.filling. 182// If an error has occurred during encoding it will be returned. 183func (e *Encoder) nextBlock(final bool) error { 184 s := &e.state 185 // Wait for current block. 186 s.wg.Wait() 187 if s.err != nil { 188 return s.err 189 } 190 if len(s.filling) > e.o.blockSize { 191 return fmt.Errorf("block > maxStoreBlockSize") 192 } 193 if !s.headerWritten { 194 // If we have a single block encode, do a sync compression. 195 if final && len(s.filling) == 0 && !e.o.fullZero { 196 s.headerWritten = true 197 s.fullFrameWritten = true 198 s.eofWritten = true 199 return nil 200 } 201 if final && len(s.filling) > 0 { 202 s.current = e.EncodeAll(s.filling, s.current[:0]) 203 var n2 int 204 n2, s.err = s.w.Write(s.current) 205 if s.err != nil { 206 return s.err 207 } 208 s.nWritten += int64(n2) 209 s.nInput += int64(len(s.filling)) 210 s.current = s.current[:0] 211 s.filling = s.filling[:0] 212 s.headerWritten = true 213 s.fullFrameWritten = true 214 s.eofWritten = true 215 return nil 216 } 217 218 var tmp [maxHeaderSize]byte 219 fh := frameHeader{ 220 ContentSize: uint64(s.frameContentSize), 221 WindowSize: uint32(s.encoder.WindowSize(s.frameContentSize)), 222 SingleSegment: false, 223 Checksum: e.o.crc, 224 DictID: e.o.dict.ID(), 225 } 226 227 dst, err := fh.appendTo(tmp[:0]) 228 if err != nil { 229 return err 230 } 231 s.headerWritten = true 232 s.wWg.Wait() 233 var n2 int 234 n2, s.err = s.w.Write(dst) 235 if s.err != nil { 236 return s.err 237 } 238 s.nWritten += int64(n2) 239 } 240 if s.eofWritten { 241 // Ensure we only write it once. 242 final = false 243 } 244 245 if len(s.filling) == 0 { 246 // Final block, but no data. 247 if final { 248 enc := s.encoder 249 blk := enc.Block() 250 blk.reset(nil) 251 blk.last = true 252 blk.encodeRaw(nil) 253 s.wWg.Wait() 254 _, s.err = s.w.Write(blk.output) 255 s.nWritten += int64(len(blk.output)) 256 s.eofWritten = true 257 } 258 return s.err 259 } 260 261 // Move blocks forward. 262 s.filling, s.current, s.previous = s.previous[:0], s.filling, s.current 263 s.nInput += int64(len(s.current)) 264 s.wg.Add(1) 265 go func(src []byte) { 266 if debugEncoder { 267 println("Adding block,", len(src), "bytes, final:", final) 268 } 269 defer func() { 270 if r := recover(); r != nil { 271 s.err = fmt.Errorf("panic while encoding: %v", r) 272 rdebug.PrintStack() 273 } 274 s.wg.Done() 275 }() 276 enc := s.encoder 277 blk := enc.Block() 278 enc.Encode(blk, src) 279 blk.last = final 280 if final { 281 s.eofWritten = true 282 } 283 // Wait for pending writes. 284 s.wWg.Wait() 285 if s.writeErr != nil { 286 s.err = s.writeErr 287 return 288 } 289 // Transfer encoders from previous write block. 290 blk.swapEncoders(s.writing) 291 // Transfer recent offsets to next. 292 enc.UseBlock(s.writing) 293 s.writing = blk 294 s.wWg.Add(1) 295 go func() { 296 defer func() { 297 if r := recover(); r != nil { 298 s.writeErr = fmt.Errorf("panic while encoding/writing: %v", r) 299 rdebug.PrintStack() 300 } 301 s.wWg.Done() 302 }() 303 err := errIncompressible 304 // If we got the exact same number of literals as input, 305 // assume the literals cannot be compressed. 306 if len(src) != len(blk.literals) || len(src) != e.o.blockSize { 307 err = blk.encode(src, e.o.noEntropy, !e.o.allLitEntropy) 308 } 309 switch err { 310 case errIncompressible: 311 if debugEncoder { 312 println("Storing incompressible block as raw") 313 } 314 blk.encodeRaw(src) 315 // In fast mode, we do not transfer offsets, so we don't have to deal with changing the. 316 case nil: 317 default: 318 s.writeErr = err 319 return 320 } 321 _, s.writeErr = s.w.Write(blk.output) 322 s.nWritten += int64(len(blk.output)) 323 }() 324 }(s.current) 325 return nil 326} 327 328// ReadFrom reads data from r until EOF or error. 329// The return value n is the number of bytes read. 330// Any error except io.EOF encountered during the read is also returned. 331// 332// The Copy function uses ReaderFrom if available. 333func (e *Encoder) ReadFrom(r io.Reader) (n int64, err error) { 334 if debugEncoder { 335 println("Using ReadFrom") 336 } 337 338 // Flush any current writes. 339 if len(e.state.filling) > 0 { 340 if err := e.nextBlock(false); err != nil { 341 return 0, err 342 } 343 } 344 e.state.filling = e.state.filling[:e.o.blockSize] 345 src := e.state.filling 346 for { 347 n2, err := r.Read(src) 348 if e.o.crc { 349 _, _ = e.state.encoder.CRC().Write(src[:n2]) 350 } 351 // src is now the unfilled part... 352 src = src[n2:] 353 n += int64(n2) 354 switch err { 355 case io.EOF: 356 e.state.filling = e.state.filling[:len(e.state.filling)-len(src)] 357 if debugEncoder { 358 println("ReadFrom: got EOF final block:", len(e.state.filling)) 359 } 360 return n, nil 361 case nil: 362 default: 363 if debugEncoder { 364 println("ReadFrom: got error:", err) 365 } 366 e.state.err = err 367 return n, err 368 } 369 if len(src) > 0 { 370 if debugEncoder { 371 println("ReadFrom: got space left in source:", len(src)) 372 } 373 continue 374 } 375 err = e.nextBlock(false) 376 if err != nil { 377 return n, err 378 } 379 e.state.filling = e.state.filling[:e.o.blockSize] 380 src = e.state.filling 381 } 382} 383 384// Flush will send the currently written data to output 385// and block until everything has been written. 386// This should only be used on rare occasions where pushing the currently queued data is critical. 387func (e *Encoder) Flush() error { 388 s := &e.state 389 if len(s.filling) > 0 { 390 err := e.nextBlock(false) 391 if err != nil { 392 return err 393 } 394 } 395 s.wg.Wait() 396 s.wWg.Wait() 397 if s.err != nil { 398 return s.err 399 } 400 return s.writeErr 401} 402 403// Close will flush the final output and close the stream. 404// The function will block until everything has been written. 405// The Encoder can still be re-used after calling this. 406func (e *Encoder) Close() error { 407 s := &e.state 408 if s.encoder == nil { 409 return nil 410 } 411 err := e.nextBlock(true) 412 if err != nil { 413 return err 414 } 415 if s.frameContentSize > 0 { 416 if s.nInput != s.frameContentSize { 417 return fmt.Errorf("frame content size %d given, but %d bytes was written", s.frameContentSize, s.nInput) 418 } 419 } 420 if e.state.fullFrameWritten { 421 return s.err 422 } 423 s.wg.Wait() 424 s.wWg.Wait() 425 426 if s.err != nil { 427 return s.err 428 } 429 if s.writeErr != nil { 430 return s.writeErr 431 } 432 433 // Write CRC 434 if e.o.crc && s.err == nil { 435 // heap alloc. 436 var tmp [4]byte 437 _, s.err = s.w.Write(s.encoder.AppendCRC(tmp[:0])) 438 s.nWritten += 4 439 } 440 441 // Add padding with content from crypto/rand.Reader 442 if s.err == nil && e.o.pad > 0 { 443 add := calcSkippableFrame(s.nWritten, int64(e.o.pad)) 444 frame, err := skippableFrame(s.filling[:0], add, rand.Reader) 445 if err != nil { 446 return err 447 } 448 _, s.err = s.w.Write(frame) 449 } 450 return s.err 451} 452 453// EncodeAll will encode all input in src and append it to dst. 454// This function can be called concurrently, but each call will only run on a single goroutine. 455// If empty input is given, nothing is returned, unless WithZeroFrames is specified. 456// Encoded blocks can be concatenated and the result will be the combined input stream. 457// Data compressed with EncodeAll can be decoded with the Decoder, 458// using either a stream or DecodeAll. 459func (e *Encoder) EncodeAll(src, dst []byte) []byte { 460 if len(src) == 0 { 461 if e.o.fullZero { 462 // Add frame header. 463 fh := frameHeader{ 464 ContentSize: 0, 465 WindowSize: MinWindowSize, 466 SingleSegment: true, 467 // Adding a checksum would be a waste of space. 468 Checksum: false, 469 DictID: 0, 470 } 471 dst, _ = fh.appendTo(dst) 472 473 // Write raw block as last one only. 474 var blk blockHeader 475 blk.setSize(0) 476 blk.setType(blockTypeRaw) 477 blk.setLast(true) 478 dst = blk.appendTo(dst) 479 } 480 return dst 481 } 482 e.init.Do(e.initialize) 483 enc := <-e.encoders 484 defer func() { 485 // Release encoder reference to last block. 486 // If a non-single block is needed the encoder will reset again. 487 e.encoders <- enc 488 }() 489 // Use single segments when above minimum window and below 1MB. 490 single := len(src) < 1<<20 && len(src) > MinWindowSize 491 if e.o.single != nil { 492 single = *e.o.single 493 } 494 fh := frameHeader{ 495 ContentSize: uint64(len(src)), 496 WindowSize: uint32(enc.WindowSize(int64(len(src)))), 497 SingleSegment: single, 498 Checksum: e.o.crc, 499 DictID: e.o.dict.ID(), 500 } 501 502 // If less than 1MB, allocate a buffer up front. 503 if len(dst) == 0 && cap(dst) == 0 && len(src) < 1<<20 && !e.o.lowMem { 504 dst = make([]byte, 0, len(src)) 505 } 506 dst, err := fh.appendTo(dst) 507 if err != nil { 508 panic(err) 509 } 510 511 // If we can do everything in one block, prefer that. 512 if len(src) <= maxCompressedBlockSize { 513 enc.Reset(e.o.dict, true) 514 // Slightly faster with no history and everything in one block. 515 if e.o.crc { 516 _, _ = enc.CRC().Write(src) 517 } 518 blk := enc.Block() 519 blk.last = true 520 if e.o.dict == nil { 521 enc.EncodeNoHist(blk, src) 522 } else { 523 enc.Encode(blk, src) 524 } 525 526 // If we got the exact same number of literals as input, 527 // assume the literals cannot be compressed. 528 err := errIncompressible 529 oldout := blk.output 530 if len(blk.literals) != len(src) || len(src) != e.o.blockSize { 531 // Output directly to dst 532 blk.output = dst 533 err = blk.encode(src, e.o.noEntropy, !e.o.allLitEntropy) 534 } 535 536 switch err { 537 case errIncompressible: 538 if debugEncoder { 539 println("Storing incompressible block as raw") 540 } 541 dst = blk.encodeRawTo(dst, src) 542 case nil: 543 dst = blk.output 544 default: 545 panic(err) 546 } 547 blk.output = oldout 548 } else { 549 enc.Reset(e.o.dict, false) 550 blk := enc.Block() 551 for len(src) > 0 { 552 todo := src 553 if len(todo) > e.o.blockSize { 554 todo = todo[:e.o.blockSize] 555 } 556 src = src[len(todo):] 557 if e.o.crc { 558 _, _ = enc.CRC().Write(todo) 559 } 560 blk.pushOffsets() 561 enc.Encode(blk, todo) 562 if len(src) == 0 { 563 blk.last = true 564 } 565 err := errIncompressible 566 // If we got the exact same number of literals as input, 567 // assume the literals cannot be compressed. 568 if len(blk.literals) != len(todo) || len(todo) != e.o.blockSize { 569 err = blk.encode(todo, e.o.noEntropy, !e.o.allLitEntropy) 570 } 571 572 switch err { 573 case errIncompressible: 574 if debugEncoder { 575 println("Storing incompressible block as raw") 576 } 577 dst = blk.encodeRawTo(dst, todo) 578 blk.popOffsets() 579 case nil: 580 dst = append(dst, blk.output...) 581 default: 582 panic(err) 583 } 584 blk.reset(nil) 585 } 586 } 587 if e.o.crc { 588 dst = enc.AppendCRC(dst) 589 } 590 // Add padding with content from crypto/rand.Reader 591 if e.o.pad > 0 { 592 add := calcSkippableFrame(int64(len(dst)), int64(e.o.pad)) 593 dst, err = skippableFrame(dst, add, rand.Reader) 594 if err != nil { 595 panic(err) 596 } 597 } 598 return dst 599} 600