1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
2 //
3 // Use of this source code is governed by a BSD-style
4 // license that can be found in the LICENSE file.
5
6 #include "lib/jxl/enc_ans.h"
7
8 #include <stdint.h>
9
10 #include <algorithm>
11 #include <array>
12 #include <cmath>
13 #include <limits>
14 #include <numeric>
15 #include <type_traits>
16 #include <unordered_map>
17 #include <utility>
18 #include <vector>
19
20 #include "lib/jxl/ans_common.h"
21 #include "lib/jxl/aux_out.h"
22 #include "lib/jxl/aux_out_fwd.h"
23 #include "lib/jxl/base/bits.h"
24 #include "lib/jxl/dec_ans.h"
25 #include "lib/jxl/enc_cluster.h"
26 #include "lib/jxl/enc_context_map.h"
27 #include "lib/jxl/enc_huffman.h"
28 #include "lib/jxl/fast_math-inl.h"
29 #include "lib/jxl/fields.h"
30
31 namespace jxl {
32
33 namespace {
34
35 bool ans_fuzzer_friendly_ = false;
36
37 static const int kMaxNumSymbolsForSmallCode = 4;
38
ANSBuildInfoTable(const ANSHistBin * counts,const AliasTable::Entry * table,size_t alphabet_size,size_t log_alpha_size,ANSEncSymbolInfo * info)39 void ANSBuildInfoTable(const ANSHistBin* counts, const AliasTable::Entry* table,
40 size_t alphabet_size, size_t log_alpha_size,
41 ANSEncSymbolInfo* info) {
42 size_t log_entry_size = ANS_LOG_TAB_SIZE - log_alpha_size;
43 size_t entry_size_minus_1 = (1 << log_entry_size) - 1;
44 // create valid alias table for empty streams.
45 for (size_t s = 0; s < std::max<size_t>(1, alphabet_size); ++s) {
46 const ANSHistBin freq = s == alphabet_size ? ANS_TAB_SIZE : counts[s];
47 info[s].freq_ = static_cast<uint16_t>(freq);
48 #ifdef USE_MULT_BY_RECIPROCAL
49 if (freq != 0) {
50 info[s].ifreq_ =
51 ((1ull << RECIPROCAL_PRECISION) + info[s].freq_ - 1) / info[s].freq_;
52 } else {
53 info[s].ifreq_ = 1; // shouldn't matter (symbol shouldn't occur), but...
54 }
55 #endif
56 info[s].reverse_map_.resize(freq);
57 }
58 for (int i = 0; i < ANS_TAB_SIZE; i++) {
59 AliasTable::Symbol s =
60 AliasTable::Lookup(table, i, log_entry_size, entry_size_minus_1);
61 info[s.value].reverse_map_[s.offset] = i;
62 }
63 }
64
EstimateDataBits(const ANSHistBin * histogram,const ANSHistBin * counts,size_t len)65 float EstimateDataBits(const ANSHistBin* histogram, const ANSHistBin* counts,
66 size_t len) {
67 float sum = 0.0f;
68 int total_histogram = 0;
69 int total_counts = 0;
70 for (size_t i = 0; i < len; ++i) {
71 total_histogram += histogram[i];
72 total_counts += counts[i];
73 if (histogram[i] > 0) {
74 JXL_ASSERT(counts[i] > 0);
75 // += histogram[i] * -log(counts[i]/total_counts)
76 sum += histogram[i] *
77 std::max(0.0f, ANS_LOG_TAB_SIZE - FastLog2f(counts[i]));
78 }
79 }
80 if (total_histogram > 0) {
81 JXL_ASSERT(total_counts == ANS_TAB_SIZE);
82 }
83 return sum;
84 }
85
EstimateDataBitsFlat(const ANSHistBin * histogram,size_t len)86 float EstimateDataBitsFlat(const ANSHistBin* histogram, size_t len) {
87 const float flat_bits = std::max(FastLog2f(len), 0.0f);
88 int total_histogram = 0;
89 for (size_t i = 0; i < len; ++i) {
90 total_histogram += histogram[i];
91 }
92 return total_histogram * flat_bits;
93 }
94
95 // Static Huffman code for encoding logcounts. The last symbol is used as RLE
96 // sequence.
97 static const uint8_t kLogCountBitLengths[ANS_LOG_TAB_SIZE + 2] = {
98 5, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 6, 7, 7,
99 };
100 static const uint8_t kLogCountSymbols[ANS_LOG_TAB_SIZE + 2] = {
101 17, 11, 15, 3, 9, 7, 4, 2, 5, 6, 0, 33, 1, 65,
102 };
103
104 // Returns the difference between largest count that can be represented and is
105 // smaller than "count" and smallest representable count larger than "count".
SmallestIncrement(uint32_t count,uint32_t shift)106 static int SmallestIncrement(uint32_t count, uint32_t shift) {
107 int bits = count == 0 ? -1 : FloorLog2Nonzero(count);
108 int drop_bits = bits - GetPopulationCountPrecision(bits, shift);
109 return drop_bits < 0 ? 1 : (1 << drop_bits);
110 }
111
112 template <bool minimize_error_of_sum>
RebalanceHistogram(const float * targets,int max_symbol,int table_size,uint32_t shift,int * omit_pos,ANSHistBin * counts)113 bool RebalanceHistogram(const float* targets, int max_symbol, int table_size,
114 uint32_t shift, int* omit_pos, ANSHistBin* counts) {
115 int sum = 0;
116 float sum_nonrounded = 0.0;
117 int remainder_pos = 0; // if all of them are handled in first loop
118 int remainder_log = -1;
119 for (int n = 0; n < max_symbol; ++n) {
120 if (targets[n] > 0 && targets[n] < 1.0f) {
121 counts[n] = 1;
122 sum_nonrounded += targets[n];
123 sum += counts[n];
124 }
125 }
126 const float discount_ratio =
127 (table_size - sum) / (table_size - sum_nonrounded);
128 JXL_ASSERT(discount_ratio > 0);
129 JXL_ASSERT(discount_ratio <= 1.0f);
130 // Invariant for minimize_error_of_sum == true:
131 // abs(sum - sum_nonrounded)
132 // <= SmallestIncrement(max(targets[])) + max_symbol
133 for (int n = 0; n < max_symbol; ++n) {
134 if (targets[n] >= 1.0f) {
135 sum_nonrounded += targets[n];
136 counts[n] =
137 static_cast<ANSHistBin>(targets[n] * discount_ratio); // truncate
138 if (counts[n] == 0) counts[n] = 1;
139 if (counts[n] == table_size) counts[n] = table_size - 1;
140 // Round the count to the closest nonzero multiple of SmallestIncrement
141 // (when minimize_error_of_sum is false) or one of two closest so as to
142 // keep the sum as close as possible to sum_nonrounded.
143 int inc = SmallestIncrement(counts[n], shift);
144 counts[n] -= counts[n] & (inc - 1);
145 // TODO(robryk): Should we rescale targets[n]?
146 const float target =
147 minimize_error_of_sum ? (sum_nonrounded - sum) : targets[n];
148 if (counts[n] == 0 ||
149 (target > counts[n] + inc / 2 && counts[n] + inc < table_size)) {
150 counts[n] += inc;
151 }
152 sum += counts[n];
153 const int count_log = FloorLog2Nonzero(static_cast<uint32_t>(counts[n]));
154 if (count_log > remainder_log) {
155 remainder_pos = n;
156 remainder_log = count_log;
157 }
158 }
159 }
160 JXL_ASSERT(remainder_pos != -1);
161 // NOTE: This is the only place where counts could go negative. We could
162 // detect that, return false and make ANSHistBin uint32_t.
163 counts[remainder_pos] -= sum - table_size;
164 *omit_pos = remainder_pos;
165 return counts[remainder_pos] > 0;
166 }
167
NormalizeCounts(ANSHistBin * counts,int * omit_pos,const int length,const int precision_bits,uint32_t shift,int * num_symbols,int * symbols)168 Status NormalizeCounts(ANSHistBin* counts, int* omit_pos, const int length,
169 const int precision_bits, uint32_t shift,
170 int* num_symbols, int* symbols) {
171 const int32_t table_size = 1 << precision_bits; // target sum / table size
172 uint64_t total = 0;
173 int max_symbol = 0;
174 int symbol_count = 0;
175 for (int n = 0; n < length; ++n) {
176 total += counts[n];
177 if (counts[n] > 0) {
178 if (symbol_count < kMaxNumSymbolsForSmallCode) {
179 symbols[symbol_count] = n;
180 }
181 ++symbol_count;
182 max_symbol = n + 1;
183 }
184 }
185 *num_symbols = symbol_count;
186 if (symbol_count == 0) {
187 return true;
188 }
189 if (symbol_count == 1) {
190 counts[symbols[0]] = table_size;
191 return true;
192 }
193 if (symbol_count > table_size)
194 return JXL_FAILURE("Too many entries in an ANS histogram");
195
196 const float norm = 1.f * table_size / total;
197 std::vector<float> targets(max_symbol);
198 for (size_t n = 0; n < targets.size(); ++n) {
199 targets[n] = norm * counts[n];
200 }
201 if (!RebalanceHistogram<false>(&targets[0], max_symbol, table_size, shift,
202 omit_pos, counts)) {
203 // Use an alternative rebalancing mechanism if the one above failed
204 // to create a histogram that is positive wherever the original one was.
205 if (!RebalanceHistogram<true>(&targets[0], max_symbol, table_size, shift,
206 omit_pos, counts)) {
207 return JXL_FAILURE("Logic error: couldn't rebalance a histogram");
208 }
209 }
210 return true;
211 }
212
213 struct SizeWriter {
214 size_t size = 0;
Writejxl::__anon30e99be10111::SizeWriter215 void Write(size_t num, size_t bits) { size += num; }
216 };
217
218 template <typename Writer>
StoreVarLenUint8(size_t n,Writer * writer)219 void StoreVarLenUint8(size_t n, Writer* writer) {
220 JXL_DASSERT(n <= 255);
221 if (n == 0) {
222 writer->Write(1, 0);
223 } else {
224 writer->Write(1, 1);
225 size_t nbits = FloorLog2Nonzero(n);
226 writer->Write(3, nbits);
227 writer->Write(nbits, n - (1ULL << nbits));
228 }
229 }
230
231 template <typename Writer>
StoreVarLenUint16(size_t n,Writer * writer)232 void StoreVarLenUint16(size_t n, Writer* writer) {
233 JXL_DASSERT(n <= 65535);
234 if (n == 0) {
235 writer->Write(1, 0);
236 } else {
237 writer->Write(1, 1);
238 size_t nbits = FloorLog2Nonzero(n);
239 writer->Write(4, nbits);
240 writer->Write(nbits, n - (1ULL << nbits));
241 }
242 }
243
244 template <typename Writer>
EncodeCounts(const ANSHistBin * counts,const int alphabet_size,const int omit_pos,const int num_symbols,uint32_t shift,const int * symbols,Writer * writer)245 bool EncodeCounts(const ANSHistBin* counts, const int alphabet_size,
246 const int omit_pos, const int num_symbols, uint32_t shift,
247 const int* symbols, Writer* writer) {
248 bool ok = true;
249 if (num_symbols <= 2) {
250 // Small tree marker to encode 1-2 symbols.
251 writer->Write(1, 1);
252 if (num_symbols == 0) {
253 writer->Write(1, 0);
254 StoreVarLenUint8(0, writer);
255 } else {
256 writer->Write(1, num_symbols - 1);
257 for (int i = 0; i < num_symbols; ++i) {
258 StoreVarLenUint8(symbols[i], writer);
259 }
260 }
261 if (num_symbols == 2) {
262 writer->Write(ANS_LOG_TAB_SIZE, counts[symbols[0]]);
263 }
264 } else {
265 // Mark non-small tree.
266 writer->Write(1, 0);
267 // Mark non-flat histogram.
268 writer->Write(1, 0);
269
270 // Precompute sequences for RLE encoding. Contains the number of identical
271 // values starting at a given index. Only contains the value at the first
272 // element of the series.
273 std::vector<uint32_t> same(alphabet_size, 0);
274 int last = 0;
275 for (int i = 1; i < alphabet_size; i++) {
276 // Store the sequence length once different symbol reached, or we're at
277 // the end, or the length is longer than we can encode, or we are at
278 // the omit_pos. We don't support including the omit_pos in an RLE
279 // sequence because this value may use a different amount of log2 bits
280 // than standard, it is too complex to handle in the decoder.
281 if (counts[i] != counts[last] || i + 1 == alphabet_size ||
282 (i - last) >= 255 || i == omit_pos || i == omit_pos + 1) {
283 same[last] = (i - last);
284 last = i + 1;
285 }
286 }
287
288 int length = 0;
289 std::vector<int> logcounts(alphabet_size);
290 int omit_log = 0;
291 for (int i = 0; i < alphabet_size; ++i) {
292 JXL_ASSERT(counts[i] <= ANS_TAB_SIZE);
293 JXL_ASSERT(counts[i] >= 0);
294 if (i == omit_pos) {
295 length = i + 1;
296 } else if (counts[i] > 0) {
297 logcounts[i] = FloorLog2Nonzero(static_cast<uint32_t>(counts[i])) + 1;
298 length = i + 1;
299 if (i < omit_pos) {
300 omit_log = std::max(omit_log, logcounts[i] + 1);
301 } else {
302 omit_log = std::max(omit_log, logcounts[i]);
303 }
304 }
305 }
306 logcounts[omit_pos] = omit_log;
307
308 // Elias gamma-like code for shift. Only difference is that if the number
309 // of bits to be encoded is equal to FloorLog2(ANS_LOG_TAB_SIZE+1), we skip
310 // the terminating 0 in unary coding.
311 int upper_bound_log = FloorLog2Nonzero(ANS_LOG_TAB_SIZE + 1);
312 int log = FloorLog2Nonzero(shift + 1);
313 writer->Write(log, (1 << log) - 1);
314 if (log != upper_bound_log) writer->Write(1, 0);
315 writer->Write(log, ((1 << log) - 1) & (shift + 1));
316
317 // Since num_symbols >= 3, we know that length >= 3, therefore we encode
318 // length - 3.
319 if (length - 3 > 255) {
320 // Pretend that everything is OK, but complain about correctness later.
321 StoreVarLenUint8(255, writer);
322 ok = false;
323 } else {
324 StoreVarLenUint8(length - 3, writer);
325 }
326
327 // The logcount values are encoded with a static Huffman code.
328 static const size_t kMinReps = 4;
329 size_t rep = ANS_LOG_TAB_SIZE + 1;
330 for (int i = 0; i < length; ++i) {
331 if (i > 0 && same[i - 1] > kMinReps) {
332 // Encode the RLE symbol and skip the repeated ones.
333 writer->Write(kLogCountBitLengths[rep], kLogCountSymbols[rep]);
334 StoreVarLenUint8(same[i - 1] - kMinReps - 1, writer);
335 i += same[i - 1] - 2;
336 continue;
337 }
338 writer->Write(kLogCountBitLengths[logcounts[i]],
339 kLogCountSymbols[logcounts[i]]);
340 }
341 for (int i = 0; i < length; ++i) {
342 if (i > 0 && same[i - 1] > kMinReps) {
343 // Skip symbols encoded by RLE.
344 i += same[i - 1] - 2;
345 continue;
346 }
347 if (logcounts[i] > 1 && i != omit_pos) {
348 int bitcount = GetPopulationCountPrecision(logcounts[i] - 1, shift);
349 int drop_bits = logcounts[i] - 1 - bitcount;
350 JXL_CHECK((counts[i] & ((1 << drop_bits) - 1)) == 0);
351 writer->Write(bitcount, (counts[i] >> drop_bits) - (1 << bitcount));
352 }
353 }
354 }
355 return ok;
356 }
357
EncodeFlatHistogram(const int alphabet_size,BitWriter * writer)358 void EncodeFlatHistogram(const int alphabet_size, BitWriter* writer) {
359 // Mark non-small tree.
360 writer->Write(1, 0);
361 // Mark uniform histogram.
362 writer->Write(1, 1);
363 JXL_ASSERT(alphabet_size > 0);
364 // Encode alphabet size.
365 StoreVarLenUint8(alphabet_size - 1, writer);
366 }
367
ComputeHistoAndDataCost(const ANSHistBin * histogram,size_t alphabet_size,uint32_t method)368 float ComputeHistoAndDataCost(const ANSHistBin* histogram, size_t alphabet_size,
369 uint32_t method) {
370 if (method == 0) { // Flat code
371 return ANS_LOG_TAB_SIZE + 2 +
372 EstimateDataBitsFlat(histogram, alphabet_size);
373 }
374 // Non-flat: shift = method-1.
375 uint32_t shift = method - 1;
376 std::vector<ANSHistBin> counts(histogram, histogram + alphabet_size);
377 int omit_pos = 0;
378 int num_symbols;
379 int symbols[kMaxNumSymbolsForSmallCode] = {};
380 JXL_CHECK(NormalizeCounts(counts.data(), &omit_pos, alphabet_size,
381 ANS_LOG_TAB_SIZE, shift, &num_symbols, symbols));
382 SizeWriter writer;
383 // Ignore the correctness, no real encoding happens at this stage.
384 (void)EncodeCounts(counts.data(), alphabet_size, omit_pos, num_symbols, shift,
385 symbols, &writer);
386 return writer.size +
387 EstimateDataBits(histogram, counts.data(), alphabet_size);
388 }
389
ComputeBestMethod(const ANSHistBin * histogram,size_t alphabet_size,float * cost,HistogramParams::ANSHistogramStrategy ans_histogram_strategy)390 uint32_t ComputeBestMethod(
391 const ANSHistBin* histogram, size_t alphabet_size, float* cost,
392 HistogramParams::ANSHistogramStrategy ans_histogram_strategy) {
393 size_t method = 0;
394 float fcost = ComputeHistoAndDataCost(histogram, alphabet_size, 0);
395 for (uint32_t shift = 0; shift <= ANS_LOG_TAB_SIZE;
396 ans_histogram_strategy != HistogramParams::ANSHistogramStrategy::kPrecise
397 ? shift += 2
398 : shift++) {
399 float c = ComputeHistoAndDataCost(histogram, alphabet_size, shift + 1);
400 if (c < fcost) {
401 method = shift + 1;
402 fcost = c;
403 } else if (ans_histogram_strategy ==
404 HistogramParams::ANSHistogramStrategy::kFast) {
405 // do not be as precise if estimating cost.
406 break;
407 }
408 }
409 *cost = fcost;
410 return method;
411 }
412
413 } // namespace
414
415 // Returns an estimate of the cost of encoding this histogram and the
416 // corresponding data.
BuildAndStoreANSEncodingData(HistogramParams::ANSHistogramStrategy ans_histogram_strategy,const ANSHistBin * histogram,size_t alphabet_size,size_t log_alpha_size,bool use_prefix_code,ANSEncSymbolInfo * info,BitWriter * writer)417 size_t BuildAndStoreANSEncodingData(
418 HistogramParams::ANSHistogramStrategy ans_histogram_strategy,
419 const ANSHistBin* histogram, size_t alphabet_size, size_t log_alpha_size,
420 bool use_prefix_code, ANSEncSymbolInfo* info, BitWriter* writer) {
421 if (use_prefix_code) {
422 if (alphabet_size <= 1) return 0;
423 std::vector<uint32_t> histo(alphabet_size);
424 size_t total = 0;
425 for (size_t i = 0; i < alphabet_size; i++) {
426 histo[i] = histogram[i];
427 JXL_CHECK(histogram[i] >= 0);
428 total += histo[i];
429 }
430 size_t cost = 0;
431 {
432 std::vector<uint8_t> depths(alphabet_size);
433 std::vector<uint16_t> bits(alphabet_size);
434 BitWriter tmp_writer;
435 BitWriter* w = writer ? writer : &tmp_writer;
436 size_t start = w->BitsWritten();
437 BitWriter::Allotment allotment(
438 w, 8 * alphabet_size + 8); // safe upper bound
439 BuildAndStoreHuffmanTree(histo.data(), alphabet_size, depths.data(),
440 bits.data(), w);
441 ReclaimAndCharge(w, &allotment, 0, /*aux_out=*/nullptr);
442
443 for (size_t i = 0; i < alphabet_size; i++) {
444 info[i].bits = depths[i] == 0 ? 0 : bits[i];
445 info[i].depth = depths[i];
446 }
447 cost = w->BitsWritten() - start;
448 }
449 // Estimate data cost.
450 for (size_t i = 0; i < alphabet_size; i++) {
451 cost += histogram[i] * info[i].depth;
452 }
453 return cost;
454 }
455 JXL_ASSERT(alphabet_size <= ANS_TAB_SIZE);
456 // Ensure we ignore trailing zeros in the histogram.
457 if (alphabet_size != 0) {
458 size_t largest_symbol = 0;
459 for (size_t i = 0; i < alphabet_size; i++) {
460 if (histogram[i] != 0) largest_symbol = i;
461 }
462 alphabet_size = largest_symbol + 1;
463 }
464 float cost;
465 uint32_t method = ComputeBestMethod(histogram, alphabet_size, &cost,
466 ans_histogram_strategy);
467 JXL_ASSERT(cost >= 0);
468 int num_symbols;
469 int symbols[kMaxNumSymbolsForSmallCode] = {};
470 std::vector<ANSHistBin> counts(histogram, histogram + alphabet_size);
471 if (!counts.empty()) {
472 size_t sum = 0;
473 for (size_t i = 0; i < counts.size(); i++) {
474 sum += counts[i];
475 }
476 if (sum == 0) {
477 counts[0] = ANS_TAB_SIZE;
478 }
479 }
480 if (method == 0) {
481 counts = CreateFlatHistogram(alphabet_size, ANS_TAB_SIZE);
482 AliasTable::Entry a[ANS_MAX_ALPHABET_SIZE];
483 InitAliasTable(counts, ANS_TAB_SIZE, log_alpha_size, a);
484 ANSBuildInfoTable(counts.data(), a, alphabet_size, log_alpha_size, info);
485 if (writer != nullptr) {
486 EncodeFlatHistogram(alphabet_size, writer);
487 }
488 return cost;
489 }
490 int omit_pos = 0;
491 uint32_t shift = method - 1;
492 JXL_CHECK(NormalizeCounts(counts.data(), &omit_pos, alphabet_size,
493 ANS_LOG_TAB_SIZE, shift, &num_symbols, symbols));
494 AliasTable::Entry a[ANS_MAX_ALPHABET_SIZE];
495 InitAliasTable(counts, ANS_TAB_SIZE, log_alpha_size, a);
496 ANSBuildInfoTable(counts.data(), a, alphabet_size, log_alpha_size, info);
497 if (writer != nullptr) {
498 bool ok = EncodeCounts(counts.data(), alphabet_size, omit_pos, num_symbols,
499 shift, symbols, writer);
500 (void)ok;
501 JXL_DASSERT(ok);
502 }
503 return cost;
504 }
505
ANSPopulationCost(const ANSHistBin * data,size_t alphabet_size)506 float ANSPopulationCost(const ANSHistBin* data, size_t alphabet_size) {
507 float c;
508 ComputeBestMethod(data, alphabet_size, &c,
509 HistogramParams::ANSHistogramStrategy::kFast);
510 return c;
511 }
512
513 template <typename Writer>
EncodeUintConfig(const HybridUintConfig uint_config,Writer * writer,size_t log_alpha_size)514 void EncodeUintConfig(const HybridUintConfig uint_config, Writer* writer,
515 size_t log_alpha_size) {
516 writer->Write(CeilLog2Nonzero(log_alpha_size + 1),
517 uint_config.split_exponent);
518 if (uint_config.split_exponent == log_alpha_size) {
519 return; // msb/lsb don't matter.
520 }
521 size_t nbits = CeilLog2Nonzero(uint_config.split_exponent + 1);
522 writer->Write(nbits, uint_config.msb_in_token);
523 nbits = CeilLog2Nonzero(uint_config.split_exponent -
524 uint_config.msb_in_token + 1);
525 writer->Write(nbits, uint_config.lsb_in_token);
526 }
527 template <typename Writer>
EncodeUintConfigs(const std::vector<HybridUintConfig> & uint_config,Writer * writer,size_t log_alpha_size)528 void EncodeUintConfigs(const std::vector<HybridUintConfig>& uint_config,
529 Writer* writer, size_t log_alpha_size) {
530 // TODO(veluca): RLE?
531 for (size_t i = 0; i < uint_config.size(); i++) {
532 EncodeUintConfig(uint_config[i], writer, log_alpha_size);
533 }
534 }
535 template void EncodeUintConfigs(const std::vector<HybridUintConfig>&,
536 BitWriter*, size_t);
537
538 namespace {
539
ChooseUintConfigs(const HistogramParams & params,const std::vector<std::vector<Token>> & tokens,const std::vector<uint8_t> & context_map,std::vector<Histogram> * clustered_histograms,EntropyEncodingData * codes,size_t * log_alpha_size)540 void ChooseUintConfigs(const HistogramParams& params,
541 const std::vector<std::vector<Token>>& tokens,
542 const std::vector<uint8_t>& context_map,
543 std::vector<Histogram>* clustered_histograms,
544 EntropyEncodingData* codes, size_t* log_alpha_size) {
545 codes->uint_config.resize(clustered_histograms->size());
546 if (params.uint_method == HistogramParams::HybridUintMethod::kNone) return;
547 if (params.uint_method == HistogramParams::HybridUintMethod::kContextMap) {
548 codes->uint_config.clear();
549 codes->uint_config.resize(clustered_histograms->size(),
550 HybridUintConfig(2, 0, 1));
551 return;
552 }
553
554 // Brute-force method that tries a few options.
555 std::vector<HybridUintConfig> configs;
556 if (params.uint_method == HistogramParams::HybridUintMethod::kBest) {
557 configs = {
558 HybridUintConfig(4, 2, 0), // default
559 HybridUintConfig(4, 1, 0), // less precise
560 HybridUintConfig(4, 2, 1), // add sign
561 HybridUintConfig(4, 2, 2), // add sign+parity
562 HybridUintConfig(4, 1, 2), // add parity but less msb
563 // Same as above, but more direct coding.
564 HybridUintConfig(5, 2, 0), HybridUintConfig(5, 1, 0),
565 HybridUintConfig(5, 2, 1), HybridUintConfig(5, 2, 2),
566 HybridUintConfig(5, 1, 2),
567 // Same as above, but less direct coding.
568 HybridUintConfig(3, 2, 0), HybridUintConfig(3, 1, 0),
569 HybridUintConfig(3, 2, 1), HybridUintConfig(3, 1, 2),
570 // For near-lossless.
571 HybridUintConfig(4, 1, 3), HybridUintConfig(5, 1, 4),
572 HybridUintConfig(5, 2, 3), HybridUintConfig(6, 1, 5),
573 HybridUintConfig(6, 2, 4), HybridUintConfig(6, 0, 0),
574 // Other
575 HybridUintConfig(0, 0, 0), // varlenuint
576 HybridUintConfig(2, 0, 1), // works well for ctx map
577 HybridUintConfig(7, 0, 0), // direct coding
578 HybridUintConfig(8, 0, 0), // direct coding
579 HybridUintConfig(9, 0, 0), // direct coding
580 HybridUintConfig(10, 0, 0), // direct coding
581 HybridUintConfig(11, 0, 0), // direct coding
582 HybridUintConfig(12, 0, 0), // direct coding
583 };
584 } else if (params.uint_method == HistogramParams::HybridUintMethod::kFast) {
585 configs = {
586 HybridUintConfig(4, 2, 0), // default
587 HybridUintConfig(4, 1, 2), // add parity but less msb
588 HybridUintConfig(0, 0, 0), // smallest histograms
589 HybridUintConfig(2, 0, 1), // works well for ctx map
590 };
591 }
592
593 std::vector<float> costs(clustered_histograms->size(),
594 std::numeric_limits<float>::max());
595 std::vector<uint32_t> extra_bits(clustered_histograms->size());
596 std::vector<uint8_t> is_valid(clustered_histograms->size());
597 size_t max_alpha =
598 codes->use_prefix_code ? PREFIX_MAX_ALPHABET_SIZE : ANS_MAX_ALPHABET_SIZE;
599 for (HybridUintConfig cfg : configs) {
600 std::fill(is_valid.begin(), is_valid.end(), true);
601 std::fill(extra_bits.begin(), extra_bits.end(), 0);
602
603 for (size_t i = 0; i < clustered_histograms->size(); i++) {
604 (*clustered_histograms)[i].Clear();
605 }
606 for (size_t i = 0; i < tokens.size(); ++i) {
607 for (size_t j = 0; j < tokens[i].size(); ++j) {
608 const Token token = tokens[i][j];
609 // TODO(veluca): do not ignore lz77 commands.
610 if (token.is_lz77_length) continue;
611 size_t histo = context_map[token.context];
612 uint32_t tok, nbits, bits;
613 cfg.Encode(token.value, &tok, &nbits, &bits);
614 if (tok >= max_alpha ||
615 (codes->lz77.enabled && tok >= codes->lz77.min_symbol)) {
616 is_valid[histo] = false;
617 continue;
618 }
619 extra_bits[histo] += nbits;
620 (*clustered_histograms)[histo].Add(tok);
621 }
622 }
623
624 for (size_t i = 0; i < clustered_histograms->size(); i++) {
625 if (!is_valid[i]) continue;
626 float cost = (*clustered_histograms)[i].PopulationCost() + extra_bits[i];
627 if (cost < costs[i]) {
628 codes->uint_config[i] = cfg;
629 costs[i] = cost;
630 }
631 }
632 }
633
634 // Rebuild histograms.
635 for (size_t i = 0; i < clustered_histograms->size(); i++) {
636 (*clustered_histograms)[i].Clear();
637 }
638 *log_alpha_size = 4;
639 for (size_t i = 0; i < tokens.size(); ++i) {
640 for (size_t j = 0; j < tokens[i].size(); ++j) {
641 const Token token = tokens[i][j];
642 uint32_t tok, nbits, bits;
643 size_t histo = context_map[token.context];
644 (token.is_lz77_length ? codes->lz77.length_uint_config
645 : codes->uint_config[histo])
646 .Encode(token.value, &tok, &nbits, &bits);
647 tok += token.is_lz77_length ? codes->lz77.min_symbol : 0;
648 (*clustered_histograms)[histo].Add(tok);
649 while (tok >= (1u << *log_alpha_size)) (*log_alpha_size)++;
650 }
651 }
652 #if JXL_ENABLE_ASSERT
653 size_t max_log_alpha_size = codes->use_prefix_code ? PREFIX_MAX_BITS : 8;
654 JXL_ASSERT(*log_alpha_size <= max_log_alpha_size);
655 #endif
656 }
657
658 class HistogramBuilder {
659 public:
HistogramBuilder(const size_t num_contexts)660 explicit HistogramBuilder(const size_t num_contexts)
661 : histograms_(num_contexts) {}
662
VisitSymbol(int symbol,size_t histo_idx)663 void VisitSymbol(int symbol, size_t histo_idx) {
664 JXL_DASSERT(histo_idx < histograms_.size());
665 histograms_[histo_idx].Add(symbol);
666 }
667
668 // NOTE: `layer` is only for clustered_entropy; caller does ReclaimAndCharge.
BuildAndStoreEntropyCodes(const HistogramParams & params,const std::vector<std::vector<Token>> & tokens,EntropyEncodingData * codes,std::vector<uint8_t> * context_map,bool use_prefix_code,BitWriter * writer,size_t layer,AuxOut * aux_out) const669 size_t BuildAndStoreEntropyCodes(
670 const HistogramParams& params,
671 const std::vector<std::vector<Token>>& tokens, EntropyEncodingData* codes,
672 std::vector<uint8_t>* context_map, bool use_prefix_code,
673 BitWriter* writer, size_t layer, AuxOut* aux_out) const {
674 size_t cost = 0;
675 codes->encoding_info.clear();
676 std::vector<Histogram> clustered_histograms(histograms_);
677 context_map->resize(histograms_.size());
678 if (histograms_.size() > 1) {
679 if (!ans_fuzzer_friendly_) {
680 std::vector<uint32_t> histogram_symbols;
681 ClusterHistograms(params, histograms_, histograms_.size(),
682 kClustersLimit, &clustered_histograms,
683 &histogram_symbols);
684 for (size_t c = 0; c < histograms_.size(); ++c) {
685 (*context_map)[c] = static_cast<uint8_t>(histogram_symbols[c]);
686 }
687 } else {
688 fill(context_map->begin(), context_map->end(), 0);
689 size_t max_symbol = 0;
690 for (const Histogram& h : histograms_) {
691 max_symbol = std::max(h.data_.size(), max_symbol);
692 }
693 size_t num_symbols = 1 << CeilLog2Nonzero(max_symbol + 1);
694 clustered_histograms.resize(1);
695 clustered_histograms[0].Clear();
696 for (size_t i = 0; i < num_symbols; i++) {
697 clustered_histograms[0].Add(i);
698 }
699 }
700 if (writer != nullptr) {
701 EncodeContextMap(*context_map, clustered_histograms.size(), writer);
702 }
703 }
704 if (aux_out != nullptr) {
705 for (size_t i = 0; i < clustered_histograms.size(); ++i) {
706 aux_out->layers[layer].clustered_entropy +=
707 clustered_histograms[i].ShannonEntropy();
708 }
709 }
710 codes->use_prefix_code = use_prefix_code;
711 size_t log_alpha_size = codes->lz77.enabled ? 8 : 7; // Sane default.
712 if (ans_fuzzer_friendly_) {
713 codes->uint_config.clear();
714 codes->uint_config.resize(1, HybridUintConfig(7, 0, 0));
715 } else {
716 ChooseUintConfigs(params, tokens, *context_map, &clustered_histograms,
717 codes, &log_alpha_size);
718 }
719 if (log_alpha_size < 5) log_alpha_size = 5;
720 SizeWriter size_writer; // Used if writer == nullptr to estimate costs.
721 cost += 1;
722 if (writer) writer->Write(1, use_prefix_code);
723
724 if (use_prefix_code) {
725 log_alpha_size = PREFIX_MAX_BITS;
726 } else {
727 cost += 2;
728 }
729 if (writer == nullptr) {
730 EncodeUintConfigs(codes->uint_config, &size_writer, log_alpha_size);
731 } else {
732 if (!use_prefix_code) writer->Write(2, log_alpha_size - 5);
733 EncodeUintConfigs(codes->uint_config, writer, log_alpha_size);
734 }
735 if (use_prefix_code) {
736 for (size_t c = 0; c < clustered_histograms.size(); ++c) {
737 size_t num_symbol = 1;
738 for (size_t i = 0; i < clustered_histograms[c].data_.size(); i++) {
739 if (clustered_histograms[c].data_[i]) num_symbol = i + 1;
740 }
741 if (writer) {
742 StoreVarLenUint16(num_symbol - 1, writer);
743 } else {
744 StoreVarLenUint16(num_symbol - 1, &size_writer);
745 }
746 }
747 }
748 cost += size_writer.size;
749 for (size_t c = 0; c < clustered_histograms.size(); ++c) {
750 size_t num_symbol = 1;
751 for (size_t i = 0; i < clustered_histograms[c].data_.size(); i++) {
752 if (clustered_histograms[c].data_[i]) num_symbol = i + 1;
753 }
754 codes->encoding_info.emplace_back();
755 codes->encoding_info.back().resize(std::max<size_t>(1, num_symbol));
756
757 BitWriter::Allotment allotment(writer, 256 + num_symbol * 24);
758 cost += BuildAndStoreANSEncodingData(
759 params.ans_histogram_strategy, clustered_histograms[c].data_.data(),
760 num_symbol, log_alpha_size, use_prefix_code,
761 codes->encoding_info.back().data(), writer);
762 allotment.FinishedHistogram(writer);
763 ReclaimAndCharge(writer, &allotment, layer, aux_out);
764 }
765 return cost;
766 }
767
Histo(size_t i) const768 const Histogram& Histo(size_t i) const { return histograms_[i]; }
769
770 private:
771 std::vector<Histogram> histograms_;
772 };
773
774 class SymbolCostEstimator {
775 public:
SymbolCostEstimator(size_t num_contexts,bool force_huffman,const std::vector<std::vector<Token>> & tokens,const LZ77Params & lz77)776 SymbolCostEstimator(size_t num_contexts, bool force_huffman,
777 const std::vector<std::vector<Token>>& tokens,
778 const LZ77Params& lz77) {
779 HistogramBuilder builder(num_contexts);
780 // Build histograms for estimating lz77 savings.
781 HybridUintConfig uint_config;
782 for (size_t i = 0; i < tokens.size(); ++i) {
783 for (size_t j = 0; j < tokens[i].size(); ++j) {
784 const Token token = tokens[i][j];
785 uint32_t tok, nbits, bits;
786 (token.is_lz77_length ? lz77.length_uint_config : uint_config)
787 .Encode(token.value, &tok, &nbits, &bits);
788 tok += token.is_lz77_length ? lz77.min_symbol : 0;
789 builder.VisitSymbol(tok, token.context);
790 }
791 }
792 max_alphabet_size_ = 0;
793 for (size_t i = 0; i < num_contexts; i++) {
794 max_alphabet_size_ =
795 std::max(max_alphabet_size_, builder.Histo(i).data_.size());
796 }
797 bits_.resize(num_contexts * max_alphabet_size_);
798 // TODO(veluca): SIMD?
799 add_symbol_cost_.resize(num_contexts);
800 for (size_t i = 0; i < num_contexts; i++) {
801 float inv_total = 1.0f / (builder.Histo(i).total_count_ + 1e-8f);
802 float total_cost = 0;
803 for (size_t j = 0; j < builder.Histo(i).data_.size(); j++) {
804 size_t cnt = builder.Histo(i).data_[j];
805 float cost = 0;
806 if (cnt != 0 && cnt != builder.Histo(i).total_count_) {
807 cost = -FastLog2f(cnt * inv_total);
808 if (force_huffman) cost = std::ceil(cost);
809 } else if (cnt == 0) {
810 cost = ANS_LOG_TAB_SIZE; // Highest possible cost.
811 }
812 bits_[i * max_alphabet_size_ + j] = cost;
813 total_cost += cost * builder.Histo(i).data_[j];
814 }
815 // Penalty for adding a lz77 symbol to this contest (only used for static
816 // cost model). Higher penalty for contexts that have a very low
817 // per-symbol entropy.
818 add_symbol_cost_[i] = std::max(0.0f, 6.0f - total_cost * inv_total);
819 }
820 }
Bits(size_t ctx,size_t sym) const821 float Bits(size_t ctx, size_t sym) const {
822 return bits_[ctx * max_alphabet_size_ + sym];
823 }
LenCost(size_t ctx,size_t len,const LZ77Params & lz77) const824 float LenCost(size_t ctx, size_t len, const LZ77Params& lz77) const {
825 uint32_t nbits, bits, tok;
826 lz77.length_uint_config.Encode(len, &tok, &nbits, &bits);
827 tok += lz77.min_symbol;
828 return nbits + Bits(ctx, tok);
829 }
DistCost(size_t len,const LZ77Params & lz77) const830 float DistCost(size_t len, const LZ77Params& lz77) const {
831 uint32_t nbits, bits, tok;
832 HybridUintConfig().Encode(len, &tok, &nbits, &bits);
833 return nbits + Bits(lz77.nonserialized_distance_context, tok);
834 }
AddSymbolCost(size_t idx) const835 float AddSymbolCost(size_t idx) const { return add_symbol_cost_[idx]; }
836
837 private:
838 size_t max_alphabet_size_;
839 std::vector<float> bits_;
840 std::vector<float> add_symbol_cost_;
841 };
842
ApplyLZ77_RLE(const HistogramParams & params,size_t num_contexts,const std::vector<std::vector<Token>> & tokens,LZ77Params & lz77,std::vector<std::vector<Token>> & tokens_lz77)843 void ApplyLZ77_RLE(const HistogramParams& params, size_t num_contexts,
844 const std::vector<std::vector<Token>>& tokens,
845 LZ77Params& lz77,
846 std::vector<std::vector<Token>>& tokens_lz77) {
847 // TODO(veluca): tune heuristics here.
848 SymbolCostEstimator sce(num_contexts, params.force_huffman, tokens, lz77);
849 float bit_decrease = 0;
850 size_t total_symbols = 0;
851 tokens_lz77.resize(tokens.size());
852 std::vector<float> sym_cost;
853 HybridUintConfig uint_config;
854 for (size_t stream = 0; stream < tokens.size(); stream++) {
855 size_t distance_multiplier =
856 params.image_widths.size() > stream ? params.image_widths[stream] : 0;
857 const auto& in = tokens[stream];
858 auto& out = tokens_lz77[stream];
859 total_symbols += in.size();
860 // Cumulative sum of bit costs.
861 sym_cost.resize(in.size() + 1);
862 for (size_t i = 0; i < in.size(); i++) {
863 uint32_t tok, nbits, unused_bits;
864 uint_config.Encode(in[i].value, &tok, &nbits, &unused_bits);
865 sym_cost[i + 1] = sce.Bits(in[i].context, tok) + nbits + sym_cost[i];
866 }
867 out.reserve(in.size());
868 for (size_t i = 0; i < in.size(); i++) {
869 size_t num_to_copy = 0;
870 size_t distance_symbol = 0; // 1 for RLE.
871 if (distance_multiplier != 0) {
872 distance_symbol = 1; // Special distance 1 if enabled.
873 JXL_DASSERT(kSpecialDistances[1][0] == 1);
874 JXL_DASSERT(kSpecialDistances[1][1] == 0);
875 }
876 if (i > 0) {
877 for (; i + num_to_copy < in.size(); num_to_copy++) {
878 if (in[i + num_to_copy].value != in[i - 1].value) {
879 break;
880 }
881 }
882 }
883 if (num_to_copy == 0) {
884 out.push_back(in[i]);
885 continue;
886 }
887 float cost = sym_cost[i + num_to_copy] - sym_cost[i];
888 // This subtraction might overflow, but that's OK.
889 size_t lz77_len = num_to_copy - lz77.min_length;
890 float lz77_cost = num_to_copy >= lz77.min_length
891 ? CeilLog2Nonzero(lz77_len + 1) + 1
892 : 0;
893 if (num_to_copy < lz77.min_length || cost <= lz77_cost) {
894 for (size_t j = 0; j < num_to_copy; j++) {
895 out.push_back(in[i + j]);
896 }
897 i += num_to_copy - 1;
898 continue;
899 }
900 // Output the LZ77 length
901 out.emplace_back(in[i].context, lz77_len);
902 out.back().is_lz77_length = true;
903 i += num_to_copy - 1;
904 bit_decrease += cost - lz77_cost;
905 // Output the LZ77 copy distance.
906 out.emplace_back(lz77.nonserialized_distance_context, distance_symbol);
907 }
908 }
909
910 if (bit_decrease > total_symbols * 0.2 + 16) {
911 lz77.enabled = true;
912 }
913 }
914
915 // Hash chain for LZ77 matching
916 struct HashChain {
917 size_t size_;
918 std::vector<uint32_t> data_;
919
920 unsigned hash_num_values_ = 32768;
921 unsigned hash_mask_ = hash_num_values_ - 1;
922 unsigned hash_shift_ = 5;
923
924 std::vector<int> head;
925 std::vector<uint32_t> chain;
926 std::vector<int> val;
927
928 // Speed up repetitions of zero
929 std::vector<int> headz;
930 std::vector<uint32_t> chainz;
931 std::vector<uint32_t> zeros;
932 uint32_t numzeros = 0;
933
934 size_t window_size_;
935 size_t window_mask_;
936 size_t min_length_;
937 size_t max_length_;
938
939 // Map of special distance codes.
940 std::unordered_map<int, int> special_dist_table_;
941 size_t num_special_distances_ = 0;
942
943 uint32_t maxchainlength = 256; // window_size_ to allow all
944
HashChainjxl::__anon30e99be10211::HashChain945 HashChain(const Token* data, size_t size, size_t window_size,
946 size_t min_length, size_t max_length, size_t distance_multiplier)
947 : size_(size),
948 window_size_(window_size),
949 window_mask_(window_size - 1),
950 min_length_(min_length),
951 max_length_(max_length) {
952 data_.resize(size);
953 for (size_t i = 0; i < size; i++) {
954 data_[i] = data[i].value;
955 }
956
957 head.resize(hash_num_values_, -1);
958 val.resize(window_size_, -1);
959 chain.resize(window_size_);
960 for (uint32_t i = 0; i < window_size_; ++i) {
961 chain[i] = i; // same value as index indicates uninitialized
962 }
963
964 zeros.resize(window_size_);
965 headz.resize(window_size_ + 1, -1);
966 chainz.resize(window_size_);
967 for (uint32_t i = 0; i < window_size_; ++i) {
968 chainz[i] = i;
969 }
970 // Translate distance to special distance code.
971 if (distance_multiplier) {
972 // Count down, so if due to small distance multiplier multiple distances
973 // map to the same code, the smallest code will be used in the end.
974 for (int i = kNumSpecialDistances - 1; i >= 0; --i) {
975 int xi = kSpecialDistances[i][0];
976 int yi = kSpecialDistances[i][1];
977 int distance = yi * distance_multiplier + xi;
978 // Ensure that we map distance 1 to the lowest symbols.
979 if (distance < 1) distance = 1;
980 special_dist_table_[distance] = i;
981 }
982 num_special_distances_ = kNumSpecialDistances;
983 }
984 }
985
GetHashjxl::__anon30e99be10211::HashChain986 uint32_t GetHash(size_t pos) const {
987 uint32_t result = 0;
988 if (pos + 2 < size_) {
989 // TODO(lode): take the MSB's of the uint32_t values into account as well,
990 // given that the hash code itself is less than 32 bits.
991 result ^= (uint32_t)(data_[pos + 0] << 0u);
992 result ^= (uint32_t)(data_[pos + 1] << hash_shift_);
993 result ^= (uint32_t)(data_[pos + 2] << (hash_shift_ * 2));
994 } else {
995 // No need to compute hash of last 2 bytes, the length 2 is too short.
996 return 0;
997 }
998 return result & hash_mask_;
999 }
1000
CountZerosjxl::__anon30e99be10211::HashChain1001 uint32_t CountZeros(size_t pos, uint32_t prevzeros) const {
1002 size_t end = pos + window_size_;
1003 if (end > size_) end = size_;
1004 if (prevzeros > 0) {
1005 if (prevzeros >= window_mask_ && data_[end - 1] == 0 &&
1006 end == pos + window_size_) {
1007 return prevzeros;
1008 } else {
1009 return prevzeros - 1;
1010 }
1011 }
1012 uint32_t num = 0;
1013 while (pos + num < end && data_[pos + num] == 0) num++;
1014 return num;
1015 }
1016
Updatejxl::__anon30e99be10211::HashChain1017 void Update(size_t pos) {
1018 uint32_t hashval = GetHash(pos);
1019 uint32_t wpos = pos & window_mask_;
1020
1021 val[wpos] = (int)hashval;
1022 if (head[hashval] != -1) chain[wpos] = head[hashval];
1023 head[hashval] = wpos;
1024
1025 if (pos > 0 && data_[pos] != data_[pos - 1]) numzeros = 0;
1026 numzeros = CountZeros(pos, numzeros);
1027
1028 zeros[wpos] = numzeros;
1029 if (headz[numzeros] != -1) chainz[wpos] = headz[numzeros];
1030 headz[numzeros] = wpos;
1031 }
1032
Updatejxl::__anon30e99be10211::HashChain1033 void Update(size_t pos, size_t len) {
1034 for (size_t i = 0; i < len; i++) {
1035 Update(pos + i);
1036 }
1037 }
1038
1039 template <typename CB>
FindMatchesjxl::__anon30e99be10211::HashChain1040 void FindMatches(size_t pos, int max_dist, const CB& found_match) const {
1041 uint32_t wpos = pos & window_mask_;
1042 uint32_t hashval = GetHash(pos);
1043 uint32_t hashpos = chain[wpos];
1044
1045 int prev_dist = 0;
1046 int end = std::min<int>(pos + max_length_, size_);
1047 uint32_t chainlength = 0;
1048 uint32_t best_len = 0;
1049 for (;;) {
1050 int dist = (hashpos <= wpos) ? (wpos - hashpos)
1051 : (wpos - hashpos + window_mask_ + 1);
1052 if (dist < prev_dist) break;
1053 prev_dist = dist;
1054 uint32_t len = 0;
1055 if (dist > 0) {
1056 int i = pos;
1057 int j = pos - dist;
1058 if (numzeros > 3) {
1059 int r = std::min<int>(numzeros - 1, zeros[hashpos]);
1060 if (i + r >= end) r = end - i - 1;
1061 i += r;
1062 j += r;
1063 }
1064 while (i < end && data_[i] == data_[j]) {
1065 i++;
1066 j++;
1067 }
1068 len = i - pos;
1069 // This can trigger even if the new length is slightly smaller than the
1070 // best length, because it is possible for a slightly cheaper distance
1071 // symbol to occur.
1072 if (len >= min_length_ && len + 2 >= best_len) {
1073 auto it = special_dist_table_.find(dist);
1074 int dist_symbol = (it == special_dist_table_.end())
1075 ? (num_special_distances_ + dist - 1)
1076 : it->second;
1077 found_match(len, dist_symbol);
1078 if (len > best_len) best_len = len;
1079 }
1080 }
1081
1082 chainlength++;
1083 if (chainlength >= maxchainlength) break;
1084
1085 if (numzeros >= 3 && len > numzeros) {
1086 if (hashpos == chainz[hashpos]) break;
1087 hashpos = chainz[hashpos];
1088 if (zeros[hashpos] != numzeros) break;
1089 } else {
1090 if (hashpos == chain[hashpos]) break;
1091 hashpos = chain[hashpos];
1092 if (val[hashpos] != (int)hashval) break; // outdated hash value
1093 }
1094 }
1095 }
FindMatchjxl::__anon30e99be10211::HashChain1096 void FindMatch(size_t pos, int max_dist, size_t* result_dist_symbol,
1097 size_t* result_len) const {
1098 *result_dist_symbol = 0;
1099 *result_len = 1;
1100 FindMatches(pos, max_dist, [&](size_t len, size_t dist_symbol) {
1101 if (len > *result_len ||
1102 (len == *result_len && *result_dist_symbol > dist_symbol)) {
1103 *result_len = len;
1104 *result_dist_symbol = dist_symbol;
1105 }
1106 });
1107 }
1108 };
1109
LenCost(size_t len)1110 float LenCost(size_t len) {
1111 uint32_t nbits, bits, tok;
1112 HybridUintConfig(1, 0, 0).Encode(len, &tok, &nbits, &bits);
1113 constexpr float kCostTable[] = {
1114 2.797667318563126, 3.213177690381199, 2.5706009246743737,
1115 2.408392498667534, 2.829649191872326, 3.3923087753324577,
1116 4.029267451554331, 4.415576699706408, 4.509357574741465,
1117 9.21481543803004, 10.020590190114898, 11.858671627804766,
1118 12.45853300490526, 11.713105831990857, 12.561996324849314,
1119 13.775477692278367, 13.174027068768641,
1120 };
1121 size_t table_size = sizeof kCostTable / sizeof *kCostTable;
1122 if (tok >= table_size) tok = table_size - 1;
1123 return kCostTable[tok] + nbits;
1124 }
1125
1126 // TODO(veluca): this does not take into account usage or non-usage of distance
1127 // multipliers.
DistCost(size_t dist)1128 float DistCost(size_t dist) {
1129 uint32_t nbits, bits, tok;
1130 HybridUintConfig(7, 0, 0).Encode(dist, &tok, &nbits, &bits);
1131 constexpr float kCostTable[] = {
1132 6.368282626312716, 5.680793277090298, 8.347404197105247,
1133 7.641619201599141, 6.914328374119438, 7.959808291537444,
1134 8.70023120759855, 8.71378518934703, 9.379132523982769,
1135 9.110472749092708, 9.159029569270908, 9.430936766731973,
1136 7.278284055315169, 7.8278514904267755, 10.026641158289236,
1137 9.976049229827066, 9.64351607048908, 9.563403863480442,
1138 10.171474111762747, 10.45950155077234, 9.994813912104219,
1139 10.322524683741156, 8.465808729388186, 8.756254166066853,
1140 10.160930174662234, 10.247329273413435, 10.04090403724809,
1141 10.129398517544082, 9.342311691539546, 9.07608009102374,
1142 10.104799540677513, 10.378079384990906, 10.165828974075072,
1143 10.337595322341553, 7.940557464567944, 10.575665823319431,
1144 11.023344321751955, 10.736144698831827, 11.118277044595054,
1145 7.468468230648442, 10.738305230932939, 10.906980780216568,
1146 10.163468216353817, 10.17805759656433, 11.167283670483565,
1147 11.147050200274544, 10.517921919244333, 10.651764778156886,
1148 10.17074446448919, 11.217636876224745, 11.261630721139484,
1149 11.403140815247259, 10.892472096873417, 11.1859607804481,
1150 8.017346947551262, 7.895143720278828, 11.036577113822025,
1151 11.170562110315794, 10.326988722591086, 10.40872184751056,
1152 11.213498225466386, 11.30580635516863, 10.672272515665442,
1153 10.768069466228063, 11.145257364153565, 11.64668307145549,
1154 10.593156194627339, 11.207499484844943, 10.767517766396908,
1155 10.826629811407042, 10.737764794499988, 10.6200448518045,
1156 10.191315385198092, 8.468384171390085, 11.731295299170432,
1157 11.824619886654398, 10.41518844301179, 10.16310536548649,
1158 10.539423685097576, 10.495136599328031, 10.469112847728267,
1159 11.72057686174922, 10.910326337834674, 11.378921834673758,
1160 11.847759036098536, 11.92071647623854, 10.810628276345282,
1161 11.008601085273893, 11.910326337834674, 11.949212023423133,
1162 11.298614839104337, 11.611603659010392, 10.472930394619985,
1163 11.835564720850282, 11.523267392285337, 12.01055816679611,
1164 8.413029688994023, 11.895784139536406, 11.984679534970505,
1165 11.220654278717394, 11.716311684833672, 10.61036646226114,
1166 10.89849965960364, 10.203762898863669, 10.997560826267238,
1167 11.484217379438984, 11.792836176993665, 12.24310468755171,
1168 11.464858097919262, 12.212747017409377, 11.425595666074955,
1169 11.572048533398757, 12.742093965163013, 11.381874288645637,
1170 12.191870445817015, 11.683156920035426, 11.152442115262197,
1171 11.90303691580457, 11.653292787169159, 11.938615382266098,
1172 16.970641701570223, 16.853602280380002, 17.26240782594733,
1173 16.644655390108507, 17.14310889757499, 16.910935455445955,
1174 17.505678976959697, 17.213498225466388, 2.4162310293553024,
1175 3.494587244462329, 3.5258600986408344, 3.4959806589517095,
1176 3.098390886949687, 3.343454654302911, 3.588847442290287,
1177 4.14614790111827, 5.152948641990529, 7.433696808092598,
1178 9.716311684833672,
1179 };
1180 size_t table_size = sizeof kCostTable / sizeof *kCostTable;
1181 if (tok >= table_size) tok = table_size - 1;
1182 return kCostTable[tok] + nbits;
1183 }
1184
ApplyLZ77_LZ77(const HistogramParams & params,size_t num_contexts,const std::vector<std::vector<Token>> & tokens,LZ77Params & lz77,std::vector<std::vector<Token>> & tokens_lz77)1185 void ApplyLZ77_LZ77(const HistogramParams& params, size_t num_contexts,
1186 const std::vector<std::vector<Token>>& tokens,
1187 LZ77Params& lz77,
1188 std::vector<std::vector<Token>>& tokens_lz77) {
1189 // TODO(veluca): tune heuristics here.
1190 SymbolCostEstimator sce(num_contexts, params.force_huffman, tokens, lz77);
1191 float bit_decrease = 0;
1192 size_t total_symbols = 0;
1193 tokens_lz77.resize(tokens.size());
1194 HybridUintConfig uint_config;
1195 std::vector<float> sym_cost;
1196 for (size_t stream = 0; stream < tokens.size(); stream++) {
1197 size_t distance_multiplier =
1198 params.image_widths.size() > stream ? params.image_widths[stream] : 0;
1199 const auto& in = tokens[stream];
1200 auto& out = tokens_lz77[stream];
1201 total_symbols += in.size();
1202 // Cumulative sum of bit costs.
1203 sym_cost.resize(in.size() + 1);
1204 for (size_t i = 0; i < in.size(); i++) {
1205 uint32_t tok, nbits, unused_bits;
1206 uint_config.Encode(in[i].value, &tok, &nbits, &unused_bits);
1207 sym_cost[i + 1] = sce.Bits(in[i].context, tok) + nbits + sym_cost[i];
1208 }
1209
1210 out.reserve(in.size());
1211 size_t max_distance = in.size();
1212 size_t min_length = lz77.min_length;
1213 JXL_ASSERT(min_length >= 3);
1214 size_t max_length = in.size();
1215
1216 // Use next power of two as window size.
1217 size_t window_size = 1;
1218 while (window_size < max_distance && window_size < kWindowSize) {
1219 window_size <<= 1;
1220 }
1221
1222 HashChain chain(in.data(), in.size(), window_size, min_length, max_length,
1223 distance_multiplier);
1224 size_t len, dist_symbol;
1225
1226 const size_t max_lazy_match_len = 256; // 0 to disable lazy matching
1227
1228 // Whether the next symbol was already updated (to test lazy matching)
1229 bool already_updated = false;
1230 for (size_t i = 0; i < in.size(); i++) {
1231 out.push_back(in[i]);
1232 if (!already_updated) chain.Update(i);
1233 already_updated = false;
1234 chain.FindMatch(i, max_distance, &dist_symbol, &len);
1235 if (len >= min_length) {
1236 if (len < max_lazy_match_len && i + 1 < in.size()) {
1237 // Try length at next symbol lazy matching
1238 chain.Update(i + 1);
1239 already_updated = true;
1240 size_t len2, dist_symbol2;
1241 chain.FindMatch(i + 1, max_distance, &dist_symbol2, &len2);
1242 if (len2 > len) {
1243 // Use the lazy match. Add literal, and use the next length starting
1244 // from the next byte.
1245 ++i;
1246 already_updated = false;
1247 len = len2;
1248 dist_symbol = dist_symbol2;
1249 out.push_back(in[i]);
1250 }
1251 }
1252
1253 float cost = sym_cost[i + len] - sym_cost[i];
1254 size_t lz77_len = len - lz77.min_length;
1255 float lz77_cost = LenCost(lz77_len) + DistCost(dist_symbol) +
1256 sce.AddSymbolCost(out.back().context);
1257
1258 if (lz77_cost <= cost) {
1259 out.back().value = len - min_length;
1260 out.back().is_lz77_length = true;
1261 out.emplace_back(lz77.nonserialized_distance_context, dist_symbol);
1262 bit_decrease += cost - lz77_cost;
1263 } else {
1264 // LZ77 match ignored, and symbol already pushed. Push all other
1265 // symbols and skip.
1266 for (size_t j = 1; j < len; j++) {
1267 out.push_back(in[i + j]);
1268 }
1269 }
1270
1271 if (already_updated) {
1272 chain.Update(i + 2, len - 2);
1273 already_updated = false;
1274 } else {
1275 chain.Update(i + 1, len - 1);
1276 }
1277 i += len - 1;
1278 } else {
1279 // Literal, already pushed
1280 }
1281 }
1282 }
1283
1284 if (bit_decrease > total_symbols * 0.2 + 16) {
1285 lz77.enabled = true;
1286 }
1287 }
1288
ApplyLZ77_Optimal(const HistogramParams & params,size_t num_contexts,const std::vector<std::vector<Token>> & tokens,LZ77Params & lz77,std::vector<std::vector<Token>> & tokens_lz77)1289 void ApplyLZ77_Optimal(const HistogramParams& params, size_t num_contexts,
1290 const std::vector<std::vector<Token>>& tokens,
1291 LZ77Params& lz77,
1292 std::vector<std::vector<Token>>& tokens_lz77) {
1293 std::vector<std::vector<Token>> tokens_for_cost_estimate;
1294 ApplyLZ77_LZ77(params, num_contexts, tokens, lz77, tokens_for_cost_estimate);
1295 // If greedy-LZ77 does not give better compression than no-lz77, no reason to
1296 // run the optimal matching.
1297 if (!lz77.enabled) return;
1298 SymbolCostEstimator sce(num_contexts + 1, params.force_huffman,
1299 tokens_for_cost_estimate, lz77);
1300 size_t total_symbols = 0;
1301 tokens_lz77.resize(tokens.size());
1302 HybridUintConfig uint_config;
1303 std::vector<float> sym_cost;
1304 std::vector<uint32_t> dist_symbols;
1305 for (size_t stream = 0; stream < tokens.size(); stream++) {
1306 size_t distance_multiplier =
1307 params.image_widths.size() > stream ? params.image_widths[stream] : 0;
1308 const auto& in = tokens[stream];
1309 auto& out = tokens_lz77[stream];
1310 total_symbols += in.size();
1311 // Cumulative sum of bit costs.
1312 sym_cost.resize(in.size() + 1);
1313 for (size_t i = 0; i < in.size(); i++) {
1314 uint32_t tok, nbits, unused_bits;
1315 uint_config.Encode(in[i].value, &tok, &nbits, &unused_bits);
1316 sym_cost[i + 1] = sce.Bits(in[i].context, tok) + nbits + sym_cost[i];
1317 }
1318
1319 out.reserve(in.size());
1320 size_t max_distance = in.size();
1321 size_t min_length = lz77.min_length;
1322 JXL_ASSERT(min_length >= 3);
1323 size_t max_length = in.size();
1324
1325 // Use next power of two as window size.
1326 size_t window_size = 1;
1327 while (window_size < max_distance && window_size < kWindowSize) {
1328 window_size <<= 1;
1329 }
1330
1331 HashChain chain(in.data(), in.size(), window_size, min_length, max_length,
1332 distance_multiplier);
1333
1334 struct MatchInfo {
1335 uint32_t len;
1336 uint32_t dist_symbol;
1337 uint32_t ctx;
1338 float total_cost = std::numeric_limits<float>::max();
1339 };
1340 // Total cost to encode the first N symbols.
1341 std::vector<MatchInfo> prefix_costs(in.size() + 1);
1342 prefix_costs[0].total_cost = 0;
1343
1344 size_t rle_length = 0;
1345 size_t skip_lz77 = 0;
1346 for (size_t i = 0; i < in.size(); i++) {
1347 chain.Update(i);
1348 float lit_cost =
1349 prefix_costs[i].total_cost + sym_cost[i + 1] - sym_cost[i];
1350 if (prefix_costs[i + 1].total_cost > lit_cost) {
1351 prefix_costs[i + 1].dist_symbol = 0;
1352 prefix_costs[i + 1].len = 1;
1353 prefix_costs[i + 1].ctx = in[i].context;
1354 prefix_costs[i + 1].total_cost = lit_cost;
1355 }
1356 if (skip_lz77 > 0) {
1357 skip_lz77--;
1358 continue;
1359 }
1360 dist_symbols.clear();
1361 chain.FindMatches(i, max_distance,
1362 [&dist_symbols](size_t len, size_t dist_symbol) {
1363 if (dist_symbols.size() <= len) {
1364 dist_symbols.resize(len + 1, dist_symbol);
1365 }
1366 if (dist_symbol < dist_symbols[len]) {
1367 dist_symbols[len] = dist_symbol;
1368 }
1369 });
1370 if (dist_symbols.size() <= min_length) continue;
1371 {
1372 size_t best_cost = dist_symbols.back();
1373 for (size_t j = dist_symbols.size() - 1; j >= min_length; j--) {
1374 if (dist_symbols[j] < best_cost) {
1375 best_cost = dist_symbols[j];
1376 }
1377 dist_symbols[j] = best_cost;
1378 }
1379 }
1380 for (size_t j = min_length; j < dist_symbols.size(); j++) {
1381 // Cost model that uses results from lazy LZ77.
1382 float lz77_cost = sce.LenCost(in[i].context, j - min_length, lz77) +
1383 sce.DistCost(dist_symbols[j], lz77);
1384 float cost = prefix_costs[i].total_cost + lz77_cost;
1385 if (prefix_costs[i + j].total_cost > cost) {
1386 prefix_costs[i + j].len = j;
1387 prefix_costs[i + j].dist_symbol = dist_symbols[j] + 1;
1388 prefix_costs[i + j].ctx = in[i].context;
1389 prefix_costs[i + j].total_cost = cost;
1390 }
1391 }
1392 // We are in a RLE sequence: skip all the symbols except the first 8 and
1393 // the last 8. This avoid quadratic costs for sequences with long runs of
1394 // the same symbol.
1395 if ((dist_symbols.back() == 0 && distance_multiplier == 0) ||
1396 (dist_symbols.back() == 1 && distance_multiplier != 0)) {
1397 rle_length++;
1398 } else {
1399 rle_length = 0;
1400 }
1401 if (rle_length >= 8 && dist_symbols.size() > 9) {
1402 skip_lz77 = dist_symbols.size() - 10;
1403 rle_length = 0;
1404 }
1405 }
1406 size_t pos = in.size();
1407 while (pos > 0) {
1408 bool is_lz77_length = prefix_costs[pos].dist_symbol != 0;
1409 if (is_lz77_length) {
1410 size_t dist_symbol = prefix_costs[pos].dist_symbol - 1;
1411 out.emplace_back(lz77.nonserialized_distance_context, dist_symbol);
1412 }
1413 size_t val = is_lz77_length ? prefix_costs[pos].len - min_length
1414 : in[pos - 1].value;
1415 out.emplace_back(prefix_costs[pos].ctx, val);
1416 out.back().is_lz77_length = is_lz77_length;
1417 pos -= prefix_costs[pos].len;
1418 }
1419 std::reverse(out.begin(), out.end());
1420 }
1421 }
1422
ApplyLZ77(const HistogramParams & params,size_t num_contexts,const std::vector<std::vector<Token>> & tokens,LZ77Params & lz77,std::vector<std::vector<Token>> & tokens_lz77)1423 void ApplyLZ77(const HistogramParams& params, size_t num_contexts,
1424 const std::vector<std::vector<Token>>& tokens, LZ77Params& lz77,
1425 std::vector<std::vector<Token>>& tokens_lz77) {
1426 lz77.enabled = false;
1427 if (params.force_huffman) {
1428 lz77.min_symbol = std::min(PREFIX_MAX_ALPHABET_SIZE - 32, 512);
1429 } else {
1430 lz77.min_symbol = 224;
1431 }
1432 if (params.lz77_method == HistogramParams::LZ77Method::kNone) {
1433 return;
1434 } else if (params.lz77_method == HistogramParams::LZ77Method::kRLE) {
1435 ApplyLZ77_RLE(params, num_contexts, tokens, lz77, tokens_lz77);
1436 } else if (params.lz77_method == HistogramParams::LZ77Method::kLZ77) {
1437 ApplyLZ77_LZ77(params, num_contexts, tokens, lz77, tokens_lz77);
1438 } else if (params.lz77_method == HistogramParams::LZ77Method::kOptimal) {
1439 ApplyLZ77_Optimal(params, num_contexts, tokens, lz77, tokens_lz77);
1440 } else {
1441 JXL_ABORT("Not implemented");
1442 }
1443 }
1444 } // namespace
1445
BuildAndEncodeHistograms(const HistogramParams & params,size_t num_contexts,std::vector<std::vector<Token>> & tokens,EntropyEncodingData * codes,std::vector<uint8_t> * context_map,BitWriter * writer,size_t layer,AuxOut * aux_out)1446 size_t BuildAndEncodeHistograms(const HistogramParams& params,
1447 size_t num_contexts,
1448 std::vector<std::vector<Token>>& tokens,
1449 EntropyEncodingData* codes,
1450 std::vector<uint8_t>* context_map,
1451 BitWriter* writer, size_t layer,
1452 AuxOut* aux_out) {
1453 size_t total_bits = 0;
1454 codes->lz77.nonserialized_distance_context = num_contexts;
1455 std::vector<std::vector<Token>> tokens_lz77;
1456 ApplyLZ77(params, num_contexts, tokens, codes->lz77, tokens_lz77);
1457 if (ans_fuzzer_friendly_) {
1458 codes->lz77.length_uint_config = HybridUintConfig(10, 0, 0);
1459 codes->lz77.min_symbol = 2048;
1460 }
1461
1462 const size_t max_contexts = std::min(num_contexts, kClustersLimit);
1463 BitWriter::Allotment allotment(writer,
1464 128 + num_contexts * 40 + max_contexts * 96);
1465 if (writer) {
1466 JXL_CHECK(Bundle::Write(codes->lz77, writer, layer, aux_out));
1467 } else {
1468 size_t ebits, bits;
1469 JXL_CHECK(Bundle::CanEncode(codes->lz77, &ebits, &bits));
1470 total_bits += bits;
1471 }
1472 if (codes->lz77.enabled) {
1473 if (writer) {
1474 size_t b = writer->BitsWritten();
1475 EncodeUintConfig(codes->lz77.length_uint_config, writer,
1476 /*log_alpha_size=*/8);
1477 total_bits += writer->BitsWritten() - b;
1478 } else {
1479 SizeWriter size_writer;
1480 EncodeUintConfig(codes->lz77.length_uint_config, &size_writer,
1481 /*log_alpha_size=*/8);
1482 total_bits += size_writer.size;
1483 }
1484 num_contexts += 1;
1485 tokens = std::move(tokens_lz77);
1486 }
1487 size_t total_tokens = 0;
1488 // Build histograms.
1489 HistogramBuilder builder(num_contexts);
1490 HybridUintConfig uint_config; // Default config for clustering.
1491 // Unless we are using the kContextMap histogram option.
1492 if (params.uint_method == HistogramParams::HybridUintMethod::kContextMap) {
1493 uint_config = HybridUintConfig(2, 0, 1);
1494 }
1495 if (ans_fuzzer_friendly_) {
1496 uint_config = HybridUintConfig(10, 0, 0);
1497 }
1498 for (size_t i = 0; i < tokens.size(); ++i) {
1499 for (size_t j = 0; j < tokens[i].size(); ++j) {
1500 const Token token = tokens[i][j];
1501 total_tokens++;
1502 uint32_t tok, nbits, bits;
1503 (token.is_lz77_length ? codes->lz77.length_uint_config : uint_config)
1504 .Encode(token.value, &tok, &nbits, &bits);
1505 tok += token.is_lz77_length ? codes->lz77.min_symbol : 0;
1506 builder.VisitSymbol(tok, token.context);
1507 }
1508 }
1509
1510 bool use_prefix_code =
1511 params.force_huffman || total_tokens < 100 ||
1512 params.clustering == HistogramParams::ClusteringType::kFastest ||
1513 ans_fuzzer_friendly_;
1514 if (!use_prefix_code) {
1515 bool all_singleton = true;
1516 for (size_t i = 0; i < num_contexts; i++) {
1517 if (builder.Histo(i).ShannonEntropy() >= 1e-5) {
1518 all_singleton = false;
1519 }
1520 }
1521 if (all_singleton) {
1522 use_prefix_code = true;
1523 }
1524 }
1525
1526 // Encode histograms.
1527 total_bits += builder.BuildAndStoreEntropyCodes(params, tokens, codes,
1528 context_map, use_prefix_code,
1529 writer, layer, aux_out);
1530 allotment.FinishedHistogram(writer);
1531 ReclaimAndCharge(writer, &allotment, layer, aux_out);
1532
1533 if (aux_out != nullptr) {
1534 aux_out->layers[layer].num_clustered_histograms +=
1535 codes->encoding_info.size();
1536 }
1537 return total_bits;
1538 }
1539
WriteTokens(const std::vector<Token> & tokens,const EntropyEncodingData & codes,const std::vector<uint8_t> & context_map,BitWriter * writer)1540 size_t WriteTokens(const std::vector<Token>& tokens,
1541 const EntropyEncodingData& codes,
1542 const std::vector<uint8_t>& context_map, BitWriter* writer) {
1543 size_t num_extra_bits = 0;
1544 if (codes.use_prefix_code) {
1545 for (size_t i = 0; i < tokens.size(); i++) {
1546 uint32_t tok, nbits, bits;
1547 const Token& token = tokens[i];
1548 size_t histo = context_map[token.context];
1549 (token.is_lz77_length ? codes.lz77.length_uint_config
1550 : codes.uint_config[histo])
1551 .Encode(token.value, &tok, &nbits, &bits);
1552 tok += token.is_lz77_length ? codes.lz77.min_symbol : 0;
1553 // Combine two calls to the BitWriter. Equivalent to:
1554 // writer->Write(codes.encoding_info[histo][tok].depth,
1555 // codes.encoding_info[histo][tok].bits);
1556 // writer->Write(nbits, bits);
1557 uint64_t data = codes.encoding_info[histo][tok].bits;
1558 data |= bits << codes.encoding_info[histo][tok].depth;
1559 writer->Write(codes.encoding_info[histo][tok].depth + nbits, data);
1560 num_extra_bits += nbits;
1561 }
1562 return num_extra_bits;
1563 }
1564 std::vector<uint64_t> out;
1565 std::vector<uint8_t> out_nbits;
1566 out.reserve(tokens.size());
1567 out_nbits.reserve(tokens.size());
1568 uint64_t allbits = 0;
1569 size_t numallbits = 0;
1570 // Writes in *reversed* order.
1571 auto addbits = [&](size_t bits, size_t nbits) {
1572 JXL_DASSERT(bits >> nbits == 0);
1573 if (JXL_UNLIKELY(numallbits + nbits > BitWriter::kMaxBitsPerCall)) {
1574 out.push_back(allbits);
1575 out_nbits.push_back(numallbits);
1576 numallbits = allbits = 0;
1577 }
1578 allbits <<= nbits;
1579 allbits |= bits;
1580 numallbits += nbits;
1581 };
1582 const int end = tokens.size();
1583 ANSCoder ans;
1584 for (int i = end - 1; i >= 0; --i) {
1585 const Token token = tokens[i];
1586 const uint8_t histo = context_map[token.context];
1587 uint32_t tok, nbits, bits;
1588 (token.is_lz77_length ? codes.lz77.length_uint_config
1589 : codes.uint_config[histo])
1590 .Encode(tokens[i].value, &tok, &nbits, &bits);
1591 tok += token.is_lz77_length ? codes.lz77.min_symbol : 0;
1592 const ANSEncSymbolInfo& info = codes.encoding_info[histo][tok];
1593 // Extra bits first as this is reversed.
1594 addbits(bits, nbits);
1595 num_extra_bits += nbits;
1596 uint8_t ans_nbits = 0;
1597 uint32_t ans_bits = ans.PutSymbol(info, &ans_nbits);
1598 addbits(ans_bits, ans_nbits);
1599 }
1600 const uint32_t state = ans.GetState();
1601 writer->Write(32, state);
1602 writer->Write(numallbits, allbits);
1603 for (int i = out.size(); i > 0; --i) {
1604 writer->Write(out_nbits[i - 1], out[i - 1]);
1605 }
1606 return num_extra_bits;
1607 }
1608
WriteTokens(const std::vector<Token> & tokens,const EntropyEncodingData & codes,const std::vector<uint8_t> & context_map,BitWriter * writer,size_t layer,AuxOut * aux_out)1609 void WriteTokens(const std::vector<Token>& tokens,
1610 const EntropyEncodingData& codes,
1611 const std::vector<uint8_t>& context_map, BitWriter* writer,
1612 size_t layer, AuxOut* aux_out) {
1613 BitWriter::Allotment allotment(writer, 32 * tokens.size() + 32 * 1024 * 4);
1614 size_t num_extra_bits = WriteTokens(tokens, codes, context_map, writer);
1615 ReclaimAndCharge(writer, &allotment, layer, aux_out);
1616 if (aux_out != nullptr) {
1617 aux_out->layers[layer].extra_bits += num_extra_bits;
1618 }
1619 }
1620
SetANSFuzzerFriendly(bool ans_fuzzer_friendly)1621 void SetANSFuzzerFriendly(bool ans_fuzzer_friendly) {
1622 #if JXL_IS_DEBUG_BUILD // Guard against accidential / malicious changes.
1623 ans_fuzzer_friendly_ = ans_fuzzer_friendly;
1624 #endif
1625 }
1626 } // namespace jxl
1627