1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
2 //
3 // Use of this source code is governed by a BSD-style
4 // license that can be found in the LICENSE file.
5 
6 #include "lib/jxl/modular/encoding/encoding.h"
7 
8 #include <stdint.h>
9 #include <stdlib.h>
10 
11 #include <queue>
12 
13 #include "lib/jxl/modular/encoding/context_predict.h"
14 #include "lib/jxl/modular/options.h"
15 
16 namespace jxl {
17 
18 // Removes all nodes that use a static property (i.e. channel or group ID) from
19 // the tree and collapses each node on even levels with its two children to
20 // produce a flatter tree. Also computes whether the resulting tree requires
21 // using the weighted predictor.
FilterTree(const Tree & global_tree,std::array<pixel_type,kNumStaticProperties> & static_props,size_t * num_props,bool * use_wp,bool * wp_only,bool * gradient_only)22 FlatTree FilterTree(const Tree &global_tree,
23                     std::array<pixel_type, kNumStaticProperties> &static_props,
24                     size_t *num_props, bool *use_wp, bool *wp_only,
25                     bool *gradient_only) {
26   *num_props = 0;
27   bool has_wp = false;
28   bool has_non_wp = false;
29   *gradient_only = true;
30   const auto mark_property = [&](int32_t p) {
31     if (p == kWPProp) {
32       has_wp = true;
33     } else if (p >= kNumStaticProperties) {
34       has_non_wp = true;
35     }
36     if (p >= kNumStaticProperties && p != kGradientProp) {
37       *gradient_only = false;
38     }
39   };
40   FlatTree output;
41   std::queue<size_t> nodes;
42   nodes.push(0);
43   // Produces a trimmed and flattened tree by doing a BFS visit of the original
44   // tree, ignoring branches that are known to be false and proceeding two
45   // levels at a time to collapse nodes in a flatter tree; if an inner parent
46   // node has a leaf as a child, the leaf is duplicated and an implicit fake
47   // node is added. This allows to reduce the number of branches when traversing
48   // the resulting flat tree.
49   while (!nodes.empty()) {
50     size_t cur = nodes.front();
51     nodes.pop();
52     // Skip nodes that we can decide now, by jumping directly to their children.
53     while (global_tree[cur].property < kNumStaticProperties &&
54            global_tree[cur].property != -1) {
55       if (static_props[global_tree[cur].property] > global_tree[cur].splitval) {
56         cur = global_tree[cur].lchild;
57       } else {
58         cur = global_tree[cur].rchild;
59       }
60     }
61     FlatDecisionNode flat;
62     if (global_tree[cur].property == -1) {
63       flat.property0 = -1;
64       flat.childID = global_tree[cur].lchild;
65       flat.predictor = global_tree[cur].predictor;
66       flat.predictor_offset = global_tree[cur].predictor_offset;
67       flat.multiplier = global_tree[cur].multiplier;
68       *gradient_only &= flat.predictor == Predictor::Gradient;
69       has_wp |= flat.predictor == Predictor::Weighted;
70       has_non_wp |= flat.predictor != Predictor::Weighted;
71       output.push_back(flat);
72       continue;
73     }
74     flat.childID = output.size() + nodes.size() + 1;
75 
76     flat.property0 = global_tree[cur].property;
77     *num_props = std::max<size_t>(flat.property0 + 1, *num_props);
78     flat.splitval0 = global_tree[cur].splitval;
79 
80     for (size_t i = 0; i < 2; i++) {
81       size_t cur_child =
82           i == 0 ? global_tree[cur].lchild : global_tree[cur].rchild;
83       // Skip nodes that we can decide now.
84       while (global_tree[cur_child].property < kNumStaticProperties &&
85              global_tree[cur_child].property != -1) {
86         if (static_props[global_tree[cur_child].property] >
87             global_tree[cur_child].splitval) {
88           cur_child = global_tree[cur_child].lchild;
89         } else {
90           cur_child = global_tree[cur_child].rchild;
91         }
92       }
93       // We ended up in a leaf, add a dummy decision and two copies of the leaf.
94       if (global_tree[cur_child].property == -1) {
95         flat.properties[i] = 0;
96         flat.splitvals[i] = 0;
97         nodes.push(cur_child);
98         nodes.push(cur_child);
99       } else {
100         flat.properties[i] = global_tree[cur_child].property;
101         flat.splitvals[i] = global_tree[cur_child].splitval;
102         nodes.push(global_tree[cur_child].lchild);
103         nodes.push(global_tree[cur_child].rchild);
104         *num_props = std::max<size_t>(flat.properties[i] + 1, *num_props);
105       }
106     }
107 
108     for (size_t j = 0; j < 2; j++) mark_property(flat.properties[j]);
109     mark_property(flat.property0);
110     output.push_back(flat);
111   }
112   if (*num_props > kNumNonrefProperties) {
113     *num_props =
114         DivCeil(*num_props - kNumNonrefProperties, kExtraPropsPerChannel) *
115             kExtraPropsPerChannel +
116         kNumNonrefProperties;
117   } else {
118     *num_props = kNumNonrefProperties;
119   }
120   *use_wp = has_wp;
121   *wp_only = has_wp && !has_non_wp;
122 
123   return output;
124 }
125 
DecodeModularChannelMAANS(BitReader * br,ANSSymbolReader * reader,const std::vector<uint8_t> & context_map,const Tree & global_tree,const weighted::Header & wp_header,pixel_type chan,size_t group_id,Image * image)126 Status DecodeModularChannelMAANS(BitReader *br, ANSSymbolReader *reader,
127                                  const std::vector<uint8_t> &context_map,
128                                  const Tree &global_tree,
129                                  const weighted::Header &wp_header,
130                                  pixel_type chan, size_t group_id,
131                                  Image *image) {
132   Channel &channel = image->channel[chan];
133 
134   std::array<pixel_type, kNumStaticProperties> static_props = {
135       {chan, (int)group_id}};
136   // TODO(veluca): filter the tree according to static_props.
137 
138   // zero pixel channel? could happen
139   if (channel.w == 0 || channel.h == 0) return true;
140 
141   bool tree_has_wp_prop_or_pred = false;
142   bool is_wp_only = false;
143   bool is_gradient_only = false;
144   size_t num_props;
145   FlatTree tree =
146       FilterTree(global_tree, static_props, &num_props,
147                  &tree_has_wp_prop_or_pred, &is_wp_only, &is_gradient_only);
148 
149   // From here on, tree lookup returns a *clustered* context ID.
150   // This avoids an extra memory lookup after tree traversal.
151   for (size_t i = 0; i < tree.size(); i++) {
152     if (tree[i].property0 == -1) {
153       tree[i].childID = context_map[tree[i].childID];
154     }
155   }
156 
157   JXL_DEBUG_V(3, "Decoded MA tree with %zu nodes", tree.size());
158 
159   // MAANS decode
160   const auto make_pixel = [](uint64_t v, pixel_type multiplier,
161                              pixel_type_w offset) -> pixel_type {
162     JXL_DASSERT((v & 0xFFFFFFFF) == v);
163     pixel_type_w val = UnpackSigned(v);
164     // if it overflows, it overflows, and we have a problem anyway
165     return val * multiplier + offset;
166   };
167 
168   if (tree.size() == 1) {
169     // special optimized case: no meta-adaptation, so no need
170     // to compute properties.
171     Predictor predictor = tree[0].predictor;
172     int64_t offset = tree[0].predictor_offset;
173     int32_t multiplier = tree[0].multiplier;
174     size_t ctx_id = tree[0].childID;
175     if (predictor == Predictor::Zero) {
176       uint32_t value;
177       if (reader->IsSingleValueAndAdvance(ctx_id, &value,
178                                           channel.w * channel.h)) {
179         // Special-case: histogram has a single symbol, with no extra bits, and
180         // we use ANS mode.
181         JXL_DEBUG_V(8, "Fastest track.");
182         pixel_type v = make_pixel(value, multiplier, offset);
183         for (size_t y = 0; y < channel.h; y++) {
184           pixel_type *JXL_RESTRICT r = channel.Row(y);
185           std::fill(r, r + channel.w, v);
186         }
187 
188       } else {
189         JXL_DEBUG_V(8, "Fast track.");
190         for (size_t y = 0; y < channel.h; y++) {
191           pixel_type *JXL_RESTRICT r = channel.Row(y);
192           for (size_t x = 0; x < channel.w; x++) {
193             uint32_t v = reader->ReadHybridUintClustered(ctx_id, br);
194             r[x] = make_pixel(v, multiplier, offset);
195           }
196         }
197       }
198     } else if (predictor == Predictor::Gradient && offset == 0 &&
199                multiplier == 1) {
200       JXL_DEBUG_V(8, "Gradient very fast track.");
201       const intptr_t onerow = channel.plane.PixelsPerRow();
202       for (size_t y = 0; y < channel.h; y++) {
203         pixel_type *JXL_RESTRICT r = channel.Row(y);
204         for (size_t x = 0; x < channel.w; x++) {
205           pixel_type left = (x ? r[x - 1] : y ? *(r + x - onerow) : 0);
206           pixel_type top = (y ? *(r + x - onerow) : left);
207           pixel_type topleft = (x && y ? *(r + x - 1 - onerow) : left);
208           pixel_type guess = ClampedGradient(top, left, topleft);
209           uint64_t v = reader->ReadHybridUintClustered(ctx_id, br);
210           r[x] = make_pixel(v, 1, guess);
211         }
212       }
213     } else if (predictor != Predictor::Weighted) {
214       // special optimized case: no wp
215       JXL_DEBUG_V(8, "Quite fast track.");
216       const intptr_t onerow = channel.plane.PixelsPerRow();
217       for (size_t y = 0; y < channel.h; y++) {
218         pixel_type *JXL_RESTRICT r = channel.Row(y);
219         for (size_t x = 0; x < channel.w; x++) {
220           PredictionResult pred =
221               PredictNoTreeNoWP(channel.w, r + x, onerow, x, y, predictor);
222           pixel_type_w g = pred.guess + offset;
223           uint64_t v = reader->ReadHybridUintClustered(ctx_id, br);
224           // NOTE: pred.multiplier is unset.
225           r[x] = make_pixel(v, multiplier, g);
226         }
227       }
228     } else {
229       JXL_DEBUG_V(8, "Somewhat fast track.");
230       const intptr_t onerow = channel.plane.PixelsPerRow();
231       weighted::State wp_state(wp_header, channel.w, channel.h);
232       for (size_t y = 0; y < channel.h; y++) {
233         pixel_type *JXL_RESTRICT r = channel.Row(y);
234         for (size_t x = 0; x < channel.w; x++) {
235           pixel_type_w g = PredictNoTreeWP(channel.w, r + x, onerow, x, y,
236                                            predictor, &wp_state)
237                                .guess +
238                            offset;
239           uint64_t v = reader->ReadHybridUintClustered(ctx_id, br);
240           r[x] = make_pixel(v, multiplier, g);
241           wp_state.UpdateErrors(r[x], x, y, channel.w);
242         }
243       }
244     }
245     return true;
246   }
247 
248   // Check if this tree is a WP-only tree with a small enough property value
249   // range.
250   // Initialized to avoid clang-tidy complaining.
251   uint8_t context_lookup[2 * kPropRangeFast] = {};
252   int8_t multipliers[2 * kPropRangeFast] = {};
253   int8_t offsets[2 * kPropRangeFast] = {};
254   if (is_wp_only) {
255     is_wp_only = TreeToLookupTable(tree, context_lookup, offsets, multipliers);
256   }
257   if (is_gradient_only) {
258     is_gradient_only =
259         TreeToLookupTable(tree, context_lookup, offsets, multipliers);
260   }
261 
262   if (is_gradient_only) {
263     JXL_DEBUG_V(8, "Gradient fast track.");
264     const intptr_t onerow = channel.plane.PixelsPerRow();
265     for (size_t y = 0; y < channel.h; y++) {
266       pixel_type *JXL_RESTRICT r = channel.Row(y);
267       for (size_t x = 0; x < channel.w; x++) {
268         pixel_type_w left = (x ? r[x - 1] : y ? *(r + x - onerow) : 0);
269         pixel_type_w top = (y ? *(r + x - onerow) : left);
270         pixel_type_w topleft = (x && y ? *(r + x - 1 - onerow) : left);
271         int32_t guess = ClampedGradient(top, left, topleft);
272         uint32_t pos =
273             kPropRangeFast +
274             std::min<pixel_type_w>(
275                 std::max<pixel_type_w>(-kPropRangeFast, top + left - topleft),
276                 kPropRangeFast - 1);
277         uint32_t ctx_id = context_lookup[pos];
278         uint64_t v = reader->ReadHybridUintClustered(ctx_id, br);
279         r[x] = make_pixel(v, multipliers[pos],
280                           static_cast<pixel_type_w>(offsets[pos]) + guess);
281       }
282     }
283   } else if (is_wp_only) {
284     JXL_DEBUG_V(8, "WP fast track.");
285     const intptr_t onerow = channel.plane.PixelsPerRow();
286     weighted::State wp_state(wp_header, channel.w, channel.h);
287     Properties properties(1);
288     for (size_t y = 0; y < channel.h; y++) {
289       pixel_type *JXL_RESTRICT r = channel.Row(y);
290       for (size_t x = 0; x < channel.w; x++) {
291         size_t offset = 0;
292         pixel_type_w left = (x ? r[x - 1] : y ? *(r + x - onerow) : 0);
293         pixel_type_w top = (y ? *(r + x - onerow) : left);
294         pixel_type_w topleft = (x && y ? *(r + x - 1 - onerow) : left);
295         pixel_type_w topright =
296             (x + 1 < channel.w && y ? *(r + x + 1 - onerow) : top);
297         pixel_type_w toptop = (y > 1 ? *(r + x - onerow - onerow) : top);
298         int32_t guess = wp_state.Predict</*compute_properties=*/true>(
299             x, y, channel.w, top, left, topright, topleft, toptop, &properties,
300             offset);
301         uint32_t pos =
302             kPropRangeFast + std::min(std::max(-kPropRangeFast, properties[0]),
303                                       kPropRangeFast - 1);
304         uint32_t ctx_id = context_lookup[pos];
305         uint64_t v = reader->ReadHybridUintClustered(ctx_id, br);
306         r[x] = make_pixel(v, multipliers[pos],
307                           static_cast<pixel_type_w>(offsets[pos]) + guess);
308         wp_state.UpdateErrors(r[x], x, y, channel.w);
309       }
310     }
311   } else if (!tree_has_wp_prop_or_pred) {
312     // special optimized case: the weighted predictor and its properties are not
313     // used, so no need to compute weights and properties.
314     JXL_DEBUG_V(8, "Slow track.");
315     MATreeLookup tree_lookup(tree);
316     Properties properties = Properties(num_props);
317     const intptr_t onerow = channel.plane.PixelsPerRow();
318     Channel references(properties.size() - kNumNonrefProperties, channel.w);
319     for (size_t y = 0; y < channel.h; y++) {
320       pixel_type *JXL_RESTRICT p = channel.Row(y);
321       PrecomputeReferences(channel, y, *image, chan, &references);
322       InitPropsRow(&properties, static_props, y);
323       for (size_t x = 0; x < channel.w; x++) {
324         PredictionResult res =
325             PredictTreeNoWP(&properties, channel.w, p + x, onerow, x, y,
326                             tree_lookup, references);
327         uint64_t v = reader->ReadHybridUintClustered(res.context, br);
328         p[x] = make_pixel(v, res.multiplier, res.guess);
329       }
330     }
331   } else {
332     JXL_DEBUG_V(8, "Slowest track.");
333     MATreeLookup tree_lookup(tree);
334     Properties properties = Properties(num_props);
335     const intptr_t onerow = channel.plane.PixelsPerRow();
336     Channel references(properties.size() - kNumNonrefProperties, channel.w);
337     weighted::State wp_state(wp_header, channel.w, channel.h);
338     for (size_t y = 0; y < channel.h; y++) {
339       pixel_type *JXL_RESTRICT p = channel.Row(y);
340       InitPropsRow(&properties, static_props, y);
341       PrecomputeReferences(channel, y, *image, chan, &references);
342       for (size_t x = 0; x < channel.w; x++) {
343         PredictionResult res =
344             PredictTreeWP(&properties, channel.w, p + x, onerow, x, y,
345                           tree_lookup, references, &wp_state);
346         uint64_t v = reader->ReadHybridUintClustered(res.context, br);
347         p[x] = make_pixel(v, res.multiplier, res.guess);
348         wp_state.UpdateErrors(p[x], x, y, channel.w);
349       }
350     }
351   }
352   return true;
353 }
354 
GroupHeader()355 GroupHeader::GroupHeader() { Bundle::Init(this); }
356 
ValidateChannelDimensions(const Image & image,const ModularOptions & options)357 Status ValidateChannelDimensions(const Image &image,
358                                  const ModularOptions &options) {
359   size_t nb_channels = image.channel.size();
360   for (bool is_dc : {true, false}) {
361     size_t group_dim = options.group_dim * (is_dc ? kBlockDim : 1);
362     size_t c = image.nb_meta_channels;
363     for (; c < nb_channels; c++) {
364       const Channel &ch = image.channel[c];
365       if (ch.w > options.group_dim || ch.h > options.group_dim) break;
366     }
367     for (; c < nb_channels; c++) {
368       const Channel &ch = image.channel[c];
369       if (ch.w == 0 || ch.h == 0) continue;  // skip empty
370       bool is_dc_channel = std::min(ch.hshift, ch.vshift) >= 3;
371       if (is_dc_channel != is_dc) continue;
372       size_t tile_dim = group_dim >> std::max(ch.hshift, ch.vshift);
373       if (tile_dim == 0) {
374         return JXL_FAILURE("Inconsistent transforms");
375       }
376     }
377   }
378   return true;
379 }
380 
ModularDecode(BitReader * br,Image & image,GroupHeader & header,size_t group_id,ModularOptions * options,const Tree * global_tree,const ANSCode * global_code,const std::vector<uint8_t> * global_ctx_map,bool allow_truncated_group)381 Status ModularDecode(BitReader *br, Image &image, GroupHeader &header,
382                      size_t group_id, ModularOptions *options,
383                      const Tree *global_tree, const ANSCode *global_code,
384                      const std::vector<uint8_t> *global_ctx_map,
385                      bool allow_truncated_group) {
386   if (image.channel.empty()) return true;
387 
388   // decode transforms
389   JXL_RETURN_IF_ERROR(Bundle::Read(br, &header));
390   JXL_DEBUG_V(3, "Image data underwent %zu transformations: ",
391               header.transforms.size());
392   image.transform = header.transforms;
393   for (Transform &transform : image.transform) {
394     JXL_RETURN_IF_ERROR(transform.MetaApply(image));
395   }
396   if (image.error) {
397     return JXL_FAILURE("Corrupt file. Aborting.");
398   }
399   if (br->AllReadsWithinBounds()) {
400     // Only check if the transforms list is complete.
401     JXL_RETURN_IF_ERROR(ValidateChannelDimensions(image, *options));
402   }
403 
404   size_t nb_channels = image.channel.size();
405 
406   size_t num_chans = 0;
407   size_t distance_multiplier = 0;
408   for (size_t i = 0; i < nb_channels; i++) {
409     Channel &channel = image.channel[i];
410     if (!channel.w || !channel.h) {
411       continue;  // skip empty channels
412     }
413     if (i >= image.nb_meta_channels && (channel.w > options->max_chan_size ||
414                                         channel.h > options->max_chan_size)) {
415       break;
416     }
417     if (channel.w > distance_multiplier) {
418       distance_multiplier = channel.w;
419     }
420     num_chans++;
421   }
422   if (num_chans == 0) return true;
423 
424   // Read tree.
425   Tree tree_storage;
426   std::vector<uint8_t> context_map_storage;
427   ANSCode code_storage;
428   const Tree *tree = &tree_storage;
429   const ANSCode *code = &code_storage;
430   const std::vector<uint8_t> *context_map = &context_map_storage;
431   if (!header.use_global_tree) {
432     size_t max_tree_size = 1024;
433     for (size_t i = 0; i < nb_channels; i++) {
434       Channel &channel = image.channel[i];
435       if (!channel.w || !channel.h) {
436         continue;  // skip empty channels
437       }
438       if (i >= image.nb_meta_channels && (channel.w > options->max_chan_size ||
439                                           channel.h > options->max_chan_size)) {
440         break;
441       }
442       size_t pixels = channel.w * channel.h;
443       if (pixels / channel.w != channel.h) {
444         return JXL_FAILURE("Tree size overflow");
445       }
446       max_tree_size += pixels;
447       if (max_tree_size < pixels) return JXL_FAILURE("Tree size overflow");
448     }
449 
450     JXL_RETURN_IF_ERROR(DecodeTree(br, &tree_storage, max_tree_size));
451     JXL_RETURN_IF_ERROR(DecodeHistograms(br, (tree_storage.size() + 1) / 2,
452                                          &code_storage, &context_map_storage));
453   } else {
454     if (!global_tree || !global_code || !global_ctx_map ||
455         global_tree->empty()) {
456       return JXL_FAILURE("No global tree available but one was requested");
457     }
458     tree = global_tree;
459     code = global_code;
460     context_map = global_ctx_map;
461   }
462 
463   // Read channels
464   ANSSymbolReader reader(code, br, distance_multiplier);
465   for (size_t i = 0; i < nb_channels; i++) {
466     Channel &channel = image.channel[i];
467     if (!channel.w || !channel.h) {
468       continue;  // skip empty channels
469     }
470     if (i >= image.nb_meta_channels && (channel.w > options->max_chan_size ||
471                                         channel.h > options->max_chan_size)) {
472       break;
473     }
474     JXL_RETURN_IF_ERROR(DecodeModularChannelMAANS(br, &reader, *context_map,
475                                                   *tree, header.wp_header, i,
476                                                   group_id, &image));
477     // Truncated group.
478     if (!br->AllReadsWithinBounds()) {
479       if (!allow_truncated_group) return JXL_FAILURE("Truncated input");
480       ZeroFillImage(&channel.plane);
481       while (++i < nb_channels) ZeroFillImage(&image.channel[i].plane);
482       return Status(StatusCode::kNotEnoughBytes);
483     }
484   }
485   if (!reader.CheckANSFinalState()) {
486     return JXL_FAILURE("ANS decode final state failed");
487   }
488   return true;
489 }
490 
ModularGenericDecompress(BitReader * br,Image & image,GroupHeader * header,size_t group_id,ModularOptions * options,int undo_transforms,const Tree * tree,const ANSCode * code,const std::vector<uint8_t> * ctx_map,bool allow_truncated_group)491 Status ModularGenericDecompress(BitReader *br, Image &image,
492                                 GroupHeader *header, size_t group_id,
493                                 ModularOptions *options, int undo_transforms,
494                                 const Tree *tree, const ANSCode *code,
495                                 const std::vector<uint8_t> *ctx_map,
496                                 bool allow_truncated_group) {
497 #ifdef JXL_ENABLE_ASSERT
498   std::vector<std::pair<uint32_t, uint32_t>> req_sizes(image.channel.size());
499   for (size_t c = 0; c < req_sizes.size(); c++) {
500     req_sizes[c] = {image.channel[c].w, image.channel[c].h};
501   }
502 #endif
503   GroupHeader local_header;
504   if (header == nullptr) header = &local_header;
505   auto dec_status = ModularDecode(br, image, *header, group_id, options, tree,
506                                   code, ctx_map, allow_truncated_group);
507   if (!allow_truncated_group) JXL_RETURN_IF_ERROR(dec_status);
508   if (dec_status.IsFatalError()) return dec_status;
509   image.undo_transforms(header->wp_header, undo_transforms);
510   if (image.error) return JXL_FAILURE("Corrupt file. Aborting.");
511   size_t bit_pos = br->TotalBitsConsumed();
512   JXL_DEBUG_V(4, "Modular-decoded a %zux%zu nbchans=%zu image from %zu bytes",
513               image.w, image.h, image.channel.size(),
514               (br->TotalBitsConsumed() - bit_pos) / 8);
515   (void)bit_pos;
516 #ifdef JXL_ENABLE_ASSERT
517   // Check that after applying all transforms we are back to the requested image
518   // sizes, otherwise there's a programming error with the transformations.
519   if (undo_transforms == -1 || undo_transforms == 0) {
520     JXL_ASSERT(image.channel.size() == req_sizes.size());
521     for (size_t c = 0; c < req_sizes.size(); c++) {
522       JXL_ASSERT(req_sizes[c].first == image.channel[c].w);
523       JXL_ASSERT(req_sizes[c].second == image.channel[c].h);
524     }
525   }
526 #endif
527   return dec_status;
528 }
529 
530 }  // namespace jxl
531