1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
2 //
3 // Use of this source code is governed by a BSD-style
4 // license that can be found in the LICENSE file.
5 
6 #include "lib/jxl/dec_modular.h"
7 
8 #include <stdint.h>
9 
10 #include <vector>
11 
12 #include "lib/jxl/frame_header.h"
13 
14 #undef HWY_TARGET_INCLUDE
15 #define HWY_TARGET_INCLUDE "lib/jxl/dec_modular.cc"
16 #include <hwy/foreach_target.h>
17 #include <hwy/highway.h>
18 
19 #include "lib/jxl/alpha.h"
20 #include "lib/jxl/base/compiler_specific.h"
21 #include "lib/jxl/base/span.h"
22 #include "lib/jxl/base/status.h"
23 #include "lib/jxl/compressed_dc.h"
24 #include "lib/jxl/epf.h"
25 #include "lib/jxl/modular/encoding/encoding.h"
26 #include "lib/jxl/modular/modular_image.h"
27 HWY_BEFORE_NAMESPACE();
28 namespace jxl {
29 namespace HWY_NAMESPACE {
30 
31 // These templates are not found via ADL.
32 using hwy::HWY_NAMESPACE::Rebind;
33 
MultiplySum(const size_t xsize,const pixel_type * const JXL_RESTRICT row_in,const pixel_type * const JXL_RESTRICT row_in_Y,const float factor,float * const JXL_RESTRICT row_out)34 void MultiplySum(const size_t xsize,
35                  const pixel_type* const JXL_RESTRICT row_in,
36                  const pixel_type* const JXL_RESTRICT row_in_Y,
37                  const float factor, float* const JXL_RESTRICT row_out) {
38   const HWY_FULL(float) df;
39   const Rebind<pixel_type, HWY_FULL(float)> di;  // assumes pixel_type <= float
40   const auto factor_v = Set(df, factor);
41   for (size_t x = 0; x < xsize; x += Lanes(di)) {
42     const auto in = Load(di, row_in + x) + Load(di, row_in_Y + x);
43     const auto out = ConvertTo(df, in) * factor_v;
44     Store(out, df, row_out + x);
45   }
46 }
47 
RgbFromSingle(const size_t xsize,const pixel_type * const JXL_RESTRICT row_in,const float factor,Image3F * decoded,size_t,size_t y)48 void RgbFromSingle(const size_t xsize,
49                    const pixel_type* const JXL_RESTRICT row_in,
50                    const float factor, Image3F* decoded, size_t /*c*/,
51                    size_t y) {
52   const HWY_FULL(float) df;
53   const Rebind<pixel_type, HWY_FULL(float)> di;  // assumes pixel_type <= float
54 
55   float* const JXL_RESTRICT row_out_r = decoded->PlaneRow(0, y);
56   float* const JXL_RESTRICT row_out_g = decoded->PlaneRow(1, y);
57   float* const JXL_RESTRICT row_out_b = decoded->PlaneRow(2, y);
58 
59   const auto factor_v = Set(df, factor);
60   for (size_t x = 0; x < xsize; x += Lanes(di)) {
61     const auto in = Load(di, row_in + x);
62     const auto out = ConvertTo(df, in) * factor_v;
63     Store(out, df, row_out_r + x);
64     Store(out, df, row_out_g + x);
65     Store(out, df, row_out_b + x);
66   }
67 }
68 
69 // Same signature as RgbFromSingle so we can assign to the same pointer.
SingleFromSingle(const size_t xsize,const pixel_type * const JXL_RESTRICT row_in,const float factor,Image3F * decoded,size_t c,size_t y)70 void SingleFromSingle(const size_t xsize,
71                       const pixel_type* const JXL_RESTRICT row_in,
72                       const float factor, Image3F* decoded, size_t c,
73                       size_t y) {
74   const HWY_FULL(float) df;
75   const Rebind<pixel_type, HWY_FULL(float)> di;  // assumes pixel_type <= float
76 
77   float* const JXL_RESTRICT row_out = decoded->PlaneRow(c, y);
78 
79   const auto factor_v = Set(df, factor);
80   for (size_t x = 0; x < xsize; x += Lanes(di)) {
81     const auto in = Load(di, row_in + x);
82     const auto out = ConvertTo(df, in) * factor_v;
83     Store(out, df, row_out + x);
84   }
85 }
86 // NOLINTNEXTLINE(google-readability-namespace-comments)
87 }  // namespace HWY_NAMESPACE
88 }  // namespace jxl
89 HWY_AFTER_NAMESPACE();
90 
91 #if HWY_ONCE
92 namespace jxl {
93 HWY_EXPORT(MultiplySum);       // Local function
94 HWY_EXPORT(RgbFromSingle);     // Local function
95 HWY_EXPORT(SingleFromSingle);  // Local function
96 
97 // convert custom [bits]-bit float (with [exp_bits] exponent bits) stored as int
98 // back to binary32 float
int_to_float(const pixel_type * const JXL_RESTRICT row_in,float * const JXL_RESTRICT row_out,const size_t xsize,const int bits,const int exp_bits)99 void int_to_float(const pixel_type* const JXL_RESTRICT row_in,
100                   float* const JXL_RESTRICT row_out, const size_t xsize,
101                   const int bits, const int exp_bits) {
102   if (bits == 32) {
103     JXL_ASSERT(sizeof(pixel_type) == sizeof(float));
104     JXL_ASSERT(exp_bits == 8);
105     memcpy(row_out, row_in, xsize * sizeof(float));
106     return;
107   }
108   int exp_bias = (1 << (exp_bits - 1)) - 1;
109   int sign_shift = bits - 1;
110   int mant_bits = bits - exp_bits - 1;
111   int mant_shift = 23 - mant_bits;
112   for (size_t x = 0; x < xsize; ++x) {
113     uint32_t f;
114     memcpy(&f, &row_in[x], 4);
115     int signbit = (f >> sign_shift);
116     f &= (1 << sign_shift) - 1;
117     if (f == 0) {
118       row_out[x] = (signbit ? -0.f : 0.f);
119       continue;
120     }
121     int exp = (f >> mant_bits);
122     int mantissa = (f & ((1 << mant_bits) - 1));
123     mantissa <<= mant_shift;
124     // Try to normalize only if there is space for maneuver.
125     if (exp == 0 && exp_bits < 8) {
126       // subnormal number
127       while ((mantissa & 0x800000) == 0) {
128         mantissa <<= 1;
129         exp--;
130       }
131       exp++;
132       // remove leading 1 because it is implicit now
133       mantissa &= 0x7fffff;
134     }
135     exp -= exp_bias;
136     // broke up the arbitrary float into its parts, now reassemble into
137     // binary32
138     exp += 127;
139     JXL_ASSERT(exp >= 0);
140     f = (signbit ? 0x80000000 : 0);
141     f |= (exp << 23);
142     f |= mantissa;
143     memcpy(&row_out[x], &f, 4);
144   }
145 }
146 
DecodeGlobalInfo(BitReader * reader,const FrameHeader & frame_header,bool allow_truncated_group)147 Status ModularFrameDecoder::DecodeGlobalInfo(BitReader* reader,
148                                              const FrameHeader& frame_header,
149                                              bool allow_truncated_group) {
150   bool decode_color = frame_header.encoding == FrameEncoding::kModular;
151   const auto& metadata = frame_header.nonserialized_metadata->m;
152   bool is_gray = metadata.color_encoding.IsGray();
153   size_t nb_chans = 3;
154   if (is_gray && frame_header.color_transform == ColorTransform::kNone) {
155     nb_chans = 1;
156   }
157   bool has_tree = reader->ReadBits(1);
158   if (has_tree) {
159     size_t tree_size_limit =
160         1024 + frame_dim.xsize * frame_dim.ysize * nb_chans;
161     JXL_RETURN_IF_ERROR(DecodeTree(reader, &tree, tree_size_limit));
162     JXL_RETURN_IF_ERROR(
163         DecodeHistograms(reader, (tree.size() + 1) / 2, &code, &context_map));
164   }
165   do_color = decode_color;
166   if (!do_color) nb_chans = 0;
167   size_t nb_extra = metadata.extra_channel_info.size();
168 
169   bool fp = metadata.bit_depth.floating_point_sample;
170 
171   // bits_per_sample is just metadata for XYB images.
172   if (metadata.bit_depth.bits_per_sample >= 32 && do_color &&
173       frame_header.color_transform != ColorTransform::kXYB) {
174     if (metadata.bit_depth.bits_per_sample == 32 && fp == false) {
175       // TODO(lode): does modular support uint32_t? maxval is signed int so
176       // cannot represent 32 bits.
177       return JXL_FAILURE("uint32_t not supported in dec_modular");
178     } else if (metadata.bit_depth.bits_per_sample > 32) {
179       return JXL_FAILURE("bits_per_sample > 32 not supported");
180     }
181   }
182   // TODO(lode): must handle metadata.floating_point_channel?
183   int maxval =
184       (fp ? 1
185           : (1u << static_cast<uint32_t>(metadata.bit_depth.bits_per_sample)) -
186                 1);
187 
188   Image gi(frame_dim.xsize, frame_dim.ysize, maxval, nb_chans + nb_extra);
189 
190   if (frame_header.color_transform == ColorTransform::kYCbCr) {
191     for (size_t c = 0; c < nb_chans; c++) {
192       gi.channel[c].hshift = frame_header.chroma_subsampling.HShift(c);
193       gi.channel[c].vshift = frame_header.chroma_subsampling.VShift(c);
194       size_t xsize_shifted = DivCeil(frame_dim.xsize, 1 << gi.channel[c].hshift);
195       size_t ysize_shifted = DivCeil(frame_dim.ysize, 1 << gi.channel[c].vshift);
196       gi.channel[c].resize(xsize_shifted, ysize_shifted);
197     }
198   }
199 
200   for (size_t ec = 0, c = nb_chans; ec < nb_extra; ec++, c++) {
201     size_t ecups = frame_header.extra_channel_upsampling[ec];
202     gi.channel[c].resize(DivCeil(frame_dim.xsize_upsampled, ecups),
203                          DivCeil(frame_dim.ysize_upsampled, ecups));
204     gi.channel[c].hshift = gi.channel[c].vshift =
205         CeilLog2Nonzero(ecups) - CeilLog2Nonzero(frame_header.upsampling);
206   }
207 
208   ModularOptions options;
209   options.max_chan_size = frame_dim.group_dim;
210   Status dec_status = ModularGenericDecompress(
211       reader, gi, &global_header, ModularStreamId::Global().ID(frame_dim),
212       &options,
213       /*undo_transforms=*/-2, &tree, &code, &context_map,
214       allow_truncated_group);
215   if (!allow_truncated_group) JXL_RETURN_IF_ERROR(dec_status);
216   if (dec_status.IsFatalError()) {
217     return JXL_FAILURE("Failed to decode global modular info");
218   }
219 
220   // TODO(eustas): are we sure this can be done after partial decode?
221   // ensure all the channel buffers are allocated
222   have_something = false;
223   for (size_t c = 0; c < gi.channel.size(); c++) {
224     Channel& gic = gi.channel[c];
225     if (c >= gi.nb_meta_channels && gic.w < frame_dim.group_dim &&
226         gic.h < frame_dim.group_dim)
227       have_something = true;
228     gic.resize();
229   }
230   full_image = std::move(gi);
231   return dec_status;
232 }
233 
DecodeGroup(const Rect & rect,BitReader * reader,int minShift,int maxShift,const ModularStreamId & stream,bool zerofill)234 Status ModularFrameDecoder::DecodeGroup(const Rect& rect, BitReader* reader,
235                                         int minShift, int maxShift,
236                                         const ModularStreamId& stream,
237                                         bool zerofill) {
238   JXL_DASSERT(stream.kind == ModularStreamId::kModularDC ||
239               stream.kind == ModularStreamId::kModularAC);
240   const size_t xsize = rect.xsize();
241   const size_t ysize = rect.ysize();
242   int maxval = full_image.maxval;
243   Image gi(xsize, ysize, maxval, 0);
244   // start at the first bigger-than-groupsize non-metachannel
245   size_t c = full_image.nb_meta_channels;
246   for (; c < full_image.channel.size(); c++) {
247     Channel& fc = full_image.channel[c];
248     if (fc.w > frame_dim.group_dim || fc.h > frame_dim.group_dim) break;
249   }
250   size_t beginc = c;
251   for (; c < full_image.channel.size(); c++) {
252     Channel& fc = full_image.channel[c];
253     int shift = std::min(fc.hshift, fc.vshift);
254     if (shift > maxShift) continue;
255     if (shift < minShift) continue;
256     Rect r(rect.x0() >> fc.hshift, rect.y0() >> fc.vshift,
257            rect.xsize() >> fc.hshift, rect.ysize() >> fc.vshift, fc.w, fc.h);
258     if (r.xsize() == 0 || r.ysize() == 0) continue;
259     Channel gc(r.xsize(), r.ysize());
260     gc.hshift = fc.hshift;
261     gc.vshift = fc.vshift;
262     gi.channel.emplace_back(std::move(gc));
263   }
264   gi.nb_channels = gi.channel.size();
265   gi.real_nb_channels = gi.nb_channels;
266   if (zerofill) {
267     int gic = 0;
268     for (c = beginc; c < full_image.channel.size(); c++) {
269       Channel& fc = full_image.channel[c];
270       int shift = std::min(fc.hshift, fc.vshift);
271       if (shift > maxShift) continue;
272       if (shift < minShift) continue;
273       Rect r(rect.x0() >> fc.hshift, rect.y0() >> fc.vshift,
274              rect.xsize() >> fc.hshift, rect.ysize() >> fc.vshift, fc.w, fc.h);
275       if (r.xsize() == 0 || r.ysize() == 0) continue;
276       for (size_t y = 0; y < r.ysize(); ++y) {
277         pixel_type* const JXL_RESTRICT row_out = r.Row(&fc.plane, y);
278         memset(row_out, 0, r.xsize() * sizeof(*row_out));
279       }
280       gic++;
281     }
282     return true;
283   }
284   ModularOptions options;
285   if (!ModularGenericDecompress(
286           reader, gi, /*header=*/nullptr, stream.ID(frame_dim), &options,
287           /*undo_transforms=*/-1, &tree, &code, &context_map))
288     return JXL_FAILURE("Failed to decode modular group");
289   int gic = 0;
290   for (c = beginc; c < full_image.channel.size(); c++) {
291     Channel& fc = full_image.channel[c];
292     int shift = std::min(fc.hshift, fc.vshift);
293     if (shift > maxShift) continue;
294     if (shift < minShift) continue;
295     Rect r(rect.x0() >> fc.hshift, rect.y0() >> fc.vshift,
296            rect.xsize() >> fc.hshift, rect.ysize() >> fc.vshift, fc.w, fc.h);
297     if (r.xsize() == 0 || r.ysize() == 0) continue;
298     for (size_t y = 0; y < r.ysize(); ++y) {
299       pixel_type* const JXL_RESTRICT row_out = r.Row(&fc.plane, y);
300       const pixel_type* const JXL_RESTRICT row_in = gi.channel[gic].Row(y);
301       for (size_t x = 0; x < r.xsize(); ++x) {
302         row_out[x] = row_in[x];
303       }
304     }
305     gic++;
306   }
307   return true;
308 }
DecodeVarDCTDC(size_t group_id,BitReader * reader,PassesDecoderState * dec_state)309 Status ModularFrameDecoder::DecodeVarDCTDC(size_t group_id, BitReader* reader,
310                                            PassesDecoderState* dec_state) {
311   const Rect r = dec_state->shared->DCGroupRect(group_id);
312   // TODO(eustas): investigate if we could reduce the impact of
313   //               EvalRationalPolynomial; generally speaking, the limit is
314   //               2**(128/(3*magic)), where 128 comes from IEEE 754 exponent,
315   //               3 comes from XybToRgb that cubes the values, and "magic" is
316   //               the sum of all other contributions. 2**18 is known to lead
317   //               to NaN on input found by fuzzing (see commit message).
318   constexpr const int kRawDcLimit = 1 << 17;
319   Image image(r.xsize(), r.ysize(), kRawDcLimit, 3);
320   image.minval = -kRawDcLimit;
321   size_t stream_id = ModularStreamId::VarDCTDC(group_id).ID(frame_dim);
322   reader->Refill();
323   size_t extra_precision = reader->ReadFixedBits<2>();
324   float mul = 1.0f / (1 << extra_precision);
325   ModularOptions options;
326   for (size_t c = 0; c < 3; c++) {
327     Channel& ch = image.channel[c < 2 ? c ^ 1 : c];
328     ch.w >>= dec_state->shared->frame_header.chroma_subsampling.HShift(c);
329     ch.h >>= dec_state->shared->frame_header.chroma_subsampling.VShift(c);
330     ch.resize();
331   }
332   if (!ModularGenericDecompress(
333           reader, image, /*header=*/nullptr, stream_id, &options,
334           /*undo_transforms=*/0, &tree, &code, &context_map)) {
335     return JXL_FAILURE("Failed to decode modular DC group");
336   }
337   DequantDC(r, &dec_state->shared_storage.dc_storage,
338             &dec_state->shared_storage.quant_dc, image,
339             dec_state->shared->quantizer.MulDC(), mul,
340             dec_state->shared->cmap.DCFactors(),
341             dec_state->shared->frame_header.chroma_subsampling,
342             dec_state->shared->block_ctx_map);
343   return true;
344 }
345 
DecodeAcMetadata(size_t group_id,BitReader * reader,PassesDecoderState * dec_state)346 Status ModularFrameDecoder::DecodeAcMetadata(size_t group_id, BitReader* reader,
347                                              PassesDecoderState* dec_state) {
348   const Rect r = dec_state->shared->DCGroupRect(group_id);
349   size_t upper_bound = r.xsize() * r.ysize();
350   reader->Refill();
351   size_t count = reader->ReadBits(CeilLog2Nonzero(upper_bound)) + 1;
352   size_t stream_id = ModularStreamId::ACMetadata(group_id).ID(frame_dim);
353   // YToX, YToB, ACS + QF, EPF
354   Image image(r.xsize(), r.ysize(), 255, 4);
355   static_assert(kColorTileDimInBlocks == 8, "Color tile size changed");
356   Rect cr(r.x0() >> 3, r.y0() >> 3, (r.xsize() + 7) >> 3, (r.ysize() + 7) >> 3);
357   image.channel[0] = Channel(cr.xsize(), cr.ysize(), 3, 3);
358   image.channel[1] = Channel(cr.xsize(), cr.ysize(), 3, 3);
359   image.channel[2] = Channel(count, 2, 0, 0);
360   ModularOptions options;
361   if (!ModularGenericDecompress(
362           reader, image, /*header=*/nullptr, stream_id, &options,
363           /*undo_transforms=*/-1, &tree, &code, &context_map)) {
364     return JXL_FAILURE("Failed to decode AC metadata");
365   }
366   ConvertPlaneAndClamp(Rect(image.channel[0].plane), image.channel[0].plane, cr,
367                        &dec_state->shared_storage.cmap.ytox_map);
368   ConvertPlaneAndClamp(Rect(image.channel[1].plane), image.channel[1].plane, cr,
369                        &dec_state->shared_storage.cmap.ytob_map);
370   size_t num = 0;
371   bool is444 = dec_state->shared->frame_header.chroma_subsampling.Is444();
372   auto& ac_strategy = dec_state->shared_storage.ac_strategy;
373   size_t xlim = std::min(ac_strategy.xsize(), r.x0() + r.xsize());
374   size_t ylim = std::min(ac_strategy.ysize(), r.y0() + r.ysize());
375   uint32_t local_used_acs = 0;
376   for (size_t iy = 0; iy < r.ysize(); iy++) {
377     size_t y = r.y0() + iy;
378     int* row_qf = r.Row(&dec_state->shared_storage.raw_quant_field, iy);
379     uint8_t* row_epf = r.Row(&dec_state->shared_storage.epf_sharpness, iy);
380     int* row_in_1 = image.channel[2].plane.Row(0);
381     int* row_in_2 = image.channel[2].plane.Row(1);
382     int* row_in_3 = image.channel[3].plane.Row(iy);
383     for (size_t ix = 0; ix < r.xsize(); ix++) {
384       size_t x = r.x0() + ix;
385       int sharpness = row_in_3[ix];
386       if (sharpness < 0 || sharpness >= LoopFilter::kEpfSharpEntries) {
387         return JXL_FAILURE("Corrupted sharpness field");
388       }
389       row_epf[ix] = sharpness;
390       if (ac_strategy.IsValid(x, y)) {
391         continue;
392       }
393 
394       if (num >= count) return JXL_FAILURE("Corrupted stream");
395 
396       if (!AcStrategy::IsRawStrategyValid(row_in_1[num])) {
397         return JXL_FAILURE("Invalid AC strategy");
398       }
399       local_used_acs |= 1u << row_in_1[num];
400       AcStrategy acs = AcStrategy::FromRawStrategy(row_in_1[num]);
401       if ((acs.covered_blocks_x() > 1 || acs.covered_blocks_y() > 1) &&
402           !is444) {
403         return JXL_FAILURE(
404             "AC strategy not compatible with chroma subsampling");
405       }
406       // Ensure that blocks do not overflow *AC* groups.
407       size_t next_x_ac_block = (x / kGroupDimInBlocks + 1) * kGroupDimInBlocks;
408       size_t next_y_ac_block = (y / kGroupDimInBlocks + 1) * kGroupDimInBlocks;
409       size_t next_x_dct_block = x + acs.covered_blocks_x();
410       size_t next_y_dct_block = y + acs.covered_blocks_y();
411       if (next_x_dct_block > next_x_ac_block || next_x_dct_block > xlim) {
412         return JXL_FAILURE("Invalid AC strategy, x overflow");
413       }
414       if (next_y_dct_block > next_y_ac_block || next_y_dct_block > ylim) {
415         return JXL_FAILURE("Invalid AC strategy, y overflow");
416       }
417       JXL_RETURN_IF_ERROR(
418           ac_strategy.SetNoBoundsCheck(x, y, AcStrategy::Type(row_in_1[num])));
419       row_qf[ix] =
420           1 + std::max(0, std::min(Quantizer::kQuantMax - 1, row_in_2[num]));
421       num++;
422     }
423   }
424   dec_state->used_acs |= local_used_acs;
425   if (dec_state->shared->frame_header.loop_filter.epf_iters > 0) {
426     ComputeSigma(r, dec_state);
427   }
428   return true;
429 }
430 
FinalizeDecoding(PassesDecoderState * dec_state,jxl::ThreadPool * pool,ImageBundle * output)431 Status ModularFrameDecoder::FinalizeDecoding(PassesDecoderState* dec_state,
432                                              jxl::ThreadPool* pool,
433                                              ImageBundle* output) {
434   Image& gi = full_image;
435   size_t xsize = gi.w;
436   size_t ysize = gi.h;
437 
438   const auto& frame_header = dec_state->shared->frame_header;
439   const auto* metadata = frame_header.nonserialized_metadata;
440 
441   // Don't use threads if total image size is smaller than a group
442   if (xsize * ysize < frame_dim.group_dim * frame_dim.group_dim) pool = nullptr;
443 
444   // Undo the global transforms
445   gi.undo_transforms(global_header.wp_header, -1, pool);
446   if (gi.error) return JXL_FAILURE("Undoing transforms failed");
447 
448   auto& decoded = dec_state->decoded;
449 
450   int c = 0;
451   if (do_color) {
452     const bool rgb_from_gray =
453         metadata->m.color_encoding.IsGray() &&
454         frame_header.color_transform == ColorTransform::kNone;
455     const bool fp = metadata->m.bit_depth.floating_point_sample;
456 
457     for (; c < 3; c++) {
458       float factor = 1.f / (float)full_image.maxval;
459       int c_in = c;
460       if (frame_header.color_transform == ColorTransform::kXYB) {
461         factor = dec_state->shared->matrices.DCQuants()[c];
462         // XYB is encoded as YX(B-Y)
463         if (c < 2) c_in = 1 - c;
464       } else if (rgb_from_gray) {
465         c_in = 0;
466       }
467       // TODO(eustas): could we detect it on earlier stage?
468       if (gi.channel[c_in].w == 0 || gi.channel[c_in].h == 0) {
469         return JXL_FAILURE("Empty image");
470       }
471       size_t xsize_shifted = DivCeil(xsize, 1 << gi.channel[c_in].hshift);
472       size_t ysize_shifted = DivCeil(ysize, 1 << gi.channel[c_in].vshift);
473       if (ysize_shifted != gi.channel[c_in].h || xsize_shifted != gi.channel[c_in].w) {
474             return JXL_FAILURE("Dimension mismatch");
475       }
476       if (frame_header.color_transform == ColorTransform::kXYB && c == 2) {
477         JXL_ASSERT(!fp);
478         RunOnPool(
479             pool, 0, ysize_shifted, jxl::ThreadPool::SkipInit(),
480             [&](const int task, const int thread) {
481               const size_t y = task;
482               const pixel_type* const JXL_RESTRICT row_in =
483                   gi.channel[c_in].Row(y);
484               const pixel_type* const JXL_RESTRICT row_in_Y =
485                   gi.channel[0].Row(y);
486               float* const JXL_RESTRICT row_out = decoded.PlaneRow(c, y);
487               HWY_DYNAMIC_DISPATCH(MultiplySum)
488               (xsize_shifted, row_in, row_in_Y, factor, row_out);
489             },
490             "ModularIntToFloat");
491       } else if (fp) {
492         int bits = metadata->m.bit_depth.bits_per_sample;
493         int exp_bits = metadata->m.bit_depth.exponent_bits_per_sample;
494         RunOnPool(
495             pool, 0, ysize_shifted, jxl::ThreadPool::SkipInit(),
496             [&](const int task, const int thread) {
497               const size_t y = task;
498               const pixel_type* const JXL_RESTRICT row_in =
499                   gi.channel[c_in].Row(y);
500               float* const JXL_RESTRICT row_out = decoded.PlaneRow(c, y);
501               int_to_float(row_in, row_out, xsize_shifted, bits, exp_bits);
502             },
503             "ModularIntToFloat_losslessfloat");
504       } else {
505         RunOnPool(
506             pool, 0, ysize_shifted, jxl::ThreadPool::SkipInit(),
507             [&](const int task, const int thread) {
508               const size_t y = task;
509               const pixel_type* const JXL_RESTRICT row_in =
510                   gi.channel[c_in].Row(y);
511               if (rgb_from_gray) {
512                 HWY_DYNAMIC_DISPATCH(RgbFromSingle)
513                 (xsize_shifted, row_in, factor, &decoded, c, y);
514               } else {
515                 HWY_DYNAMIC_DISPATCH(SingleFromSingle)
516                 (xsize_shifted, row_in, factor, &decoded, c, y);
517               }
518             },
519             "ModularIntToFloat");
520       }
521       if (rgb_from_gray) {
522         break;
523       }
524     }
525     if (rgb_from_gray) {
526       c = 1;
527     }
528   }
529   for (size_t ec = 0; ec < dec_state->extra_channels.size(); ec++, c++) {
530     const ExtraChannelInfo& eci = output->metadata()->extra_channel_info[ec];
531     int bits = eci.bit_depth.bits_per_sample;
532     int exp_bits = eci.bit_depth.exponent_bits_per_sample;
533     bool fp = eci.bit_depth.floating_point_sample;
534     JXL_ASSERT(fp || bits < 32);
535     const float mul = fp ? 0 : (1.0f / ((1u << bits) - 1));
536     size_t ecups = frame_header.extra_channel_upsampling[ec];
537     const size_t ec_xsize = DivCeil(frame_dim.xsize_upsampled, ecups);
538     const size_t ec_ysize = DivCeil(frame_dim.ysize_upsampled, ecups);
539     for (size_t y = 0; y < ec_ysize; ++y) {
540       float* const JXL_RESTRICT row_out = dec_state->extra_channels[ec].Row(y);
541       const pixel_type* const JXL_RESTRICT row_in = gi.channel[c].Row(y);
542       if (fp) {
543         int_to_float(row_in, row_out, ec_xsize, bits, exp_bits);
544       } else {
545         for (size_t x = 0; x < ec_xsize; ++x) {
546           row_out[x] = row_in[x] * mul;
547         }
548       }
549     }
550   }
551   return true;
552 }
553 
554 static constexpr const float kAlmostZero = 1e-8f;
555 
DecodeQuantTable(size_t required_size_x,size_t required_size_y,BitReader * br,QuantEncoding * encoding,size_t idx,ModularFrameDecoder * modular_frame_decoder)556 Status ModularFrameDecoder::DecodeQuantTable(
557     size_t required_size_x, size_t required_size_y, BitReader* br,
558     QuantEncoding* encoding, size_t idx,
559     ModularFrameDecoder* modular_frame_decoder) {
560   JXL_RETURN_IF_ERROR(F16Coder::Read(br, &encoding->qraw.qtable_den));
561   if (encoding->qraw.qtable_den < kAlmostZero) {
562     // qtable[] values are already checked for <= 0 so the denominator may not
563     // be negative.
564     return JXL_FAILURE("Invalid qtable_den: value too small");
565   }
566   Image image(required_size_x, required_size_y, 255, 3);
567   ModularOptions options;
568   if (modular_frame_decoder) {
569     JXL_RETURN_IF_ERROR(ModularGenericDecompress(
570         br, image, /*header=*/nullptr,
571         ModularStreamId::QuantTable(idx).ID(modular_frame_decoder->frame_dim),
572         &options, /*undo_transforms=*/-1, &modular_frame_decoder->tree,
573         &modular_frame_decoder->code, &modular_frame_decoder->context_map));
574   } else {
575     JXL_RETURN_IF_ERROR(ModularGenericDecompress(br, image, /*header=*/nullptr,
576                                                  0, &options,
577                                                  /*undo_transforms=*/-1));
578   }
579   if (!encoding->qraw.qtable) {
580     encoding->qraw.qtable = new std::vector<int>();
581   }
582   encoding->qraw.qtable->resize(required_size_x * required_size_y * 3);
583   for (size_t c = 0; c < 3; c++) {
584     for (size_t y = 0; y < required_size_y; y++) {
585       int* JXL_RESTRICT row = image.channel[c].Row(y);
586       for (size_t x = 0; x < required_size_x; x++) {
587         (*encoding->qraw.qtable)[c * required_size_x * required_size_y +
588                                  y * required_size_x + x] = row[x];
589         if (row[x] <= 0) {
590           return JXL_FAILURE("Invalid raw quantization table");
591         }
592       }
593     }
594   }
595   return true;
596 }
597 
598 }  // namespace jxl
599 #endif  // HWY_ONCE
600