1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
2 //
3 // Use of this source code is governed by a BSD-style
4 // license that can be found in the LICENSE file.
5 
6 #include "lib/jxl/enc_ans.h"
7 
8 #include <stdint.h>
9 
10 #include <algorithm>
11 #include <array>
12 #include <cmath>
13 #include <limits>
14 #include <numeric>
15 #include <type_traits>
16 #include <unordered_map>
17 #include <utility>
18 #include <vector>
19 
20 #include "lib/jxl/ans_common.h"
21 #include "lib/jxl/aux_out.h"
22 #include "lib/jxl/aux_out_fwd.h"
23 #include "lib/jxl/base/bits.h"
24 #include "lib/jxl/dec_ans.h"
25 #include "lib/jxl/enc_cluster.h"
26 #include "lib/jxl/enc_context_map.h"
27 #include "lib/jxl/enc_huffman.h"
28 #include "lib/jxl/fast_math-inl.h"
29 #include "lib/jxl/fields.h"
30 
31 namespace jxl {
32 
33 namespace {
34 
35 bool ans_fuzzer_friendly_ = false;
36 
37 static const int kMaxNumSymbolsForSmallCode = 4;
38 
ANSBuildInfoTable(const ANSHistBin * counts,const AliasTable::Entry * table,size_t alphabet_size,size_t log_alpha_size,ANSEncSymbolInfo * info)39 void ANSBuildInfoTable(const ANSHistBin* counts, const AliasTable::Entry* table,
40                        size_t alphabet_size, size_t log_alpha_size,
41                        ANSEncSymbolInfo* info) {
42   size_t log_entry_size = ANS_LOG_TAB_SIZE - log_alpha_size;
43   size_t entry_size_minus_1 = (1 << log_entry_size) - 1;
44   // create valid alias table for empty streams.
45   for (size_t s = 0; s < std::max<size_t>(1, alphabet_size); ++s) {
46     const ANSHistBin freq = s == alphabet_size ? ANS_TAB_SIZE : counts[s];
47     info[s].freq_ = static_cast<uint16_t>(freq);
48 #ifdef USE_MULT_BY_RECIPROCAL
49     if (freq != 0) {
50       info[s].ifreq_ =
51           ((1ull << RECIPROCAL_PRECISION) + info[s].freq_ - 1) / info[s].freq_;
52     } else {
53       info[s].ifreq_ = 1;  // shouldn't matter (symbol shouldn't occur), but...
54     }
55 #endif
56     info[s].reverse_map_.resize(freq);
57   }
58   for (int i = 0; i < ANS_TAB_SIZE; i++) {
59     AliasTable::Symbol s =
60         AliasTable::Lookup(table, i, log_entry_size, entry_size_minus_1);
61     info[s.value].reverse_map_[s.offset] = i;
62   }
63 }
64 
EstimateDataBits(const ANSHistBin * histogram,const ANSHistBin * counts,size_t len)65 float EstimateDataBits(const ANSHistBin* histogram, const ANSHistBin* counts,
66                        size_t len) {
67   float sum = 0.0f;
68   int total_histogram = 0;
69   int total_counts = 0;
70   for (size_t i = 0; i < len; ++i) {
71     total_histogram += histogram[i];
72     total_counts += counts[i];
73     if (histogram[i] > 0) {
74       JXL_ASSERT(counts[i] > 0);
75       // += histogram[i] * -log(counts[i]/total_counts)
76       sum += histogram[i] *
77              std::max(0.0f, ANS_LOG_TAB_SIZE - FastLog2f(counts[i]));
78     }
79   }
80   if (total_histogram > 0) {
81     JXL_ASSERT(total_counts == ANS_TAB_SIZE);
82   }
83   return sum;
84 }
85 
EstimateDataBitsFlat(const ANSHistBin * histogram,size_t len)86 float EstimateDataBitsFlat(const ANSHistBin* histogram, size_t len) {
87   const float flat_bits = std::max(FastLog2f(len), 0.0f);
88   int total_histogram = 0;
89   for (size_t i = 0; i < len; ++i) {
90     total_histogram += histogram[i];
91   }
92   return total_histogram * flat_bits;
93 }
94 
95 // Static Huffman code for encoding logcounts. The last symbol is used as RLE
96 // sequence.
97 static const uint8_t kLogCountBitLengths[ANS_LOG_TAB_SIZE + 2] = {
98     5, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 6, 7, 7,
99 };
100 static const uint8_t kLogCountSymbols[ANS_LOG_TAB_SIZE + 2] = {
101     17, 11, 15, 3, 9, 7, 4, 2, 5, 6, 0, 33, 1, 65,
102 };
103 
104 // Returns the difference between largest count that can be represented and is
105 // smaller than "count" and smallest representable count larger than "count".
SmallestIncrement(uint32_t count,uint32_t shift)106 static int SmallestIncrement(uint32_t count, uint32_t shift) {
107   int bits = count == 0 ? -1 : FloorLog2Nonzero(count);
108   int drop_bits = bits - GetPopulationCountPrecision(bits, shift);
109   return drop_bits < 0 ? 1 : (1 << drop_bits);
110 }
111 
112 template <bool minimize_error_of_sum>
RebalanceHistogram(const float * targets,int max_symbol,int table_size,uint32_t shift,int * omit_pos,ANSHistBin * counts)113 bool RebalanceHistogram(const float* targets, int max_symbol, int table_size,
114                         uint32_t shift, int* omit_pos, ANSHistBin* counts) {
115   int sum = 0;
116   float sum_nonrounded = 0.0;
117   int remainder_pos = 0;  // if all of them are handled in first loop
118   int remainder_log = -1;
119   for (int n = 0; n < max_symbol; ++n) {
120     if (targets[n] > 0 && targets[n] < 1.0f) {
121       counts[n] = 1;
122       sum_nonrounded += targets[n];
123       sum += counts[n];
124     }
125   }
126   const float discount_ratio =
127       (table_size - sum) / (table_size - sum_nonrounded);
128   JXL_ASSERT(discount_ratio > 0);
129   JXL_ASSERT(discount_ratio <= 1.0f);
130   // Invariant for minimize_error_of_sum == true:
131   // abs(sum - sum_nonrounded)
132   //   <= SmallestIncrement(max(targets[])) + max_symbol
133   for (int n = 0; n < max_symbol; ++n) {
134     if (targets[n] >= 1.0f) {
135       sum_nonrounded += targets[n];
136       counts[n] =
137           static_cast<ANSHistBin>(targets[n] * discount_ratio);  // truncate
138       if (counts[n] == 0) counts[n] = 1;
139       if (counts[n] == table_size) counts[n] = table_size - 1;
140       // Round the count to the closest nonzero multiple of SmallestIncrement
141       // (when minimize_error_of_sum is false) or one of two closest so as to
142       // keep the sum as close as possible to sum_nonrounded.
143       int inc = SmallestIncrement(counts[n], shift);
144       counts[n] -= counts[n] & (inc - 1);
145       // TODO(robryk): Should we rescale targets[n]?
146       const float target =
147           minimize_error_of_sum ? (sum_nonrounded - sum) : targets[n];
148       if (counts[n] == 0 ||
149           (target > counts[n] + inc / 2 && counts[n] + inc < table_size)) {
150         counts[n] += inc;
151       }
152       sum += counts[n];
153       const int count_log = FloorLog2Nonzero(static_cast<uint32_t>(counts[n]));
154       if (count_log > remainder_log) {
155         remainder_pos = n;
156         remainder_log = count_log;
157       }
158     }
159   }
160   JXL_ASSERT(remainder_pos != -1);
161   // NOTE: This is the only place where counts could go negative. We could
162   // detect that, return false and make ANSHistBin uint32_t.
163   counts[remainder_pos] -= sum - table_size;
164   *omit_pos = remainder_pos;
165   return counts[remainder_pos] > 0;
166 }
167 
NormalizeCounts(ANSHistBin * counts,int * omit_pos,const int length,const int precision_bits,uint32_t shift,int * num_symbols,int * symbols)168 Status NormalizeCounts(ANSHistBin* counts, int* omit_pos, const int length,
169                        const int precision_bits, uint32_t shift,
170                        int* num_symbols, int* symbols) {
171   const int32_t table_size = 1 << precision_bits;  // target sum / table size
172   uint64_t total = 0;
173   int max_symbol = 0;
174   int symbol_count = 0;
175   for (int n = 0; n < length; ++n) {
176     total += counts[n];
177     if (counts[n] > 0) {
178       if (symbol_count < kMaxNumSymbolsForSmallCode) {
179         symbols[symbol_count] = n;
180       }
181       ++symbol_count;
182       max_symbol = n + 1;
183     }
184   }
185   *num_symbols = symbol_count;
186   if (symbol_count == 0) {
187     return true;
188   }
189   if (symbol_count == 1) {
190     counts[symbols[0]] = table_size;
191     return true;
192   }
193   if (symbol_count > table_size)
194     return JXL_FAILURE("Too many entries in an ANS histogram");
195 
196   const float norm = 1.f * table_size / total;
197   std::vector<float> targets(max_symbol);
198   for (size_t n = 0; n < targets.size(); ++n) {
199     targets[n] = norm * counts[n];
200   }
201   if (!RebalanceHistogram<false>(&targets[0], max_symbol, table_size, shift,
202                                  omit_pos, counts)) {
203     // Use an alternative rebalancing mechanism if the one above failed
204     // to create a histogram that is positive wherever the original one was.
205     if (!RebalanceHistogram<true>(&targets[0], max_symbol, table_size, shift,
206                                   omit_pos, counts)) {
207       return JXL_FAILURE("Logic error: couldn't rebalance a histogram");
208     }
209   }
210   return true;
211 }
212 
213 struct SizeWriter {
214   size_t size = 0;
Writejxl::__anonfd93e7cc0111::SizeWriter215   void Write(size_t num, size_t bits) { size += num; }
216 };
217 
218 template <typename Writer>
StoreVarLenUint8(size_t n,Writer * writer)219 void StoreVarLenUint8(size_t n, Writer* writer) {
220   JXL_DASSERT(n <= 255);
221   if (n == 0) {
222     writer->Write(1, 0);
223   } else {
224     writer->Write(1, 1);
225     size_t nbits = FloorLog2Nonzero(n);
226     writer->Write(3, nbits);
227     writer->Write(nbits, n - (1ULL << nbits));
228   }
229 }
230 
231 template <typename Writer>
StoreVarLenUint16(size_t n,Writer * writer)232 void StoreVarLenUint16(size_t n, Writer* writer) {
233   JXL_DASSERT(n <= 65535);
234   if (n == 0) {
235     writer->Write(1, 0);
236   } else {
237     writer->Write(1, 1);
238     size_t nbits = FloorLog2Nonzero(n);
239     writer->Write(4, nbits);
240     writer->Write(nbits, n - (1ULL << nbits));
241   }
242 }
243 
244 template <typename Writer>
EncodeCounts(const ANSHistBin * counts,const int alphabet_size,const int omit_pos,const int num_symbols,uint32_t shift,const int * symbols,Writer * writer)245 bool EncodeCounts(const ANSHistBin* counts, const int alphabet_size,
246                   const int omit_pos, const int num_symbols, uint32_t shift,
247                   const int* symbols, Writer* writer) {
248   bool ok = true;
249   if (num_symbols <= 2) {
250     // Small tree marker to encode 1-2 symbols.
251     writer->Write(1, 1);
252     if (num_symbols == 0) {
253       writer->Write(1, 0);
254       StoreVarLenUint8(0, writer);
255     } else {
256       writer->Write(1, num_symbols - 1);
257       for (int i = 0; i < num_symbols; ++i) {
258         StoreVarLenUint8(symbols[i], writer);
259       }
260     }
261     if (num_symbols == 2) {
262       writer->Write(ANS_LOG_TAB_SIZE, counts[symbols[0]]);
263     }
264   } else {
265     // Mark non-small tree.
266     writer->Write(1, 0);
267     // Mark non-flat histogram.
268     writer->Write(1, 0);
269 
270     // Precompute sequences for RLE encoding. Contains the number of identical
271     // values starting at a given index. Only contains the value at the first
272     // element of the series.
273     std::vector<uint32_t> same(alphabet_size, 0);
274     int last = 0;
275     for (int i = 1; i < alphabet_size; i++) {
276       // Store the sequence length once different symbol reached, or we're at
277       // the end, or the length is longer than we can encode, or we are at
278       // the omit_pos. We don't support including the omit_pos in an RLE
279       // sequence because this value may use a different amount of log2 bits
280       // than standard, it is too complex to handle in the decoder.
281       if (counts[i] != counts[last] || i + 1 == alphabet_size ||
282           (i - last) >= 255 || i == omit_pos || i == omit_pos + 1) {
283         same[last] = (i - last);
284         last = i + 1;
285       }
286     }
287 
288     int length = 0;
289     std::vector<int> logcounts(alphabet_size);
290     int omit_log = 0;
291     for (int i = 0; i < alphabet_size; ++i) {
292       JXL_ASSERT(counts[i] <= ANS_TAB_SIZE);
293       JXL_ASSERT(counts[i] >= 0);
294       if (i == omit_pos) {
295         length = i + 1;
296       } else if (counts[i] > 0) {
297         logcounts[i] = FloorLog2Nonzero(static_cast<uint32_t>(counts[i])) + 1;
298         length = i + 1;
299         if (i < omit_pos) {
300           omit_log = std::max(omit_log, logcounts[i] + 1);
301         } else {
302           omit_log = std::max(omit_log, logcounts[i]);
303         }
304       }
305     }
306     logcounts[omit_pos] = omit_log;
307 
308     // Elias gamma-like code for shift. Only difference is that if the number
309     // of bits to be encoded is equal to FloorLog2(ANS_LOG_TAB_SIZE+1), we skip
310     // the terminating 0 in unary coding.
311     int upper_bound_log = FloorLog2Nonzero(ANS_LOG_TAB_SIZE + 1);
312     int log = FloorLog2Nonzero(shift + 1);
313     writer->Write(log, (1 << log) - 1);
314     if (log != upper_bound_log) writer->Write(1, 0);
315     writer->Write(log, ((1 << log) - 1) & (shift + 1));
316 
317     // Since num_symbols >= 3, we know that length >= 3, therefore we encode
318     // length - 3.
319     if (length - 3 > 255) {
320       // Pretend that everything is OK, but complain about correctness later.
321       StoreVarLenUint8(255, writer);
322       ok = false;
323     } else {
324       StoreVarLenUint8(length - 3, writer);
325     }
326 
327     // The logcount values are encoded with a static Huffman code.
328     static const size_t kMinReps = 4;
329     size_t rep = ANS_LOG_TAB_SIZE + 1;
330     for (int i = 0; i < length; ++i) {
331       if (i > 0 && same[i - 1] > kMinReps) {
332         // Encode the RLE symbol and skip the repeated ones.
333         writer->Write(kLogCountBitLengths[rep], kLogCountSymbols[rep]);
334         StoreVarLenUint8(same[i - 1] - kMinReps - 1, writer);
335         i += same[i - 1] - 2;
336         continue;
337       }
338       writer->Write(kLogCountBitLengths[logcounts[i]],
339                     kLogCountSymbols[logcounts[i]]);
340     }
341     for (int i = 0; i < length; ++i) {
342       if (i > 0 && same[i - 1] > kMinReps) {
343         // Skip symbols encoded by RLE.
344         i += same[i - 1] - 2;
345         continue;
346       }
347       if (logcounts[i] > 1 && i != omit_pos) {
348         int bitcount = GetPopulationCountPrecision(logcounts[i] - 1, shift);
349         int drop_bits = logcounts[i] - 1 - bitcount;
350         JXL_CHECK((counts[i] & ((1 << drop_bits) - 1)) == 0);
351         writer->Write(bitcount, (counts[i] >> drop_bits) - (1 << bitcount));
352       }
353     }
354   }
355   return ok;
356 }
357 
EncodeFlatHistogram(const int alphabet_size,BitWriter * writer)358 void EncodeFlatHistogram(const int alphabet_size, BitWriter* writer) {
359   // Mark non-small tree.
360   writer->Write(1, 0);
361   // Mark uniform histogram.
362   writer->Write(1, 1);
363   JXL_ASSERT(alphabet_size > 0);
364   // Encode alphabet size.
365   StoreVarLenUint8(alphabet_size - 1, writer);
366 }
367 
ComputeHistoAndDataCost(const ANSHistBin * histogram,size_t alphabet_size,uint32_t method)368 float ComputeHistoAndDataCost(const ANSHistBin* histogram, size_t alphabet_size,
369                               uint32_t method) {
370   if (method == 0) {  // Flat code
371     return ANS_LOG_TAB_SIZE + 2 +
372            EstimateDataBitsFlat(histogram, alphabet_size);
373   }
374   // Non-flat: shift = method-1.
375   uint32_t shift = method - 1;
376   std::vector<ANSHistBin> counts(histogram, histogram + alphabet_size);
377   int omit_pos = 0;
378   int num_symbols;
379   int symbols[kMaxNumSymbolsForSmallCode] = {};
380   JXL_CHECK(NormalizeCounts(counts.data(), &omit_pos, alphabet_size,
381                             ANS_LOG_TAB_SIZE, shift, &num_symbols, symbols));
382   SizeWriter writer;
383   // Ignore the correctness, no real encoding happens at this stage.
384   (void)EncodeCounts(counts.data(), alphabet_size, omit_pos, num_symbols, shift,
385                      symbols, &writer);
386   return writer.size +
387          EstimateDataBits(histogram, counts.data(), alphabet_size);
388 }
389 
ComputeBestMethod(const ANSHistBin * histogram,size_t alphabet_size,float * cost,HistogramParams::ANSHistogramStrategy ans_histogram_strategy)390 uint32_t ComputeBestMethod(
391     const ANSHistBin* histogram, size_t alphabet_size, float* cost,
392     HistogramParams::ANSHistogramStrategy ans_histogram_strategy) {
393   size_t method = 0;
394   float fcost = ComputeHistoAndDataCost(histogram, alphabet_size, 0);
395   auto try_shift = [&](size_t shift) {
396     float c = ComputeHistoAndDataCost(histogram, alphabet_size, shift + 1);
397     if (c < fcost) {
398       method = shift + 1;
399       fcost = c;
400     }
401   };
402   switch (ans_histogram_strategy) {
403     case HistogramParams::ANSHistogramStrategy::kPrecise: {
404       for (uint32_t shift = 0; shift <= ANS_LOG_TAB_SIZE; shift++) {
405         try_shift(shift);
406       }
407       break;
408     }
409     case HistogramParams::ANSHistogramStrategy::kApproximate: {
410       for (uint32_t shift = 0; shift <= ANS_LOG_TAB_SIZE; shift += 2) {
411         try_shift(shift);
412       }
413       break;
414     }
415     case HistogramParams::ANSHistogramStrategy::kFast: {
416       try_shift(0);
417       try_shift(ANS_LOG_TAB_SIZE / 2);
418       try_shift(ANS_LOG_TAB_SIZE);
419       break;
420     }
421   };
422   *cost = fcost;
423   return method;
424 }
425 
426 }  // namespace
427 
428 // Returns an estimate of the cost of encoding this histogram and the
429 // corresponding data.
BuildAndStoreANSEncodingData(HistogramParams::ANSHistogramStrategy ans_histogram_strategy,const ANSHistBin * histogram,size_t alphabet_size,size_t log_alpha_size,bool use_prefix_code,ANSEncSymbolInfo * info,BitWriter * writer)430 size_t BuildAndStoreANSEncodingData(
431     HistogramParams::ANSHistogramStrategy ans_histogram_strategy,
432     const ANSHistBin* histogram, size_t alphabet_size, size_t log_alpha_size,
433     bool use_prefix_code, ANSEncSymbolInfo* info, BitWriter* writer) {
434   if (use_prefix_code) {
435     if (alphabet_size <= 1) return 0;
436     std::vector<uint32_t> histo(alphabet_size);
437     for (size_t i = 0; i < alphabet_size; i++) {
438       histo[i] = histogram[i];
439       JXL_CHECK(histogram[i] >= 0);
440     }
441     size_t cost = 0;
442     {
443       std::vector<uint8_t> depths(alphabet_size);
444       std::vector<uint16_t> bits(alphabet_size);
445       BitWriter tmp_writer;
446       BitWriter* w = writer ? writer : &tmp_writer;
447       size_t start = w->BitsWritten();
448       BitWriter::Allotment allotment(
449           w, 8 * alphabet_size + 8);  // safe upper bound
450       BuildAndStoreHuffmanTree(histo.data(), alphabet_size, depths.data(),
451                                bits.data(), w);
452       ReclaimAndCharge(w, &allotment, 0, /*aux_out=*/nullptr);
453 
454       for (size_t i = 0; i < alphabet_size; i++) {
455         info[i].bits = depths[i] == 0 ? 0 : bits[i];
456         info[i].depth = depths[i];
457       }
458       cost = w->BitsWritten() - start;
459     }
460     // Estimate data cost.
461     for (size_t i = 0; i < alphabet_size; i++) {
462       cost += histogram[i] * info[i].depth;
463     }
464     return cost;
465   }
466   JXL_ASSERT(alphabet_size <= ANS_TAB_SIZE);
467   // Ensure we ignore trailing zeros in the histogram.
468   if (alphabet_size != 0) {
469     size_t largest_symbol = 0;
470     for (size_t i = 0; i < alphabet_size; i++) {
471       if (histogram[i] != 0) largest_symbol = i;
472     }
473     alphabet_size = largest_symbol + 1;
474   }
475   float cost;
476   uint32_t method = ComputeBestMethod(histogram, alphabet_size, &cost,
477                                       ans_histogram_strategy);
478   JXL_ASSERT(cost >= 0);
479   int num_symbols;
480   int symbols[kMaxNumSymbolsForSmallCode] = {};
481   std::vector<ANSHistBin> counts(histogram, histogram + alphabet_size);
482   if (!counts.empty()) {
483     size_t sum = 0;
484     for (size_t i = 0; i < counts.size(); i++) {
485       sum += counts[i];
486     }
487     if (sum == 0) {
488       counts[0] = ANS_TAB_SIZE;
489     }
490   }
491   if (method == 0) {
492     counts = CreateFlatHistogram(alphabet_size, ANS_TAB_SIZE);
493     AliasTable::Entry a[ANS_MAX_ALPHABET_SIZE];
494     InitAliasTable(counts, ANS_TAB_SIZE, log_alpha_size, a);
495     ANSBuildInfoTable(counts.data(), a, alphabet_size, log_alpha_size, info);
496     if (writer != nullptr) {
497       EncodeFlatHistogram(alphabet_size, writer);
498     }
499     return cost;
500   }
501   int omit_pos = 0;
502   uint32_t shift = method - 1;
503   JXL_CHECK(NormalizeCounts(counts.data(), &omit_pos, alphabet_size,
504                             ANS_LOG_TAB_SIZE, shift, &num_symbols, symbols));
505   AliasTable::Entry a[ANS_MAX_ALPHABET_SIZE];
506   InitAliasTable(counts, ANS_TAB_SIZE, log_alpha_size, a);
507   ANSBuildInfoTable(counts.data(), a, alphabet_size, log_alpha_size, info);
508   if (writer != nullptr) {
509     bool ok = EncodeCounts(counts.data(), alphabet_size, omit_pos, num_symbols,
510                            shift, symbols, writer);
511     (void)ok;
512     JXL_DASSERT(ok);
513   }
514   return cost;
515 }
516 
ANSPopulationCost(const ANSHistBin * data,size_t alphabet_size)517 float ANSPopulationCost(const ANSHistBin* data, size_t alphabet_size) {
518   float c;
519   ComputeBestMethod(data, alphabet_size, &c,
520                     HistogramParams::ANSHistogramStrategy::kFast);
521   return c;
522 }
523 
524 template <typename Writer>
EncodeUintConfig(const HybridUintConfig uint_config,Writer * writer,size_t log_alpha_size)525 void EncodeUintConfig(const HybridUintConfig uint_config, Writer* writer,
526                       size_t log_alpha_size) {
527   writer->Write(CeilLog2Nonzero(log_alpha_size + 1),
528                 uint_config.split_exponent);
529   if (uint_config.split_exponent == log_alpha_size) {
530     return;  // msb/lsb don't matter.
531   }
532   size_t nbits = CeilLog2Nonzero(uint_config.split_exponent + 1);
533   writer->Write(nbits, uint_config.msb_in_token);
534   nbits = CeilLog2Nonzero(uint_config.split_exponent -
535                           uint_config.msb_in_token + 1);
536   writer->Write(nbits, uint_config.lsb_in_token);
537 }
538 template <typename Writer>
EncodeUintConfigs(const std::vector<HybridUintConfig> & uint_config,Writer * writer,size_t log_alpha_size)539 void EncodeUintConfigs(const std::vector<HybridUintConfig>& uint_config,
540                        Writer* writer, size_t log_alpha_size) {
541   // TODO(veluca): RLE?
542   for (size_t i = 0; i < uint_config.size(); i++) {
543     EncodeUintConfig(uint_config[i], writer, log_alpha_size);
544   }
545 }
546 template void EncodeUintConfigs(const std::vector<HybridUintConfig>&,
547                                 BitWriter*, size_t);
548 
549 namespace {
550 
ChooseUintConfigs(const HistogramParams & params,const std::vector<std::vector<Token>> & tokens,const std::vector<uint8_t> & context_map,std::vector<Histogram> * clustered_histograms,EntropyEncodingData * codes,size_t * log_alpha_size)551 void ChooseUintConfigs(const HistogramParams& params,
552                        const std::vector<std::vector<Token>>& tokens,
553                        const std::vector<uint8_t>& context_map,
554                        std::vector<Histogram>* clustered_histograms,
555                        EntropyEncodingData* codes, size_t* log_alpha_size) {
556   codes->uint_config.resize(clustered_histograms->size());
557 
558   if (params.uint_method == HistogramParams::HybridUintMethod::kNone) return;
559   if (params.uint_method == HistogramParams::HybridUintMethod::kContextMap) {
560     codes->uint_config.clear();
561     codes->uint_config.resize(clustered_histograms->size(),
562                               HybridUintConfig(2, 0, 1));
563     return;
564   }
565 
566   // Brute-force method that tries a few options.
567   std::vector<HybridUintConfig> configs;
568   if (params.uint_method == HistogramParams::HybridUintMethod::kBest) {
569     configs = {
570         HybridUintConfig(4, 2, 0),  // default
571         HybridUintConfig(4, 1, 0),  // less precise
572         HybridUintConfig(4, 2, 1),  // add sign
573         HybridUintConfig(4, 2, 2),  // add sign+parity
574         HybridUintConfig(4, 1, 2),  // add parity but less msb
575         // Same as above, but more direct coding.
576         HybridUintConfig(5, 2, 0), HybridUintConfig(5, 1, 0),
577         HybridUintConfig(5, 2, 1), HybridUintConfig(5, 2, 2),
578         HybridUintConfig(5, 1, 2),
579         // Same as above, but less direct coding.
580         HybridUintConfig(3, 2, 0), HybridUintConfig(3, 1, 0),
581         HybridUintConfig(3, 2, 1), HybridUintConfig(3, 1, 2),
582         // For near-lossless.
583         HybridUintConfig(4, 1, 3), HybridUintConfig(5, 1, 4),
584         HybridUintConfig(5, 2, 3), HybridUintConfig(6, 1, 5),
585         HybridUintConfig(6, 2, 4), HybridUintConfig(6, 0, 0),
586         // Other
587         HybridUintConfig(0, 0, 0),   // varlenuint
588         HybridUintConfig(2, 0, 1),   // works well for ctx map
589         HybridUintConfig(7, 0, 0),   // direct coding
590         HybridUintConfig(8, 0, 0),   // direct coding
591         HybridUintConfig(9, 0, 0),   // direct coding
592         HybridUintConfig(10, 0, 0),  // direct coding
593         HybridUintConfig(11, 0, 0),  // direct coding
594         HybridUintConfig(12, 0, 0),  // direct coding
595     };
596   } else if (params.uint_method == HistogramParams::HybridUintMethod::kFast) {
597     configs = {
598         HybridUintConfig(4, 2, 0),  // default
599         HybridUintConfig(4, 1, 2),  // add parity but less msb
600         HybridUintConfig(0, 0, 0),  // smallest histograms
601         HybridUintConfig(2, 0, 1),  // works well for ctx map
602     };
603   }
604 
605   std::vector<float> costs(clustered_histograms->size(),
606                            std::numeric_limits<float>::max());
607   std::vector<uint32_t> extra_bits(clustered_histograms->size());
608   std::vector<uint8_t> is_valid(clustered_histograms->size());
609   size_t max_alpha =
610       codes->use_prefix_code ? PREFIX_MAX_ALPHABET_SIZE : ANS_MAX_ALPHABET_SIZE;
611   for (HybridUintConfig cfg : configs) {
612     std::fill(is_valid.begin(), is_valid.end(), true);
613     std::fill(extra_bits.begin(), extra_bits.end(), 0);
614 
615     for (size_t i = 0; i < clustered_histograms->size(); i++) {
616       (*clustered_histograms)[i].Clear();
617     }
618     for (size_t i = 0; i < tokens.size(); ++i) {
619       for (size_t j = 0; j < tokens[i].size(); ++j) {
620         const Token token = tokens[i][j];
621         // TODO(veluca): do not ignore lz77 commands.
622         if (token.is_lz77_length) continue;
623         size_t histo = context_map[token.context];
624         uint32_t tok, nbits, bits;
625         cfg.Encode(token.value, &tok, &nbits, &bits);
626         if (tok >= max_alpha ||
627             (codes->lz77.enabled && tok >= codes->lz77.min_symbol)) {
628           is_valid[histo] = false;
629           continue;
630         }
631         extra_bits[histo] += nbits;
632         (*clustered_histograms)[histo].Add(tok);
633       }
634     }
635 
636     for (size_t i = 0; i < clustered_histograms->size(); i++) {
637       if (!is_valid[i]) continue;
638       float cost = (*clustered_histograms)[i].PopulationCost() + extra_bits[i];
639       // add signaling cost of the hybriduintconfig itself
640       cost += CeilLog2Nonzero(cfg.split_exponent + 1);
641       cost += CeilLog2Nonzero(cfg.split_exponent - cfg.msb_in_token + 1);
642       if (cost < costs[i]) {
643         codes->uint_config[i] = cfg;
644         costs[i] = cost;
645       }
646     }
647   }
648 
649   // Rebuild histograms.
650   for (size_t i = 0; i < clustered_histograms->size(); i++) {
651     (*clustered_histograms)[i].Clear();
652   }
653   *log_alpha_size = 4;
654   for (size_t i = 0; i < tokens.size(); ++i) {
655     for (size_t j = 0; j < tokens[i].size(); ++j) {
656       const Token token = tokens[i][j];
657       uint32_t tok, nbits, bits;
658       size_t histo = context_map[token.context];
659       (token.is_lz77_length ? codes->lz77.length_uint_config
660                             : codes->uint_config[histo])
661           .Encode(token.value, &tok, &nbits, &bits);
662       tok += token.is_lz77_length ? codes->lz77.min_symbol : 0;
663       (*clustered_histograms)[histo].Add(tok);
664       while (tok >= (1u << *log_alpha_size)) (*log_alpha_size)++;
665     }
666   }
667 #if JXL_ENABLE_ASSERT
668   size_t max_log_alpha_size = codes->use_prefix_code ? PREFIX_MAX_BITS : 8;
669   JXL_ASSERT(*log_alpha_size <= max_log_alpha_size);
670 #endif
671 }
672 
673 class HistogramBuilder {
674  public:
HistogramBuilder(const size_t num_contexts)675   explicit HistogramBuilder(const size_t num_contexts)
676       : histograms_(num_contexts) {}
677 
VisitSymbol(int symbol,size_t histo_idx)678   void VisitSymbol(int symbol, size_t histo_idx) {
679     JXL_DASSERT(histo_idx < histograms_.size());
680     histograms_[histo_idx].Add(symbol);
681   }
682 
683   // NOTE: `layer` is only for clustered_entropy; caller does ReclaimAndCharge.
BuildAndStoreEntropyCodes(const HistogramParams & params,const std::vector<std::vector<Token>> & tokens,EntropyEncodingData * codes,std::vector<uint8_t> * context_map,bool use_prefix_code,BitWriter * writer,size_t layer,AuxOut * aux_out) const684   size_t BuildAndStoreEntropyCodes(
685       const HistogramParams& params,
686       const std::vector<std::vector<Token>>& tokens, EntropyEncodingData* codes,
687       std::vector<uint8_t>* context_map, bool use_prefix_code,
688       BitWriter* writer, size_t layer, AuxOut* aux_out) const {
689     size_t cost = 0;
690     codes->encoding_info.clear();
691     std::vector<Histogram> clustered_histograms(histograms_);
692     context_map->resize(histograms_.size());
693     if (histograms_.size() > 1) {
694       if (!ans_fuzzer_friendly_) {
695         std::vector<uint32_t> histogram_symbols;
696         ClusterHistograms(params, histograms_, histograms_.size(),
697                           kClustersLimit, &clustered_histograms,
698                           &histogram_symbols);
699         for (size_t c = 0; c < histograms_.size(); ++c) {
700           (*context_map)[c] = static_cast<uint8_t>(histogram_symbols[c]);
701         }
702       } else {
703         fill(context_map->begin(), context_map->end(), 0);
704         size_t max_symbol = 0;
705         for (const Histogram& h : histograms_) {
706           max_symbol = std::max(h.data_.size(), max_symbol);
707         }
708         size_t num_symbols = 1 << CeilLog2Nonzero(max_symbol + 1);
709         clustered_histograms.resize(1);
710         clustered_histograms[0].Clear();
711         for (size_t i = 0; i < num_symbols; i++) {
712           clustered_histograms[0].Add(i);
713         }
714       }
715       if (writer != nullptr) {
716         EncodeContextMap(*context_map, clustered_histograms.size(), writer);
717       }
718     }
719     if (aux_out != nullptr) {
720       for (size_t i = 0; i < clustered_histograms.size(); ++i) {
721         aux_out->layers[layer].clustered_entropy +=
722             clustered_histograms[i].ShannonEntropy();
723       }
724     }
725     codes->use_prefix_code = use_prefix_code;
726     size_t log_alpha_size = codes->lz77.enabled ? 8 : 7;  // Sane default.
727     if (ans_fuzzer_friendly_) {
728       codes->uint_config.clear();
729       codes->uint_config.resize(1, HybridUintConfig(7, 0, 0));
730     } else {
731       ChooseUintConfigs(params, tokens, *context_map, &clustered_histograms,
732                         codes, &log_alpha_size);
733     }
734     if (log_alpha_size < 5) log_alpha_size = 5;
735     SizeWriter size_writer;  // Used if writer == nullptr to estimate costs.
736     cost += 1;
737     if (writer) writer->Write(1, use_prefix_code);
738 
739     if (use_prefix_code) {
740       log_alpha_size = PREFIX_MAX_BITS;
741     } else {
742       cost += 2;
743     }
744     if (writer == nullptr) {
745       EncodeUintConfigs(codes->uint_config, &size_writer, log_alpha_size);
746     } else {
747       if (!use_prefix_code) writer->Write(2, log_alpha_size - 5);
748       EncodeUintConfigs(codes->uint_config, writer, log_alpha_size);
749     }
750     if (use_prefix_code) {
751       for (size_t c = 0; c < clustered_histograms.size(); ++c) {
752         size_t num_symbol = 1;
753         for (size_t i = 0; i < clustered_histograms[c].data_.size(); i++) {
754           if (clustered_histograms[c].data_[i]) num_symbol = i + 1;
755         }
756         if (writer) {
757           StoreVarLenUint16(num_symbol - 1, writer);
758         } else {
759           StoreVarLenUint16(num_symbol - 1, &size_writer);
760         }
761       }
762     }
763     cost += size_writer.size;
764     for (size_t c = 0; c < clustered_histograms.size(); ++c) {
765       size_t num_symbol = 1;
766       for (size_t i = 0; i < clustered_histograms[c].data_.size(); i++) {
767         if (clustered_histograms[c].data_[i]) num_symbol = i + 1;
768       }
769       codes->encoding_info.emplace_back();
770       codes->encoding_info.back().resize(std::max<size_t>(1, num_symbol));
771 
772       BitWriter::Allotment allotment(writer, 256 + num_symbol * 24);
773       cost += BuildAndStoreANSEncodingData(
774           params.ans_histogram_strategy, clustered_histograms[c].data_.data(),
775           num_symbol, log_alpha_size, use_prefix_code,
776           codes->encoding_info.back().data(), writer);
777       allotment.FinishedHistogram(writer);
778       ReclaimAndCharge(writer, &allotment, layer, aux_out);
779     }
780     return cost;
781   }
782 
Histo(size_t i) const783   const Histogram& Histo(size_t i) const { return histograms_[i]; }
784 
785  private:
786   std::vector<Histogram> histograms_;
787 };
788 
789 class SymbolCostEstimator {
790  public:
SymbolCostEstimator(size_t num_contexts,bool force_huffman,const std::vector<std::vector<Token>> & tokens,const LZ77Params & lz77)791   SymbolCostEstimator(size_t num_contexts, bool force_huffman,
792                       const std::vector<std::vector<Token>>& tokens,
793                       const LZ77Params& lz77) {
794     HistogramBuilder builder(num_contexts);
795     // Build histograms for estimating lz77 savings.
796     HybridUintConfig uint_config;
797     for (size_t i = 0; i < tokens.size(); ++i) {
798       for (size_t j = 0; j < tokens[i].size(); ++j) {
799         const Token token = tokens[i][j];
800         uint32_t tok, nbits, bits;
801         (token.is_lz77_length ? lz77.length_uint_config : uint_config)
802             .Encode(token.value, &tok, &nbits, &bits);
803         tok += token.is_lz77_length ? lz77.min_symbol : 0;
804         builder.VisitSymbol(tok, token.context);
805       }
806     }
807     max_alphabet_size_ = 0;
808     for (size_t i = 0; i < num_contexts; i++) {
809       max_alphabet_size_ =
810           std::max(max_alphabet_size_, builder.Histo(i).data_.size());
811     }
812     bits_.resize(num_contexts * max_alphabet_size_);
813     // TODO(veluca): SIMD?
814     add_symbol_cost_.resize(num_contexts);
815     for (size_t i = 0; i < num_contexts; i++) {
816       float inv_total = 1.0f / (builder.Histo(i).total_count_ + 1e-8f);
817       float total_cost = 0;
818       for (size_t j = 0; j < builder.Histo(i).data_.size(); j++) {
819         size_t cnt = builder.Histo(i).data_[j];
820         float cost = 0;
821         if (cnt != 0 && cnt != builder.Histo(i).total_count_) {
822           cost = -FastLog2f(cnt * inv_total);
823           if (force_huffman) cost = std::ceil(cost);
824         } else if (cnt == 0) {
825           cost = ANS_LOG_TAB_SIZE;  // Highest possible cost.
826         }
827         bits_[i * max_alphabet_size_ + j] = cost;
828         total_cost += cost * builder.Histo(i).data_[j];
829       }
830       // Penalty for adding a lz77 symbol to this contest (only used for static
831       // cost model). Higher penalty for contexts that have a very low
832       // per-symbol entropy.
833       add_symbol_cost_[i] = std::max(0.0f, 6.0f - total_cost * inv_total);
834     }
835   }
Bits(size_t ctx,size_t sym) const836   float Bits(size_t ctx, size_t sym) const {
837     return bits_[ctx * max_alphabet_size_ + sym];
838   }
LenCost(size_t ctx,size_t len,const LZ77Params & lz77) const839   float LenCost(size_t ctx, size_t len, const LZ77Params& lz77) const {
840     uint32_t nbits, bits, tok;
841     lz77.length_uint_config.Encode(len, &tok, &nbits, &bits);
842     tok += lz77.min_symbol;
843     return nbits + Bits(ctx, tok);
844   }
DistCost(size_t len,const LZ77Params & lz77) const845   float DistCost(size_t len, const LZ77Params& lz77) const {
846     uint32_t nbits, bits, tok;
847     HybridUintConfig().Encode(len, &tok, &nbits, &bits);
848     return nbits + Bits(lz77.nonserialized_distance_context, tok);
849   }
AddSymbolCost(size_t idx) const850   float AddSymbolCost(size_t idx) const { return add_symbol_cost_[idx]; }
851 
852  private:
853   size_t max_alphabet_size_;
854   std::vector<float> bits_;
855   std::vector<float> add_symbol_cost_;
856 };
857 
ApplyLZ77_RLE(const HistogramParams & params,size_t num_contexts,const std::vector<std::vector<Token>> & tokens,LZ77Params & lz77,std::vector<std::vector<Token>> & tokens_lz77)858 void ApplyLZ77_RLE(const HistogramParams& params, size_t num_contexts,
859                    const std::vector<std::vector<Token>>& tokens,
860                    LZ77Params& lz77,
861                    std::vector<std::vector<Token>>& tokens_lz77) {
862   // TODO(veluca): tune heuristics here.
863   SymbolCostEstimator sce(num_contexts, params.force_huffman, tokens, lz77);
864   float bit_decrease = 0;
865   size_t total_symbols = 0;
866   tokens_lz77.resize(tokens.size());
867   std::vector<float> sym_cost;
868   HybridUintConfig uint_config;
869   for (size_t stream = 0; stream < tokens.size(); stream++) {
870     size_t distance_multiplier =
871         params.image_widths.size() > stream ? params.image_widths[stream] : 0;
872     const auto& in = tokens[stream];
873     auto& out = tokens_lz77[stream];
874     total_symbols += in.size();
875     // Cumulative sum of bit costs.
876     sym_cost.resize(in.size() + 1);
877     for (size_t i = 0; i < in.size(); i++) {
878       uint32_t tok, nbits, unused_bits;
879       uint_config.Encode(in[i].value, &tok, &nbits, &unused_bits);
880       sym_cost[i + 1] = sce.Bits(in[i].context, tok) + nbits + sym_cost[i];
881     }
882     out.reserve(in.size());
883     for (size_t i = 0; i < in.size(); i++) {
884       size_t num_to_copy = 0;
885       size_t distance_symbol = 0;  // 1 for RLE.
886       if (distance_multiplier != 0) {
887         distance_symbol = 1;  // Special distance 1 if enabled.
888         JXL_DASSERT(kSpecialDistances[1][0] == 1);
889         JXL_DASSERT(kSpecialDistances[1][1] == 0);
890       }
891       if (i > 0) {
892         for (; i + num_to_copy < in.size(); num_to_copy++) {
893           if (in[i + num_to_copy].value != in[i - 1].value) {
894             break;
895           }
896         }
897       }
898       if (num_to_copy == 0) {
899         out.push_back(in[i]);
900         continue;
901       }
902       float cost = sym_cost[i + num_to_copy] - sym_cost[i];
903       // This subtraction might overflow, but that's OK.
904       size_t lz77_len = num_to_copy - lz77.min_length;
905       float lz77_cost = num_to_copy >= lz77.min_length
906                             ? CeilLog2Nonzero(lz77_len + 1) + 1
907                             : 0;
908       if (num_to_copy < lz77.min_length || cost <= lz77_cost) {
909         for (size_t j = 0; j < num_to_copy; j++) {
910           out.push_back(in[i + j]);
911         }
912         i += num_to_copy - 1;
913         continue;
914       }
915       // Output the LZ77 length
916       out.emplace_back(in[i].context, lz77_len);
917       out.back().is_lz77_length = true;
918       i += num_to_copy - 1;
919       bit_decrease += cost - lz77_cost;
920       // Output the LZ77 copy distance.
921       out.emplace_back(lz77.nonserialized_distance_context, distance_symbol);
922     }
923   }
924 
925   if (bit_decrease > total_symbols * 0.2 + 16) {
926     lz77.enabled = true;
927   }
928 }
929 
930 // Hash chain for LZ77 matching
931 struct HashChain {
932   size_t size_;
933   std::vector<uint32_t> data_;
934 
935   unsigned hash_num_values_ = 32768;
936   unsigned hash_mask_ = hash_num_values_ - 1;
937   unsigned hash_shift_ = 5;
938 
939   std::vector<int> head;
940   std::vector<uint32_t> chain;
941   std::vector<int> val;
942 
943   // Speed up repetitions of zero
944   std::vector<int> headz;
945   std::vector<uint32_t> chainz;
946   std::vector<uint32_t> zeros;
947   uint32_t numzeros = 0;
948 
949   size_t window_size_;
950   size_t window_mask_;
951   size_t min_length_;
952   size_t max_length_;
953 
954   // Map of special distance codes.
955   std::unordered_map<int, int> special_dist_table_;
956   size_t num_special_distances_ = 0;
957 
958   uint32_t maxchainlength = 256;  // window_size_ to allow all
959 
HashChainjxl::__anonfd93e7cc0311::HashChain960   HashChain(const Token* data, size_t size, size_t window_size,
961             size_t min_length, size_t max_length, size_t distance_multiplier)
962       : size_(size),
963         window_size_(window_size),
964         window_mask_(window_size - 1),
965         min_length_(min_length),
966         max_length_(max_length) {
967     data_.resize(size);
968     for (size_t i = 0; i < size; i++) {
969       data_[i] = data[i].value;
970     }
971 
972     head.resize(hash_num_values_, -1);
973     val.resize(window_size_, -1);
974     chain.resize(window_size_);
975     for (uint32_t i = 0; i < window_size_; ++i) {
976       chain[i] = i;  // same value as index indicates uninitialized
977     }
978 
979     zeros.resize(window_size_);
980     headz.resize(window_size_ + 1, -1);
981     chainz.resize(window_size_);
982     for (uint32_t i = 0; i < window_size_; ++i) {
983       chainz[i] = i;
984     }
985     // Translate distance to special distance code.
986     if (distance_multiplier) {
987       // Count down, so if due to small distance multiplier multiple distances
988       // map to the same code, the smallest code will be used in the end.
989       for (int i = kNumSpecialDistances - 1; i >= 0; --i) {
990         int xi = kSpecialDistances[i][0];
991         int yi = kSpecialDistances[i][1];
992         int distance = yi * distance_multiplier + xi;
993         // Ensure that we map distance 1 to the lowest symbols.
994         if (distance < 1) distance = 1;
995         special_dist_table_[distance] = i;
996       }
997       num_special_distances_ = kNumSpecialDistances;
998     }
999   }
1000 
GetHashjxl::__anonfd93e7cc0311::HashChain1001   uint32_t GetHash(size_t pos) const {
1002     uint32_t result = 0;
1003     if (pos + 2 < size_) {
1004       // TODO(lode): take the MSB's of the uint32_t values into account as well,
1005       // given that the hash code itself is less than 32 bits.
1006       result ^= (uint32_t)(data_[pos + 0] << 0u);
1007       result ^= (uint32_t)(data_[pos + 1] << hash_shift_);
1008       result ^= (uint32_t)(data_[pos + 2] << (hash_shift_ * 2));
1009     } else {
1010       // No need to compute hash of last 2 bytes, the length 2 is too short.
1011       return 0;
1012     }
1013     return result & hash_mask_;
1014   }
1015 
CountZerosjxl::__anonfd93e7cc0311::HashChain1016   uint32_t CountZeros(size_t pos, uint32_t prevzeros) const {
1017     size_t end = pos + window_size_;
1018     if (end > size_) end = size_;
1019     if (prevzeros > 0) {
1020       if (prevzeros >= window_mask_ && data_[end - 1] == 0 &&
1021           end == pos + window_size_) {
1022         return prevzeros;
1023       } else {
1024         return prevzeros - 1;
1025       }
1026     }
1027     uint32_t num = 0;
1028     while (pos + num < end && data_[pos + num] == 0) num++;
1029     return num;
1030   }
1031 
Updatejxl::__anonfd93e7cc0311::HashChain1032   void Update(size_t pos) {
1033     uint32_t hashval = GetHash(pos);
1034     uint32_t wpos = pos & window_mask_;
1035 
1036     val[wpos] = (int)hashval;
1037     if (head[hashval] != -1) chain[wpos] = head[hashval];
1038     head[hashval] = wpos;
1039 
1040     if (pos > 0 && data_[pos] != data_[pos - 1]) numzeros = 0;
1041     numzeros = CountZeros(pos, numzeros);
1042 
1043     zeros[wpos] = numzeros;
1044     if (headz[numzeros] != -1) chainz[wpos] = headz[numzeros];
1045     headz[numzeros] = wpos;
1046   }
1047 
Updatejxl::__anonfd93e7cc0311::HashChain1048   void Update(size_t pos, size_t len) {
1049     for (size_t i = 0; i < len; i++) {
1050       Update(pos + i);
1051     }
1052   }
1053 
1054   template <typename CB>
FindMatchesjxl::__anonfd93e7cc0311::HashChain1055   void FindMatches(size_t pos, int max_dist, const CB& found_match) const {
1056     uint32_t wpos = pos & window_mask_;
1057     uint32_t hashval = GetHash(pos);
1058     uint32_t hashpos = chain[wpos];
1059 
1060     int prev_dist = 0;
1061     int end = std::min<int>(pos + max_length_, size_);
1062     uint32_t chainlength = 0;
1063     uint32_t best_len = 0;
1064     for (;;) {
1065       int dist = (hashpos <= wpos) ? (wpos - hashpos)
1066                                    : (wpos - hashpos + window_mask_ + 1);
1067       if (dist < prev_dist) break;
1068       prev_dist = dist;
1069       uint32_t len = 0;
1070       if (dist > 0) {
1071         int i = pos;
1072         int j = pos - dist;
1073         if (numzeros > 3) {
1074           int r = std::min<int>(numzeros - 1, zeros[hashpos]);
1075           if (i + r >= end) r = end - i - 1;
1076           i += r;
1077           j += r;
1078         }
1079         while (i < end && data_[i] == data_[j]) {
1080           i++;
1081           j++;
1082         }
1083         len = i - pos;
1084         // This can trigger even if the new length is slightly smaller than the
1085         // best length, because it is possible for a slightly cheaper distance
1086         // symbol to occur.
1087         if (len >= min_length_ && len + 2 >= best_len) {
1088           auto it = special_dist_table_.find(dist);
1089           int dist_symbol = (it == special_dist_table_.end())
1090                                 ? (num_special_distances_ + dist - 1)
1091                                 : it->second;
1092           found_match(len, dist_symbol);
1093           if (len > best_len) best_len = len;
1094         }
1095       }
1096 
1097       chainlength++;
1098       if (chainlength >= maxchainlength) break;
1099 
1100       if (numzeros >= 3 && len > numzeros) {
1101         if (hashpos == chainz[hashpos]) break;
1102         hashpos = chainz[hashpos];
1103         if (zeros[hashpos] != numzeros) break;
1104       } else {
1105         if (hashpos == chain[hashpos]) break;
1106         hashpos = chain[hashpos];
1107         if (val[hashpos] != (int)hashval) break;  // outdated hash value
1108       }
1109     }
1110   }
FindMatchjxl::__anonfd93e7cc0311::HashChain1111   void FindMatch(size_t pos, int max_dist, size_t* result_dist_symbol,
1112                  size_t* result_len) const {
1113     *result_dist_symbol = 0;
1114     *result_len = 1;
1115     FindMatches(pos, max_dist, [&](size_t len, size_t dist_symbol) {
1116       if (len > *result_len ||
1117           (len == *result_len && *result_dist_symbol > dist_symbol)) {
1118         *result_len = len;
1119         *result_dist_symbol = dist_symbol;
1120       }
1121     });
1122   }
1123 };
1124 
LenCost(size_t len)1125 float LenCost(size_t len) {
1126   uint32_t nbits, bits, tok;
1127   HybridUintConfig(1, 0, 0).Encode(len, &tok, &nbits, &bits);
1128   constexpr float kCostTable[] = {
1129       2.797667318563126,  3.213177690381199,  2.5706009246743737,
1130       2.408392498667534,  2.829649191872326,  3.3923087753324577,
1131       4.029267451554331,  4.415576699706408,  4.509357574741465,
1132       9.21481543803004,   10.020590190114898, 11.858671627804766,
1133       12.45853300490526,  11.713105831990857, 12.561996324849314,
1134       13.775477692278367, 13.174027068768641,
1135   };
1136   size_t table_size = sizeof kCostTable / sizeof *kCostTable;
1137   if (tok >= table_size) tok = table_size - 1;
1138   return kCostTable[tok] + nbits;
1139 }
1140 
1141 // TODO(veluca): this does not take into account usage or non-usage of distance
1142 // multipliers.
DistCost(size_t dist)1143 float DistCost(size_t dist) {
1144   uint32_t nbits, bits, tok;
1145   HybridUintConfig(7, 0, 0).Encode(dist, &tok, &nbits, &bits);
1146   constexpr float kCostTable[] = {
1147       6.368282626312716,  5.680793277090298,  8.347404197105247,
1148       7.641619201599141,  6.914328374119438,  7.959808291537444,
1149       8.70023120759855,   8.71378518934703,   9.379132523982769,
1150       9.110472749092708,  9.159029569270908,  9.430936766731973,
1151       7.278284055315169,  7.8278514904267755, 10.026641158289236,
1152       9.976049229827066,  9.64351607048908,   9.563403863480442,
1153       10.171474111762747, 10.45950155077234,  9.994813912104219,
1154       10.322524683741156, 8.465808729388186,  8.756254166066853,
1155       10.160930174662234, 10.247329273413435, 10.04090403724809,
1156       10.129398517544082, 9.342311691539546,  9.07608009102374,
1157       10.104799540677513, 10.378079384990906, 10.165828974075072,
1158       10.337595322341553, 7.940557464567944,  10.575665823319431,
1159       11.023344321751955, 10.736144698831827, 11.118277044595054,
1160       7.468468230648442,  10.738305230932939, 10.906980780216568,
1161       10.163468216353817, 10.17805759656433,  11.167283670483565,
1162       11.147050200274544, 10.517921919244333, 10.651764778156886,
1163       10.17074446448919,  11.217636876224745, 11.261630721139484,
1164       11.403140815247259, 10.892472096873417, 11.1859607804481,
1165       8.017346947551262,  7.895143720278828,  11.036577113822025,
1166       11.170562110315794, 10.326988722591086, 10.40872184751056,
1167       11.213498225466386, 11.30580635516863,  10.672272515665442,
1168       10.768069466228063, 11.145257364153565, 11.64668307145549,
1169       10.593156194627339, 11.207499484844943, 10.767517766396908,
1170       10.826629811407042, 10.737764794499988, 10.6200448518045,
1171       10.191315385198092, 8.468384171390085,  11.731295299170432,
1172       11.824619886654398, 10.41518844301179,  10.16310536548649,
1173       10.539423685097576, 10.495136599328031, 10.469112847728267,
1174       11.72057686174922,  10.910326337834674, 11.378921834673758,
1175       11.847759036098536, 11.92071647623854,  10.810628276345282,
1176       11.008601085273893, 11.910326337834674, 11.949212023423133,
1177       11.298614839104337, 11.611603659010392, 10.472930394619985,
1178       11.835564720850282, 11.523267392285337, 12.01055816679611,
1179       8.413029688994023,  11.895784139536406, 11.984679534970505,
1180       11.220654278717394, 11.716311684833672, 10.61036646226114,
1181       10.89849965960364,  10.203762898863669, 10.997560826267238,
1182       11.484217379438984, 11.792836176993665, 12.24310468755171,
1183       11.464858097919262, 12.212747017409377, 11.425595666074955,
1184       11.572048533398757, 12.742093965163013, 11.381874288645637,
1185       12.191870445817015, 11.683156920035426, 11.152442115262197,
1186       11.90303691580457,  11.653292787169159, 11.938615382266098,
1187       16.970641701570223, 16.853602280380002, 17.26240782594733,
1188       16.644655390108507, 17.14310889757499,  16.910935455445955,
1189       17.505678976959697, 17.213498225466388, 2.4162310293553024,
1190       3.494587244462329,  3.5258600986408344, 3.4959806589517095,
1191       3.098390886949687,  3.343454654302911,  3.588847442290287,
1192       4.14614790111827,   5.152948641990529,  7.433696808092598,
1193       9.716311684833672,
1194   };
1195   size_t table_size = sizeof kCostTable / sizeof *kCostTable;
1196   if (tok >= table_size) tok = table_size - 1;
1197   return kCostTable[tok] + nbits;
1198 }
1199 
ApplyLZ77_LZ77(const HistogramParams & params,size_t num_contexts,const std::vector<std::vector<Token>> & tokens,LZ77Params & lz77,std::vector<std::vector<Token>> & tokens_lz77)1200 void ApplyLZ77_LZ77(const HistogramParams& params, size_t num_contexts,
1201                     const std::vector<std::vector<Token>>& tokens,
1202                     LZ77Params& lz77,
1203                     std::vector<std::vector<Token>>& tokens_lz77) {
1204   // TODO(veluca): tune heuristics here.
1205   SymbolCostEstimator sce(num_contexts, params.force_huffman, tokens, lz77);
1206   float bit_decrease = 0;
1207   size_t total_symbols = 0;
1208   tokens_lz77.resize(tokens.size());
1209   HybridUintConfig uint_config;
1210   std::vector<float> sym_cost;
1211   for (size_t stream = 0; stream < tokens.size(); stream++) {
1212     size_t distance_multiplier =
1213         params.image_widths.size() > stream ? params.image_widths[stream] : 0;
1214     const auto& in = tokens[stream];
1215     auto& out = tokens_lz77[stream];
1216     total_symbols += in.size();
1217     // Cumulative sum of bit costs.
1218     sym_cost.resize(in.size() + 1);
1219     for (size_t i = 0; i < in.size(); i++) {
1220       uint32_t tok, nbits, unused_bits;
1221       uint_config.Encode(in[i].value, &tok, &nbits, &unused_bits);
1222       sym_cost[i + 1] = sce.Bits(in[i].context, tok) + nbits + sym_cost[i];
1223     }
1224 
1225     out.reserve(in.size());
1226     size_t max_distance = in.size();
1227     size_t min_length = lz77.min_length;
1228     JXL_ASSERT(min_length >= 3);
1229     size_t max_length = in.size();
1230 
1231     // Use next power of two as window size.
1232     size_t window_size = 1;
1233     while (window_size < max_distance && window_size < kWindowSize) {
1234       window_size <<= 1;
1235     }
1236 
1237     HashChain chain(in.data(), in.size(), window_size, min_length, max_length,
1238                     distance_multiplier);
1239     size_t len, dist_symbol;
1240 
1241     const size_t max_lazy_match_len = 256;  // 0 to disable lazy matching
1242 
1243     // Whether the next symbol was already updated (to test lazy matching)
1244     bool already_updated = false;
1245     for (size_t i = 0; i < in.size(); i++) {
1246       out.push_back(in[i]);
1247       if (!already_updated) chain.Update(i);
1248       already_updated = false;
1249       chain.FindMatch(i, max_distance, &dist_symbol, &len);
1250       if (len >= min_length) {
1251         if (len < max_lazy_match_len && i + 1 < in.size()) {
1252           // Try length at next symbol lazy matching
1253           chain.Update(i + 1);
1254           already_updated = true;
1255           size_t len2, dist_symbol2;
1256           chain.FindMatch(i + 1, max_distance, &dist_symbol2, &len2);
1257           if (len2 > len) {
1258             // Use the lazy match. Add literal, and use the next length starting
1259             // from the next byte.
1260             ++i;
1261             already_updated = false;
1262             len = len2;
1263             dist_symbol = dist_symbol2;
1264             out.push_back(in[i]);
1265           }
1266         }
1267 
1268         float cost = sym_cost[i + len] - sym_cost[i];
1269         size_t lz77_len = len - lz77.min_length;
1270         float lz77_cost = LenCost(lz77_len) + DistCost(dist_symbol) +
1271                           sce.AddSymbolCost(out.back().context);
1272 
1273         if (lz77_cost <= cost) {
1274           out.back().value = len - min_length;
1275           out.back().is_lz77_length = true;
1276           out.emplace_back(lz77.nonserialized_distance_context, dist_symbol);
1277           bit_decrease += cost - lz77_cost;
1278         } else {
1279           // LZ77 match ignored, and symbol already pushed. Push all other
1280           // symbols and skip.
1281           for (size_t j = 1; j < len; j++) {
1282             out.push_back(in[i + j]);
1283           }
1284         }
1285 
1286         if (already_updated) {
1287           chain.Update(i + 2, len - 2);
1288           already_updated = false;
1289         } else {
1290           chain.Update(i + 1, len - 1);
1291         }
1292         i += len - 1;
1293       } else {
1294         // Literal, already pushed
1295       }
1296     }
1297   }
1298 
1299   if (bit_decrease > total_symbols * 0.2 + 16) {
1300     lz77.enabled = true;
1301   }
1302 }
1303 
ApplyLZ77_Optimal(const HistogramParams & params,size_t num_contexts,const std::vector<std::vector<Token>> & tokens,LZ77Params & lz77,std::vector<std::vector<Token>> & tokens_lz77)1304 void ApplyLZ77_Optimal(const HistogramParams& params, size_t num_contexts,
1305                        const std::vector<std::vector<Token>>& tokens,
1306                        LZ77Params& lz77,
1307                        std::vector<std::vector<Token>>& tokens_lz77) {
1308   std::vector<std::vector<Token>> tokens_for_cost_estimate;
1309   ApplyLZ77_LZ77(params, num_contexts, tokens, lz77, tokens_for_cost_estimate);
1310   // If greedy-LZ77 does not give better compression than no-lz77, no reason to
1311   // run the optimal matching.
1312   if (!lz77.enabled) return;
1313   SymbolCostEstimator sce(num_contexts + 1, params.force_huffman,
1314                           tokens_for_cost_estimate, lz77);
1315   tokens_lz77.resize(tokens.size());
1316   HybridUintConfig uint_config;
1317   std::vector<float> sym_cost;
1318   std::vector<uint32_t> dist_symbols;
1319   for (size_t stream = 0; stream < tokens.size(); stream++) {
1320     size_t distance_multiplier =
1321         params.image_widths.size() > stream ? params.image_widths[stream] : 0;
1322     const auto& in = tokens[stream];
1323     auto& out = tokens_lz77[stream];
1324     // Cumulative sum of bit costs.
1325     sym_cost.resize(in.size() + 1);
1326     for (size_t i = 0; i < in.size(); i++) {
1327       uint32_t tok, nbits, unused_bits;
1328       uint_config.Encode(in[i].value, &tok, &nbits, &unused_bits);
1329       sym_cost[i + 1] = sce.Bits(in[i].context, tok) + nbits + sym_cost[i];
1330     }
1331 
1332     out.reserve(in.size());
1333     size_t max_distance = in.size();
1334     size_t min_length = lz77.min_length;
1335     JXL_ASSERT(min_length >= 3);
1336     size_t max_length = in.size();
1337 
1338     // Use next power of two as window size.
1339     size_t window_size = 1;
1340     while (window_size < max_distance && window_size < kWindowSize) {
1341       window_size <<= 1;
1342     }
1343 
1344     HashChain chain(in.data(), in.size(), window_size, min_length, max_length,
1345                     distance_multiplier);
1346 
1347     struct MatchInfo {
1348       uint32_t len;
1349       uint32_t dist_symbol;
1350       uint32_t ctx;
1351       float total_cost = std::numeric_limits<float>::max();
1352     };
1353     // Total cost to encode the first N symbols.
1354     std::vector<MatchInfo> prefix_costs(in.size() + 1);
1355     prefix_costs[0].total_cost = 0;
1356 
1357     size_t rle_length = 0;
1358     size_t skip_lz77 = 0;
1359     for (size_t i = 0; i < in.size(); i++) {
1360       chain.Update(i);
1361       float lit_cost =
1362           prefix_costs[i].total_cost + sym_cost[i + 1] - sym_cost[i];
1363       if (prefix_costs[i + 1].total_cost > lit_cost) {
1364         prefix_costs[i + 1].dist_symbol = 0;
1365         prefix_costs[i + 1].len = 1;
1366         prefix_costs[i + 1].ctx = in[i].context;
1367         prefix_costs[i + 1].total_cost = lit_cost;
1368       }
1369       if (skip_lz77 > 0) {
1370         skip_lz77--;
1371         continue;
1372       }
1373       dist_symbols.clear();
1374       chain.FindMatches(i, max_distance,
1375                         [&dist_symbols](size_t len, size_t dist_symbol) {
1376                           if (dist_symbols.size() <= len) {
1377                             dist_symbols.resize(len + 1, dist_symbol);
1378                           }
1379                           if (dist_symbol < dist_symbols[len]) {
1380                             dist_symbols[len] = dist_symbol;
1381                           }
1382                         });
1383       if (dist_symbols.size() <= min_length) continue;
1384       {
1385         size_t best_cost = dist_symbols.back();
1386         for (size_t j = dist_symbols.size() - 1; j >= min_length; j--) {
1387           if (dist_symbols[j] < best_cost) {
1388             best_cost = dist_symbols[j];
1389           }
1390           dist_symbols[j] = best_cost;
1391         }
1392       }
1393       for (size_t j = min_length; j < dist_symbols.size(); j++) {
1394         // Cost model that uses results from lazy LZ77.
1395         float lz77_cost = sce.LenCost(in[i].context, j - min_length, lz77) +
1396                           sce.DistCost(dist_symbols[j], lz77);
1397         float cost = prefix_costs[i].total_cost + lz77_cost;
1398         if (prefix_costs[i + j].total_cost > cost) {
1399           prefix_costs[i + j].len = j;
1400           prefix_costs[i + j].dist_symbol = dist_symbols[j] + 1;
1401           prefix_costs[i + j].ctx = in[i].context;
1402           prefix_costs[i + j].total_cost = cost;
1403         }
1404       }
1405       // We are in a RLE sequence: skip all the symbols except the first 8 and
1406       // the last 8. This avoid quadratic costs for sequences with long runs of
1407       // the same symbol.
1408       if ((dist_symbols.back() == 0 && distance_multiplier == 0) ||
1409           (dist_symbols.back() == 1 && distance_multiplier != 0)) {
1410         rle_length++;
1411       } else {
1412         rle_length = 0;
1413       }
1414       if (rle_length >= 8 && dist_symbols.size() > 9) {
1415         skip_lz77 = dist_symbols.size() - 10;
1416         rle_length = 0;
1417       }
1418     }
1419     size_t pos = in.size();
1420     while (pos > 0) {
1421       bool is_lz77_length = prefix_costs[pos].dist_symbol != 0;
1422       if (is_lz77_length) {
1423         size_t dist_symbol = prefix_costs[pos].dist_symbol - 1;
1424         out.emplace_back(lz77.nonserialized_distance_context, dist_symbol);
1425       }
1426       size_t val = is_lz77_length ? prefix_costs[pos].len - min_length
1427                                   : in[pos - 1].value;
1428       out.emplace_back(prefix_costs[pos].ctx, val);
1429       out.back().is_lz77_length = is_lz77_length;
1430       pos -= prefix_costs[pos].len;
1431     }
1432     std::reverse(out.begin(), out.end());
1433   }
1434 }
1435 
ApplyLZ77(const HistogramParams & params,size_t num_contexts,const std::vector<std::vector<Token>> & tokens,LZ77Params & lz77,std::vector<std::vector<Token>> & tokens_lz77)1436 void ApplyLZ77(const HistogramParams& params, size_t num_contexts,
1437                const std::vector<std::vector<Token>>& tokens, LZ77Params& lz77,
1438                std::vector<std::vector<Token>>& tokens_lz77) {
1439   lz77.enabled = false;
1440   if (params.force_huffman) {
1441     lz77.min_symbol = std::min(PREFIX_MAX_ALPHABET_SIZE - 32, 512);
1442   } else {
1443     lz77.min_symbol = 224;
1444   }
1445   if (params.lz77_method == HistogramParams::LZ77Method::kNone) {
1446     return;
1447   } else if (params.lz77_method == HistogramParams::LZ77Method::kRLE) {
1448     ApplyLZ77_RLE(params, num_contexts, tokens, lz77, tokens_lz77);
1449   } else if (params.lz77_method == HistogramParams::LZ77Method::kLZ77) {
1450     ApplyLZ77_LZ77(params, num_contexts, tokens, lz77, tokens_lz77);
1451   } else if (params.lz77_method == HistogramParams::LZ77Method::kOptimal) {
1452     ApplyLZ77_Optimal(params, num_contexts, tokens, lz77, tokens_lz77);
1453   } else {
1454     JXL_ABORT("Not implemented");
1455   }
1456 }
1457 }  // namespace
1458 
BuildAndEncodeHistograms(const HistogramParams & params,size_t num_contexts,std::vector<std::vector<Token>> & tokens,EntropyEncodingData * codes,std::vector<uint8_t> * context_map,BitWriter * writer,size_t layer,AuxOut * aux_out)1459 size_t BuildAndEncodeHistograms(const HistogramParams& params,
1460                                 size_t num_contexts,
1461                                 std::vector<std::vector<Token>>& tokens,
1462                                 EntropyEncodingData* codes,
1463                                 std::vector<uint8_t>* context_map,
1464                                 BitWriter* writer, size_t layer,
1465                                 AuxOut* aux_out) {
1466   size_t total_bits = 0;
1467   codes->lz77.nonserialized_distance_context = num_contexts;
1468   std::vector<std::vector<Token>> tokens_lz77;
1469   ApplyLZ77(params, num_contexts, tokens, codes->lz77, tokens_lz77);
1470   if (ans_fuzzer_friendly_) {
1471     codes->lz77.length_uint_config = HybridUintConfig(10, 0, 0);
1472     codes->lz77.min_symbol = 2048;
1473   }
1474 
1475   const size_t max_contexts = std::min(num_contexts, kClustersLimit);
1476   BitWriter::Allotment allotment(writer,
1477                                  128 + num_contexts * 40 + max_contexts * 96);
1478   if (writer) {
1479     JXL_CHECK(Bundle::Write(codes->lz77, writer, layer, aux_out));
1480   } else {
1481     size_t ebits, bits;
1482     JXL_CHECK(Bundle::CanEncode(codes->lz77, &ebits, &bits));
1483     total_bits += bits;
1484   }
1485   if (codes->lz77.enabled) {
1486     if (writer) {
1487       size_t b = writer->BitsWritten();
1488       EncodeUintConfig(codes->lz77.length_uint_config, writer,
1489                        /*log_alpha_size=*/8);
1490       total_bits += writer->BitsWritten() - b;
1491     } else {
1492       SizeWriter size_writer;
1493       EncodeUintConfig(codes->lz77.length_uint_config, &size_writer,
1494                        /*log_alpha_size=*/8);
1495       total_bits += size_writer.size;
1496     }
1497     num_contexts += 1;
1498     tokens = std::move(tokens_lz77);
1499   }
1500   size_t total_tokens = 0;
1501   // Build histograms.
1502   HistogramBuilder builder(num_contexts);
1503   HybridUintConfig uint_config;  //  Default config for clustering.
1504   // Unless we are using the kContextMap histogram option.
1505   if (params.uint_method == HistogramParams::HybridUintMethod::kContextMap) {
1506     uint_config = HybridUintConfig(2, 0, 1);
1507   }
1508   if (ans_fuzzer_friendly_) {
1509     uint_config = HybridUintConfig(10, 0, 0);
1510   }
1511   for (size_t i = 0; i < tokens.size(); ++i) {
1512     if (codes->lz77.enabled) {
1513       for (size_t j = 0; j < tokens[i].size(); ++j) {
1514         const Token& token = tokens[i][j];
1515         total_tokens++;
1516         uint32_t tok, nbits, bits;
1517         (token.is_lz77_length ? codes->lz77.length_uint_config : uint_config)
1518             .Encode(token.value, &tok, &nbits, &bits);
1519         tok += token.is_lz77_length ? codes->lz77.min_symbol : 0;
1520         builder.VisitSymbol(tok, token.context);
1521       }
1522     } else if (num_contexts == 1) {
1523       for (size_t j = 0; j < tokens[i].size(); ++j) {
1524         const Token& token = tokens[i][j];
1525         total_tokens++;
1526         uint32_t tok, nbits, bits;
1527         uint_config.Encode(token.value, &tok, &nbits, &bits);
1528         builder.VisitSymbol(tok, /*token.context=*/0);
1529       }
1530     } else {
1531       for (size_t j = 0; j < tokens[i].size(); ++j) {
1532         const Token& token = tokens[i][j];
1533         total_tokens++;
1534         uint32_t tok, nbits, bits;
1535         uint_config.Encode(token.value, &tok, &nbits, &bits);
1536         builder.VisitSymbol(tok, token.context);
1537       }
1538     }
1539   }
1540 
1541   bool use_prefix_code =
1542       params.force_huffman || total_tokens < 100 ||
1543       params.clustering == HistogramParams::ClusteringType::kFastest ||
1544       ans_fuzzer_friendly_;
1545   if (!use_prefix_code) {
1546     bool all_singleton = true;
1547     for (size_t i = 0; i < num_contexts; i++) {
1548       if (builder.Histo(i).ShannonEntropy() >= 1e-5) {
1549         all_singleton = false;
1550       }
1551     }
1552     if (all_singleton) {
1553       use_prefix_code = true;
1554     }
1555   }
1556 
1557   // Encode histograms.
1558   total_bits += builder.BuildAndStoreEntropyCodes(params, tokens, codes,
1559                                                   context_map, use_prefix_code,
1560                                                   writer, layer, aux_out);
1561   allotment.FinishedHistogram(writer);
1562   ReclaimAndCharge(writer, &allotment, layer, aux_out);
1563 
1564   if (aux_out != nullptr) {
1565     aux_out->layers[layer].num_clustered_histograms +=
1566         codes->encoding_info.size();
1567   }
1568   return total_bits;
1569 }
1570 
WriteTokens(const std::vector<Token> & tokens,const EntropyEncodingData & codes,const std::vector<uint8_t> & context_map,BitWriter * writer)1571 size_t WriteTokens(const std::vector<Token>& tokens,
1572                    const EntropyEncodingData& codes,
1573                    const std::vector<uint8_t>& context_map, BitWriter* writer) {
1574   size_t num_extra_bits = 0;
1575   if (codes.use_prefix_code) {
1576     for (size_t i = 0; i < tokens.size(); i++) {
1577       uint32_t tok, nbits, bits;
1578       const Token& token = tokens[i];
1579       size_t histo = context_map[token.context];
1580       (token.is_lz77_length ? codes.lz77.length_uint_config
1581                             : codes.uint_config[histo])
1582           .Encode(token.value, &tok, &nbits, &bits);
1583       tok += token.is_lz77_length ? codes.lz77.min_symbol : 0;
1584       // Combine two calls to the BitWriter. Equivalent to:
1585       // writer->Write(codes.encoding_info[histo][tok].depth,
1586       //               codes.encoding_info[histo][tok].bits);
1587       // writer->Write(nbits, bits);
1588       uint64_t data = codes.encoding_info[histo][tok].bits;
1589       data |= bits << codes.encoding_info[histo][tok].depth;
1590       writer->Write(codes.encoding_info[histo][tok].depth + nbits, data);
1591       num_extra_bits += nbits;
1592     }
1593     return num_extra_bits;
1594   }
1595   std::vector<uint64_t> out;
1596   std::vector<uint8_t> out_nbits;
1597   out.reserve(tokens.size());
1598   out_nbits.reserve(tokens.size());
1599   uint64_t allbits = 0;
1600   size_t numallbits = 0;
1601   // Writes in *reversed* order.
1602   auto addbits = [&](size_t bits, size_t nbits) {
1603     if (JXL_UNLIKELY(nbits)) {
1604       JXL_DASSERT(bits >> nbits == 0);
1605       if (JXL_UNLIKELY(numallbits + nbits > BitWriter::kMaxBitsPerCall)) {
1606         out.push_back(allbits);
1607         out_nbits.push_back(numallbits);
1608         numallbits = allbits = 0;
1609       }
1610       allbits <<= nbits;
1611       allbits |= bits;
1612       numallbits += nbits;
1613     }
1614   };
1615   const int end = tokens.size();
1616   ANSCoder ans;
1617   if (codes.lz77.enabled || context_map.size() > 1) {
1618     for (int i = end - 1; i >= 0; --i) {
1619       const Token token = tokens[i];
1620       const uint8_t histo = context_map[token.context];
1621       uint32_t tok, nbits, bits;
1622       (token.is_lz77_length ? codes.lz77.length_uint_config
1623                             : codes.uint_config[histo])
1624           .Encode(tokens[i].value, &tok, &nbits, &bits);
1625       tok += token.is_lz77_length ? codes.lz77.min_symbol : 0;
1626       const ANSEncSymbolInfo& info = codes.encoding_info[histo][tok];
1627       // Extra bits first as this is reversed.
1628       addbits(bits, nbits);
1629       num_extra_bits += nbits;
1630       uint8_t ans_nbits = 0;
1631       uint32_t ans_bits = ans.PutSymbol(info, &ans_nbits);
1632       addbits(ans_bits, ans_nbits);
1633     }
1634   } else {
1635     for (int i = end - 1; i >= 0; --i) {
1636       uint32_t tok, nbits, bits;
1637       codes.uint_config[0].Encode(tokens[i].value, &tok, &nbits, &bits);
1638       const ANSEncSymbolInfo& info = codes.encoding_info[0][tok];
1639       // Extra bits first as this is reversed.
1640       addbits(bits, nbits);
1641       num_extra_bits += nbits;
1642       uint8_t ans_nbits = 0;
1643       uint32_t ans_bits = ans.PutSymbol(info, &ans_nbits);
1644       addbits(ans_bits, ans_nbits);
1645     }
1646   }
1647   const uint32_t state = ans.GetState();
1648   writer->Write(32, state);
1649   writer->Write(numallbits, allbits);
1650   for (int i = out.size(); i > 0; --i) {
1651     writer->Write(out_nbits[i - 1], out[i - 1]);
1652   }
1653   return num_extra_bits;
1654 }
1655 
WriteTokens(const std::vector<Token> & tokens,const EntropyEncodingData & codes,const std::vector<uint8_t> & context_map,BitWriter * writer,size_t layer,AuxOut * aux_out)1656 void WriteTokens(const std::vector<Token>& tokens,
1657                  const EntropyEncodingData& codes,
1658                  const std::vector<uint8_t>& context_map, BitWriter* writer,
1659                  size_t layer, AuxOut* aux_out) {
1660   BitWriter::Allotment allotment(writer, 32 * tokens.size() + 32 * 1024 * 4);
1661   size_t num_extra_bits = WriteTokens(tokens, codes, context_map, writer);
1662   ReclaimAndCharge(writer, &allotment, layer, aux_out);
1663   if (aux_out != nullptr) {
1664     aux_out->layers[layer].extra_bits += num_extra_bits;
1665   }
1666 }
1667 
SetANSFuzzerFriendly(bool ans_fuzzer_friendly)1668 void SetANSFuzzerFriendly(bool ans_fuzzer_friendly) {
1669 #if JXL_IS_DEBUG_BUILD  // Guard against accidental / malicious changes.
1670   ans_fuzzer_friendly_ = ans_fuzzer_friendly;
1671 #endif
1672 }
1673 }  // namespace jxl
1674