1// Copyright 2019+ Klaus Post. All rights reserved. 2// License information can be found in the LICENSE file. 3// Based on work by Yann Collet, released under BSD License. 4 5package zstd 6 7import ( 8 "bytes" 9 "fmt" 10 11 "github.com/klauspost/compress" 12) 13 14const ( 15 bestLongTableBits = 22 // Bits used in the long match table 16 bestLongTableSize = 1 << bestLongTableBits // Size of the table 17 bestLongLen = 8 // Bytes used for table hash 18 19 // Note: Increasing the short table bits or making the hash shorter 20 // can actually lead to compression degradation since it will 'steal' more from the 21 // long match table and match offsets are quite big. 22 // This greatly depends on the type of input. 23 bestShortTableBits = 18 // Bits used in the short match table 24 bestShortTableSize = 1 << bestShortTableBits // Size of the table 25 bestShortLen = 4 // Bytes used for table hash 26 27) 28 29type match struct { 30 offset int32 31 s int32 32 length int32 33 rep int32 34 est int32 35} 36 37const highScore = 25000 38 39// estBits will estimate output bits from predefined tables. 40func (m *match) estBits(bitsPerByte int32) { 41 mlc := mlCode(uint32(m.length - zstdMinMatch)) 42 var ofc uint8 43 if m.rep < 0 { 44 ofc = ofCode(uint32(m.s-m.offset) + 3) 45 } else { 46 ofc = ofCode(uint32(m.rep)) 47 } 48 // Cost, excluding 49 ofTT, mlTT := fsePredefEnc[tableOffsets].ct.symbolTT[ofc], fsePredefEnc[tableMatchLengths].ct.symbolTT[mlc] 50 51 // Add cost of match encoding... 52 m.est = int32(ofTT.outBits + mlTT.outBits) 53 m.est += int32(ofTT.deltaNbBits>>16 + mlTT.deltaNbBits>>16) 54 // Subtract savings compared to literal encoding... 55 m.est -= (m.length * bitsPerByte) >> 10 56 if m.est > 0 { 57 // Unlikely gain.. 58 m.length = 0 59 m.est = highScore 60 } 61} 62 63// bestFastEncoder uses 2 tables, one for short matches (5 bytes) and one for long matches. 64// The long match table contains the previous entry with the same hash, 65// effectively making it a "chain" of length 2. 66// When we find a long match we choose between the two values and select the longest. 67// When we find a short match, after checking the long, we check if we can find a long at n+1 68// and that it is longer (lazy matching). 69type bestFastEncoder struct { 70 fastBase 71 table [bestShortTableSize]prevEntry 72 longTable [bestLongTableSize]prevEntry 73 dictTable []prevEntry 74 dictLongTable []prevEntry 75} 76 77// Encode improves compression... 78func (e *bestFastEncoder) Encode(blk *blockEnc, src []byte) { 79 const ( 80 // Input margin is the number of bytes we read (8) 81 // and the maximum we will read ahead (2) 82 inputMargin = 8 + 4 83 minNonLiteralBlockSize = 16 84 ) 85 86 // Protect against e.cur wraparound. 87 for e.cur >= bufferReset { 88 if len(e.hist) == 0 { 89 for i := range e.table[:] { 90 e.table[i] = prevEntry{} 91 } 92 for i := range e.longTable[:] { 93 e.longTable[i] = prevEntry{} 94 } 95 e.cur = e.maxMatchOff 96 break 97 } 98 // Shift down everything in the table that isn't already too far away. 99 minOff := e.cur + int32(len(e.hist)) - e.maxMatchOff 100 for i := range e.table[:] { 101 v := e.table[i].offset 102 v2 := e.table[i].prev 103 if v < minOff { 104 v = 0 105 v2 = 0 106 } else { 107 v = v - e.cur + e.maxMatchOff 108 if v2 < minOff { 109 v2 = 0 110 } else { 111 v2 = v2 - e.cur + e.maxMatchOff 112 } 113 } 114 e.table[i] = prevEntry{ 115 offset: v, 116 prev: v2, 117 } 118 } 119 for i := range e.longTable[:] { 120 v := e.longTable[i].offset 121 v2 := e.longTable[i].prev 122 if v < minOff { 123 v = 0 124 v2 = 0 125 } else { 126 v = v - e.cur + e.maxMatchOff 127 if v2 < minOff { 128 v2 = 0 129 } else { 130 v2 = v2 - e.cur + e.maxMatchOff 131 } 132 } 133 e.longTable[i] = prevEntry{ 134 offset: v, 135 prev: v2, 136 } 137 } 138 e.cur = e.maxMatchOff 139 break 140 } 141 142 s := e.addBlock(src) 143 blk.size = len(src) 144 if len(src) < minNonLiteralBlockSize { 145 blk.extraLits = len(src) 146 blk.literals = blk.literals[:len(src)] 147 copy(blk.literals, src) 148 return 149 } 150 151 // Use this to estimate literal cost. 152 // Scaled by 10 bits. 153 bitsPerByte := int32((compress.ShannonEntropyBits(src) * 1024) / len(src)) 154 // Huffman can never go < 1 bit/byte 155 if bitsPerByte < 1024 { 156 bitsPerByte = 1024 157 } 158 159 // Override src 160 src = e.hist 161 sLimit := int32(len(src)) - inputMargin 162 const kSearchStrength = 10 163 164 // nextEmit is where in src the next emitLiteral should start from. 165 nextEmit := s 166 cv := load6432(src, s) 167 168 // Relative offsets 169 offset1 := int32(blk.recentOffsets[0]) 170 offset2 := int32(blk.recentOffsets[1]) 171 offset3 := int32(blk.recentOffsets[2]) 172 173 addLiterals := func(s *seq, until int32) { 174 if until == nextEmit { 175 return 176 } 177 blk.literals = append(blk.literals, src[nextEmit:until]...) 178 s.litLen = uint32(until - nextEmit) 179 } 180 _ = addLiterals 181 182 if debugEncoder { 183 println("recent offsets:", blk.recentOffsets) 184 } 185 186encodeLoop: 187 for { 188 // We allow the encoder to optionally turn off repeat offsets across blocks 189 canRepeat := len(blk.sequences) > 2 190 191 if debugAsserts && canRepeat && offset1 == 0 { 192 panic("offset0 was 0") 193 } 194 195 bestOf := func(a, b match) match { 196 if a.est+(a.s-b.s)*bitsPerByte>>10 < b.est+(b.s-a.s)*bitsPerByte>>10 { 197 return a 198 } 199 return b 200 } 201 const goodEnough = 100 202 203 nextHashL := hashLen(cv, bestLongTableBits, bestLongLen) 204 nextHashS := hashLen(cv, bestShortTableBits, bestShortLen) 205 candidateL := e.longTable[nextHashL] 206 candidateS := e.table[nextHashS] 207 208 matchAt := func(offset int32, s int32, first uint32, rep int32) match { 209 if s-offset >= e.maxMatchOff || load3232(src, offset) != first { 210 return match{s: s, est: highScore} 211 } 212 if debugAsserts { 213 if !bytes.Equal(src[s:s+4], src[offset:offset+4]) { 214 panic(fmt.Sprintf("first match mismatch: %v != %v, first: %08x", src[s:s+4], src[offset:offset+4], first)) 215 } 216 } 217 m := match{offset: offset, s: s, length: 4 + e.matchlen(s+4, offset+4, src), rep: rep} 218 m.estBits(bitsPerByte) 219 return m 220 } 221 222 best := bestOf(matchAt(candidateL.offset-e.cur, s, uint32(cv), -1), matchAt(candidateL.prev-e.cur, s, uint32(cv), -1)) 223 best = bestOf(best, matchAt(candidateS.offset-e.cur, s, uint32(cv), -1)) 224 best = bestOf(best, matchAt(candidateS.prev-e.cur, s, uint32(cv), -1)) 225 226 if canRepeat && best.length < goodEnough { 227 cv32 := uint32(cv >> 8) 228 spp := s + 1 229 best = bestOf(best, matchAt(spp-offset1, spp, cv32, 1)) 230 best = bestOf(best, matchAt(spp-offset2, spp, cv32, 2)) 231 best = bestOf(best, matchAt(spp-offset3, spp, cv32, 3)) 232 if best.length > 0 { 233 cv32 = uint32(cv >> 24) 234 spp += 2 235 best = bestOf(best, matchAt(spp-offset1, spp, cv32, 1)) 236 best = bestOf(best, matchAt(spp-offset2, spp, cv32, 2)) 237 best = bestOf(best, matchAt(spp-offset3, spp, cv32, 3)) 238 } 239 } 240 // Load next and check... 241 e.longTable[nextHashL] = prevEntry{offset: s + e.cur, prev: candidateL.offset} 242 e.table[nextHashS] = prevEntry{offset: s + e.cur, prev: candidateS.offset} 243 244 // Look far ahead, unless we have a really long match already... 245 if best.length < goodEnough { 246 // No match found, move forward on input, no need to check forward... 247 if best.length < 4 { 248 s += 1 + (s-nextEmit)>>(kSearchStrength-1) 249 if s >= sLimit { 250 break encodeLoop 251 } 252 cv = load6432(src, s) 253 continue 254 } 255 256 s++ 257 candidateS = e.table[hashLen(cv>>8, bestShortTableBits, bestShortLen)] 258 cv = load6432(src, s) 259 cv2 := load6432(src, s+1) 260 candidateL = e.longTable[hashLen(cv, bestLongTableBits, bestLongLen)] 261 candidateL2 := e.longTable[hashLen(cv2, bestLongTableBits, bestLongLen)] 262 263 // Short at s+1 264 best = bestOf(best, matchAt(candidateS.offset-e.cur, s, uint32(cv), -1)) 265 // Long at s+1, s+2 266 best = bestOf(best, matchAt(candidateL.offset-e.cur, s, uint32(cv), -1)) 267 best = bestOf(best, matchAt(candidateL.prev-e.cur, s, uint32(cv), -1)) 268 best = bestOf(best, matchAt(candidateL2.offset-e.cur, s+1, uint32(cv2), -1)) 269 best = bestOf(best, matchAt(candidateL2.prev-e.cur, s+1, uint32(cv2), -1)) 270 if false { 271 // Short at s+3. 272 // Too often worse... 273 best = bestOf(best, matchAt(e.table[hashLen(cv2>>8, bestShortTableBits, bestShortLen)].offset-e.cur, s+2, uint32(cv2>>8), -1)) 274 } 275 // See if we can find a better match by checking where the current best ends. 276 // Use that offset to see if we can find a better full match. 277 if sAt := best.s + best.length; sAt < sLimit { 278 nextHashL := hashLen(load6432(src, sAt), bestLongTableBits, bestLongLen) 279 candidateEnd := e.longTable[nextHashL] 280 if pos := candidateEnd.offset - e.cur - best.length; pos >= 0 { 281 bestEnd := bestOf(best, matchAt(pos, best.s, load3232(src, best.s), -1)) 282 if pos := candidateEnd.prev - e.cur - best.length; pos >= 0 { 283 bestEnd = bestOf(bestEnd, matchAt(pos, best.s, load3232(src, best.s), -1)) 284 } 285 best = bestEnd 286 } 287 } 288 } 289 290 if debugAsserts { 291 if !bytes.Equal(src[best.s:best.s+best.length], src[best.offset:best.offset+best.length]) { 292 panic(fmt.Sprintf("match mismatch: %v != %v", src[best.s:best.s+best.length], src[best.offset:best.offset+best.length])) 293 } 294 } 295 296 // We have a match, we can store the forward value 297 if best.rep > 0 { 298 s = best.s 299 var seq seq 300 seq.matchLen = uint32(best.length - zstdMinMatch) 301 302 // We might be able to match backwards. 303 // Extend as long as we can. 304 start := best.s 305 // We end the search early, so we don't risk 0 literals 306 // and have to do special offset treatment. 307 startLimit := nextEmit + 1 308 309 tMin := s - e.maxMatchOff 310 if tMin < 0 { 311 tMin = 0 312 } 313 repIndex := best.offset 314 for repIndex > tMin && start > startLimit && src[repIndex-1] == src[start-1] && seq.matchLen < maxMatchLength-zstdMinMatch-1 { 315 repIndex-- 316 start-- 317 seq.matchLen++ 318 } 319 addLiterals(&seq, start) 320 321 // rep 0 322 seq.offset = uint32(best.rep) 323 if debugSequences { 324 println("repeat sequence", seq, "next s:", s) 325 } 326 blk.sequences = append(blk.sequences, seq) 327 328 // Index match start+1 (long) -> s - 1 329 index0 := s 330 s = best.s + best.length 331 332 nextEmit = s 333 if s >= sLimit { 334 if debugEncoder { 335 println("repeat ended", s, best.length) 336 337 } 338 break encodeLoop 339 } 340 // Index skipped... 341 off := index0 + e.cur 342 for index0 < s-1 { 343 cv0 := load6432(src, index0) 344 h0 := hashLen(cv0, bestLongTableBits, bestLongLen) 345 h1 := hashLen(cv0, bestShortTableBits, bestShortLen) 346 e.longTable[h0] = prevEntry{offset: off, prev: e.longTable[h0].offset} 347 e.table[h1] = prevEntry{offset: off, prev: e.table[h1].offset} 348 off++ 349 index0++ 350 } 351 switch best.rep { 352 case 2: 353 offset1, offset2 = offset2, offset1 354 case 3: 355 offset1, offset2, offset3 = offset3, offset1, offset2 356 } 357 cv = load6432(src, s) 358 continue 359 } 360 361 // A 4-byte match has been found. Update recent offsets. 362 // We'll later see if more than 4 bytes. 363 s = best.s 364 t := best.offset 365 offset1, offset2, offset3 = s-t, offset1, offset2 366 367 if debugAsserts && s <= t { 368 panic(fmt.Sprintf("s (%d) <= t (%d)", s, t)) 369 } 370 371 if debugAsserts && int(offset1) > len(src) { 372 panic("invalid offset") 373 } 374 375 // Extend the n-byte match as long as possible. 376 l := best.length 377 378 // Extend backwards 379 tMin := s - e.maxMatchOff 380 if tMin < 0 { 381 tMin = 0 382 } 383 for t > tMin && s > nextEmit && src[t-1] == src[s-1] && l < maxMatchLength { 384 s-- 385 t-- 386 l++ 387 } 388 389 // Write our sequence 390 var seq seq 391 seq.litLen = uint32(s - nextEmit) 392 seq.matchLen = uint32(l - zstdMinMatch) 393 if seq.litLen > 0 { 394 blk.literals = append(blk.literals, src[nextEmit:s]...) 395 } 396 seq.offset = uint32(s-t) + 3 397 s += l 398 if debugSequences { 399 println("sequence", seq, "next s:", s) 400 } 401 blk.sequences = append(blk.sequences, seq) 402 nextEmit = s 403 if s >= sLimit { 404 break encodeLoop 405 } 406 407 // Index match start+1 (long) -> s - 1 408 index0 := s - l + 1 409 // every entry 410 for index0 < s-1 { 411 cv0 := load6432(src, index0) 412 h0 := hashLen(cv0, bestLongTableBits, bestLongLen) 413 h1 := hashLen(cv0, bestShortTableBits, bestShortLen) 414 off := index0 + e.cur 415 e.longTable[h0] = prevEntry{offset: off, prev: e.longTable[h0].offset} 416 e.table[h1] = prevEntry{offset: off, prev: e.table[h1].offset} 417 index0++ 418 } 419 420 cv = load6432(src, s) 421 if !canRepeat { 422 continue 423 } 424 425 // Check offset 2 426 for { 427 o2 := s - offset2 428 if load3232(src, o2) != uint32(cv) { 429 // Do regular search 430 break 431 } 432 433 // Store this, since we have it. 434 nextHashS := hashLen(cv, bestShortTableBits, bestShortLen) 435 nextHashL := hashLen(cv, bestLongTableBits, bestLongLen) 436 437 // We have at least 4 byte match. 438 // No need to check backwards. We come straight from a match 439 l := 4 + e.matchlen(s+4, o2+4, src) 440 441 e.longTable[nextHashL] = prevEntry{offset: s + e.cur, prev: e.longTable[nextHashL].offset} 442 e.table[nextHashS] = prevEntry{offset: s + e.cur, prev: e.table[nextHashS].offset} 443 seq.matchLen = uint32(l) - zstdMinMatch 444 seq.litLen = 0 445 446 // Since litlen is always 0, this is offset 1. 447 seq.offset = 1 448 s += l 449 nextEmit = s 450 if debugSequences { 451 println("sequence", seq, "next s:", s) 452 } 453 blk.sequences = append(blk.sequences, seq) 454 455 // Swap offset 1 and 2. 456 offset1, offset2 = offset2, offset1 457 if s >= sLimit { 458 // Finished 459 break encodeLoop 460 } 461 cv = load6432(src, s) 462 } 463 } 464 465 if int(nextEmit) < len(src) { 466 blk.literals = append(blk.literals, src[nextEmit:]...) 467 blk.extraLits = len(src) - int(nextEmit) 468 } 469 blk.recentOffsets[0] = uint32(offset1) 470 blk.recentOffsets[1] = uint32(offset2) 471 blk.recentOffsets[2] = uint32(offset3) 472 if debugEncoder { 473 println("returning, recent offsets:", blk.recentOffsets, "extra literals:", blk.extraLits) 474 } 475} 476 477// EncodeNoHist will encode a block with no history and no following blocks. 478// Most notable difference is that src will not be copied for history and 479// we do not need to check for max match length. 480func (e *bestFastEncoder) EncodeNoHist(blk *blockEnc, src []byte) { 481 e.ensureHist(len(src)) 482 e.Encode(blk, src) 483} 484 485// Reset will reset and set a dictionary if not nil 486func (e *bestFastEncoder) Reset(d *dict, singleBlock bool) { 487 e.resetBase(d, singleBlock) 488 if d == nil { 489 return 490 } 491 // Init or copy dict table 492 if len(e.dictTable) != len(e.table) || d.id != e.lastDictID { 493 if len(e.dictTable) != len(e.table) { 494 e.dictTable = make([]prevEntry, len(e.table)) 495 } 496 end := int32(len(d.content)) - 8 + e.maxMatchOff 497 for i := e.maxMatchOff; i < end; i += 4 { 498 const hashLog = bestShortTableBits 499 500 cv := load6432(d.content, i-e.maxMatchOff) 501 nextHash := hashLen(cv, hashLog, bestShortLen) // 0 -> 4 502 nextHash1 := hashLen(cv>>8, hashLog, bestShortLen) // 1 -> 5 503 nextHash2 := hashLen(cv>>16, hashLog, bestShortLen) // 2 -> 6 504 nextHash3 := hashLen(cv>>24, hashLog, bestShortLen) // 3 -> 7 505 e.dictTable[nextHash] = prevEntry{ 506 prev: e.dictTable[nextHash].offset, 507 offset: i, 508 } 509 e.dictTable[nextHash1] = prevEntry{ 510 prev: e.dictTable[nextHash1].offset, 511 offset: i + 1, 512 } 513 e.dictTable[nextHash2] = prevEntry{ 514 prev: e.dictTable[nextHash2].offset, 515 offset: i + 2, 516 } 517 e.dictTable[nextHash3] = prevEntry{ 518 prev: e.dictTable[nextHash3].offset, 519 offset: i + 3, 520 } 521 } 522 e.lastDictID = d.id 523 } 524 525 // Init or copy dict table 526 if len(e.dictLongTable) != len(e.longTable) || d.id != e.lastDictID { 527 if len(e.dictLongTable) != len(e.longTable) { 528 e.dictLongTable = make([]prevEntry, len(e.longTable)) 529 } 530 if len(d.content) >= 8 { 531 cv := load6432(d.content, 0) 532 h := hashLen(cv, bestLongTableBits, bestLongLen) 533 e.dictLongTable[h] = prevEntry{ 534 offset: e.maxMatchOff, 535 prev: e.dictLongTable[h].offset, 536 } 537 538 end := int32(len(d.content)) - 8 + e.maxMatchOff 539 off := 8 // First to read 540 for i := e.maxMatchOff + 1; i < end; i++ { 541 cv = cv>>8 | (uint64(d.content[off]) << 56) 542 h := hashLen(cv, bestLongTableBits, bestLongLen) 543 e.dictLongTable[h] = prevEntry{ 544 offset: i, 545 prev: e.dictLongTable[h].offset, 546 } 547 off++ 548 } 549 } 550 e.lastDictID = d.id 551 } 552 // Reset table to initial state 553 copy(e.longTable[:], e.dictLongTable) 554 555 e.cur = e.maxMatchOff 556 // Reset table to initial state 557 copy(e.table[:], e.dictTable) 558} 559