1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
2 //
3 // Use of this source code is governed by a BSD-style
4 // license that can be found in the LICENSE file.
5
6 #include "lib/jxl/enc_ans.h"
7
8 #include <stdint.h>
9
10 #include <algorithm>
11 #include <array>
12 #include <cmath>
13 #include <limits>
14 #include <numeric>
15 #include <type_traits>
16 #include <unordered_map>
17 #include <utility>
18 #include <vector>
19
20 #include "lib/jxl/ans_common.h"
21 #include "lib/jxl/aux_out.h"
22 #include "lib/jxl/aux_out_fwd.h"
23 #include "lib/jxl/base/bits.h"
24 #include "lib/jxl/dec_ans.h"
25 #include "lib/jxl/enc_cluster.h"
26 #include "lib/jxl/enc_context_map.h"
27 #include "lib/jxl/enc_huffman.h"
28 #include "lib/jxl/fast_math-inl.h"
29 #include "lib/jxl/fields.h"
30
31 namespace jxl {
32
33 namespace {
34
35 bool ans_fuzzer_friendly_ = false;
36
37 static const int kMaxNumSymbolsForSmallCode = 4;
38
ANSBuildInfoTable(const ANSHistBin * counts,const AliasTable::Entry * table,size_t alphabet_size,size_t log_alpha_size,ANSEncSymbolInfo * info)39 void ANSBuildInfoTable(const ANSHistBin* counts, const AliasTable::Entry* table,
40 size_t alphabet_size, size_t log_alpha_size,
41 ANSEncSymbolInfo* info) {
42 size_t log_entry_size = ANS_LOG_TAB_SIZE - log_alpha_size;
43 size_t entry_size_minus_1 = (1 << log_entry_size) - 1;
44 // create valid alias table for empty streams.
45 for (size_t s = 0; s < std::max<size_t>(1, alphabet_size); ++s) {
46 const ANSHistBin freq = s == alphabet_size ? ANS_TAB_SIZE : counts[s];
47 info[s].freq_ = static_cast<uint16_t>(freq);
48 #ifdef USE_MULT_BY_RECIPROCAL
49 if (freq != 0) {
50 info[s].ifreq_ =
51 ((1ull << RECIPROCAL_PRECISION) + info[s].freq_ - 1) / info[s].freq_;
52 } else {
53 info[s].ifreq_ = 1; // shouldn't matter (symbol shouldn't occur), but...
54 }
55 #endif
56 info[s].reverse_map_.resize(freq);
57 }
58 for (int i = 0; i < ANS_TAB_SIZE; i++) {
59 AliasTable::Symbol s =
60 AliasTable::Lookup(table, i, log_entry_size, entry_size_minus_1);
61 info[s.value].reverse_map_[s.offset] = i;
62 }
63 }
64
EstimateDataBits(const ANSHistBin * histogram,const ANSHistBin * counts,size_t len)65 float EstimateDataBits(const ANSHistBin* histogram, const ANSHistBin* counts,
66 size_t len) {
67 float sum = 0.0f;
68 int total_histogram = 0;
69 int total_counts = 0;
70 for (size_t i = 0; i < len; ++i) {
71 total_histogram += histogram[i];
72 total_counts += counts[i];
73 if (histogram[i] > 0) {
74 JXL_ASSERT(counts[i] > 0);
75 // += histogram[i] * -log(counts[i]/total_counts)
76 sum += histogram[i] *
77 std::max(0.0f, ANS_LOG_TAB_SIZE - FastLog2f(counts[i]));
78 }
79 }
80 if (total_histogram > 0) {
81 JXL_ASSERT(total_counts == ANS_TAB_SIZE);
82 }
83 return sum;
84 }
85
EstimateDataBitsFlat(const ANSHistBin * histogram,size_t len)86 float EstimateDataBitsFlat(const ANSHistBin* histogram, size_t len) {
87 const float flat_bits = std::max(FastLog2f(len), 0.0f);
88 int total_histogram = 0;
89 for (size_t i = 0; i < len; ++i) {
90 total_histogram += histogram[i];
91 }
92 return total_histogram * flat_bits;
93 }
94
95 // Static Huffman code for encoding logcounts. The last symbol is used as RLE
96 // sequence.
97 static const uint8_t kLogCountBitLengths[ANS_LOG_TAB_SIZE + 2] = {
98 5, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 6, 7, 7,
99 };
100 static const uint8_t kLogCountSymbols[ANS_LOG_TAB_SIZE + 2] = {
101 17, 11, 15, 3, 9, 7, 4, 2, 5, 6, 0, 33, 1, 65,
102 };
103
104 // Returns the difference between largest count that can be represented and is
105 // smaller than "count" and smallest representable count larger than "count".
SmallestIncrement(uint32_t count,uint32_t shift)106 static int SmallestIncrement(uint32_t count, uint32_t shift) {
107 int bits = count == 0 ? -1 : FloorLog2Nonzero(count);
108 int drop_bits = bits - GetPopulationCountPrecision(bits, shift);
109 return drop_bits < 0 ? 1 : (1 << drop_bits);
110 }
111
112 template <bool minimize_error_of_sum>
RebalanceHistogram(const float * targets,int max_symbol,int table_size,uint32_t shift,int * omit_pos,ANSHistBin * counts)113 bool RebalanceHistogram(const float* targets, int max_symbol, int table_size,
114 uint32_t shift, int* omit_pos, ANSHistBin* counts) {
115 int sum = 0;
116 float sum_nonrounded = 0.0;
117 int remainder_pos = 0; // if all of them are handled in first loop
118 int remainder_log = -1;
119 for (int n = 0; n < max_symbol; ++n) {
120 if (targets[n] > 0 && targets[n] < 1.0f) {
121 counts[n] = 1;
122 sum_nonrounded += targets[n];
123 sum += counts[n];
124 }
125 }
126 const float discount_ratio =
127 (table_size - sum) / (table_size - sum_nonrounded);
128 JXL_ASSERT(discount_ratio > 0);
129 JXL_ASSERT(discount_ratio <= 1.0f);
130 // Invariant for minimize_error_of_sum == true:
131 // abs(sum - sum_nonrounded)
132 // <= SmallestIncrement(max(targets[])) + max_symbol
133 for (int n = 0; n < max_symbol; ++n) {
134 if (targets[n] >= 1.0f) {
135 sum_nonrounded += targets[n];
136 counts[n] =
137 static_cast<ANSHistBin>(targets[n] * discount_ratio); // truncate
138 if (counts[n] == 0) counts[n] = 1;
139 if (counts[n] == table_size) counts[n] = table_size - 1;
140 // Round the count to the closest nonzero multiple of SmallestIncrement
141 // (when minimize_error_of_sum is false) or one of two closest so as to
142 // keep the sum as close as possible to sum_nonrounded.
143 int inc = SmallestIncrement(counts[n], shift);
144 counts[n] -= counts[n] & (inc - 1);
145 // TODO(robryk): Should we rescale targets[n]?
146 const float target =
147 minimize_error_of_sum ? (sum_nonrounded - sum) : targets[n];
148 if (counts[n] == 0 ||
149 (target > counts[n] + inc / 2 && counts[n] + inc < table_size)) {
150 counts[n] += inc;
151 }
152 sum += counts[n];
153 const int count_log = FloorLog2Nonzero(static_cast<uint32_t>(counts[n]));
154 if (count_log > remainder_log) {
155 remainder_pos = n;
156 remainder_log = count_log;
157 }
158 }
159 }
160 JXL_ASSERT(remainder_pos != -1);
161 // NOTE: This is the only place where counts could go negative. We could
162 // detect that, return false and make ANSHistBin uint32_t.
163 counts[remainder_pos] -= sum - table_size;
164 *omit_pos = remainder_pos;
165 return counts[remainder_pos] > 0;
166 }
167
NormalizeCounts(ANSHistBin * counts,int * omit_pos,const int length,const int precision_bits,uint32_t shift,int * num_symbols,int * symbols)168 Status NormalizeCounts(ANSHistBin* counts, int* omit_pos, const int length,
169 const int precision_bits, uint32_t shift,
170 int* num_symbols, int* symbols) {
171 const int32_t table_size = 1 << precision_bits; // target sum / table size
172 uint64_t total = 0;
173 int max_symbol = 0;
174 int symbol_count = 0;
175 for (int n = 0; n < length; ++n) {
176 total += counts[n];
177 if (counts[n] > 0) {
178 if (symbol_count < kMaxNumSymbolsForSmallCode) {
179 symbols[symbol_count] = n;
180 }
181 ++symbol_count;
182 max_symbol = n + 1;
183 }
184 }
185 *num_symbols = symbol_count;
186 if (symbol_count == 0) {
187 return true;
188 }
189 if (symbol_count == 1) {
190 counts[symbols[0]] = table_size;
191 return true;
192 }
193 if (symbol_count > table_size)
194 return JXL_FAILURE("Too many entries in an ANS histogram");
195
196 const float norm = 1.f * table_size / total;
197 std::vector<float> targets(max_symbol);
198 for (size_t n = 0; n < targets.size(); ++n) {
199 targets[n] = norm * counts[n];
200 }
201 if (!RebalanceHistogram<false>(&targets[0], max_symbol, table_size, shift,
202 omit_pos, counts)) {
203 // Use an alternative rebalancing mechanism if the one above failed
204 // to create a histogram that is positive wherever the original one was.
205 if (!RebalanceHistogram<true>(&targets[0], max_symbol, table_size, shift,
206 omit_pos, counts)) {
207 return JXL_FAILURE("Logic error: couldn't rebalance a histogram");
208 }
209 }
210 return true;
211 }
212
213 struct SizeWriter {
214 size_t size = 0;
Writejxl::__anon551ae0000111::SizeWriter215 void Write(size_t num, size_t bits) { size += num; }
216 };
217
218 template <typename Writer>
StoreVarLenUint8(size_t n,Writer * writer)219 void StoreVarLenUint8(size_t n, Writer* writer) {
220 JXL_DASSERT(n <= 255);
221 if (n == 0) {
222 writer->Write(1, 0);
223 } else {
224 writer->Write(1, 1);
225 size_t nbits = FloorLog2Nonzero(n);
226 writer->Write(3, nbits);
227 writer->Write(nbits, n - (1ULL << nbits));
228 }
229 }
230
231 template <typename Writer>
StoreVarLenUint16(size_t n,Writer * writer)232 void StoreVarLenUint16(size_t n, Writer* writer) {
233 JXL_DASSERT(n <= 65535);
234 if (n == 0) {
235 writer->Write(1, 0);
236 } else {
237 writer->Write(1, 1);
238 size_t nbits = FloorLog2Nonzero(n);
239 writer->Write(4, nbits);
240 writer->Write(nbits, n - (1ULL << nbits));
241 }
242 }
243
244 template <typename Writer>
EncodeCounts(const ANSHistBin * counts,const int alphabet_size,const int omit_pos,const int num_symbols,uint32_t shift,const int * symbols,Writer * writer)245 bool EncodeCounts(const ANSHistBin* counts, const int alphabet_size,
246 const int omit_pos, const int num_symbols, uint32_t shift,
247 const int* symbols, Writer* writer) {
248 bool ok = true;
249 if (num_symbols <= 2) {
250 // Small tree marker to encode 1-2 symbols.
251 writer->Write(1, 1);
252 if (num_symbols == 0) {
253 writer->Write(1, 0);
254 StoreVarLenUint8(0, writer);
255 } else {
256 writer->Write(1, num_symbols - 1);
257 for (int i = 0; i < num_symbols; ++i) {
258 StoreVarLenUint8(symbols[i], writer);
259 }
260 }
261 if (num_symbols == 2) {
262 writer->Write(ANS_LOG_TAB_SIZE, counts[symbols[0]]);
263 }
264 } else {
265 // Mark non-small tree.
266 writer->Write(1, 0);
267 // Mark non-flat histogram.
268 writer->Write(1, 0);
269
270 // Precompute sequences for RLE encoding. Contains the number of identical
271 // values starting at a given index. Only contains the value at the first
272 // element of the series.
273 std::vector<uint32_t> same(alphabet_size, 0);
274 int last = 0;
275 for (int i = 1; i < alphabet_size; i++) {
276 // Store the sequence length once different symbol reached, or we're at
277 // the end, or the length is longer than we can encode, or we are at
278 // the omit_pos. We don't support including the omit_pos in an RLE
279 // sequence because this value may use a different amount of log2 bits
280 // than standard, it is too complex to handle in the decoder.
281 if (counts[i] != counts[last] || i + 1 == alphabet_size ||
282 (i - last) >= 255 || i == omit_pos || i == omit_pos + 1) {
283 same[last] = (i - last);
284 last = i + 1;
285 }
286 }
287
288 int length = 0;
289 std::vector<int> logcounts(alphabet_size);
290 int omit_log = 0;
291 for (int i = 0; i < alphabet_size; ++i) {
292 JXL_ASSERT(counts[i] <= ANS_TAB_SIZE);
293 JXL_ASSERT(counts[i] >= 0);
294 if (i == omit_pos) {
295 length = i + 1;
296 } else if (counts[i] > 0) {
297 logcounts[i] = FloorLog2Nonzero(static_cast<uint32_t>(counts[i])) + 1;
298 length = i + 1;
299 if (i < omit_pos) {
300 omit_log = std::max(omit_log, logcounts[i] + 1);
301 } else {
302 omit_log = std::max(omit_log, logcounts[i]);
303 }
304 }
305 }
306 logcounts[omit_pos] = omit_log;
307
308 // Elias gamma-like code for shift. Only difference is that if the number
309 // of bits to be encoded is equal to FloorLog2(ANS_LOG_TAB_SIZE+1), we skip
310 // the terminating 0 in unary coding.
311 int upper_bound_log = FloorLog2Nonzero(ANS_LOG_TAB_SIZE + 1);
312 int log = FloorLog2Nonzero(shift + 1);
313 writer->Write(log, (1 << log) - 1);
314 if (log != upper_bound_log) writer->Write(1, 0);
315 writer->Write(log, ((1 << log) - 1) & (shift + 1));
316
317 // Since num_symbols >= 3, we know that length >= 3, therefore we encode
318 // length - 3.
319 if (length - 3 > 255) {
320 // Pretend that everything is OK, but complain about correctness later.
321 StoreVarLenUint8(255, writer);
322 ok = false;
323 } else {
324 StoreVarLenUint8(length - 3, writer);
325 }
326
327 // The logcount values are encoded with a static Huffman code.
328 static const size_t kMinReps = 4;
329 size_t rep = ANS_LOG_TAB_SIZE + 1;
330 for (int i = 0; i < length; ++i) {
331 if (i > 0 && same[i - 1] > kMinReps) {
332 // Encode the RLE symbol and skip the repeated ones.
333 writer->Write(kLogCountBitLengths[rep], kLogCountSymbols[rep]);
334 StoreVarLenUint8(same[i - 1] - kMinReps - 1, writer);
335 i += same[i - 1] - 2;
336 continue;
337 }
338 writer->Write(kLogCountBitLengths[logcounts[i]],
339 kLogCountSymbols[logcounts[i]]);
340 }
341 for (int i = 0; i < length; ++i) {
342 if (i > 0 && same[i - 1] > kMinReps) {
343 // Skip symbols encoded by RLE.
344 i += same[i - 1] - 2;
345 continue;
346 }
347 if (logcounts[i] > 1 && i != omit_pos) {
348 int bitcount = GetPopulationCountPrecision(logcounts[i] - 1, shift);
349 int drop_bits = logcounts[i] - 1 - bitcount;
350 JXL_CHECK((counts[i] & ((1 << drop_bits) - 1)) == 0);
351 writer->Write(bitcount, (counts[i] >> drop_bits) - (1 << bitcount));
352 }
353 }
354 }
355 return ok;
356 }
357
EncodeFlatHistogram(const int alphabet_size,BitWriter * writer)358 void EncodeFlatHistogram(const int alphabet_size, BitWriter* writer) {
359 // Mark non-small tree.
360 writer->Write(1, 0);
361 // Mark uniform histogram.
362 writer->Write(1, 1);
363 JXL_ASSERT(alphabet_size > 0);
364 // Encode alphabet size.
365 StoreVarLenUint8(alphabet_size - 1, writer);
366 }
367
ComputeHistoAndDataCost(const ANSHistBin * histogram,size_t alphabet_size,uint32_t method)368 float ComputeHistoAndDataCost(const ANSHistBin* histogram, size_t alphabet_size,
369 uint32_t method) {
370 if (method == 0) { // Flat code
371 return ANS_LOG_TAB_SIZE + 2 +
372 EstimateDataBitsFlat(histogram, alphabet_size);
373 }
374 // Non-flat: shift = method-1.
375 uint32_t shift = method - 1;
376 std::vector<ANSHistBin> counts(histogram, histogram + alphabet_size);
377 int omit_pos = 0;
378 int num_symbols;
379 int symbols[kMaxNumSymbolsForSmallCode] = {};
380 JXL_CHECK(NormalizeCounts(counts.data(), &omit_pos, alphabet_size,
381 ANS_LOG_TAB_SIZE, shift, &num_symbols, symbols));
382 SizeWriter writer;
383 // Ignore the correctness, no real encoding happens at this stage.
384 (void)EncodeCounts(counts.data(), alphabet_size, omit_pos, num_symbols, shift,
385 symbols, &writer);
386 return writer.size +
387 EstimateDataBits(histogram, counts.data(), alphabet_size);
388 }
389
ComputeBestMethod(const ANSHistBin * histogram,size_t alphabet_size,float * cost,HistogramParams::ANSHistogramStrategy ans_histogram_strategy)390 uint32_t ComputeBestMethod(
391 const ANSHistBin* histogram, size_t alphabet_size, float* cost,
392 HistogramParams::ANSHistogramStrategy ans_histogram_strategy) {
393 size_t method = 0;
394 float fcost = ComputeHistoAndDataCost(histogram, alphabet_size, 0);
395 for (uint32_t shift = 0; shift <= ANS_LOG_TAB_SIZE;
396 ans_histogram_strategy != HistogramParams::ANSHistogramStrategy::kPrecise
397 ? shift += 2
398 : shift++) {
399 float c = ComputeHistoAndDataCost(histogram, alphabet_size, shift + 1);
400 if (c < fcost) {
401 method = shift + 1;
402 fcost = c;
403 } else if (ans_histogram_strategy ==
404 HistogramParams::ANSHistogramStrategy::kFast) {
405 // do not be as precise if estimating cost.
406 break;
407 }
408 }
409 *cost = fcost;
410 return method;
411 }
412
413 } // namespace
414
415 // Returns an estimate of the cost of encoding this histogram and the
416 // corresponding data.
BuildAndStoreANSEncodingData(HistogramParams::ANSHistogramStrategy ans_histogram_strategy,const ANSHistBin * histogram,size_t alphabet_size,size_t log_alpha_size,bool use_prefix_code,ANSEncSymbolInfo * info,BitWriter * writer)417 size_t BuildAndStoreANSEncodingData(
418 HistogramParams::ANSHistogramStrategy ans_histogram_strategy,
419 const ANSHistBin* histogram, size_t alphabet_size, size_t log_alpha_size,
420 bool use_prefix_code, ANSEncSymbolInfo* info, BitWriter* writer) {
421 if (use_prefix_code) {
422 if (alphabet_size <= 1) return 0;
423 std::vector<uint32_t> histo(alphabet_size);
424 for (size_t i = 0; i < alphabet_size; i++) {
425 histo[i] = histogram[i];
426 JXL_CHECK(histogram[i] >= 0);
427 }
428 size_t cost = 0;
429 {
430 std::vector<uint8_t> depths(alphabet_size);
431 std::vector<uint16_t> bits(alphabet_size);
432 BitWriter tmp_writer;
433 BitWriter* w = writer ? writer : &tmp_writer;
434 size_t start = w->BitsWritten();
435 BitWriter::Allotment allotment(
436 w, 8 * alphabet_size + 8); // safe upper bound
437 BuildAndStoreHuffmanTree(histo.data(), alphabet_size, depths.data(),
438 bits.data(), w);
439 ReclaimAndCharge(w, &allotment, 0, /*aux_out=*/nullptr);
440
441 for (size_t i = 0; i < alphabet_size; i++) {
442 info[i].bits = depths[i] == 0 ? 0 : bits[i];
443 info[i].depth = depths[i];
444 }
445 cost = w->BitsWritten() - start;
446 }
447 // Estimate data cost.
448 for (size_t i = 0; i < alphabet_size; i++) {
449 cost += histogram[i] * info[i].depth;
450 }
451 return cost;
452 }
453 JXL_ASSERT(alphabet_size <= ANS_TAB_SIZE);
454 // Ensure we ignore trailing zeros in the histogram.
455 if (alphabet_size != 0) {
456 size_t largest_symbol = 0;
457 for (size_t i = 0; i < alphabet_size; i++) {
458 if (histogram[i] != 0) largest_symbol = i;
459 }
460 alphabet_size = largest_symbol + 1;
461 }
462 float cost;
463 uint32_t method = ComputeBestMethod(histogram, alphabet_size, &cost,
464 ans_histogram_strategy);
465 JXL_ASSERT(cost >= 0);
466 int num_symbols;
467 int symbols[kMaxNumSymbolsForSmallCode] = {};
468 std::vector<ANSHistBin> counts(histogram, histogram + alphabet_size);
469 if (!counts.empty()) {
470 size_t sum = 0;
471 for (size_t i = 0; i < counts.size(); i++) {
472 sum += counts[i];
473 }
474 if (sum == 0) {
475 counts[0] = ANS_TAB_SIZE;
476 }
477 }
478 if (method == 0) {
479 counts = CreateFlatHistogram(alphabet_size, ANS_TAB_SIZE);
480 AliasTable::Entry a[ANS_MAX_ALPHABET_SIZE];
481 InitAliasTable(counts, ANS_TAB_SIZE, log_alpha_size, a);
482 ANSBuildInfoTable(counts.data(), a, alphabet_size, log_alpha_size, info);
483 if (writer != nullptr) {
484 EncodeFlatHistogram(alphabet_size, writer);
485 }
486 return cost;
487 }
488 int omit_pos = 0;
489 uint32_t shift = method - 1;
490 JXL_CHECK(NormalizeCounts(counts.data(), &omit_pos, alphabet_size,
491 ANS_LOG_TAB_SIZE, shift, &num_symbols, symbols));
492 AliasTable::Entry a[ANS_MAX_ALPHABET_SIZE];
493 InitAliasTable(counts, ANS_TAB_SIZE, log_alpha_size, a);
494 ANSBuildInfoTable(counts.data(), a, alphabet_size, log_alpha_size, info);
495 if (writer != nullptr) {
496 bool ok = EncodeCounts(counts.data(), alphabet_size, omit_pos, num_symbols,
497 shift, symbols, writer);
498 (void)ok;
499 JXL_DASSERT(ok);
500 }
501 return cost;
502 }
503
ANSPopulationCost(const ANSHistBin * data,size_t alphabet_size)504 float ANSPopulationCost(const ANSHistBin* data, size_t alphabet_size) {
505 float c;
506 ComputeBestMethod(data, alphabet_size, &c,
507 HistogramParams::ANSHistogramStrategy::kFast);
508 return c;
509 }
510
511 template <typename Writer>
EncodeUintConfig(const HybridUintConfig uint_config,Writer * writer,size_t log_alpha_size)512 void EncodeUintConfig(const HybridUintConfig uint_config, Writer* writer,
513 size_t log_alpha_size) {
514 writer->Write(CeilLog2Nonzero(log_alpha_size + 1),
515 uint_config.split_exponent);
516 if (uint_config.split_exponent == log_alpha_size) {
517 return; // msb/lsb don't matter.
518 }
519 size_t nbits = CeilLog2Nonzero(uint_config.split_exponent + 1);
520 writer->Write(nbits, uint_config.msb_in_token);
521 nbits = CeilLog2Nonzero(uint_config.split_exponent -
522 uint_config.msb_in_token + 1);
523 writer->Write(nbits, uint_config.lsb_in_token);
524 }
525 template <typename Writer>
EncodeUintConfigs(const std::vector<HybridUintConfig> & uint_config,Writer * writer,size_t log_alpha_size)526 void EncodeUintConfigs(const std::vector<HybridUintConfig>& uint_config,
527 Writer* writer, size_t log_alpha_size) {
528 // TODO(veluca): RLE?
529 for (size_t i = 0; i < uint_config.size(); i++) {
530 EncodeUintConfig(uint_config[i], writer, log_alpha_size);
531 }
532 }
533 template void EncodeUintConfigs(const std::vector<HybridUintConfig>&,
534 BitWriter*, size_t);
535
536 namespace {
537
ChooseUintConfigs(const HistogramParams & params,const std::vector<std::vector<Token>> & tokens,const std::vector<uint8_t> & context_map,std::vector<Histogram> * clustered_histograms,EntropyEncodingData * codes,size_t * log_alpha_size)538 void ChooseUintConfigs(const HistogramParams& params,
539 const std::vector<std::vector<Token>>& tokens,
540 const std::vector<uint8_t>& context_map,
541 std::vector<Histogram>* clustered_histograms,
542 EntropyEncodingData* codes, size_t* log_alpha_size) {
543 codes->uint_config.resize(clustered_histograms->size());
544 if (params.uint_method == HistogramParams::HybridUintMethod::kNone) return;
545 if (params.uint_method == HistogramParams::HybridUintMethod::kContextMap) {
546 codes->uint_config.clear();
547 codes->uint_config.resize(clustered_histograms->size(),
548 HybridUintConfig(2, 0, 1));
549 return;
550 }
551
552 // Brute-force method that tries a few options.
553 std::vector<HybridUintConfig> configs;
554 if (params.uint_method == HistogramParams::HybridUintMethod::kBest) {
555 configs = {
556 HybridUintConfig(4, 2, 0), // default
557 HybridUintConfig(4, 1, 0), // less precise
558 HybridUintConfig(4, 2, 1), // add sign
559 HybridUintConfig(4, 2, 2), // add sign+parity
560 HybridUintConfig(4, 1, 2), // add parity but less msb
561 // Same as above, but more direct coding.
562 HybridUintConfig(5, 2, 0), HybridUintConfig(5, 1, 0),
563 HybridUintConfig(5, 2, 1), HybridUintConfig(5, 2, 2),
564 HybridUintConfig(5, 1, 2),
565 // Same as above, but less direct coding.
566 HybridUintConfig(3, 2, 0), HybridUintConfig(3, 1, 0),
567 HybridUintConfig(3, 2, 1), HybridUintConfig(3, 1, 2),
568 // For near-lossless.
569 HybridUintConfig(4, 1, 3), HybridUintConfig(5, 1, 4),
570 HybridUintConfig(5, 2, 3), HybridUintConfig(6, 1, 5),
571 HybridUintConfig(6, 2, 4), HybridUintConfig(6, 0, 0),
572 // Other
573 HybridUintConfig(0, 0, 0), // varlenuint
574 HybridUintConfig(2, 0, 1), // works well for ctx map
575 HybridUintConfig(7, 0, 0), // direct coding
576 HybridUintConfig(8, 0, 0), // direct coding
577 HybridUintConfig(9, 0, 0), // direct coding
578 HybridUintConfig(10, 0, 0), // direct coding
579 HybridUintConfig(11, 0, 0), // direct coding
580 HybridUintConfig(12, 0, 0), // direct coding
581 };
582 } else if (params.uint_method == HistogramParams::HybridUintMethod::kFast) {
583 configs = {
584 HybridUintConfig(4, 2, 0), // default
585 HybridUintConfig(4, 1, 2), // add parity but less msb
586 HybridUintConfig(0, 0, 0), // smallest histograms
587 HybridUintConfig(2, 0, 1), // works well for ctx map
588 };
589 }
590
591 std::vector<float> costs(clustered_histograms->size(),
592 std::numeric_limits<float>::max());
593 std::vector<uint32_t> extra_bits(clustered_histograms->size());
594 std::vector<uint8_t> is_valid(clustered_histograms->size());
595 size_t max_alpha =
596 codes->use_prefix_code ? PREFIX_MAX_ALPHABET_SIZE : ANS_MAX_ALPHABET_SIZE;
597 for (HybridUintConfig cfg : configs) {
598 std::fill(is_valid.begin(), is_valid.end(), true);
599 std::fill(extra_bits.begin(), extra_bits.end(), 0);
600
601 for (size_t i = 0; i < clustered_histograms->size(); i++) {
602 (*clustered_histograms)[i].Clear();
603 }
604 for (size_t i = 0; i < tokens.size(); ++i) {
605 for (size_t j = 0; j < tokens[i].size(); ++j) {
606 const Token token = tokens[i][j];
607 // TODO(veluca): do not ignore lz77 commands.
608 if (token.is_lz77_length) continue;
609 size_t histo = context_map[token.context];
610 uint32_t tok, nbits, bits;
611 cfg.Encode(token.value, &tok, &nbits, &bits);
612 if (tok >= max_alpha ||
613 (codes->lz77.enabled && tok >= codes->lz77.min_symbol)) {
614 is_valid[histo] = false;
615 continue;
616 }
617 extra_bits[histo] += nbits;
618 (*clustered_histograms)[histo].Add(tok);
619 }
620 }
621
622 for (size_t i = 0; i < clustered_histograms->size(); i++) {
623 if (!is_valid[i]) continue;
624 float cost = (*clustered_histograms)[i].PopulationCost() + extra_bits[i];
625 if (cost < costs[i]) {
626 codes->uint_config[i] = cfg;
627 costs[i] = cost;
628 }
629 }
630 }
631
632 // Rebuild histograms.
633 for (size_t i = 0; i < clustered_histograms->size(); i++) {
634 (*clustered_histograms)[i].Clear();
635 }
636 *log_alpha_size = 4;
637 for (size_t i = 0; i < tokens.size(); ++i) {
638 for (size_t j = 0; j < tokens[i].size(); ++j) {
639 const Token token = tokens[i][j];
640 uint32_t tok, nbits, bits;
641 size_t histo = context_map[token.context];
642 (token.is_lz77_length ? codes->lz77.length_uint_config
643 : codes->uint_config[histo])
644 .Encode(token.value, &tok, &nbits, &bits);
645 tok += token.is_lz77_length ? codes->lz77.min_symbol : 0;
646 (*clustered_histograms)[histo].Add(tok);
647 while (tok >= (1u << *log_alpha_size)) (*log_alpha_size)++;
648 }
649 }
650 #if JXL_ENABLE_ASSERT
651 size_t max_log_alpha_size = codes->use_prefix_code ? PREFIX_MAX_BITS : 8;
652 JXL_ASSERT(*log_alpha_size <= max_log_alpha_size);
653 #endif
654 }
655
656 class HistogramBuilder {
657 public:
HistogramBuilder(const size_t num_contexts)658 explicit HistogramBuilder(const size_t num_contexts)
659 : histograms_(num_contexts) {}
660
VisitSymbol(int symbol,size_t histo_idx)661 void VisitSymbol(int symbol, size_t histo_idx) {
662 JXL_DASSERT(histo_idx < histograms_.size());
663 histograms_[histo_idx].Add(symbol);
664 }
665
666 // NOTE: `layer` is only for clustered_entropy; caller does ReclaimAndCharge.
BuildAndStoreEntropyCodes(const HistogramParams & params,const std::vector<std::vector<Token>> & tokens,EntropyEncodingData * codes,std::vector<uint8_t> * context_map,bool use_prefix_code,BitWriter * writer,size_t layer,AuxOut * aux_out) const667 size_t BuildAndStoreEntropyCodes(
668 const HistogramParams& params,
669 const std::vector<std::vector<Token>>& tokens, EntropyEncodingData* codes,
670 std::vector<uint8_t>* context_map, bool use_prefix_code,
671 BitWriter* writer, size_t layer, AuxOut* aux_out) const {
672 size_t cost = 0;
673 codes->encoding_info.clear();
674 std::vector<Histogram> clustered_histograms(histograms_);
675 context_map->resize(histograms_.size());
676 if (histograms_.size() > 1) {
677 if (!ans_fuzzer_friendly_) {
678 std::vector<uint32_t> histogram_symbols;
679 ClusterHistograms(params, histograms_, histograms_.size(),
680 kClustersLimit, &clustered_histograms,
681 &histogram_symbols);
682 for (size_t c = 0; c < histograms_.size(); ++c) {
683 (*context_map)[c] = static_cast<uint8_t>(histogram_symbols[c]);
684 }
685 } else {
686 fill(context_map->begin(), context_map->end(), 0);
687 size_t max_symbol = 0;
688 for (const Histogram& h : histograms_) {
689 max_symbol = std::max(h.data_.size(), max_symbol);
690 }
691 size_t num_symbols = 1 << CeilLog2Nonzero(max_symbol + 1);
692 clustered_histograms.resize(1);
693 clustered_histograms[0].Clear();
694 for (size_t i = 0; i < num_symbols; i++) {
695 clustered_histograms[0].Add(i);
696 }
697 }
698 if (writer != nullptr) {
699 EncodeContextMap(*context_map, clustered_histograms.size(), writer);
700 }
701 }
702 if (aux_out != nullptr) {
703 for (size_t i = 0; i < clustered_histograms.size(); ++i) {
704 aux_out->layers[layer].clustered_entropy +=
705 clustered_histograms[i].ShannonEntropy();
706 }
707 }
708 codes->use_prefix_code = use_prefix_code;
709 size_t log_alpha_size = codes->lz77.enabled ? 8 : 7; // Sane default.
710 if (ans_fuzzer_friendly_) {
711 codes->uint_config.clear();
712 codes->uint_config.resize(1, HybridUintConfig(7, 0, 0));
713 } else {
714 ChooseUintConfigs(params, tokens, *context_map, &clustered_histograms,
715 codes, &log_alpha_size);
716 }
717 if (log_alpha_size < 5) log_alpha_size = 5;
718 SizeWriter size_writer; // Used if writer == nullptr to estimate costs.
719 cost += 1;
720 if (writer) writer->Write(1, use_prefix_code);
721
722 if (use_prefix_code) {
723 log_alpha_size = PREFIX_MAX_BITS;
724 } else {
725 cost += 2;
726 }
727 if (writer == nullptr) {
728 EncodeUintConfigs(codes->uint_config, &size_writer, log_alpha_size);
729 } else {
730 if (!use_prefix_code) writer->Write(2, log_alpha_size - 5);
731 EncodeUintConfigs(codes->uint_config, writer, log_alpha_size);
732 }
733 if (use_prefix_code) {
734 for (size_t c = 0; c < clustered_histograms.size(); ++c) {
735 size_t num_symbol = 1;
736 for (size_t i = 0; i < clustered_histograms[c].data_.size(); i++) {
737 if (clustered_histograms[c].data_[i]) num_symbol = i + 1;
738 }
739 if (writer) {
740 StoreVarLenUint16(num_symbol - 1, writer);
741 } else {
742 StoreVarLenUint16(num_symbol - 1, &size_writer);
743 }
744 }
745 }
746 cost += size_writer.size;
747 for (size_t c = 0; c < clustered_histograms.size(); ++c) {
748 size_t num_symbol = 1;
749 for (size_t i = 0; i < clustered_histograms[c].data_.size(); i++) {
750 if (clustered_histograms[c].data_[i]) num_symbol = i + 1;
751 }
752 codes->encoding_info.emplace_back();
753 codes->encoding_info.back().resize(std::max<size_t>(1, num_symbol));
754
755 BitWriter::Allotment allotment(writer, 256 + num_symbol * 24);
756 cost += BuildAndStoreANSEncodingData(
757 params.ans_histogram_strategy, clustered_histograms[c].data_.data(),
758 num_symbol, log_alpha_size, use_prefix_code,
759 codes->encoding_info.back().data(), writer);
760 allotment.FinishedHistogram(writer);
761 ReclaimAndCharge(writer, &allotment, layer, aux_out);
762 }
763 return cost;
764 }
765
Histo(size_t i) const766 const Histogram& Histo(size_t i) const { return histograms_[i]; }
767
768 private:
769 std::vector<Histogram> histograms_;
770 };
771
772 class SymbolCostEstimator {
773 public:
SymbolCostEstimator(size_t num_contexts,bool force_huffman,const std::vector<std::vector<Token>> & tokens,const LZ77Params & lz77)774 SymbolCostEstimator(size_t num_contexts, bool force_huffman,
775 const std::vector<std::vector<Token>>& tokens,
776 const LZ77Params& lz77) {
777 HistogramBuilder builder(num_contexts);
778 // Build histograms for estimating lz77 savings.
779 HybridUintConfig uint_config;
780 for (size_t i = 0; i < tokens.size(); ++i) {
781 for (size_t j = 0; j < tokens[i].size(); ++j) {
782 const Token token = tokens[i][j];
783 uint32_t tok, nbits, bits;
784 (token.is_lz77_length ? lz77.length_uint_config : uint_config)
785 .Encode(token.value, &tok, &nbits, &bits);
786 tok += token.is_lz77_length ? lz77.min_symbol : 0;
787 builder.VisitSymbol(tok, token.context);
788 }
789 }
790 max_alphabet_size_ = 0;
791 for (size_t i = 0; i < num_contexts; i++) {
792 max_alphabet_size_ =
793 std::max(max_alphabet_size_, builder.Histo(i).data_.size());
794 }
795 bits_.resize(num_contexts * max_alphabet_size_);
796 // TODO(veluca): SIMD?
797 add_symbol_cost_.resize(num_contexts);
798 for (size_t i = 0; i < num_contexts; i++) {
799 float inv_total = 1.0f / (builder.Histo(i).total_count_ + 1e-8f);
800 float total_cost = 0;
801 for (size_t j = 0; j < builder.Histo(i).data_.size(); j++) {
802 size_t cnt = builder.Histo(i).data_[j];
803 float cost = 0;
804 if (cnt != 0 && cnt != builder.Histo(i).total_count_) {
805 cost = -FastLog2f(cnt * inv_total);
806 if (force_huffman) cost = std::ceil(cost);
807 } else if (cnt == 0) {
808 cost = ANS_LOG_TAB_SIZE; // Highest possible cost.
809 }
810 bits_[i * max_alphabet_size_ + j] = cost;
811 total_cost += cost * builder.Histo(i).data_[j];
812 }
813 // Penalty for adding a lz77 symbol to this contest (only used for static
814 // cost model). Higher penalty for contexts that have a very low
815 // per-symbol entropy.
816 add_symbol_cost_[i] = std::max(0.0f, 6.0f - total_cost * inv_total);
817 }
818 }
Bits(size_t ctx,size_t sym) const819 float Bits(size_t ctx, size_t sym) const {
820 return bits_[ctx * max_alphabet_size_ + sym];
821 }
LenCost(size_t ctx,size_t len,const LZ77Params & lz77) const822 float LenCost(size_t ctx, size_t len, const LZ77Params& lz77) const {
823 uint32_t nbits, bits, tok;
824 lz77.length_uint_config.Encode(len, &tok, &nbits, &bits);
825 tok += lz77.min_symbol;
826 return nbits + Bits(ctx, tok);
827 }
DistCost(size_t len,const LZ77Params & lz77) const828 float DistCost(size_t len, const LZ77Params& lz77) const {
829 uint32_t nbits, bits, tok;
830 HybridUintConfig().Encode(len, &tok, &nbits, &bits);
831 return nbits + Bits(lz77.nonserialized_distance_context, tok);
832 }
AddSymbolCost(size_t idx) const833 float AddSymbolCost(size_t idx) const { return add_symbol_cost_[idx]; }
834
835 private:
836 size_t max_alphabet_size_;
837 std::vector<float> bits_;
838 std::vector<float> add_symbol_cost_;
839 };
840
ApplyLZ77_RLE(const HistogramParams & params,size_t num_contexts,const std::vector<std::vector<Token>> & tokens,LZ77Params & lz77,std::vector<std::vector<Token>> & tokens_lz77)841 void ApplyLZ77_RLE(const HistogramParams& params, size_t num_contexts,
842 const std::vector<std::vector<Token>>& tokens,
843 LZ77Params& lz77,
844 std::vector<std::vector<Token>>& tokens_lz77) {
845 // TODO(veluca): tune heuristics here.
846 SymbolCostEstimator sce(num_contexts, params.force_huffman, tokens, lz77);
847 float bit_decrease = 0;
848 size_t total_symbols = 0;
849 tokens_lz77.resize(tokens.size());
850 std::vector<float> sym_cost;
851 HybridUintConfig uint_config;
852 for (size_t stream = 0; stream < tokens.size(); stream++) {
853 size_t distance_multiplier =
854 params.image_widths.size() > stream ? params.image_widths[stream] : 0;
855 const auto& in = tokens[stream];
856 auto& out = tokens_lz77[stream];
857 total_symbols += in.size();
858 // Cumulative sum of bit costs.
859 sym_cost.resize(in.size() + 1);
860 for (size_t i = 0; i < in.size(); i++) {
861 uint32_t tok, nbits, unused_bits;
862 uint_config.Encode(in[i].value, &tok, &nbits, &unused_bits);
863 sym_cost[i + 1] = sce.Bits(in[i].context, tok) + nbits + sym_cost[i];
864 }
865 out.reserve(in.size());
866 for (size_t i = 0; i < in.size(); i++) {
867 size_t num_to_copy = 0;
868 size_t distance_symbol = 0; // 1 for RLE.
869 if (distance_multiplier != 0) {
870 distance_symbol = 1; // Special distance 1 if enabled.
871 JXL_DASSERT(kSpecialDistances[1][0] == 1);
872 JXL_DASSERT(kSpecialDistances[1][1] == 0);
873 }
874 if (i > 0) {
875 for (; i + num_to_copy < in.size(); num_to_copy++) {
876 if (in[i + num_to_copy].value != in[i - 1].value) {
877 break;
878 }
879 }
880 }
881 if (num_to_copy == 0) {
882 out.push_back(in[i]);
883 continue;
884 }
885 float cost = sym_cost[i + num_to_copy] - sym_cost[i];
886 // This subtraction might overflow, but that's OK.
887 size_t lz77_len = num_to_copy - lz77.min_length;
888 float lz77_cost = num_to_copy >= lz77.min_length
889 ? CeilLog2Nonzero(lz77_len + 1) + 1
890 : 0;
891 if (num_to_copy < lz77.min_length || cost <= lz77_cost) {
892 for (size_t j = 0; j < num_to_copy; j++) {
893 out.push_back(in[i + j]);
894 }
895 i += num_to_copy - 1;
896 continue;
897 }
898 // Output the LZ77 length
899 out.emplace_back(in[i].context, lz77_len);
900 out.back().is_lz77_length = true;
901 i += num_to_copy - 1;
902 bit_decrease += cost - lz77_cost;
903 // Output the LZ77 copy distance.
904 out.emplace_back(lz77.nonserialized_distance_context, distance_symbol);
905 }
906 }
907
908 if (bit_decrease > total_symbols * 0.2 + 16) {
909 lz77.enabled = true;
910 }
911 }
912
913 // Hash chain for LZ77 matching
914 struct HashChain {
915 size_t size_;
916 std::vector<uint32_t> data_;
917
918 unsigned hash_num_values_ = 32768;
919 unsigned hash_mask_ = hash_num_values_ - 1;
920 unsigned hash_shift_ = 5;
921
922 std::vector<int> head;
923 std::vector<uint32_t> chain;
924 std::vector<int> val;
925
926 // Speed up repetitions of zero
927 std::vector<int> headz;
928 std::vector<uint32_t> chainz;
929 std::vector<uint32_t> zeros;
930 uint32_t numzeros = 0;
931
932 size_t window_size_;
933 size_t window_mask_;
934 size_t min_length_;
935 size_t max_length_;
936
937 // Map of special distance codes.
938 std::unordered_map<int, int> special_dist_table_;
939 size_t num_special_distances_ = 0;
940
941 uint32_t maxchainlength = 256; // window_size_ to allow all
942
HashChainjxl::__anon551ae0000211::HashChain943 HashChain(const Token* data, size_t size, size_t window_size,
944 size_t min_length, size_t max_length, size_t distance_multiplier)
945 : size_(size),
946 window_size_(window_size),
947 window_mask_(window_size - 1),
948 min_length_(min_length),
949 max_length_(max_length) {
950 data_.resize(size);
951 for (size_t i = 0; i < size; i++) {
952 data_[i] = data[i].value;
953 }
954
955 head.resize(hash_num_values_, -1);
956 val.resize(window_size_, -1);
957 chain.resize(window_size_);
958 for (uint32_t i = 0; i < window_size_; ++i) {
959 chain[i] = i; // same value as index indicates uninitialized
960 }
961
962 zeros.resize(window_size_);
963 headz.resize(window_size_ + 1, -1);
964 chainz.resize(window_size_);
965 for (uint32_t i = 0; i < window_size_; ++i) {
966 chainz[i] = i;
967 }
968 // Translate distance to special distance code.
969 if (distance_multiplier) {
970 // Count down, so if due to small distance multiplier multiple distances
971 // map to the same code, the smallest code will be used in the end.
972 for (int i = kNumSpecialDistances - 1; i >= 0; --i) {
973 int xi = kSpecialDistances[i][0];
974 int yi = kSpecialDistances[i][1];
975 int distance = yi * distance_multiplier + xi;
976 // Ensure that we map distance 1 to the lowest symbols.
977 if (distance < 1) distance = 1;
978 special_dist_table_[distance] = i;
979 }
980 num_special_distances_ = kNumSpecialDistances;
981 }
982 }
983
GetHashjxl::__anon551ae0000211::HashChain984 uint32_t GetHash(size_t pos) const {
985 uint32_t result = 0;
986 if (pos + 2 < size_) {
987 // TODO(lode): take the MSB's of the uint32_t values into account as well,
988 // given that the hash code itself is less than 32 bits.
989 result ^= (uint32_t)(data_[pos + 0] << 0u);
990 result ^= (uint32_t)(data_[pos + 1] << hash_shift_);
991 result ^= (uint32_t)(data_[pos + 2] << (hash_shift_ * 2));
992 } else {
993 // No need to compute hash of last 2 bytes, the length 2 is too short.
994 return 0;
995 }
996 return result & hash_mask_;
997 }
998
CountZerosjxl::__anon551ae0000211::HashChain999 uint32_t CountZeros(size_t pos, uint32_t prevzeros) const {
1000 size_t end = pos + window_size_;
1001 if (end > size_) end = size_;
1002 if (prevzeros > 0) {
1003 if (prevzeros >= window_mask_ && data_[end - 1] == 0 &&
1004 end == pos + window_size_) {
1005 return prevzeros;
1006 } else {
1007 return prevzeros - 1;
1008 }
1009 }
1010 uint32_t num = 0;
1011 while (pos + num < end && data_[pos + num] == 0) num++;
1012 return num;
1013 }
1014
Updatejxl::__anon551ae0000211::HashChain1015 void Update(size_t pos) {
1016 uint32_t hashval = GetHash(pos);
1017 uint32_t wpos = pos & window_mask_;
1018
1019 val[wpos] = (int)hashval;
1020 if (head[hashval] != -1) chain[wpos] = head[hashval];
1021 head[hashval] = wpos;
1022
1023 if (pos > 0 && data_[pos] != data_[pos - 1]) numzeros = 0;
1024 numzeros = CountZeros(pos, numzeros);
1025
1026 zeros[wpos] = numzeros;
1027 if (headz[numzeros] != -1) chainz[wpos] = headz[numzeros];
1028 headz[numzeros] = wpos;
1029 }
1030
Updatejxl::__anon551ae0000211::HashChain1031 void Update(size_t pos, size_t len) {
1032 for (size_t i = 0; i < len; i++) {
1033 Update(pos + i);
1034 }
1035 }
1036
1037 template <typename CB>
FindMatchesjxl::__anon551ae0000211::HashChain1038 void FindMatches(size_t pos, int max_dist, const CB& found_match) const {
1039 uint32_t wpos = pos & window_mask_;
1040 uint32_t hashval = GetHash(pos);
1041 uint32_t hashpos = chain[wpos];
1042
1043 int prev_dist = 0;
1044 int end = std::min<int>(pos + max_length_, size_);
1045 uint32_t chainlength = 0;
1046 uint32_t best_len = 0;
1047 for (;;) {
1048 int dist = (hashpos <= wpos) ? (wpos - hashpos)
1049 : (wpos - hashpos + window_mask_ + 1);
1050 if (dist < prev_dist) break;
1051 prev_dist = dist;
1052 uint32_t len = 0;
1053 if (dist > 0) {
1054 int i = pos;
1055 int j = pos - dist;
1056 if (numzeros > 3) {
1057 int r = std::min<int>(numzeros - 1, zeros[hashpos]);
1058 if (i + r >= end) r = end - i - 1;
1059 i += r;
1060 j += r;
1061 }
1062 while (i < end && data_[i] == data_[j]) {
1063 i++;
1064 j++;
1065 }
1066 len = i - pos;
1067 // This can trigger even if the new length is slightly smaller than the
1068 // best length, because it is possible for a slightly cheaper distance
1069 // symbol to occur.
1070 if (len >= min_length_ && len + 2 >= best_len) {
1071 auto it = special_dist_table_.find(dist);
1072 int dist_symbol = (it == special_dist_table_.end())
1073 ? (num_special_distances_ + dist - 1)
1074 : it->second;
1075 found_match(len, dist_symbol);
1076 if (len > best_len) best_len = len;
1077 }
1078 }
1079
1080 chainlength++;
1081 if (chainlength >= maxchainlength) break;
1082
1083 if (numzeros >= 3 && len > numzeros) {
1084 if (hashpos == chainz[hashpos]) break;
1085 hashpos = chainz[hashpos];
1086 if (zeros[hashpos] != numzeros) break;
1087 } else {
1088 if (hashpos == chain[hashpos]) break;
1089 hashpos = chain[hashpos];
1090 if (val[hashpos] != (int)hashval) break; // outdated hash value
1091 }
1092 }
1093 }
FindMatchjxl::__anon551ae0000211::HashChain1094 void FindMatch(size_t pos, int max_dist, size_t* result_dist_symbol,
1095 size_t* result_len) const {
1096 *result_dist_symbol = 0;
1097 *result_len = 1;
1098 FindMatches(pos, max_dist, [&](size_t len, size_t dist_symbol) {
1099 if (len > *result_len ||
1100 (len == *result_len && *result_dist_symbol > dist_symbol)) {
1101 *result_len = len;
1102 *result_dist_symbol = dist_symbol;
1103 }
1104 });
1105 }
1106 };
1107
LenCost(size_t len)1108 float LenCost(size_t len) {
1109 uint32_t nbits, bits, tok;
1110 HybridUintConfig(1, 0, 0).Encode(len, &tok, &nbits, &bits);
1111 constexpr float kCostTable[] = {
1112 2.797667318563126, 3.213177690381199, 2.5706009246743737,
1113 2.408392498667534, 2.829649191872326, 3.3923087753324577,
1114 4.029267451554331, 4.415576699706408, 4.509357574741465,
1115 9.21481543803004, 10.020590190114898, 11.858671627804766,
1116 12.45853300490526, 11.713105831990857, 12.561996324849314,
1117 13.775477692278367, 13.174027068768641,
1118 };
1119 size_t table_size = sizeof kCostTable / sizeof *kCostTable;
1120 if (tok >= table_size) tok = table_size - 1;
1121 return kCostTable[tok] + nbits;
1122 }
1123
1124 // TODO(veluca): this does not take into account usage or non-usage of distance
1125 // multipliers.
DistCost(size_t dist)1126 float DistCost(size_t dist) {
1127 uint32_t nbits, bits, tok;
1128 HybridUintConfig(7, 0, 0).Encode(dist, &tok, &nbits, &bits);
1129 constexpr float kCostTable[] = {
1130 6.368282626312716, 5.680793277090298, 8.347404197105247,
1131 7.641619201599141, 6.914328374119438, 7.959808291537444,
1132 8.70023120759855, 8.71378518934703, 9.379132523982769,
1133 9.110472749092708, 9.159029569270908, 9.430936766731973,
1134 7.278284055315169, 7.8278514904267755, 10.026641158289236,
1135 9.976049229827066, 9.64351607048908, 9.563403863480442,
1136 10.171474111762747, 10.45950155077234, 9.994813912104219,
1137 10.322524683741156, 8.465808729388186, 8.756254166066853,
1138 10.160930174662234, 10.247329273413435, 10.04090403724809,
1139 10.129398517544082, 9.342311691539546, 9.07608009102374,
1140 10.104799540677513, 10.378079384990906, 10.165828974075072,
1141 10.337595322341553, 7.940557464567944, 10.575665823319431,
1142 11.023344321751955, 10.736144698831827, 11.118277044595054,
1143 7.468468230648442, 10.738305230932939, 10.906980780216568,
1144 10.163468216353817, 10.17805759656433, 11.167283670483565,
1145 11.147050200274544, 10.517921919244333, 10.651764778156886,
1146 10.17074446448919, 11.217636876224745, 11.261630721139484,
1147 11.403140815247259, 10.892472096873417, 11.1859607804481,
1148 8.017346947551262, 7.895143720278828, 11.036577113822025,
1149 11.170562110315794, 10.326988722591086, 10.40872184751056,
1150 11.213498225466386, 11.30580635516863, 10.672272515665442,
1151 10.768069466228063, 11.145257364153565, 11.64668307145549,
1152 10.593156194627339, 11.207499484844943, 10.767517766396908,
1153 10.826629811407042, 10.737764794499988, 10.6200448518045,
1154 10.191315385198092, 8.468384171390085, 11.731295299170432,
1155 11.824619886654398, 10.41518844301179, 10.16310536548649,
1156 10.539423685097576, 10.495136599328031, 10.469112847728267,
1157 11.72057686174922, 10.910326337834674, 11.378921834673758,
1158 11.847759036098536, 11.92071647623854, 10.810628276345282,
1159 11.008601085273893, 11.910326337834674, 11.949212023423133,
1160 11.298614839104337, 11.611603659010392, 10.472930394619985,
1161 11.835564720850282, 11.523267392285337, 12.01055816679611,
1162 8.413029688994023, 11.895784139536406, 11.984679534970505,
1163 11.220654278717394, 11.716311684833672, 10.61036646226114,
1164 10.89849965960364, 10.203762898863669, 10.997560826267238,
1165 11.484217379438984, 11.792836176993665, 12.24310468755171,
1166 11.464858097919262, 12.212747017409377, 11.425595666074955,
1167 11.572048533398757, 12.742093965163013, 11.381874288645637,
1168 12.191870445817015, 11.683156920035426, 11.152442115262197,
1169 11.90303691580457, 11.653292787169159, 11.938615382266098,
1170 16.970641701570223, 16.853602280380002, 17.26240782594733,
1171 16.644655390108507, 17.14310889757499, 16.910935455445955,
1172 17.505678976959697, 17.213498225466388, 2.4162310293553024,
1173 3.494587244462329, 3.5258600986408344, 3.4959806589517095,
1174 3.098390886949687, 3.343454654302911, 3.588847442290287,
1175 4.14614790111827, 5.152948641990529, 7.433696808092598,
1176 9.716311684833672,
1177 };
1178 size_t table_size = sizeof kCostTable / sizeof *kCostTable;
1179 if (tok >= table_size) tok = table_size - 1;
1180 return kCostTable[tok] + nbits;
1181 }
1182
ApplyLZ77_LZ77(const HistogramParams & params,size_t num_contexts,const std::vector<std::vector<Token>> & tokens,LZ77Params & lz77,std::vector<std::vector<Token>> & tokens_lz77)1183 void ApplyLZ77_LZ77(const HistogramParams& params, size_t num_contexts,
1184 const std::vector<std::vector<Token>>& tokens,
1185 LZ77Params& lz77,
1186 std::vector<std::vector<Token>>& tokens_lz77) {
1187 // TODO(veluca): tune heuristics here.
1188 SymbolCostEstimator sce(num_contexts, params.force_huffman, tokens, lz77);
1189 float bit_decrease = 0;
1190 size_t total_symbols = 0;
1191 tokens_lz77.resize(tokens.size());
1192 HybridUintConfig uint_config;
1193 std::vector<float> sym_cost;
1194 for (size_t stream = 0; stream < tokens.size(); stream++) {
1195 size_t distance_multiplier =
1196 params.image_widths.size() > stream ? params.image_widths[stream] : 0;
1197 const auto& in = tokens[stream];
1198 auto& out = tokens_lz77[stream];
1199 total_symbols += in.size();
1200 // Cumulative sum of bit costs.
1201 sym_cost.resize(in.size() + 1);
1202 for (size_t i = 0; i < in.size(); i++) {
1203 uint32_t tok, nbits, unused_bits;
1204 uint_config.Encode(in[i].value, &tok, &nbits, &unused_bits);
1205 sym_cost[i + 1] = sce.Bits(in[i].context, tok) + nbits + sym_cost[i];
1206 }
1207
1208 out.reserve(in.size());
1209 size_t max_distance = in.size();
1210 size_t min_length = lz77.min_length;
1211 JXL_ASSERT(min_length >= 3);
1212 size_t max_length = in.size();
1213
1214 // Use next power of two as window size.
1215 size_t window_size = 1;
1216 while (window_size < max_distance && window_size < kWindowSize) {
1217 window_size <<= 1;
1218 }
1219
1220 HashChain chain(in.data(), in.size(), window_size, min_length, max_length,
1221 distance_multiplier);
1222 size_t len, dist_symbol;
1223
1224 const size_t max_lazy_match_len = 256; // 0 to disable lazy matching
1225
1226 // Whether the next symbol was already updated (to test lazy matching)
1227 bool already_updated = false;
1228 for (size_t i = 0; i < in.size(); i++) {
1229 out.push_back(in[i]);
1230 if (!already_updated) chain.Update(i);
1231 already_updated = false;
1232 chain.FindMatch(i, max_distance, &dist_symbol, &len);
1233 if (len >= min_length) {
1234 if (len < max_lazy_match_len && i + 1 < in.size()) {
1235 // Try length at next symbol lazy matching
1236 chain.Update(i + 1);
1237 already_updated = true;
1238 size_t len2, dist_symbol2;
1239 chain.FindMatch(i + 1, max_distance, &dist_symbol2, &len2);
1240 if (len2 > len) {
1241 // Use the lazy match. Add literal, and use the next length starting
1242 // from the next byte.
1243 ++i;
1244 already_updated = false;
1245 len = len2;
1246 dist_symbol = dist_symbol2;
1247 out.push_back(in[i]);
1248 }
1249 }
1250
1251 float cost = sym_cost[i + len] - sym_cost[i];
1252 size_t lz77_len = len - lz77.min_length;
1253 float lz77_cost = LenCost(lz77_len) + DistCost(dist_symbol) +
1254 sce.AddSymbolCost(out.back().context);
1255
1256 if (lz77_cost <= cost) {
1257 out.back().value = len - min_length;
1258 out.back().is_lz77_length = true;
1259 out.emplace_back(lz77.nonserialized_distance_context, dist_symbol);
1260 bit_decrease += cost - lz77_cost;
1261 } else {
1262 // LZ77 match ignored, and symbol already pushed. Push all other
1263 // symbols and skip.
1264 for (size_t j = 1; j < len; j++) {
1265 out.push_back(in[i + j]);
1266 }
1267 }
1268
1269 if (already_updated) {
1270 chain.Update(i + 2, len - 2);
1271 already_updated = false;
1272 } else {
1273 chain.Update(i + 1, len - 1);
1274 }
1275 i += len - 1;
1276 } else {
1277 // Literal, already pushed
1278 }
1279 }
1280 }
1281
1282 if (bit_decrease > total_symbols * 0.2 + 16) {
1283 lz77.enabled = true;
1284 }
1285 }
1286
ApplyLZ77_Optimal(const HistogramParams & params,size_t num_contexts,const std::vector<std::vector<Token>> & tokens,LZ77Params & lz77,std::vector<std::vector<Token>> & tokens_lz77)1287 void ApplyLZ77_Optimal(const HistogramParams& params, size_t num_contexts,
1288 const std::vector<std::vector<Token>>& tokens,
1289 LZ77Params& lz77,
1290 std::vector<std::vector<Token>>& tokens_lz77) {
1291 std::vector<std::vector<Token>> tokens_for_cost_estimate;
1292 ApplyLZ77_LZ77(params, num_contexts, tokens, lz77, tokens_for_cost_estimate);
1293 // If greedy-LZ77 does not give better compression than no-lz77, no reason to
1294 // run the optimal matching.
1295 if (!lz77.enabled) return;
1296 SymbolCostEstimator sce(num_contexts + 1, params.force_huffman,
1297 tokens_for_cost_estimate, lz77);
1298 tokens_lz77.resize(tokens.size());
1299 HybridUintConfig uint_config;
1300 std::vector<float> sym_cost;
1301 std::vector<uint32_t> dist_symbols;
1302 for (size_t stream = 0; stream < tokens.size(); stream++) {
1303 size_t distance_multiplier =
1304 params.image_widths.size() > stream ? params.image_widths[stream] : 0;
1305 const auto& in = tokens[stream];
1306 auto& out = tokens_lz77[stream];
1307 // Cumulative sum of bit costs.
1308 sym_cost.resize(in.size() + 1);
1309 for (size_t i = 0; i < in.size(); i++) {
1310 uint32_t tok, nbits, unused_bits;
1311 uint_config.Encode(in[i].value, &tok, &nbits, &unused_bits);
1312 sym_cost[i + 1] = sce.Bits(in[i].context, tok) + nbits + sym_cost[i];
1313 }
1314
1315 out.reserve(in.size());
1316 size_t max_distance = in.size();
1317 size_t min_length = lz77.min_length;
1318 JXL_ASSERT(min_length >= 3);
1319 size_t max_length = in.size();
1320
1321 // Use next power of two as window size.
1322 size_t window_size = 1;
1323 while (window_size < max_distance && window_size < kWindowSize) {
1324 window_size <<= 1;
1325 }
1326
1327 HashChain chain(in.data(), in.size(), window_size, min_length, max_length,
1328 distance_multiplier);
1329
1330 struct MatchInfo {
1331 uint32_t len;
1332 uint32_t dist_symbol;
1333 uint32_t ctx;
1334 float total_cost = std::numeric_limits<float>::max();
1335 };
1336 // Total cost to encode the first N symbols.
1337 std::vector<MatchInfo> prefix_costs(in.size() + 1);
1338 prefix_costs[0].total_cost = 0;
1339
1340 size_t rle_length = 0;
1341 size_t skip_lz77 = 0;
1342 for (size_t i = 0; i < in.size(); i++) {
1343 chain.Update(i);
1344 float lit_cost =
1345 prefix_costs[i].total_cost + sym_cost[i + 1] - sym_cost[i];
1346 if (prefix_costs[i + 1].total_cost > lit_cost) {
1347 prefix_costs[i + 1].dist_symbol = 0;
1348 prefix_costs[i + 1].len = 1;
1349 prefix_costs[i + 1].ctx = in[i].context;
1350 prefix_costs[i + 1].total_cost = lit_cost;
1351 }
1352 if (skip_lz77 > 0) {
1353 skip_lz77--;
1354 continue;
1355 }
1356 dist_symbols.clear();
1357 chain.FindMatches(i, max_distance,
1358 [&dist_symbols](size_t len, size_t dist_symbol) {
1359 if (dist_symbols.size() <= len) {
1360 dist_symbols.resize(len + 1, dist_symbol);
1361 }
1362 if (dist_symbol < dist_symbols[len]) {
1363 dist_symbols[len] = dist_symbol;
1364 }
1365 });
1366 if (dist_symbols.size() <= min_length) continue;
1367 {
1368 size_t best_cost = dist_symbols.back();
1369 for (size_t j = dist_symbols.size() - 1; j >= min_length; j--) {
1370 if (dist_symbols[j] < best_cost) {
1371 best_cost = dist_symbols[j];
1372 }
1373 dist_symbols[j] = best_cost;
1374 }
1375 }
1376 for (size_t j = min_length; j < dist_symbols.size(); j++) {
1377 // Cost model that uses results from lazy LZ77.
1378 float lz77_cost = sce.LenCost(in[i].context, j - min_length, lz77) +
1379 sce.DistCost(dist_symbols[j], lz77);
1380 float cost = prefix_costs[i].total_cost + lz77_cost;
1381 if (prefix_costs[i + j].total_cost > cost) {
1382 prefix_costs[i + j].len = j;
1383 prefix_costs[i + j].dist_symbol = dist_symbols[j] + 1;
1384 prefix_costs[i + j].ctx = in[i].context;
1385 prefix_costs[i + j].total_cost = cost;
1386 }
1387 }
1388 // We are in a RLE sequence: skip all the symbols except the first 8 and
1389 // the last 8. This avoid quadratic costs for sequences with long runs of
1390 // the same symbol.
1391 if ((dist_symbols.back() == 0 && distance_multiplier == 0) ||
1392 (dist_symbols.back() == 1 && distance_multiplier != 0)) {
1393 rle_length++;
1394 } else {
1395 rle_length = 0;
1396 }
1397 if (rle_length >= 8 && dist_symbols.size() > 9) {
1398 skip_lz77 = dist_symbols.size() - 10;
1399 rle_length = 0;
1400 }
1401 }
1402 size_t pos = in.size();
1403 while (pos > 0) {
1404 bool is_lz77_length = prefix_costs[pos].dist_symbol != 0;
1405 if (is_lz77_length) {
1406 size_t dist_symbol = prefix_costs[pos].dist_symbol - 1;
1407 out.emplace_back(lz77.nonserialized_distance_context, dist_symbol);
1408 }
1409 size_t val = is_lz77_length ? prefix_costs[pos].len - min_length
1410 : in[pos - 1].value;
1411 out.emplace_back(prefix_costs[pos].ctx, val);
1412 out.back().is_lz77_length = is_lz77_length;
1413 pos -= prefix_costs[pos].len;
1414 }
1415 std::reverse(out.begin(), out.end());
1416 }
1417 }
1418
ApplyLZ77(const HistogramParams & params,size_t num_contexts,const std::vector<std::vector<Token>> & tokens,LZ77Params & lz77,std::vector<std::vector<Token>> & tokens_lz77)1419 void ApplyLZ77(const HistogramParams& params, size_t num_contexts,
1420 const std::vector<std::vector<Token>>& tokens, LZ77Params& lz77,
1421 std::vector<std::vector<Token>>& tokens_lz77) {
1422 lz77.enabled = false;
1423 if (params.force_huffman) {
1424 lz77.min_symbol = std::min(PREFIX_MAX_ALPHABET_SIZE - 32, 512);
1425 } else {
1426 lz77.min_symbol = 224;
1427 }
1428 if (params.lz77_method == HistogramParams::LZ77Method::kNone) {
1429 return;
1430 } else if (params.lz77_method == HistogramParams::LZ77Method::kRLE) {
1431 ApplyLZ77_RLE(params, num_contexts, tokens, lz77, tokens_lz77);
1432 } else if (params.lz77_method == HistogramParams::LZ77Method::kLZ77) {
1433 ApplyLZ77_LZ77(params, num_contexts, tokens, lz77, tokens_lz77);
1434 } else if (params.lz77_method == HistogramParams::LZ77Method::kOptimal) {
1435 ApplyLZ77_Optimal(params, num_contexts, tokens, lz77, tokens_lz77);
1436 } else {
1437 JXL_ABORT("Not implemented");
1438 }
1439 }
1440 } // namespace
1441
BuildAndEncodeHistograms(const HistogramParams & params,size_t num_contexts,std::vector<std::vector<Token>> & tokens,EntropyEncodingData * codes,std::vector<uint8_t> * context_map,BitWriter * writer,size_t layer,AuxOut * aux_out)1442 size_t BuildAndEncodeHistograms(const HistogramParams& params,
1443 size_t num_contexts,
1444 std::vector<std::vector<Token>>& tokens,
1445 EntropyEncodingData* codes,
1446 std::vector<uint8_t>* context_map,
1447 BitWriter* writer, size_t layer,
1448 AuxOut* aux_out) {
1449 size_t total_bits = 0;
1450 codes->lz77.nonserialized_distance_context = num_contexts;
1451 std::vector<std::vector<Token>> tokens_lz77;
1452 ApplyLZ77(params, num_contexts, tokens, codes->lz77, tokens_lz77);
1453 if (ans_fuzzer_friendly_) {
1454 codes->lz77.length_uint_config = HybridUintConfig(10, 0, 0);
1455 codes->lz77.min_symbol = 2048;
1456 }
1457
1458 const size_t max_contexts = std::min(num_contexts, kClustersLimit);
1459 BitWriter::Allotment allotment(writer,
1460 128 + num_contexts * 40 + max_contexts * 96);
1461 if (writer) {
1462 JXL_CHECK(Bundle::Write(codes->lz77, writer, layer, aux_out));
1463 } else {
1464 size_t ebits, bits;
1465 JXL_CHECK(Bundle::CanEncode(codes->lz77, &ebits, &bits));
1466 total_bits += bits;
1467 }
1468 if (codes->lz77.enabled) {
1469 if (writer) {
1470 size_t b = writer->BitsWritten();
1471 EncodeUintConfig(codes->lz77.length_uint_config, writer,
1472 /*log_alpha_size=*/8);
1473 total_bits += writer->BitsWritten() - b;
1474 } else {
1475 SizeWriter size_writer;
1476 EncodeUintConfig(codes->lz77.length_uint_config, &size_writer,
1477 /*log_alpha_size=*/8);
1478 total_bits += size_writer.size;
1479 }
1480 num_contexts += 1;
1481 tokens = std::move(tokens_lz77);
1482 }
1483 size_t total_tokens = 0;
1484 // Build histograms.
1485 HistogramBuilder builder(num_contexts);
1486 HybridUintConfig uint_config; // Default config for clustering.
1487 // Unless we are using the kContextMap histogram option.
1488 if (params.uint_method == HistogramParams::HybridUintMethod::kContextMap) {
1489 uint_config = HybridUintConfig(2, 0, 1);
1490 }
1491 if (ans_fuzzer_friendly_) {
1492 uint_config = HybridUintConfig(10, 0, 0);
1493 }
1494 for (size_t i = 0; i < tokens.size(); ++i) {
1495 for (size_t j = 0; j < tokens[i].size(); ++j) {
1496 const Token token = tokens[i][j];
1497 total_tokens++;
1498 uint32_t tok, nbits, bits;
1499 (token.is_lz77_length ? codes->lz77.length_uint_config : uint_config)
1500 .Encode(token.value, &tok, &nbits, &bits);
1501 tok += token.is_lz77_length ? codes->lz77.min_symbol : 0;
1502 builder.VisitSymbol(tok, token.context);
1503 }
1504 }
1505
1506 bool use_prefix_code =
1507 params.force_huffman || total_tokens < 100 ||
1508 params.clustering == HistogramParams::ClusteringType::kFastest ||
1509 ans_fuzzer_friendly_;
1510 if (!use_prefix_code) {
1511 bool all_singleton = true;
1512 for (size_t i = 0; i < num_contexts; i++) {
1513 if (builder.Histo(i).ShannonEntropy() >= 1e-5) {
1514 all_singleton = false;
1515 }
1516 }
1517 if (all_singleton) {
1518 use_prefix_code = true;
1519 }
1520 }
1521
1522 // Encode histograms.
1523 total_bits += builder.BuildAndStoreEntropyCodes(params, tokens, codes,
1524 context_map, use_prefix_code,
1525 writer, layer, aux_out);
1526 allotment.FinishedHistogram(writer);
1527 ReclaimAndCharge(writer, &allotment, layer, aux_out);
1528
1529 if (aux_out != nullptr) {
1530 aux_out->layers[layer].num_clustered_histograms +=
1531 codes->encoding_info.size();
1532 }
1533 return total_bits;
1534 }
1535
WriteTokens(const std::vector<Token> & tokens,const EntropyEncodingData & codes,const std::vector<uint8_t> & context_map,BitWriter * writer)1536 size_t WriteTokens(const std::vector<Token>& tokens,
1537 const EntropyEncodingData& codes,
1538 const std::vector<uint8_t>& context_map, BitWriter* writer) {
1539 size_t num_extra_bits = 0;
1540 if (codes.use_prefix_code) {
1541 for (size_t i = 0; i < tokens.size(); i++) {
1542 uint32_t tok, nbits, bits;
1543 const Token& token = tokens[i];
1544 size_t histo = context_map[token.context];
1545 (token.is_lz77_length ? codes.lz77.length_uint_config
1546 : codes.uint_config[histo])
1547 .Encode(token.value, &tok, &nbits, &bits);
1548 tok += token.is_lz77_length ? codes.lz77.min_symbol : 0;
1549 // Combine two calls to the BitWriter. Equivalent to:
1550 // writer->Write(codes.encoding_info[histo][tok].depth,
1551 // codes.encoding_info[histo][tok].bits);
1552 // writer->Write(nbits, bits);
1553 uint64_t data = codes.encoding_info[histo][tok].bits;
1554 data |= bits << codes.encoding_info[histo][tok].depth;
1555 writer->Write(codes.encoding_info[histo][tok].depth + nbits, data);
1556 num_extra_bits += nbits;
1557 }
1558 return num_extra_bits;
1559 }
1560 std::vector<uint64_t> out;
1561 std::vector<uint8_t> out_nbits;
1562 out.reserve(tokens.size());
1563 out_nbits.reserve(tokens.size());
1564 uint64_t allbits = 0;
1565 size_t numallbits = 0;
1566 // Writes in *reversed* order.
1567 auto addbits = [&](size_t bits, size_t nbits) {
1568 JXL_DASSERT(bits >> nbits == 0);
1569 if (JXL_UNLIKELY(numallbits + nbits > BitWriter::kMaxBitsPerCall)) {
1570 out.push_back(allbits);
1571 out_nbits.push_back(numallbits);
1572 numallbits = allbits = 0;
1573 }
1574 allbits <<= nbits;
1575 allbits |= bits;
1576 numallbits += nbits;
1577 };
1578 const int end = tokens.size();
1579 ANSCoder ans;
1580 for (int i = end - 1; i >= 0; --i) {
1581 const Token token = tokens[i];
1582 const uint8_t histo = context_map[token.context];
1583 uint32_t tok, nbits, bits;
1584 (token.is_lz77_length ? codes.lz77.length_uint_config
1585 : codes.uint_config[histo])
1586 .Encode(tokens[i].value, &tok, &nbits, &bits);
1587 tok += token.is_lz77_length ? codes.lz77.min_symbol : 0;
1588 const ANSEncSymbolInfo& info = codes.encoding_info[histo][tok];
1589 // Extra bits first as this is reversed.
1590 addbits(bits, nbits);
1591 num_extra_bits += nbits;
1592 uint8_t ans_nbits = 0;
1593 uint32_t ans_bits = ans.PutSymbol(info, &ans_nbits);
1594 addbits(ans_bits, ans_nbits);
1595 }
1596 const uint32_t state = ans.GetState();
1597 writer->Write(32, state);
1598 writer->Write(numallbits, allbits);
1599 for (int i = out.size(); i > 0; --i) {
1600 writer->Write(out_nbits[i - 1], out[i - 1]);
1601 }
1602 return num_extra_bits;
1603 }
1604
WriteTokens(const std::vector<Token> & tokens,const EntropyEncodingData & codes,const std::vector<uint8_t> & context_map,BitWriter * writer,size_t layer,AuxOut * aux_out)1605 void WriteTokens(const std::vector<Token>& tokens,
1606 const EntropyEncodingData& codes,
1607 const std::vector<uint8_t>& context_map, BitWriter* writer,
1608 size_t layer, AuxOut* aux_out) {
1609 BitWriter::Allotment allotment(writer, 32 * tokens.size() + 32 * 1024 * 4);
1610 size_t num_extra_bits = WriteTokens(tokens, codes, context_map, writer);
1611 ReclaimAndCharge(writer, &allotment, layer, aux_out);
1612 if (aux_out != nullptr) {
1613 aux_out->layers[layer].extra_bits += num_extra_bits;
1614 }
1615 }
1616
SetANSFuzzerFriendly(bool ans_fuzzer_friendly)1617 void SetANSFuzzerFriendly(bool ans_fuzzer_friendly) {
1618 #if JXL_IS_DEBUG_BUILD // Guard against accidental / malicious changes.
1619 ans_fuzzer_friendly_ = ans_fuzzer_friendly;
1620 #endif
1621 }
1622 } // namespace jxl
1623