1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
2 //
3 // Use of this source code is governed by a BSD-style
4 // license that can be found in the LICENSE file.
5 
6 #include "lib/jxl/dec_frame.h"
7 
8 #include <stddef.h>
9 #include <stdint.h>
10 
11 #include <algorithm>
12 #include <atomic>
13 #include <hwy/aligned_allocator.h>
14 #include <numeric>
15 #include <utility>
16 #include <vector>
17 
18 #include "lib/jxl/ac_context.h"
19 #include "lib/jxl/ac_strategy.h"
20 #include "lib/jxl/ans_params.h"
21 #include "lib/jxl/base/bits.h"
22 #include "lib/jxl/base/compiler_specific.h"
23 #include "lib/jxl/base/data_parallel.h"
24 #include "lib/jxl/base/profiler.h"
25 #include "lib/jxl/base/status.h"
26 #include "lib/jxl/chroma_from_luma.h"
27 #include "lib/jxl/coeff_order.h"
28 #include "lib/jxl/coeff_order_fwd.h"
29 #include "lib/jxl/color_management.h"
30 #include "lib/jxl/common.h"
31 #include "lib/jxl/compressed_dc.h"
32 #include "lib/jxl/dec_ans.h"
33 #include "lib/jxl/dec_bit_reader.h"
34 #include "lib/jxl/dec_cache.h"
35 #include "lib/jxl/dec_group.h"
36 #include "lib/jxl/dec_modular.h"
37 #include "lib/jxl/dec_params.h"
38 #include "lib/jxl/dec_patch_dictionary.h"
39 #include "lib/jxl/dec_reconstruct.h"
40 #include "lib/jxl/dec_upsample.h"
41 #include "lib/jxl/dec_xyb.h"
42 #include "lib/jxl/fields.h"
43 #include "lib/jxl/filters.h"
44 #include "lib/jxl/frame_header.h"
45 #include "lib/jxl/image.h"
46 #include "lib/jxl/image_bundle.h"
47 #include "lib/jxl/image_ops.h"
48 #include "lib/jxl/jpeg/jpeg_data.h"
49 #include "lib/jxl/loop_filter.h"
50 #include "lib/jxl/luminance.h"
51 #include "lib/jxl/passes_state.h"
52 #include "lib/jxl/quant_weights.h"
53 #include "lib/jxl/quantizer.h"
54 #include "lib/jxl/splines.h"
55 #include "lib/jxl/toc.h"
56 
57 namespace jxl {
58 
59 namespace {
DecodeGlobalDCInfo(BitReader * reader,bool is_jpeg,PassesDecoderState * state,ThreadPool * pool)60 Status DecodeGlobalDCInfo(BitReader* reader, bool is_jpeg,
61                           PassesDecoderState* state, ThreadPool* pool) {
62   PROFILER_FUNC;
63   JXL_RETURN_IF_ERROR(state->shared_storage.quantizer.Decode(reader));
64 
65   JXL_RETURN_IF_ERROR(
66       DecodeBlockCtxMap(reader, &state->shared_storage.block_ctx_map));
67 
68   JXL_RETURN_IF_ERROR(state->shared_storage.cmap.DecodeDC(reader));
69 
70   // Pre-compute info for decoding a group.
71   if (is_jpeg) {
72     state->shared_storage.quantizer.ClearDCMul();  // Don't dequant DC
73   }
74 
75   state->shared_storage.ac_strategy.FillInvalid();
76   return true;
77 }
78 }  // namespace
79 
DecodeFrameHeader(BitReader * JXL_RESTRICT reader,FrameHeader * JXL_RESTRICT frame_header)80 Status DecodeFrameHeader(BitReader* JXL_RESTRICT reader,
81                          FrameHeader* JXL_RESTRICT frame_header) {
82   JXL_ASSERT(frame_header->nonserialized_metadata != nullptr);
83   JXL_RETURN_IF_ERROR(ReadFrameHeader(reader, frame_header));
84   return true;
85 }
86 
SkipFrame(const CodecMetadata & metadata,BitReader * JXL_RESTRICT reader,bool is_preview)87 Status SkipFrame(const CodecMetadata& metadata, BitReader* JXL_RESTRICT reader,
88                  bool is_preview) {
89   FrameHeader header(&metadata);
90   header.nonserialized_is_preview = is_preview;
91   JXL_RETURN_IF_ERROR(DecodeFrameHeader(reader, &header));
92 
93   // Read TOC.
94   std::vector<uint64_t> group_offsets;
95   std::vector<uint32_t> group_sizes;
96   uint64_t groups_total_size;
97   const bool has_ac_global = true;
98   const FrameDimensions frame_dim = header.ToFrameDimensions();
99   const size_t toc_entries =
100       NumTocEntries(frame_dim.num_groups, frame_dim.num_dc_groups,
101                     header.passes.num_passes, has_ac_global);
102   JXL_RETURN_IF_ERROR(ReadGroupOffsets(toc_entries, reader, &group_offsets,
103                                        &group_sizes, &groups_total_size));
104 
105   // Pretend all groups are read.
106   reader->SkipBits(groups_total_size * kBitsPerByte);
107   if (reader->TotalBitsConsumed() > reader->TotalBytes() * kBitsPerByte) {
108     return JXL_FAILURE("Group code extends after stream end");
109   }
110 
111   return true;
112 }
113 
GetReaderForSection(size_t num_groups,size_t num_passes,size_t group_codes_begin,const std::vector<uint64_t> & group_offsets,const std::vector<uint32_t> & group_sizes,BitReader * JXL_RESTRICT reader,BitReader * JXL_RESTRICT store,size_t index)114 static BitReader* GetReaderForSection(
115     size_t num_groups, size_t num_passes, size_t group_codes_begin,
116     const std::vector<uint64_t>& group_offsets,
117     const std::vector<uint32_t>& group_sizes, BitReader* JXL_RESTRICT reader,
118     BitReader* JXL_RESTRICT store, size_t index) {
119   if (num_groups == 1 && num_passes == 1) return reader;
120   const size_t group_offset = group_codes_begin + group_offsets[index];
121   const size_t next_group_offset =
122       group_codes_begin + group_offsets[index] + group_sizes[index];
123   // The order of these variables must be:
124   // group_codes_begin <= group_offset <= next_group_offset <= file.size()
125   JXL_DASSERT(group_codes_begin <= group_offset);
126   JXL_DASSERT(group_offset <= next_group_offset);
127   JXL_DASSERT(next_group_offset <= reader->TotalBytes());
128   const size_t group_size = next_group_offset - group_offset;
129   const size_t remaining_size = reader->TotalBytes() - group_offset;
130   const size_t size = std::min(group_size + 8, remaining_size);
131   *store =
132       BitReader(Span<const uint8_t>(reader->FirstByte() + group_offset, size));
133   return store;
134 }
135 
DecodeFrame(const DecompressParams & dparams,PassesDecoderState * dec_state,ThreadPool * JXL_RESTRICT pool,BitReader * JXL_RESTRICT reader,ImageBundle * decoded,const CodecMetadata & metadata,const SizeConstraints * constraints,bool is_preview)136 Status DecodeFrame(const DecompressParams& dparams,
137                    PassesDecoderState* dec_state, ThreadPool* JXL_RESTRICT pool,
138                    BitReader* JXL_RESTRICT reader, ImageBundle* decoded,
139                    const CodecMetadata& metadata,
140                    const SizeConstraints* constraints, bool is_preview) {
141   PROFILER_ZONE("DecodeFrame uninstrumented");
142 
143   FrameDecoder frame_decoder(dec_state, metadata, pool);
144 
145   frame_decoder.SetFrameSizeLimits(constraints);
146 
147   JXL_RETURN_IF_ERROR(frame_decoder.InitFrame(
148       reader, decoded, is_preview, dparams.allow_partial_files,
149       dparams.allow_partial_files && dparams.allow_more_progressive_steps));
150 
151   // Handling of progressive decoding.
152   {
153     const FrameHeader& frame_header = frame_decoder.GetFrameHeader();
154     size_t max_passes = dparams.max_passes;
155     size_t max_downsampling = std::max(
156         dparams.max_downsampling >> (frame_header.dc_level * 3), size_t(1));
157     // TODO(veluca): deal with downsamplings >= 8.
158     if (max_downsampling >= 8) {
159       max_passes = 0;
160     } else {
161       for (uint32_t i = 0; i < frame_header.passes.num_downsample; ++i) {
162         if (max_downsampling >= frame_header.passes.downsample[i] &&
163             max_passes > frame_header.passes.last_pass[i]) {
164           max_passes = frame_header.passes.last_pass[i] + 1;
165         }
166       }
167     }
168     // Do not use downsampling for kReferenceOnly frames.
169     if (frame_header.frame_type == FrameType::kReferenceOnly) {
170       max_passes = frame_header.passes.num_passes;
171     }
172     max_passes = std::min<size_t>(max_passes, frame_header.passes.num_passes);
173     frame_decoder.SetMaxPasses(max_passes);
174   }
175   frame_decoder.SetRenderSpotcolors(dparams.render_spotcolors);
176 
177   size_t processed_bytes = reader->TotalBitsConsumed() / kBitsPerByte;
178 
179   Status close_ok = true;
180   std::vector<std::unique_ptr<BitReader>> section_readers;
181   {
182     std::vector<std::unique_ptr<BitReaderScopedCloser>> section_closers;
183     std::vector<FrameDecoder::SectionInfo> section_info;
184     std::vector<FrameDecoder::SectionStatus> section_status;
185     size_t bytes_to_skip = 0;
186     for (size_t i = 0; i < frame_decoder.NumSections(); i++) {
187       size_t b = frame_decoder.SectionOffsets()[i];
188       size_t e = b + frame_decoder.SectionSizes()[i];
189       bytes_to_skip += e - b;
190       size_t pos = reader->TotalBitsConsumed() / kBitsPerByte;
191       if (pos + e <= reader->TotalBytes()) {
192         auto br = make_unique<BitReader>(
193             Span<const uint8_t>(reader->FirstByte() + b + pos, e - b));
194         section_info.emplace_back(FrameDecoder::SectionInfo{br.get(), i});
195         section_closers.emplace_back(
196             make_unique<BitReaderScopedCloser>(br.get(), &close_ok));
197         section_readers.emplace_back(std::move(br));
198       } else if (!dparams.allow_partial_files) {
199         return JXL_FAILURE("Premature end of stream.");
200       }
201     }
202     // Skip over the to-be-decoded sections.
203     reader->SkipBits(kBitsPerByte * bytes_to_skip);
204     section_status.resize(section_info.size());
205 
206     JXL_RETURN_IF_ERROR(frame_decoder.ProcessSections(
207         section_info.data(), section_info.size(), section_status.data()));
208 
209     for (size_t i = 0; i < section_status.size(); i++) {
210       auto s = section_status[i];
211       if (s == FrameDecoder::kDone) {
212         processed_bytes += frame_decoder.SectionSizes()[i];
213         continue;
214       }
215       if (dparams.allow_more_progressive_steps && s == FrameDecoder::kPartial) {
216         continue;
217       }
218       if (dparams.max_downsampling > 1 && s == FrameDecoder::kSkipped) {
219         continue;
220       }
221       return JXL_FAILURE("Invalid section %zu status: %d", section_info[i].id,
222                          s);
223     }
224   }
225 
226   JXL_RETURN_IF_ERROR(close_ok);
227 
228   JXL_RETURN_IF_ERROR(frame_decoder.FinalizeFrame());
229   decoded->SetDecodedBytes(processed_bytes);
230   return true;
231 }
232 
InitFrame(BitReader * JXL_RESTRICT br,ImageBundle * decoded,bool is_preview,bool allow_partial_frames,bool allow_partial_dc_global)233 Status FrameDecoder::InitFrame(BitReader* JXL_RESTRICT br, ImageBundle* decoded,
234                                bool is_preview, bool allow_partial_frames,
235                                bool allow_partial_dc_global) {
236   PROFILER_FUNC;
237   decoded_ = decoded;
238   JXL_ASSERT(is_finalized_);
239 
240   allow_partial_frames_ = allow_partial_frames;
241   allow_partial_dc_global_ = allow_partial_dc_global;
242 
243   // Reset the dequantization matrices to their default values.
244   dec_state_->shared_storage.matrices = DequantMatrices();
245 
246   frame_header_.nonserialized_is_preview = is_preview;
247   JXL_RETURN_IF_ERROR(DecodeFrameHeader(br, &frame_header_));
248   frame_dim_ = frame_header_.ToFrameDimensions();
249 
250   const size_t num_passes = frame_header_.passes.num_passes;
251   const size_t xsize = frame_dim_.xsize;
252   const size_t ysize = frame_dim_.ysize;
253   const size_t num_groups = frame_dim_.num_groups;
254 
255   // Check validity of frame dimensions.
256   JXL_RETURN_IF_ERROR(VerifyDimensions(constraints_, xsize, ysize));
257 
258   // If the previous frame was not a kRegularFrame, `decoded` may have different
259   // dimensions; must reset to avoid errors.
260   decoded->RemoveColor();
261   decoded->ClearExtraChannels();
262 
263   // Read TOC.
264   uint64_t groups_total_size;
265   const bool has_ac_global = true;
266   const size_t toc_entries = NumTocEntries(num_groups, frame_dim_.num_dc_groups,
267                                            num_passes, has_ac_global);
268   JXL_RETURN_IF_ERROR(ReadGroupOffsets(toc_entries, br, &section_offsets_,
269                                        &section_sizes_, &groups_total_size));
270 
271   JXL_DASSERT((br->TotalBitsConsumed() % kBitsPerByte) == 0);
272   const size_t group_codes_begin = br->TotalBitsConsumed() / kBitsPerByte;
273   JXL_DASSERT(!section_offsets_.empty());
274 
275   // Overflow check.
276   if (group_codes_begin + groups_total_size < group_codes_begin) {
277     return JXL_FAILURE("Invalid group codes");
278   }
279 
280   if (!frame_header_.chroma_subsampling.Is444() &&
281       !(frame_header_.flags & FrameHeader::kSkipAdaptiveDCSmoothing) &&
282       frame_header_.encoding == FrameEncoding::kVarDCT) {
283     return JXL_FAILURE(
284         "Non-444 chroma subsampling is not allowed when adaptive DC "
285         "smoothing is enabled");
286   }
287   JXL_RETURN_IF_ERROR(
288       InitializePassesSharedState(frame_header_, &dec_state_->shared_storage));
289   dec_state_->Init();
290   modular_frame_decoder_.Init(frame_dim_);
291 
292   if (decoded->IsJPEG()) {
293     if (frame_header_.encoding == FrameEncoding::kModular) {
294       return JXL_FAILURE("Cannot output JPEG from Modular");
295     }
296     jpeg::JPEGData* jpeg_data = decoded->jpeg_data.get();
297     if (jpeg_data->components.size() != 1 &&
298         jpeg_data->components.size() != 3) {
299       return JXL_FAILURE("Invalid number of components");
300     }
301     if (frame_header_.nonserialized_metadata->m.xyb_encoded) {
302       return JXL_FAILURE("Cannot decode to JPEG an XYB image");
303     }
304     decoded->jpeg_data->width = frame_dim_.xsize;
305     decoded->jpeg_data->height = frame_dim_.ysize;
306     if (jpeg_data->components.size() == 1) {
307       jpeg_data->components[0].width_in_blocks = frame_dim_.xsize_blocks;
308       jpeg_data->components[0].height_in_blocks = frame_dim_.ysize_blocks;
309     } else {
310       for (size_t c = 0; c < 3; c++) {
311         jpeg_data->components[c < 2 ? c ^ 1 : c].width_in_blocks =
312             frame_dim_.xsize_blocks >>
313             frame_header_.chroma_subsampling.HShift(c);
314         jpeg_data->components[c < 2 ? c ^ 1 : c].height_in_blocks =
315             frame_dim_.ysize_blocks >>
316             frame_header_.chroma_subsampling.VShift(c);
317       }
318     }
319     for (size_t c = 0; c < jpeg_data->components.size(); c++) {
320       jpeg_data->components[c].h_samp_factor =
321           1 << frame_header_.chroma_subsampling.RawHShift(c < 2 ? c ^ 1 : c);
322       jpeg_data->components[c].v_samp_factor =
323           1 << frame_header_.chroma_subsampling.RawVShift(c < 2 ? c ^ 1 : c);
324     }
325     for (auto& v : jpeg_data->components) {
326       v.coeffs.resize(v.width_in_blocks * v.height_in_blocks *
327                       jxl::kDCTBlockSize);
328     }
329   }
330 
331   // Clear the state.
332   decoded_dc_global_ = false;
333   decoded_ac_global_ = false;
334   is_finalized_ = false;
335   finalized_dc_ = false;
336   decoded_dc_groups_.clear();
337   decoded_dc_groups_.resize(frame_dim_.num_dc_groups);
338   decoded_passes_per_ac_group_.clear();
339   decoded_passes_per_ac_group_.resize(frame_dim_.num_groups, 0);
340   processed_section_.clear();
341   processed_section_.resize(section_offsets_.size());
342   max_passes_ = frame_header_.passes.num_passes;
343   num_renders_ = 0;
344 
345   return true;
346 }
347 
ProcessDCGlobal(BitReader * br)348 Status FrameDecoder::ProcessDCGlobal(BitReader* br) {
349   PROFILER_FUNC;
350   PassesSharedState& shared = dec_state_->shared_storage;
351   if (shared.frame_header.flags & FrameHeader::kPatches) {
352     bool uses_extra_channels = false;
353     JXL_RETURN_IF_ERROR(shared.image_features.patches.Decode(
354         br, frame_dim_.xsize_padded, frame_dim_.ysize_padded,
355         &uses_extra_channels));
356     if (uses_extra_channels && frame_header_.upsampling != 1) {
357       for (size_t ecups : frame_header_.extra_channel_upsampling) {
358         if (ecups != frame_header_.upsampling) {
359           return JXL_FAILURE(
360               "Cannot use extra channels in patches if color channels are "
361               "subsampled differently from extra channels");
362         }
363       }
364     }
365   } else {
366     shared.image_features.patches.Clear();
367   }
368   if (shared.frame_header.flags & FrameHeader::kSplines) {
369     JXL_RETURN_IF_ERROR(shared.image_features.splines.Decode(
370         br, frame_dim_.xsize * frame_dim_.ysize));
371   }
372   if (shared.frame_header.flags & FrameHeader::kNoise) {
373     JXL_RETURN_IF_ERROR(DecodeNoise(br, &shared.image_features.noise_params));
374   }
375 
376   JXL_RETURN_IF_ERROR(dec_state_->shared_storage.matrices.DecodeDC(br));
377   if (frame_header_.encoding == FrameEncoding::kVarDCT) {
378     JXL_RETURN_IF_ERROR(
379         jxl::DecodeGlobalDCInfo(br, decoded_->IsJPEG(), dec_state_, pool_));
380   }
381   Status dec_status = modular_frame_decoder_.DecodeGlobalInfo(
382       br, frame_header_, allow_partial_dc_global_);
383   if (dec_status.IsFatalError()) return dec_status;
384   if (dec_status) {
385     decoded_dc_global_ = true;
386   }
387   return dec_status;
388 }
389 
ProcessDCGroup(size_t dc_group_id,BitReader * br)390 Status FrameDecoder::ProcessDCGroup(size_t dc_group_id, BitReader* br) {
391   PROFILER_FUNC;
392   const size_t gx = dc_group_id % frame_dim_.xsize_dc_groups;
393   const size_t gy = dc_group_id / frame_dim_.xsize_dc_groups;
394   if (frame_header_.encoding == FrameEncoding::kVarDCT &&
395       !(frame_header_.flags & FrameHeader::kUseDcFrame)) {
396     JXL_RETURN_IF_ERROR(
397         modular_frame_decoder_.DecodeVarDCTDC(dc_group_id, br, dec_state_));
398   }
399   const Rect mrect(gx * frame_dim_.dc_group_dim, gy * frame_dim_.dc_group_dim,
400                    frame_dim_.dc_group_dim, frame_dim_.dc_group_dim);
401   JXL_RETURN_IF_ERROR(modular_frame_decoder_.DecodeGroup(
402       mrect, br, 3, 1000, ModularStreamId::ModularDC(dc_group_id),
403       /*zerofill=*/false));
404   if (frame_header_.encoding == FrameEncoding::kVarDCT) {
405     JXL_RETURN_IF_ERROR(
406         modular_frame_decoder_.DecodeAcMetadata(dc_group_id, br, dec_state_));
407   }
408   decoded_dc_groups_[dc_group_id] = true;
409   return true;
410 }
411 
FinalizeDC()412 void FrameDecoder::FinalizeDC() {
413   // Do Adaptive DC smoothing if enabled. This *must* happen between all the
414   // ProcessDCGroup and ProcessACGroup.
415   if (frame_header_.encoding == FrameEncoding::kVarDCT &&
416       !(frame_header_.flags & FrameHeader::kSkipAdaptiveDCSmoothing) &&
417       !(frame_header_.flags & FrameHeader::kUseDcFrame)) {
418     AdaptiveDCSmoothing(dec_state_->shared->quantizer.MulDC(),
419                         &dec_state_->shared_storage.dc_storage, pool_);
420   }
421 
422   finalized_dc_ = true;
423 }
424 
AllocateOutput()425 void FrameDecoder::AllocateOutput() {
426   const CodecMetadata& metadata = *frame_header_.nonserialized_metadata;
427   if (dec_state_->rgb_output == nullptr && !dec_state_->pixel_callback) {
428     decoded_->SetFromImage(Image3F(frame_dim_.xsize_upsampled_padded,
429                                    frame_dim_.ysize_upsampled_padded),
430                            dec_state_->output_encoding_info.color_encoding);
431   }
432   dec_state_->extra_channels.clear();
433   if (metadata.m.num_extra_channels > 0) {
434     for (size_t i = 0; i < metadata.m.num_extra_channels; i++) {
435       uint32_t ecups = frame_header_.extra_channel_upsampling[i];
436       dec_state_->extra_channels.emplace_back(
437           DivCeil(frame_dim_.xsize_upsampled_padded, ecups),
438           DivCeil(frame_dim_.ysize_upsampled_padded, ecups));
439 #if MEMORY_SANITIZER
440       // Avoid errors due to loading vectors on the outermost padding.
441       for (size_t y = 0; y < DivCeil(frame_dim_.ysize_upsampled_padded, ecups);
442            y++) {
443         for (size_t x = DivCeil(frame_dim_.xsize_upsampled, ecups);
444              x < DivCeil(frame_dim_.xsize_upsampled_padded, ecups); x++) {
445           dec_state_->extra_channels.back().Row(y)[x] = 0;
446         }
447       }
448 #endif
449     }
450   }
451   decoded_->origin = dec_state_->shared->frame_header.frame_origin;
452 }
453 
ProcessACGlobal(BitReader * br)454 Status FrameDecoder::ProcessACGlobal(BitReader* br) {
455   JXL_CHECK(finalized_dc_);
456   JXL_CHECK(decoded_->HasColor() || dec_state_->rgb_output != nullptr ||
457             !!dec_state_->pixel_callback);
458 
459   // Decode AC group.
460   if (frame_header_.encoding == FrameEncoding::kVarDCT) {
461     JXL_RETURN_IF_ERROR(dec_state_->shared_storage.matrices.Decode(
462         br, &modular_frame_decoder_));
463 
464     size_t num_histo_bits =
465         CeilLog2Nonzero(dec_state_->shared->frame_dim.num_groups);
466     dec_state_->shared_storage.num_histograms =
467         1 + br->ReadBits(num_histo_bits);
468 
469     dec_state_->code.resize(kMaxNumPasses);
470     dec_state_->context_map.resize(kMaxNumPasses);
471     // Read coefficient orders and histograms.
472     size_t max_num_bits_ac = 0;
473     for (size_t i = 0;
474          i < dec_state_->shared_storage.frame_header.passes.num_passes; i++) {
475       uint16_t used_orders = U32Coder::Read(kOrderEnc, br);
476       JXL_RETURN_IF_ERROR(DecodeCoeffOrders(
477           used_orders, dec_state_->used_acs,
478           &dec_state_->shared_storage
479                .coeff_orders[i * dec_state_->shared_storage.coeff_order_size],
480           br));
481       size_t num_contexts =
482           dec_state_->shared->num_histograms *
483           dec_state_->shared_storage.block_ctx_map.NumACContexts();
484       JXL_RETURN_IF_ERROR(DecodeHistograms(
485           br, num_contexts, &dec_state_->code[i], &dec_state_->context_map[i]));
486       // Add extra values to enable the cheat in hot loop of DecodeACVarBlock.
487       dec_state_->context_map[i].resize(
488           num_contexts + kZeroDensityContextLimit - kZeroDensityContextCount);
489       max_num_bits_ac =
490           std::max(max_num_bits_ac, dec_state_->code[i].max_num_bits);
491     }
492     max_num_bits_ac += CeilLog2Nonzero(
493         dec_state_->shared_storage.frame_header.passes.num_passes);
494     // 16-bit buffer for decoding to JPEG are not implemented.
495     // TODO(veluca): figure out the exact limit - 16 should still work with
496     // 16-bit buffers, but we are excluding it for safety.
497     bool use_16_bit = max_num_bits_ac < 16 && !decoded_->IsJPEG();
498     bool store = frame_header_.passes.num_passes > 1;
499     size_t xs = store ? kGroupDim * kGroupDim : 0;
500     size_t ys = store ? frame_dim_.num_groups : 0;
501     if (use_16_bit) {
502       dec_state_->coefficients = make_unique<ACImageT<int16_t>>(xs, ys);
503     } else {
504       dec_state_->coefficients = make_unique<ACImageT<int32_t>>(xs, ys);
505     }
506     if (store) {
507       dec_state_->coefficients->ZeroFill();
508     }
509   }
510 
511   // Set JPEG decoding data.
512   if (decoded_->IsJPEG()) {
513     decoded_->color_transform = frame_header_.color_transform;
514     decoded_->chroma_subsampling = frame_header_.chroma_subsampling;
515     const std::vector<QuantEncoding>& qe =
516         dec_state_->shared_storage.matrices.encodings();
517     if (qe.empty() || qe[0].mode != QuantEncoding::Mode::kQuantModeRAW ||
518         std::abs(qe[0].qraw.qtable_den - 1.f / (8 * 255)) > 1e-8f) {
519       return JXL_FAILURE(
520           "Quantization table is not a JPEG quantization table.");
521     }
522     auto jpeg_c_map = JpegOrder(frame_header_.color_transform,
523                                 decoded_->jpeg_data->components.size() == 1);
524     for (size_t c = 0; c < 3; c++) {
525       if (c != 1 && decoded_->jpeg_data->components.size() == 1) {
526         continue;
527       }
528       size_t jpeg_channel = jpeg_c_map[c];
529       size_t qpos = decoded_->jpeg_data->components[jpeg_channel].quant_idx;
530       JXL_CHECK(qpos != decoded_->jpeg_data->quant.size());
531       for (size_t x = 0; x < 8; x++) {
532         for (size_t y = 0; y < 8; y++) {
533           decoded_->jpeg_data->quant[qpos].values[x * 8 + y] =
534               (*qe[0].qraw.qtable)[c * 64 + y * 8 + x];
535         }
536       }
537     }
538   }
539   // Set memory buffer for pre-color-transform frame, if needed.
540   if (frame_header_.needs_color_transform() &&
541       frame_header_.save_before_color_transform) {
542     dec_state_->pre_color_transform_frame =
543         Image3F(frame_dim_.xsize_upsampled, frame_dim_.ysize_upsampled);
544   } else {
545     // clear pre_color_transform_frame to ensure that previously moved-from
546     // images are not used.
547     dec_state_->pre_color_transform_frame = Image3F();
548   }
549   decoded_ac_global_ = true;
550   return true;
551 }
552 
ProcessACGroup(size_t ac_group_id,BitReader * JXL_RESTRICT * br,size_t num_passes,size_t thread,bool force_draw,bool dc_only)553 Status FrameDecoder::ProcessACGroup(size_t ac_group_id,
554                                     BitReader* JXL_RESTRICT* br,
555                                     size_t num_passes, size_t thread,
556                                     bool force_draw, bool dc_only) {
557   PROFILER_ZONE("process_group");
558   const size_t gx = ac_group_id % frame_dim_.xsize_groups;
559   const size_t gy = ac_group_id / frame_dim_.xsize_groups;
560   const size_t x = gx * frame_dim_.group_dim;
561   const size_t y = gy * frame_dim_.group_dim;
562 
563   if (frame_header_.encoding == FrameEncoding::kVarDCT) {
564     group_dec_caches_[thread].InitOnce(frame_header_.passes.num_passes,
565                                        dec_state_->used_acs);
566     JXL_RETURN_IF_ERROR(DecodeGroup(
567         br, num_passes, ac_group_id, dec_state_, &group_dec_caches_[thread],
568         thread, decoded_, decoded_passes_per_ac_group_[ac_group_id], force_draw,
569         dc_only));
570   }
571 
572   // don't limit to image dimensions here (is done in DecodeGroup)
573   const Rect mrect(x, y, frame_dim_.group_dim, frame_dim_.group_dim);
574   for (size_t i = 0; i < frame_header_.passes.num_passes; i++) {
575     int minShift, maxShift;
576     frame_header_.passes.GetDownsamplingBracket(i, minShift, maxShift);
577     if (i >= decoded_passes_per_ac_group_[ac_group_id] &&
578         i < decoded_passes_per_ac_group_[ac_group_id] + num_passes) {
579       JXL_RETURN_IF_ERROR(modular_frame_decoder_.DecodeGroup(
580           mrect, br[i - decoded_passes_per_ac_group_[ac_group_id]], minShift,
581           maxShift, ModularStreamId::ModularAC(ac_group_id, i),
582           /*zerofill=*/false));
583     } else if (i >= decoded_passes_per_ac_group_[ac_group_id] + num_passes &&
584                force_draw) {
585       JXL_RETURN_IF_ERROR(modular_frame_decoder_.DecodeGroup(
586           mrect, nullptr, minShift, maxShift,
587           ModularStreamId::ModularAC(ac_group_id, i), /*zerofill=*/true));
588     }
589   }
590   decoded_passes_per_ac_group_[ac_group_id] += num_passes;
591   return true;
592 }
593 
ProcessSections(const SectionInfo * sections,size_t num,SectionStatus * section_status)594 Status FrameDecoder::ProcessSections(const SectionInfo* sections, size_t num,
595                                      SectionStatus* section_status) {
596   if (num == 0) return true;  // Nothing to process
597   std::fill(section_status, section_status + num, SectionStatus::kSkipped);
598   size_t dc_global_sec = num;
599   size_t ac_global_sec = num;
600   std::vector<size_t> dc_group_sec(frame_dim_.num_dc_groups, num);
601   std::vector<std::vector<size_t>> ac_group_sec(
602       frame_dim_.num_groups,
603       std::vector<size_t>(frame_header_.passes.num_passes, num));
604   std::vector<size_t> num_ac_passes(frame_dim_.num_groups);
605   if (frame_dim_.num_groups == 1 && frame_header_.passes.num_passes == 1) {
606     JXL_ASSERT(num == 1);
607     JXL_ASSERT(sections[0].id == 0);
608     if (processed_section_[0] == false) {
609       processed_section_[0] = true;
610       ac_group_sec[0].resize(1);
611       dc_global_sec = ac_global_sec = dc_group_sec[0] = ac_group_sec[0][0] = 0;
612       num_ac_passes[0] = 1;
613     } else {
614       section_status[0] = SectionStatus::kDuplicate;
615     }
616   } else {
617     size_t ac_global_index = frame_dim_.num_dc_groups + 1;
618     for (size_t i = 0; i < num; i++) {
619       JXL_ASSERT(sections[i].id < processed_section_.size());
620       if (processed_section_[sections[i].id]) {
621         section_status[i] = SectionStatus::kDuplicate;
622         continue;
623       }
624       if (sections[i].id == 0) {
625         dc_global_sec = i;
626       } else if (sections[i].id < ac_global_index) {
627         dc_group_sec[sections[i].id - 1] = i;
628       } else if (sections[i].id == ac_global_index) {
629         ac_global_sec = i;
630       } else {
631         size_t ac_idx = sections[i].id - ac_global_index - 1;
632         size_t acg = ac_idx % frame_dim_.num_groups;
633         size_t acp = ac_idx / frame_dim_.num_groups;
634         if (acp >= frame_header_.passes.num_passes) {
635           return JXL_FAILURE("Invalid section ID");
636         }
637         if (acp >= max_passes_) {
638           continue;
639         }
640         ac_group_sec[acg][acp] = i;
641       }
642       processed_section_[sections[i].id] = true;
643     }
644     // Count number of new passes per group.
645     for (size_t g = 0; g < ac_group_sec.size(); g++) {
646       size_t j = 0;
647       for (; j + decoded_passes_per_ac_group_[g] < max_passes_; j++) {
648         if (ac_group_sec[g][j + decoded_passes_per_ac_group_[g]] == num) {
649           break;
650         }
651       }
652       num_ac_passes[g] = j;
653     }
654   }
655   if (dc_global_sec != num) {
656     Status dc_global_status = ProcessDCGlobal(sections[dc_global_sec].br);
657     if (dc_global_status.IsFatalError()) return dc_global_status;
658     if (dc_global_status) {
659       section_status[dc_global_sec] = SectionStatus::kDone;
660     } else {
661       section_status[dc_global_sec] = SectionStatus::kPartial;
662     }
663   }
664 
665   std::atomic<bool> has_error{false};
666   if (decoded_dc_global_) {
667     RunOnPool(
668         pool_, 0, dc_group_sec.size(), ThreadPool::SkipInit(),
669         [this, &dc_group_sec, &num, &sections, &section_status, &has_error](
670             size_t i, size_t thread) {
671           if (dc_group_sec[i] != num) {
672             if (!ProcessDCGroup(i, sections[dc_group_sec[i]].br)) {
673               has_error = true;
674             } else {
675               section_status[dc_group_sec[i]] = SectionStatus::kDone;
676             }
677           }
678         },
679         "DecodeDCGroup");
680   }
681   if (has_error) return JXL_FAILURE("Error in DC group");
682 
683   if (*std::min_element(decoded_dc_groups_.begin(), decoded_dc_groups_.end()) ==
684           true &&
685       !finalized_dc_) {
686     FinalizeDC();
687     AllocateOutput();
688   }
689 
690   if (finalized_dc_) dec_state_->EnsureBordersStorage();
691   if (finalized_dc_ && ac_global_sec != num && !decoded_ac_global_) {
692     dec_state_->InitForAC(pool_);
693     JXL_RETURN_IF_ERROR(ProcessACGlobal(sections[ac_global_sec].br));
694     section_status[ac_global_sec] = SectionStatus::kDone;
695   }
696 
697   if (decoded_ac_global_) {
698     // The decoded image requires padding for filtering. ProcessACGlobal added
699     // the padding, however when Flush is used, the image is shrunk to the
700     // output size. Add the padding back here. This is a cheap opeartion
701     // since the image has the original allocated size. The memory and original
702     // size are already there, but for safety we require the indicated xsize and
703     // ysize dimensions match the working area, see PlaneRowBoundsCheck.
704     decoded_->ShrinkTo(frame_dim_.xsize_upsampled_padded,
705                        frame_dim_.ysize_upsampled_padded);
706 
707     // Mark all the AC groups that we received as not complete yet.
708     for (size_t i = 0; i < ac_group_sec.size(); i++) {
709       if (num_ac_passes[i] == 0) continue;
710       dec_state_->group_border_assigner.ClearDone(i);
711     }
712 
713     RunOnPool(
714         pool_, 0, ac_group_sec.size(),
715         [this](size_t num_threads) {
716           PrepareStorage(num_threads, decoded_passes_per_ac_group_.size());
717           return true;
718         },
719         [this, &ac_group_sec, &num_ac_passes, &num, &sections, &section_status,
720          &has_error](size_t g, size_t thread) {
721           if (num_ac_passes[g] == 0) {  // no new AC pass, nothing to do.
722             return;
723           }
724           (void)num;
725           size_t first_pass = decoded_passes_per_ac_group_[g];
726           BitReader* JXL_RESTRICT readers[kMaxNumPasses];
727           for (size_t i = 0; i < num_ac_passes[g]; i++) {
728             JXL_ASSERT(ac_group_sec[g][first_pass + i] != num);
729             readers[i] = sections[ac_group_sec[g][first_pass + i]].br;
730           }
731           if (!ProcessACGroup(g, readers, num_ac_passes[g],
732                               GetStorageLocation(thread, g),
733                               /*force_draw=*/false, /*dc_only=*/false)) {
734             has_error = true;
735           } else {
736             for (size_t i = 0; i < num_ac_passes[g]; i++) {
737               section_status[ac_group_sec[g][first_pass + i]] =
738                   SectionStatus::kDone;
739             }
740           }
741         },
742         "DecodeGroup");
743   }
744   if (has_error) return JXL_FAILURE("Error in AC group");
745 
746   for (size_t i = 0; i < num; i++) {
747     if (section_status[i] == SectionStatus::kSkipped ||
748         section_status[i] == SectionStatus::kPartial) {
749       processed_section_[sections[i].id] = false;
750     }
751   }
752   return true;
753 }
754 
Flush()755 Status FrameDecoder::Flush() {
756   bool has_blending = frame_header_.blending_info.mode != BlendMode::kReplace ||
757                       frame_header_.custom_size_or_origin;
758   for (const auto& blending_info_ec :
759        frame_header_.extra_channel_blending_info) {
760     if (blending_info_ec.mode != BlendMode::kReplace) has_blending = true;
761   }
762   // No early Flush() if blending is enabled.
763   if (has_blending && !is_finalized_) {
764     return false;
765   }
766   // No early Flush() - nothing to do - if the frame is a kSkipProgressive
767   // frame.
768   if (frame_header_.frame_type == FrameType::kSkipProgressive &&
769       !is_finalized_) {
770     return true;
771   }
772   if (decoded_->IsJPEG()) {
773     // Nothing to do.
774     return true;
775   }
776   uint32_t completely_decoded_ac_pass = *std::min_element(
777       decoded_passes_per_ac_group_.begin(), decoded_passes_per_ac_group_.end());
778   if (completely_decoded_ac_pass < frame_header_.passes.num_passes) {
779     // We don't have all AC yet: force a draw of all the missing areas.
780     // Mark all sections as not complete.
781     for (size_t i = 0; i < decoded_passes_per_ac_group_.size(); i++) {
782       if (decoded_passes_per_ac_group_[i] == frame_header_.passes.num_passes)
783         continue;
784       dec_state_->group_border_assigner.ClearDone(i);
785     }
786     std::atomic<bool> has_error{false};
787     RunOnPool(
788         pool_, 0, decoded_passes_per_ac_group_.size(),
789         [this](size_t num_threads) {
790           PrepareStorage(num_threads, decoded_passes_per_ac_group_.size());
791           return true;
792         },
793         [this, &has_error](size_t g, size_t thread) {
794           if (decoded_passes_per_ac_group_[g] ==
795               frame_header_.passes.num_passes) {
796             // This group was drawn already, nothing to do.
797             return;
798           }
799           BitReader* JXL_RESTRICT readers[kMaxNumPasses] = {};
800           bool ok = ProcessACGroup(
801               g, readers, /*num_passes=*/0, GetStorageLocation(thread, g),
802               /*force_draw=*/true, /*dc_only=*/!decoded_ac_global_);
803           if (!ok) has_error = true;
804         },
805         "ForceDrawGroup");
806     if (has_error) {
807       return JXL_FAILURE("Drawing groups failed");
808     }
809   }
810   // TODO(veluca): the rest of this function should be removed once we have full
811   // support for per-group decoding.
812 
813   // undo global modular transforms and copy int pixel buffers to float ones
814   JXL_RETURN_IF_ERROR(
815       modular_frame_decoder_.FinalizeDecoding(dec_state_, pool_, decoded_));
816 
817   JXL_RETURN_IF_ERROR(FinalizeFrameDecoding(decoded_, dec_state_, pool_,
818                                             /*force_fir=*/false,
819                                             /*skip_blending=*/false));
820 
821   num_renders_++;
822   return true;
823 }
824 
FinalizeFrame()825 Status FrameDecoder::FinalizeFrame() {
826   if (is_finalized_) {
827     return JXL_FAILURE("FinalizeFrame called multiple times");
828   }
829   is_finalized_ = true;
830   if (decoded_->IsJPEG()) {
831     // Nothing to do.
832     return true;
833   }
834   if (!finalized_dc_) {
835     // We don't have all of DC: EPF might not behave correctly (and is not
836     // particularly useful anyway on upsampling results), so we disable it.
837     dec_state_->shared_storage.frame_header.loop_filter.epf_iters = 0;
838   }
839   if ((!decoded_dc_global_ || !decoded_ac_global_ ||
840        *std::min_element(decoded_dc_groups_.begin(),
841                          decoded_dc_groups_.end()) != 1 ||
842        *std::min_element(decoded_passes_per_ac_group_.begin(),
843                          decoded_passes_per_ac_group_.end()) < max_passes_) &&
844       !allow_partial_frames_) {
845     return JXL_FAILURE(
846         "FinalizeFrame called before the frame was fully decoded");
847   }
848 
849   JXL_RETURN_IF_ERROR(Flush());
850 
851   if (dec_state_->shared->frame_header.CanBeReferenced()) {
852     size_t id = dec_state_->shared->frame_header.save_as_reference;
853     if (dec_state_->pre_color_transform_frame.xsize() == 0) {
854       dec_state_->shared_storage.reference_frames[id].storage =
855           decoded_->Copy();
856     } else {
857       dec_state_->shared_storage.reference_frames[id].storage =
858           ImageBundle(decoded_->metadata());
859       dec_state_->shared_storage.reference_frames[id].storage.SetFromImage(
860           std::move(dec_state_->pre_color_transform_frame),
861           decoded_->c_current());
862       if (decoded_->HasExtraChannels()) {
863         const std::vector<ImageF>* ecs = &dec_state_->pre_color_transform_ec;
864         if (ecs->empty()) ecs = &decoded_->extra_channels();
865         std::vector<ImageF> extra_channels;
866         for (const auto& ec : *ecs) {
867           extra_channels.push_back(CopyImage(ec));
868         }
869         dec_state_->shared_storage.reference_frames[id]
870             .storage.SetExtraChannels(std::move(extra_channels));
871       }
872     }
873     dec_state_->shared_storage.reference_frames[id].frame =
874         &dec_state_->shared_storage.reference_frames[id].storage;
875     dec_state_->shared_storage.reference_frames[id].ib_is_in_xyb =
876         dec_state_->shared->frame_header.save_before_color_transform;
877   }
878   if (dec_state_->shared->frame_header.dc_level != 0) {
879     dec_state_->shared_storage
880         .dc_frames[dec_state_->shared->frame_header.dc_level - 1] =
881         std::move(*decoded_->color());
882     decoded_->RemoveColor();
883   }
884   if (frame_header_.nonserialized_is_preview) {
885     // Fix possible larger image size (multiple of kBlockDim)
886     // TODO(lode): verify if and when that happens.
887     decoded_->ShrinkTo(frame_dim_.xsize, frame_dim_.ysize);
888   } else if (!decoded_->IsJPEG()) {
889     // A kRegularFrame is blended with the other frames, and thus results in a
890     // coalesced frame of size equal to image dimensions. Other frames are not
891     // blended, thus their final size is the size that was defined in the
892     // frame_header.
893     if (frame_header_.frame_type == kRegularFrame ||
894         frame_header_.frame_type == kSkipProgressive) {
895       decoded_->ShrinkTo(
896           dec_state_->shared->frame_header.nonserialized_metadata->xsize(),
897           dec_state_->shared->frame_header.nonserialized_metadata->ysize());
898     } else {
899       // xsize_upsampled is the actual frame size, after any upsampling has been
900       // applied.
901       decoded_->ShrinkTo(frame_dim_.xsize_upsampled,
902                          frame_dim_.ysize_upsampled);
903     }
904   }
905 
906   if (render_spotcolors_) {
907     for (size_t i = 0; i < decoded_->extra_channels().size(); i++) {
908       // Don't use Find() because there may be multiple spot color channels.
909       const ExtraChannelInfo& eci = decoded_->metadata()->extra_channel_info[i];
910       if (eci.type == ExtraChannel::kOptional) {
911         continue;
912       }
913       if (eci.type == ExtraChannel::kUnknown ||
914           (int(ExtraChannel::kReserved0) <= int(eci.type) &&
915            int(eci.type) <= int(ExtraChannel::kReserved7))) {
916         return JXL_FAILURE(
917             "Unknown extra channel (bits %u, shift %u, name '%s')\n",
918             eci.bit_depth.bits_per_sample, eci.dim_shift, eci.name.c_str());
919       }
920       if (eci.type == ExtraChannel::kSpotColor) {
921         float scale = eci.spot_color[3];
922         for (size_t c = 0; c < 3; c++) {
923           for (size_t y = 0; y < decoded_->ysize(); y++) {
924             float* JXL_RESTRICT p = decoded_->color()->Plane(c).Row(y);
925             const float* JXL_RESTRICT s =
926                 decoded_->extra_channels()[i].ConstRow(y);
927             for (size_t x = 0; x < decoded_->xsize(); x++) {
928               float mix = scale * s[x];
929               p[x] = mix * eci.spot_color[c] + (1.0 - mix) * p[x];
930             }
931           }
932         }
933       }
934     }
935   }
936   return true;
937 }
938 
939 }  // namespace jxl
940