1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
2 //
3 // Use of this source code is governed by a BSD-style
4 // license that can be found in the LICENSE file.
5
6 #include "lib/jxl/modular/encoding/encoding.h"
7
8 #include <stdint.h>
9 #include <stdlib.h>
10
11 #include <queue>
12
13 #include "lib/jxl/modular/encoding/context_predict.h"
14 #include "lib/jxl/modular/options.h"
15
16 namespace jxl {
17
18 // Removes all nodes that use a static property (i.e. channel or group ID) from
19 // the tree and collapses each node on even levels with its two children to
20 // produce a flatter tree. Also computes whether the resulting tree requires
21 // using the weighted predictor.
FilterTree(const Tree & global_tree,std::array<pixel_type,kNumStaticProperties> & static_props,size_t * num_props,bool * use_wp,bool * wp_only,bool * gradient_only)22 FlatTree FilterTree(const Tree &global_tree,
23 std::array<pixel_type, kNumStaticProperties> &static_props,
24 size_t *num_props, bool *use_wp, bool *wp_only,
25 bool *gradient_only) {
26 *num_props = 0;
27 bool has_wp = false;
28 bool has_non_wp = false;
29 *gradient_only = true;
30 const auto mark_property = [&](int32_t p) {
31 if (p == kWPProp) {
32 has_wp = true;
33 } else if (p >= kNumStaticProperties) {
34 has_non_wp = true;
35 }
36 if (p >= kNumStaticProperties && p != kGradientProp) {
37 *gradient_only = false;
38 }
39 };
40 FlatTree output;
41 std::queue<size_t> nodes;
42 nodes.push(0);
43 // Produces a trimmed and flattened tree by doing a BFS visit of the original
44 // tree, ignoring branches that are known to be false and proceeding two
45 // levels at a time to collapse nodes in a flatter tree; if an inner parent
46 // node has a leaf as a child, the leaf is duplicated and an implicit fake
47 // node is added. This allows to reduce the number of branches when traversing
48 // the resulting flat tree.
49 while (!nodes.empty()) {
50 size_t cur = nodes.front();
51 nodes.pop();
52 // Skip nodes that we can decide now, by jumping directly to their children.
53 while (global_tree[cur].property < kNumStaticProperties &&
54 global_tree[cur].property != -1) {
55 if (static_props[global_tree[cur].property] > global_tree[cur].splitval) {
56 cur = global_tree[cur].lchild;
57 } else {
58 cur = global_tree[cur].rchild;
59 }
60 }
61 FlatDecisionNode flat;
62 if (global_tree[cur].property == -1) {
63 flat.property0 = -1;
64 flat.childID = global_tree[cur].lchild;
65 flat.predictor = global_tree[cur].predictor;
66 flat.predictor_offset = global_tree[cur].predictor_offset;
67 flat.multiplier = global_tree[cur].multiplier;
68 *gradient_only &= flat.predictor == Predictor::Gradient;
69 has_wp |= flat.predictor == Predictor::Weighted;
70 has_non_wp |= flat.predictor != Predictor::Weighted;
71 output.push_back(flat);
72 continue;
73 }
74 flat.childID = output.size() + nodes.size() + 1;
75
76 flat.property0 = global_tree[cur].property;
77 *num_props = std::max<size_t>(flat.property0 + 1, *num_props);
78 flat.splitval0 = global_tree[cur].splitval;
79
80 for (size_t i = 0; i < 2; i++) {
81 size_t cur_child =
82 i == 0 ? global_tree[cur].lchild : global_tree[cur].rchild;
83 // Skip nodes that we can decide now.
84 while (global_tree[cur_child].property < kNumStaticProperties &&
85 global_tree[cur_child].property != -1) {
86 if (static_props[global_tree[cur_child].property] >
87 global_tree[cur_child].splitval) {
88 cur_child = global_tree[cur_child].lchild;
89 } else {
90 cur_child = global_tree[cur_child].rchild;
91 }
92 }
93 // We ended up in a leaf, add a dummy decision and two copies of the leaf.
94 if (global_tree[cur_child].property == -1) {
95 flat.properties[i] = 0;
96 flat.splitvals[i] = 0;
97 nodes.push(cur_child);
98 nodes.push(cur_child);
99 } else {
100 flat.properties[i] = global_tree[cur_child].property;
101 flat.splitvals[i] = global_tree[cur_child].splitval;
102 nodes.push(global_tree[cur_child].lchild);
103 nodes.push(global_tree[cur_child].rchild);
104 *num_props = std::max<size_t>(flat.properties[i] + 1, *num_props);
105 }
106 }
107
108 for (size_t j = 0; j < 2; j++) mark_property(flat.properties[j]);
109 mark_property(flat.property0);
110 output.push_back(flat);
111 }
112 if (*num_props > kNumNonrefProperties) {
113 *num_props =
114 DivCeil(*num_props - kNumNonrefProperties, kExtraPropsPerChannel) *
115 kExtraPropsPerChannel +
116 kNumNonrefProperties;
117 } else {
118 *num_props = kNumNonrefProperties;
119 }
120 *use_wp = has_wp;
121 *wp_only = has_wp && !has_non_wp;
122
123 return output;
124 }
125
DecodeModularChannelMAANS(BitReader * br,ANSSymbolReader * reader,const std::vector<uint8_t> & context_map,const Tree & global_tree,const weighted::Header & wp_header,pixel_type chan,size_t group_id,Image * image)126 Status DecodeModularChannelMAANS(BitReader *br, ANSSymbolReader *reader,
127 const std::vector<uint8_t> &context_map,
128 const Tree &global_tree,
129 const weighted::Header &wp_header,
130 pixel_type chan, size_t group_id,
131 Image *image) {
132 Channel &channel = image->channel[chan];
133
134 std::array<pixel_type, kNumStaticProperties> static_props = {
135 {chan, (int)group_id}};
136 // TODO(veluca): filter the tree according to static_props.
137
138 // zero pixel channel? could happen
139 if (channel.w == 0 || channel.h == 0) return true;
140
141 bool tree_has_wp_prop_or_pred = false;
142 bool is_wp_only = false;
143 bool is_gradient_only = false;
144 size_t num_props;
145 FlatTree tree =
146 FilterTree(global_tree, static_props, &num_props,
147 &tree_has_wp_prop_or_pred, &is_wp_only, &is_gradient_only);
148
149 // From here on, tree lookup returns a *clustered* context ID.
150 // This avoids an extra memory lookup after tree traversal.
151 for (size_t i = 0; i < tree.size(); i++) {
152 if (tree[i].property0 == -1) {
153 tree[i].childID = context_map[tree[i].childID];
154 }
155 }
156
157 JXL_DEBUG_V(3, "Decoded MA tree with %zu nodes", tree.size());
158
159 // MAANS decode
160 const auto make_pixel = [](uint64_t v, pixel_type multiplier,
161 pixel_type_w offset) -> pixel_type {
162 JXL_DASSERT((v & 0xFFFFFFFF) == v);
163 pixel_type_w val = UnpackSigned(v);
164 // if it overflows, it overflows, and we have a problem anyway
165 return val * multiplier + offset;
166 };
167
168 if (tree.size() == 1) {
169 // special optimized case: no meta-adaptation, so no need
170 // to compute properties.
171 Predictor predictor = tree[0].predictor;
172 int64_t offset = tree[0].predictor_offset;
173 int32_t multiplier = tree[0].multiplier;
174 size_t ctx_id = tree[0].childID;
175 if (predictor == Predictor::Zero) {
176 uint32_t value;
177 if (reader->IsSingleValueAndAdvance(ctx_id, &value,
178 channel.w * channel.h)) {
179 // Special-case: histogram has a single symbol, with no extra bits, and
180 // we use ANS mode.
181 JXL_DEBUG_V(8, "Fastest track.");
182 pixel_type v = make_pixel(value, multiplier, offset);
183 for (size_t y = 0; y < channel.h; y++) {
184 pixel_type *JXL_RESTRICT r = channel.Row(y);
185 std::fill(r, r + channel.w, v);
186 }
187
188 } else {
189 JXL_DEBUG_V(8, "Fast track.");
190 for (size_t y = 0; y < channel.h; y++) {
191 pixel_type *JXL_RESTRICT r = channel.Row(y);
192 for (size_t x = 0; x < channel.w; x++) {
193 uint32_t v = reader->ReadHybridUintClustered(ctx_id, br);
194 r[x] = make_pixel(v, multiplier, offset);
195 }
196 }
197 }
198 } else if (predictor == Predictor::Gradient && offset == 0 &&
199 multiplier == 1) {
200 JXL_DEBUG_V(8, "Gradient very fast track.");
201 const intptr_t onerow = channel.plane.PixelsPerRow();
202 for (size_t y = 0; y < channel.h; y++) {
203 pixel_type *JXL_RESTRICT r = channel.Row(y);
204 for (size_t x = 0; x < channel.w; x++) {
205 pixel_type left = (x ? r[x - 1] : y ? *(r + x - onerow) : 0);
206 pixel_type top = (y ? *(r + x - onerow) : left);
207 pixel_type topleft = (x && y ? *(r + x - 1 - onerow) : left);
208 pixel_type guess = ClampedGradient(top, left, topleft);
209 uint64_t v = reader->ReadHybridUintClustered(ctx_id, br);
210 r[x] = make_pixel(v, 1, guess);
211 }
212 }
213 } else if (predictor != Predictor::Weighted) {
214 // special optimized case: no wp
215 JXL_DEBUG_V(8, "Quite fast track.");
216 const intptr_t onerow = channel.plane.PixelsPerRow();
217 for (size_t y = 0; y < channel.h; y++) {
218 pixel_type *JXL_RESTRICT r = channel.Row(y);
219 for (size_t x = 0; x < channel.w; x++) {
220 PredictionResult pred =
221 PredictNoTreeNoWP(channel.w, r + x, onerow, x, y, predictor);
222 pixel_type_w g = pred.guess + offset;
223 uint64_t v = reader->ReadHybridUintClustered(ctx_id, br);
224 // NOTE: pred.multiplier is unset.
225 r[x] = make_pixel(v, multiplier, g);
226 }
227 }
228 } else {
229 JXL_DEBUG_V(8, "Somewhat fast track.");
230 const intptr_t onerow = channel.plane.PixelsPerRow();
231 weighted::State wp_state(wp_header, channel.w, channel.h);
232 for (size_t y = 0; y < channel.h; y++) {
233 pixel_type *JXL_RESTRICT r = channel.Row(y);
234 for (size_t x = 0; x < channel.w; x++) {
235 pixel_type_w g = PredictNoTreeWP(channel.w, r + x, onerow, x, y,
236 predictor, &wp_state)
237 .guess +
238 offset;
239 uint64_t v = reader->ReadHybridUintClustered(ctx_id, br);
240 r[x] = make_pixel(v, multiplier, g);
241 wp_state.UpdateErrors(r[x], x, y, channel.w);
242 }
243 }
244 }
245 return true;
246 }
247
248 // Check if this tree is a WP-only tree with a small enough property value
249 // range.
250 // Initialized to avoid clang-tidy complaining.
251 uint8_t context_lookup[2 * kPropRangeFast] = {};
252 int8_t multipliers[2 * kPropRangeFast] = {};
253 int8_t offsets[2 * kPropRangeFast] = {};
254 if (is_wp_only) {
255 is_wp_only = TreeToLookupTable(tree, context_lookup, offsets, multipliers);
256 }
257 if (is_gradient_only) {
258 is_gradient_only =
259 TreeToLookupTable(tree, context_lookup, offsets, multipliers);
260 }
261
262 if (is_gradient_only) {
263 JXL_DEBUG_V(8, "Gradient fast track.");
264 const intptr_t onerow = channel.plane.PixelsPerRow();
265 for (size_t y = 0; y < channel.h; y++) {
266 pixel_type *JXL_RESTRICT r = channel.Row(y);
267 for (size_t x = 0; x < channel.w; x++) {
268 pixel_type_w left = (x ? r[x - 1] : y ? *(r + x - onerow) : 0);
269 pixel_type_w top = (y ? *(r + x - onerow) : left);
270 pixel_type_w topleft = (x && y ? *(r + x - 1 - onerow) : left);
271 int32_t guess = ClampedGradient(top, left, topleft);
272 uint32_t pos =
273 kPropRangeFast +
274 std::min<pixel_type_w>(
275 std::max<pixel_type_w>(-kPropRangeFast, top + left - topleft),
276 kPropRangeFast - 1);
277 uint32_t ctx_id = context_lookup[pos];
278 uint64_t v = reader->ReadHybridUintClustered(ctx_id, br);
279 r[x] = make_pixel(v, multipliers[pos],
280 static_cast<pixel_type_w>(offsets[pos]) + guess);
281 }
282 }
283 } else if (is_wp_only) {
284 JXL_DEBUG_V(8, "WP fast track.");
285 const intptr_t onerow = channel.plane.PixelsPerRow();
286 weighted::State wp_state(wp_header, channel.w, channel.h);
287 Properties properties(1);
288 for (size_t y = 0; y < channel.h; y++) {
289 pixel_type *JXL_RESTRICT r = channel.Row(y);
290 for (size_t x = 0; x < channel.w; x++) {
291 size_t offset = 0;
292 pixel_type_w left = (x ? r[x - 1] : y ? *(r + x - onerow) : 0);
293 pixel_type_w top = (y ? *(r + x - onerow) : left);
294 pixel_type_w topleft = (x && y ? *(r + x - 1 - onerow) : left);
295 pixel_type_w topright =
296 (x + 1 < channel.w && y ? *(r + x + 1 - onerow) : top);
297 pixel_type_w toptop = (y > 1 ? *(r + x - onerow - onerow) : top);
298 int32_t guess = wp_state.Predict</*compute_properties=*/true>(
299 x, y, channel.w, top, left, topright, topleft, toptop, &properties,
300 offset);
301 uint32_t pos =
302 kPropRangeFast + std::min(std::max(-kPropRangeFast, properties[0]),
303 kPropRangeFast - 1);
304 uint32_t ctx_id = context_lookup[pos];
305 uint64_t v = reader->ReadHybridUintClustered(ctx_id, br);
306 r[x] = make_pixel(v, multipliers[pos],
307 static_cast<pixel_type_w>(offsets[pos]) + guess);
308 wp_state.UpdateErrors(r[x], x, y, channel.w);
309 }
310 }
311 } else if (!tree_has_wp_prop_or_pred) {
312 // special optimized case: the weighted predictor and its properties are not
313 // used, so no need to compute weights and properties.
314 JXL_DEBUG_V(8, "Slow track.");
315 MATreeLookup tree_lookup(tree);
316 Properties properties = Properties(num_props);
317 const intptr_t onerow = channel.plane.PixelsPerRow();
318 Channel references(properties.size() - kNumNonrefProperties, channel.w);
319 for (size_t y = 0; y < channel.h; y++) {
320 pixel_type *JXL_RESTRICT p = channel.Row(y);
321 PrecomputeReferences(channel, y, *image, chan, &references);
322 InitPropsRow(&properties, static_props, y);
323 for (size_t x = 0; x < channel.w; x++) {
324 PredictionResult res =
325 PredictTreeNoWP(&properties, channel.w, p + x, onerow, x, y,
326 tree_lookup, references);
327 uint64_t v = reader->ReadHybridUintClustered(res.context, br);
328 p[x] = make_pixel(v, res.multiplier, res.guess);
329 }
330 }
331 } else {
332 JXL_DEBUG_V(8, "Slowest track.");
333 MATreeLookup tree_lookup(tree);
334 Properties properties = Properties(num_props);
335 const intptr_t onerow = channel.plane.PixelsPerRow();
336 Channel references(properties.size() - kNumNonrefProperties, channel.w);
337 weighted::State wp_state(wp_header, channel.w, channel.h);
338 for (size_t y = 0; y < channel.h; y++) {
339 pixel_type *JXL_RESTRICT p = channel.Row(y);
340 InitPropsRow(&properties, static_props, y);
341 PrecomputeReferences(channel, y, *image, chan, &references);
342 for (size_t x = 0; x < channel.w; x++) {
343 PredictionResult res =
344 PredictTreeWP(&properties, channel.w, p + x, onerow, x, y,
345 tree_lookup, references, &wp_state);
346 uint64_t v = reader->ReadHybridUintClustered(res.context, br);
347 p[x] = make_pixel(v, res.multiplier, res.guess);
348 wp_state.UpdateErrors(p[x], x, y, channel.w);
349 }
350 }
351 }
352 return true;
353 }
354
GroupHeader()355 GroupHeader::GroupHeader() { Bundle::Init(this); }
356
ValidateChannelDimensions(const Image & image,const ModularOptions & options)357 Status ValidateChannelDimensions(const Image &image,
358 const ModularOptions &options) {
359 size_t nb_channels = image.channel.size();
360 for (bool is_dc : {true, false}) {
361 size_t group_dim = options.group_dim * (is_dc ? kBlockDim : 1);
362 size_t c = image.nb_meta_channels;
363 for (; c < nb_channels; c++) {
364 const Channel &ch = image.channel[c];
365 if (ch.w > options.group_dim || ch.h > options.group_dim) break;
366 }
367 for (; c < nb_channels; c++) {
368 const Channel &ch = image.channel[c];
369 if (ch.w == 0 || ch.h == 0) continue; // skip empty
370 bool is_dc_channel = std::min(ch.hshift, ch.vshift) >= 3;
371 if (is_dc_channel != is_dc) continue;
372 size_t tile_dim = group_dim >> std::max(ch.hshift, ch.vshift);
373 if (tile_dim == 0) {
374 return JXL_FAILURE("Inconsistent transforms");
375 }
376 }
377 }
378 return true;
379 }
380
ModularDecode(BitReader * br,Image & image,GroupHeader & header,size_t group_id,ModularOptions * options,const Tree * global_tree,const ANSCode * global_code,const std::vector<uint8_t> * global_ctx_map,bool allow_truncated_group)381 Status ModularDecode(BitReader *br, Image &image, GroupHeader &header,
382 size_t group_id, ModularOptions *options,
383 const Tree *global_tree, const ANSCode *global_code,
384 const std::vector<uint8_t> *global_ctx_map,
385 bool allow_truncated_group) {
386 if (image.channel.empty()) return true;
387
388 // decode transforms
389 JXL_RETURN_IF_ERROR(Bundle::Read(br, &header));
390 JXL_DEBUG_V(3, "Image data underwent %zu transformations: ",
391 header.transforms.size());
392 image.transform = header.transforms;
393 for (Transform &transform : image.transform) {
394 JXL_RETURN_IF_ERROR(transform.MetaApply(image));
395 }
396 if (image.error) {
397 return JXL_FAILURE("Corrupt file. Aborting.");
398 }
399 if (br->AllReadsWithinBounds()) {
400 // Only check if the transforms list is complete.
401 JXL_RETURN_IF_ERROR(ValidateChannelDimensions(image, *options));
402 }
403
404 size_t nb_channels = image.channel.size();
405
406 size_t num_chans = 0;
407 size_t distance_multiplier = 0;
408 for (size_t i = 0; i < nb_channels; i++) {
409 Channel &channel = image.channel[i];
410 if (!channel.w || !channel.h) {
411 continue; // skip empty channels
412 }
413 if (i >= image.nb_meta_channels && (channel.w > options->max_chan_size ||
414 channel.h > options->max_chan_size)) {
415 break;
416 }
417 if (channel.w > distance_multiplier) {
418 distance_multiplier = channel.w;
419 }
420 num_chans++;
421 }
422 if (num_chans == 0) return true;
423
424 // Read tree.
425 Tree tree_storage;
426 std::vector<uint8_t> context_map_storage;
427 ANSCode code_storage;
428 const Tree *tree = &tree_storage;
429 const ANSCode *code = &code_storage;
430 const std::vector<uint8_t> *context_map = &context_map_storage;
431 if (!header.use_global_tree) {
432 size_t max_tree_size = 1024;
433 for (size_t i = 0; i < nb_channels; i++) {
434 Channel &channel = image.channel[i];
435 if (!channel.w || !channel.h) {
436 continue; // skip empty channels
437 }
438 if (i >= image.nb_meta_channels && (channel.w > options->max_chan_size ||
439 channel.h > options->max_chan_size)) {
440 break;
441 }
442 size_t pixels = channel.w * channel.h;
443 if (pixels / channel.w != channel.h) {
444 return JXL_FAILURE("Tree size overflow");
445 }
446 max_tree_size += pixels;
447 if (max_tree_size < pixels) return JXL_FAILURE("Tree size overflow");
448 }
449
450 JXL_RETURN_IF_ERROR(DecodeTree(br, &tree_storage, max_tree_size));
451 JXL_RETURN_IF_ERROR(DecodeHistograms(br, (tree_storage.size() + 1) / 2,
452 &code_storage, &context_map_storage));
453 } else {
454 if (!global_tree || !global_code || !global_ctx_map ||
455 global_tree->empty()) {
456 return JXL_FAILURE("No global tree available but one was requested");
457 }
458 tree = global_tree;
459 code = global_code;
460 context_map = global_ctx_map;
461 }
462
463 // Read channels
464 ANSSymbolReader reader(code, br, distance_multiplier);
465 for (size_t i = 0; i < nb_channels; i++) {
466 Channel &channel = image.channel[i];
467 if (!channel.w || !channel.h) {
468 continue; // skip empty channels
469 }
470 if (i >= image.nb_meta_channels && (channel.w > options->max_chan_size ||
471 channel.h > options->max_chan_size)) {
472 break;
473 }
474 JXL_RETURN_IF_ERROR(DecodeModularChannelMAANS(br, &reader, *context_map,
475 *tree, header.wp_header, i,
476 group_id, &image));
477 // Truncated group.
478 if (!br->AllReadsWithinBounds()) {
479 if (!allow_truncated_group) return JXL_FAILURE("Truncated input");
480 ZeroFillImage(&channel.plane);
481 while (++i < nb_channels) ZeroFillImage(&image.channel[i].plane);
482 return Status(StatusCode::kNotEnoughBytes);
483 }
484 }
485 if (!reader.CheckANSFinalState()) {
486 return JXL_FAILURE("ANS decode final state failed");
487 }
488 return true;
489 }
490
ModularGenericDecompress(BitReader * br,Image & image,GroupHeader * header,size_t group_id,ModularOptions * options,int undo_transforms,const Tree * tree,const ANSCode * code,const std::vector<uint8_t> * ctx_map,bool allow_truncated_group)491 Status ModularGenericDecompress(BitReader *br, Image &image,
492 GroupHeader *header, size_t group_id,
493 ModularOptions *options, int undo_transforms,
494 const Tree *tree, const ANSCode *code,
495 const std::vector<uint8_t> *ctx_map,
496 bool allow_truncated_group) {
497 #ifdef JXL_ENABLE_ASSERT
498 std::vector<std::pair<uint32_t, uint32_t>> req_sizes(image.channel.size());
499 for (size_t c = 0; c < req_sizes.size(); c++) {
500 req_sizes[c] = {image.channel[c].w, image.channel[c].h};
501 }
502 #endif
503 GroupHeader local_header;
504 if (header == nullptr) header = &local_header;
505 auto dec_status = ModularDecode(br, image, *header, group_id, options, tree,
506 code, ctx_map, allow_truncated_group);
507 if (!allow_truncated_group) JXL_RETURN_IF_ERROR(dec_status);
508 if (dec_status.IsFatalError()) return dec_status;
509 image.undo_transforms(header->wp_header, undo_transforms);
510 if (image.error) return JXL_FAILURE("Corrupt file. Aborting.");
511 size_t bit_pos = br->TotalBitsConsumed();
512 JXL_DEBUG_V(4, "Modular-decoded a %zux%zu nbchans=%zu image from %zu bytes",
513 image.w, image.h, image.channel.size(),
514 (br->TotalBitsConsumed() - bit_pos) / 8);
515 (void)bit_pos;
516 #ifdef JXL_ENABLE_ASSERT
517 // Check that after applying all transforms we are back to the requested image
518 // sizes, otherwise there's a programming error with the transformations.
519 if (undo_transforms == -1 || undo_transforms == 0) {
520 JXL_ASSERT(image.channel.size() == req_sizes.size());
521 for (size_t c = 0; c < req_sizes.size(); c++) {
522 JXL_ASSERT(req_sizes[c].first == image.channel[c].w);
523 JXL_ASSERT(req_sizes[c].second == image.channel[c].h);
524 }
525 }
526 #endif
527 return dec_status;
528 }
529
530 } // namespace jxl
531