1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
2 //
3 // Use of this source code is governed by a BSD-style
4 // license that can be found in the LICENSE file.
5 
6 #include "lib/jxl/modular/encoding/enc_ma.h"
7 
8 #include <algorithm>
9 #include <limits>
10 #include <numeric>
11 #include <queue>
12 #include <random>
13 #include <unordered_map>
14 #include <unordered_set>
15 
16 #include "lib/jxl/modular/encoding/ma_common.h"
17 
18 #undef HWY_TARGET_INCLUDE
19 #define HWY_TARGET_INCLUDE "lib/jxl/modular/encoding/enc_ma.cc"
20 #include <hwy/foreach_target.h>
21 #include <hwy/highway.h>
22 
23 #ifndef LIB_JXL_ENC_MODULAR_ENCODING_MA_
24 #define LIB_JXL_ENC_MODULAR_ENCODING_MA_
25 namespace {
26 struct Rng {
27   uint64_t s[2];
Rng__anonc19bfe650111::Rng28   explicit Rng(size_t seed)
29       : s{0x94D049BB133111EBull, 0xBF58476D1CE4E5B9ull + seed} {}
30   // Xorshift128+ adapted from xorshift128+-inl.h
operator ()__anonc19bfe650111::Rng31   uint64_t operator()() {
32     uint64_t s1 = s[0];
33     const uint64_t s0 = s[1];
34     const uint64_t bits = s1 + s0;  // b, c
35     s[0] = s0;
36     s1 ^= s1 << 23;
37     s1 ^= s0 ^ (s1 >> 18) ^ (s0 >> 5);
38     s[1] = s1;
39     return bits;
40   }
max__anonc19bfe650111::Rng41   static constexpr uint64_t max() { return ~0ULL; }
min__anonc19bfe650111::Rng42   static constexpr uint64_t min() { return 0; }
43 };
44 }  // namespace
45 #endif
46 
47 #include "lib/jxl/enc_ans.h"
48 #include "lib/jxl/fast_math-inl.h"
49 #include "lib/jxl/modular/encoding/context_predict.h"
50 #include "lib/jxl/modular/options.h"
51 HWY_BEFORE_NAMESPACE();
52 namespace jxl {
53 namespace HWY_NAMESPACE {
54 
55 const HWY_FULL(float) df;
56 const HWY_FULL(int32_t) di;
Padded(size_t x)57 size_t Padded(size_t x) { return RoundUpTo(x, Lanes(df)); }
58 
EstimateBits(const int32_t * counts,int32_t * rounded_counts,size_t num_symbols)59 float EstimateBits(const int32_t *counts, int32_t *rounded_counts,
60                    size_t num_symbols) {
61   // Try to approximate the effect of rounding up nonzero probabilities.
62   int32_t total = std::accumulate(counts, counts + num_symbols, 0);
63   const auto min = Set(di, (total + ANS_TAB_SIZE - 1) >> ANS_LOG_TAB_SIZE);
64   const auto zero_i = Zero(di);
65   for (size_t i = 0; i < num_symbols; i += Lanes(df)) {
66     auto counts_v = LoadU(di, &counts[i]);
67     counts_v = IfThenElse(counts_v == zero_i, zero_i,
68                           IfThenElse(counts_v < min, min, counts_v));
69     StoreU(counts_v, di, &rounded_counts[i]);
70   }
71   // Compute entropy of the "rounded" probabilities.
72   const auto zero = Zero(df);
73   const size_t total_scalar =
74       std::accumulate(rounded_counts, rounded_counts + num_symbols, 0);
75   const auto inv_total = Set(df, 1.0f / total_scalar);
76   auto bits_lanes = Zero(df);
77   auto total_v = Set(di, total_scalar);
78   for (size_t i = 0; i < num_symbols; i += Lanes(df)) {
79     const auto counts_v = ConvertTo(df, LoadU(di, &counts[i]));
80     const auto round_counts_v = LoadU(di, &rounded_counts[i]);
81     const auto probs = ConvertTo(df, round_counts_v) * inv_total;
82     const auto nbps = IfThenElse(round_counts_v == total_v, BitCast(di, zero),
83                                  BitCast(di, FastLog2f(df, probs)));
84     bits_lanes -=
85         IfThenElse(counts_v == zero, zero, counts_v * BitCast(df, nbps));
86   }
87   return GetLane(SumOfLanes(bits_lanes));
88 }
89 
MakeSplitNode(size_t pos,int property,int splitval,Predictor lpred,int64_t loff,Predictor rpred,int64_t roff,Tree * tree)90 void MakeSplitNode(size_t pos, int property, int splitval, Predictor lpred,
91                    int64_t loff, Predictor rpred, int64_t roff, Tree *tree) {
92   // Note that the tree splits on *strictly greater*.
93   (*tree)[pos].lchild = tree->size();
94   (*tree)[pos].rchild = tree->size() + 1;
95   (*tree)[pos].splitval = splitval;
96   (*tree)[pos].property = property;
97   tree->emplace_back();
98   tree->back().property = -1;
99   tree->back().predictor = rpred;
100   tree->back().predictor_offset = roff;
101   tree->back().multiplier = 1;
102   tree->emplace_back();
103   tree->back().property = -1;
104   tree->back().predictor = lpred;
105   tree->back().predictor_offset = loff;
106   tree->back().multiplier = 1;
107 }
108 
109 enum class IntersectionType { kNone, kPartial, kInside };
BoxIntersects(StaticPropRange needle,StaticPropRange haystack,uint32_t & partial_axis,uint32_t & partial_val)110 IntersectionType BoxIntersects(StaticPropRange needle, StaticPropRange haystack,
111                                uint32_t &partial_axis, uint32_t &partial_val) {
112   bool partial = false;
113   for (size_t i = 0; i < kNumStaticProperties; i++) {
114     if (haystack[i][0] >= needle[i][1]) {
115       return IntersectionType::kNone;
116     }
117     if (haystack[i][1] <= needle[i][0]) {
118       return IntersectionType::kNone;
119     }
120     if (haystack[i][0] <= needle[i][0] && haystack[i][1] >= needle[i][1]) {
121       continue;
122     }
123     partial = true;
124     partial_axis = i;
125     if (haystack[i][0] > needle[i][0] && haystack[i][0] < needle[i][1]) {
126       partial_val = haystack[i][0] - 1;
127     } else {
128       JXL_DASSERT(haystack[i][1] > needle[i][0] &&
129                   haystack[i][1] < needle[i][1]);
130       partial_val = haystack[i][1] - 1;
131     }
132   }
133   return partial ? IntersectionType::kPartial : IntersectionType::kInside;
134 }
135 
SplitTreeSamples(TreeSamples & tree_samples,size_t begin,size_t pos,size_t end,size_t prop)136 void SplitTreeSamples(TreeSamples &tree_samples, size_t begin, size_t pos,
137                       size_t end, size_t prop) {
138   auto cmp = [&](size_t a, size_t b) {
139     return int32_t(tree_samples.Property(prop, a)) -
140            int32_t(tree_samples.Property(prop, b));
141   };
142   Rng rng(0);
143   while (end > begin + 1) {
144     {
145       JXL_ASSERT(end > begin);  // silence clang-tidy.
146       size_t pivot = rng() % (end - begin) + begin;
147       tree_samples.Swap(begin, pivot);
148     }
149     size_t pivot_begin = begin;
150     size_t pivot_end = pivot_begin + 1;
151     for (size_t i = begin + 1; i < end; i++) {
152       JXL_DASSERT(i >= pivot_end);
153       JXL_DASSERT(pivot_end > pivot_begin);
154       int32_t cmp_result = cmp(i, pivot_begin);
155       if (cmp_result < 0) {  // i < pivot, move pivot forward and put i before
156                              // the pivot.
157         tree_samples.ThreeShuffle(pivot_begin, pivot_end, i);
158         pivot_begin++;
159         pivot_end++;
160       } else if (cmp_result == 0) {
161         tree_samples.Swap(pivot_end, i);
162         pivot_end++;
163       }
164     }
165     JXL_DASSERT(pivot_begin >= begin);
166     JXL_DASSERT(pivot_end > pivot_begin);
167     JXL_DASSERT(pivot_end <= end);
168     for (size_t i = begin; i < pivot_begin; i++) {
169       JXL_DASSERT(cmp(i, pivot_begin) < 0);
170     }
171     for (size_t i = pivot_end; i < end; i++) {
172       JXL_DASSERT(cmp(i, pivot_begin) > 0);
173     }
174     for (size_t i = pivot_begin; i < pivot_end; i++) {
175       JXL_DASSERT(cmp(i, pivot_begin) == 0);
176     }
177     // We now have that [begin, pivot_begin) is < pivot, [pivot_begin,
178     // pivot_end) is = pivot, and [pivot_end, end) is > pivot.
179     // If pos falls in the first or the last interval, we continue in that
180     // interval; otherwise, we are done.
181     if (pivot_begin > pos) {
182       end = pivot_begin;
183     } else if (pivot_end < pos) {
184       begin = pivot_end;
185     } else {
186       break;
187     }
188   }
189 }
190 
FindBestSplit(TreeSamples & tree_samples,float threshold,const std::vector<ModularMultiplierInfo> & mul_info,StaticPropRange initial_static_prop_range,float fast_decode_multiplier,Tree * tree)191 void FindBestSplit(TreeSamples &tree_samples, float threshold,
192                    const std::vector<ModularMultiplierInfo> &mul_info,
193                    StaticPropRange initial_static_prop_range,
194                    float fast_decode_multiplier, Tree *tree) {
195   struct NodeInfo {
196     size_t pos;
197     size_t begin;
198     size_t end;
199     uint64_t used_properties;
200     StaticPropRange static_prop_range;
201   };
202   std::vector<NodeInfo> nodes;
203   nodes.push_back(NodeInfo{0, 0, tree_samples.NumDistinctSamples(), 0,
204                            initial_static_prop_range});
205 
206   size_t num_predictors = tree_samples.NumPredictors();
207   size_t num_properties = tree_samples.NumProperties();
208 
209   // TODO(veluca): consider parallelizing the search (processing multiple nodes
210   // at a time).
211   while (!nodes.empty()) {
212     size_t pos = nodes.back().pos;
213     size_t begin = nodes.back().begin;
214     size_t end = nodes.back().end;
215     uint64_t used_properties = nodes.back().used_properties;
216     StaticPropRange static_prop_range = nodes.back().static_prop_range;
217     nodes.pop_back();
218     if (begin == end) continue;
219 
220     struct SplitInfo {
221       size_t prop = 0;
222       uint32_t val = 0;
223       size_t pos = 0;
224       float lcost = std::numeric_limits<float>::max();
225       float rcost = std::numeric_limits<float>::max();
226       Predictor lpred = Predictor::Zero;
227       Predictor rpred = Predictor::Zero;
228       float Cost() { return lcost + rcost; }
229     };
230 
231     SplitInfo best_split_static_constant;
232     SplitInfo best_split_static;
233     SplitInfo best_split_nonstatic;
234     SplitInfo best_split_nowp;
235 
236     JXL_DASSERT(begin <= end);
237     JXL_DASSERT(end <= tree_samples.NumDistinctSamples());
238 
239     // Compute the maximum token in the range.
240     size_t max_symbols = 0;
241     for (size_t pred = 0; pred < num_predictors; pred++) {
242       for (size_t i = begin; i < end; i++) {
243         uint32_t tok = tree_samples.Token(pred, i);
244         max_symbols = max_symbols > tok + 1 ? max_symbols : tok + 1;
245       }
246     }
247     max_symbols = Padded(max_symbols);
248     std::vector<int32_t> rounded_counts(max_symbols);
249     std::vector<int32_t> counts(max_symbols * num_predictors);
250     std::vector<uint32_t> tot_extra_bits(num_predictors);
251     for (size_t pred = 0; pred < num_predictors; pred++) {
252       for (size_t i = begin; i < end; i++) {
253         counts[pred * max_symbols + tree_samples.Token(pred, i)] +=
254             tree_samples.Count(i);
255         tot_extra_bits[pred] +=
256             tree_samples.NBits(pred, i) * tree_samples.Count(i);
257       }
258     }
259 
260     float base_bits;
261     {
262       size_t pred = tree_samples.PredictorIndex((*tree)[pos].predictor);
263       base_bits = EstimateBits(counts.data() + pred * max_symbols,
264                                rounded_counts.data(), max_symbols) +
265                   tot_extra_bits[pred];
266     }
267 
268     SplitInfo *best = &best_split_nonstatic;
269 
270     SplitInfo forced_split;
271     // The multiplier ranges cut halfway through the current ranges of static
272     // properties. We do this even if the current node is not a leaf, to
273     // minimize the number of nodes in the resulting tree.
274     for (size_t i = 0; i < mul_info.size(); i++) {
275       uint32_t axis, val;
276       IntersectionType t =
277           BoxIntersects(static_prop_range, mul_info[i].range, axis, val);
278       if (t == IntersectionType::kNone) continue;
279       if (t == IntersectionType::kInside) {
280         (*tree)[pos].multiplier = mul_info[i].multiplier;
281         break;
282       }
283       if (t == IntersectionType::kPartial) {
284         forced_split.val = tree_samples.QuantizeProperty(axis, val);
285         forced_split.prop = axis;
286         forced_split.lcost = forced_split.rcost = base_bits / 2 - threshold;
287         forced_split.lpred = forced_split.rpred = (*tree)[pos].predictor;
288         best = &forced_split;
289         best->pos = begin;
290         JXL_ASSERT(best->prop == tree_samples.PropertyFromIndex(best->prop));
291         for (size_t x = begin; x < end; x++) {
292           if (tree_samples.Property(best->prop, x) <= best->val) {
293             best->pos++;
294           }
295         }
296         break;
297       }
298     }
299 
300     if (best != &forced_split) {
301       std::vector<int> prop_value_used_count;
302       std::vector<int> count_increase;
303       std::vector<size_t> extra_bits_increase;
304       // For each property, compute which of its values are used, and what
305       // tokens correspond to those usages. Then, iterate through the values,
306       // and compute the entropy of each side of the split (of the form `prop >
307       // threshold`). Finally, find the split that minimizes the cost.
308       struct CostInfo {
309         float cost = std::numeric_limits<float>::max();
310         float extra_cost = 0;
311         float Cost() const { return cost + extra_cost; }
312         Predictor pred;  // will be uninitialized in some cases, but never used.
313       };
314       std::vector<CostInfo> costs_l;
315       std::vector<CostInfo> costs_r;
316 
317       std::vector<int32_t> counts_above(max_symbols);
318       std::vector<int32_t> counts_below(max_symbols);
319 
320       // The lower the threshold, the higher the expected noisiness of the
321       // estimate. Thus, discourage changing predictors.
322       float change_pred_penalty = 800.0f / (100.0f + threshold);
323       for (size_t prop = 0; prop < num_properties && base_bits > threshold;
324            prop++) {
325         costs_l.clear();
326         costs_r.clear();
327         size_t prop_size = tree_samples.NumPropertyValues(prop);
328         if (extra_bits_increase.size() < prop_size) {
329           count_increase.resize(prop_size * max_symbols);
330           extra_bits_increase.resize(prop_size);
331         }
332         // Clear prop_value_used_count (which cannot be cleared "on the go")
333         prop_value_used_count.clear();
334         prop_value_used_count.resize(prop_size);
335 
336         size_t first_used = prop_size;
337         size_t last_used = 0;
338 
339         // TODO(veluca): consider finding multiple splits along a single
340         // property at the same time, possibly with a bottom-up approach.
341         for (size_t i = begin; i < end; i++) {
342           size_t p = tree_samples.Property(prop, i);
343           prop_value_used_count[p]++;
344           last_used = std::max(last_used, p);
345           first_used = std::min(first_used, p);
346         }
347         costs_l.resize(last_used - first_used);
348         costs_r.resize(last_used - first_used);
349         // For all predictors, compute the right and left costs of each split.
350         for (size_t pred = 0; pred < num_predictors; pred++) {
351           // Compute cost and histogram increments for each property value.
352           for (size_t i = begin; i < end; i++) {
353             size_t p = tree_samples.Property(prop, i);
354             size_t cnt = tree_samples.Count(i);
355             size_t sym = tree_samples.Token(pred, i);
356             count_increase[p * max_symbols + sym] += cnt;
357             extra_bits_increase[p] += tree_samples.NBits(pred, i) * cnt;
358           }
359           memcpy(counts_above.data(), counts.data() + pred * max_symbols,
360                  max_symbols * sizeof counts_above[0]);
361           memset(counts_below.data(), 0, max_symbols * sizeof counts_below[0]);
362           size_t extra_bits_below = 0;
363           // Exclude last used: this ensures neither counts_above nor
364           // counts_below is empty.
365           for (size_t i = first_used; i < last_used; i++) {
366             if (!prop_value_used_count[i]) continue;
367             extra_bits_below += extra_bits_increase[i];
368             // The increase for this property value has been used, and will not
369             // be used again: clear it. Also below.
370             extra_bits_increase[i] = 0;
371             for (size_t sym = 0; sym < max_symbols; sym++) {
372               counts_above[sym] -= count_increase[i * max_symbols + sym];
373               counts_below[sym] += count_increase[i * max_symbols + sym];
374               count_increase[i * max_symbols + sym] = 0;
375             }
376             float rcost = EstimateBits(counts_above.data(),
377                                        rounded_counts.data(), max_symbols) +
378                           tot_extra_bits[pred] - extra_bits_below;
379             float lcost = EstimateBits(counts_below.data(),
380                                        rounded_counts.data(), max_symbols) +
381                           extra_bits_below;
382             JXL_DASSERT(extra_bits_below <= tot_extra_bits[pred]);
383             float penalty = 0;
384             // Never discourage moving away from the Weighted predictor.
385             if (tree_samples.PredictorFromIndex(pred) !=
386                     (*tree)[pos].predictor &&
387                 (*tree)[pos].predictor != Predictor::Weighted) {
388               penalty = change_pred_penalty;
389             }
390             // If everything else is equal, disfavour Weighted (slower) and
391             // favour Zero (faster if it's the only predictor used in a
392             // group+channel combination)
393             if (tree_samples.PredictorFromIndex(pred) == Predictor::Weighted) {
394               penalty += 1e-8;
395             }
396             if (tree_samples.PredictorFromIndex(pred) == Predictor::Zero) {
397               penalty -= 1e-8;
398             }
399             if (rcost + penalty < costs_r[i - first_used].Cost()) {
400               costs_r[i - first_used].cost = rcost;
401               costs_r[i - first_used].extra_cost = penalty;
402               costs_r[i - first_used].pred =
403                   tree_samples.PredictorFromIndex(pred);
404             }
405             if (lcost + penalty < costs_l[i - first_used].Cost()) {
406               costs_l[i - first_used].cost = lcost;
407               costs_l[i - first_used].extra_cost = penalty;
408               costs_l[i - first_used].pred =
409                   tree_samples.PredictorFromIndex(pred);
410             }
411           }
412         }
413         // Iterate through the possible splits and find the one with minimum sum
414         // of costs of the two sides.
415         size_t split = begin;
416         for (size_t i = first_used; i < last_used; i++) {
417           if (!prop_value_used_count[i]) continue;
418           split += prop_value_used_count[i];
419           float rcost = costs_r[i - first_used].cost;
420           float lcost = costs_l[i - first_used].cost;
421           // WP was not used + we would use the WP property or predictor
422           bool adds_wp =
423               (tree_samples.PropertyFromIndex(prop) == kWPProp &&
424                (used_properties & (1LU << prop)) == 0) ||
425               ((costs_l[i - first_used].pred == Predictor::Weighted ||
426                 costs_r[i - first_used].pred == Predictor::Weighted) &&
427                (*tree)[pos].predictor != Predictor::Weighted);
428           bool zero_entropy_side = rcost == 0 || lcost == 0;
429 
430           SplitInfo &best =
431               prop < kNumStaticProperties
432                   ? (zero_entropy_side ? best_split_static_constant
433                                        : best_split_static)
434                   : (adds_wp ? best_split_nonstatic : best_split_nowp);
435           if (lcost + rcost < best.Cost()) {
436             best.prop = prop;
437             best.val = i;
438             best.pos = split;
439             best.lcost = lcost;
440             best.lpred = costs_l[i - first_used].pred;
441             best.rcost = rcost;
442             best.rpred = costs_r[i - first_used].pred;
443           }
444         }
445         // Clear extra_bits_increase and cost_increase for last_used.
446         extra_bits_increase[last_used] = 0;
447         for (size_t sym = 0; sym < max_symbols; sym++) {
448           count_increase[last_used * max_symbols + sym] = 0;
449         }
450       }
451 
452       // Try to avoid introducing WP.
453       if (best_split_nowp.Cost() + threshold < base_bits &&
454           best_split_nowp.Cost() <= fast_decode_multiplier * best->Cost()) {
455         best = &best_split_nowp;
456       }
457       // Split along static props if possible and not significantly more
458       // expensive.
459       if (best_split_static.Cost() + threshold < base_bits &&
460           best_split_static.Cost() <= fast_decode_multiplier * best->Cost()) {
461         best = &best_split_static;
462       }
463       // Split along static props to create constant nodes if possible.
464       if (best_split_static_constant.Cost() + threshold < base_bits) {
465         best = &best_split_static_constant;
466       }
467     }
468 
469     if (best->Cost() + threshold < base_bits) {
470       uint32_t p = tree_samples.PropertyFromIndex(best->prop);
471       pixel_type dequant =
472           tree_samples.UnquantizeProperty(best->prop, best->val);
473       // Split node and try to split children.
474       MakeSplitNode(pos, p, dequant, best->lpred, 0, best->rpred, 0, tree);
475       // "Sort" according to winning property
476       SplitTreeSamples(tree_samples, begin, best->pos, end, best->prop);
477       if (p >= kNumStaticProperties) {
478         used_properties |= 1 << best->prop;
479       }
480       auto new_sp_range = static_prop_range;
481       if (p < kNumStaticProperties) {
482         JXL_ASSERT(static_cast<uint32_t>(dequant + 1) <= new_sp_range[p][1]);
483         new_sp_range[p][1] = dequant + 1;
484         JXL_ASSERT(new_sp_range[p][0] < new_sp_range[p][1]);
485       }
486       nodes.push_back(NodeInfo{(*tree)[pos].rchild, begin, best->pos,
487                                used_properties, new_sp_range});
488       new_sp_range = static_prop_range;
489       if (p < kNumStaticProperties) {
490         JXL_ASSERT(new_sp_range[p][0] <= static_cast<uint32_t>(dequant + 1));
491         new_sp_range[p][0] = dequant + 1;
492         JXL_ASSERT(new_sp_range[p][0] < new_sp_range[p][1]);
493       }
494       nodes.push_back(NodeInfo{(*tree)[pos].lchild, best->pos, end,
495                                used_properties, new_sp_range});
496     }
497   }
498 }
499 
500 // NOLINTNEXTLINE(google-readability-namespace-comments)
501 }  // namespace HWY_NAMESPACE
502 }  // namespace jxl
503 HWY_AFTER_NAMESPACE();
504 
505 #if HWY_ONCE
506 namespace jxl {
507 
508 HWY_EXPORT(FindBestSplit);  // Local function.
509 
ComputeBestTree(TreeSamples & tree_samples,float threshold,const std::vector<ModularMultiplierInfo> & mul_info,StaticPropRange static_prop_range,float fast_decode_multiplier,Tree * tree)510 void ComputeBestTree(TreeSamples &tree_samples, float threshold,
511                      const std::vector<ModularMultiplierInfo> &mul_info,
512                      StaticPropRange static_prop_range,
513                      float fast_decode_multiplier, Tree *tree) {
514   // TODO(veluca): take into account that different contexts can have different
515   // uint configs.
516   //
517   // Initialize tree.
518   tree->emplace_back();
519   tree->back().property = -1;
520   tree->back().predictor = tree_samples.PredictorFromIndex(0);
521   tree->back().predictor_offset = 0;
522   tree->back().multiplier = 1;
523   JXL_ASSERT(tree_samples.NumProperties() < 64);
524 
525   JXL_ASSERT(tree_samples.NumDistinctSamples() <=
526              std::numeric_limits<uint32_t>::max());
527   HWY_DYNAMIC_DISPATCH(FindBestSplit)
528   (tree_samples, threshold, mul_info, static_prop_range, fast_decode_multiplier,
529    tree);
530 }
531 
532 constexpr int TreeSamples::kPropertyRange;
533 constexpr uint32_t TreeSamples::kDedupEntryUnused;
534 
SetPredictor(Predictor predictor,ModularOptions::TreeMode wp_tree_mode)535 Status TreeSamples::SetPredictor(Predictor predictor,
536                                  ModularOptions::TreeMode wp_tree_mode) {
537   if (wp_tree_mode == ModularOptions::TreeMode::kWPOnly) {
538     predictors = {Predictor::Weighted};
539     residuals.resize(1);
540     return true;
541   }
542   if (wp_tree_mode == ModularOptions::TreeMode::kNoWP &&
543       predictor == Predictor::Weighted) {
544     return JXL_FAILURE("Invalid predictor settings");
545   }
546   if (predictor == Predictor::Variable) {
547     for (size_t i = 0; i < kNumModularPredictors; i++) {
548       predictors.push_back(static_cast<Predictor>(i));
549     }
550     std::swap(predictors[0], predictors[static_cast<int>(Predictor::Weighted)]);
551     std::swap(predictors[1], predictors[static_cast<int>(Predictor::Gradient)]);
552   } else if (predictor == Predictor::Best) {
553     predictors = {Predictor::Weighted, Predictor::Gradient};
554   } else {
555     predictors = {predictor};
556   }
557   if (wp_tree_mode == ModularOptions::TreeMode::kNoWP) {
558     auto wp_it =
559         std::find(predictors.begin(), predictors.end(), Predictor::Weighted);
560     if (wp_it != predictors.end()) {
561       predictors.erase(wp_it);
562     }
563   }
564   residuals.resize(predictors.size());
565   return true;
566 }
567 
SetProperties(const std::vector<uint32_t> & properties,ModularOptions::TreeMode wp_tree_mode)568 Status TreeSamples::SetProperties(const std::vector<uint32_t> &properties,
569                                   ModularOptions::TreeMode wp_tree_mode) {
570   props_to_use = properties;
571   if (wp_tree_mode == ModularOptions::TreeMode::kWPOnly) {
572     props_to_use = {static_cast<uint32_t>(kWPProp)};
573   }
574   if (wp_tree_mode == ModularOptions::TreeMode::kGradientOnly) {
575     props_to_use = {static_cast<uint32_t>(kGradientProp)};
576   }
577   if (wp_tree_mode == ModularOptions::TreeMode::kNoWP) {
578     auto it = std::find(props_to_use.begin(), props_to_use.end(), kWPProp);
579     if (it != props_to_use.end()) {
580       props_to_use.erase(it);
581     }
582   }
583   if (props_to_use.empty()) {
584     return JXL_FAILURE("Invalid property set configuration");
585   }
586   props.resize(props_to_use.size());
587   return true;
588 }
589 
InitTable(size_t size)590 void TreeSamples::InitTable(size_t size) {
591   JXL_DASSERT((size & (size - 1)) == 0);
592   if (dedup_table_.size() == size) return;
593   dedup_table_.resize(size, kDedupEntryUnused);
594   for (size_t i = 0; i < NumDistinctSamples(); i++) {
595     if (sample_counts[i] != std::numeric_limits<uint16_t>::max()) {
596       AddToTable(i);
597     }
598   }
599 }
600 
AddToTableAndMerge(size_t a)601 bool TreeSamples::AddToTableAndMerge(size_t a) {
602   size_t pos1 = Hash1(a);
603   size_t pos2 = Hash2(a);
604   if (dedup_table_[pos1] != kDedupEntryUnused &&
605       IsSameSample(a, dedup_table_[pos1])) {
606     JXL_DASSERT(sample_counts[a] == 1);
607     sample_counts[dedup_table_[pos1]]++;
608     // Remove from hash table samples that are saturated.
609     if (sample_counts[dedup_table_[pos1]] ==
610         std::numeric_limits<uint16_t>::max()) {
611       dedup_table_[pos1] = kDedupEntryUnused;
612     }
613     return true;
614   }
615   if (dedup_table_[pos2] != kDedupEntryUnused &&
616       IsSameSample(a, dedup_table_[pos2])) {
617     JXL_DASSERT(sample_counts[a] == 1);
618     sample_counts[dedup_table_[pos2]]++;
619     // Remove from hash table samples that are saturated.
620     if (sample_counts[dedup_table_[pos2]] ==
621         std::numeric_limits<uint16_t>::max()) {
622       dedup_table_[pos2] = kDedupEntryUnused;
623     }
624     return true;
625   }
626   AddToTable(a);
627   return false;
628 }
629 
AddToTable(size_t a)630 void TreeSamples::AddToTable(size_t a) {
631   size_t pos1 = Hash1(a);
632   size_t pos2 = Hash2(a);
633   if (dedup_table_[pos1] == kDedupEntryUnused) {
634     dedup_table_[pos1] = a;
635   } else if (dedup_table_[pos2] == kDedupEntryUnused) {
636     dedup_table_[pos2] = a;
637   }
638 }
639 
PrepareForSamples(size_t num_samples)640 void TreeSamples::PrepareForSamples(size_t num_samples) {
641   for (auto &res : residuals) {
642     res.reserve(res.size() + num_samples);
643   }
644   for (auto &p : props) {
645     p.reserve(p.size() + num_samples);
646   }
647   size_t total_num_samples = num_samples + sample_counts.size();
648   size_t next_pow2 = 1LLU << CeilLog2Nonzero(total_num_samples * 3 / 2);
649   InitTable(next_pow2);
650 }
651 
Hash1(size_t a) const652 size_t TreeSamples::Hash1(size_t a) const {
653   constexpr uint64_t constant = 0x1e35a7bd;
654   uint64_t h = constant;
655   for (const auto &r : residuals) {
656     h = h * constant + r[a].tok;
657     h = h * constant + r[a].nbits;
658   }
659   for (const auto &p : props) {
660     h = h * constant + p[a];
661   }
662   return (h >> 16) & (dedup_table_.size() - 1);
663 }
Hash2(size_t a) const664 size_t TreeSamples::Hash2(size_t a) const {
665   constexpr uint64_t constant = 0x1e35a7bd1e35a7bd;
666   uint64_t h = constant;
667   for (const auto &p : props) {
668     h = h * constant ^ p[a];
669   }
670   for (const auto &r : residuals) {
671     h = h * constant ^ r[a].tok;
672     h = h * constant ^ r[a].nbits;
673   }
674   return (h >> 16) & (dedup_table_.size() - 1);
675 }
676 
IsSameSample(size_t a,size_t b) const677 bool TreeSamples::IsSameSample(size_t a, size_t b) const {
678   bool ret = true;
679   for (const auto &r : residuals) {
680     if (r[a].tok != r[b].tok) {
681       ret = false;
682     }
683     if (r[a].nbits != r[b].nbits) {
684       ret = false;
685     }
686   }
687   for (const auto &p : props) {
688     if (p[a] != p[b]) {
689       ret = false;
690     }
691   }
692   return ret;
693 }
694 
AddSample(pixel_type_w pixel,const Properties & properties,const pixel_type_w * predictions)695 void TreeSamples::AddSample(pixel_type_w pixel, const Properties &properties,
696                             const pixel_type_w *predictions) {
697   for (size_t i = 0; i < predictors.size(); i++) {
698     pixel_type v = pixel - predictions[static_cast<int>(predictors[i])];
699     uint32_t tok, nbits, bits;
700     HybridUintConfig(4, 1, 2).Encode(PackSigned(v), &tok, &nbits, &bits);
701     JXL_DASSERT(tok < 256);
702     JXL_DASSERT(nbits < 256);
703     residuals[i].emplace_back(
704         ResidualToken{static_cast<uint8_t>(tok), static_cast<uint8_t>(nbits)});
705   }
706   for (size_t i = 0; i < props_to_use.size(); i++) {
707     props[i].push_back(QuantizeProperty(i, properties[props_to_use[i]]));
708   }
709   sample_counts.push_back(1);
710   num_samples++;
711   if (AddToTableAndMerge(sample_counts.size() - 1)) {
712     for (auto &r : residuals) r.pop_back();
713     for (auto &p : props) p.pop_back();
714     sample_counts.pop_back();
715   }
716 }
717 
Swap(size_t a,size_t b)718 void TreeSamples::Swap(size_t a, size_t b) {
719   if (a == b) return;
720   for (auto &r : residuals) {
721     std::swap(r[a], r[b]);
722   }
723   for (auto &p : props) {
724     std::swap(p[a], p[b]);
725   }
726   std::swap(sample_counts[a], sample_counts[b]);
727 }
728 
ThreeShuffle(size_t a,size_t b,size_t c)729 void TreeSamples::ThreeShuffle(size_t a, size_t b, size_t c) {
730   if (b == c) return Swap(a, b);
731   for (auto &r : residuals) {
732     auto tmp = r[a];
733     r[a] = r[c];
734     r[c] = r[b];
735     r[b] = tmp;
736   }
737   for (auto &p : props) {
738     auto tmp = p[a];
739     p[a] = p[c];
740     p[c] = p[b];
741     p[b] = tmp;
742   }
743   auto tmp = sample_counts[a];
744   sample_counts[a] = sample_counts[c];
745   sample_counts[c] = sample_counts[b];
746   sample_counts[b] = tmp;
747 }
748 
749 namespace {
QuantizeHistogram(const std::vector<uint32_t> & histogram,size_t num_chunks)750 std::vector<int> QuantizeHistogram(const std::vector<uint32_t> &histogram,
751                                    size_t num_chunks) {
752   if (histogram.empty()) return {};
753   // TODO(veluca): selecting distinct quantiles is likely not the best
754   // way to go about this.
755   std::vector<int> thresholds;
756   size_t sum = std::accumulate(histogram.begin(), histogram.end(), 0LU);
757   size_t cumsum = 0;
758   size_t threshold = 0;
759   for (size_t i = 0; i + 1 < histogram.size(); i++) {
760     cumsum += histogram[i];
761     if (cumsum > (threshold + 1) * sum / num_chunks) {
762       thresholds.push_back(i);
763       while (cumsum >= (threshold + 1) * sum / num_chunks) threshold++;
764     }
765   }
766   return thresholds;
767 }
768 
QuantizeSamples(const std::vector<int32_t> & samples,size_t num_chunks)769 std::vector<int> QuantizeSamples(const std::vector<int32_t> &samples,
770                                  size_t num_chunks) {
771   if (samples.empty()) return {};
772   int min = *std::min_element(samples.begin(), samples.end());
773   constexpr int kRange = 512;
774   min = std::min(std::max(min, -kRange), kRange);
775   std::vector<uint32_t> counts(2 * kRange + 1);
776   for (int s : samples) {
777     uint32_t sample_offset = std::min(std::max(s, -kRange), kRange) - min;
778     counts[sample_offset]++;
779   }
780   std::vector<int> thresholds = QuantizeHistogram(counts, num_chunks);
781   for (auto &v : thresholds) v += min;
782   return thresholds;
783 }
784 }  // namespace
785 
PreQuantizeProperties(const StaticPropRange & range,const std::vector<ModularMultiplierInfo> & multiplier_info,const std::vector<uint32_t> & group_pixel_count,const std::vector<uint32_t> & channel_pixel_count,std::vector<pixel_type> & pixel_samples,std::vector<pixel_type> & diff_samples,size_t max_property_values)786 void TreeSamples::PreQuantizeProperties(
787     const StaticPropRange &range,
788     const std::vector<ModularMultiplierInfo> &multiplier_info,
789     const std::vector<uint32_t> &group_pixel_count,
790     const std::vector<uint32_t> &channel_pixel_count,
791     std::vector<pixel_type> &pixel_samples,
792     std::vector<pixel_type> &diff_samples, size_t max_property_values) {
793   // If we have forced splits because of multipliers, choose channel and group
794   // thresholds accordingly.
795   std::vector<int32_t> group_multiplier_thresholds;
796   std::vector<int32_t> channel_multiplier_thresholds;
797   for (const auto &v : multiplier_info) {
798     if (v.range[0][0] != range[0][0]) {
799       channel_multiplier_thresholds.push_back(v.range[0][0] - 1);
800     }
801     if (v.range[0][1] != range[0][1]) {
802       channel_multiplier_thresholds.push_back(v.range[0][1] - 1);
803     }
804     if (v.range[1][0] != range[1][0]) {
805       group_multiplier_thresholds.push_back(v.range[1][0] - 1);
806     }
807     if (v.range[1][1] != range[1][1]) {
808       group_multiplier_thresholds.push_back(v.range[1][1] - 1);
809     }
810   }
811   std::sort(channel_multiplier_thresholds.begin(),
812             channel_multiplier_thresholds.end());
813   channel_multiplier_thresholds.resize(
814       std::unique(channel_multiplier_thresholds.begin(),
815                   channel_multiplier_thresholds.end()) -
816       channel_multiplier_thresholds.begin());
817   std::sort(group_multiplier_thresholds.begin(),
818             group_multiplier_thresholds.end());
819   group_multiplier_thresholds.resize(
820       std::unique(group_multiplier_thresholds.begin(),
821                   group_multiplier_thresholds.end()) -
822       group_multiplier_thresholds.begin());
823 
824   compact_properties.resize(props_to_use.size());
825   auto quantize_channel = [&]() {
826     if (!channel_multiplier_thresholds.empty()) {
827       return channel_multiplier_thresholds;
828     }
829     return QuantizeHistogram(channel_pixel_count, max_property_values);
830   };
831   auto quantize_group_id = [&]() {
832     if (!group_multiplier_thresholds.empty()) {
833       return group_multiplier_thresholds;
834     }
835     return QuantizeHistogram(group_pixel_count, max_property_values);
836   };
837   auto quantize_coordinate = [&]() {
838     std::vector<int> quantized;
839     quantized.reserve(max_property_values - 1);
840     for (size_t i = 0; i + 1 < max_property_values; i++) {
841       quantized.push_back((i + 1) * 256 / max_property_values - 1);
842     }
843     return quantized;
844   };
845   std::vector<int> abs_pixel_thr;
846   std::vector<int> pixel_thr;
847   auto quantize_pixel_property = [&]() {
848     if (pixel_thr.empty()) {
849       pixel_thr = QuantizeSamples(pixel_samples, max_property_values);
850     }
851     return pixel_thr;
852   };
853   auto quantize_abs_pixel_property = [&]() {
854     if (abs_pixel_thr.empty()) {
855       quantize_pixel_property();  // Compute the non-abs thresholds.
856       for (auto &v : pixel_samples) v = std::abs(v);
857       abs_pixel_thr = QuantizeSamples(pixel_samples, max_property_values);
858     }
859     return abs_pixel_thr;
860   };
861   std::vector<int> abs_diff_thr;
862   std::vector<int> diff_thr;
863   auto quantize_diff_property = [&]() {
864     if (diff_thr.empty()) {
865       diff_thr = QuantizeSamples(diff_samples, max_property_values);
866     }
867     return diff_thr;
868   };
869   auto quantize_abs_diff_property = [&]() {
870     if (abs_diff_thr.empty()) {
871       quantize_diff_property();  // Compute the non-abs thresholds.
872       for (auto &v : diff_samples) v = std::abs(v);
873       abs_diff_thr = QuantizeSamples(diff_samples, max_property_values);
874     }
875     return abs_diff_thr;
876   };
877   auto quantize_wp = [&]() {
878     if (max_property_values < 32) {
879       return std::vector<int>{-127, -63, -31, -15, -7, -3, -1, 0,
880                               1,    3,   7,   15,  31, 63, 127};
881     }
882     if (max_property_values < 64) {
883       return std::vector<int>{-255, -191, -127, -95, -63, -47, -31, -23,
884                               -15,  -11,  -7,   -5,  -3,  -1,  0,   1,
885                               3,    5,    7,    11,  15,  23,  31,  47,
886                               63,   95,   127,  191, 255};
887     }
888     return std::vector<int>{
889         -255, -223, -191, -159, -127, -111, -95, -79, -63, -55, -47,
890         -39,  -31,  -27,  -23,  -19,  -15,  -13, -11, -9,  -7,  -6,
891         -5,   -4,   -3,   -2,   -1,   0,    1,   2,   3,   4,   5,
892         6,    7,    9,    11,   13,   15,   19,  23,  27,  31,  39,
893         47,   55,   63,   79,   95,   111,  127, 159, 191, 223, 255};
894   };
895 
896   property_mapping.resize(props_to_use.size());
897   for (size_t i = 0; i < props_to_use.size(); i++) {
898     if (props_to_use[i] == 0) {
899       compact_properties[i] = quantize_channel();
900     } else if (props_to_use[i] == 1) {
901       compact_properties[i] = quantize_group_id();
902     } else if (props_to_use[i] == 2 || props_to_use[i] == 3) {
903       compact_properties[i] = quantize_coordinate();
904     } else if (props_to_use[i] == 6 || props_to_use[i] == 7 ||
905                props_to_use[i] == 8 ||
906                (props_to_use[i] >= kNumNonrefProperties &&
907                 (props_to_use[i] - kNumNonrefProperties) % 4 == 1)) {
908       compact_properties[i] = quantize_pixel_property();
909     } else if (props_to_use[i] == 4 || props_to_use[i] == 5 ||
910                (props_to_use[i] >= kNumNonrefProperties &&
911                 (props_to_use[i] - kNumNonrefProperties) % 4 == 0)) {
912       compact_properties[i] = quantize_abs_pixel_property();
913     } else if (props_to_use[i] >= kNumNonrefProperties &&
914                (props_to_use[i] - kNumNonrefProperties) % 4 == 2) {
915       compact_properties[i] = quantize_abs_diff_property();
916     } else if (props_to_use[i] == kWPProp) {
917       compact_properties[i] = quantize_wp();
918     } else {
919       compact_properties[i] = quantize_diff_property();
920     }
921     property_mapping[i].resize(kPropertyRange * 2 + 1);
922     size_t mapped = 0;
923     for (size_t j = 0; j < property_mapping[i].size(); j++) {
924       while (mapped < compact_properties[i].size() &&
925              static_cast<int>(j) - kPropertyRange >
926                  compact_properties[i][mapped]) {
927         mapped++;
928       }
929       // property_mapping[i] of a value V is `mapped` if
930       // compact_properties[i][mapped] <= j and
931       // compact_properties[i][mapped-1] > j
932       // This is because the decision node in the tree splits on (property) > j,
933       // hence everything that is not > of a threshold should be clustered
934       // together.
935       property_mapping[i][j] = mapped;
936     }
937   }
938 }
939 
CollectPixelSamples(const Image & image,const ModularOptions & options,size_t group_id,std::vector<uint32_t> & group_pixel_count,std::vector<uint32_t> & channel_pixel_count,std::vector<pixel_type> & pixel_samples,std::vector<pixel_type> & diff_samples)940 void CollectPixelSamples(const Image &image, const ModularOptions &options,
941                          size_t group_id,
942                          std::vector<uint32_t> &group_pixel_count,
943                          std::vector<uint32_t> &channel_pixel_count,
944                          std::vector<pixel_type> &pixel_samples,
945                          std::vector<pixel_type> &diff_samples) {
946   if (options.nb_repeats == 0) return;
947   if (group_pixel_count.size() <= group_id) {
948     group_pixel_count.resize(group_id + 1);
949   }
950   if (channel_pixel_count.size() < image.channel.size()) {
951     channel_pixel_count.resize(image.channel.size());
952   }
953   Rng rng(group_id);
954   // Sample 10% of the final number of samples for property quantization.
955   float fraction = options.nb_repeats * 0.1;
956   std::geometric_distribution<uint32_t> dist(fraction);
957   size_t total_pixels = 0;
958   std::vector<size_t> channel_ids;
959   for (size_t i = 0; i < image.channel.size(); i++) {
960     if (image.channel[i].w <= 1 || image.channel[i].h == 0) {
961       continue;  // skip empty or width-1 channels.
962     }
963     if (i >= image.nb_meta_channels &&
964         (image.channel[i].w > options.max_chan_size ||
965          image.channel[i].h > options.max_chan_size)) {
966       break;
967     }
968     channel_ids.push_back(i);
969     group_pixel_count[group_id] += image.channel[i].w * image.channel[i].h;
970     channel_pixel_count[i] += image.channel[i].w * image.channel[i].h;
971     total_pixels += image.channel[i].w * image.channel[i].h;
972   }
973   if (channel_ids.empty()) return;
974   pixel_samples.reserve(pixel_samples.size() + fraction * total_pixels);
975   diff_samples.reserve(diff_samples.size() + fraction * total_pixels);
976   size_t i = 0;
977   size_t y = 0;
978   size_t x = 0;
979   auto advance = [&](size_t amount) {
980     x += amount;
981     // Detect row overflow (rare).
982     while (x >= image.channel[channel_ids[i]].w) {
983       x -= image.channel[channel_ids[i]].w;
984       y++;
985       // Detect end-of-channel (even rarer).
986       if (y == image.channel[channel_ids[i]].h) {
987         i++;
988         y = 0;
989         if (i >= channel_ids.size()) {
990           return;
991         }
992       }
993     }
994   };
995   advance(dist(rng));
996   for (; i < channel_ids.size(); advance(dist(rng) + 1)) {
997     const pixel_type *row = image.channel[channel_ids[i]].Row(y);
998     pixel_samples.push_back(row[x]);
999     size_t xp = x == 0 ? 1 : x - 1;
1000     diff_samples.push_back(row[x] - row[xp]);
1001   }
1002 }
1003 
1004 // TODO(veluca): very simple encoding scheme. This should be improved.
TokenizeTree(const Tree & tree,std::vector<Token> * tokens,Tree * decoder_tree)1005 void TokenizeTree(const Tree &tree, std::vector<Token> *tokens,
1006                   Tree *decoder_tree) {
1007   JXL_ASSERT(tree.size() <= kMaxTreeSize);
1008   std::queue<int> q;
1009   q.push(0);
1010   size_t leaf_id = 0;
1011   decoder_tree->clear();
1012   while (!q.empty()) {
1013     int cur = q.front();
1014     q.pop();
1015     JXL_ASSERT(tree[cur].property >= -1);
1016     tokens->emplace_back(kPropertyContext, tree[cur].property + 1);
1017     if (tree[cur].property == -1) {
1018       tokens->emplace_back(kPredictorContext,
1019                            static_cast<int>(tree[cur].predictor));
1020       tokens->emplace_back(kOffsetContext,
1021                            PackSigned(tree[cur].predictor_offset));
1022       uint32_t mul_log = Num0BitsBelowLS1Bit_Nonzero(tree[cur].multiplier);
1023       uint32_t mul_bits = (tree[cur].multiplier >> mul_log) - 1;
1024       tokens->emplace_back(kMultiplierLogContext, mul_log);
1025       tokens->emplace_back(kMultiplierBitsContext, mul_bits);
1026       JXL_ASSERT(tree[cur].predictor < Predictor::Best);
1027       decoder_tree->emplace_back(-1, 0, leaf_id++, 0, tree[cur].predictor,
1028                                  tree[cur].predictor_offset,
1029                                  tree[cur].multiplier);
1030       continue;
1031     }
1032     decoder_tree->emplace_back(tree[cur].property, tree[cur].splitval,
1033                                decoder_tree->size() + q.size() + 1,
1034                                decoder_tree->size() + q.size() + 2,
1035                                Predictor::Zero, 0, 1);
1036     q.push(tree[cur].lchild);
1037     q.push(tree[cur].rchild);
1038     tokens->emplace_back(kSplitValContext, PackSigned(tree[cur].splitval));
1039   }
1040 }
1041 
1042 }  // namespace jxl
1043 #endif  // HWY_ONCE
1044