1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
2 //
3 // Use of this source code is governed by a BSD-style
4 // license that can be found in the LICENSE file.
5 
6 #include "lib/jxl/enc_ans.h"
7 
8 #include <stdint.h>
9 
10 #include <algorithm>
11 #include <array>
12 #include <cmath>
13 #include <limits>
14 #include <numeric>
15 #include <type_traits>
16 #include <unordered_map>
17 #include <utility>
18 #include <vector>
19 
20 #include "lib/jxl/ans_common.h"
21 #include "lib/jxl/aux_out.h"
22 #include "lib/jxl/aux_out_fwd.h"
23 #include "lib/jxl/base/bits.h"
24 #include "lib/jxl/dec_ans.h"
25 #include "lib/jxl/enc_cluster.h"
26 #include "lib/jxl/enc_context_map.h"
27 #include "lib/jxl/enc_huffman.h"
28 #include "lib/jxl/fast_math-inl.h"
29 #include "lib/jxl/fields.h"
30 
31 namespace jxl {
32 
33 namespace {
34 
35 bool ans_fuzzer_friendly_ = false;
36 
37 static const int kMaxNumSymbolsForSmallCode = 4;
38 
ANSBuildInfoTable(const ANSHistBin * counts,const AliasTable::Entry * table,size_t alphabet_size,size_t log_alpha_size,ANSEncSymbolInfo * info)39 void ANSBuildInfoTable(const ANSHistBin* counts, const AliasTable::Entry* table,
40                        size_t alphabet_size, size_t log_alpha_size,
41                        ANSEncSymbolInfo* info) {
42   size_t log_entry_size = ANS_LOG_TAB_SIZE - log_alpha_size;
43   size_t entry_size_minus_1 = (1 << log_entry_size) - 1;
44   // create valid alias table for empty streams.
45   for (size_t s = 0; s < std::max<size_t>(1, alphabet_size); ++s) {
46     const ANSHistBin freq = s == alphabet_size ? ANS_TAB_SIZE : counts[s];
47     info[s].freq_ = static_cast<uint16_t>(freq);
48 #ifdef USE_MULT_BY_RECIPROCAL
49     if (freq != 0) {
50       info[s].ifreq_ =
51           ((1ull << RECIPROCAL_PRECISION) + info[s].freq_ - 1) / info[s].freq_;
52     } else {
53       info[s].ifreq_ = 1;  // shouldn't matter (symbol shouldn't occur), but...
54     }
55 #endif
56     info[s].reverse_map_.resize(freq);
57   }
58   for (int i = 0; i < ANS_TAB_SIZE; i++) {
59     AliasTable::Symbol s =
60         AliasTable::Lookup(table, i, log_entry_size, entry_size_minus_1);
61     info[s.value].reverse_map_[s.offset] = i;
62   }
63 }
64 
EstimateDataBits(const ANSHistBin * histogram,const ANSHistBin * counts,size_t len)65 float EstimateDataBits(const ANSHistBin* histogram, const ANSHistBin* counts,
66                        size_t len) {
67   float sum = 0.0f;
68   int total_histogram = 0;
69   int total_counts = 0;
70   for (size_t i = 0; i < len; ++i) {
71     total_histogram += histogram[i];
72     total_counts += counts[i];
73     if (histogram[i] > 0) {
74       JXL_ASSERT(counts[i] > 0);
75       // += histogram[i] * -log(counts[i]/total_counts)
76       sum += histogram[i] *
77              std::max(0.0f, ANS_LOG_TAB_SIZE - FastLog2f(counts[i]));
78     }
79   }
80   if (total_histogram > 0) {
81     JXL_ASSERT(total_counts == ANS_TAB_SIZE);
82   }
83   return sum;
84 }
85 
EstimateDataBitsFlat(const ANSHistBin * histogram,size_t len)86 float EstimateDataBitsFlat(const ANSHistBin* histogram, size_t len) {
87   const float flat_bits = std::max(FastLog2f(len), 0.0f);
88   int total_histogram = 0;
89   for (size_t i = 0; i < len; ++i) {
90     total_histogram += histogram[i];
91   }
92   return total_histogram * flat_bits;
93 }
94 
95 // Static Huffman code for encoding logcounts. The last symbol is used as RLE
96 // sequence.
97 static const uint8_t kLogCountBitLengths[ANS_LOG_TAB_SIZE + 2] = {
98     5, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 6, 7, 7,
99 };
100 static const uint8_t kLogCountSymbols[ANS_LOG_TAB_SIZE + 2] = {
101     17, 11, 15, 3, 9, 7, 4, 2, 5, 6, 0, 33, 1, 65,
102 };
103 
104 // Returns the difference between largest count that can be represented and is
105 // smaller than "count" and smallest representable count larger than "count".
SmallestIncrement(uint32_t count,uint32_t shift)106 static int SmallestIncrement(uint32_t count, uint32_t shift) {
107   int bits = count == 0 ? -1 : FloorLog2Nonzero(count);
108   int drop_bits = bits - GetPopulationCountPrecision(bits, shift);
109   return drop_bits < 0 ? 1 : (1 << drop_bits);
110 }
111 
112 template <bool minimize_error_of_sum>
RebalanceHistogram(const float * targets,int max_symbol,int table_size,uint32_t shift,int * omit_pos,ANSHistBin * counts)113 bool RebalanceHistogram(const float* targets, int max_symbol, int table_size,
114                         uint32_t shift, int* omit_pos, ANSHistBin* counts) {
115   int sum = 0;
116   float sum_nonrounded = 0.0;
117   int remainder_pos = 0;  // if all of them are handled in first loop
118   int remainder_log = -1;
119   for (int n = 0; n < max_symbol; ++n) {
120     if (targets[n] > 0 && targets[n] < 1.0f) {
121       counts[n] = 1;
122       sum_nonrounded += targets[n];
123       sum += counts[n];
124     }
125   }
126   const float discount_ratio =
127       (table_size - sum) / (table_size - sum_nonrounded);
128   JXL_ASSERT(discount_ratio > 0);
129   JXL_ASSERT(discount_ratio <= 1.0f);
130   // Invariant for minimize_error_of_sum == true:
131   // abs(sum - sum_nonrounded)
132   //   <= SmallestIncrement(max(targets[])) + max_symbol
133   for (int n = 0; n < max_symbol; ++n) {
134     if (targets[n] >= 1.0f) {
135       sum_nonrounded += targets[n];
136       counts[n] =
137           static_cast<ANSHistBin>(targets[n] * discount_ratio);  // truncate
138       if (counts[n] == 0) counts[n] = 1;
139       if (counts[n] == table_size) counts[n] = table_size - 1;
140       // Round the count to the closest nonzero multiple of SmallestIncrement
141       // (when minimize_error_of_sum is false) or one of two closest so as to
142       // keep the sum as close as possible to sum_nonrounded.
143       int inc = SmallestIncrement(counts[n], shift);
144       counts[n] -= counts[n] & (inc - 1);
145       // TODO(robryk): Should we rescale targets[n]?
146       const float target =
147           minimize_error_of_sum ? (sum_nonrounded - sum) : targets[n];
148       if (counts[n] == 0 ||
149           (target > counts[n] + inc / 2 && counts[n] + inc < table_size)) {
150         counts[n] += inc;
151       }
152       sum += counts[n];
153       const int count_log = FloorLog2Nonzero(static_cast<uint32_t>(counts[n]));
154       if (count_log > remainder_log) {
155         remainder_pos = n;
156         remainder_log = count_log;
157       }
158     }
159   }
160   JXL_ASSERT(remainder_pos != -1);
161   // NOTE: This is the only place where counts could go negative. We could
162   // detect that, return false and make ANSHistBin uint32_t.
163   counts[remainder_pos] -= sum - table_size;
164   *omit_pos = remainder_pos;
165   return counts[remainder_pos] > 0;
166 }
167 
NormalizeCounts(ANSHistBin * counts,int * omit_pos,const int length,const int precision_bits,uint32_t shift,int * num_symbols,int * symbols)168 Status NormalizeCounts(ANSHistBin* counts, int* omit_pos, const int length,
169                        const int precision_bits, uint32_t shift,
170                        int* num_symbols, int* symbols) {
171   const int32_t table_size = 1 << precision_bits;  // target sum / table size
172   uint64_t total = 0;
173   int max_symbol = 0;
174   int symbol_count = 0;
175   for (int n = 0; n < length; ++n) {
176     total += counts[n];
177     if (counts[n] > 0) {
178       if (symbol_count < kMaxNumSymbolsForSmallCode) {
179         symbols[symbol_count] = n;
180       }
181       ++symbol_count;
182       max_symbol = n + 1;
183     }
184   }
185   *num_symbols = symbol_count;
186   if (symbol_count == 0) {
187     return true;
188   }
189   if (symbol_count == 1) {
190     counts[symbols[0]] = table_size;
191     return true;
192   }
193   if (symbol_count > table_size)
194     return JXL_FAILURE("Too many entries in an ANS histogram");
195 
196   const float norm = 1.f * table_size / total;
197   std::vector<float> targets(max_symbol);
198   for (size_t n = 0; n < targets.size(); ++n) {
199     targets[n] = norm * counts[n];
200   }
201   if (!RebalanceHistogram<false>(&targets[0], max_symbol, table_size, shift,
202                                  omit_pos, counts)) {
203     // Use an alternative rebalancing mechanism if the one above failed
204     // to create a histogram that is positive wherever the original one was.
205     if (!RebalanceHistogram<true>(&targets[0], max_symbol, table_size, shift,
206                                   omit_pos, counts)) {
207       return JXL_FAILURE("Logic error: couldn't rebalance a histogram");
208     }
209   }
210   return true;
211 }
212 
213 struct SizeWriter {
214   size_t size = 0;
Writejxl::__anon30e99be10111::SizeWriter215   void Write(size_t num, size_t bits) { size += num; }
216 };
217 
218 template <typename Writer>
StoreVarLenUint8(size_t n,Writer * writer)219 void StoreVarLenUint8(size_t n, Writer* writer) {
220   JXL_DASSERT(n <= 255);
221   if (n == 0) {
222     writer->Write(1, 0);
223   } else {
224     writer->Write(1, 1);
225     size_t nbits = FloorLog2Nonzero(n);
226     writer->Write(3, nbits);
227     writer->Write(nbits, n - (1ULL << nbits));
228   }
229 }
230 
231 template <typename Writer>
StoreVarLenUint16(size_t n,Writer * writer)232 void StoreVarLenUint16(size_t n, Writer* writer) {
233   JXL_DASSERT(n <= 65535);
234   if (n == 0) {
235     writer->Write(1, 0);
236   } else {
237     writer->Write(1, 1);
238     size_t nbits = FloorLog2Nonzero(n);
239     writer->Write(4, nbits);
240     writer->Write(nbits, n - (1ULL << nbits));
241   }
242 }
243 
244 template <typename Writer>
EncodeCounts(const ANSHistBin * counts,const int alphabet_size,const int omit_pos,const int num_symbols,uint32_t shift,const int * symbols,Writer * writer)245 bool EncodeCounts(const ANSHistBin* counts, const int alphabet_size,
246                   const int omit_pos, const int num_symbols, uint32_t shift,
247                   const int* symbols, Writer* writer) {
248   bool ok = true;
249   if (num_symbols <= 2) {
250     // Small tree marker to encode 1-2 symbols.
251     writer->Write(1, 1);
252     if (num_symbols == 0) {
253       writer->Write(1, 0);
254       StoreVarLenUint8(0, writer);
255     } else {
256       writer->Write(1, num_symbols - 1);
257       for (int i = 0; i < num_symbols; ++i) {
258         StoreVarLenUint8(symbols[i], writer);
259       }
260     }
261     if (num_symbols == 2) {
262       writer->Write(ANS_LOG_TAB_SIZE, counts[symbols[0]]);
263     }
264   } else {
265     // Mark non-small tree.
266     writer->Write(1, 0);
267     // Mark non-flat histogram.
268     writer->Write(1, 0);
269 
270     // Precompute sequences for RLE encoding. Contains the number of identical
271     // values starting at a given index. Only contains the value at the first
272     // element of the series.
273     std::vector<uint32_t> same(alphabet_size, 0);
274     int last = 0;
275     for (int i = 1; i < alphabet_size; i++) {
276       // Store the sequence length once different symbol reached, or we're at
277       // the end, or the length is longer than we can encode, or we are at
278       // the omit_pos. We don't support including the omit_pos in an RLE
279       // sequence because this value may use a different amount of log2 bits
280       // than standard, it is too complex to handle in the decoder.
281       if (counts[i] != counts[last] || i + 1 == alphabet_size ||
282           (i - last) >= 255 || i == omit_pos || i == omit_pos + 1) {
283         same[last] = (i - last);
284         last = i + 1;
285       }
286     }
287 
288     int length = 0;
289     std::vector<int> logcounts(alphabet_size);
290     int omit_log = 0;
291     for (int i = 0; i < alphabet_size; ++i) {
292       JXL_ASSERT(counts[i] <= ANS_TAB_SIZE);
293       JXL_ASSERT(counts[i] >= 0);
294       if (i == omit_pos) {
295         length = i + 1;
296       } else if (counts[i] > 0) {
297         logcounts[i] = FloorLog2Nonzero(static_cast<uint32_t>(counts[i])) + 1;
298         length = i + 1;
299         if (i < omit_pos) {
300           omit_log = std::max(omit_log, logcounts[i] + 1);
301         } else {
302           omit_log = std::max(omit_log, logcounts[i]);
303         }
304       }
305     }
306     logcounts[omit_pos] = omit_log;
307 
308     // Elias gamma-like code for shift. Only difference is that if the number
309     // of bits to be encoded is equal to FloorLog2(ANS_LOG_TAB_SIZE+1), we skip
310     // the terminating 0 in unary coding.
311     int upper_bound_log = FloorLog2Nonzero(ANS_LOG_TAB_SIZE + 1);
312     int log = FloorLog2Nonzero(shift + 1);
313     writer->Write(log, (1 << log) - 1);
314     if (log != upper_bound_log) writer->Write(1, 0);
315     writer->Write(log, ((1 << log) - 1) & (shift + 1));
316 
317     // Since num_symbols >= 3, we know that length >= 3, therefore we encode
318     // length - 3.
319     if (length - 3 > 255) {
320       // Pretend that everything is OK, but complain about correctness later.
321       StoreVarLenUint8(255, writer);
322       ok = false;
323     } else {
324       StoreVarLenUint8(length - 3, writer);
325     }
326 
327     // The logcount values are encoded with a static Huffman code.
328     static const size_t kMinReps = 4;
329     size_t rep = ANS_LOG_TAB_SIZE + 1;
330     for (int i = 0; i < length; ++i) {
331       if (i > 0 && same[i - 1] > kMinReps) {
332         // Encode the RLE symbol and skip the repeated ones.
333         writer->Write(kLogCountBitLengths[rep], kLogCountSymbols[rep]);
334         StoreVarLenUint8(same[i - 1] - kMinReps - 1, writer);
335         i += same[i - 1] - 2;
336         continue;
337       }
338       writer->Write(kLogCountBitLengths[logcounts[i]],
339                     kLogCountSymbols[logcounts[i]]);
340     }
341     for (int i = 0; i < length; ++i) {
342       if (i > 0 && same[i - 1] > kMinReps) {
343         // Skip symbols encoded by RLE.
344         i += same[i - 1] - 2;
345         continue;
346       }
347       if (logcounts[i] > 1 && i != omit_pos) {
348         int bitcount = GetPopulationCountPrecision(logcounts[i] - 1, shift);
349         int drop_bits = logcounts[i] - 1 - bitcount;
350         JXL_CHECK((counts[i] & ((1 << drop_bits) - 1)) == 0);
351         writer->Write(bitcount, (counts[i] >> drop_bits) - (1 << bitcount));
352       }
353     }
354   }
355   return ok;
356 }
357 
EncodeFlatHistogram(const int alphabet_size,BitWriter * writer)358 void EncodeFlatHistogram(const int alphabet_size, BitWriter* writer) {
359   // Mark non-small tree.
360   writer->Write(1, 0);
361   // Mark uniform histogram.
362   writer->Write(1, 1);
363   JXL_ASSERT(alphabet_size > 0);
364   // Encode alphabet size.
365   StoreVarLenUint8(alphabet_size - 1, writer);
366 }
367 
ComputeHistoAndDataCost(const ANSHistBin * histogram,size_t alphabet_size,uint32_t method)368 float ComputeHistoAndDataCost(const ANSHistBin* histogram, size_t alphabet_size,
369                               uint32_t method) {
370   if (method == 0) {  // Flat code
371     return ANS_LOG_TAB_SIZE + 2 +
372            EstimateDataBitsFlat(histogram, alphabet_size);
373   }
374   // Non-flat: shift = method-1.
375   uint32_t shift = method - 1;
376   std::vector<ANSHistBin> counts(histogram, histogram + alphabet_size);
377   int omit_pos = 0;
378   int num_symbols;
379   int symbols[kMaxNumSymbolsForSmallCode] = {};
380   JXL_CHECK(NormalizeCounts(counts.data(), &omit_pos, alphabet_size,
381                             ANS_LOG_TAB_SIZE, shift, &num_symbols, symbols));
382   SizeWriter writer;
383   // Ignore the correctness, no real encoding happens at this stage.
384   (void)EncodeCounts(counts.data(), alphabet_size, omit_pos, num_symbols, shift,
385                      symbols, &writer);
386   return writer.size +
387          EstimateDataBits(histogram, counts.data(), alphabet_size);
388 }
389 
ComputeBestMethod(const ANSHistBin * histogram,size_t alphabet_size,float * cost,HistogramParams::ANSHistogramStrategy ans_histogram_strategy)390 uint32_t ComputeBestMethod(
391     const ANSHistBin* histogram, size_t alphabet_size, float* cost,
392     HistogramParams::ANSHistogramStrategy ans_histogram_strategy) {
393   size_t method = 0;
394   float fcost = ComputeHistoAndDataCost(histogram, alphabet_size, 0);
395   for (uint32_t shift = 0; shift <= ANS_LOG_TAB_SIZE;
396        ans_histogram_strategy != HistogramParams::ANSHistogramStrategy::kPrecise
397            ? shift += 2
398            : shift++) {
399     float c = ComputeHistoAndDataCost(histogram, alphabet_size, shift + 1);
400     if (c < fcost) {
401       method = shift + 1;
402       fcost = c;
403     } else if (ans_histogram_strategy ==
404                HistogramParams::ANSHistogramStrategy::kFast) {
405       // do not be as precise if estimating cost.
406       break;
407     }
408   }
409   *cost = fcost;
410   return method;
411 }
412 
413 }  // namespace
414 
415 // Returns an estimate of the cost of encoding this histogram and the
416 // corresponding data.
BuildAndStoreANSEncodingData(HistogramParams::ANSHistogramStrategy ans_histogram_strategy,const ANSHistBin * histogram,size_t alphabet_size,size_t log_alpha_size,bool use_prefix_code,ANSEncSymbolInfo * info,BitWriter * writer)417 size_t BuildAndStoreANSEncodingData(
418     HistogramParams::ANSHistogramStrategy ans_histogram_strategy,
419     const ANSHistBin* histogram, size_t alphabet_size, size_t log_alpha_size,
420     bool use_prefix_code, ANSEncSymbolInfo* info, BitWriter* writer) {
421   if (use_prefix_code) {
422     if (alphabet_size <= 1) return 0;
423     std::vector<uint32_t> histo(alphabet_size);
424     size_t total = 0;
425     for (size_t i = 0; i < alphabet_size; i++) {
426       histo[i] = histogram[i];
427       JXL_CHECK(histogram[i] >= 0);
428       total += histo[i];
429     }
430     size_t cost = 0;
431     {
432       std::vector<uint8_t> depths(alphabet_size);
433       std::vector<uint16_t> bits(alphabet_size);
434       BitWriter tmp_writer;
435       BitWriter* w = writer ? writer : &tmp_writer;
436       size_t start = w->BitsWritten();
437       BitWriter::Allotment allotment(
438           w, 8 * alphabet_size + 8);  // safe upper bound
439       BuildAndStoreHuffmanTree(histo.data(), alphabet_size, depths.data(),
440                                bits.data(), w);
441       ReclaimAndCharge(w, &allotment, 0, /*aux_out=*/nullptr);
442 
443       for (size_t i = 0; i < alphabet_size; i++) {
444         info[i].bits = depths[i] == 0 ? 0 : bits[i];
445         info[i].depth = depths[i];
446       }
447       cost = w->BitsWritten() - start;
448     }
449     // Estimate data cost.
450     for (size_t i = 0; i < alphabet_size; i++) {
451       cost += histogram[i] * info[i].depth;
452     }
453     return cost;
454   }
455   JXL_ASSERT(alphabet_size <= ANS_TAB_SIZE);
456   // Ensure we ignore trailing zeros in the histogram.
457   if (alphabet_size != 0) {
458     size_t largest_symbol = 0;
459     for (size_t i = 0; i < alphabet_size; i++) {
460       if (histogram[i] != 0) largest_symbol = i;
461     }
462     alphabet_size = largest_symbol + 1;
463   }
464   float cost;
465   uint32_t method = ComputeBestMethod(histogram, alphabet_size, &cost,
466                                       ans_histogram_strategy);
467   JXL_ASSERT(cost >= 0);
468   int num_symbols;
469   int symbols[kMaxNumSymbolsForSmallCode] = {};
470   std::vector<ANSHistBin> counts(histogram, histogram + alphabet_size);
471   if (!counts.empty()) {
472     size_t sum = 0;
473     for (size_t i = 0; i < counts.size(); i++) {
474       sum += counts[i];
475     }
476     if (sum == 0) {
477       counts[0] = ANS_TAB_SIZE;
478     }
479   }
480   if (method == 0) {
481     counts = CreateFlatHistogram(alphabet_size, ANS_TAB_SIZE);
482     AliasTable::Entry a[ANS_MAX_ALPHABET_SIZE];
483     InitAliasTable(counts, ANS_TAB_SIZE, log_alpha_size, a);
484     ANSBuildInfoTable(counts.data(), a, alphabet_size, log_alpha_size, info);
485     if (writer != nullptr) {
486       EncodeFlatHistogram(alphabet_size, writer);
487     }
488     return cost;
489   }
490   int omit_pos = 0;
491   uint32_t shift = method - 1;
492   JXL_CHECK(NormalizeCounts(counts.data(), &omit_pos, alphabet_size,
493                             ANS_LOG_TAB_SIZE, shift, &num_symbols, symbols));
494   AliasTable::Entry a[ANS_MAX_ALPHABET_SIZE];
495   InitAliasTable(counts, ANS_TAB_SIZE, log_alpha_size, a);
496   ANSBuildInfoTable(counts.data(), a, alphabet_size, log_alpha_size, info);
497   if (writer != nullptr) {
498     bool ok = EncodeCounts(counts.data(), alphabet_size, omit_pos, num_symbols,
499                            shift, symbols, writer);
500     (void)ok;
501     JXL_DASSERT(ok);
502   }
503   return cost;
504 }
505 
ANSPopulationCost(const ANSHistBin * data,size_t alphabet_size)506 float ANSPopulationCost(const ANSHistBin* data, size_t alphabet_size) {
507   float c;
508   ComputeBestMethod(data, alphabet_size, &c,
509                     HistogramParams::ANSHistogramStrategy::kFast);
510   return c;
511 }
512 
513 template <typename Writer>
EncodeUintConfig(const HybridUintConfig uint_config,Writer * writer,size_t log_alpha_size)514 void EncodeUintConfig(const HybridUintConfig uint_config, Writer* writer,
515                       size_t log_alpha_size) {
516   writer->Write(CeilLog2Nonzero(log_alpha_size + 1),
517                 uint_config.split_exponent);
518   if (uint_config.split_exponent == log_alpha_size) {
519     return;  // msb/lsb don't matter.
520   }
521   size_t nbits = CeilLog2Nonzero(uint_config.split_exponent + 1);
522   writer->Write(nbits, uint_config.msb_in_token);
523   nbits = CeilLog2Nonzero(uint_config.split_exponent -
524                           uint_config.msb_in_token + 1);
525   writer->Write(nbits, uint_config.lsb_in_token);
526 }
527 template <typename Writer>
EncodeUintConfigs(const std::vector<HybridUintConfig> & uint_config,Writer * writer,size_t log_alpha_size)528 void EncodeUintConfigs(const std::vector<HybridUintConfig>& uint_config,
529                        Writer* writer, size_t log_alpha_size) {
530   // TODO(veluca): RLE?
531   for (size_t i = 0; i < uint_config.size(); i++) {
532     EncodeUintConfig(uint_config[i], writer, log_alpha_size);
533   }
534 }
535 template void EncodeUintConfigs(const std::vector<HybridUintConfig>&,
536                                 BitWriter*, size_t);
537 
538 namespace {
539 
ChooseUintConfigs(const HistogramParams & params,const std::vector<std::vector<Token>> & tokens,const std::vector<uint8_t> & context_map,std::vector<Histogram> * clustered_histograms,EntropyEncodingData * codes,size_t * log_alpha_size)540 void ChooseUintConfigs(const HistogramParams& params,
541                        const std::vector<std::vector<Token>>& tokens,
542                        const std::vector<uint8_t>& context_map,
543                        std::vector<Histogram>* clustered_histograms,
544                        EntropyEncodingData* codes, size_t* log_alpha_size) {
545   codes->uint_config.resize(clustered_histograms->size());
546   if (params.uint_method == HistogramParams::HybridUintMethod::kNone) return;
547   if (params.uint_method == HistogramParams::HybridUintMethod::kContextMap) {
548     codes->uint_config.clear();
549     codes->uint_config.resize(clustered_histograms->size(),
550                               HybridUintConfig(2, 0, 1));
551     return;
552   }
553 
554   // Brute-force method that tries a few options.
555   std::vector<HybridUintConfig> configs;
556   if (params.uint_method == HistogramParams::HybridUintMethod::kBest) {
557     configs = {
558         HybridUintConfig(4, 2, 0),  // default
559         HybridUintConfig(4, 1, 0),  // less precise
560         HybridUintConfig(4, 2, 1),  // add sign
561         HybridUintConfig(4, 2, 2),  // add sign+parity
562         HybridUintConfig(4, 1, 2),  // add parity but less msb
563         // Same as above, but more direct coding.
564         HybridUintConfig(5, 2, 0), HybridUintConfig(5, 1, 0),
565         HybridUintConfig(5, 2, 1), HybridUintConfig(5, 2, 2),
566         HybridUintConfig(5, 1, 2),
567         // Same as above, but less direct coding.
568         HybridUintConfig(3, 2, 0), HybridUintConfig(3, 1, 0),
569         HybridUintConfig(3, 2, 1), HybridUintConfig(3, 1, 2),
570         // For near-lossless.
571         HybridUintConfig(4, 1, 3), HybridUintConfig(5, 1, 4),
572         HybridUintConfig(5, 2, 3), HybridUintConfig(6, 1, 5),
573         HybridUintConfig(6, 2, 4), HybridUintConfig(6, 0, 0),
574         // Other
575         HybridUintConfig(0, 0, 0),   // varlenuint
576         HybridUintConfig(2, 0, 1),   // works well for ctx map
577         HybridUintConfig(7, 0, 0),   // direct coding
578         HybridUintConfig(8, 0, 0),   // direct coding
579         HybridUintConfig(9, 0, 0),   // direct coding
580         HybridUintConfig(10, 0, 0),  // direct coding
581         HybridUintConfig(11, 0, 0),  // direct coding
582         HybridUintConfig(12, 0, 0),  // direct coding
583     };
584   } else if (params.uint_method == HistogramParams::HybridUintMethod::kFast) {
585     configs = {
586         HybridUintConfig(4, 2, 0),  // default
587         HybridUintConfig(4, 1, 2),  // add parity but less msb
588         HybridUintConfig(0, 0, 0),  // smallest histograms
589         HybridUintConfig(2, 0, 1),  // works well for ctx map
590     };
591   }
592 
593   std::vector<float> costs(clustered_histograms->size(),
594                            std::numeric_limits<float>::max());
595   std::vector<uint32_t> extra_bits(clustered_histograms->size());
596   std::vector<uint8_t> is_valid(clustered_histograms->size());
597   size_t max_alpha =
598       codes->use_prefix_code ? PREFIX_MAX_ALPHABET_SIZE : ANS_MAX_ALPHABET_SIZE;
599   for (HybridUintConfig cfg : configs) {
600     std::fill(is_valid.begin(), is_valid.end(), true);
601     std::fill(extra_bits.begin(), extra_bits.end(), 0);
602 
603     for (size_t i = 0; i < clustered_histograms->size(); i++) {
604       (*clustered_histograms)[i].Clear();
605     }
606     for (size_t i = 0; i < tokens.size(); ++i) {
607       for (size_t j = 0; j < tokens[i].size(); ++j) {
608         const Token token = tokens[i][j];
609         // TODO(veluca): do not ignore lz77 commands.
610         if (token.is_lz77_length) continue;
611         size_t histo = context_map[token.context];
612         uint32_t tok, nbits, bits;
613         cfg.Encode(token.value, &tok, &nbits, &bits);
614         if (tok >= max_alpha ||
615             (codes->lz77.enabled && tok >= codes->lz77.min_symbol)) {
616           is_valid[histo] = false;
617           continue;
618         }
619         extra_bits[histo] += nbits;
620         (*clustered_histograms)[histo].Add(tok);
621       }
622     }
623 
624     for (size_t i = 0; i < clustered_histograms->size(); i++) {
625       if (!is_valid[i]) continue;
626       float cost = (*clustered_histograms)[i].PopulationCost() + extra_bits[i];
627       if (cost < costs[i]) {
628         codes->uint_config[i] = cfg;
629         costs[i] = cost;
630       }
631     }
632   }
633 
634   // Rebuild histograms.
635   for (size_t i = 0; i < clustered_histograms->size(); i++) {
636     (*clustered_histograms)[i].Clear();
637   }
638   *log_alpha_size = 4;
639   for (size_t i = 0; i < tokens.size(); ++i) {
640     for (size_t j = 0; j < tokens[i].size(); ++j) {
641       const Token token = tokens[i][j];
642       uint32_t tok, nbits, bits;
643       size_t histo = context_map[token.context];
644       (token.is_lz77_length ? codes->lz77.length_uint_config
645                             : codes->uint_config[histo])
646           .Encode(token.value, &tok, &nbits, &bits);
647       tok += token.is_lz77_length ? codes->lz77.min_symbol : 0;
648       (*clustered_histograms)[histo].Add(tok);
649       while (tok >= (1u << *log_alpha_size)) (*log_alpha_size)++;
650     }
651   }
652 #if JXL_ENABLE_ASSERT
653   size_t max_log_alpha_size = codes->use_prefix_code ? PREFIX_MAX_BITS : 8;
654   JXL_ASSERT(*log_alpha_size <= max_log_alpha_size);
655 #endif
656 }
657 
658 class HistogramBuilder {
659  public:
HistogramBuilder(const size_t num_contexts)660   explicit HistogramBuilder(const size_t num_contexts)
661       : histograms_(num_contexts) {}
662 
VisitSymbol(int symbol,size_t histo_idx)663   void VisitSymbol(int symbol, size_t histo_idx) {
664     JXL_DASSERT(histo_idx < histograms_.size());
665     histograms_[histo_idx].Add(symbol);
666   }
667 
668   // NOTE: `layer` is only for clustered_entropy; caller does ReclaimAndCharge.
BuildAndStoreEntropyCodes(const HistogramParams & params,const std::vector<std::vector<Token>> & tokens,EntropyEncodingData * codes,std::vector<uint8_t> * context_map,bool use_prefix_code,BitWriter * writer,size_t layer,AuxOut * aux_out) const669   size_t BuildAndStoreEntropyCodes(
670       const HistogramParams& params,
671       const std::vector<std::vector<Token>>& tokens, EntropyEncodingData* codes,
672       std::vector<uint8_t>* context_map, bool use_prefix_code,
673       BitWriter* writer, size_t layer, AuxOut* aux_out) const {
674     size_t cost = 0;
675     codes->encoding_info.clear();
676     std::vector<Histogram> clustered_histograms(histograms_);
677     context_map->resize(histograms_.size());
678     if (histograms_.size() > 1) {
679       if (!ans_fuzzer_friendly_) {
680         std::vector<uint32_t> histogram_symbols;
681         ClusterHistograms(params, histograms_, histograms_.size(),
682                           kClustersLimit, &clustered_histograms,
683                           &histogram_symbols);
684         for (size_t c = 0; c < histograms_.size(); ++c) {
685           (*context_map)[c] = static_cast<uint8_t>(histogram_symbols[c]);
686         }
687       } else {
688         fill(context_map->begin(), context_map->end(), 0);
689         size_t max_symbol = 0;
690         for (const Histogram& h : histograms_) {
691           max_symbol = std::max(h.data_.size(), max_symbol);
692         }
693         size_t num_symbols = 1 << CeilLog2Nonzero(max_symbol + 1);
694         clustered_histograms.resize(1);
695         clustered_histograms[0].Clear();
696         for (size_t i = 0; i < num_symbols; i++) {
697           clustered_histograms[0].Add(i);
698         }
699       }
700       if (writer != nullptr) {
701         EncodeContextMap(*context_map, clustered_histograms.size(), writer);
702       }
703     }
704     if (aux_out != nullptr) {
705       for (size_t i = 0; i < clustered_histograms.size(); ++i) {
706         aux_out->layers[layer].clustered_entropy +=
707             clustered_histograms[i].ShannonEntropy();
708       }
709     }
710     codes->use_prefix_code = use_prefix_code;
711     size_t log_alpha_size = codes->lz77.enabled ? 8 : 7;  // Sane default.
712     if (ans_fuzzer_friendly_) {
713       codes->uint_config.clear();
714       codes->uint_config.resize(1, HybridUintConfig(7, 0, 0));
715     } else {
716       ChooseUintConfigs(params, tokens, *context_map, &clustered_histograms,
717                         codes, &log_alpha_size);
718     }
719     if (log_alpha_size < 5) log_alpha_size = 5;
720     SizeWriter size_writer;  // Used if writer == nullptr to estimate costs.
721     cost += 1;
722     if (writer) writer->Write(1, use_prefix_code);
723 
724     if (use_prefix_code) {
725       log_alpha_size = PREFIX_MAX_BITS;
726     } else {
727       cost += 2;
728     }
729     if (writer == nullptr) {
730       EncodeUintConfigs(codes->uint_config, &size_writer, log_alpha_size);
731     } else {
732       if (!use_prefix_code) writer->Write(2, log_alpha_size - 5);
733       EncodeUintConfigs(codes->uint_config, writer, log_alpha_size);
734     }
735     if (use_prefix_code) {
736       for (size_t c = 0; c < clustered_histograms.size(); ++c) {
737         size_t num_symbol = 1;
738         for (size_t i = 0; i < clustered_histograms[c].data_.size(); i++) {
739           if (clustered_histograms[c].data_[i]) num_symbol = i + 1;
740         }
741         if (writer) {
742           StoreVarLenUint16(num_symbol - 1, writer);
743         } else {
744           StoreVarLenUint16(num_symbol - 1, &size_writer);
745         }
746       }
747     }
748     cost += size_writer.size;
749     for (size_t c = 0; c < clustered_histograms.size(); ++c) {
750       size_t num_symbol = 1;
751       for (size_t i = 0; i < clustered_histograms[c].data_.size(); i++) {
752         if (clustered_histograms[c].data_[i]) num_symbol = i + 1;
753       }
754       codes->encoding_info.emplace_back();
755       codes->encoding_info.back().resize(std::max<size_t>(1, num_symbol));
756 
757       BitWriter::Allotment allotment(writer, 256 + num_symbol * 24);
758       cost += BuildAndStoreANSEncodingData(
759           params.ans_histogram_strategy, clustered_histograms[c].data_.data(),
760           num_symbol, log_alpha_size, use_prefix_code,
761           codes->encoding_info.back().data(), writer);
762       allotment.FinishedHistogram(writer);
763       ReclaimAndCharge(writer, &allotment, layer, aux_out);
764     }
765     return cost;
766   }
767 
Histo(size_t i) const768   const Histogram& Histo(size_t i) const { return histograms_[i]; }
769 
770  private:
771   std::vector<Histogram> histograms_;
772 };
773 
774 class SymbolCostEstimator {
775  public:
SymbolCostEstimator(size_t num_contexts,bool force_huffman,const std::vector<std::vector<Token>> & tokens,const LZ77Params & lz77)776   SymbolCostEstimator(size_t num_contexts, bool force_huffman,
777                       const std::vector<std::vector<Token>>& tokens,
778                       const LZ77Params& lz77) {
779     HistogramBuilder builder(num_contexts);
780     // Build histograms for estimating lz77 savings.
781     HybridUintConfig uint_config;
782     for (size_t i = 0; i < tokens.size(); ++i) {
783       for (size_t j = 0; j < tokens[i].size(); ++j) {
784         const Token token = tokens[i][j];
785         uint32_t tok, nbits, bits;
786         (token.is_lz77_length ? lz77.length_uint_config : uint_config)
787             .Encode(token.value, &tok, &nbits, &bits);
788         tok += token.is_lz77_length ? lz77.min_symbol : 0;
789         builder.VisitSymbol(tok, token.context);
790       }
791     }
792     max_alphabet_size_ = 0;
793     for (size_t i = 0; i < num_contexts; i++) {
794       max_alphabet_size_ =
795           std::max(max_alphabet_size_, builder.Histo(i).data_.size());
796     }
797     bits_.resize(num_contexts * max_alphabet_size_);
798     // TODO(veluca): SIMD?
799     add_symbol_cost_.resize(num_contexts);
800     for (size_t i = 0; i < num_contexts; i++) {
801       float inv_total = 1.0f / (builder.Histo(i).total_count_ + 1e-8f);
802       float total_cost = 0;
803       for (size_t j = 0; j < builder.Histo(i).data_.size(); j++) {
804         size_t cnt = builder.Histo(i).data_[j];
805         float cost = 0;
806         if (cnt != 0 && cnt != builder.Histo(i).total_count_) {
807           cost = -FastLog2f(cnt * inv_total);
808           if (force_huffman) cost = std::ceil(cost);
809         } else if (cnt == 0) {
810           cost = ANS_LOG_TAB_SIZE;  // Highest possible cost.
811         }
812         bits_[i * max_alphabet_size_ + j] = cost;
813         total_cost += cost * builder.Histo(i).data_[j];
814       }
815       // Penalty for adding a lz77 symbol to this contest (only used for static
816       // cost model). Higher penalty for contexts that have a very low
817       // per-symbol entropy.
818       add_symbol_cost_[i] = std::max(0.0f, 6.0f - total_cost * inv_total);
819     }
820   }
Bits(size_t ctx,size_t sym) const821   float Bits(size_t ctx, size_t sym) const {
822     return bits_[ctx * max_alphabet_size_ + sym];
823   }
LenCost(size_t ctx,size_t len,const LZ77Params & lz77) const824   float LenCost(size_t ctx, size_t len, const LZ77Params& lz77) const {
825     uint32_t nbits, bits, tok;
826     lz77.length_uint_config.Encode(len, &tok, &nbits, &bits);
827     tok += lz77.min_symbol;
828     return nbits + Bits(ctx, tok);
829   }
DistCost(size_t len,const LZ77Params & lz77) const830   float DistCost(size_t len, const LZ77Params& lz77) const {
831     uint32_t nbits, bits, tok;
832     HybridUintConfig().Encode(len, &tok, &nbits, &bits);
833     return nbits + Bits(lz77.nonserialized_distance_context, tok);
834   }
AddSymbolCost(size_t idx) const835   float AddSymbolCost(size_t idx) const { return add_symbol_cost_[idx]; }
836 
837  private:
838   size_t max_alphabet_size_;
839   std::vector<float> bits_;
840   std::vector<float> add_symbol_cost_;
841 };
842 
ApplyLZ77_RLE(const HistogramParams & params,size_t num_contexts,const std::vector<std::vector<Token>> & tokens,LZ77Params & lz77,std::vector<std::vector<Token>> & tokens_lz77)843 void ApplyLZ77_RLE(const HistogramParams& params, size_t num_contexts,
844                    const std::vector<std::vector<Token>>& tokens,
845                    LZ77Params& lz77,
846                    std::vector<std::vector<Token>>& tokens_lz77) {
847   // TODO(veluca): tune heuristics here.
848   SymbolCostEstimator sce(num_contexts, params.force_huffman, tokens, lz77);
849   float bit_decrease = 0;
850   size_t total_symbols = 0;
851   tokens_lz77.resize(tokens.size());
852   std::vector<float> sym_cost;
853   HybridUintConfig uint_config;
854   for (size_t stream = 0; stream < tokens.size(); stream++) {
855     size_t distance_multiplier =
856         params.image_widths.size() > stream ? params.image_widths[stream] : 0;
857     const auto& in = tokens[stream];
858     auto& out = tokens_lz77[stream];
859     total_symbols += in.size();
860     // Cumulative sum of bit costs.
861     sym_cost.resize(in.size() + 1);
862     for (size_t i = 0; i < in.size(); i++) {
863       uint32_t tok, nbits, unused_bits;
864       uint_config.Encode(in[i].value, &tok, &nbits, &unused_bits);
865       sym_cost[i + 1] = sce.Bits(in[i].context, tok) + nbits + sym_cost[i];
866     }
867     out.reserve(in.size());
868     for (size_t i = 0; i < in.size(); i++) {
869       size_t num_to_copy = 0;
870       size_t distance_symbol = 0;  // 1 for RLE.
871       if (distance_multiplier != 0) {
872         distance_symbol = 1;  // Special distance 1 if enabled.
873         JXL_DASSERT(kSpecialDistances[1][0] == 1);
874         JXL_DASSERT(kSpecialDistances[1][1] == 0);
875       }
876       if (i > 0) {
877         for (; i + num_to_copy < in.size(); num_to_copy++) {
878           if (in[i + num_to_copy].value != in[i - 1].value) {
879             break;
880           }
881         }
882       }
883       if (num_to_copy == 0) {
884         out.push_back(in[i]);
885         continue;
886       }
887       float cost = sym_cost[i + num_to_copy] - sym_cost[i];
888       // This subtraction might overflow, but that's OK.
889       size_t lz77_len = num_to_copy - lz77.min_length;
890       float lz77_cost = num_to_copy >= lz77.min_length
891                             ? CeilLog2Nonzero(lz77_len + 1) + 1
892                             : 0;
893       if (num_to_copy < lz77.min_length || cost <= lz77_cost) {
894         for (size_t j = 0; j < num_to_copy; j++) {
895           out.push_back(in[i + j]);
896         }
897         i += num_to_copy - 1;
898         continue;
899       }
900       // Output the LZ77 length
901       out.emplace_back(in[i].context, lz77_len);
902       out.back().is_lz77_length = true;
903       i += num_to_copy - 1;
904       bit_decrease += cost - lz77_cost;
905       // Output the LZ77 copy distance.
906       out.emplace_back(lz77.nonserialized_distance_context, distance_symbol);
907     }
908   }
909 
910   if (bit_decrease > total_symbols * 0.2 + 16) {
911     lz77.enabled = true;
912   }
913 }
914 
915 // Hash chain for LZ77 matching
916 struct HashChain {
917   size_t size_;
918   std::vector<uint32_t> data_;
919 
920   unsigned hash_num_values_ = 32768;
921   unsigned hash_mask_ = hash_num_values_ - 1;
922   unsigned hash_shift_ = 5;
923 
924   std::vector<int> head;
925   std::vector<uint32_t> chain;
926   std::vector<int> val;
927 
928   // Speed up repetitions of zero
929   std::vector<int> headz;
930   std::vector<uint32_t> chainz;
931   std::vector<uint32_t> zeros;
932   uint32_t numzeros = 0;
933 
934   size_t window_size_;
935   size_t window_mask_;
936   size_t min_length_;
937   size_t max_length_;
938 
939   // Map of special distance codes.
940   std::unordered_map<int, int> special_dist_table_;
941   size_t num_special_distances_ = 0;
942 
943   uint32_t maxchainlength = 256;  // window_size_ to allow all
944 
HashChainjxl::__anon30e99be10211::HashChain945   HashChain(const Token* data, size_t size, size_t window_size,
946             size_t min_length, size_t max_length, size_t distance_multiplier)
947       : size_(size),
948         window_size_(window_size),
949         window_mask_(window_size - 1),
950         min_length_(min_length),
951         max_length_(max_length) {
952     data_.resize(size);
953     for (size_t i = 0; i < size; i++) {
954       data_[i] = data[i].value;
955     }
956 
957     head.resize(hash_num_values_, -1);
958     val.resize(window_size_, -1);
959     chain.resize(window_size_);
960     for (uint32_t i = 0; i < window_size_; ++i) {
961       chain[i] = i;  // same value as index indicates uninitialized
962     }
963 
964     zeros.resize(window_size_);
965     headz.resize(window_size_ + 1, -1);
966     chainz.resize(window_size_);
967     for (uint32_t i = 0; i < window_size_; ++i) {
968       chainz[i] = i;
969     }
970     // Translate distance to special distance code.
971     if (distance_multiplier) {
972       // Count down, so if due to small distance multiplier multiple distances
973       // map to the same code, the smallest code will be used in the end.
974       for (int i = kNumSpecialDistances - 1; i >= 0; --i) {
975         int xi = kSpecialDistances[i][0];
976         int yi = kSpecialDistances[i][1];
977         int distance = yi * distance_multiplier + xi;
978         // Ensure that we map distance 1 to the lowest symbols.
979         if (distance < 1) distance = 1;
980         special_dist_table_[distance] = i;
981       }
982       num_special_distances_ = kNumSpecialDistances;
983     }
984   }
985 
GetHashjxl::__anon30e99be10211::HashChain986   uint32_t GetHash(size_t pos) const {
987     uint32_t result = 0;
988     if (pos + 2 < size_) {
989       // TODO(lode): take the MSB's of the uint32_t values into account as well,
990       // given that the hash code itself is less than 32 bits.
991       result ^= (uint32_t)(data_[pos + 0] << 0u);
992       result ^= (uint32_t)(data_[pos + 1] << hash_shift_);
993       result ^= (uint32_t)(data_[pos + 2] << (hash_shift_ * 2));
994     } else {
995       // No need to compute hash of last 2 bytes, the length 2 is too short.
996       return 0;
997     }
998     return result & hash_mask_;
999   }
1000 
CountZerosjxl::__anon30e99be10211::HashChain1001   uint32_t CountZeros(size_t pos, uint32_t prevzeros) const {
1002     size_t end = pos + window_size_;
1003     if (end > size_) end = size_;
1004     if (prevzeros > 0) {
1005       if (prevzeros >= window_mask_ && data_[end - 1] == 0 &&
1006           end == pos + window_size_) {
1007         return prevzeros;
1008       } else {
1009         return prevzeros - 1;
1010       }
1011     }
1012     uint32_t num = 0;
1013     while (pos + num < end && data_[pos + num] == 0) num++;
1014     return num;
1015   }
1016 
Updatejxl::__anon30e99be10211::HashChain1017   void Update(size_t pos) {
1018     uint32_t hashval = GetHash(pos);
1019     uint32_t wpos = pos & window_mask_;
1020 
1021     val[wpos] = (int)hashval;
1022     if (head[hashval] != -1) chain[wpos] = head[hashval];
1023     head[hashval] = wpos;
1024 
1025     if (pos > 0 && data_[pos] != data_[pos - 1]) numzeros = 0;
1026     numzeros = CountZeros(pos, numzeros);
1027 
1028     zeros[wpos] = numzeros;
1029     if (headz[numzeros] != -1) chainz[wpos] = headz[numzeros];
1030     headz[numzeros] = wpos;
1031   }
1032 
Updatejxl::__anon30e99be10211::HashChain1033   void Update(size_t pos, size_t len) {
1034     for (size_t i = 0; i < len; i++) {
1035       Update(pos + i);
1036     }
1037   }
1038 
1039   template <typename CB>
FindMatchesjxl::__anon30e99be10211::HashChain1040   void FindMatches(size_t pos, int max_dist, const CB& found_match) const {
1041     uint32_t wpos = pos & window_mask_;
1042     uint32_t hashval = GetHash(pos);
1043     uint32_t hashpos = chain[wpos];
1044 
1045     int prev_dist = 0;
1046     int end = std::min<int>(pos + max_length_, size_);
1047     uint32_t chainlength = 0;
1048     uint32_t best_len = 0;
1049     for (;;) {
1050       int dist = (hashpos <= wpos) ? (wpos - hashpos)
1051                                    : (wpos - hashpos + window_mask_ + 1);
1052       if (dist < prev_dist) break;
1053       prev_dist = dist;
1054       uint32_t len = 0;
1055       if (dist > 0) {
1056         int i = pos;
1057         int j = pos - dist;
1058         if (numzeros > 3) {
1059           int r = std::min<int>(numzeros - 1, zeros[hashpos]);
1060           if (i + r >= end) r = end - i - 1;
1061           i += r;
1062           j += r;
1063         }
1064         while (i < end && data_[i] == data_[j]) {
1065           i++;
1066           j++;
1067         }
1068         len = i - pos;
1069         // This can trigger even if the new length is slightly smaller than the
1070         // best length, because it is possible for a slightly cheaper distance
1071         // symbol to occur.
1072         if (len >= min_length_ && len + 2 >= best_len) {
1073           auto it = special_dist_table_.find(dist);
1074           int dist_symbol = (it == special_dist_table_.end())
1075                                 ? (num_special_distances_ + dist - 1)
1076                                 : it->second;
1077           found_match(len, dist_symbol);
1078           if (len > best_len) best_len = len;
1079         }
1080       }
1081 
1082       chainlength++;
1083       if (chainlength >= maxchainlength) break;
1084 
1085       if (numzeros >= 3 && len > numzeros) {
1086         if (hashpos == chainz[hashpos]) break;
1087         hashpos = chainz[hashpos];
1088         if (zeros[hashpos] != numzeros) break;
1089       } else {
1090         if (hashpos == chain[hashpos]) break;
1091         hashpos = chain[hashpos];
1092         if (val[hashpos] != (int)hashval) break;  // outdated hash value
1093       }
1094     }
1095   }
FindMatchjxl::__anon30e99be10211::HashChain1096   void FindMatch(size_t pos, int max_dist, size_t* result_dist_symbol,
1097                  size_t* result_len) const {
1098     *result_dist_symbol = 0;
1099     *result_len = 1;
1100     FindMatches(pos, max_dist, [&](size_t len, size_t dist_symbol) {
1101       if (len > *result_len ||
1102           (len == *result_len && *result_dist_symbol > dist_symbol)) {
1103         *result_len = len;
1104         *result_dist_symbol = dist_symbol;
1105       }
1106     });
1107   }
1108 };
1109 
LenCost(size_t len)1110 float LenCost(size_t len) {
1111   uint32_t nbits, bits, tok;
1112   HybridUintConfig(1, 0, 0).Encode(len, &tok, &nbits, &bits);
1113   constexpr float kCostTable[] = {
1114       2.797667318563126,  3.213177690381199,  2.5706009246743737,
1115       2.408392498667534,  2.829649191872326,  3.3923087753324577,
1116       4.029267451554331,  4.415576699706408,  4.509357574741465,
1117       9.21481543803004,   10.020590190114898, 11.858671627804766,
1118       12.45853300490526,  11.713105831990857, 12.561996324849314,
1119       13.775477692278367, 13.174027068768641,
1120   };
1121   size_t table_size = sizeof kCostTable / sizeof *kCostTable;
1122   if (tok >= table_size) tok = table_size - 1;
1123   return kCostTable[tok] + nbits;
1124 }
1125 
1126 // TODO(veluca): this does not take into account usage or non-usage of distance
1127 // multipliers.
DistCost(size_t dist)1128 float DistCost(size_t dist) {
1129   uint32_t nbits, bits, tok;
1130   HybridUintConfig(7, 0, 0).Encode(dist, &tok, &nbits, &bits);
1131   constexpr float kCostTable[] = {
1132       6.368282626312716,  5.680793277090298,  8.347404197105247,
1133       7.641619201599141,  6.914328374119438,  7.959808291537444,
1134       8.70023120759855,   8.71378518934703,   9.379132523982769,
1135       9.110472749092708,  9.159029569270908,  9.430936766731973,
1136       7.278284055315169,  7.8278514904267755, 10.026641158289236,
1137       9.976049229827066,  9.64351607048908,   9.563403863480442,
1138       10.171474111762747, 10.45950155077234,  9.994813912104219,
1139       10.322524683741156, 8.465808729388186,  8.756254166066853,
1140       10.160930174662234, 10.247329273413435, 10.04090403724809,
1141       10.129398517544082, 9.342311691539546,  9.07608009102374,
1142       10.104799540677513, 10.378079384990906, 10.165828974075072,
1143       10.337595322341553, 7.940557464567944,  10.575665823319431,
1144       11.023344321751955, 10.736144698831827, 11.118277044595054,
1145       7.468468230648442,  10.738305230932939, 10.906980780216568,
1146       10.163468216353817, 10.17805759656433,  11.167283670483565,
1147       11.147050200274544, 10.517921919244333, 10.651764778156886,
1148       10.17074446448919,  11.217636876224745, 11.261630721139484,
1149       11.403140815247259, 10.892472096873417, 11.1859607804481,
1150       8.017346947551262,  7.895143720278828,  11.036577113822025,
1151       11.170562110315794, 10.326988722591086, 10.40872184751056,
1152       11.213498225466386, 11.30580635516863,  10.672272515665442,
1153       10.768069466228063, 11.145257364153565, 11.64668307145549,
1154       10.593156194627339, 11.207499484844943, 10.767517766396908,
1155       10.826629811407042, 10.737764794499988, 10.6200448518045,
1156       10.191315385198092, 8.468384171390085,  11.731295299170432,
1157       11.824619886654398, 10.41518844301179,  10.16310536548649,
1158       10.539423685097576, 10.495136599328031, 10.469112847728267,
1159       11.72057686174922,  10.910326337834674, 11.378921834673758,
1160       11.847759036098536, 11.92071647623854,  10.810628276345282,
1161       11.008601085273893, 11.910326337834674, 11.949212023423133,
1162       11.298614839104337, 11.611603659010392, 10.472930394619985,
1163       11.835564720850282, 11.523267392285337, 12.01055816679611,
1164       8.413029688994023,  11.895784139536406, 11.984679534970505,
1165       11.220654278717394, 11.716311684833672, 10.61036646226114,
1166       10.89849965960364,  10.203762898863669, 10.997560826267238,
1167       11.484217379438984, 11.792836176993665, 12.24310468755171,
1168       11.464858097919262, 12.212747017409377, 11.425595666074955,
1169       11.572048533398757, 12.742093965163013, 11.381874288645637,
1170       12.191870445817015, 11.683156920035426, 11.152442115262197,
1171       11.90303691580457,  11.653292787169159, 11.938615382266098,
1172       16.970641701570223, 16.853602280380002, 17.26240782594733,
1173       16.644655390108507, 17.14310889757499,  16.910935455445955,
1174       17.505678976959697, 17.213498225466388, 2.4162310293553024,
1175       3.494587244462329,  3.5258600986408344, 3.4959806589517095,
1176       3.098390886949687,  3.343454654302911,  3.588847442290287,
1177       4.14614790111827,   5.152948641990529,  7.433696808092598,
1178       9.716311684833672,
1179   };
1180   size_t table_size = sizeof kCostTable / sizeof *kCostTable;
1181   if (tok >= table_size) tok = table_size - 1;
1182   return kCostTable[tok] + nbits;
1183 }
1184 
ApplyLZ77_LZ77(const HistogramParams & params,size_t num_contexts,const std::vector<std::vector<Token>> & tokens,LZ77Params & lz77,std::vector<std::vector<Token>> & tokens_lz77)1185 void ApplyLZ77_LZ77(const HistogramParams& params, size_t num_contexts,
1186                     const std::vector<std::vector<Token>>& tokens,
1187                     LZ77Params& lz77,
1188                     std::vector<std::vector<Token>>& tokens_lz77) {
1189   // TODO(veluca): tune heuristics here.
1190   SymbolCostEstimator sce(num_contexts, params.force_huffman, tokens, lz77);
1191   float bit_decrease = 0;
1192   size_t total_symbols = 0;
1193   tokens_lz77.resize(tokens.size());
1194   HybridUintConfig uint_config;
1195   std::vector<float> sym_cost;
1196   for (size_t stream = 0; stream < tokens.size(); stream++) {
1197     size_t distance_multiplier =
1198         params.image_widths.size() > stream ? params.image_widths[stream] : 0;
1199     const auto& in = tokens[stream];
1200     auto& out = tokens_lz77[stream];
1201     total_symbols += in.size();
1202     // Cumulative sum of bit costs.
1203     sym_cost.resize(in.size() + 1);
1204     for (size_t i = 0; i < in.size(); i++) {
1205       uint32_t tok, nbits, unused_bits;
1206       uint_config.Encode(in[i].value, &tok, &nbits, &unused_bits);
1207       sym_cost[i + 1] = sce.Bits(in[i].context, tok) + nbits + sym_cost[i];
1208     }
1209 
1210     out.reserve(in.size());
1211     size_t max_distance = in.size();
1212     size_t min_length = lz77.min_length;
1213     JXL_ASSERT(min_length >= 3);
1214     size_t max_length = in.size();
1215 
1216     // Use next power of two as window size.
1217     size_t window_size = 1;
1218     while (window_size < max_distance && window_size < kWindowSize) {
1219       window_size <<= 1;
1220     }
1221 
1222     HashChain chain(in.data(), in.size(), window_size, min_length, max_length,
1223                     distance_multiplier);
1224     size_t len, dist_symbol;
1225 
1226     const size_t max_lazy_match_len = 256;  // 0 to disable lazy matching
1227 
1228     // Whether the next symbol was already updated (to test lazy matching)
1229     bool already_updated = false;
1230     for (size_t i = 0; i < in.size(); i++) {
1231       out.push_back(in[i]);
1232       if (!already_updated) chain.Update(i);
1233       already_updated = false;
1234       chain.FindMatch(i, max_distance, &dist_symbol, &len);
1235       if (len >= min_length) {
1236         if (len < max_lazy_match_len && i + 1 < in.size()) {
1237           // Try length at next symbol lazy matching
1238           chain.Update(i + 1);
1239           already_updated = true;
1240           size_t len2, dist_symbol2;
1241           chain.FindMatch(i + 1, max_distance, &dist_symbol2, &len2);
1242           if (len2 > len) {
1243             // Use the lazy match. Add literal, and use the next length starting
1244             // from the next byte.
1245             ++i;
1246             already_updated = false;
1247             len = len2;
1248             dist_symbol = dist_symbol2;
1249             out.push_back(in[i]);
1250           }
1251         }
1252 
1253         float cost = sym_cost[i + len] - sym_cost[i];
1254         size_t lz77_len = len - lz77.min_length;
1255         float lz77_cost = LenCost(lz77_len) + DistCost(dist_symbol) +
1256                           sce.AddSymbolCost(out.back().context);
1257 
1258         if (lz77_cost <= cost) {
1259           out.back().value = len - min_length;
1260           out.back().is_lz77_length = true;
1261           out.emplace_back(lz77.nonserialized_distance_context, dist_symbol);
1262           bit_decrease += cost - lz77_cost;
1263         } else {
1264           // LZ77 match ignored, and symbol already pushed. Push all other
1265           // symbols and skip.
1266           for (size_t j = 1; j < len; j++) {
1267             out.push_back(in[i + j]);
1268           }
1269         }
1270 
1271         if (already_updated) {
1272           chain.Update(i + 2, len - 2);
1273           already_updated = false;
1274         } else {
1275           chain.Update(i + 1, len - 1);
1276         }
1277         i += len - 1;
1278       } else {
1279         // Literal, already pushed
1280       }
1281     }
1282   }
1283 
1284   if (bit_decrease > total_symbols * 0.2 + 16) {
1285     lz77.enabled = true;
1286   }
1287 }
1288 
ApplyLZ77_Optimal(const HistogramParams & params,size_t num_contexts,const std::vector<std::vector<Token>> & tokens,LZ77Params & lz77,std::vector<std::vector<Token>> & tokens_lz77)1289 void ApplyLZ77_Optimal(const HistogramParams& params, size_t num_contexts,
1290                        const std::vector<std::vector<Token>>& tokens,
1291                        LZ77Params& lz77,
1292                        std::vector<std::vector<Token>>& tokens_lz77) {
1293   std::vector<std::vector<Token>> tokens_for_cost_estimate;
1294   ApplyLZ77_LZ77(params, num_contexts, tokens, lz77, tokens_for_cost_estimate);
1295   // If greedy-LZ77 does not give better compression than no-lz77, no reason to
1296   // run the optimal matching.
1297   if (!lz77.enabled) return;
1298   SymbolCostEstimator sce(num_contexts + 1, params.force_huffman,
1299                           tokens_for_cost_estimate, lz77);
1300   size_t total_symbols = 0;
1301   tokens_lz77.resize(tokens.size());
1302   HybridUintConfig uint_config;
1303   std::vector<float> sym_cost;
1304   std::vector<uint32_t> dist_symbols;
1305   for (size_t stream = 0; stream < tokens.size(); stream++) {
1306     size_t distance_multiplier =
1307         params.image_widths.size() > stream ? params.image_widths[stream] : 0;
1308     const auto& in = tokens[stream];
1309     auto& out = tokens_lz77[stream];
1310     total_symbols += in.size();
1311     // Cumulative sum of bit costs.
1312     sym_cost.resize(in.size() + 1);
1313     for (size_t i = 0; i < in.size(); i++) {
1314       uint32_t tok, nbits, unused_bits;
1315       uint_config.Encode(in[i].value, &tok, &nbits, &unused_bits);
1316       sym_cost[i + 1] = sce.Bits(in[i].context, tok) + nbits + sym_cost[i];
1317     }
1318 
1319     out.reserve(in.size());
1320     size_t max_distance = in.size();
1321     size_t min_length = lz77.min_length;
1322     JXL_ASSERT(min_length >= 3);
1323     size_t max_length = in.size();
1324 
1325     // Use next power of two as window size.
1326     size_t window_size = 1;
1327     while (window_size < max_distance && window_size < kWindowSize) {
1328       window_size <<= 1;
1329     }
1330 
1331     HashChain chain(in.data(), in.size(), window_size, min_length, max_length,
1332                     distance_multiplier);
1333 
1334     struct MatchInfo {
1335       uint32_t len;
1336       uint32_t dist_symbol;
1337       uint32_t ctx;
1338       float total_cost = std::numeric_limits<float>::max();
1339     };
1340     // Total cost to encode the first N symbols.
1341     std::vector<MatchInfo> prefix_costs(in.size() + 1);
1342     prefix_costs[0].total_cost = 0;
1343 
1344     size_t rle_length = 0;
1345     size_t skip_lz77 = 0;
1346     for (size_t i = 0; i < in.size(); i++) {
1347       chain.Update(i);
1348       float lit_cost =
1349           prefix_costs[i].total_cost + sym_cost[i + 1] - sym_cost[i];
1350       if (prefix_costs[i + 1].total_cost > lit_cost) {
1351         prefix_costs[i + 1].dist_symbol = 0;
1352         prefix_costs[i + 1].len = 1;
1353         prefix_costs[i + 1].ctx = in[i].context;
1354         prefix_costs[i + 1].total_cost = lit_cost;
1355       }
1356       if (skip_lz77 > 0) {
1357         skip_lz77--;
1358         continue;
1359       }
1360       dist_symbols.clear();
1361       chain.FindMatches(i, max_distance,
1362                         [&dist_symbols](size_t len, size_t dist_symbol) {
1363                           if (dist_symbols.size() <= len) {
1364                             dist_symbols.resize(len + 1, dist_symbol);
1365                           }
1366                           if (dist_symbol < dist_symbols[len]) {
1367                             dist_symbols[len] = dist_symbol;
1368                           }
1369                         });
1370       if (dist_symbols.size() <= min_length) continue;
1371       {
1372         size_t best_cost = dist_symbols.back();
1373         for (size_t j = dist_symbols.size() - 1; j >= min_length; j--) {
1374           if (dist_symbols[j] < best_cost) {
1375             best_cost = dist_symbols[j];
1376           }
1377           dist_symbols[j] = best_cost;
1378         }
1379       }
1380       for (size_t j = min_length; j < dist_symbols.size(); j++) {
1381         // Cost model that uses results from lazy LZ77.
1382         float lz77_cost = sce.LenCost(in[i].context, j - min_length, lz77) +
1383                           sce.DistCost(dist_symbols[j], lz77);
1384         float cost = prefix_costs[i].total_cost + lz77_cost;
1385         if (prefix_costs[i + j].total_cost > cost) {
1386           prefix_costs[i + j].len = j;
1387           prefix_costs[i + j].dist_symbol = dist_symbols[j] + 1;
1388           prefix_costs[i + j].ctx = in[i].context;
1389           prefix_costs[i + j].total_cost = cost;
1390         }
1391       }
1392       // We are in a RLE sequence: skip all the symbols except the first 8 and
1393       // the last 8. This avoid quadratic costs for sequences with long runs of
1394       // the same symbol.
1395       if ((dist_symbols.back() == 0 && distance_multiplier == 0) ||
1396           (dist_symbols.back() == 1 && distance_multiplier != 0)) {
1397         rle_length++;
1398       } else {
1399         rle_length = 0;
1400       }
1401       if (rle_length >= 8 && dist_symbols.size() > 9) {
1402         skip_lz77 = dist_symbols.size() - 10;
1403         rle_length = 0;
1404       }
1405     }
1406     size_t pos = in.size();
1407     while (pos > 0) {
1408       bool is_lz77_length = prefix_costs[pos].dist_symbol != 0;
1409       if (is_lz77_length) {
1410         size_t dist_symbol = prefix_costs[pos].dist_symbol - 1;
1411         out.emplace_back(lz77.nonserialized_distance_context, dist_symbol);
1412       }
1413       size_t val = is_lz77_length ? prefix_costs[pos].len - min_length
1414                                   : in[pos - 1].value;
1415       out.emplace_back(prefix_costs[pos].ctx, val);
1416       out.back().is_lz77_length = is_lz77_length;
1417       pos -= prefix_costs[pos].len;
1418     }
1419     std::reverse(out.begin(), out.end());
1420   }
1421 }
1422 
ApplyLZ77(const HistogramParams & params,size_t num_contexts,const std::vector<std::vector<Token>> & tokens,LZ77Params & lz77,std::vector<std::vector<Token>> & tokens_lz77)1423 void ApplyLZ77(const HistogramParams& params, size_t num_contexts,
1424                const std::vector<std::vector<Token>>& tokens, LZ77Params& lz77,
1425                std::vector<std::vector<Token>>& tokens_lz77) {
1426   lz77.enabled = false;
1427   if (params.force_huffman) {
1428     lz77.min_symbol = std::min(PREFIX_MAX_ALPHABET_SIZE - 32, 512);
1429   } else {
1430     lz77.min_symbol = 224;
1431   }
1432   if (params.lz77_method == HistogramParams::LZ77Method::kNone) {
1433     return;
1434   } else if (params.lz77_method == HistogramParams::LZ77Method::kRLE) {
1435     ApplyLZ77_RLE(params, num_contexts, tokens, lz77, tokens_lz77);
1436   } else if (params.lz77_method == HistogramParams::LZ77Method::kLZ77) {
1437     ApplyLZ77_LZ77(params, num_contexts, tokens, lz77, tokens_lz77);
1438   } else if (params.lz77_method == HistogramParams::LZ77Method::kOptimal) {
1439     ApplyLZ77_Optimal(params, num_contexts, tokens, lz77, tokens_lz77);
1440   } else {
1441     JXL_ABORT("Not implemented");
1442   }
1443 }
1444 }  // namespace
1445 
BuildAndEncodeHistograms(const HistogramParams & params,size_t num_contexts,std::vector<std::vector<Token>> & tokens,EntropyEncodingData * codes,std::vector<uint8_t> * context_map,BitWriter * writer,size_t layer,AuxOut * aux_out)1446 size_t BuildAndEncodeHistograms(const HistogramParams& params,
1447                                 size_t num_contexts,
1448                                 std::vector<std::vector<Token>>& tokens,
1449                                 EntropyEncodingData* codes,
1450                                 std::vector<uint8_t>* context_map,
1451                                 BitWriter* writer, size_t layer,
1452                                 AuxOut* aux_out) {
1453   size_t total_bits = 0;
1454   codes->lz77.nonserialized_distance_context = num_contexts;
1455   std::vector<std::vector<Token>> tokens_lz77;
1456   ApplyLZ77(params, num_contexts, tokens, codes->lz77, tokens_lz77);
1457   if (ans_fuzzer_friendly_) {
1458     codes->lz77.length_uint_config = HybridUintConfig(10, 0, 0);
1459     codes->lz77.min_symbol = 2048;
1460   }
1461 
1462   const size_t max_contexts = std::min(num_contexts, kClustersLimit);
1463   BitWriter::Allotment allotment(writer,
1464                                  128 + num_contexts * 40 + max_contexts * 96);
1465   if (writer) {
1466     JXL_CHECK(Bundle::Write(codes->lz77, writer, layer, aux_out));
1467   } else {
1468     size_t ebits, bits;
1469     JXL_CHECK(Bundle::CanEncode(codes->lz77, &ebits, &bits));
1470     total_bits += bits;
1471   }
1472   if (codes->lz77.enabled) {
1473     if (writer) {
1474       size_t b = writer->BitsWritten();
1475       EncodeUintConfig(codes->lz77.length_uint_config, writer,
1476                        /*log_alpha_size=*/8);
1477       total_bits += writer->BitsWritten() - b;
1478     } else {
1479       SizeWriter size_writer;
1480       EncodeUintConfig(codes->lz77.length_uint_config, &size_writer,
1481                        /*log_alpha_size=*/8);
1482       total_bits += size_writer.size;
1483     }
1484     num_contexts += 1;
1485     tokens = std::move(tokens_lz77);
1486   }
1487   size_t total_tokens = 0;
1488   // Build histograms.
1489   HistogramBuilder builder(num_contexts);
1490   HybridUintConfig uint_config;  //  Default config for clustering.
1491   // Unless we are using the kContextMap histogram option.
1492   if (params.uint_method == HistogramParams::HybridUintMethod::kContextMap) {
1493     uint_config = HybridUintConfig(2, 0, 1);
1494   }
1495   if (ans_fuzzer_friendly_) {
1496     uint_config = HybridUintConfig(10, 0, 0);
1497   }
1498   for (size_t i = 0; i < tokens.size(); ++i) {
1499     for (size_t j = 0; j < tokens[i].size(); ++j) {
1500       const Token token = tokens[i][j];
1501       total_tokens++;
1502       uint32_t tok, nbits, bits;
1503       (token.is_lz77_length ? codes->lz77.length_uint_config : uint_config)
1504           .Encode(token.value, &tok, &nbits, &bits);
1505       tok += token.is_lz77_length ? codes->lz77.min_symbol : 0;
1506       builder.VisitSymbol(tok, token.context);
1507     }
1508   }
1509 
1510   bool use_prefix_code =
1511       params.force_huffman || total_tokens < 100 ||
1512       params.clustering == HistogramParams::ClusteringType::kFastest ||
1513       ans_fuzzer_friendly_;
1514   if (!use_prefix_code) {
1515     bool all_singleton = true;
1516     for (size_t i = 0; i < num_contexts; i++) {
1517       if (builder.Histo(i).ShannonEntropy() >= 1e-5) {
1518         all_singleton = false;
1519       }
1520     }
1521     if (all_singleton) {
1522       use_prefix_code = true;
1523     }
1524   }
1525 
1526   // Encode histograms.
1527   total_bits += builder.BuildAndStoreEntropyCodes(params, tokens, codes,
1528                                                   context_map, use_prefix_code,
1529                                                   writer, layer, aux_out);
1530   allotment.FinishedHistogram(writer);
1531   ReclaimAndCharge(writer, &allotment, layer, aux_out);
1532 
1533   if (aux_out != nullptr) {
1534     aux_out->layers[layer].num_clustered_histograms +=
1535         codes->encoding_info.size();
1536   }
1537   return total_bits;
1538 }
1539 
WriteTokens(const std::vector<Token> & tokens,const EntropyEncodingData & codes,const std::vector<uint8_t> & context_map,BitWriter * writer)1540 size_t WriteTokens(const std::vector<Token>& tokens,
1541                    const EntropyEncodingData& codes,
1542                    const std::vector<uint8_t>& context_map, BitWriter* writer) {
1543   size_t num_extra_bits = 0;
1544   if (codes.use_prefix_code) {
1545     for (size_t i = 0; i < tokens.size(); i++) {
1546       uint32_t tok, nbits, bits;
1547       const Token& token = tokens[i];
1548       size_t histo = context_map[token.context];
1549       (token.is_lz77_length ? codes.lz77.length_uint_config
1550                             : codes.uint_config[histo])
1551           .Encode(token.value, &tok, &nbits, &bits);
1552       tok += token.is_lz77_length ? codes.lz77.min_symbol : 0;
1553       // Combine two calls to the BitWriter. Equivalent to:
1554       // writer->Write(codes.encoding_info[histo][tok].depth,
1555       //               codes.encoding_info[histo][tok].bits);
1556       // writer->Write(nbits, bits);
1557       uint64_t data = codes.encoding_info[histo][tok].bits;
1558       data |= bits << codes.encoding_info[histo][tok].depth;
1559       writer->Write(codes.encoding_info[histo][tok].depth + nbits, data);
1560       num_extra_bits += nbits;
1561     }
1562     return num_extra_bits;
1563   }
1564   std::vector<uint64_t> out;
1565   std::vector<uint8_t> out_nbits;
1566   out.reserve(tokens.size());
1567   out_nbits.reserve(tokens.size());
1568   uint64_t allbits = 0;
1569   size_t numallbits = 0;
1570   // Writes in *reversed* order.
1571   auto addbits = [&](size_t bits, size_t nbits) {
1572     JXL_DASSERT(bits >> nbits == 0);
1573     if (JXL_UNLIKELY(numallbits + nbits > BitWriter::kMaxBitsPerCall)) {
1574       out.push_back(allbits);
1575       out_nbits.push_back(numallbits);
1576       numallbits = allbits = 0;
1577     }
1578     allbits <<= nbits;
1579     allbits |= bits;
1580     numallbits += nbits;
1581   };
1582   const int end = tokens.size();
1583   ANSCoder ans;
1584   for (int i = end - 1; i >= 0; --i) {
1585     const Token token = tokens[i];
1586     const uint8_t histo = context_map[token.context];
1587     uint32_t tok, nbits, bits;
1588     (token.is_lz77_length ? codes.lz77.length_uint_config
1589                           : codes.uint_config[histo])
1590         .Encode(tokens[i].value, &tok, &nbits, &bits);
1591     tok += token.is_lz77_length ? codes.lz77.min_symbol : 0;
1592     const ANSEncSymbolInfo& info = codes.encoding_info[histo][tok];
1593     // Extra bits first as this is reversed.
1594     addbits(bits, nbits);
1595     num_extra_bits += nbits;
1596     uint8_t ans_nbits = 0;
1597     uint32_t ans_bits = ans.PutSymbol(info, &ans_nbits);
1598     addbits(ans_bits, ans_nbits);
1599   }
1600   const uint32_t state = ans.GetState();
1601   writer->Write(32, state);
1602   writer->Write(numallbits, allbits);
1603   for (int i = out.size(); i > 0; --i) {
1604     writer->Write(out_nbits[i - 1], out[i - 1]);
1605   }
1606   return num_extra_bits;
1607 }
1608 
WriteTokens(const std::vector<Token> & tokens,const EntropyEncodingData & codes,const std::vector<uint8_t> & context_map,BitWriter * writer,size_t layer,AuxOut * aux_out)1609 void WriteTokens(const std::vector<Token>& tokens,
1610                  const EntropyEncodingData& codes,
1611                  const std::vector<uint8_t>& context_map, BitWriter* writer,
1612                  size_t layer, AuxOut* aux_out) {
1613   BitWriter::Allotment allotment(writer, 32 * tokens.size() + 32 * 1024 * 4);
1614   size_t num_extra_bits = WriteTokens(tokens, codes, context_map, writer);
1615   ReclaimAndCharge(writer, &allotment, layer, aux_out);
1616   if (aux_out != nullptr) {
1617     aux_out->layers[layer].extra_bits += num_extra_bits;
1618   }
1619 }
1620 
SetANSFuzzerFriendly(bool ans_fuzzer_friendly)1621 void SetANSFuzzerFriendly(bool ans_fuzzer_friendly) {
1622 #if JXL_IS_DEBUG_BUILD  // Guard against accidential / malicious changes.
1623   ans_fuzzer_friendly_ = ans_fuzzer_friendly;
1624 #endif
1625 }
1626 }  // namespace jxl
1627