1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
2 //
3 // Use of this source code is governed by a BSD-style
4 // license that can be found in the LICENSE file.
5 
6 #include "lib/jxl/enc_ans.h"
7 
8 #include <stdint.h>
9 
10 #include <algorithm>
11 #include <array>
12 #include <cmath>
13 #include <limits>
14 #include <numeric>
15 #include <type_traits>
16 #include <unordered_map>
17 #include <utility>
18 #include <vector>
19 
20 #include "lib/jxl/ans_common.h"
21 #include "lib/jxl/aux_out.h"
22 #include "lib/jxl/aux_out_fwd.h"
23 #include "lib/jxl/base/bits.h"
24 #include "lib/jxl/dec_ans.h"
25 #include "lib/jxl/enc_cluster.h"
26 #include "lib/jxl/enc_context_map.h"
27 #include "lib/jxl/enc_huffman.h"
28 #include "lib/jxl/fast_math-inl.h"
29 #include "lib/jxl/fields.h"
30 
31 namespace jxl {
32 
33 namespace {
34 
35 bool ans_fuzzer_friendly_ = false;
36 
37 static const int kMaxNumSymbolsForSmallCode = 4;
38 
ANSBuildInfoTable(const ANSHistBin * counts,const AliasTable::Entry * table,size_t alphabet_size,size_t log_alpha_size,ANSEncSymbolInfo * info)39 void ANSBuildInfoTable(const ANSHistBin* counts, const AliasTable::Entry* table,
40                        size_t alphabet_size, size_t log_alpha_size,
41                        ANSEncSymbolInfo* info) {
42   size_t log_entry_size = ANS_LOG_TAB_SIZE - log_alpha_size;
43   size_t entry_size_minus_1 = (1 << log_entry_size) - 1;
44   // create valid alias table for empty streams.
45   for (size_t s = 0; s < std::max<size_t>(1, alphabet_size); ++s) {
46     const ANSHistBin freq = s == alphabet_size ? ANS_TAB_SIZE : counts[s];
47     info[s].freq_ = static_cast<uint16_t>(freq);
48 #ifdef USE_MULT_BY_RECIPROCAL
49     if (freq != 0) {
50       info[s].ifreq_ =
51           ((1ull << RECIPROCAL_PRECISION) + info[s].freq_ - 1) / info[s].freq_;
52     } else {
53       info[s].ifreq_ = 1;  // shouldn't matter (symbol shouldn't occur), but...
54     }
55 #endif
56     info[s].reverse_map_.resize(freq);
57   }
58   for (int i = 0; i < ANS_TAB_SIZE; i++) {
59     AliasTable::Symbol s =
60         AliasTable::Lookup(table, i, log_entry_size, entry_size_minus_1);
61     info[s.value].reverse_map_[s.offset] = i;
62   }
63 }
64 
EstimateDataBits(const ANSHistBin * histogram,const ANSHistBin * counts,size_t len)65 float EstimateDataBits(const ANSHistBin* histogram, const ANSHistBin* counts,
66                        size_t len) {
67   float sum = 0.0f;
68   int total_histogram = 0;
69   int total_counts = 0;
70   for (size_t i = 0; i < len; ++i) {
71     total_histogram += histogram[i];
72     total_counts += counts[i];
73     if (histogram[i] > 0) {
74       JXL_ASSERT(counts[i] > 0);
75       // += histogram[i] * -log(counts[i]/total_counts)
76       sum += histogram[i] *
77              std::max(0.0f, ANS_LOG_TAB_SIZE - FastLog2f(counts[i]));
78     }
79   }
80   if (total_histogram > 0) {
81     JXL_ASSERT(total_counts == ANS_TAB_SIZE);
82   }
83   return sum;
84 }
85 
EstimateDataBitsFlat(const ANSHistBin * histogram,size_t len)86 float EstimateDataBitsFlat(const ANSHistBin* histogram, size_t len) {
87   const float flat_bits = std::max(FastLog2f(len), 0.0f);
88   int total_histogram = 0;
89   for (size_t i = 0; i < len; ++i) {
90     total_histogram += histogram[i];
91   }
92   return total_histogram * flat_bits;
93 }
94 
95 // Static Huffman code for encoding logcounts. The last symbol is used as RLE
96 // sequence.
97 static const uint8_t kLogCountBitLengths[ANS_LOG_TAB_SIZE + 2] = {
98     5, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 6, 7, 7,
99 };
100 static const uint8_t kLogCountSymbols[ANS_LOG_TAB_SIZE + 2] = {
101     17, 11, 15, 3, 9, 7, 4, 2, 5, 6, 0, 33, 1, 65,
102 };
103 
104 // Returns the difference between largest count that can be represented and is
105 // smaller than "count" and smallest representable count larger than "count".
SmallestIncrement(uint32_t count,uint32_t shift)106 static int SmallestIncrement(uint32_t count, uint32_t shift) {
107   int bits = count == 0 ? -1 : FloorLog2Nonzero(count);
108   int drop_bits = bits - GetPopulationCountPrecision(bits, shift);
109   return drop_bits < 0 ? 1 : (1 << drop_bits);
110 }
111 
112 template <bool minimize_error_of_sum>
RebalanceHistogram(const float * targets,int max_symbol,int table_size,uint32_t shift,int * omit_pos,ANSHistBin * counts)113 bool RebalanceHistogram(const float* targets, int max_symbol, int table_size,
114                         uint32_t shift, int* omit_pos, ANSHistBin* counts) {
115   int sum = 0;
116   float sum_nonrounded = 0.0;
117   int remainder_pos = 0;  // if all of them are handled in first loop
118   int remainder_log = -1;
119   for (int n = 0; n < max_symbol; ++n) {
120     if (targets[n] > 0 && targets[n] < 1.0f) {
121       counts[n] = 1;
122       sum_nonrounded += targets[n];
123       sum += counts[n];
124     }
125   }
126   const float discount_ratio =
127       (table_size - sum) / (table_size - sum_nonrounded);
128   JXL_ASSERT(discount_ratio > 0);
129   JXL_ASSERT(discount_ratio <= 1.0f);
130   // Invariant for minimize_error_of_sum == true:
131   // abs(sum - sum_nonrounded)
132   //   <= SmallestIncrement(max(targets[])) + max_symbol
133   for (int n = 0; n < max_symbol; ++n) {
134     if (targets[n] >= 1.0f) {
135       sum_nonrounded += targets[n];
136       counts[n] =
137           static_cast<ANSHistBin>(targets[n] * discount_ratio);  // truncate
138       if (counts[n] == 0) counts[n] = 1;
139       if (counts[n] == table_size) counts[n] = table_size - 1;
140       // Round the count to the closest nonzero multiple of SmallestIncrement
141       // (when minimize_error_of_sum is false) or one of two closest so as to
142       // keep the sum as close as possible to sum_nonrounded.
143       int inc = SmallestIncrement(counts[n], shift);
144       counts[n] -= counts[n] & (inc - 1);
145       // TODO(robryk): Should we rescale targets[n]?
146       const float target =
147           minimize_error_of_sum ? (sum_nonrounded - sum) : targets[n];
148       if (counts[n] == 0 ||
149           (target > counts[n] + inc / 2 && counts[n] + inc < table_size)) {
150         counts[n] += inc;
151       }
152       sum += counts[n];
153       const int count_log = FloorLog2Nonzero(static_cast<uint32_t>(counts[n]));
154       if (count_log > remainder_log) {
155         remainder_pos = n;
156         remainder_log = count_log;
157       }
158     }
159   }
160   JXL_ASSERT(remainder_pos != -1);
161   // NOTE: This is the only place where counts could go negative. We could
162   // detect that, return false and make ANSHistBin uint32_t.
163   counts[remainder_pos] -= sum - table_size;
164   *omit_pos = remainder_pos;
165   return counts[remainder_pos] > 0;
166 }
167 
NormalizeCounts(ANSHistBin * counts,int * omit_pos,const int length,const int precision_bits,uint32_t shift,int * num_symbols,int * symbols)168 Status NormalizeCounts(ANSHistBin* counts, int* omit_pos, const int length,
169                        const int precision_bits, uint32_t shift,
170                        int* num_symbols, int* symbols) {
171   const int32_t table_size = 1 << precision_bits;  // target sum / table size
172   uint64_t total = 0;
173   int max_symbol = 0;
174   int symbol_count = 0;
175   for (int n = 0; n < length; ++n) {
176     total += counts[n];
177     if (counts[n] > 0) {
178       if (symbol_count < kMaxNumSymbolsForSmallCode) {
179         symbols[symbol_count] = n;
180       }
181       ++symbol_count;
182       max_symbol = n + 1;
183     }
184   }
185   *num_symbols = symbol_count;
186   if (symbol_count == 0) {
187     return true;
188   }
189   if (symbol_count == 1) {
190     counts[symbols[0]] = table_size;
191     return true;
192   }
193   if (symbol_count > table_size)
194     return JXL_FAILURE("Too many entries in an ANS histogram");
195 
196   const float norm = 1.f * table_size / total;
197   std::vector<float> targets(max_symbol);
198   for (size_t n = 0; n < targets.size(); ++n) {
199     targets[n] = norm * counts[n];
200   }
201   if (!RebalanceHistogram<false>(&targets[0], max_symbol, table_size, shift,
202                                  omit_pos, counts)) {
203     // Use an alternative rebalancing mechanism if the one above failed
204     // to create a histogram that is positive wherever the original one was.
205     if (!RebalanceHistogram<true>(&targets[0], max_symbol, table_size, shift,
206                                   omit_pos, counts)) {
207       return JXL_FAILURE("Logic error: couldn't rebalance a histogram");
208     }
209   }
210   return true;
211 }
212 
213 struct SizeWriter {
214   size_t size = 0;
Writejxl::__anon551ae0000111::SizeWriter215   void Write(size_t num, size_t bits) { size += num; }
216 };
217 
218 template <typename Writer>
StoreVarLenUint8(size_t n,Writer * writer)219 void StoreVarLenUint8(size_t n, Writer* writer) {
220   JXL_DASSERT(n <= 255);
221   if (n == 0) {
222     writer->Write(1, 0);
223   } else {
224     writer->Write(1, 1);
225     size_t nbits = FloorLog2Nonzero(n);
226     writer->Write(3, nbits);
227     writer->Write(nbits, n - (1ULL << nbits));
228   }
229 }
230 
231 template <typename Writer>
StoreVarLenUint16(size_t n,Writer * writer)232 void StoreVarLenUint16(size_t n, Writer* writer) {
233   JXL_DASSERT(n <= 65535);
234   if (n == 0) {
235     writer->Write(1, 0);
236   } else {
237     writer->Write(1, 1);
238     size_t nbits = FloorLog2Nonzero(n);
239     writer->Write(4, nbits);
240     writer->Write(nbits, n - (1ULL << nbits));
241   }
242 }
243 
244 template <typename Writer>
EncodeCounts(const ANSHistBin * counts,const int alphabet_size,const int omit_pos,const int num_symbols,uint32_t shift,const int * symbols,Writer * writer)245 bool EncodeCounts(const ANSHistBin* counts, const int alphabet_size,
246                   const int omit_pos, const int num_symbols, uint32_t shift,
247                   const int* symbols, Writer* writer) {
248   bool ok = true;
249   if (num_symbols <= 2) {
250     // Small tree marker to encode 1-2 symbols.
251     writer->Write(1, 1);
252     if (num_symbols == 0) {
253       writer->Write(1, 0);
254       StoreVarLenUint8(0, writer);
255     } else {
256       writer->Write(1, num_symbols - 1);
257       for (int i = 0; i < num_symbols; ++i) {
258         StoreVarLenUint8(symbols[i], writer);
259       }
260     }
261     if (num_symbols == 2) {
262       writer->Write(ANS_LOG_TAB_SIZE, counts[symbols[0]]);
263     }
264   } else {
265     // Mark non-small tree.
266     writer->Write(1, 0);
267     // Mark non-flat histogram.
268     writer->Write(1, 0);
269 
270     // Precompute sequences for RLE encoding. Contains the number of identical
271     // values starting at a given index. Only contains the value at the first
272     // element of the series.
273     std::vector<uint32_t> same(alphabet_size, 0);
274     int last = 0;
275     for (int i = 1; i < alphabet_size; i++) {
276       // Store the sequence length once different symbol reached, or we're at
277       // the end, or the length is longer than we can encode, or we are at
278       // the omit_pos. We don't support including the omit_pos in an RLE
279       // sequence because this value may use a different amount of log2 bits
280       // than standard, it is too complex to handle in the decoder.
281       if (counts[i] != counts[last] || i + 1 == alphabet_size ||
282           (i - last) >= 255 || i == omit_pos || i == omit_pos + 1) {
283         same[last] = (i - last);
284         last = i + 1;
285       }
286     }
287 
288     int length = 0;
289     std::vector<int> logcounts(alphabet_size);
290     int omit_log = 0;
291     for (int i = 0; i < alphabet_size; ++i) {
292       JXL_ASSERT(counts[i] <= ANS_TAB_SIZE);
293       JXL_ASSERT(counts[i] >= 0);
294       if (i == omit_pos) {
295         length = i + 1;
296       } else if (counts[i] > 0) {
297         logcounts[i] = FloorLog2Nonzero(static_cast<uint32_t>(counts[i])) + 1;
298         length = i + 1;
299         if (i < omit_pos) {
300           omit_log = std::max(omit_log, logcounts[i] + 1);
301         } else {
302           omit_log = std::max(omit_log, logcounts[i]);
303         }
304       }
305     }
306     logcounts[omit_pos] = omit_log;
307 
308     // Elias gamma-like code for shift. Only difference is that if the number
309     // of bits to be encoded is equal to FloorLog2(ANS_LOG_TAB_SIZE+1), we skip
310     // the terminating 0 in unary coding.
311     int upper_bound_log = FloorLog2Nonzero(ANS_LOG_TAB_SIZE + 1);
312     int log = FloorLog2Nonzero(shift + 1);
313     writer->Write(log, (1 << log) - 1);
314     if (log != upper_bound_log) writer->Write(1, 0);
315     writer->Write(log, ((1 << log) - 1) & (shift + 1));
316 
317     // Since num_symbols >= 3, we know that length >= 3, therefore we encode
318     // length - 3.
319     if (length - 3 > 255) {
320       // Pretend that everything is OK, but complain about correctness later.
321       StoreVarLenUint8(255, writer);
322       ok = false;
323     } else {
324       StoreVarLenUint8(length - 3, writer);
325     }
326 
327     // The logcount values are encoded with a static Huffman code.
328     static const size_t kMinReps = 4;
329     size_t rep = ANS_LOG_TAB_SIZE + 1;
330     for (int i = 0; i < length; ++i) {
331       if (i > 0 && same[i - 1] > kMinReps) {
332         // Encode the RLE symbol and skip the repeated ones.
333         writer->Write(kLogCountBitLengths[rep], kLogCountSymbols[rep]);
334         StoreVarLenUint8(same[i - 1] - kMinReps - 1, writer);
335         i += same[i - 1] - 2;
336         continue;
337       }
338       writer->Write(kLogCountBitLengths[logcounts[i]],
339                     kLogCountSymbols[logcounts[i]]);
340     }
341     for (int i = 0; i < length; ++i) {
342       if (i > 0 && same[i - 1] > kMinReps) {
343         // Skip symbols encoded by RLE.
344         i += same[i - 1] - 2;
345         continue;
346       }
347       if (logcounts[i] > 1 && i != omit_pos) {
348         int bitcount = GetPopulationCountPrecision(logcounts[i] - 1, shift);
349         int drop_bits = logcounts[i] - 1 - bitcount;
350         JXL_CHECK((counts[i] & ((1 << drop_bits) - 1)) == 0);
351         writer->Write(bitcount, (counts[i] >> drop_bits) - (1 << bitcount));
352       }
353     }
354   }
355   return ok;
356 }
357 
EncodeFlatHistogram(const int alphabet_size,BitWriter * writer)358 void EncodeFlatHistogram(const int alphabet_size, BitWriter* writer) {
359   // Mark non-small tree.
360   writer->Write(1, 0);
361   // Mark uniform histogram.
362   writer->Write(1, 1);
363   JXL_ASSERT(alphabet_size > 0);
364   // Encode alphabet size.
365   StoreVarLenUint8(alphabet_size - 1, writer);
366 }
367 
ComputeHistoAndDataCost(const ANSHistBin * histogram,size_t alphabet_size,uint32_t method)368 float ComputeHistoAndDataCost(const ANSHistBin* histogram, size_t alphabet_size,
369                               uint32_t method) {
370   if (method == 0) {  // Flat code
371     return ANS_LOG_TAB_SIZE + 2 +
372            EstimateDataBitsFlat(histogram, alphabet_size);
373   }
374   // Non-flat: shift = method-1.
375   uint32_t shift = method - 1;
376   std::vector<ANSHistBin> counts(histogram, histogram + alphabet_size);
377   int omit_pos = 0;
378   int num_symbols;
379   int symbols[kMaxNumSymbolsForSmallCode] = {};
380   JXL_CHECK(NormalizeCounts(counts.data(), &omit_pos, alphabet_size,
381                             ANS_LOG_TAB_SIZE, shift, &num_symbols, symbols));
382   SizeWriter writer;
383   // Ignore the correctness, no real encoding happens at this stage.
384   (void)EncodeCounts(counts.data(), alphabet_size, omit_pos, num_symbols, shift,
385                      symbols, &writer);
386   return writer.size +
387          EstimateDataBits(histogram, counts.data(), alphabet_size);
388 }
389 
ComputeBestMethod(const ANSHistBin * histogram,size_t alphabet_size,float * cost,HistogramParams::ANSHistogramStrategy ans_histogram_strategy)390 uint32_t ComputeBestMethod(
391     const ANSHistBin* histogram, size_t alphabet_size, float* cost,
392     HistogramParams::ANSHistogramStrategy ans_histogram_strategy) {
393   size_t method = 0;
394   float fcost = ComputeHistoAndDataCost(histogram, alphabet_size, 0);
395   for (uint32_t shift = 0; shift <= ANS_LOG_TAB_SIZE;
396        ans_histogram_strategy != HistogramParams::ANSHistogramStrategy::kPrecise
397            ? shift += 2
398            : shift++) {
399     float c = ComputeHistoAndDataCost(histogram, alphabet_size, shift + 1);
400     if (c < fcost) {
401       method = shift + 1;
402       fcost = c;
403     } else if (ans_histogram_strategy ==
404                HistogramParams::ANSHistogramStrategy::kFast) {
405       // do not be as precise if estimating cost.
406       break;
407     }
408   }
409   *cost = fcost;
410   return method;
411 }
412 
413 }  // namespace
414 
415 // Returns an estimate of the cost of encoding this histogram and the
416 // corresponding data.
BuildAndStoreANSEncodingData(HistogramParams::ANSHistogramStrategy ans_histogram_strategy,const ANSHistBin * histogram,size_t alphabet_size,size_t log_alpha_size,bool use_prefix_code,ANSEncSymbolInfo * info,BitWriter * writer)417 size_t BuildAndStoreANSEncodingData(
418     HistogramParams::ANSHistogramStrategy ans_histogram_strategy,
419     const ANSHistBin* histogram, size_t alphabet_size, size_t log_alpha_size,
420     bool use_prefix_code, ANSEncSymbolInfo* info, BitWriter* writer) {
421   if (use_prefix_code) {
422     if (alphabet_size <= 1) return 0;
423     std::vector<uint32_t> histo(alphabet_size);
424     for (size_t i = 0; i < alphabet_size; i++) {
425       histo[i] = histogram[i];
426       JXL_CHECK(histogram[i] >= 0);
427     }
428     size_t cost = 0;
429     {
430       std::vector<uint8_t> depths(alphabet_size);
431       std::vector<uint16_t> bits(alphabet_size);
432       BitWriter tmp_writer;
433       BitWriter* w = writer ? writer : &tmp_writer;
434       size_t start = w->BitsWritten();
435       BitWriter::Allotment allotment(
436           w, 8 * alphabet_size + 8);  // safe upper bound
437       BuildAndStoreHuffmanTree(histo.data(), alphabet_size, depths.data(),
438                                bits.data(), w);
439       ReclaimAndCharge(w, &allotment, 0, /*aux_out=*/nullptr);
440 
441       for (size_t i = 0; i < alphabet_size; i++) {
442         info[i].bits = depths[i] == 0 ? 0 : bits[i];
443         info[i].depth = depths[i];
444       }
445       cost = w->BitsWritten() - start;
446     }
447     // Estimate data cost.
448     for (size_t i = 0; i < alphabet_size; i++) {
449       cost += histogram[i] * info[i].depth;
450     }
451     return cost;
452   }
453   JXL_ASSERT(alphabet_size <= ANS_TAB_SIZE);
454   // Ensure we ignore trailing zeros in the histogram.
455   if (alphabet_size != 0) {
456     size_t largest_symbol = 0;
457     for (size_t i = 0; i < alphabet_size; i++) {
458       if (histogram[i] != 0) largest_symbol = i;
459     }
460     alphabet_size = largest_symbol + 1;
461   }
462   float cost;
463   uint32_t method = ComputeBestMethod(histogram, alphabet_size, &cost,
464                                       ans_histogram_strategy);
465   JXL_ASSERT(cost >= 0);
466   int num_symbols;
467   int symbols[kMaxNumSymbolsForSmallCode] = {};
468   std::vector<ANSHistBin> counts(histogram, histogram + alphabet_size);
469   if (!counts.empty()) {
470     size_t sum = 0;
471     for (size_t i = 0; i < counts.size(); i++) {
472       sum += counts[i];
473     }
474     if (sum == 0) {
475       counts[0] = ANS_TAB_SIZE;
476     }
477   }
478   if (method == 0) {
479     counts = CreateFlatHistogram(alphabet_size, ANS_TAB_SIZE);
480     AliasTable::Entry a[ANS_MAX_ALPHABET_SIZE];
481     InitAliasTable(counts, ANS_TAB_SIZE, log_alpha_size, a);
482     ANSBuildInfoTable(counts.data(), a, alphabet_size, log_alpha_size, info);
483     if (writer != nullptr) {
484       EncodeFlatHistogram(alphabet_size, writer);
485     }
486     return cost;
487   }
488   int omit_pos = 0;
489   uint32_t shift = method - 1;
490   JXL_CHECK(NormalizeCounts(counts.data(), &omit_pos, alphabet_size,
491                             ANS_LOG_TAB_SIZE, shift, &num_symbols, symbols));
492   AliasTable::Entry a[ANS_MAX_ALPHABET_SIZE];
493   InitAliasTable(counts, ANS_TAB_SIZE, log_alpha_size, a);
494   ANSBuildInfoTable(counts.data(), a, alphabet_size, log_alpha_size, info);
495   if (writer != nullptr) {
496     bool ok = EncodeCounts(counts.data(), alphabet_size, omit_pos, num_symbols,
497                            shift, symbols, writer);
498     (void)ok;
499     JXL_DASSERT(ok);
500   }
501   return cost;
502 }
503 
ANSPopulationCost(const ANSHistBin * data,size_t alphabet_size)504 float ANSPopulationCost(const ANSHistBin* data, size_t alphabet_size) {
505   float c;
506   ComputeBestMethod(data, alphabet_size, &c,
507                     HistogramParams::ANSHistogramStrategy::kFast);
508   return c;
509 }
510 
511 template <typename Writer>
EncodeUintConfig(const HybridUintConfig uint_config,Writer * writer,size_t log_alpha_size)512 void EncodeUintConfig(const HybridUintConfig uint_config, Writer* writer,
513                       size_t log_alpha_size) {
514   writer->Write(CeilLog2Nonzero(log_alpha_size + 1),
515                 uint_config.split_exponent);
516   if (uint_config.split_exponent == log_alpha_size) {
517     return;  // msb/lsb don't matter.
518   }
519   size_t nbits = CeilLog2Nonzero(uint_config.split_exponent + 1);
520   writer->Write(nbits, uint_config.msb_in_token);
521   nbits = CeilLog2Nonzero(uint_config.split_exponent -
522                           uint_config.msb_in_token + 1);
523   writer->Write(nbits, uint_config.lsb_in_token);
524 }
525 template <typename Writer>
EncodeUintConfigs(const std::vector<HybridUintConfig> & uint_config,Writer * writer,size_t log_alpha_size)526 void EncodeUintConfigs(const std::vector<HybridUintConfig>& uint_config,
527                        Writer* writer, size_t log_alpha_size) {
528   // TODO(veluca): RLE?
529   for (size_t i = 0; i < uint_config.size(); i++) {
530     EncodeUintConfig(uint_config[i], writer, log_alpha_size);
531   }
532 }
533 template void EncodeUintConfigs(const std::vector<HybridUintConfig>&,
534                                 BitWriter*, size_t);
535 
536 namespace {
537 
ChooseUintConfigs(const HistogramParams & params,const std::vector<std::vector<Token>> & tokens,const std::vector<uint8_t> & context_map,std::vector<Histogram> * clustered_histograms,EntropyEncodingData * codes,size_t * log_alpha_size)538 void ChooseUintConfigs(const HistogramParams& params,
539                        const std::vector<std::vector<Token>>& tokens,
540                        const std::vector<uint8_t>& context_map,
541                        std::vector<Histogram>* clustered_histograms,
542                        EntropyEncodingData* codes, size_t* log_alpha_size) {
543   codes->uint_config.resize(clustered_histograms->size());
544   if (params.uint_method == HistogramParams::HybridUintMethod::kNone) return;
545   if (params.uint_method == HistogramParams::HybridUintMethod::kContextMap) {
546     codes->uint_config.clear();
547     codes->uint_config.resize(clustered_histograms->size(),
548                               HybridUintConfig(2, 0, 1));
549     return;
550   }
551 
552   // Brute-force method that tries a few options.
553   std::vector<HybridUintConfig> configs;
554   if (params.uint_method == HistogramParams::HybridUintMethod::kBest) {
555     configs = {
556         HybridUintConfig(4, 2, 0),  // default
557         HybridUintConfig(4, 1, 0),  // less precise
558         HybridUintConfig(4, 2, 1),  // add sign
559         HybridUintConfig(4, 2, 2),  // add sign+parity
560         HybridUintConfig(4, 1, 2),  // add parity but less msb
561         // Same as above, but more direct coding.
562         HybridUintConfig(5, 2, 0), HybridUintConfig(5, 1, 0),
563         HybridUintConfig(5, 2, 1), HybridUintConfig(5, 2, 2),
564         HybridUintConfig(5, 1, 2),
565         // Same as above, but less direct coding.
566         HybridUintConfig(3, 2, 0), HybridUintConfig(3, 1, 0),
567         HybridUintConfig(3, 2, 1), HybridUintConfig(3, 1, 2),
568         // For near-lossless.
569         HybridUintConfig(4, 1, 3), HybridUintConfig(5, 1, 4),
570         HybridUintConfig(5, 2, 3), HybridUintConfig(6, 1, 5),
571         HybridUintConfig(6, 2, 4), HybridUintConfig(6, 0, 0),
572         // Other
573         HybridUintConfig(0, 0, 0),   // varlenuint
574         HybridUintConfig(2, 0, 1),   // works well for ctx map
575         HybridUintConfig(7, 0, 0),   // direct coding
576         HybridUintConfig(8, 0, 0),   // direct coding
577         HybridUintConfig(9, 0, 0),   // direct coding
578         HybridUintConfig(10, 0, 0),  // direct coding
579         HybridUintConfig(11, 0, 0),  // direct coding
580         HybridUintConfig(12, 0, 0),  // direct coding
581     };
582   } else if (params.uint_method == HistogramParams::HybridUintMethod::kFast) {
583     configs = {
584         HybridUintConfig(4, 2, 0),  // default
585         HybridUintConfig(4, 1, 2),  // add parity but less msb
586         HybridUintConfig(0, 0, 0),  // smallest histograms
587         HybridUintConfig(2, 0, 1),  // works well for ctx map
588     };
589   }
590 
591   std::vector<float> costs(clustered_histograms->size(),
592                            std::numeric_limits<float>::max());
593   std::vector<uint32_t> extra_bits(clustered_histograms->size());
594   std::vector<uint8_t> is_valid(clustered_histograms->size());
595   size_t max_alpha =
596       codes->use_prefix_code ? PREFIX_MAX_ALPHABET_SIZE : ANS_MAX_ALPHABET_SIZE;
597   for (HybridUintConfig cfg : configs) {
598     std::fill(is_valid.begin(), is_valid.end(), true);
599     std::fill(extra_bits.begin(), extra_bits.end(), 0);
600 
601     for (size_t i = 0; i < clustered_histograms->size(); i++) {
602       (*clustered_histograms)[i].Clear();
603     }
604     for (size_t i = 0; i < tokens.size(); ++i) {
605       for (size_t j = 0; j < tokens[i].size(); ++j) {
606         const Token token = tokens[i][j];
607         // TODO(veluca): do not ignore lz77 commands.
608         if (token.is_lz77_length) continue;
609         size_t histo = context_map[token.context];
610         uint32_t tok, nbits, bits;
611         cfg.Encode(token.value, &tok, &nbits, &bits);
612         if (tok >= max_alpha ||
613             (codes->lz77.enabled && tok >= codes->lz77.min_symbol)) {
614           is_valid[histo] = false;
615           continue;
616         }
617         extra_bits[histo] += nbits;
618         (*clustered_histograms)[histo].Add(tok);
619       }
620     }
621 
622     for (size_t i = 0; i < clustered_histograms->size(); i++) {
623       if (!is_valid[i]) continue;
624       float cost = (*clustered_histograms)[i].PopulationCost() + extra_bits[i];
625       if (cost < costs[i]) {
626         codes->uint_config[i] = cfg;
627         costs[i] = cost;
628       }
629     }
630   }
631 
632   // Rebuild histograms.
633   for (size_t i = 0; i < clustered_histograms->size(); i++) {
634     (*clustered_histograms)[i].Clear();
635   }
636   *log_alpha_size = 4;
637   for (size_t i = 0; i < tokens.size(); ++i) {
638     for (size_t j = 0; j < tokens[i].size(); ++j) {
639       const Token token = tokens[i][j];
640       uint32_t tok, nbits, bits;
641       size_t histo = context_map[token.context];
642       (token.is_lz77_length ? codes->lz77.length_uint_config
643                             : codes->uint_config[histo])
644           .Encode(token.value, &tok, &nbits, &bits);
645       tok += token.is_lz77_length ? codes->lz77.min_symbol : 0;
646       (*clustered_histograms)[histo].Add(tok);
647       while (tok >= (1u << *log_alpha_size)) (*log_alpha_size)++;
648     }
649   }
650 #if JXL_ENABLE_ASSERT
651   size_t max_log_alpha_size = codes->use_prefix_code ? PREFIX_MAX_BITS : 8;
652   JXL_ASSERT(*log_alpha_size <= max_log_alpha_size);
653 #endif
654 }
655 
656 class HistogramBuilder {
657  public:
HistogramBuilder(const size_t num_contexts)658   explicit HistogramBuilder(const size_t num_contexts)
659       : histograms_(num_contexts) {}
660 
VisitSymbol(int symbol,size_t histo_idx)661   void VisitSymbol(int symbol, size_t histo_idx) {
662     JXL_DASSERT(histo_idx < histograms_.size());
663     histograms_[histo_idx].Add(symbol);
664   }
665 
666   // NOTE: `layer` is only for clustered_entropy; caller does ReclaimAndCharge.
BuildAndStoreEntropyCodes(const HistogramParams & params,const std::vector<std::vector<Token>> & tokens,EntropyEncodingData * codes,std::vector<uint8_t> * context_map,bool use_prefix_code,BitWriter * writer,size_t layer,AuxOut * aux_out) const667   size_t BuildAndStoreEntropyCodes(
668       const HistogramParams& params,
669       const std::vector<std::vector<Token>>& tokens, EntropyEncodingData* codes,
670       std::vector<uint8_t>* context_map, bool use_prefix_code,
671       BitWriter* writer, size_t layer, AuxOut* aux_out) const {
672     size_t cost = 0;
673     codes->encoding_info.clear();
674     std::vector<Histogram> clustered_histograms(histograms_);
675     context_map->resize(histograms_.size());
676     if (histograms_.size() > 1) {
677       if (!ans_fuzzer_friendly_) {
678         std::vector<uint32_t> histogram_symbols;
679         ClusterHistograms(params, histograms_, histograms_.size(),
680                           kClustersLimit, &clustered_histograms,
681                           &histogram_symbols);
682         for (size_t c = 0; c < histograms_.size(); ++c) {
683           (*context_map)[c] = static_cast<uint8_t>(histogram_symbols[c]);
684         }
685       } else {
686         fill(context_map->begin(), context_map->end(), 0);
687         size_t max_symbol = 0;
688         for (const Histogram& h : histograms_) {
689           max_symbol = std::max(h.data_.size(), max_symbol);
690         }
691         size_t num_symbols = 1 << CeilLog2Nonzero(max_symbol + 1);
692         clustered_histograms.resize(1);
693         clustered_histograms[0].Clear();
694         for (size_t i = 0; i < num_symbols; i++) {
695           clustered_histograms[0].Add(i);
696         }
697       }
698       if (writer != nullptr) {
699         EncodeContextMap(*context_map, clustered_histograms.size(), writer);
700       }
701     }
702     if (aux_out != nullptr) {
703       for (size_t i = 0; i < clustered_histograms.size(); ++i) {
704         aux_out->layers[layer].clustered_entropy +=
705             clustered_histograms[i].ShannonEntropy();
706       }
707     }
708     codes->use_prefix_code = use_prefix_code;
709     size_t log_alpha_size = codes->lz77.enabled ? 8 : 7;  // Sane default.
710     if (ans_fuzzer_friendly_) {
711       codes->uint_config.clear();
712       codes->uint_config.resize(1, HybridUintConfig(7, 0, 0));
713     } else {
714       ChooseUintConfigs(params, tokens, *context_map, &clustered_histograms,
715                         codes, &log_alpha_size);
716     }
717     if (log_alpha_size < 5) log_alpha_size = 5;
718     SizeWriter size_writer;  // Used if writer == nullptr to estimate costs.
719     cost += 1;
720     if (writer) writer->Write(1, use_prefix_code);
721 
722     if (use_prefix_code) {
723       log_alpha_size = PREFIX_MAX_BITS;
724     } else {
725       cost += 2;
726     }
727     if (writer == nullptr) {
728       EncodeUintConfigs(codes->uint_config, &size_writer, log_alpha_size);
729     } else {
730       if (!use_prefix_code) writer->Write(2, log_alpha_size - 5);
731       EncodeUintConfigs(codes->uint_config, writer, log_alpha_size);
732     }
733     if (use_prefix_code) {
734       for (size_t c = 0; c < clustered_histograms.size(); ++c) {
735         size_t num_symbol = 1;
736         for (size_t i = 0; i < clustered_histograms[c].data_.size(); i++) {
737           if (clustered_histograms[c].data_[i]) num_symbol = i + 1;
738         }
739         if (writer) {
740           StoreVarLenUint16(num_symbol - 1, writer);
741         } else {
742           StoreVarLenUint16(num_symbol - 1, &size_writer);
743         }
744       }
745     }
746     cost += size_writer.size;
747     for (size_t c = 0; c < clustered_histograms.size(); ++c) {
748       size_t num_symbol = 1;
749       for (size_t i = 0; i < clustered_histograms[c].data_.size(); i++) {
750         if (clustered_histograms[c].data_[i]) num_symbol = i + 1;
751       }
752       codes->encoding_info.emplace_back();
753       codes->encoding_info.back().resize(std::max<size_t>(1, num_symbol));
754 
755       BitWriter::Allotment allotment(writer, 256 + num_symbol * 24);
756       cost += BuildAndStoreANSEncodingData(
757           params.ans_histogram_strategy, clustered_histograms[c].data_.data(),
758           num_symbol, log_alpha_size, use_prefix_code,
759           codes->encoding_info.back().data(), writer);
760       allotment.FinishedHistogram(writer);
761       ReclaimAndCharge(writer, &allotment, layer, aux_out);
762     }
763     return cost;
764   }
765 
Histo(size_t i) const766   const Histogram& Histo(size_t i) const { return histograms_[i]; }
767 
768  private:
769   std::vector<Histogram> histograms_;
770 };
771 
772 class SymbolCostEstimator {
773  public:
SymbolCostEstimator(size_t num_contexts,bool force_huffman,const std::vector<std::vector<Token>> & tokens,const LZ77Params & lz77)774   SymbolCostEstimator(size_t num_contexts, bool force_huffman,
775                       const std::vector<std::vector<Token>>& tokens,
776                       const LZ77Params& lz77) {
777     HistogramBuilder builder(num_contexts);
778     // Build histograms for estimating lz77 savings.
779     HybridUintConfig uint_config;
780     for (size_t i = 0; i < tokens.size(); ++i) {
781       for (size_t j = 0; j < tokens[i].size(); ++j) {
782         const Token token = tokens[i][j];
783         uint32_t tok, nbits, bits;
784         (token.is_lz77_length ? lz77.length_uint_config : uint_config)
785             .Encode(token.value, &tok, &nbits, &bits);
786         tok += token.is_lz77_length ? lz77.min_symbol : 0;
787         builder.VisitSymbol(tok, token.context);
788       }
789     }
790     max_alphabet_size_ = 0;
791     for (size_t i = 0; i < num_contexts; i++) {
792       max_alphabet_size_ =
793           std::max(max_alphabet_size_, builder.Histo(i).data_.size());
794     }
795     bits_.resize(num_contexts * max_alphabet_size_);
796     // TODO(veluca): SIMD?
797     add_symbol_cost_.resize(num_contexts);
798     for (size_t i = 0; i < num_contexts; i++) {
799       float inv_total = 1.0f / (builder.Histo(i).total_count_ + 1e-8f);
800       float total_cost = 0;
801       for (size_t j = 0; j < builder.Histo(i).data_.size(); j++) {
802         size_t cnt = builder.Histo(i).data_[j];
803         float cost = 0;
804         if (cnt != 0 && cnt != builder.Histo(i).total_count_) {
805           cost = -FastLog2f(cnt * inv_total);
806           if (force_huffman) cost = std::ceil(cost);
807         } else if (cnt == 0) {
808           cost = ANS_LOG_TAB_SIZE;  // Highest possible cost.
809         }
810         bits_[i * max_alphabet_size_ + j] = cost;
811         total_cost += cost * builder.Histo(i).data_[j];
812       }
813       // Penalty for adding a lz77 symbol to this contest (only used for static
814       // cost model). Higher penalty for contexts that have a very low
815       // per-symbol entropy.
816       add_symbol_cost_[i] = std::max(0.0f, 6.0f - total_cost * inv_total);
817     }
818   }
Bits(size_t ctx,size_t sym) const819   float Bits(size_t ctx, size_t sym) const {
820     return bits_[ctx * max_alphabet_size_ + sym];
821   }
LenCost(size_t ctx,size_t len,const LZ77Params & lz77) const822   float LenCost(size_t ctx, size_t len, const LZ77Params& lz77) const {
823     uint32_t nbits, bits, tok;
824     lz77.length_uint_config.Encode(len, &tok, &nbits, &bits);
825     tok += lz77.min_symbol;
826     return nbits + Bits(ctx, tok);
827   }
DistCost(size_t len,const LZ77Params & lz77) const828   float DistCost(size_t len, const LZ77Params& lz77) const {
829     uint32_t nbits, bits, tok;
830     HybridUintConfig().Encode(len, &tok, &nbits, &bits);
831     return nbits + Bits(lz77.nonserialized_distance_context, tok);
832   }
AddSymbolCost(size_t idx) const833   float AddSymbolCost(size_t idx) const { return add_symbol_cost_[idx]; }
834 
835  private:
836   size_t max_alphabet_size_;
837   std::vector<float> bits_;
838   std::vector<float> add_symbol_cost_;
839 };
840 
ApplyLZ77_RLE(const HistogramParams & params,size_t num_contexts,const std::vector<std::vector<Token>> & tokens,LZ77Params & lz77,std::vector<std::vector<Token>> & tokens_lz77)841 void ApplyLZ77_RLE(const HistogramParams& params, size_t num_contexts,
842                    const std::vector<std::vector<Token>>& tokens,
843                    LZ77Params& lz77,
844                    std::vector<std::vector<Token>>& tokens_lz77) {
845   // TODO(veluca): tune heuristics here.
846   SymbolCostEstimator sce(num_contexts, params.force_huffman, tokens, lz77);
847   float bit_decrease = 0;
848   size_t total_symbols = 0;
849   tokens_lz77.resize(tokens.size());
850   std::vector<float> sym_cost;
851   HybridUintConfig uint_config;
852   for (size_t stream = 0; stream < tokens.size(); stream++) {
853     size_t distance_multiplier =
854         params.image_widths.size() > stream ? params.image_widths[stream] : 0;
855     const auto& in = tokens[stream];
856     auto& out = tokens_lz77[stream];
857     total_symbols += in.size();
858     // Cumulative sum of bit costs.
859     sym_cost.resize(in.size() + 1);
860     for (size_t i = 0; i < in.size(); i++) {
861       uint32_t tok, nbits, unused_bits;
862       uint_config.Encode(in[i].value, &tok, &nbits, &unused_bits);
863       sym_cost[i + 1] = sce.Bits(in[i].context, tok) + nbits + sym_cost[i];
864     }
865     out.reserve(in.size());
866     for (size_t i = 0; i < in.size(); i++) {
867       size_t num_to_copy = 0;
868       size_t distance_symbol = 0;  // 1 for RLE.
869       if (distance_multiplier != 0) {
870         distance_symbol = 1;  // Special distance 1 if enabled.
871         JXL_DASSERT(kSpecialDistances[1][0] == 1);
872         JXL_DASSERT(kSpecialDistances[1][1] == 0);
873       }
874       if (i > 0) {
875         for (; i + num_to_copy < in.size(); num_to_copy++) {
876           if (in[i + num_to_copy].value != in[i - 1].value) {
877             break;
878           }
879         }
880       }
881       if (num_to_copy == 0) {
882         out.push_back(in[i]);
883         continue;
884       }
885       float cost = sym_cost[i + num_to_copy] - sym_cost[i];
886       // This subtraction might overflow, but that's OK.
887       size_t lz77_len = num_to_copy - lz77.min_length;
888       float lz77_cost = num_to_copy >= lz77.min_length
889                             ? CeilLog2Nonzero(lz77_len + 1) + 1
890                             : 0;
891       if (num_to_copy < lz77.min_length || cost <= lz77_cost) {
892         for (size_t j = 0; j < num_to_copy; j++) {
893           out.push_back(in[i + j]);
894         }
895         i += num_to_copy - 1;
896         continue;
897       }
898       // Output the LZ77 length
899       out.emplace_back(in[i].context, lz77_len);
900       out.back().is_lz77_length = true;
901       i += num_to_copy - 1;
902       bit_decrease += cost - lz77_cost;
903       // Output the LZ77 copy distance.
904       out.emplace_back(lz77.nonserialized_distance_context, distance_symbol);
905     }
906   }
907 
908   if (bit_decrease > total_symbols * 0.2 + 16) {
909     lz77.enabled = true;
910   }
911 }
912 
913 // Hash chain for LZ77 matching
914 struct HashChain {
915   size_t size_;
916   std::vector<uint32_t> data_;
917 
918   unsigned hash_num_values_ = 32768;
919   unsigned hash_mask_ = hash_num_values_ - 1;
920   unsigned hash_shift_ = 5;
921 
922   std::vector<int> head;
923   std::vector<uint32_t> chain;
924   std::vector<int> val;
925 
926   // Speed up repetitions of zero
927   std::vector<int> headz;
928   std::vector<uint32_t> chainz;
929   std::vector<uint32_t> zeros;
930   uint32_t numzeros = 0;
931 
932   size_t window_size_;
933   size_t window_mask_;
934   size_t min_length_;
935   size_t max_length_;
936 
937   // Map of special distance codes.
938   std::unordered_map<int, int> special_dist_table_;
939   size_t num_special_distances_ = 0;
940 
941   uint32_t maxchainlength = 256;  // window_size_ to allow all
942 
HashChainjxl::__anon551ae0000211::HashChain943   HashChain(const Token* data, size_t size, size_t window_size,
944             size_t min_length, size_t max_length, size_t distance_multiplier)
945       : size_(size),
946         window_size_(window_size),
947         window_mask_(window_size - 1),
948         min_length_(min_length),
949         max_length_(max_length) {
950     data_.resize(size);
951     for (size_t i = 0; i < size; i++) {
952       data_[i] = data[i].value;
953     }
954 
955     head.resize(hash_num_values_, -1);
956     val.resize(window_size_, -1);
957     chain.resize(window_size_);
958     for (uint32_t i = 0; i < window_size_; ++i) {
959       chain[i] = i;  // same value as index indicates uninitialized
960     }
961 
962     zeros.resize(window_size_);
963     headz.resize(window_size_ + 1, -1);
964     chainz.resize(window_size_);
965     for (uint32_t i = 0; i < window_size_; ++i) {
966       chainz[i] = i;
967     }
968     // Translate distance to special distance code.
969     if (distance_multiplier) {
970       // Count down, so if due to small distance multiplier multiple distances
971       // map to the same code, the smallest code will be used in the end.
972       for (int i = kNumSpecialDistances - 1; i >= 0; --i) {
973         int xi = kSpecialDistances[i][0];
974         int yi = kSpecialDistances[i][1];
975         int distance = yi * distance_multiplier + xi;
976         // Ensure that we map distance 1 to the lowest symbols.
977         if (distance < 1) distance = 1;
978         special_dist_table_[distance] = i;
979       }
980       num_special_distances_ = kNumSpecialDistances;
981     }
982   }
983 
GetHashjxl::__anon551ae0000211::HashChain984   uint32_t GetHash(size_t pos) const {
985     uint32_t result = 0;
986     if (pos + 2 < size_) {
987       // TODO(lode): take the MSB's of the uint32_t values into account as well,
988       // given that the hash code itself is less than 32 bits.
989       result ^= (uint32_t)(data_[pos + 0] << 0u);
990       result ^= (uint32_t)(data_[pos + 1] << hash_shift_);
991       result ^= (uint32_t)(data_[pos + 2] << (hash_shift_ * 2));
992     } else {
993       // No need to compute hash of last 2 bytes, the length 2 is too short.
994       return 0;
995     }
996     return result & hash_mask_;
997   }
998 
CountZerosjxl::__anon551ae0000211::HashChain999   uint32_t CountZeros(size_t pos, uint32_t prevzeros) const {
1000     size_t end = pos + window_size_;
1001     if (end > size_) end = size_;
1002     if (prevzeros > 0) {
1003       if (prevzeros >= window_mask_ && data_[end - 1] == 0 &&
1004           end == pos + window_size_) {
1005         return prevzeros;
1006       } else {
1007         return prevzeros - 1;
1008       }
1009     }
1010     uint32_t num = 0;
1011     while (pos + num < end && data_[pos + num] == 0) num++;
1012     return num;
1013   }
1014 
Updatejxl::__anon551ae0000211::HashChain1015   void Update(size_t pos) {
1016     uint32_t hashval = GetHash(pos);
1017     uint32_t wpos = pos & window_mask_;
1018 
1019     val[wpos] = (int)hashval;
1020     if (head[hashval] != -1) chain[wpos] = head[hashval];
1021     head[hashval] = wpos;
1022 
1023     if (pos > 0 && data_[pos] != data_[pos - 1]) numzeros = 0;
1024     numzeros = CountZeros(pos, numzeros);
1025 
1026     zeros[wpos] = numzeros;
1027     if (headz[numzeros] != -1) chainz[wpos] = headz[numzeros];
1028     headz[numzeros] = wpos;
1029   }
1030 
Updatejxl::__anon551ae0000211::HashChain1031   void Update(size_t pos, size_t len) {
1032     for (size_t i = 0; i < len; i++) {
1033       Update(pos + i);
1034     }
1035   }
1036 
1037   template <typename CB>
FindMatchesjxl::__anon551ae0000211::HashChain1038   void FindMatches(size_t pos, int max_dist, const CB& found_match) const {
1039     uint32_t wpos = pos & window_mask_;
1040     uint32_t hashval = GetHash(pos);
1041     uint32_t hashpos = chain[wpos];
1042 
1043     int prev_dist = 0;
1044     int end = std::min<int>(pos + max_length_, size_);
1045     uint32_t chainlength = 0;
1046     uint32_t best_len = 0;
1047     for (;;) {
1048       int dist = (hashpos <= wpos) ? (wpos - hashpos)
1049                                    : (wpos - hashpos + window_mask_ + 1);
1050       if (dist < prev_dist) break;
1051       prev_dist = dist;
1052       uint32_t len = 0;
1053       if (dist > 0) {
1054         int i = pos;
1055         int j = pos - dist;
1056         if (numzeros > 3) {
1057           int r = std::min<int>(numzeros - 1, zeros[hashpos]);
1058           if (i + r >= end) r = end - i - 1;
1059           i += r;
1060           j += r;
1061         }
1062         while (i < end && data_[i] == data_[j]) {
1063           i++;
1064           j++;
1065         }
1066         len = i - pos;
1067         // This can trigger even if the new length is slightly smaller than the
1068         // best length, because it is possible for a slightly cheaper distance
1069         // symbol to occur.
1070         if (len >= min_length_ && len + 2 >= best_len) {
1071           auto it = special_dist_table_.find(dist);
1072           int dist_symbol = (it == special_dist_table_.end())
1073                                 ? (num_special_distances_ + dist - 1)
1074                                 : it->second;
1075           found_match(len, dist_symbol);
1076           if (len > best_len) best_len = len;
1077         }
1078       }
1079 
1080       chainlength++;
1081       if (chainlength >= maxchainlength) break;
1082 
1083       if (numzeros >= 3 && len > numzeros) {
1084         if (hashpos == chainz[hashpos]) break;
1085         hashpos = chainz[hashpos];
1086         if (zeros[hashpos] != numzeros) break;
1087       } else {
1088         if (hashpos == chain[hashpos]) break;
1089         hashpos = chain[hashpos];
1090         if (val[hashpos] != (int)hashval) break;  // outdated hash value
1091       }
1092     }
1093   }
FindMatchjxl::__anon551ae0000211::HashChain1094   void FindMatch(size_t pos, int max_dist, size_t* result_dist_symbol,
1095                  size_t* result_len) const {
1096     *result_dist_symbol = 0;
1097     *result_len = 1;
1098     FindMatches(pos, max_dist, [&](size_t len, size_t dist_symbol) {
1099       if (len > *result_len ||
1100           (len == *result_len && *result_dist_symbol > dist_symbol)) {
1101         *result_len = len;
1102         *result_dist_symbol = dist_symbol;
1103       }
1104     });
1105   }
1106 };
1107 
LenCost(size_t len)1108 float LenCost(size_t len) {
1109   uint32_t nbits, bits, tok;
1110   HybridUintConfig(1, 0, 0).Encode(len, &tok, &nbits, &bits);
1111   constexpr float kCostTable[] = {
1112       2.797667318563126,  3.213177690381199,  2.5706009246743737,
1113       2.408392498667534,  2.829649191872326,  3.3923087753324577,
1114       4.029267451554331,  4.415576699706408,  4.509357574741465,
1115       9.21481543803004,   10.020590190114898, 11.858671627804766,
1116       12.45853300490526,  11.713105831990857, 12.561996324849314,
1117       13.775477692278367, 13.174027068768641,
1118   };
1119   size_t table_size = sizeof kCostTable / sizeof *kCostTable;
1120   if (tok >= table_size) tok = table_size - 1;
1121   return kCostTable[tok] + nbits;
1122 }
1123 
1124 // TODO(veluca): this does not take into account usage or non-usage of distance
1125 // multipliers.
DistCost(size_t dist)1126 float DistCost(size_t dist) {
1127   uint32_t nbits, bits, tok;
1128   HybridUintConfig(7, 0, 0).Encode(dist, &tok, &nbits, &bits);
1129   constexpr float kCostTable[] = {
1130       6.368282626312716,  5.680793277090298,  8.347404197105247,
1131       7.641619201599141,  6.914328374119438,  7.959808291537444,
1132       8.70023120759855,   8.71378518934703,   9.379132523982769,
1133       9.110472749092708,  9.159029569270908,  9.430936766731973,
1134       7.278284055315169,  7.8278514904267755, 10.026641158289236,
1135       9.976049229827066,  9.64351607048908,   9.563403863480442,
1136       10.171474111762747, 10.45950155077234,  9.994813912104219,
1137       10.322524683741156, 8.465808729388186,  8.756254166066853,
1138       10.160930174662234, 10.247329273413435, 10.04090403724809,
1139       10.129398517544082, 9.342311691539546,  9.07608009102374,
1140       10.104799540677513, 10.378079384990906, 10.165828974075072,
1141       10.337595322341553, 7.940557464567944,  10.575665823319431,
1142       11.023344321751955, 10.736144698831827, 11.118277044595054,
1143       7.468468230648442,  10.738305230932939, 10.906980780216568,
1144       10.163468216353817, 10.17805759656433,  11.167283670483565,
1145       11.147050200274544, 10.517921919244333, 10.651764778156886,
1146       10.17074446448919,  11.217636876224745, 11.261630721139484,
1147       11.403140815247259, 10.892472096873417, 11.1859607804481,
1148       8.017346947551262,  7.895143720278828,  11.036577113822025,
1149       11.170562110315794, 10.326988722591086, 10.40872184751056,
1150       11.213498225466386, 11.30580635516863,  10.672272515665442,
1151       10.768069466228063, 11.145257364153565, 11.64668307145549,
1152       10.593156194627339, 11.207499484844943, 10.767517766396908,
1153       10.826629811407042, 10.737764794499988, 10.6200448518045,
1154       10.191315385198092, 8.468384171390085,  11.731295299170432,
1155       11.824619886654398, 10.41518844301179,  10.16310536548649,
1156       10.539423685097576, 10.495136599328031, 10.469112847728267,
1157       11.72057686174922,  10.910326337834674, 11.378921834673758,
1158       11.847759036098536, 11.92071647623854,  10.810628276345282,
1159       11.008601085273893, 11.910326337834674, 11.949212023423133,
1160       11.298614839104337, 11.611603659010392, 10.472930394619985,
1161       11.835564720850282, 11.523267392285337, 12.01055816679611,
1162       8.413029688994023,  11.895784139536406, 11.984679534970505,
1163       11.220654278717394, 11.716311684833672, 10.61036646226114,
1164       10.89849965960364,  10.203762898863669, 10.997560826267238,
1165       11.484217379438984, 11.792836176993665, 12.24310468755171,
1166       11.464858097919262, 12.212747017409377, 11.425595666074955,
1167       11.572048533398757, 12.742093965163013, 11.381874288645637,
1168       12.191870445817015, 11.683156920035426, 11.152442115262197,
1169       11.90303691580457,  11.653292787169159, 11.938615382266098,
1170       16.970641701570223, 16.853602280380002, 17.26240782594733,
1171       16.644655390108507, 17.14310889757499,  16.910935455445955,
1172       17.505678976959697, 17.213498225466388, 2.4162310293553024,
1173       3.494587244462329,  3.5258600986408344, 3.4959806589517095,
1174       3.098390886949687,  3.343454654302911,  3.588847442290287,
1175       4.14614790111827,   5.152948641990529,  7.433696808092598,
1176       9.716311684833672,
1177   };
1178   size_t table_size = sizeof kCostTable / sizeof *kCostTable;
1179   if (tok >= table_size) tok = table_size - 1;
1180   return kCostTable[tok] + nbits;
1181 }
1182 
ApplyLZ77_LZ77(const HistogramParams & params,size_t num_contexts,const std::vector<std::vector<Token>> & tokens,LZ77Params & lz77,std::vector<std::vector<Token>> & tokens_lz77)1183 void ApplyLZ77_LZ77(const HistogramParams& params, size_t num_contexts,
1184                     const std::vector<std::vector<Token>>& tokens,
1185                     LZ77Params& lz77,
1186                     std::vector<std::vector<Token>>& tokens_lz77) {
1187   // TODO(veluca): tune heuristics here.
1188   SymbolCostEstimator sce(num_contexts, params.force_huffman, tokens, lz77);
1189   float bit_decrease = 0;
1190   size_t total_symbols = 0;
1191   tokens_lz77.resize(tokens.size());
1192   HybridUintConfig uint_config;
1193   std::vector<float> sym_cost;
1194   for (size_t stream = 0; stream < tokens.size(); stream++) {
1195     size_t distance_multiplier =
1196         params.image_widths.size() > stream ? params.image_widths[stream] : 0;
1197     const auto& in = tokens[stream];
1198     auto& out = tokens_lz77[stream];
1199     total_symbols += in.size();
1200     // Cumulative sum of bit costs.
1201     sym_cost.resize(in.size() + 1);
1202     for (size_t i = 0; i < in.size(); i++) {
1203       uint32_t tok, nbits, unused_bits;
1204       uint_config.Encode(in[i].value, &tok, &nbits, &unused_bits);
1205       sym_cost[i + 1] = sce.Bits(in[i].context, tok) + nbits + sym_cost[i];
1206     }
1207 
1208     out.reserve(in.size());
1209     size_t max_distance = in.size();
1210     size_t min_length = lz77.min_length;
1211     JXL_ASSERT(min_length >= 3);
1212     size_t max_length = in.size();
1213 
1214     // Use next power of two as window size.
1215     size_t window_size = 1;
1216     while (window_size < max_distance && window_size < kWindowSize) {
1217       window_size <<= 1;
1218     }
1219 
1220     HashChain chain(in.data(), in.size(), window_size, min_length, max_length,
1221                     distance_multiplier);
1222     size_t len, dist_symbol;
1223 
1224     const size_t max_lazy_match_len = 256;  // 0 to disable lazy matching
1225 
1226     // Whether the next symbol was already updated (to test lazy matching)
1227     bool already_updated = false;
1228     for (size_t i = 0; i < in.size(); i++) {
1229       out.push_back(in[i]);
1230       if (!already_updated) chain.Update(i);
1231       already_updated = false;
1232       chain.FindMatch(i, max_distance, &dist_symbol, &len);
1233       if (len >= min_length) {
1234         if (len < max_lazy_match_len && i + 1 < in.size()) {
1235           // Try length at next symbol lazy matching
1236           chain.Update(i + 1);
1237           already_updated = true;
1238           size_t len2, dist_symbol2;
1239           chain.FindMatch(i + 1, max_distance, &dist_symbol2, &len2);
1240           if (len2 > len) {
1241             // Use the lazy match. Add literal, and use the next length starting
1242             // from the next byte.
1243             ++i;
1244             already_updated = false;
1245             len = len2;
1246             dist_symbol = dist_symbol2;
1247             out.push_back(in[i]);
1248           }
1249         }
1250 
1251         float cost = sym_cost[i + len] - sym_cost[i];
1252         size_t lz77_len = len - lz77.min_length;
1253         float lz77_cost = LenCost(lz77_len) + DistCost(dist_symbol) +
1254                           sce.AddSymbolCost(out.back().context);
1255 
1256         if (lz77_cost <= cost) {
1257           out.back().value = len - min_length;
1258           out.back().is_lz77_length = true;
1259           out.emplace_back(lz77.nonserialized_distance_context, dist_symbol);
1260           bit_decrease += cost - lz77_cost;
1261         } else {
1262           // LZ77 match ignored, and symbol already pushed. Push all other
1263           // symbols and skip.
1264           for (size_t j = 1; j < len; j++) {
1265             out.push_back(in[i + j]);
1266           }
1267         }
1268 
1269         if (already_updated) {
1270           chain.Update(i + 2, len - 2);
1271           already_updated = false;
1272         } else {
1273           chain.Update(i + 1, len - 1);
1274         }
1275         i += len - 1;
1276       } else {
1277         // Literal, already pushed
1278       }
1279     }
1280   }
1281 
1282   if (bit_decrease > total_symbols * 0.2 + 16) {
1283     lz77.enabled = true;
1284   }
1285 }
1286 
ApplyLZ77_Optimal(const HistogramParams & params,size_t num_contexts,const std::vector<std::vector<Token>> & tokens,LZ77Params & lz77,std::vector<std::vector<Token>> & tokens_lz77)1287 void ApplyLZ77_Optimal(const HistogramParams& params, size_t num_contexts,
1288                        const std::vector<std::vector<Token>>& tokens,
1289                        LZ77Params& lz77,
1290                        std::vector<std::vector<Token>>& tokens_lz77) {
1291   std::vector<std::vector<Token>> tokens_for_cost_estimate;
1292   ApplyLZ77_LZ77(params, num_contexts, tokens, lz77, tokens_for_cost_estimate);
1293   // If greedy-LZ77 does not give better compression than no-lz77, no reason to
1294   // run the optimal matching.
1295   if (!lz77.enabled) return;
1296   SymbolCostEstimator sce(num_contexts + 1, params.force_huffman,
1297                           tokens_for_cost_estimate, lz77);
1298   tokens_lz77.resize(tokens.size());
1299   HybridUintConfig uint_config;
1300   std::vector<float> sym_cost;
1301   std::vector<uint32_t> dist_symbols;
1302   for (size_t stream = 0; stream < tokens.size(); stream++) {
1303     size_t distance_multiplier =
1304         params.image_widths.size() > stream ? params.image_widths[stream] : 0;
1305     const auto& in = tokens[stream];
1306     auto& out = tokens_lz77[stream];
1307     // Cumulative sum of bit costs.
1308     sym_cost.resize(in.size() + 1);
1309     for (size_t i = 0; i < in.size(); i++) {
1310       uint32_t tok, nbits, unused_bits;
1311       uint_config.Encode(in[i].value, &tok, &nbits, &unused_bits);
1312       sym_cost[i + 1] = sce.Bits(in[i].context, tok) + nbits + sym_cost[i];
1313     }
1314 
1315     out.reserve(in.size());
1316     size_t max_distance = in.size();
1317     size_t min_length = lz77.min_length;
1318     JXL_ASSERT(min_length >= 3);
1319     size_t max_length = in.size();
1320 
1321     // Use next power of two as window size.
1322     size_t window_size = 1;
1323     while (window_size < max_distance && window_size < kWindowSize) {
1324       window_size <<= 1;
1325     }
1326 
1327     HashChain chain(in.data(), in.size(), window_size, min_length, max_length,
1328                     distance_multiplier);
1329 
1330     struct MatchInfo {
1331       uint32_t len;
1332       uint32_t dist_symbol;
1333       uint32_t ctx;
1334       float total_cost = std::numeric_limits<float>::max();
1335     };
1336     // Total cost to encode the first N symbols.
1337     std::vector<MatchInfo> prefix_costs(in.size() + 1);
1338     prefix_costs[0].total_cost = 0;
1339 
1340     size_t rle_length = 0;
1341     size_t skip_lz77 = 0;
1342     for (size_t i = 0; i < in.size(); i++) {
1343       chain.Update(i);
1344       float lit_cost =
1345           prefix_costs[i].total_cost + sym_cost[i + 1] - sym_cost[i];
1346       if (prefix_costs[i + 1].total_cost > lit_cost) {
1347         prefix_costs[i + 1].dist_symbol = 0;
1348         prefix_costs[i + 1].len = 1;
1349         prefix_costs[i + 1].ctx = in[i].context;
1350         prefix_costs[i + 1].total_cost = lit_cost;
1351       }
1352       if (skip_lz77 > 0) {
1353         skip_lz77--;
1354         continue;
1355       }
1356       dist_symbols.clear();
1357       chain.FindMatches(i, max_distance,
1358                         [&dist_symbols](size_t len, size_t dist_symbol) {
1359                           if (dist_symbols.size() <= len) {
1360                             dist_symbols.resize(len + 1, dist_symbol);
1361                           }
1362                           if (dist_symbol < dist_symbols[len]) {
1363                             dist_symbols[len] = dist_symbol;
1364                           }
1365                         });
1366       if (dist_symbols.size() <= min_length) continue;
1367       {
1368         size_t best_cost = dist_symbols.back();
1369         for (size_t j = dist_symbols.size() - 1; j >= min_length; j--) {
1370           if (dist_symbols[j] < best_cost) {
1371             best_cost = dist_symbols[j];
1372           }
1373           dist_symbols[j] = best_cost;
1374         }
1375       }
1376       for (size_t j = min_length; j < dist_symbols.size(); j++) {
1377         // Cost model that uses results from lazy LZ77.
1378         float lz77_cost = sce.LenCost(in[i].context, j - min_length, lz77) +
1379                           sce.DistCost(dist_symbols[j], lz77);
1380         float cost = prefix_costs[i].total_cost + lz77_cost;
1381         if (prefix_costs[i + j].total_cost > cost) {
1382           prefix_costs[i + j].len = j;
1383           prefix_costs[i + j].dist_symbol = dist_symbols[j] + 1;
1384           prefix_costs[i + j].ctx = in[i].context;
1385           prefix_costs[i + j].total_cost = cost;
1386         }
1387       }
1388       // We are in a RLE sequence: skip all the symbols except the first 8 and
1389       // the last 8. This avoid quadratic costs for sequences with long runs of
1390       // the same symbol.
1391       if ((dist_symbols.back() == 0 && distance_multiplier == 0) ||
1392           (dist_symbols.back() == 1 && distance_multiplier != 0)) {
1393         rle_length++;
1394       } else {
1395         rle_length = 0;
1396       }
1397       if (rle_length >= 8 && dist_symbols.size() > 9) {
1398         skip_lz77 = dist_symbols.size() - 10;
1399         rle_length = 0;
1400       }
1401     }
1402     size_t pos = in.size();
1403     while (pos > 0) {
1404       bool is_lz77_length = prefix_costs[pos].dist_symbol != 0;
1405       if (is_lz77_length) {
1406         size_t dist_symbol = prefix_costs[pos].dist_symbol - 1;
1407         out.emplace_back(lz77.nonserialized_distance_context, dist_symbol);
1408       }
1409       size_t val = is_lz77_length ? prefix_costs[pos].len - min_length
1410                                   : in[pos - 1].value;
1411       out.emplace_back(prefix_costs[pos].ctx, val);
1412       out.back().is_lz77_length = is_lz77_length;
1413       pos -= prefix_costs[pos].len;
1414     }
1415     std::reverse(out.begin(), out.end());
1416   }
1417 }
1418 
ApplyLZ77(const HistogramParams & params,size_t num_contexts,const std::vector<std::vector<Token>> & tokens,LZ77Params & lz77,std::vector<std::vector<Token>> & tokens_lz77)1419 void ApplyLZ77(const HistogramParams& params, size_t num_contexts,
1420                const std::vector<std::vector<Token>>& tokens, LZ77Params& lz77,
1421                std::vector<std::vector<Token>>& tokens_lz77) {
1422   lz77.enabled = false;
1423   if (params.force_huffman) {
1424     lz77.min_symbol = std::min(PREFIX_MAX_ALPHABET_SIZE - 32, 512);
1425   } else {
1426     lz77.min_symbol = 224;
1427   }
1428   if (params.lz77_method == HistogramParams::LZ77Method::kNone) {
1429     return;
1430   } else if (params.lz77_method == HistogramParams::LZ77Method::kRLE) {
1431     ApplyLZ77_RLE(params, num_contexts, tokens, lz77, tokens_lz77);
1432   } else if (params.lz77_method == HistogramParams::LZ77Method::kLZ77) {
1433     ApplyLZ77_LZ77(params, num_contexts, tokens, lz77, tokens_lz77);
1434   } else if (params.lz77_method == HistogramParams::LZ77Method::kOptimal) {
1435     ApplyLZ77_Optimal(params, num_contexts, tokens, lz77, tokens_lz77);
1436   } else {
1437     JXL_ABORT("Not implemented");
1438   }
1439 }
1440 }  // namespace
1441 
BuildAndEncodeHistograms(const HistogramParams & params,size_t num_contexts,std::vector<std::vector<Token>> & tokens,EntropyEncodingData * codes,std::vector<uint8_t> * context_map,BitWriter * writer,size_t layer,AuxOut * aux_out)1442 size_t BuildAndEncodeHistograms(const HistogramParams& params,
1443                                 size_t num_contexts,
1444                                 std::vector<std::vector<Token>>& tokens,
1445                                 EntropyEncodingData* codes,
1446                                 std::vector<uint8_t>* context_map,
1447                                 BitWriter* writer, size_t layer,
1448                                 AuxOut* aux_out) {
1449   size_t total_bits = 0;
1450   codes->lz77.nonserialized_distance_context = num_contexts;
1451   std::vector<std::vector<Token>> tokens_lz77;
1452   ApplyLZ77(params, num_contexts, tokens, codes->lz77, tokens_lz77);
1453   if (ans_fuzzer_friendly_) {
1454     codes->lz77.length_uint_config = HybridUintConfig(10, 0, 0);
1455     codes->lz77.min_symbol = 2048;
1456   }
1457 
1458   const size_t max_contexts = std::min(num_contexts, kClustersLimit);
1459   BitWriter::Allotment allotment(writer,
1460                                  128 + num_contexts * 40 + max_contexts * 96);
1461   if (writer) {
1462     JXL_CHECK(Bundle::Write(codes->lz77, writer, layer, aux_out));
1463   } else {
1464     size_t ebits, bits;
1465     JXL_CHECK(Bundle::CanEncode(codes->lz77, &ebits, &bits));
1466     total_bits += bits;
1467   }
1468   if (codes->lz77.enabled) {
1469     if (writer) {
1470       size_t b = writer->BitsWritten();
1471       EncodeUintConfig(codes->lz77.length_uint_config, writer,
1472                        /*log_alpha_size=*/8);
1473       total_bits += writer->BitsWritten() - b;
1474     } else {
1475       SizeWriter size_writer;
1476       EncodeUintConfig(codes->lz77.length_uint_config, &size_writer,
1477                        /*log_alpha_size=*/8);
1478       total_bits += size_writer.size;
1479     }
1480     num_contexts += 1;
1481     tokens = std::move(tokens_lz77);
1482   }
1483   size_t total_tokens = 0;
1484   // Build histograms.
1485   HistogramBuilder builder(num_contexts);
1486   HybridUintConfig uint_config;  //  Default config for clustering.
1487   // Unless we are using the kContextMap histogram option.
1488   if (params.uint_method == HistogramParams::HybridUintMethod::kContextMap) {
1489     uint_config = HybridUintConfig(2, 0, 1);
1490   }
1491   if (ans_fuzzer_friendly_) {
1492     uint_config = HybridUintConfig(10, 0, 0);
1493   }
1494   for (size_t i = 0; i < tokens.size(); ++i) {
1495     for (size_t j = 0; j < tokens[i].size(); ++j) {
1496       const Token token = tokens[i][j];
1497       total_tokens++;
1498       uint32_t tok, nbits, bits;
1499       (token.is_lz77_length ? codes->lz77.length_uint_config : uint_config)
1500           .Encode(token.value, &tok, &nbits, &bits);
1501       tok += token.is_lz77_length ? codes->lz77.min_symbol : 0;
1502       builder.VisitSymbol(tok, token.context);
1503     }
1504   }
1505 
1506   bool use_prefix_code =
1507       params.force_huffman || total_tokens < 100 ||
1508       params.clustering == HistogramParams::ClusteringType::kFastest ||
1509       ans_fuzzer_friendly_;
1510   if (!use_prefix_code) {
1511     bool all_singleton = true;
1512     for (size_t i = 0; i < num_contexts; i++) {
1513       if (builder.Histo(i).ShannonEntropy() >= 1e-5) {
1514         all_singleton = false;
1515       }
1516     }
1517     if (all_singleton) {
1518       use_prefix_code = true;
1519     }
1520   }
1521 
1522   // Encode histograms.
1523   total_bits += builder.BuildAndStoreEntropyCodes(params, tokens, codes,
1524                                                   context_map, use_prefix_code,
1525                                                   writer, layer, aux_out);
1526   allotment.FinishedHistogram(writer);
1527   ReclaimAndCharge(writer, &allotment, layer, aux_out);
1528 
1529   if (aux_out != nullptr) {
1530     aux_out->layers[layer].num_clustered_histograms +=
1531         codes->encoding_info.size();
1532   }
1533   return total_bits;
1534 }
1535 
WriteTokens(const std::vector<Token> & tokens,const EntropyEncodingData & codes,const std::vector<uint8_t> & context_map,BitWriter * writer)1536 size_t WriteTokens(const std::vector<Token>& tokens,
1537                    const EntropyEncodingData& codes,
1538                    const std::vector<uint8_t>& context_map, BitWriter* writer) {
1539   size_t num_extra_bits = 0;
1540   if (codes.use_prefix_code) {
1541     for (size_t i = 0; i < tokens.size(); i++) {
1542       uint32_t tok, nbits, bits;
1543       const Token& token = tokens[i];
1544       size_t histo = context_map[token.context];
1545       (token.is_lz77_length ? codes.lz77.length_uint_config
1546                             : codes.uint_config[histo])
1547           .Encode(token.value, &tok, &nbits, &bits);
1548       tok += token.is_lz77_length ? codes.lz77.min_symbol : 0;
1549       // Combine two calls to the BitWriter. Equivalent to:
1550       // writer->Write(codes.encoding_info[histo][tok].depth,
1551       //               codes.encoding_info[histo][tok].bits);
1552       // writer->Write(nbits, bits);
1553       uint64_t data = codes.encoding_info[histo][tok].bits;
1554       data |= bits << codes.encoding_info[histo][tok].depth;
1555       writer->Write(codes.encoding_info[histo][tok].depth + nbits, data);
1556       num_extra_bits += nbits;
1557     }
1558     return num_extra_bits;
1559   }
1560   std::vector<uint64_t> out;
1561   std::vector<uint8_t> out_nbits;
1562   out.reserve(tokens.size());
1563   out_nbits.reserve(tokens.size());
1564   uint64_t allbits = 0;
1565   size_t numallbits = 0;
1566   // Writes in *reversed* order.
1567   auto addbits = [&](size_t bits, size_t nbits) {
1568     JXL_DASSERT(bits >> nbits == 0);
1569     if (JXL_UNLIKELY(numallbits + nbits > BitWriter::kMaxBitsPerCall)) {
1570       out.push_back(allbits);
1571       out_nbits.push_back(numallbits);
1572       numallbits = allbits = 0;
1573     }
1574     allbits <<= nbits;
1575     allbits |= bits;
1576     numallbits += nbits;
1577   };
1578   const int end = tokens.size();
1579   ANSCoder ans;
1580   for (int i = end - 1; i >= 0; --i) {
1581     const Token token = tokens[i];
1582     const uint8_t histo = context_map[token.context];
1583     uint32_t tok, nbits, bits;
1584     (token.is_lz77_length ? codes.lz77.length_uint_config
1585                           : codes.uint_config[histo])
1586         .Encode(tokens[i].value, &tok, &nbits, &bits);
1587     tok += token.is_lz77_length ? codes.lz77.min_symbol : 0;
1588     const ANSEncSymbolInfo& info = codes.encoding_info[histo][tok];
1589     // Extra bits first as this is reversed.
1590     addbits(bits, nbits);
1591     num_extra_bits += nbits;
1592     uint8_t ans_nbits = 0;
1593     uint32_t ans_bits = ans.PutSymbol(info, &ans_nbits);
1594     addbits(ans_bits, ans_nbits);
1595   }
1596   const uint32_t state = ans.GetState();
1597   writer->Write(32, state);
1598   writer->Write(numallbits, allbits);
1599   for (int i = out.size(); i > 0; --i) {
1600     writer->Write(out_nbits[i - 1], out[i - 1]);
1601   }
1602   return num_extra_bits;
1603 }
1604 
WriteTokens(const std::vector<Token> & tokens,const EntropyEncodingData & codes,const std::vector<uint8_t> & context_map,BitWriter * writer,size_t layer,AuxOut * aux_out)1605 void WriteTokens(const std::vector<Token>& tokens,
1606                  const EntropyEncodingData& codes,
1607                  const std::vector<uint8_t>& context_map, BitWriter* writer,
1608                  size_t layer, AuxOut* aux_out) {
1609   BitWriter::Allotment allotment(writer, 32 * tokens.size() + 32 * 1024 * 4);
1610   size_t num_extra_bits = WriteTokens(tokens, codes, context_map, writer);
1611   ReclaimAndCharge(writer, &allotment, layer, aux_out);
1612   if (aux_out != nullptr) {
1613     aux_out->layers[layer].extra_bits += num_extra_bits;
1614   }
1615 }
1616 
SetANSFuzzerFriendly(bool ans_fuzzer_friendly)1617 void SetANSFuzzerFriendly(bool ans_fuzzer_friendly) {
1618 #if JXL_IS_DEBUG_BUILD  // Guard against accidental / malicious changes.
1619   ans_fuzzer_friendly_ = ans_fuzzer_friendly;
1620 #endif
1621 }
1622 }  // namespace jxl
1623