1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
2 //
3 // Use of this source code is governed by a BSD-style
4 // license that can be found in the LICENSE file.
5 
6 #include "lib/jxl/dec_modular.h"
7 
8 #include <stdint.h>
9 
10 #include <vector>
11 
12 #include "lib/jxl/frame_header.h"
13 
14 #undef HWY_TARGET_INCLUDE
15 #define HWY_TARGET_INCLUDE "lib/jxl/dec_modular.cc"
16 #include <hwy/foreach_target.h>
17 #include <hwy/highway.h>
18 
19 #include "lib/jxl/alpha.h"
20 #include "lib/jxl/base/compiler_specific.h"
21 #include "lib/jxl/base/span.h"
22 #include "lib/jxl/base/status.h"
23 #include "lib/jxl/compressed_dc.h"
24 #include "lib/jxl/epf.h"
25 #include "lib/jxl/modular/encoding/encoding.h"
26 #include "lib/jxl/modular/modular_image.h"
27 #include "lib/jxl/modular/transform/transform.h"
28 
29 HWY_BEFORE_NAMESPACE();
30 namespace jxl {
31 namespace HWY_NAMESPACE {
32 
33 // These templates are not found via ADL.
34 using hwy::HWY_NAMESPACE::Rebind;
35 
MultiplySum(const size_t xsize,const pixel_type * const JXL_RESTRICT row_in,const pixel_type * const JXL_RESTRICT row_in_Y,const float factor,float * const JXL_RESTRICT row_out)36 void MultiplySum(const size_t xsize,
37                  const pixel_type* const JXL_RESTRICT row_in,
38                  const pixel_type* const JXL_RESTRICT row_in_Y,
39                  const float factor, float* const JXL_RESTRICT row_out) {
40   const HWY_FULL(float) df;
41   const Rebind<pixel_type, HWY_FULL(float)> di;  // assumes pixel_type <= float
42   const auto factor_v = Set(df, factor);
43   for (size_t x = 0; x < xsize; x += Lanes(di)) {
44     const auto in = Load(di, row_in + x) + Load(di, row_in_Y + x);
45     const auto out = ConvertTo(df, in) * factor_v;
46     Store(out, df, row_out + x);
47   }
48 }
49 
RgbFromSingle(const size_t xsize,const pixel_type * const JXL_RESTRICT row_in,const float factor,Image3F * decoded,size_t,size_t y,Rect & rect)50 void RgbFromSingle(const size_t xsize,
51                    const pixel_type* const JXL_RESTRICT row_in,
52                    const float factor, Image3F* decoded, size_t /*c*/, size_t y,
53                    Rect& rect) {
54   JXL_DASSERT(xsize <= rect.xsize());
55   const HWY_FULL(float) df;
56   const Rebind<pixel_type, HWY_FULL(float)> di;  // assumes pixel_type <= float
57 
58   float* const JXL_RESTRICT row_out_r = rect.PlaneRow(decoded, 0, y);
59   float* const JXL_RESTRICT row_out_g = rect.PlaneRow(decoded, 1, y);
60   float* const JXL_RESTRICT row_out_b = rect.PlaneRow(decoded, 2, y);
61 
62   const auto factor_v = Set(df, factor);
63   for (size_t x = 0; x < xsize; x += Lanes(di)) {
64     const auto in = Load(di, row_in + x);
65     const auto out = ConvertTo(df, in) * factor_v;
66     Store(out, df, row_out_r + x);
67     Store(out, df, row_out_g + x);
68     Store(out, df, row_out_b + x);
69   }
70 }
71 
72 // Same signature as RgbFromSingle so we can assign to the same pointer.
SingleFromSingle(const size_t xsize,const pixel_type * const JXL_RESTRICT row_in,const float factor,Image3F * decoded,size_t c,size_t y,Rect & rect)73 void SingleFromSingle(const size_t xsize,
74                       const pixel_type* const JXL_RESTRICT row_in,
75                       const float factor, Image3F* decoded, size_t c, size_t y,
76                       Rect& rect) {
77   JXL_DASSERT(xsize <= rect.xsize());
78   const HWY_FULL(float) df;
79   const Rebind<pixel_type, HWY_FULL(float)> di;  // assumes pixel_type <= float
80 
81   float* const JXL_RESTRICT row_out = rect.PlaneRow(decoded, c, y);
82 
83   const auto factor_v = Set(df, factor);
84   for (size_t x = 0; x < xsize; x += Lanes(di)) {
85     const auto in = Load(di, row_in + x);
86     const auto out = ConvertTo(df, in) * factor_v;
87     Store(out, df, row_out + x);
88   }
89 }
90 // NOLINTNEXTLINE(google-readability-namespace-comments)
91 }  // namespace HWY_NAMESPACE
92 }  // namespace jxl
93 HWY_AFTER_NAMESPACE();
94 
95 #if HWY_ONCE
96 namespace jxl {
97 HWY_EXPORT(MultiplySum);       // Local function
98 HWY_EXPORT(RgbFromSingle);     // Local function
99 HWY_EXPORT(SingleFromSingle);  // Local function
100 
101 // convert custom [bits]-bit float (with [exp_bits] exponent bits) stored as int
102 // back to binary32 float
int_to_float(const pixel_type * const JXL_RESTRICT row_in,float * const JXL_RESTRICT row_out,const size_t xsize,const int bits,const int exp_bits)103 void int_to_float(const pixel_type* const JXL_RESTRICT row_in,
104                   float* const JXL_RESTRICT row_out, const size_t xsize,
105                   const int bits, const int exp_bits) {
106   if (bits == 32) {
107     JXL_ASSERT(sizeof(pixel_type) == sizeof(float));
108     JXL_ASSERT(exp_bits == 8);
109     memcpy(row_out, row_in, xsize * sizeof(float));
110     return;
111   }
112   int exp_bias = (1 << (exp_bits - 1)) - 1;
113   int sign_shift = bits - 1;
114   int mant_bits = bits - exp_bits - 1;
115   int mant_shift = 23 - mant_bits;
116   for (size_t x = 0; x < xsize; ++x) {
117     uint32_t f;
118     memcpy(&f, &row_in[x], 4);
119     int signbit = (f >> sign_shift);
120     f &= (1 << sign_shift) - 1;
121     if (f == 0) {
122       row_out[x] = (signbit ? -0.f : 0.f);
123       continue;
124     }
125     int exp = (f >> mant_bits);
126     int mantissa = (f & ((1 << mant_bits) - 1));
127     mantissa <<= mant_shift;
128     // Try to normalize only if there is space for maneuver.
129     if (exp == 0 && exp_bits < 8) {
130       // subnormal number
131       while ((mantissa & 0x800000) == 0) {
132         mantissa <<= 1;
133         exp--;
134       }
135       exp++;
136       // remove leading 1 because it is implicit now
137       mantissa &= 0x7fffff;
138     }
139     exp -= exp_bias;
140     // broke up the arbitrary float into its parts, now reassemble into
141     // binary32
142     exp += 127;
143     JXL_ASSERT(exp >= 0);
144     f = (signbit ? 0x80000000 : 0);
145     f |= (exp << 23);
146     f |= mantissa;
147     memcpy(&row_out[x], &f, 4);
148   }
149 }
150 
DecodeGlobalInfo(BitReader * reader,const FrameHeader & frame_header,bool allow_truncated_group)151 Status ModularFrameDecoder::DecodeGlobalInfo(BitReader* reader,
152                                              const FrameHeader& frame_header,
153                                              bool allow_truncated_group) {
154   bool decode_color = frame_header.encoding == FrameEncoding::kModular;
155   const auto& metadata = frame_header.nonserialized_metadata->m;
156   bool is_gray = metadata.color_encoding.IsGray();
157   size_t nb_chans = 3;
158   if (is_gray && frame_header.color_transform == ColorTransform::kNone) {
159     nb_chans = 1;
160   }
161   bool has_tree = reader->ReadBits(1);
162   if (has_tree) {
163     size_t tree_size_limit =
164         1024 + frame_dim.xsize * frame_dim.ysize * nb_chans / 16;
165     JXL_RETURN_IF_ERROR(DecodeTree(reader, &tree, tree_size_limit));
166     JXL_RETURN_IF_ERROR(
167         DecodeHistograms(reader, (tree.size() + 1) / 2, &code, &context_map));
168   }
169   do_color = decode_color;
170   if (!do_color) nb_chans = 0;
171   size_t nb_extra = metadata.extra_channel_info.size();
172 
173   bool fp = metadata.bit_depth.floating_point_sample;
174 
175   // bits_per_sample is just metadata for XYB images.
176   if (metadata.bit_depth.bits_per_sample >= 32 && do_color &&
177       frame_header.color_transform != ColorTransform::kXYB) {
178     if (metadata.bit_depth.bits_per_sample == 32 && fp == false) {
179       return JXL_FAILURE("uint32_t not supported in dec_modular");
180     } else if (metadata.bit_depth.bits_per_sample > 32) {
181       return JXL_FAILURE("bits_per_sample > 32 not supported");
182     }
183   }
184 
185   Image gi(frame_dim.xsize, frame_dim.ysize, metadata.bit_depth.bits_per_sample,
186            nb_chans + nb_extra);
187 
188   all_same_shift = true;
189   if (frame_header.color_transform == ColorTransform::kYCbCr) {
190     for (size_t c = 0; c < nb_chans; c++) {
191       gi.channel[c].hshift = frame_header.chroma_subsampling.HShift(c);
192       gi.channel[c].vshift = frame_header.chroma_subsampling.VShift(c);
193       size_t xsize_shifted =
194           DivCeil(frame_dim.xsize, 1 << gi.channel[c].hshift);
195       size_t ysize_shifted =
196           DivCeil(frame_dim.ysize, 1 << gi.channel[c].vshift);
197       gi.channel[c].shrink(xsize_shifted, ysize_shifted);
198       if (gi.channel[c].hshift != gi.channel[0].hshift ||
199           gi.channel[c].vshift != gi.channel[0].vshift)
200         all_same_shift = false;
201     }
202   }
203 
204   for (size_t ec = 0, c = nb_chans; ec < nb_extra; ec++, c++) {
205     size_t ecups = frame_header.extra_channel_upsampling[ec];
206     gi.channel[c].shrink(DivCeil(frame_dim.xsize_upsampled, ecups),
207                          DivCeil(frame_dim.ysize_upsampled, ecups));
208     gi.channel[c].hshift = gi.channel[c].vshift =
209         CeilLog2Nonzero(ecups) - CeilLog2Nonzero(frame_header.upsampling);
210     if (gi.channel[c].hshift != gi.channel[0].hshift ||
211         gi.channel[c].vshift != gi.channel[0].vshift)
212       all_same_shift = false;
213   }
214 
215   ModularOptions options;
216   options.max_chan_size = frame_dim.group_dim;
217   options.group_dim = frame_dim.group_dim;
218   Status dec_status = ModularGenericDecompress(
219       reader, gi, &global_header, ModularStreamId::Global().ID(frame_dim),
220       &options,
221       /*undo_transforms=*/-2, &tree, &code, &context_map,
222       allow_truncated_group);
223   if (!allow_truncated_group) JXL_RETURN_IF_ERROR(dec_status);
224   if (dec_status.IsFatalError()) {
225     return JXL_FAILURE("Failed to decode global modular info");
226   }
227 
228   // TODO(eustas): are we sure this can be done after partial decode?
229   have_something = false;
230   for (size_t c = 0; c < gi.channel.size(); c++) {
231     Channel& gic = gi.channel[c];
232     if (c >= gi.nb_meta_channels && gic.w <= frame_dim.group_dim &&
233         gic.h <= frame_dim.group_dim)
234       have_something = true;
235   }
236   // move global transforms to groups if possible
237   if (!have_something && all_same_shift) {
238     if (gi.transform.size() == 1 && gi.transform[0].id == TransformId::kRCT) {
239       global_transform = gi.transform;
240       gi.transform.clear();
241       // TODO(jon): also move no-delta-palette out (trickier though)
242     }
243   }
244   full_image = std::move(gi);
245   return dec_status;
246 }
247 
MaybeDropFullImage()248 void ModularFrameDecoder::MaybeDropFullImage() {
249   if (full_image.transform.empty() && !have_something && all_same_shift) {
250     use_full_image = false;
251     for (auto& ch : full_image.channel) {
252       // keep metadata on channels around, but dealloc their planes
253       ch.plane = Plane<pixel_type>();
254     }
255   }
256 }
257 
DecodeGroup(const Rect & rect,BitReader * reader,int minShift,int maxShift,const ModularStreamId & stream,bool zerofill,PassesDecoderState * dec_state,ImageBundle * output)258 Status ModularFrameDecoder::DecodeGroup(const Rect& rect, BitReader* reader,
259                                         int minShift, int maxShift,
260                                         const ModularStreamId& stream,
261                                         bool zerofill,
262                                         PassesDecoderState* dec_state,
263                                         ImageBundle* output) {
264   JXL_DASSERT(stream.kind == ModularStreamId::kModularDC ||
265               stream.kind == ModularStreamId::kModularAC);
266   const size_t xsize = rect.xsize();
267   const size_t ysize = rect.ysize();
268   Image gi(xsize, ysize, full_image.bitdepth, 0);
269   // start at the first bigger-than-groupsize non-metachannel
270   size_t c = full_image.nb_meta_channels;
271   for (; c < full_image.channel.size(); c++) {
272     Channel& fc = full_image.channel[c];
273     if (fc.w > frame_dim.group_dim || fc.h > frame_dim.group_dim) break;
274   }
275   size_t beginc = c;
276   for (; c < full_image.channel.size(); c++) {
277     Channel& fc = full_image.channel[c];
278     int shift = std::min(fc.hshift, fc.vshift);
279     if (shift > maxShift) continue;
280     if (shift < minShift) continue;
281     Rect r(rect.x0() >> fc.hshift, rect.y0() >> fc.vshift,
282            rect.xsize() >> fc.hshift, rect.ysize() >> fc.vshift, fc.w, fc.h);
283     if (r.xsize() == 0 || r.ysize() == 0) continue;
284     if (zerofill && use_full_image) {
285       for (size_t y = 0; y < r.ysize(); ++y) {
286         pixel_type* const JXL_RESTRICT row_out = r.Row(&fc.plane, y);
287         memset(row_out, 0, r.xsize() * sizeof(*row_out));
288       }
289     } else {
290       Channel gc(r.xsize(), r.ysize());
291       if (zerofill) ZeroFillImage(&gc.plane);
292       gc.hshift = fc.hshift;
293       gc.vshift = fc.vshift;
294       gi.channel.emplace_back(std::move(gc));
295     }
296   }
297   if (zerofill && use_full_image) return true;
298   // Return early if there's nothing to decode. Otherwise there might be
299   // problems later (in ModularImageToDecodedRect).
300   if (gi.channel.empty()) return true;
301   ModularOptions options;
302   if (!zerofill) {
303     if (!ModularGenericDecompress(
304             reader, gi, /*header=*/nullptr, stream.ID(frame_dim), &options,
305             /*undo_transforms=*/-1, &tree, &code, &context_map)) {
306       return JXL_FAILURE("Failed to decode modular group");
307     }
308   }
309   // Undo global transforms that have been pushed to the group level
310   if (!use_full_image) {
311     for (auto t : global_transform) {
312       JXL_RETURN_IF_ERROR(t.Inverse(gi, global_header.wp_header));
313     }
314     JXL_RETURN_IF_ERROR(ModularImageToDecodedRect(
315         gi, dec_state, nullptr, output, rect.Crop(dec_state->decoded)));
316     return true;
317   }
318   int gic = 0;
319   for (c = beginc; c < full_image.channel.size(); c++) {
320     Channel& fc = full_image.channel[c];
321     int shift = std::min(fc.hshift, fc.vshift);
322     if (shift > maxShift) continue;
323     if (shift < minShift) continue;
324     Rect r(rect.x0() >> fc.hshift, rect.y0() >> fc.vshift,
325            rect.xsize() >> fc.hshift, rect.ysize() >> fc.vshift, fc.w, fc.h);
326     if (r.xsize() == 0 || r.ysize() == 0) continue;
327     JXL_ASSERT(use_full_image);
328     CopyImageTo(/*rect_from=*/Rect(0, 0, r.xsize(), r.ysize()),
329                 /*from=*/gi.channel[gic].plane,
330                 /*rect_to=*/r, /*to=*/&fc.plane);
331     gic++;
332   }
333   return true;
334 }
DecodeVarDCTDC(size_t group_id,BitReader * reader,PassesDecoderState * dec_state)335 Status ModularFrameDecoder::DecodeVarDCTDC(size_t group_id, BitReader* reader,
336                                            PassesDecoderState* dec_state) {
337   const Rect r = dec_state->shared->DCGroupRect(group_id);
338   // TODO(eustas): investigate if we could reduce the impact of
339   //               EvalRationalPolynomial; generally speaking, the limit is
340   //               2**(128/(3*magic)), where 128 comes from IEEE 754 exponent,
341   //               3 comes from XybToRgb that cubes the values, and "magic" is
342   //               the sum of all other contributions. 2**18 is known to lead
343   //               to NaN on input found by fuzzing (see commit message).
344   Image image(r.xsize(), r.ysize(), full_image.bitdepth, 3);
345   size_t stream_id = ModularStreamId::VarDCTDC(group_id).ID(frame_dim);
346   reader->Refill();
347   size_t extra_precision = reader->ReadFixedBits<2>();
348   float mul = 1.0f / (1 << extra_precision);
349   ModularOptions options;
350   for (size_t c = 0; c < 3; c++) {
351     Channel& ch = image.channel[c < 2 ? c ^ 1 : c];
352     ch.w >>= dec_state->shared->frame_header.chroma_subsampling.HShift(c);
353     ch.h >>= dec_state->shared->frame_header.chroma_subsampling.VShift(c);
354     ch.shrink();
355   }
356   if (!ModularGenericDecompress(
357           reader, image, /*header=*/nullptr, stream_id, &options,
358           /*undo_transforms=*/-1, &tree, &code, &context_map)) {
359     return JXL_FAILURE("Failed to decode modular DC group");
360   }
361   DequantDC(r, &dec_state->shared_storage.dc_storage,
362             &dec_state->shared_storage.quant_dc, image,
363             dec_state->shared->quantizer.MulDC(), mul,
364             dec_state->shared->cmap.DCFactors(),
365             dec_state->shared->frame_header.chroma_subsampling,
366             dec_state->shared->block_ctx_map);
367   return true;
368 }
369 
DecodeAcMetadata(size_t group_id,BitReader * reader,PassesDecoderState * dec_state)370 Status ModularFrameDecoder::DecodeAcMetadata(size_t group_id, BitReader* reader,
371                                              PassesDecoderState* dec_state) {
372   const Rect r = dec_state->shared->DCGroupRect(group_id);
373   size_t upper_bound = r.xsize() * r.ysize();
374   reader->Refill();
375   size_t count = reader->ReadBits(CeilLog2Nonzero(upper_bound)) + 1;
376   size_t stream_id = ModularStreamId::ACMetadata(group_id).ID(frame_dim);
377   // YToX, YToB, ACS + QF, EPF
378   Image image(r.xsize(), r.ysize(), full_image.bitdepth, 4);
379   static_assert(kColorTileDimInBlocks == 8, "Color tile size changed");
380   Rect cr(r.x0() >> 3, r.y0() >> 3, (r.xsize() + 7) >> 3, (r.ysize() + 7) >> 3);
381   image.channel[0] = Channel(cr.xsize(), cr.ysize(), 3, 3);
382   image.channel[1] = Channel(cr.xsize(), cr.ysize(), 3, 3);
383   image.channel[2] = Channel(count, 2, 0, 0);
384   ModularOptions options;
385   if (!ModularGenericDecompress(
386           reader, image, /*header=*/nullptr, stream_id, &options,
387           /*undo_transforms=*/-1, &tree, &code, &context_map)) {
388     return JXL_FAILURE("Failed to decode AC metadata");
389   }
390   ConvertPlaneAndClamp(Rect(image.channel[0].plane), image.channel[0].plane, cr,
391                        &dec_state->shared_storage.cmap.ytox_map);
392   ConvertPlaneAndClamp(Rect(image.channel[1].plane), image.channel[1].plane, cr,
393                        &dec_state->shared_storage.cmap.ytob_map);
394   size_t num = 0;
395   bool is444 = dec_state->shared->frame_header.chroma_subsampling.Is444();
396   auto& ac_strategy = dec_state->shared_storage.ac_strategy;
397   size_t xlim = std::min(ac_strategy.xsize(), r.x0() + r.xsize());
398   size_t ylim = std::min(ac_strategy.ysize(), r.y0() + r.ysize());
399   uint32_t local_used_acs = 0;
400   for (size_t iy = 0; iy < r.ysize(); iy++) {
401     size_t y = r.y0() + iy;
402     int* row_qf = r.Row(&dec_state->shared_storage.raw_quant_field, iy);
403     uint8_t* row_epf = r.Row(&dec_state->shared_storage.epf_sharpness, iy);
404     int* row_in_1 = image.channel[2].plane.Row(0);
405     int* row_in_2 = image.channel[2].plane.Row(1);
406     int* row_in_3 = image.channel[3].plane.Row(iy);
407     for (size_t ix = 0; ix < r.xsize(); ix++) {
408       size_t x = r.x0() + ix;
409       int sharpness = row_in_3[ix];
410       if (sharpness < 0 || sharpness >= LoopFilter::kEpfSharpEntries) {
411         return JXL_FAILURE("Corrupted sharpness field");
412       }
413       row_epf[ix] = sharpness;
414       if (ac_strategy.IsValid(x, y)) {
415         continue;
416       }
417 
418       if (num >= count) return JXL_FAILURE("Corrupted stream");
419 
420       if (!AcStrategy::IsRawStrategyValid(row_in_1[num])) {
421         return JXL_FAILURE("Invalid AC strategy");
422       }
423       local_used_acs |= 1u << row_in_1[num];
424       AcStrategy acs = AcStrategy::FromRawStrategy(row_in_1[num]);
425       if ((acs.covered_blocks_x() > 1 || acs.covered_blocks_y() > 1) &&
426           !is444) {
427         return JXL_FAILURE(
428             "AC strategy not compatible with chroma subsampling");
429       }
430       // Ensure that blocks do not overflow *AC* groups.
431       size_t next_x_ac_block = (x / kGroupDimInBlocks + 1) * kGroupDimInBlocks;
432       size_t next_y_ac_block = (y / kGroupDimInBlocks + 1) * kGroupDimInBlocks;
433       size_t next_x_dct_block = x + acs.covered_blocks_x();
434       size_t next_y_dct_block = y + acs.covered_blocks_y();
435       if (next_x_dct_block > next_x_ac_block || next_x_dct_block > xlim) {
436         return JXL_FAILURE("Invalid AC strategy, x overflow");
437       }
438       if (next_y_dct_block > next_y_ac_block || next_y_dct_block > ylim) {
439         return JXL_FAILURE("Invalid AC strategy, y overflow");
440       }
441       JXL_RETURN_IF_ERROR(
442           ac_strategy.SetNoBoundsCheck(x, y, AcStrategy::Type(row_in_1[num])));
443       row_qf[ix] =
444           1 + std::max(0, std::min(Quantizer::kQuantMax - 1, row_in_2[num]));
445       num++;
446     }
447   }
448   dec_state->used_acs |= local_used_acs;
449   if (dec_state->shared->frame_header.loop_filter.epf_iters > 0) {
450     ComputeSigma(r, dec_state);
451   }
452   return true;
453 }
454 
ModularImageToDecodedRect(Image & gi,PassesDecoderState * dec_state,jxl::ThreadPool * pool,ImageBundle * output,Rect rect)455 Status ModularFrameDecoder::ModularImageToDecodedRect(
456     Image& gi, PassesDecoderState* dec_state, jxl::ThreadPool* pool,
457     ImageBundle* output, Rect rect) {
458   auto& decoded = dec_state->decoded;
459   const auto& frame_header = dec_state->shared->frame_header;
460   const auto* metadata = frame_header.nonserialized_metadata;
461   size_t xsize = rect.xsize();
462   size_t ysize = rect.ysize();
463   if (!xsize || !ysize) {
464     return true;
465   }
466   JXL_DASSERT(rect.IsInside(decoded));
467 
468   size_t c = 0;
469   if (do_color) {
470     const bool rgb_from_gray =
471         metadata->m.color_encoding.IsGray() &&
472         frame_header.color_transform == ColorTransform::kNone;
473     const bool fp = metadata->m.bit_depth.floating_point_sample &&
474                     frame_header.color_transform != ColorTransform::kXYB;
475     for (; c < 3; c++) {
476       float factor = full_image.bitdepth < 32
477                          ? 1.f / ((1u << full_image.bitdepth) - 1)
478                          : 0;
479       size_t c_in = c;
480       if (frame_header.color_transform == ColorTransform::kXYB) {
481         factor = dec_state->shared->matrices.DCQuants()[c];
482         // XYB is encoded as YX(B-Y)
483         if (c < 2) c_in = 1 - c;
484       } else if (rgb_from_gray) {
485         c_in = 0;
486       }
487       JXL_ASSERT(c_in < gi.channel.size());
488       Channel& ch_in = gi.channel[c_in];
489       // TODO(eustas): could we detect it on earlier stage?
490       if (ch_in.w == 0 || ch_in.h == 0) {
491         return JXL_FAILURE("Empty image");
492       }
493       size_t xsize_shifted = DivCeil(xsize, 1 << ch_in.hshift);
494       size_t ysize_shifted = DivCeil(ysize, 1 << ch_in.vshift);
495       Rect r(rect.x0() >> ch_in.hshift, rect.y0() >> ch_in.vshift,
496              rect.xsize() >> ch_in.hshift, rect.ysize() >> ch_in.vshift,
497              DivCeil(decoded.xsize(), 1 << ch_in.hshift),
498              DivCeil(decoded.ysize(), 1 << ch_in.vshift));
499       if (r.ysize() != ch_in.h || r.xsize() != ch_in.w) {
500         return JXL_FAILURE(
501             "Dimension mismatch: trying to fit a %zux%zu modular channel into "
502             "a %zux%zu rect",
503             ch_in.w, ch_in.h, r.xsize(), r.ysize());
504       }
505       if (frame_header.color_transform == ColorTransform::kXYB && c == 2) {
506         JXL_ASSERT(!fp);
507         RunOnPool(
508             pool, 0, ysize_shifted, jxl::ThreadPool::SkipInit(),
509             [&](const int task, const int thread) {
510               const size_t y = task;
511               const pixel_type* const JXL_RESTRICT row_in = ch_in.Row(y);
512               const pixel_type* const JXL_RESTRICT row_in_Y =
513                   gi.channel[0].Row(y);
514               float* const JXL_RESTRICT row_out = r.PlaneRow(&decoded, c, y);
515               HWY_DYNAMIC_DISPATCH(MultiplySum)
516               (xsize_shifted, row_in, row_in_Y, factor, row_out);
517             },
518             "ModularIntToFloat");
519       } else if (fp) {
520         int bits = metadata->m.bit_depth.bits_per_sample;
521         int exp_bits = metadata->m.bit_depth.exponent_bits_per_sample;
522         RunOnPool(
523             pool, 0, ysize_shifted, jxl::ThreadPool::SkipInit(),
524             [&](const int task, const int thread) {
525               const size_t y = task;
526               const pixel_type* const JXL_RESTRICT row_in = ch_in.Row(y);
527               float* const JXL_RESTRICT row_out = r.PlaneRow(&decoded, c, y);
528               int_to_float(row_in, row_out, xsize_shifted, bits, exp_bits);
529             },
530             "ModularIntToFloat_losslessfloat");
531       } else {
532         RunOnPool(
533             pool, 0, ysize_shifted, jxl::ThreadPool::SkipInit(),
534             [&](const int task, const int thread) {
535               const size_t y = task;
536               const pixel_type* const JXL_RESTRICT row_in = ch_in.Row(y);
537               if (rgb_from_gray) {
538                 HWY_DYNAMIC_DISPATCH(RgbFromSingle)
539                 (xsize_shifted, row_in, factor, &decoded, c, y, r);
540               } else {
541                 HWY_DYNAMIC_DISPATCH(SingleFromSingle)
542                 (xsize_shifted, row_in, factor, &decoded, c, y, r);
543               }
544             },
545             "ModularIntToFloat");
546       }
547       if (rgb_from_gray) {
548         break;
549       }
550     }
551     if (rgb_from_gray) {
552       c = 1;
553     }
554   }
555   for (size_t ec = 0; ec < dec_state->extra_channels.size(); ec++, c++) {
556     const ExtraChannelInfo& eci = output->metadata()->extra_channel_info[ec];
557     int bits = eci.bit_depth.bits_per_sample;
558     int exp_bits = eci.bit_depth.exponent_bits_per_sample;
559     bool fp = eci.bit_depth.floating_point_sample;
560     JXL_ASSERT(fp || bits < 32);
561     const float mul = fp ? 0 : (1.0f / ((1u << bits) - 1));
562     size_t ecups = frame_header.extra_channel_upsampling[ec];
563     const size_t ec_xsize = DivCeil(frame_dim.xsize_upsampled, ecups);
564     const size_t ec_ysize = DivCeil(frame_dim.ysize_upsampled, ecups);
565     JXL_ASSERT(c < gi.channel.size());
566     Channel& ch_in = gi.channel[c];
567     // For x0, y0 there's no need to do a DivCeil().
568     JXL_DASSERT(rect.x0() % (1ul << ch_in.hshift) == 0);
569     JXL_DASSERT(rect.y0() % (1ul << ch_in.vshift) == 0);
570     Rect r(rect.x0() >> ch_in.hshift, rect.y0() >> ch_in.vshift,
571            DivCeil(rect.xsize(), 1lu << ch_in.hshift),
572            DivCeil(rect.ysize(), 1lu << ch_in.vshift), ec_xsize, ec_ysize);
573 
574     JXL_DASSERT(r.IsInside(dec_state->extra_channels[ec]));
575     JXL_DASSERT(Rect(0, 0, r.xsize(), r.ysize()).IsInside(ch_in.plane));
576     for (size_t y = 0; y < r.ysize(); ++y) {
577       float* const JXL_RESTRICT row_out =
578           r.Row(&dec_state->extra_channels[ec], y);
579       const pixel_type* const JXL_RESTRICT row_in = ch_in.Row(y);
580       if (fp) {
581         int_to_float(row_in, row_out, r.xsize(), bits, exp_bits);
582       } else {
583         for (size_t x = 0; x < r.xsize(); ++x) {
584           row_out[x] = row_in[x] * mul;
585         }
586       }
587     }
588     JXL_CHECK_IMAGE_INITIALIZED(dec_state->extra_channels[ec], r);
589   }
590   return true;
591 }
592 
FinalizeDecoding(PassesDecoderState * dec_state,jxl::ThreadPool * pool,ImageBundle * output)593 Status ModularFrameDecoder::FinalizeDecoding(PassesDecoderState* dec_state,
594                                              jxl::ThreadPool* pool,
595                                              ImageBundle* output) {
596   if (!use_full_image) return true;
597   Image& gi = full_image;
598   size_t xsize = gi.w;
599   size_t ysize = gi.h;
600 
601   // Don't use threads if total image size is smaller than a group
602   if (xsize * ysize < frame_dim.group_dim * frame_dim.group_dim) pool = nullptr;
603 
604   // Undo the global transforms
605   gi.undo_transforms(global_header.wp_header, -1, pool);
606   for (auto t : global_transform) {
607     JXL_RETURN_IF_ERROR(t.Inverse(gi, global_header.wp_header));
608   }
609   if (gi.error) return JXL_FAILURE("Undoing transforms failed");
610 
611   auto& decoded = dec_state->decoded;
612 
613   JXL_RETURN_IF_ERROR(
614       ModularImageToDecodedRect(gi, dec_state, pool, output, Rect(decoded)));
615   return true;
616 }
617 
618 static constexpr const float kAlmostZero = 1e-8f;
619 
DecodeQuantTable(size_t required_size_x,size_t required_size_y,BitReader * br,QuantEncoding * encoding,size_t idx,ModularFrameDecoder * modular_frame_decoder)620 Status ModularFrameDecoder::DecodeQuantTable(
621     size_t required_size_x, size_t required_size_y, BitReader* br,
622     QuantEncoding* encoding, size_t idx,
623     ModularFrameDecoder* modular_frame_decoder) {
624   JXL_RETURN_IF_ERROR(F16Coder::Read(br, &encoding->qraw.qtable_den));
625   if (encoding->qraw.qtable_den < kAlmostZero) {
626     // qtable[] values are already checked for <= 0 so the denominator may not
627     // be negative.
628     return JXL_FAILURE("Invalid qtable_den: value too small");
629   }
630   Image image(required_size_x, required_size_y, 8, 3);
631   ModularOptions options;
632   if (modular_frame_decoder) {
633     JXL_RETURN_IF_ERROR(ModularGenericDecompress(
634         br, image, /*header=*/nullptr,
635         ModularStreamId::QuantTable(idx).ID(modular_frame_decoder->frame_dim),
636         &options, /*undo_transforms=*/-1, &modular_frame_decoder->tree,
637         &modular_frame_decoder->code, &modular_frame_decoder->context_map));
638   } else {
639     JXL_RETURN_IF_ERROR(ModularGenericDecompress(br, image, /*header=*/nullptr,
640                                                  0, &options,
641                                                  /*undo_transforms=*/-1));
642   }
643   if (!encoding->qraw.qtable) {
644     encoding->qraw.qtable = new std::vector<int>();
645   }
646   encoding->qraw.qtable->resize(required_size_x * required_size_y * 3);
647   for (size_t c = 0; c < 3; c++) {
648     for (size_t y = 0; y < required_size_y; y++) {
649       int* JXL_RESTRICT row = image.channel[c].Row(y);
650       for (size_t x = 0; x < required_size_x; x++) {
651         (*encoding->qraw.qtable)[c * required_size_x * required_size_y +
652                                  y * required_size_x + x] = row[x];
653         if (row[x] <= 0) {
654           return JXL_FAILURE("Invalid raw quantization table");
655         }
656       }
657     }
658   }
659   return true;
660 }
661 
662 }  // namespace jxl
663 #endif  // HWY_ONCE
664