1// Copyright 2019+ Klaus Post. All rights reserved. 2// License information can be found in the LICENSE file. 3// Based on work by Yann Collet, released under BSD License. 4 5package zstd 6 7import ( 8 "crypto/rand" 9 "fmt" 10 "io" 11 rdebug "runtime/debug" 12 "sync" 13 14 "github.com/klauspost/compress/zstd/internal/xxhash" 15) 16 17// Encoder provides encoding to Zstandard. 18// An Encoder can be used for either compressing a stream via the 19// io.WriteCloser interface supported by the Encoder or as multiple independent 20// tasks via the EncodeAll function. 21// Smaller encodes are encouraged to use the EncodeAll function. 22// Use NewWriter to create a new instance. 23type Encoder struct { 24 o encoderOptions 25 encoders chan encoder 26 state encoderState 27 init sync.Once 28} 29 30type encoder interface { 31 Encode(blk *blockEnc, src []byte) 32 EncodeNoHist(blk *blockEnc, src []byte) 33 Block() *blockEnc 34 CRC() *xxhash.Digest 35 AppendCRC([]byte) []byte 36 WindowSize(size int) int32 37 UseBlock(*blockEnc) 38 Reset(d *dict, singleBlock bool) 39} 40 41type encoderState struct { 42 w io.Writer 43 filling []byte 44 current []byte 45 previous []byte 46 encoder encoder 47 writing *blockEnc 48 err error 49 writeErr error 50 nWritten int64 51 headerWritten bool 52 eofWritten bool 53 fullFrameWritten bool 54 55 // This waitgroup indicates an encode is running. 56 wg sync.WaitGroup 57 // This waitgroup indicates we have a block encoding/writing. 58 wWg sync.WaitGroup 59} 60 61// NewWriter will create a new Zstandard encoder. 62// If the encoder will be used for encoding blocks a nil writer can be used. 63func NewWriter(w io.Writer, opts ...EOption) (*Encoder, error) { 64 initPredefined() 65 var e Encoder 66 e.o.setDefault() 67 for _, o := range opts { 68 err := o(&e.o) 69 if err != nil { 70 return nil, err 71 } 72 } 73 if w != nil { 74 e.Reset(w) 75 } 76 return &e, nil 77} 78 79func (e *Encoder) initialize() { 80 if e.o.concurrent == 0 { 81 e.o.setDefault() 82 } 83 e.encoders = make(chan encoder, e.o.concurrent) 84 for i := 0; i < e.o.concurrent; i++ { 85 enc := e.o.encoder() 86 e.encoders <- enc 87 } 88} 89 90// Reset will re-initialize the writer and new writes will encode to the supplied writer 91// as a new, independent stream. 92func (e *Encoder) Reset(w io.Writer) { 93 s := &e.state 94 s.wg.Wait() 95 s.wWg.Wait() 96 if cap(s.filling) == 0 { 97 s.filling = make([]byte, 0, e.o.blockSize) 98 } 99 if cap(s.current) == 0 { 100 s.current = make([]byte, 0, e.o.blockSize) 101 } 102 if cap(s.previous) == 0 { 103 s.previous = make([]byte, 0, e.o.blockSize) 104 } 105 if s.encoder == nil { 106 s.encoder = e.o.encoder() 107 } 108 if s.writing == nil { 109 s.writing = &blockEnc{} 110 s.writing.init() 111 } 112 s.writing.initNewEncode() 113 s.filling = s.filling[:0] 114 s.current = s.current[:0] 115 s.previous = s.previous[:0] 116 s.encoder.Reset(e.o.dict, false) 117 s.headerWritten = false 118 s.eofWritten = false 119 s.fullFrameWritten = false 120 s.w = w 121 s.err = nil 122 s.nWritten = 0 123 s.writeErr = nil 124} 125 126// Write data to the encoder. 127// Input data will be buffered and as the buffer fills up 128// content will be compressed and written to the output. 129// When done writing, use Close to flush the remaining output 130// and write CRC if requested. 131func (e *Encoder) Write(p []byte) (n int, err error) { 132 s := &e.state 133 for len(p) > 0 { 134 if len(p)+len(s.filling) < e.o.blockSize { 135 if e.o.crc { 136 _, _ = s.encoder.CRC().Write(p) 137 } 138 s.filling = append(s.filling, p...) 139 return n + len(p), nil 140 } 141 add := p 142 if len(p)+len(s.filling) > e.o.blockSize { 143 add = add[:e.o.blockSize-len(s.filling)] 144 } 145 if e.o.crc { 146 _, _ = s.encoder.CRC().Write(add) 147 } 148 s.filling = append(s.filling, add...) 149 p = p[len(add):] 150 n += len(add) 151 if len(s.filling) < e.o.blockSize { 152 return n, nil 153 } 154 err := e.nextBlock(false) 155 if err != nil { 156 return n, err 157 } 158 if debugAsserts && len(s.filling) > 0 { 159 panic(len(s.filling)) 160 } 161 } 162 return n, nil 163} 164 165// nextBlock will synchronize and start compressing input in e.state.filling. 166// If an error has occurred during encoding it will be returned. 167func (e *Encoder) nextBlock(final bool) error { 168 s := &e.state 169 // Wait for current block. 170 s.wg.Wait() 171 if s.err != nil { 172 return s.err 173 } 174 if len(s.filling) > e.o.blockSize { 175 return fmt.Errorf("block > maxStoreBlockSize") 176 } 177 if !s.headerWritten { 178 // If we have a single block encode, do a sync compression. 179 if final && len(s.filling) > 0 { 180 s.current = e.EncodeAll(s.filling, s.current[:0]) 181 var n2 int 182 n2, s.err = s.w.Write(s.current) 183 if s.err != nil { 184 return s.err 185 } 186 s.nWritten += int64(n2) 187 s.current = s.current[:0] 188 s.filling = s.filling[:0] 189 s.headerWritten = true 190 s.fullFrameWritten = true 191 s.eofWritten = true 192 return nil 193 } 194 195 var tmp [maxHeaderSize]byte 196 fh := frameHeader{ 197 ContentSize: 0, 198 WindowSize: uint32(s.encoder.WindowSize(0)), 199 SingleSegment: false, 200 Checksum: e.o.crc, 201 DictID: e.o.dict.ID(), 202 } 203 204 dst, err := fh.appendTo(tmp[:0]) 205 if err != nil { 206 return err 207 } 208 s.headerWritten = true 209 s.wWg.Wait() 210 var n2 int 211 n2, s.err = s.w.Write(dst) 212 if s.err != nil { 213 return s.err 214 } 215 s.nWritten += int64(n2) 216 } 217 if s.eofWritten { 218 // Ensure we only write it once. 219 final = false 220 } 221 222 if len(s.filling) == 0 { 223 // Final block, but no data. 224 if final { 225 enc := s.encoder 226 blk := enc.Block() 227 blk.reset(nil) 228 blk.last = true 229 blk.encodeRaw(nil) 230 s.wWg.Wait() 231 _, s.err = s.w.Write(blk.output) 232 s.nWritten += int64(len(blk.output)) 233 s.eofWritten = true 234 } 235 return s.err 236 } 237 238 // Move blocks forward. 239 s.filling, s.current, s.previous = s.previous[:0], s.filling, s.current 240 s.wg.Add(1) 241 go func(src []byte) { 242 if debug { 243 println("Adding block,", len(src), "bytes, final:", final) 244 } 245 defer func() { 246 if r := recover(); r != nil { 247 s.err = fmt.Errorf("panic while encoding: %v", r) 248 rdebug.PrintStack() 249 } 250 s.wg.Done() 251 }() 252 enc := s.encoder 253 blk := enc.Block() 254 enc.Encode(blk, src) 255 blk.last = final 256 if final { 257 s.eofWritten = true 258 } 259 // Wait for pending writes. 260 s.wWg.Wait() 261 if s.writeErr != nil { 262 s.err = s.writeErr 263 return 264 } 265 // Transfer encoders from previous write block. 266 blk.swapEncoders(s.writing) 267 // Transfer recent offsets to next. 268 enc.UseBlock(s.writing) 269 s.writing = blk 270 s.wWg.Add(1) 271 go func() { 272 defer func() { 273 if r := recover(); r != nil { 274 s.writeErr = fmt.Errorf("panic while encoding/writing: %v", r) 275 rdebug.PrintStack() 276 } 277 s.wWg.Done() 278 }() 279 err := errIncompressible 280 // If we got the exact same number of literals as input, 281 // assume the literals cannot be compressed. 282 if len(src) != len(blk.literals) || len(src) != e.o.blockSize { 283 err = blk.encode(src, e.o.noEntropy, !e.o.allLitEntropy) 284 } 285 switch err { 286 case errIncompressible: 287 if debug { 288 println("Storing incompressible block as raw") 289 } 290 blk.encodeRaw(src) 291 // In fast mode, we do not transfer offsets, so we don't have to deal with changing the. 292 case nil: 293 default: 294 s.writeErr = err 295 return 296 } 297 _, s.writeErr = s.w.Write(blk.output) 298 s.nWritten += int64(len(blk.output)) 299 }() 300 }(s.current) 301 return nil 302} 303 304// ReadFrom reads data from r until EOF or error. 305// The return value n is the number of bytes read. 306// Any error except io.EOF encountered during the read is also returned. 307// 308// The Copy function uses ReaderFrom if available. 309func (e *Encoder) ReadFrom(r io.Reader) (n int64, err error) { 310 if debug { 311 println("Using ReadFrom") 312 } 313 314 // Flush any current writes. 315 if len(e.state.filling) > 0 { 316 if err := e.nextBlock(false); err != nil { 317 return 0, err 318 } 319 } 320 e.state.filling = e.state.filling[:e.o.blockSize] 321 src := e.state.filling 322 for { 323 n2, err := r.Read(src) 324 if e.o.crc { 325 _, _ = e.state.encoder.CRC().Write(src[:n2]) 326 } 327 // src is now the unfilled part... 328 src = src[n2:] 329 n += int64(n2) 330 switch err { 331 case io.EOF: 332 e.state.filling = e.state.filling[:len(e.state.filling)-len(src)] 333 if debug { 334 println("ReadFrom: got EOF final block:", len(e.state.filling)) 335 } 336 return n, nil 337 default: 338 if debug { 339 println("ReadFrom: got error:", err) 340 } 341 e.state.err = err 342 return n, err 343 case nil: 344 } 345 if len(src) > 0 { 346 if debug { 347 println("ReadFrom: got space left in source:", len(src)) 348 } 349 continue 350 } 351 err = e.nextBlock(false) 352 if err != nil { 353 return n, err 354 } 355 e.state.filling = e.state.filling[:e.o.blockSize] 356 src = e.state.filling 357 } 358} 359 360// Flush will send the currently written data to output 361// and block until everything has been written. 362// This should only be used on rare occasions where pushing the currently queued data is critical. 363func (e *Encoder) Flush() error { 364 s := &e.state 365 if len(s.filling) > 0 { 366 err := e.nextBlock(false) 367 if err != nil { 368 return err 369 } 370 } 371 s.wg.Wait() 372 s.wWg.Wait() 373 if s.err != nil { 374 return s.err 375 } 376 return s.writeErr 377} 378 379// Close will flush the final output and close the stream. 380// The function will block until everything has been written. 381// The Encoder can still be re-used after calling this. 382func (e *Encoder) Close() error { 383 s := &e.state 384 if s.encoder == nil { 385 return nil 386 } 387 err := e.nextBlock(true) 388 if err != nil { 389 return err 390 } 391 if e.state.fullFrameWritten { 392 return s.err 393 } 394 s.wg.Wait() 395 s.wWg.Wait() 396 397 if s.err != nil { 398 return s.err 399 } 400 if s.writeErr != nil { 401 return s.writeErr 402 } 403 404 // Write CRC 405 if e.o.crc && s.err == nil { 406 // heap alloc. 407 var tmp [4]byte 408 _, s.err = s.w.Write(s.encoder.AppendCRC(tmp[:0])) 409 s.nWritten += 4 410 } 411 412 // Add padding with content from crypto/rand.Reader 413 if s.err == nil && e.o.pad > 0 { 414 add := calcSkippableFrame(s.nWritten, int64(e.o.pad)) 415 frame, err := skippableFrame(s.filling[:0], add, rand.Reader) 416 if err != nil { 417 return err 418 } 419 _, s.err = s.w.Write(frame) 420 } 421 return s.err 422} 423 424// EncodeAll will encode all input in src and append it to dst. 425// This function can be called concurrently, but each call will only run on a single goroutine. 426// If empty input is given, nothing is returned, unless WithZeroFrames is specified. 427// Encoded blocks can be concatenated and the result will be the combined input stream. 428// Data compressed with EncodeAll can be decoded with the Decoder, 429// using either a stream or DecodeAll. 430func (e *Encoder) EncodeAll(src, dst []byte) []byte { 431 if len(src) == 0 { 432 if e.o.fullZero { 433 // Add frame header. 434 fh := frameHeader{ 435 ContentSize: 0, 436 WindowSize: MinWindowSize, 437 SingleSegment: true, 438 // Adding a checksum would be a waste of space. 439 Checksum: false, 440 DictID: 0, 441 } 442 dst, _ = fh.appendTo(dst) 443 444 // Write raw block as last one only. 445 var blk blockHeader 446 blk.setSize(0) 447 blk.setType(blockTypeRaw) 448 blk.setLast(true) 449 dst = blk.appendTo(dst) 450 } 451 return dst 452 } 453 e.init.Do(e.initialize) 454 enc := <-e.encoders 455 defer func() { 456 // Release encoder reference to last block. 457 // If a non-single block is needed the encoder will reset again. 458 e.encoders <- enc 459 }() 460 // Use single segments when above minimum window and below 1MB. 461 single := len(src) < 1<<20 && len(src) > MinWindowSize 462 if e.o.single != nil { 463 single = *e.o.single 464 } 465 fh := frameHeader{ 466 ContentSize: uint64(len(src)), 467 WindowSize: uint32(enc.WindowSize(len(src))), 468 SingleSegment: single, 469 Checksum: e.o.crc, 470 DictID: e.o.dict.ID(), 471 } 472 473 // If less than 1MB, allocate a buffer up front. 474 if len(dst) == 0 && cap(dst) == 0 && len(src) < 1<<20 { 475 dst = make([]byte, 0, len(src)) 476 } 477 dst, err := fh.appendTo(dst) 478 if err != nil { 479 panic(err) 480 } 481 482 // If we can do everything in one block, prefer that. 483 if len(src) <= maxCompressedBlockSize { 484 enc.Reset(e.o.dict, true) 485 // Slightly faster with no history and everything in one block. 486 if e.o.crc { 487 _, _ = enc.CRC().Write(src) 488 } 489 blk := enc.Block() 490 blk.last = true 491 if e.o.dict == nil { 492 enc.EncodeNoHist(blk, src) 493 } else { 494 enc.Encode(blk, src) 495 } 496 497 // If we got the exact same number of literals as input, 498 // assume the literals cannot be compressed. 499 err := errIncompressible 500 oldout := blk.output 501 if len(blk.literals) != len(src) || len(src) != e.o.blockSize { 502 // Output directly to dst 503 blk.output = dst 504 err = blk.encode(src, e.o.noEntropy, !e.o.allLitEntropy) 505 } 506 507 switch err { 508 case errIncompressible: 509 if debug { 510 println("Storing incompressible block as raw") 511 } 512 dst = blk.encodeRawTo(dst, src) 513 case nil: 514 dst = blk.output 515 default: 516 panic(err) 517 } 518 blk.output = oldout 519 } else { 520 enc.Reset(e.o.dict, false) 521 blk := enc.Block() 522 for len(src) > 0 { 523 todo := src 524 if len(todo) > e.o.blockSize { 525 todo = todo[:e.o.blockSize] 526 } 527 src = src[len(todo):] 528 if e.o.crc { 529 _, _ = enc.CRC().Write(todo) 530 } 531 blk.pushOffsets() 532 enc.Encode(blk, todo) 533 if len(src) == 0 { 534 blk.last = true 535 } 536 err := errIncompressible 537 // If we got the exact same number of literals as input, 538 // assume the literals cannot be compressed. 539 if len(blk.literals) != len(todo) || len(todo) != e.o.blockSize { 540 err = blk.encode(todo, e.o.noEntropy, !e.o.allLitEntropy) 541 } 542 543 switch err { 544 case errIncompressible: 545 if debug { 546 println("Storing incompressible block as raw") 547 } 548 dst = blk.encodeRawTo(dst, todo) 549 blk.popOffsets() 550 case nil: 551 dst = append(dst, blk.output...) 552 default: 553 panic(err) 554 } 555 blk.reset(nil) 556 } 557 } 558 if e.o.crc { 559 dst = enc.AppendCRC(dst) 560 } 561 // Add padding with content from crypto/rand.Reader 562 if e.o.pad > 0 { 563 add := calcSkippableFrame(int64(len(dst)), int64(e.o.pad)) 564 dst, err = skippableFrame(dst, add, rand.Reader) 565 if err != nil { 566 panic(err) 567 } 568 } 569 return dst 570} 571