1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
2 //
3 // Use of this source code is governed by a BSD-style
4 // license that can be found in the LICENSE file.
5
6 #include "lib/jxl/enc_ans.h"
7
8 #include <stdint.h>
9
10 #include <algorithm>
11 #include <array>
12 #include <cmath>
13 #include <limits>
14 #include <numeric>
15 #include <type_traits>
16 #include <unordered_map>
17 #include <utility>
18 #include <vector>
19
20 #include "lib/jxl/ans_common.h"
21 #include "lib/jxl/aux_out.h"
22 #include "lib/jxl/aux_out_fwd.h"
23 #include "lib/jxl/base/bits.h"
24 #include "lib/jxl/dec_ans.h"
25 #include "lib/jxl/enc_cluster.h"
26 #include "lib/jxl/enc_context_map.h"
27 #include "lib/jxl/enc_huffman.h"
28 #include "lib/jxl/fast_math-inl.h"
29 #include "lib/jxl/fields.h"
30
31 namespace jxl {
32
33 namespace {
34
35 bool ans_fuzzer_friendly_ = false;
36
37 static const int kMaxNumSymbolsForSmallCode = 4;
38
ANSBuildInfoTable(const ANSHistBin * counts,const AliasTable::Entry * table,size_t alphabet_size,size_t log_alpha_size,ANSEncSymbolInfo * info)39 void ANSBuildInfoTable(const ANSHistBin* counts, const AliasTable::Entry* table,
40 size_t alphabet_size, size_t log_alpha_size,
41 ANSEncSymbolInfo* info) {
42 size_t log_entry_size = ANS_LOG_TAB_SIZE - log_alpha_size;
43 size_t entry_size_minus_1 = (1 << log_entry_size) - 1;
44 // create valid alias table for empty streams.
45 for (size_t s = 0; s < std::max<size_t>(1, alphabet_size); ++s) {
46 const ANSHistBin freq = s == alphabet_size ? ANS_TAB_SIZE : counts[s];
47 info[s].freq_ = static_cast<uint16_t>(freq);
48 #ifdef USE_MULT_BY_RECIPROCAL
49 if (freq != 0) {
50 info[s].ifreq_ =
51 ((1ull << RECIPROCAL_PRECISION) + info[s].freq_ - 1) / info[s].freq_;
52 } else {
53 info[s].ifreq_ = 1; // shouldn't matter (symbol shouldn't occur), but...
54 }
55 #endif
56 info[s].reverse_map_.resize(freq);
57 }
58 for (int i = 0; i < ANS_TAB_SIZE; i++) {
59 AliasTable::Symbol s =
60 AliasTable::Lookup(table, i, log_entry_size, entry_size_minus_1);
61 info[s.value].reverse_map_[s.offset] = i;
62 }
63 }
64
EstimateDataBits(const ANSHistBin * histogram,const ANSHistBin * counts,size_t len)65 float EstimateDataBits(const ANSHistBin* histogram, const ANSHistBin* counts,
66 size_t len) {
67 float sum = 0.0f;
68 int total_histogram = 0;
69 int total_counts = 0;
70 for (size_t i = 0; i < len; ++i) {
71 total_histogram += histogram[i];
72 total_counts += counts[i];
73 if (histogram[i] > 0) {
74 JXL_ASSERT(counts[i] > 0);
75 // += histogram[i] * -log(counts[i]/total_counts)
76 sum += histogram[i] *
77 std::max(0.0f, ANS_LOG_TAB_SIZE - FastLog2f(counts[i]));
78 }
79 }
80 if (total_histogram > 0) {
81 JXL_ASSERT(total_counts == ANS_TAB_SIZE);
82 }
83 return sum;
84 }
85
EstimateDataBitsFlat(const ANSHistBin * histogram,size_t len)86 float EstimateDataBitsFlat(const ANSHistBin* histogram, size_t len) {
87 const float flat_bits = std::max(FastLog2f(len), 0.0f);
88 int total_histogram = 0;
89 for (size_t i = 0; i < len; ++i) {
90 total_histogram += histogram[i];
91 }
92 return total_histogram * flat_bits;
93 }
94
95 // Static Huffman code for encoding logcounts. The last symbol is used as RLE
96 // sequence.
97 static const uint8_t kLogCountBitLengths[ANS_LOG_TAB_SIZE + 2] = {
98 5, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 6, 7, 7,
99 };
100 static const uint8_t kLogCountSymbols[ANS_LOG_TAB_SIZE + 2] = {
101 17, 11, 15, 3, 9, 7, 4, 2, 5, 6, 0, 33, 1, 65,
102 };
103
104 // Returns the difference between largest count that can be represented and is
105 // smaller than "count" and smallest representable count larger than "count".
SmallestIncrement(uint32_t count,uint32_t shift)106 static int SmallestIncrement(uint32_t count, uint32_t shift) {
107 int bits = count == 0 ? -1 : FloorLog2Nonzero(count);
108 int drop_bits = bits - GetPopulationCountPrecision(bits, shift);
109 return drop_bits < 0 ? 1 : (1 << drop_bits);
110 }
111
112 template <bool minimize_error_of_sum>
RebalanceHistogram(const float * targets,int max_symbol,int table_size,uint32_t shift,int * omit_pos,ANSHistBin * counts)113 bool RebalanceHistogram(const float* targets, int max_symbol, int table_size,
114 uint32_t shift, int* omit_pos, ANSHistBin* counts) {
115 int sum = 0;
116 float sum_nonrounded = 0.0;
117 int remainder_pos = 0; // if all of them are handled in first loop
118 int remainder_log = -1;
119 for (int n = 0; n < max_symbol; ++n) {
120 if (targets[n] > 0 && targets[n] < 1.0f) {
121 counts[n] = 1;
122 sum_nonrounded += targets[n];
123 sum += counts[n];
124 }
125 }
126 const float discount_ratio =
127 (table_size - sum) / (table_size - sum_nonrounded);
128 JXL_ASSERT(discount_ratio > 0);
129 JXL_ASSERT(discount_ratio <= 1.0f);
130 // Invariant for minimize_error_of_sum == true:
131 // abs(sum - sum_nonrounded)
132 // <= SmallestIncrement(max(targets[])) + max_symbol
133 for (int n = 0; n < max_symbol; ++n) {
134 if (targets[n] >= 1.0f) {
135 sum_nonrounded += targets[n];
136 counts[n] =
137 static_cast<ANSHistBin>(targets[n] * discount_ratio); // truncate
138 if (counts[n] == 0) counts[n] = 1;
139 if (counts[n] == table_size) counts[n] = table_size - 1;
140 // Round the count to the closest nonzero multiple of SmallestIncrement
141 // (when minimize_error_of_sum is false) or one of two closest so as to
142 // keep the sum as close as possible to sum_nonrounded.
143 int inc = SmallestIncrement(counts[n], shift);
144 counts[n] -= counts[n] & (inc - 1);
145 // TODO(robryk): Should we rescale targets[n]?
146 const float target =
147 minimize_error_of_sum ? (sum_nonrounded - sum) : targets[n];
148 if (counts[n] == 0 ||
149 (target > counts[n] + inc / 2 && counts[n] + inc < table_size)) {
150 counts[n] += inc;
151 }
152 sum += counts[n];
153 const int count_log = FloorLog2Nonzero(static_cast<uint32_t>(counts[n]));
154 if (count_log > remainder_log) {
155 remainder_pos = n;
156 remainder_log = count_log;
157 }
158 }
159 }
160 JXL_ASSERT(remainder_pos != -1);
161 // NOTE: This is the only place where counts could go negative. We could
162 // detect that, return false and make ANSHistBin uint32_t.
163 counts[remainder_pos] -= sum - table_size;
164 *omit_pos = remainder_pos;
165 return counts[remainder_pos] > 0;
166 }
167
NormalizeCounts(ANSHistBin * counts,int * omit_pos,const int length,const int precision_bits,uint32_t shift,int * num_symbols,int * symbols)168 Status NormalizeCounts(ANSHistBin* counts, int* omit_pos, const int length,
169 const int precision_bits, uint32_t shift,
170 int* num_symbols, int* symbols) {
171 const int32_t table_size = 1 << precision_bits; // target sum / table size
172 uint64_t total = 0;
173 int max_symbol = 0;
174 int symbol_count = 0;
175 for (int n = 0; n < length; ++n) {
176 total += counts[n];
177 if (counts[n] > 0) {
178 if (symbol_count < kMaxNumSymbolsForSmallCode) {
179 symbols[symbol_count] = n;
180 }
181 ++symbol_count;
182 max_symbol = n + 1;
183 }
184 }
185 *num_symbols = symbol_count;
186 if (symbol_count == 0) {
187 return true;
188 }
189 if (symbol_count == 1) {
190 counts[symbols[0]] = table_size;
191 return true;
192 }
193 if (symbol_count > table_size)
194 return JXL_FAILURE("Too many entries in an ANS histogram");
195
196 const float norm = 1.f * table_size / total;
197 std::vector<float> targets(max_symbol);
198 for (size_t n = 0; n < targets.size(); ++n) {
199 targets[n] = norm * counts[n];
200 }
201 if (!RebalanceHistogram<false>(&targets[0], max_symbol, table_size, shift,
202 omit_pos, counts)) {
203 // Use an alternative rebalancing mechanism if the one above failed
204 // to create a histogram that is positive wherever the original one was.
205 if (!RebalanceHistogram<true>(&targets[0], max_symbol, table_size, shift,
206 omit_pos, counts)) {
207 return JXL_FAILURE("Logic error: couldn't rebalance a histogram");
208 }
209 }
210 return true;
211 }
212
213 struct SizeWriter {
214 size_t size = 0;
Writejxl::__anonfd93e7cc0111::SizeWriter215 void Write(size_t num, size_t bits) { size += num; }
216 };
217
218 template <typename Writer>
StoreVarLenUint8(size_t n,Writer * writer)219 void StoreVarLenUint8(size_t n, Writer* writer) {
220 JXL_DASSERT(n <= 255);
221 if (n == 0) {
222 writer->Write(1, 0);
223 } else {
224 writer->Write(1, 1);
225 size_t nbits = FloorLog2Nonzero(n);
226 writer->Write(3, nbits);
227 writer->Write(nbits, n - (1ULL << nbits));
228 }
229 }
230
231 template <typename Writer>
StoreVarLenUint16(size_t n,Writer * writer)232 void StoreVarLenUint16(size_t n, Writer* writer) {
233 JXL_DASSERT(n <= 65535);
234 if (n == 0) {
235 writer->Write(1, 0);
236 } else {
237 writer->Write(1, 1);
238 size_t nbits = FloorLog2Nonzero(n);
239 writer->Write(4, nbits);
240 writer->Write(nbits, n - (1ULL << nbits));
241 }
242 }
243
244 template <typename Writer>
EncodeCounts(const ANSHistBin * counts,const int alphabet_size,const int omit_pos,const int num_symbols,uint32_t shift,const int * symbols,Writer * writer)245 bool EncodeCounts(const ANSHistBin* counts, const int alphabet_size,
246 const int omit_pos, const int num_symbols, uint32_t shift,
247 const int* symbols, Writer* writer) {
248 bool ok = true;
249 if (num_symbols <= 2) {
250 // Small tree marker to encode 1-2 symbols.
251 writer->Write(1, 1);
252 if (num_symbols == 0) {
253 writer->Write(1, 0);
254 StoreVarLenUint8(0, writer);
255 } else {
256 writer->Write(1, num_symbols - 1);
257 for (int i = 0; i < num_symbols; ++i) {
258 StoreVarLenUint8(symbols[i], writer);
259 }
260 }
261 if (num_symbols == 2) {
262 writer->Write(ANS_LOG_TAB_SIZE, counts[symbols[0]]);
263 }
264 } else {
265 // Mark non-small tree.
266 writer->Write(1, 0);
267 // Mark non-flat histogram.
268 writer->Write(1, 0);
269
270 // Precompute sequences for RLE encoding. Contains the number of identical
271 // values starting at a given index. Only contains the value at the first
272 // element of the series.
273 std::vector<uint32_t> same(alphabet_size, 0);
274 int last = 0;
275 for (int i = 1; i < alphabet_size; i++) {
276 // Store the sequence length once different symbol reached, or we're at
277 // the end, or the length is longer than we can encode, or we are at
278 // the omit_pos. We don't support including the omit_pos in an RLE
279 // sequence because this value may use a different amount of log2 bits
280 // than standard, it is too complex to handle in the decoder.
281 if (counts[i] != counts[last] || i + 1 == alphabet_size ||
282 (i - last) >= 255 || i == omit_pos || i == omit_pos + 1) {
283 same[last] = (i - last);
284 last = i + 1;
285 }
286 }
287
288 int length = 0;
289 std::vector<int> logcounts(alphabet_size);
290 int omit_log = 0;
291 for (int i = 0; i < alphabet_size; ++i) {
292 JXL_ASSERT(counts[i] <= ANS_TAB_SIZE);
293 JXL_ASSERT(counts[i] >= 0);
294 if (i == omit_pos) {
295 length = i + 1;
296 } else if (counts[i] > 0) {
297 logcounts[i] = FloorLog2Nonzero(static_cast<uint32_t>(counts[i])) + 1;
298 length = i + 1;
299 if (i < omit_pos) {
300 omit_log = std::max(omit_log, logcounts[i] + 1);
301 } else {
302 omit_log = std::max(omit_log, logcounts[i]);
303 }
304 }
305 }
306 logcounts[omit_pos] = omit_log;
307
308 // Elias gamma-like code for shift. Only difference is that if the number
309 // of bits to be encoded is equal to FloorLog2(ANS_LOG_TAB_SIZE+1), we skip
310 // the terminating 0 in unary coding.
311 int upper_bound_log = FloorLog2Nonzero(ANS_LOG_TAB_SIZE + 1);
312 int log = FloorLog2Nonzero(shift + 1);
313 writer->Write(log, (1 << log) - 1);
314 if (log != upper_bound_log) writer->Write(1, 0);
315 writer->Write(log, ((1 << log) - 1) & (shift + 1));
316
317 // Since num_symbols >= 3, we know that length >= 3, therefore we encode
318 // length - 3.
319 if (length - 3 > 255) {
320 // Pretend that everything is OK, but complain about correctness later.
321 StoreVarLenUint8(255, writer);
322 ok = false;
323 } else {
324 StoreVarLenUint8(length - 3, writer);
325 }
326
327 // The logcount values are encoded with a static Huffman code.
328 static const size_t kMinReps = 4;
329 size_t rep = ANS_LOG_TAB_SIZE + 1;
330 for (int i = 0; i < length; ++i) {
331 if (i > 0 && same[i - 1] > kMinReps) {
332 // Encode the RLE symbol and skip the repeated ones.
333 writer->Write(kLogCountBitLengths[rep], kLogCountSymbols[rep]);
334 StoreVarLenUint8(same[i - 1] - kMinReps - 1, writer);
335 i += same[i - 1] - 2;
336 continue;
337 }
338 writer->Write(kLogCountBitLengths[logcounts[i]],
339 kLogCountSymbols[logcounts[i]]);
340 }
341 for (int i = 0; i < length; ++i) {
342 if (i > 0 && same[i - 1] > kMinReps) {
343 // Skip symbols encoded by RLE.
344 i += same[i - 1] - 2;
345 continue;
346 }
347 if (logcounts[i] > 1 && i != omit_pos) {
348 int bitcount = GetPopulationCountPrecision(logcounts[i] - 1, shift);
349 int drop_bits = logcounts[i] - 1 - bitcount;
350 JXL_CHECK((counts[i] & ((1 << drop_bits) - 1)) == 0);
351 writer->Write(bitcount, (counts[i] >> drop_bits) - (1 << bitcount));
352 }
353 }
354 }
355 return ok;
356 }
357
EncodeFlatHistogram(const int alphabet_size,BitWriter * writer)358 void EncodeFlatHistogram(const int alphabet_size, BitWriter* writer) {
359 // Mark non-small tree.
360 writer->Write(1, 0);
361 // Mark uniform histogram.
362 writer->Write(1, 1);
363 JXL_ASSERT(alphabet_size > 0);
364 // Encode alphabet size.
365 StoreVarLenUint8(alphabet_size - 1, writer);
366 }
367
ComputeHistoAndDataCost(const ANSHistBin * histogram,size_t alphabet_size,uint32_t method)368 float ComputeHistoAndDataCost(const ANSHistBin* histogram, size_t alphabet_size,
369 uint32_t method) {
370 if (method == 0) { // Flat code
371 return ANS_LOG_TAB_SIZE + 2 +
372 EstimateDataBitsFlat(histogram, alphabet_size);
373 }
374 // Non-flat: shift = method-1.
375 uint32_t shift = method - 1;
376 std::vector<ANSHistBin> counts(histogram, histogram + alphabet_size);
377 int omit_pos = 0;
378 int num_symbols;
379 int symbols[kMaxNumSymbolsForSmallCode] = {};
380 JXL_CHECK(NormalizeCounts(counts.data(), &omit_pos, alphabet_size,
381 ANS_LOG_TAB_SIZE, shift, &num_symbols, symbols));
382 SizeWriter writer;
383 // Ignore the correctness, no real encoding happens at this stage.
384 (void)EncodeCounts(counts.data(), alphabet_size, omit_pos, num_symbols, shift,
385 symbols, &writer);
386 return writer.size +
387 EstimateDataBits(histogram, counts.data(), alphabet_size);
388 }
389
ComputeBestMethod(const ANSHistBin * histogram,size_t alphabet_size,float * cost,HistogramParams::ANSHistogramStrategy ans_histogram_strategy)390 uint32_t ComputeBestMethod(
391 const ANSHistBin* histogram, size_t alphabet_size, float* cost,
392 HistogramParams::ANSHistogramStrategy ans_histogram_strategy) {
393 size_t method = 0;
394 float fcost = ComputeHistoAndDataCost(histogram, alphabet_size, 0);
395 auto try_shift = [&](size_t shift) {
396 float c = ComputeHistoAndDataCost(histogram, alphabet_size, shift + 1);
397 if (c < fcost) {
398 method = shift + 1;
399 fcost = c;
400 }
401 };
402 switch (ans_histogram_strategy) {
403 case HistogramParams::ANSHistogramStrategy::kPrecise: {
404 for (uint32_t shift = 0; shift <= ANS_LOG_TAB_SIZE; shift++) {
405 try_shift(shift);
406 }
407 break;
408 }
409 case HistogramParams::ANSHistogramStrategy::kApproximate: {
410 for (uint32_t shift = 0; shift <= ANS_LOG_TAB_SIZE; shift += 2) {
411 try_shift(shift);
412 }
413 break;
414 }
415 case HistogramParams::ANSHistogramStrategy::kFast: {
416 try_shift(0);
417 try_shift(ANS_LOG_TAB_SIZE / 2);
418 try_shift(ANS_LOG_TAB_SIZE);
419 break;
420 }
421 };
422 *cost = fcost;
423 return method;
424 }
425
426 } // namespace
427
428 // Returns an estimate of the cost of encoding this histogram and the
429 // corresponding data.
BuildAndStoreANSEncodingData(HistogramParams::ANSHistogramStrategy ans_histogram_strategy,const ANSHistBin * histogram,size_t alphabet_size,size_t log_alpha_size,bool use_prefix_code,ANSEncSymbolInfo * info,BitWriter * writer)430 size_t BuildAndStoreANSEncodingData(
431 HistogramParams::ANSHistogramStrategy ans_histogram_strategy,
432 const ANSHistBin* histogram, size_t alphabet_size, size_t log_alpha_size,
433 bool use_prefix_code, ANSEncSymbolInfo* info, BitWriter* writer) {
434 if (use_prefix_code) {
435 if (alphabet_size <= 1) return 0;
436 std::vector<uint32_t> histo(alphabet_size);
437 for (size_t i = 0; i < alphabet_size; i++) {
438 histo[i] = histogram[i];
439 JXL_CHECK(histogram[i] >= 0);
440 }
441 size_t cost = 0;
442 {
443 std::vector<uint8_t> depths(alphabet_size);
444 std::vector<uint16_t> bits(alphabet_size);
445 BitWriter tmp_writer;
446 BitWriter* w = writer ? writer : &tmp_writer;
447 size_t start = w->BitsWritten();
448 BitWriter::Allotment allotment(
449 w, 8 * alphabet_size + 8); // safe upper bound
450 BuildAndStoreHuffmanTree(histo.data(), alphabet_size, depths.data(),
451 bits.data(), w);
452 ReclaimAndCharge(w, &allotment, 0, /*aux_out=*/nullptr);
453
454 for (size_t i = 0; i < alphabet_size; i++) {
455 info[i].bits = depths[i] == 0 ? 0 : bits[i];
456 info[i].depth = depths[i];
457 }
458 cost = w->BitsWritten() - start;
459 }
460 // Estimate data cost.
461 for (size_t i = 0; i < alphabet_size; i++) {
462 cost += histogram[i] * info[i].depth;
463 }
464 return cost;
465 }
466 JXL_ASSERT(alphabet_size <= ANS_TAB_SIZE);
467 // Ensure we ignore trailing zeros in the histogram.
468 if (alphabet_size != 0) {
469 size_t largest_symbol = 0;
470 for (size_t i = 0; i < alphabet_size; i++) {
471 if (histogram[i] != 0) largest_symbol = i;
472 }
473 alphabet_size = largest_symbol + 1;
474 }
475 float cost;
476 uint32_t method = ComputeBestMethod(histogram, alphabet_size, &cost,
477 ans_histogram_strategy);
478 JXL_ASSERT(cost >= 0);
479 int num_symbols;
480 int symbols[kMaxNumSymbolsForSmallCode] = {};
481 std::vector<ANSHistBin> counts(histogram, histogram + alphabet_size);
482 if (!counts.empty()) {
483 size_t sum = 0;
484 for (size_t i = 0; i < counts.size(); i++) {
485 sum += counts[i];
486 }
487 if (sum == 0) {
488 counts[0] = ANS_TAB_SIZE;
489 }
490 }
491 if (method == 0) {
492 counts = CreateFlatHistogram(alphabet_size, ANS_TAB_SIZE);
493 AliasTable::Entry a[ANS_MAX_ALPHABET_SIZE];
494 InitAliasTable(counts, ANS_TAB_SIZE, log_alpha_size, a);
495 ANSBuildInfoTable(counts.data(), a, alphabet_size, log_alpha_size, info);
496 if (writer != nullptr) {
497 EncodeFlatHistogram(alphabet_size, writer);
498 }
499 return cost;
500 }
501 int omit_pos = 0;
502 uint32_t shift = method - 1;
503 JXL_CHECK(NormalizeCounts(counts.data(), &omit_pos, alphabet_size,
504 ANS_LOG_TAB_SIZE, shift, &num_symbols, symbols));
505 AliasTable::Entry a[ANS_MAX_ALPHABET_SIZE];
506 InitAliasTable(counts, ANS_TAB_SIZE, log_alpha_size, a);
507 ANSBuildInfoTable(counts.data(), a, alphabet_size, log_alpha_size, info);
508 if (writer != nullptr) {
509 bool ok = EncodeCounts(counts.data(), alphabet_size, omit_pos, num_symbols,
510 shift, symbols, writer);
511 (void)ok;
512 JXL_DASSERT(ok);
513 }
514 return cost;
515 }
516
ANSPopulationCost(const ANSHistBin * data,size_t alphabet_size)517 float ANSPopulationCost(const ANSHistBin* data, size_t alphabet_size) {
518 float c;
519 ComputeBestMethod(data, alphabet_size, &c,
520 HistogramParams::ANSHistogramStrategy::kFast);
521 return c;
522 }
523
524 template <typename Writer>
EncodeUintConfig(const HybridUintConfig uint_config,Writer * writer,size_t log_alpha_size)525 void EncodeUintConfig(const HybridUintConfig uint_config, Writer* writer,
526 size_t log_alpha_size) {
527 writer->Write(CeilLog2Nonzero(log_alpha_size + 1),
528 uint_config.split_exponent);
529 if (uint_config.split_exponent == log_alpha_size) {
530 return; // msb/lsb don't matter.
531 }
532 size_t nbits = CeilLog2Nonzero(uint_config.split_exponent + 1);
533 writer->Write(nbits, uint_config.msb_in_token);
534 nbits = CeilLog2Nonzero(uint_config.split_exponent -
535 uint_config.msb_in_token + 1);
536 writer->Write(nbits, uint_config.lsb_in_token);
537 }
538 template <typename Writer>
EncodeUintConfigs(const std::vector<HybridUintConfig> & uint_config,Writer * writer,size_t log_alpha_size)539 void EncodeUintConfigs(const std::vector<HybridUintConfig>& uint_config,
540 Writer* writer, size_t log_alpha_size) {
541 // TODO(veluca): RLE?
542 for (size_t i = 0; i < uint_config.size(); i++) {
543 EncodeUintConfig(uint_config[i], writer, log_alpha_size);
544 }
545 }
546 template void EncodeUintConfigs(const std::vector<HybridUintConfig>&,
547 BitWriter*, size_t);
548
549 namespace {
550
ChooseUintConfigs(const HistogramParams & params,const std::vector<std::vector<Token>> & tokens,const std::vector<uint8_t> & context_map,std::vector<Histogram> * clustered_histograms,EntropyEncodingData * codes,size_t * log_alpha_size)551 void ChooseUintConfigs(const HistogramParams& params,
552 const std::vector<std::vector<Token>>& tokens,
553 const std::vector<uint8_t>& context_map,
554 std::vector<Histogram>* clustered_histograms,
555 EntropyEncodingData* codes, size_t* log_alpha_size) {
556 codes->uint_config.resize(clustered_histograms->size());
557
558 if (params.uint_method == HistogramParams::HybridUintMethod::kNone) return;
559 if (params.uint_method == HistogramParams::HybridUintMethod::kContextMap) {
560 codes->uint_config.clear();
561 codes->uint_config.resize(clustered_histograms->size(),
562 HybridUintConfig(2, 0, 1));
563 return;
564 }
565
566 // Brute-force method that tries a few options.
567 std::vector<HybridUintConfig> configs;
568 if (params.uint_method == HistogramParams::HybridUintMethod::kBest) {
569 configs = {
570 HybridUintConfig(4, 2, 0), // default
571 HybridUintConfig(4, 1, 0), // less precise
572 HybridUintConfig(4, 2, 1), // add sign
573 HybridUintConfig(4, 2, 2), // add sign+parity
574 HybridUintConfig(4, 1, 2), // add parity but less msb
575 // Same as above, but more direct coding.
576 HybridUintConfig(5, 2, 0), HybridUintConfig(5, 1, 0),
577 HybridUintConfig(5, 2, 1), HybridUintConfig(5, 2, 2),
578 HybridUintConfig(5, 1, 2),
579 // Same as above, but less direct coding.
580 HybridUintConfig(3, 2, 0), HybridUintConfig(3, 1, 0),
581 HybridUintConfig(3, 2, 1), HybridUintConfig(3, 1, 2),
582 // For near-lossless.
583 HybridUintConfig(4, 1, 3), HybridUintConfig(5, 1, 4),
584 HybridUintConfig(5, 2, 3), HybridUintConfig(6, 1, 5),
585 HybridUintConfig(6, 2, 4), HybridUintConfig(6, 0, 0),
586 // Other
587 HybridUintConfig(0, 0, 0), // varlenuint
588 HybridUintConfig(2, 0, 1), // works well for ctx map
589 HybridUintConfig(7, 0, 0), // direct coding
590 HybridUintConfig(8, 0, 0), // direct coding
591 HybridUintConfig(9, 0, 0), // direct coding
592 HybridUintConfig(10, 0, 0), // direct coding
593 HybridUintConfig(11, 0, 0), // direct coding
594 HybridUintConfig(12, 0, 0), // direct coding
595 };
596 } else if (params.uint_method == HistogramParams::HybridUintMethod::kFast) {
597 configs = {
598 HybridUintConfig(4, 2, 0), // default
599 HybridUintConfig(4, 1, 2), // add parity but less msb
600 HybridUintConfig(0, 0, 0), // smallest histograms
601 HybridUintConfig(2, 0, 1), // works well for ctx map
602 };
603 }
604
605 std::vector<float> costs(clustered_histograms->size(),
606 std::numeric_limits<float>::max());
607 std::vector<uint32_t> extra_bits(clustered_histograms->size());
608 std::vector<uint8_t> is_valid(clustered_histograms->size());
609 size_t max_alpha =
610 codes->use_prefix_code ? PREFIX_MAX_ALPHABET_SIZE : ANS_MAX_ALPHABET_SIZE;
611 for (HybridUintConfig cfg : configs) {
612 std::fill(is_valid.begin(), is_valid.end(), true);
613 std::fill(extra_bits.begin(), extra_bits.end(), 0);
614
615 for (size_t i = 0; i < clustered_histograms->size(); i++) {
616 (*clustered_histograms)[i].Clear();
617 }
618 for (size_t i = 0; i < tokens.size(); ++i) {
619 for (size_t j = 0; j < tokens[i].size(); ++j) {
620 const Token token = tokens[i][j];
621 // TODO(veluca): do not ignore lz77 commands.
622 if (token.is_lz77_length) continue;
623 size_t histo = context_map[token.context];
624 uint32_t tok, nbits, bits;
625 cfg.Encode(token.value, &tok, &nbits, &bits);
626 if (tok >= max_alpha ||
627 (codes->lz77.enabled && tok >= codes->lz77.min_symbol)) {
628 is_valid[histo] = false;
629 continue;
630 }
631 extra_bits[histo] += nbits;
632 (*clustered_histograms)[histo].Add(tok);
633 }
634 }
635
636 for (size_t i = 0; i < clustered_histograms->size(); i++) {
637 if (!is_valid[i]) continue;
638 float cost = (*clustered_histograms)[i].PopulationCost() + extra_bits[i];
639 // add signaling cost of the hybriduintconfig itself
640 cost += CeilLog2Nonzero(cfg.split_exponent + 1);
641 cost += CeilLog2Nonzero(cfg.split_exponent - cfg.msb_in_token + 1);
642 if (cost < costs[i]) {
643 codes->uint_config[i] = cfg;
644 costs[i] = cost;
645 }
646 }
647 }
648
649 // Rebuild histograms.
650 for (size_t i = 0; i < clustered_histograms->size(); i++) {
651 (*clustered_histograms)[i].Clear();
652 }
653 *log_alpha_size = 4;
654 for (size_t i = 0; i < tokens.size(); ++i) {
655 for (size_t j = 0; j < tokens[i].size(); ++j) {
656 const Token token = tokens[i][j];
657 uint32_t tok, nbits, bits;
658 size_t histo = context_map[token.context];
659 (token.is_lz77_length ? codes->lz77.length_uint_config
660 : codes->uint_config[histo])
661 .Encode(token.value, &tok, &nbits, &bits);
662 tok += token.is_lz77_length ? codes->lz77.min_symbol : 0;
663 (*clustered_histograms)[histo].Add(tok);
664 while (tok >= (1u << *log_alpha_size)) (*log_alpha_size)++;
665 }
666 }
667 #if JXL_ENABLE_ASSERT
668 size_t max_log_alpha_size = codes->use_prefix_code ? PREFIX_MAX_BITS : 8;
669 JXL_ASSERT(*log_alpha_size <= max_log_alpha_size);
670 #endif
671 }
672
673 class HistogramBuilder {
674 public:
HistogramBuilder(const size_t num_contexts)675 explicit HistogramBuilder(const size_t num_contexts)
676 : histograms_(num_contexts) {}
677
VisitSymbol(int symbol,size_t histo_idx)678 void VisitSymbol(int symbol, size_t histo_idx) {
679 JXL_DASSERT(histo_idx < histograms_.size());
680 histograms_[histo_idx].Add(symbol);
681 }
682
683 // NOTE: `layer` is only for clustered_entropy; caller does ReclaimAndCharge.
BuildAndStoreEntropyCodes(const HistogramParams & params,const std::vector<std::vector<Token>> & tokens,EntropyEncodingData * codes,std::vector<uint8_t> * context_map,bool use_prefix_code,BitWriter * writer,size_t layer,AuxOut * aux_out) const684 size_t BuildAndStoreEntropyCodes(
685 const HistogramParams& params,
686 const std::vector<std::vector<Token>>& tokens, EntropyEncodingData* codes,
687 std::vector<uint8_t>* context_map, bool use_prefix_code,
688 BitWriter* writer, size_t layer, AuxOut* aux_out) const {
689 size_t cost = 0;
690 codes->encoding_info.clear();
691 std::vector<Histogram> clustered_histograms(histograms_);
692 context_map->resize(histograms_.size());
693 if (histograms_.size() > 1) {
694 if (!ans_fuzzer_friendly_) {
695 std::vector<uint32_t> histogram_symbols;
696 ClusterHistograms(params, histograms_, histograms_.size(),
697 kClustersLimit, &clustered_histograms,
698 &histogram_symbols);
699 for (size_t c = 0; c < histograms_.size(); ++c) {
700 (*context_map)[c] = static_cast<uint8_t>(histogram_symbols[c]);
701 }
702 } else {
703 fill(context_map->begin(), context_map->end(), 0);
704 size_t max_symbol = 0;
705 for (const Histogram& h : histograms_) {
706 max_symbol = std::max(h.data_.size(), max_symbol);
707 }
708 size_t num_symbols = 1 << CeilLog2Nonzero(max_symbol + 1);
709 clustered_histograms.resize(1);
710 clustered_histograms[0].Clear();
711 for (size_t i = 0; i < num_symbols; i++) {
712 clustered_histograms[0].Add(i);
713 }
714 }
715 if (writer != nullptr) {
716 EncodeContextMap(*context_map, clustered_histograms.size(), writer);
717 }
718 }
719 if (aux_out != nullptr) {
720 for (size_t i = 0; i < clustered_histograms.size(); ++i) {
721 aux_out->layers[layer].clustered_entropy +=
722 clustered_histograms[i].ShannonEntropy();
723 }
724 }
725 codes->use_prefix_code = use_prefix_code;
726 size_t log_alpha_size = codes->lz77.enabled ? 8 : 7; // Sane default.
727 if (ans_fuzzer_friendly_) {
728 codes->uint_config.clear();
729 codes->uint_config.resize(1, HybridUintConfig(7, 0, 0));
730 } else {
731 ChooseUintConfigs(params, tokens, *context_map, &clustered_histograms,
732 codes, &log_alpha_size);
733 }
734 if (log_alpha_size < 5) log_alpha_size = 5;
735 SizeWriter size_writer; // Used if writer == nullptr to estimate costs.
736 cost += 1;
737 if (writer) writer->Write(1, use_prefix_code);
738
739 if (use_prefix_code) {
740 log_alpha_size = PREFIX_MAX_BITS;
741 } else {
742 cost += 2;
743 }
744 if (writer == nullptr) {
745 EncodeUintConfigs(codes->uint_config, &size_writer, log_alpha_size);
746 } else {
747 if (!use_prefix_code) writer->Write(2, log_alpha_size - 5);
748 EncodeUintConfigs(codes->uint_config, writer, log_alpha_size);
749 }
750 if (use_prefix_code) {
751 for (size_t c = 0; c < clustered_histograms.size(); ++c) {
752 size_t num_symbol = 1;
753 for (size_t i = 0; i < clustered_histograms[c].data_.size(); i++) {
754 if (clustered_histograms[c].data_[i]) num_symbol = i + 1;
755 }
756 if (writer) {
757 StoreVarLenUint16(num_symbol - 1, writer);
758 } else {
759 StoreVarLenUint16(num_symbol - 1, &size_writer);
760 }
761 }
762 }
763 cost += size_writer.size;
764 for (size_t c = 0; c < clustered_histograms.size(); ++c) {
765 size_t num_symbol = 1;
766 for (size_t i = 0; i < clustered_histograms[c].data_.size(); i++) {
767 if (clustered_histograms[c].data_[i]) num_symbol = i + 1;
768 }
769 codes->encoding_info.emplace_back();
770 codes->encoding_info.back().resize(std::max<size_t>(1, num_symbol));
771
772 BitWriter::Allotment allotment(writer, 256 + num_symbol * 24);
773 cost += BuildAndStoreANSEncodingData(
774 params.ans_histogram_strategy, clustered_histograms[c].data_.data(),
775 num_symbol, log_alpha_size, use_prefix_code,
776 codes->encoding_info.back().data(), writer);
777 allotment.FinishedHistogram(writer);
778 ReclaimAndCharge(writer, &allotment, layer, aux_out);
779 }
780 return cost;
781 }
782
Histo(size_t i) const783 const Histogram& Histo(size_t i) const { return histograms_[i]; }
784
785 private:
786 std::vector<Histogram> histograms_;
787 };
788
789 class SymbolCostEstimator {
790 public:
SymbolCostEstimator(size_t num_contexts,bool force_huffman,const std::vector<std::vector<Token>> & tokens,const LZ77Params & lz77)791 SymbolCostEstimator(size_t num_contexts, bool force_huffman,
792 const std::vector<std::vector<Token>>& tokens,
793 const LZ77Params& lz77) {
794 HistogramBuilder builder(num_contexts);
795 // Build histograms for estimating lz77 savings.
796 HybridUintConfig uint_config;
797 for (size_t i = 0; i < tokens.size(); ++i) {
798 for (size_t j = 0; j < tokens[i].size(); ++j) {
799 const Token token = tokens[i][j];
800 uint32_t tok, nbits, bits;
801 (token.is_lz77_length ? lz77.length_uint_config : uint_config)
802 .Encode(token.value, &tok, &nbits, &bits);
803 tok += token.is_lz77_length ? lz77.min_symbol : 0;
804 builder.VisitSymbol(tok, token.context);
805 }
806 }
807 max_alphabet_size_ = 0;
808 for (size_t i = 0; i < num_contexts; i++) {
809 max_alphabet_size_ =
810 std::max(max_alphabet_size_, builder.Histo(i).data_.size());
811 }
812 bits_.resize(num_contexts * max_alphabet_size_);
813 // TODO(veluca): SIMD?
814 add_symbol_cost_.resize(num_contexts);
815 for (size_t i = 0; i < num_contexts; i++) {
816 float inv_total = 1.0f / (builder.Histo(i).total_count_ + 1e-8f);
817 float total_cost = 0;
818 for (size_t j = 0; j < builder.Histo(i).data_.size(); j++) {
819 size_t cnt = builder.Histo(i).data_[j];
820 float cost = 0;
821 if (cnt != 0 && cnt != builder.Histo(i).total_count_) {
822 cost = -FastLog2f(cnt * inv_total);
823 if (force_huffman) cost = std::ceil(cost);
824 } else if (cnt == 0) {
825 cost = ANS_LOG_TAB_SIZE; // Highest possible cost.
826 }
827 bits_[i * max_alphabet_size_ + j] = cost;
828 total_cost += cost * builder.Histo(i).data_[j];
829 }
830 // Penalty for adding a lz77 symbol to this contest (only used for static
831 // cost model). Higher penalty for contexts that have a very low
832 // per-symbol entropy.
833 add_symbol_cost_[i] = std::max(0.0f, 6.0f - total_cost * inv_total);
834 }
835 }
Bits(size_t ctx,size_t sym) const836 float Bits(size_t ctx, size_t sym) const {
837 return bits_[ctx * max_alphabet_size_ + sym];
838 }
LenCost(size_t ctx,size_t len,const LZ77Params & lz77) const839 float LenCost(size_t ctx, size_t len, const LZ77Params& lz77) const {
840 uint32_t nbits, bits, tok;
841 lz77.length_uint_config.Encode(len, &tok, &nbits, &bits);
842 tok += lz77.min_symbol;
843 return nbits + Bits(ctx, tok);
844 }
DistCost(size_t len,const LZ77Params & lz77) const845 float DistCost(size_t len, const LZ77Params& lz77) const {
846 uint32_t nbits, bits, tok;
847 HybridUintConfig().Encode(len, &tok, &nbits, &bits);
848 return nbits + Bits(lz77.nonserialized_distance_context, tok);
849 }
AddSymbolCost(size_t idx) const850 float AddSymbolCost(size_t idx) const { return add_symbol_cost_[idx]; }
851
852 private:
853 size_t max_alphabet_size_;
854 std::vector<float> bits_;
855 std::vector<float> add_symbol_cost_;
856 };
857
ApplyLZ77_RLE(const HistogramParams & params,size_t num_contexts,const std::vector<std::vector<Token>> & tokens,LZ77Params & lz77,std::vector<std::vector<Token>> & tokens_lz77)858 void ApplyLZ77_RLE(const HistogramParams& params, size_t num_contexts,
859 const std::vector<std::vector<Token>>& tokens,
860 LZ77Params& lz77,
861 std::vector<std::vector<Token>>& tokens_lz77) {
862 // TODO(veluca): tune heuristics here.
863 SymbolCostEstimator sce(num_contexts, params.force_huffman, tokens, lz77);
864 float bit_decrease = 0;
865 size_t total_symbols = 0;
866 tokens_lz77.resize(tokens.size());
867 std::vector<float> sym_cost;
868 HybridUintConfig uint_config;
869 for (size_t stream = 0; stream < tokens.size(); stream++) {
870 size_t distance_multiplier =
871 params.image_widths.size() > stream ? params.image_widths[stream] : 0;
872 const auto& in = tokens[stream];
873 auto& out = tokens_lz77[stream];
874 total_symbols += in.size();
875 // Cumulative sum of bit costs.
876 sym_cost.resize(in.size() + 1);
877 for (size_t i = 0; i < in.size(); i++) {
878 uint32_t tok, nbits, unused_bits;
879 uint_config.Encode(in[i].value, &tok, &nbits, &unused_bits);
880 sym_cost[i + 1] = sce.Bits(in[i].context, tok) + nbits + sym_cost[i];
881 }
882 out.reserve(in.size());
883 for (size_t i = 0; i < in.size(); i++) {
884 size_t num_to_copy = 0;
885 size_t distance_symbol = 0; // 1 for RLE.
886 if (distance_multiplier != 0) {
887 distance_symbol = 1; // Special distance 1 if enabled.
888 JXL_DASSERT(kSpecialDistances[1][0] == 1);
889 JXL_DASSERT(kSpecialDistances[1][1] == 0);
890 }
891 if (i > 0) {
892 for (; i + num_to_copy < in.size(); num_to_copy++) {
893 if (in[i + num_to_copy].value != in[i - 1].value) {
894 break;
895 }
896 }
897 }
898 if (num_to_copy == 0) {
899 out.push_back(in[i]);
900 continue;
901 }
902 float cost = sym_cost[i + num_to_copy] - sym_cost[i];
903 // This subtraction might overflow, but that's OK.
904 size_t lz77_len = num_to_copy - lz77.min_length;
905 float lz77_cost = num_to_copy >= lz77.min_length
906 ? CeilLog2Nonzero(lz77_len + 1) + 1
907 : 0;
908 if (num_to_copy < lz77.min_length || cost <= lz77_cost) {
909 for (size_t j = 0; j < num_to_copy; j++) {
910 out.push_back(in[i + j]);
911 }
912 i += num_to_copy - 1;
913 continue;
914 }
915 // Output the LZ77 length
916 out.emplace_back(in[i].context, lz77_len);
917 out.back().is_lz77_length = true;
918 i += num_to_copy - 1;
919 bit_decrease += cost - lz77_cost;
920 // Output the LZ77 copy distance.
921 out.emplace_back(lz77.nonserialized_distance_context, distance_symbol);
922 }
923 }
924
925 if (bit_decrease > total_symbols * 0.2 + 16) {
926 lz77.enabled = true;
927 }
928 }
929
930 // Hash chain for LZ77 matching
931 struct HashChain {
932 size_t size_;
933 std::vector<uint32_t> data_;
934
935 unsigned hash_num_values_ = 32768;
936 unsigned hash_mask_ = hash_num_values_ - 1;
937 unsigned hash_shift_ = 5;
938
939 std::vector<int> head;
940 std::vector<uint32_t> chain;
941 std::vector<int> val;
942
943 // Speed up repetitions of zero
944 std::vector<int> headz;
945 std::vector<uint32_t> chainz;
946 std::vector<uint32_t> zeros;
947 uint32_t numzeros = 0;
948
949 size_t window_size_;
950 size_t window_mask_;
951 size_t min_length_;
952 size_t max_length_;
953
954 // Map of special distance codes.
955 std::unordered_map<int, int> special_dist_table_;
956 size_t num_special_distances_ = 0;
957
958 uint32_t maxchainlength = 256; // window_size_ to allow all
959
HashChainjxl::__anonfd93e7cc0311::HashChain960 HashChain(const Token* data, size_t size, size_t window_size,
961 size_t min_length, size_t max_length, size_t distance_multiplier)
962 : size_(size),
963 window_size_(window_size),
964 window_mask_(window_size - 1),
965 min_length_(min_length),
966 max_length_(max_length) {
967 data_.resize(size);
968 for (size_t i = 0; i < size; i++) {
969 data_[i] = data[i].value;
970 }
971
972 head.resize(hash_num_values_, -1);
973 val.resize(window_size_, -1);
974 chain.resize(window_size_);
975 for (uint32_t i = 0; i < window_size_; ++i) {
976 chain[i] = i; // same value as index indicates uninitialized
977 }
978
979 zeros.resize(window_size_);
980 headz.resize(window_size_ + 1, -1);
981 chainz.resize(window_size_);
982 for (uint32_t i = 0; i < window_size_; ++i) {
983 chainz[i] = i;
984 }
985 // Translate distance to special distance code.
986 if (distance_multiplier) {
987 // Count down, so if due to small distance multiplier multiple distances
988 // map to the same code, the smallest code will be used in the end.
989 for (int i = kNumSpecialDistances - 1; i >= 0; --i) {
990 int xi = kSpecialDistances[i][0];
991 int yi = kSpecialDistances[i][1];
992 int distance = yi * distance_multiplier + xi;
993 // Ensure that we map distance 1 to the lowest symbols.
994 if (distance < 1) distance = 1;
995 special_dist_table_[distance] = i;
996 }
997 num_special_distances_ = kNumSpecialDistances;
998 }
999 }
1000
GetHashjxl::__anonfd93e7cc0311::HashChain1001 uint32_t GetHash(size_t pos) const {
1002 uint32_t result = 0;
1003 if (pos + 2 < size_) {
1004 // TODO(lode): take the MSB's of the uint32_t values into account as well,
1005 // given that the hash code itself is less than 32 bits.
1006 result ^= (uint32_t)(data_[pos + 0] << 0u);
1007 result ^= (uint32_t)(data_[pos + 1] << hash_shift_);
1008 result ^= (uint32_t)(data_[pos + 2] << (hash_shift_ * 2));
1009 } else {
1010 // No need to compute hash of last 2 bytes, the length 2 is too short.
1011 return 0;
1012 }
1013 return result & hash_mask_;
1014 }
1015
CountZerosjxl::__anonfd93e7cc0311::HashChain1016 uint32_t CountZeros(size_t pos, uint32_t prevzeros) const {
1017 size_t end = pos + window_size_;
1018 if (end > size_) end = size_;
1019 if (prevzeros > 0) {
1020 if (prevzeros >= window_mask_ && data_[end - 1] == 0 &&
1021 end == pos + window_size_) {
1022 return prevzeros;
1023 } else {
1024 return prevzeros - 1;
1025 }
1026 }
1027 uint32_t num = 0;
1028 while (pos + num < end && data_[pos + num] == 0) num++;
1029 return num;
1030 }
1031
Updatejxl::__anonfd93e7cc0311::HashChain1032 void Update(size_t pos) {
1033 uint32_t hashval = GetHash(pos);
1034 uint32_t wpos = pos & window_mask_;
1035
1036 val[wpos] = (int)hashval;
1037 if (head[hashval] != -1) chain[wpos] = head[hashval];
1038 head[hashval] = wpos;
1039
1040 if (pos > 0 && data_[pos] != data_[pos - 1]) numzeros = 0;
1041 numzeros = CountZeros(pos, numzeros);
1042
1043 zeros[wpos] = numzeros;
1044 if (headz[numzeros] != -1) chainz[wpos] = headz[numzeros];
1045 headz[numzeros] = wpos;
1046 }
1047
Updatejxl::__anonfd93e7cc0311::HashChain1048 void Update(size_t pos, size_t len) {
1049 for (size_t i = 0; i < len; i++) {
1050 Update(pos + i);
1051 }
1052 }
1053
1054 template <typename CB>
FindMatchesjxl::__anonfd93e7cc0311::HashChain1055 void FindMatches(size_t pos, int max_dist, const CB& found_match) const {
1056 uint32_t wpos = pos & window_mask_;
1057 uint32_t hashval = GetHash(pos);
1058 uint32_t hashpos = chain[wpos];
1059
1060 int prev_dist = 0;
1061 int end = std::min<int>(pos + max_length_, size_);
1062 uint32_t chainlength = 0;
1063 uint32_t best_len = 0;
1064 for (;;) {
1065 int dist = (hashpos <= wpos) ? (wpos - hashpos)
1066 : (wpos - hashpos + window_mask_ + 1);
1067 if (dist < prev_dist) break;
1068 prev_dist = dist;
1069 uint32_t len = 0;
1070 if (dist > 0) {
1071 int i = pos;
1072 int j = pos - dist;
1073 if (numzeros > 3) {
1074 int r = std::min<int>(numzeros - 1, zeros[hashpos]);
1075 if (i + r >= end) r = end - i - 1;
1076 i += r;
1077 j += r;
1078 }
1079 while (i < end && data_[i] == data_[j]) {
1080 i++;
1081 j++;
1082 }
1083 len = i - pos;
1084 // This can trigger even if the new length is slightly smaller than the
1085 // best length, because it is possible for a slightly cheaper distance
1086 // symbol to occur.
1087 if (len >= min_length_ && len + 2 >= best_len) {
1088 auto it = special_dist_table_.find(dist);
1089 int dist_symbol = (it == special_dist_table_.end())
1090 ? (num_special_distances_ + dist - 1)
1091 : it->second;
1092 found_match(len, dist_symbol);
1093 if (len > best_len) best_len = len;
1094 }
1095 }
1096
1097 chainlength++;
1098 if (chainlength >= maxchainlength) break;
1099
1100 if (numzeros >= 3 && len > numzeros) {
1101 if (hashpos == chainz[hashpos]) break;
1102 hashpos = chainz[hashpos];
1103 if (zeros[hashpos] != numzeros) break;
1104 } else {
1105 if (hashpos == chain[hashpos]) break;
1106 hashpos = chain[hashpos];
1107 if (val[hashpos] != (int)hashval) break; // outdated hash value
1108 }
1109 }
1110 }
FindMatchjxl::__anonfd93e7cc0311::HashChain1111 void FindMatch(size_t pos, int max_dist, size_t* result_dist_symbol,
1112 size_t* result_len) const {
1113 *result_dist_symbol = 0;
1114 *result_len = 1;
1115 FindMatches(pos, max_dist, [&](size_t len, size_t dist_symbol) {
1116 if (len > *result_len ||
1117 (len == *result_len && *result_dist_symbol > dist_symbol)) {
1118 *result_len = len;
1119 *result_dist_symbol = dist_symbol;
1120 }
1121 });
1122 }
1123 };
1124
LenCost(size_t len)1125 float LenCost(size_t len) {
1126 uint32_t nbits, bits, tok;
1127 HybridUintConfig(1, 0, 0).Encode(len, &tok, &nbits, &bits);
1128 constexpr float kCostTable[] = {
1129 2.797667318563126, 3.213177690381199, 2.5706009246743737,
1130 2.408392498667534, 2.829649191872326, 3.3923087753324577,
1131 4.029267451554331, 4.415576699706408, 4.509357574741465,
1132 9.21481543803004, 10.020590190114898, 11.858671627804766,
1133 12.45853300490526, 11.713105831990857, 12.561996324849314,
1134 13.775477692278367, 13.174027068768641,
1135 };
1136 size_t table_size = sizeof kCostTable / sizeof *kCostTable;
1137 if (tok >= table_size) tok = table_size - 1;
1138 return kCostTable[tok] + nbits;
1139 }
1140
1141 // TODO(veluca): this does not take into account usage or non-usage of distance
1142 // multipliers.
DistCost(size_t dist)1143 float DistCost(size_t dist) {
1144 uint32_t nbits, bits, tok;
1145 HybridUintConfig(7, 0, 0).Encode(dist, &tok, &nbits, &bits);
1146 constexpr float kCostTable[] = {
1147 6.368282626312716, 5.680793277090298, 8.347404197105247,
1148 7.641619201599141, 6.914328374119438, 7.959808291537444,
1149 8.70023120759855, 8.71378518934703, 9.379132523982769,
1150 9.110472749092708, 9.159029569270908, 9.430936766731973,
1151 7.278284055315169, 7.8278514904267755, 10.026641158289236,
1152 9.976049229827066, 9.64351607048908, 9.563403863480442,
1153 10.171474111762747, 10.45950155077234, 9.994813912104219,
1154 10.322524683741156, 8.465808729388186, 8.756254166066853,
1155 10.160930174662234, 10.247329273413435, 10.04090403724809,
1156 10.129398517544082, 9.342311691539546, 9.07608009102374,
1157 10.104799540677513, 10.378079384990906, 10.165828974075072,
1158 10.337595322341553, 7.940557464567944, 10.575665823319431,
1159 11.023344321751955, 10.736144698831827, 11.118277044595054,
1160 7.468468230648442, 10.738305230932939, 10.906980780216568,
1161 10.163468216353817, 10.17805759656433, 11.167283670483565,
1162 11.147050200274544, 10.517921919244333, 10.651764778156886,
1163 10.17074446448919, 11.217636876224745, 11.261630721139484,
1164 11.403140815247259, 10.892472096873417, 11.1859607804481,
1165 8.017346947551262, 7.895143720278828, 11.036577113822025,
1166 11.170562110315794, 10.326988722591086, 10.40872184751056,
1167 11.213498225466386, 11.30580635516863, 10.672272515665442,
1168 10.768069466228063, 11.145257364153565, 11.64668307145549,
1169 10.593156194627339, 11.207499484844943, 10.767517766396908,
1170 10.826629811407042, 10.737764794499988, 10.6200448518045,
1171 10.191315385198092, 8.468384171390085, 11.731295299170432,
1172 11.824619886654398, 10.41518844301179, 10.16310536548649,
1173 10.539423685097576, 10.495136599328031, 10.469112847728267,
1174 11.72057686174922, 10.910326337834674, 11.378921834673758,
1175 11.847759036098536, 11.92071647623854, 10.810628276345282,
1176 11.008601085273893, 11.910326337834674, 11.949212023423133,
1177 11.298614839104337, 11.611603659010392, 10.472930394619985,
1178 11.835564720850282, 11.523267392285337, 12.01055816679611,
1179 8.413029688994023, 11.895784139536406, 11.984679534970505,
1180 11.220654278717394, 11.716311684833672, 10.61036646226114,
1181 10.89849965960364, 10.203762898863669, 10.997560826267238,
1182 11.484217379438984, 11.792836176993665, 12.24310468755171,
1183 11.464858097919262, 12.212747017409377, 11.425595666074955,
1184 11.572048533398757, 12.742093965163013, 11.381874288645637,
1185 12.191870445817015, 11.683156920035426, 11.152442115262197,
1186 11.90303691580457, 11.653292787169159, 11.938615382266098,
1187 16.970641701570223, 16.853602280380002, 17.26240782594733,
1188 16.644655390108507, 17.14310889757499, 16.910935455445955,
1189 17.505678976959697, 17.213498225466388, 2.4162310293553024,
1190 3.494587244462329, 3.5258600986408344, 3.4959806589517095,
1191 3.098390886949687, 3.343454654302911, 3.588847442290287,
1192 4.14614790111827, 5.152948641990529, 7.433696808092598,
1193 9.716311684833672,
1194 };
1195 size_t table_size = sizeof kCostTable / sizeof *kCostTable;
1196 if (tok >= table_size) tok = table_size - 1;
1197 return kCostTable[tok] + nbits;
1198 }
1199
ApplyLZ77_LZ77(const HistogramParams & params,size_t num_contexts,const std::vector<std::vector<Token>> & tokens,LZ77Params & lz77,std::vector<std::vector<Token>> & tokens_lz77)1200 void ApplyLZ77_LZ77(const HistogramParams& params, size_t num_contexts,
1201 const std::vector<std::vector<Token>>& tokens,
1202 LZ77Params& lz77,
1203 std::vector<std::vector<Token>>& tokens_lz77) {
1204 // TODO(veluca): tune heuristics here.
1205 SymbolCostEstimator sce(num_contexts, params.force_huffman, tokens, lz77);
1206 float bit_decrease = 0;
1207 size_t total_symbols = 0;
1208 tokens_lz77.resize(tokens.size());
1209 HybridUintConfig uint_config;
1210 std::vector<float> sym_cost;
1211 for (size_t stream = 0; stream < tokens.size(); stream++) {
1212 size_t distance_multiplier =
1213 params.image_widths.size() > stream ? params.image_widths[stream] : 0;
1214 const auto& in = tokens[stream];
1215 auto& out = tokens_lz77[stream];
1216 total_symbols += in.size();
1217 // Cumulative sum of bit costs.
1218 sym_cost.resize(in.size() + 1);
1219 for (size_t i = 0; i < in.size(); i++) {
1220 uint32_t tok, nbits, unused_bits;
1221 uint_config.Encode(in[i].value, &tok, &nbits, &unused_bits);
1222 sym_cost[i + 1] = sce.Bits(in[i].context, tok) + nbits + sym_cost[i];
1223 }
1224
1225 out.reserve(in.size());
1226 size_t max_distance = in.size();
1227 size_t min_length = lz77.min_length;
1228 JXL_ASSERT(min_length >= 3);
1229 size_t max_length = in.size();
1230
1231 // Use next power of two as window size.
1232 size_t window_size = 1;
1233 while (window_size < max_distance && window_size < kWindowSize) {
1234 window_size <<= 1;
1235 }
1236
1237 HashChain chain(in.data(), in.size(), window_size, min_length, max_length,
1238 distance_multiplier);
1239 size_t len, dist_symbol;
1240
1241 const size_t max_lazy_match_len = 256; // 0 to disable lazy matching
1242
1243 // Whether the next symbol was already updated (to test lazy matching)
1244 bool already_updated = false;
1245 for (size_t i = 0; i < in.size(); i++) {
1246 out.push_back(in[i]);
1247 if (!already_updated) chain.Update(i);
1248 already_updated = false;
1249 chain.FindMatch(i, max_distance, &dist_symbol, &len);
1250 if (len >= min_length) {
1251 if (len < max_lazy_match_len && i + 1 < in.size()) {
1252 // Try length at next symbol lazy matching
1253 chain.Update(i + 1);
1254 already_updated = true;
1255 size_t len2, dist_symbol2;
1256 chain.FindMatch(i + 1, max_distance, &dist_symbol2, &len2);
1257 if (len2 > len) {
1258 // Use the lazy match. Add literal, and use the next length starting
1259 // from the next byte.
1260 ++i;
1261 already_updated = false;
1262 len = len2;
1263 dist_symbol = dist_symbol2;
1264 out.push_back(in[i]);
1265 }
1266 }
1267
1268 float cost = sym_cost[i + len] - sym_cost[i];
1269 size_t lz77_len = len - lz77.min_length;
1270 float lz77_cost = LenCost(lz77_len) + DistCost(dist_symbol) +
1271 sce.AddSymbolCost(out.back().context);
1272
1273 if (lz77_cost <= cost) {
1274 out.back().value = len - min_length;
1275 out.back().is_lz77_length = true;
1276 out.emplace_back(lz77.nonserialized_distance_context, dist_symbol);
1277 bit_decrease += cost - lz77_cost;
1278 } else {
1279 // LZ77 match ignored, and symbol already pushed. Push all other
1280 // symbols and skip.
1281 for (size_t j = 1; j < len; j++) {
1282 out.push_back(in[i + j]);
1283 }
1284 }
1285
1286 if (already_updated) {
1287 chain.Update(i + 2, len - 2);
1288 already_updated = false;
1289 } else {
1290 chain.Update(i + 1, len - 1);
1291 }
1292 i += len - 1;
1293 } else {
1294 // Literal, already pushed
1295 }
1296 }
1297 }
1298
1299 if (bit_decrease > total_symbols * 0.2 + 16) {
1300 lz77.enabled = true;
1301 }
1302 }
1303
ApplyLZ77_Optimal(const HistogramParams & params,size_t num_contexts,const std::vector<std::vector<Token>> & tokens,LZ77Params & lz77,std::vector<std::vector<Token>> & tokens_lz77)1304 void ApplyLZ77_Optimal(const HistogramParams& params, size_t num_contexts,
1305 const std::vector<std::vector<Token>>& tokens,
1306 LZ77Params& lz77,
1307 std::vector<std::vector<Token>>& tokens_lz77) {
1308 std::vector<std::vector<Token>> tokens_for_cost_estimate;
1309 ApplyLZ77_LZ77(params, num_contexts, tokens, lz77, tokens_for_cost_estimate);
1310 // If greedy-LZ77 does not give better compression than no-lz77, no reason to
1311 // run the optimal matching.
1312 if (!lz77.enabled) return;
1313 SymbolCostEstimator sce(num_contexts + 1, params.force_huffman,
1314 tokens_for_cost_estimate, lz77);
1315 tokens_lz77.resize(tokens.size());
1316 HybridUintConfig uint_config;
1317 std::vector<float> sym_cost;
1318 std::vector<uint32_t> dist_symbols;
1319 for (size_t stream = 0; stream < tokens.size(); stream++) {
1320 size_t distance_multiplier =
1321 params.image_widths.size() > stream ? params.image_widths[stream] : 0;
1322 const auto& in = tokens[stream];
1323 auto& out = tokens_lz77[stream];
1324 // Cumulative sum of bit costs.
1325 sym_cost.resize(in.size() + 1);
1326 for (size_t i = 0; i < in.size(); i++) {
1327 uint32_t tok, nbits, unused_bits;
1328 uint_config.Encode(in[i].value, &tok, &nbits, &unused_bits);
1329 sym_cost[i + 1] = sce.Bits(in[i].context, tok) + nbits + sym_cost[i];
1330 }
1331
1332 out.reserve(in.size());
1333 size_t max_distance = in.size();
1334 size_t min_length = lz77.min_length;
1335 JXL_ASSERT(min_length >= 3);
1336 size_t max_length = in.size();
1337
1338 // Use next power of two as window size.
1339 size_t window_size = 1;
1340 while (window_size < max_distance && window_size < kWindowSize) {
1341 window_size <<= 1;
1342 }
1343
1344 HashChain chain(in.data(), in.size(), window_size, min_length, max_length,
1345 distance_multiplier);
1346
1347 struct MatchInfo {
1348 uint32_t len;
1349 uint32_t dist_symbol;
1350 uint32_t ctx;
1351 float total_cost = std::numeric_limits<float>::max();
1352 };
1353 // Total cost to encode the first N symbols.
1354 std::vector<MatchInfo> prefix_costs(in.size() + 1);
1355 prefix_costs[0].total_cost = 0;
1356
1357 size_t rle_length = 0;
1358 size_t skip_lz77 = 0;
1359 for (size_t i = 0; i < in.size(); i++) {
1360 chain.Update(i);
1361 float lit_cost =
1362 prefix_costs[i].total_cost + sym_cost[i + 1] - sym_cost[i];
1363 if (prefix_costs[i + 1].total_cost > lit_cost) {
1364 prefix_costs[i + 1].dist_symbol = 0;
1365 prefix_costs[i + 1].len = 1;
1366 prefix_costs[i + 1].ctx = in[i].context;
1367 prefix_costs[i + 1].total_cost = lit_cost;
1368 }
1369 if (skip_lz77 > 0) {
1370 skip_lz77--;
1371 continue;
1372 }
1373 dist_symbols.clear();
1374 chain.FindMatches(i, max_distance,
1375 [&dist_symbols](size_t len, size_t dist_symbol) {
1376 if (dist_symbols.size() <= len) {
1377 dist_symbols.resize(len + 1, dist_symbol);
1378 }
1379 if (dist_symbol < dist_symbols[len]) {
1380 dist_symbols[len] = dist_symbol;
1381 }
1382 });
1383 if (dist_symbols.size() <= min_length) continue;
1384 {
1385 size_t best_cost = dist_symbols.back();
1386 for (size_t j = dist_symbols.size() - 1; j >= min_length; j--) {
1387 if (dist_symbols[j] < best_cost) {
1388 best_cost = dist_symbols[j];
1389 }
1390 dist_symbols[j] = best_cost;
1391 }
1392 }
1393 for (size_t j = min_length; j < dist_symbols.size(); j++) {
1394 // Cost model that uses results from lazy LZ77.
1395 float lz77_cost = sce.LenCost(in[i].context, j - min_length, lz77) +
1396 sce.DistCost(dist_symbols[j], lz77);
1397 float cost = prefix_costs[i].total_cost + lz77_cost;
1398 if (prefix_costs[i + j].total_cost > cost) {
1399 prefix_costs[i + j].len = j;
1400 prefix_costs[i + j].dist_symbol = dist_symbols[j] + 1;
1401 prefix_costs[i + j].ctx = in[i].context;
1402 prefix_costs[i + j].total_cost = cost;
1403 }
1404 }
1405 // We are in a RLE sequence: skip all the symbols except the first 8 and
1406 // the last 8. This avoid quadratic costs for sequences with long runs of
1407 // the same symbol.
1408 if ((dist_symbols.back() == 0 && distance_multiplier == 0) ||
1409 (dist_symbols.back() == 1 && distance_multiplier != 0)) {
1410 rle_length++;
1411 } else {
1412 rle_length = 0;
1413 }
1414 if (rle_length >= 8 && dist_symbols.size() > 9) {
1415 skip_lz77 = dist_symbols.size() - 10;
1416 rle_length = 0;
1417 }
1418 }
1419 size_t pos = in.size();
1420 while (pos > 0) {
1421 bool is_lz77_length = prefix_costs[pos].dist_symbol != 0;
1422 if (is_lz77_length) {
1423 size_t dist_symbol = prefix_costs[pos].dist_symbol - 1;
1424 out.emplace_back(lz77.nonserialized_distance_context, dist_symbol);
1425 }
1426 size_t val = is_lz77_length ? prefix_costs[pos].len - min_length
1427 : in[pos - 1].value;
1428 out.emplace_back(prefix_costs[pos].ctx, val);
1429 out.back().is_lz77_length = is_lz77_length;
1430 pos -= prefix_costs[pos].len;
1431 }
1432 std::reverse(out.begin(), out.end());
1433 }
1434 }
1435
ApplyLZ77(const HistogramParams & params,size_t num_contexts,const std::vector<std::vector<Token>> & tokens,LZ77Params & lz77,std::vector<std::vector<Token>> & tokens_lz77)1436 void ApplyLZ77(const HistogramParams& params, size_t num_contexts,
1437 const std::vector<std::vector<Token>>& tokens, LZ77Params& lz77,
1438 std::vector<std::vector<Token>>& tokens_lz77) {
1439 lz77.enabled = false;
1440 if (params.force_huffman) {
1441 lz77.min_symbol = std::min(PREFIX_MAX_ALPHABET_SIZE - 32, 512);
1442 } else {
1443 lz77.min_symbol = 224;
1444 }
1445 if (params.lz77_method == HistogramParams::LZ77Method::kNone) {
1446 return;
1447 } else if (params.lz77_method == HistogramParams::LZ77Method::kRLE) {
1448 ApplyLZ77_RLE(params, num_contexts, tokens, lz77, tokens_lz77);
1449 } else if (params.lz77_method == HistogramParams::LZ77Method::kLZ77) {
1450 ApplyLZ77_LZ77(params, num_contexts, tokens, lz77, tokens_lz77);
1451 } else if (params.lz77_method == HistogramParams::LZ77Method::kOptimal) {
1452 ApplyLZ77_Optimal(params, num_contexts, tokens, lz77, tokens_lz77);
1453 } else {
1454 JXL_ABORT("Not implemented");
1455 }
1456 }
1457 } // namespace
1458
BuildAndEncodeHistograms(const HistogramParams & params,size_t num_contexts,std::vector<std::vector<Token>> & tokens,EntropyEncodingData * codes,std::vector<uint8_t> * context_map,BitWriter * writer,size_t layer,AuxOut * aux_out)1459 size_t BuildAndEncodeHistograms(const HistogramParams& params,
1460 size_t num_contexts,
1461 std::vector<std::vector<Token>>& tokens,
1462 EntropyEncodingData* codes,
1463 std::vector<uint8_t>* context_map,
1464 BitWriter* writer, size_t layer,
1465 AuxOut* aux_out) {
1466 size_t total_bits = 0;
1467 codes->lz77.nonserialized_distance_context = num_contexts;
1468 std::vector<std::vector<Token>> tokens_lz77;
1469 ApplyLZ77(params, num_contexts, tokens, codes->lz77, tokens_lz77);
1470 if (ans_fuzzer_friendly_) {
1471 codes->lz77.length_uint_config = HybridUintConfig(10, 0, 0);
1472 codes->lz77.min_symbol = 2048;
1473 }
1474
1475 const size_t max_contexts = std::min(num_contexts, kClustersLimit);
1476 BitWriter::Allotment allotment(writer,
1477 128 + num_contexts * 40 + max_contexts * 96);
1478 if (writer) {
1479 JXL_CHECK(Bundle::Write(codes->lz77, writer, layer, aux_out));
1480 } else {
1481 size_t ebits, bits;
1482 JXL_CHECK(Bundle::CanEncode(codes->lz77, &ebits, &bits));
1483 total_bits += bits;
1484 }
1485 if (codes->lz77.enabled) {
1486 if (writer) {
1487 size_t b = writer->BitsWritten();
1488 EncodeUintConfig(codes->lz77.length_uint_config, writer,
1489 /*log_alpha_size=*/8);
1490 total_bits += writer->BitsWritten() - b;
1491 } else {
1492 SizeWriter size_writer;
1493 EncodeUintConfig(codes->lz77.length_uint_config, &size_writer,
1494 /*log_alpha_size=*/8);
1495 total_bits += size_writer.size;
1496 }
1497 num_contexts += 1;
1498 tokens = std::move(tokens_lz77);
1499 }
1500 size_t total_tokens = 0;
1501 // Build histograms.
1502 HistogramBuilder builder(num_contexts);
1503 HybridUintConfig uint_config; // Default config for clustering.
1504 // Unless we are using the kContextMap histogram option.
1505 if (params.uint_method == HistogramParams::HybridUintMethod::kContextMap) {
1506 uint_config = HybridUintConfig(2, 0, 1);
1507 }
1508 if (ans_fuzzer_friendly_) {
1509 uint_config = HybridUintConfig(10, 0, 0);
1510 }
1511 for (size_t i = 0; i < tokens.size(); ++i) {
1512 if (codes->lz77.enabled) {
1513 for (size_t j = 0; j < tokens[i].size(); ++j) {
1514 const Token& token = tokens[i][j];
1515 total_tokens++;
1516 uint32_t tok, nbits, bits;
1517 (token.is_lz77_length ? codes->lz77.length_uint_config : uint_config)
1518 .Encode(token.value, &tok, &nbits, &bits);
1519 tok += token.is_lz77_length ? codes->lz77.min_symbol : 0;
1520 builder.VisitSymbol(tok, token.context);
1521 }
1522 } else if (num_contexts == 1) {
1523 for (size_t j = 0; j < tokens[i].size(); ++j) {
1524 const Token& token = tokens[i][j];
1525 total_tokens++;
1526 uint32_t tok, nbits, bits;
1527 uint_config.Encode(token.value, &tok, &nbits, &bits);
1528 builder.VisitSymbol(tok, /*token.context=*/0);
1529 }
1530 } else {
1531 for (size_t j = 0; j < tokens[i].size(); ++j) {
1532 const Token& token = tokens[i][j];
1533 total_tokens++;
1534 uint32_t tok, nbits, bits;
1535 uint_config.Encode(token.value, &tok, &nbits, &bits);
1536 builder.VisitSymbol(tok, token.context);
1537 }
1538 }
1539 }
1540
1541 bool use_prefix_code =
1542 params.force_huffman || total_tokens < 100 ||
1543 params.clustering == HistogramParams::ClusteringType::kFastest ||
1544 ans_fuzzer_friendly_;
1545 if (!use_prefix_code) {
1546 bool all_singleton = true;
1547 for (size_t i = 0; i < num_contexts; i++) {
1548 if (builder.Histo(i).ShannonEntropy() >= 1e-5) {
1549 all_singleton = false;
1550 }
1551 }
1552 if (all_singleton) {
1553 use_prefix_code = true;
1554 }
1555 }
1556
1557 // Encode histograms.
1558 total_bits += builder.BuildAndStoreEntropyCodes(params, tokens, codes,
1559 context_map, use_prefix_code,
1560 writer, layer, aux_out);
1561 allotment.FinishedHistogram(writer);
1562 ReclaimAndCharge(writer, &allotment, layer, aux_out);
1563
1564 if (aux_out != nullptr) {
1565 aux_out->layers[layer].num_clustered_histograms +=
1566 codes->encoding_info.size();
1567 }
1568 return total_bits;
1569 }
1570
WriteTokens(const std::vector<Token> & tokens,const EntropyEncodingData & codes,const std::vector<uint8_t> & context_map,BitWriter * writer)1571 size_t WriteTokens(const std::vector<Token>& tokens,
1572 const EntropyEncodingData& codes,
1573 const std::vector<uint8_t>& context_map, BitWriter* writer) {
1574 size_t num_extra_bits = 0;
1575 if (codes.use_prefix_code) {
1576 for (size_t i = 0; i < tokens.size(); i++) {
1577 uint32_t tok, nbits, bits;
1578 const Token& token = tokens[i];
1579 size_t histo = context_map[token.context];
1580 (token.is_lz77_length ? codes.lz77.length_uint_config
1581 : codes.uint_config[histo])
1582 .Encode(token.value, &tok, &nbits, &bits);
1583 tok += token.is_lz77_length ? codes.lz77.min_symbol : 0;
1584 // Combine two calls to the BitWriter. Equivalent to:
1585 // writer->Write(codes.encoding_info[histo][tok].depth,
1586 // codes.encoding_info[histo][tok].bits);
1587 // writer->Write(nbits, bits);
1588 uint64_t data = codes.encoding_info[histo][tok].bits;
1589 data |= bits << codes.encoding_info[histo][tok].depth;
1590 writer->Write(codes.encoding_info[histo][tok].depth + nbits, data);
1591 num_extra_bits += nbits;
1592 }
1593 return num_extra_bits;
1594 }
1595 std::vector<uint64_t> out;
1596 std::vector<uint8_t> out_nbits;
1597 out.reserve(tokens.size());
1598 out_nbits.reserve(tokens.size());
1599 uint64_t allbits = 0;
1600 size_t numallbits = 0;
1601 // Writes in *reversed* order.
1602 auto addbits = [&](size_t bits, size_t nbits) {
1603 if (JXL_UNLIKELY(nbits)) {
1604 JXL_DASSERT(bits >> nbits == 0);
1605 if (JXL_UNLIKELY(numallbits + nbits > BitWriter::kMaxBitsPerCall)) {
1606 out.push_back(allbits);
1607 out_nbits.push_back(numallbits);
1608 numallbits = allbits = 0;
1609 }
1610 allbits <<= nbits;
1611 allbits |= bits;
1612 numallbits += nbits;
1613 }
1614 };
1615 const int end = tokens.size();
1616 ANSCoder ans;
1617 if (codes.lz77.enabled || context_map.size() > 1) {
1618 for (int i = end - 1; i >= 0; --i) {
1619 const Token token = tokens[i];
1620 const uint8_t histo = context_map[token.context];
1621 uint32_t tok, nbits, bits;
1622 (token.is_lz77_length ? codes.lz77.length_uint_config
1623 : codes.uint_config[histo])
1624 .Encode(tokens[i].value, &tok, &nbits, &bits);
1625 tok += token.is_lz77_length ? codes.lz77.min_symbol : 0;
1626 const ANSEncSymbolInfo& info = codes.encoding_info[histo][tok];
1627 // Extra bits first as this is reversed.
1628 addbits(bits, nbits);
1629 num_extra_bits += nbits;
1630 uint8_t ans_nbits = 0;
1631 uint32_t ans_bits = ans.PutSymbol(info, &ans_nbits);
1632 addbits(ans_bits, ans_nbits);
1633 }
1634 } else {
1635 for (int i = end - 1; i >= 0; --i) {
1636 uint32_t tok, nbits, bits;
1637 codes.uint_config[0].Encode(tokens[i].value, &tok, &nbits, &bits);
1638 const ANSEncSymbolInfo& info = codes.encoding_info[0][tok];
1639 // Extra bits first as this is reversed.
1640 addbits(bits, nbits);
1641 num_extra_bits += nbits;
1642 uint8_t ans_nbits = 0;
1643 uint32_t ans_bits = ans.PutSymbol(info, &ans_nbits);
1644 addbits(ans_bits, ans_nbits);
1645 }
1646 }
1647 const uint32_t state = ans.GetState();
1648 writer->Write(32, state);
1649 writer->Write(numallbits, allbits);
1650 for (int i = out.size(); i > 0; --i) {
1651 writer->Write(out_nbits[i - 1], out[i - 1]);
1652 }
1653 return num_extra_bits;
1654 }
1655
WriteTokens(const std::vector<Token> & tokens,const EntropyEncodingData & codes,const std::vector<uint8_t> & context_map,BitWriter * writer,size_t layer,AuxOut * aux_out)1656 void WriteTokens(const std::vector<Token>& tokens,
1657 const EntropyEncodingData& codes,
1658 const std::vector<uint8_t>& context_map, BitWriter* writer,
1659 size_t layer, AuxOut* aux_out) {
1660 BitWriter::Allotment allotment(writer, 32 * tokens.size() + 32 * 1024 * 4);
1661 size_t num_extra_bits = WriteTokens(tokens, codes, context_map, writer);
1662 ReclaimAndCharge(writer, &allotment, layer, aux_out);
1663 if (aux_out != nullptr) {
1664 aux_out->layers[layer].extra_bits += num_extra_bits;
1665 }
1666 }
1667
SetANSFuzzerFriendly(bool ans_fuzzer_friendly)1668 void SetANSFuzzerFriendly(bool ans_fuzzer_friendly) {
1669 #if JXL_IS_DEBUG_BUILD // Guard against accidental / malicious changes.
1670 ans_fuzzer_friendly_ = ans_fuzzer_friendly;
1671 #endif
1672 }
1673 } // namespace jxl
1674