1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
2 //
3 // Use of this source code is governed by a BSD-style
4 // license that can be found in the LICENSE file.
5 
6 #include "lib/jxl/enc_modular.h"
7 
8 #include <stddef.h>
9 #include <stdint.h>
10 
11 #include <array>
12 #include <limits>
13 #include <queue>
14 #include <utility>
15 #include <vector>
16 
17 #include "lib/jxl/aux_out.h"
18 #include "lib/jxl/base/compiler_specific.h"
19 #include "lib/jxl/base/padded_bytes.h"
20 #include "lib/jxl/base/status.h"
21 #include "lib/jxl/compressed_dc.h"
22 #include "lib/jxl/dec_ans.h"
23 #include "lib/jxl/enc_bit_writer.h"
24 #include "lib/jxl/enc_cluster.h"
25 #include "lib/jxl/enc_params.h"
26 #include "lib/jxl/enc_patch_dictionary.h"
27 #include "lib/jxl/enc_quant_weights.h"
28 #include "lib/jxl/frame_header.h"
29 #include "lib/jxl/gaborish.h"
30 #include "lib/jxl/modular/encoding/context_predict.h"
31 #include "lib/jxl/modular/encoding/enc_encoding.h"
32 #include "lib/jxl/modular/encoding/encoding.h"
33 #include "lib/jxl/modular/encoding/ma_common.h"
34 #include "lib/jxl/modular/modular_image.h"
35 #include "lib/jxl/modular/options.h"
36 #include "lib/jxl/modular/transform/transform.h"
37 #include "lib/jxl/toc.h"
38 
39 namespace jxl {
40 
41 namespace {
42 // Squeeze default quantization factors
43 // these quantization factors are for -Q 50  (other qualities simply scale the
44 // factors; things are rounded down and obviously cannot get below 1)
45 static const float squeeze_quality_factor =
46     0.35;  // for easy tweaking of the quality range (decrease this number for
47            // higher quality)
48 static const float squeeze_luma_factor =
49     1.1;  // for easy tweaking of the balance between luma (or anything
50           // non-chroma) and chroma (decrease this number for higher quality
51           // luma)
52 static const float squeeze_quality_factor_xyb = 2.4f;
53 static const float squeeze_xyb_qtable[3][16] = {
54     {163.84, 81.92, 40.96, 20.48, 10.24, 5.12, 2.56, 1.28, 0.64, 0.32, 0.16,
55      0.08, 0.04, 0.02, 0.01, 0.005},  // Y
56     {1024, 512, 256, 128, 64, 32, 16, 8, 4, 2, 1, 0.5, 0.5, 0.5, 0.5,
57      0.5},  // X
58     {2048, 1024, 512, 256, 128, 64, 32, 16, 8, 4, 2, 1, 0.5, 0.5, 0.5,
59      0.5},  // B-Y
60 };
61 
62 static const float squeeze_luma_qtable[16] = {
63     163.84, 81.92, 40.96, 20.48, 10.24, 5.12, 2.56, 1.28,
64     0.64,   0.32,  0.16,  0.08,  0.04,  0.02, 0.01, 0.005};
65 // for 8-bit input, the range of YCoCg chroma is -255..255 so basically this
66 // does 4:2:0 subsampling (two most fine grained layers get quantized away)
67 static const float squeeze_chroma_qtable[16] = {
68     1024, 512, 256, 128, 64, 32, 16, 8, 4, 2, 1, 0.5, 0.5, 0.5, 0.5, 0.5};
69 
70 // `cutoffs` must be sorted.
MakeFixedTree(int property,const std::vector<int32_t> & cutoffs,Predictor pred,size_t num_pixels)71 Tree MakeFixedTree(int property, const std::vector<int32_t>& cutoffs,
72                    Predictor pred, size_t num_pixels) {
73   size_t log_px = CeilLog2Nonzero(num_pixels);
74   size_t min_gap = 0;
75   // Reduce fixed tree height when encoding small images.
76   if (log_px < 14) {
77     min_gap = 8 * (14 - log_px);
78   }
79   Tree tree;
80   struct NodeInfo {
81     size_t begin, end, pos;
82   };
83   std::queue<NodeInfo> q;
84   // Leaf IDs will be set by roundtrip decoding the tree.
85   tree.push_back(PropertyDecisionNode::Leaf(pred));
86   q.push(NodeInfo{0, cutoffs.size(), 0});
87   while (!q.empty()) {
88     NodeInfo info = q.front();
89     q.pop();
90     if (info.begin + min_gap >= info.end) continue;
91     uint32_t split = (info.begin + info.end) / 2;
92     tree[info.pos] =
93         PropertyDecisionNode::Split(property, cutoffs[split], tree.size());
94     q.push(NodeInfo{split + 1, info.end, tree.size()});
95     tree.push_back(PropertyDecisionNode::Leaf(pred));
96     q.push(NodeInfo{info.begin, split, tree.size()});
97     tree.push_back(PropertyDecisionNode::Leaf(pred));
98   }
99   return tree;
100 }
101 
PredefinedTree(ModularOptions::TreeKind tree_kind,size_t total_pixels)102 Tree PredefinedTree(ModularOptions::TreeKind tree_kind, size_t total_pixels) {
103   if (tree_kind == ModularOptions::TreeKind::kJpegTranscodeACMeta) {
104     // All the data is 0, so no need for a fancy tree.
105     return {PropertyDecisionNode::Leaf(Predictor::Zero)};
106   }
107   if (tree_kind == ModularOptions::TreeKind::kFalconACMeta) {
108     // All the data is 0 except the quant field. TODO(veluca): make that 0 too.
109     return {PropertyDecisionNode::Leaf(Predictor::Left)};
110   }
111   if (tree_kind == ModularOptions::TreeKind::kACMeta) {
112     // Small image.
113     if (total_pixels < 1024) {
114       return {PropertyDecisionNode::Leaf(Predictor::Left)};
115     }
116     Tree tree;
117     // 0: c > 1
118     tree.push_back(PropertyDecisionNode::Split(0, 1, 1));
119     // 1: c > 2
120     tree.push_back(PropertyDecisionNode::Split(0, 2, 3));
121     // 2: c > 0
122     tree.push_back(PropertyDecisionNode::Split(0, 0, 5));
123     // 3: EPF control field (all 0 or 4), top > 0
124     tree.push_back(PropertyDecisionNode::Split(6, 0, 21));
125     // 4: ACS+QF, y > 0
126     tree.push_back(PropertyDecisionNode::Split(2, 0, 7));
127     // 5: CfL x
128     tree.push_back(PropertyDecisionNode::Leaf(Predictor::Gradient));
129     // 6: CfL b
130     tree.push_back(PropertyDecisionNode::Leaf(Predictor::Gradient));
131     // 7: QF: split according to the left quant value.
132     tree.push_back(PropertyDecisionNode::Split(7, 5, 9));
133     // 8: ACS: split in 4 segments (8x8 from 0 to 3, large square 4-5, large
134     // rectangular 6-11, 8x8 12+), according to previous ACS value.
135     tree.push_back(PropertyDecisionNode::Split(7, 5, 15));
136     // QF
137     tree.push_back(PropertyDecisionNode::Split(7, 11, 11));
138     tree.push_back(PropertyDecisionNode::Split(7, 3, 13));
139     tree.push_back(PropertyDecisionNode::Leaf(Predictor::Left));
140     tree.push_back(PropertyDecisionNode::Leaf(Predictor::Left));
141     tree.push_back(PropertyDecisionNode::Leaf(Predictor::Left));
142     tree.push_back(PropertyDecisionNode::Leaf(Predictor::Left));
143     // ACS
144     tree.push_back(PropertyDecisionNode::Split(7, 11, 17));
145     tree.push_back(PropertyDecisionNode::Split(7, 3, 19));
146     tree.push_back(PropertyDecisionNode::Leaf(Predictor::Zero));
147     tree.push_back(PropertyDecisionNode::Leaf(Predictor::Zero));
148     tree.push_back(PropertyDecisionNode::Leaf(Predictor::Zero));
149     tree.push_back(PropertyDecisionNode::Leaf(Predictor::Zero));
150     // EPF, left > 0
151     tree.push_back(PropertyDecisionNode::Split(7, 0, 23));
152     tree.push_back(PropertyDecisionNode::Split(7, 0, 25));
153     tree.push_back(PropertyDecisionNode::Leaf(Predictor::Zero));
154     tree.push_back(PropertyDecisionNode::Leaf(Predictor::Zero));
155     tree.push_back(PropertyDecisionNode::Leaf(Predictor::Zero));
156     tree.push_back(PropertyDecisionNode::Leaf(Predictor::Zero));
157     return tree;
158   }
159   if (tree_kind == ModularOptions::TreeKind::kWPFixedDC) {
160     std::vector<int32_t> cutoffs = {
161         -500, -392, -255, -191, -127, -95, -63, -47, -31, -23, -15,
162         -11,  -7,   -4,   -3,   -1,   0,   1,   3,   5,   7,   11,
163         15,   23,   31,   47,   63,   95,  127, 191, 255, 392, 500};
164     return MakeFixedTree(kNumNonrefProperties - weighted::kNumProperties,
165                          cutoffs, Predictor::Weighted, total_pixels);
166   }
167   if (tree_kind == ModularOptions::TreeKind::kGradientFixedDC) {
168     std::vector<int32_t> cutoffs = {
169         -500, -392, -255, -191, -127, -95, -63, -47, -31, -23, -15,
170         -11,  -7,   -4,   -3,   -1,   0,   1,   3,   5,   7,   11,
171         15,   23,   31,   47,   63,   95,  127, 191, 255, 392, 500};
172     return MakeFixedTree(kGradientProp, cutoffs, Predictor::Gradient,
173                          total_pixels);
174   }
175   JXL_ABORT("Unreachable");
176   return {};
177 }
178 
179 // Merges the trees in `trees` using nodes that decide on stream_id, as defined
180 // by `tree_splits`.
MergeTrees(const std::vector<Tree> & trees,const std::vector<size_t> & tree_splits,size_t begin,size_t end,Tree * tree)181 void MergeTrees(const std::vector<Tree>& trees,
182                 const std::vector<size_t>& tree_splits, size_t begin,
183                 size_t end, Tree* tree) {
184   JXL_ASSERT(trees.size() + 1 == tree_splits.size());
185   JXL_ASSERT(end > begin);
186   JXL_ASSERT(end <= trees.size());
187   if (end == begin + 1) {
188     // Insert the tree, adding the opportune offset to all child nodes.
189     // This will make the leaf IDs wrong, but subsequent roundtripping will fix
190     // them.
191     size_t sz = tree->size();
192     tree->insert(tree->end(), trees[begin].begin(), trees[begin].end());
193     for (size_t i = sz; i < tree->size(); i++) {
194       (*tree)[i].lchild += sz;
195       (*tree)[i].rchild += sz;
196     }
197     return;
198   }
199   size_t mid = (begin + end) / 2;
200   size_t splitval = tree_splits[mid] - 1;
201   size_t cur = tree->size();
202   tree->emplace_back(1 /*stream_id*/, splitval, 0, 0, Predictor::Zero, 0, 1);
203   (*tree)[cur].lchild = tree->size();
204   MergeTrees(trees, tree_splits, mid, end, tree);
205   (*tree)[cur].rchild = tree->size();
206   MergeTrees(trees, tree_splits, begin, mid, tree);
207 }
208 
QuantizeChannel(Channel & ch,const int q)209 void QuantizeChannel(Channel& ch, const int q) {
210   if (q == 1) return;
211   for (size_t y = 0; y < ch.plane.ysize(); y++) {
212     pixel_type* row = ch.plane.Row(y);
213     for (size_t x = 0; x < ch.plane.xsize(); x++) {
214       if (row[x] < 0) {
215         row[x] = -((-row[x] + q / 2) / q) * q;
216       } else {
217         row[x] = ((row[x] + q / 2) / q) * q;
218       }
219     }
220   }
221 }
222 
223 // convert binary32 float that corresponds to custom [bits]-bit float (with
224 // [exp_bits] exponent bits) to a [bits]-bit integer representation that should
225 // fit in pixel_type
float_to_int(const float * const row_in,pixel_type * const row_out,size_t xsize,unsigned int bits,unsigned int exp_bits,bool fp,float factor)226 Status float_to_int(const float* const row_in, pixel_type* const row_out,
227                     size_t xsize, unsigned int bits, unsigned int exp_bits,
228                     bool fp, float factor) {
229   JXL_ASSERT(sizeof(pixel_type) * 8 >= bits);
230   if (!fp) {
231     for (size_t x = 0; x < xsize; ++x) {
232       row_out[x] = row_in[x] * factor + 0.5f;
233     }
234     return true;
235   }
236   if (bits == 32 && fp) {
237     JXL_ASSERT(exp_bits == 8);
238     memcpy((void*)row_out, (const void*)row_in, 4 * xsize);
239     return true;
240   }
241 
242   int exp_bias = (1 << (exp_bits - 1)) - 1;
243   int max_exp = (1 << exp_bits) - 1;
244   uint32_t sign = (1u << (bits - 1));
245   int mant_bits = bits - exp_bits - 1;
246   int mant_shift = 23 - mant_bits;
247   for (size_t x = 0; x < xsize; ++x) {
248     uint32_t f;
249     memcpy(&f, &row_in[x], 4);
250     int signbit = (f >> 31);
251     f &= 0x7fffffff;
252     if (f == 0) {
253       row_out[x] = (signbit ? sign : 0);
254       continue;
255     }
256     int exp = (f >> 23) - 127;
257     if (exp == 128) return JXL_FAILURE("Inf/NaN not allowed");
258     int mantissa = (f & 0x007fffff);
259     // broke up the binary32 into its parts, now reassemble into
260     // arbitrary float
261     exp += exp_bias;
262     if (exp < 0) {  // will become a subnormal number
263       // add implicit leading 1 to mantissa
264       mantissa |= 0x00800000;
265       if (exp < -mant_bits) {
266         return JXL_FAILURE(
267             "Invalid float number: %g cannot be represented with %i "
268             "exp_bits and %i mant_bits (exp %i)",
269             row_in[x], exp_bits, mant_bits, exp);
270       }
271       mantissa >>= 1 - exp;
272       exp = 0;
273     }
274     // exp should be representable in exp_bits, otherwise input was
275     // invalid
276     if (exp > max_exp) return JXL_FAILURE("Invalid float exponent");
277     if (mantissa & ((1 << mant_shift) - 1)) {
278       return JXL_FAILURE("%g is losing precision (mant: %x)", row_in[x],
279                          mantissa);
280     }
281     mantissa >>= mant_shift;
282     f = (signbit ? sign : 0);
283     f |= (exp << mant_bits);
284     f |= mantissa;
285     row_out[x] = (pixel_type)f;
286   }
287   return true;
288 }
289 }  // namespace
290 
ModularFrameEncoder(const FrameHeader & frame_header,const CompressParams & cparams_orig)291 ModularFrameEncoder::ModularFrameEncoder(const FrameHeader& frame_header,
292                                          const CompressParams& cparams_orig)
293     : frame_dim(frame_header.ToFrameDimensions()), cparams(cparams_orig) {
294   size_t num_streams =
295       ModularStreamId::Num(frame_dim, frame_header.passes.num_passes);
296   if (cparams.modular_mode &&
297       cparams.quality_pair == std::pair<float, float>{100.0, 100.0}) {
298     switch (cparams.decoding_speed_tier) {
299       case 0:
300         break;
301       case 1:
302         cparams.options.wp_tree_mode = ModularOptions::TreeMode::kWPOnly;
303         break;
304       case 2: {
305         cparams.options.wp_tree_mode = ModularOptions::TreeMode::kGradientOnly;
306         cparams.options.predictor = Predictor::Gradient;
307         break;
308       }
309       case 3: {  // LZ77, no Gradient.
310         cparams.options.nb_repeats = 0;
311         cparams.options.predictor = Predictor::Gradient;
312         break;
313       }
314       default: {  // LZ77, no predictor.
315         cparams.options.nb_repeats = 0;
316         cparams.options.predictor = Predictor::Zero;
317         break;
318       }
319     }
320   }
321   stream_images.resize(num_streams);
322   if (cquality > 100) cquality = quality;
323 
324   // use a sensible default if nothing explicit is specified:
325   // Squeeze for lossy, no squeeze for lossless
326   if (cparams.responsive < 0) {
327     if (quality == 100) {
328       cparams.responsive = 0;
329     } else {
330       cparams.responsive = 1;
331     }
332   }
333 
334   if (cparams.speed_tier > SpeedTier::kWombat) {
335     cparams.options.splitting_heuristics_node_threshold = 192;
336   } else {
337     cparams.options.splitting_heuristics_node_threshold = 96;
338   }
339   {
340     // Set properties.
341     std::vector<uint32_t> prop_order;
342     if (cparams.responsive) {
343       // Properties in order of their likelyhood of being useful for Squeeze
344       // residuals.
345       prop_order = {0, 1, 4, 5, 6, 7, 8, 15, 9, 10, 11, 12, 13, 14, 2, 3};
346     } else {
347       // Same, but for the non-Squeeze case.
348       prop_order = {0, 1, 15, 9, 10, 11, 12, 13, 14, 2, 3, 4, 5, 6, 7, 8};
349     }
350     switch (cparams.speed_tier) {
351       case SpeedTier::kSquirrel:
352         cparams.options.splitting_heuristics_properties.assign(
353             prop_order.begin(), prop_order.begin() + 8);
354         cparams.options.max_property_values = 32;
355         break;
356       case SpeedTier::kKitten:
357         cparams.options.splitting_heuristics_properties.assign(
358             prop_order.begin(), prop_order.begin() + 10);
359         cparams.options.max_property_values = 64;
360         break;
361       case SpeedTier::kTortoise:
362         cparams.options.splitting_heuristics_properties = prop_order;
363         cparams.options.max_property_values = 256;
364         break;
365       default:
366         cparams.options.splitting_heuristics_properties.assign(
367             prop_order.begin(), prop_order.begin() + 6);
368         cparams.options.max_property_values = 16;
369         break;
370     }
371     if (cparams.speed_tier > SpeedTier::kTortoise) {
372       // Gradient in previous channels.
373       for (int i = 0; i < cparams.options.max_properties; i++) {
374         cparams.options.splitting_heuristics_properties.push_back(
375             kNumNonrefProperties + i * 4 + 3);
376       }
377     } else {
378       // All the extra properties in Tortoise mode.
379       for (int i = 0; i < cparams.options.max_properties * 4; i++) {
380         cparams.options.splitting_heuristics_properties.push_back(
381             kNumNonrefProperties + i);
382       }
383     }
384   }
385 
386   if (cparams.options.predictor == static_cast<Predictor>(-1)) {
387     // no explicit predictor(s) given, set a good default
388     if ((cparams.speed_tier <= SpeedTier::kTortoise ||
389          cparams.modular_mode == false) &&
390         quality == 100 && cparams.near_lossless == false &&
391         cparams.responsive == false) {
392       // TODO(veluca): allow all predictors that don't break residual
393       // multipliers in lossy mode.
394       cparams.options.predictor = Predictor::Variable;
395     } else if (cparams.near_lossless) {
396       // weighted predictor for near_lossless
397       cparams.options.predictor = Predictor::Weighted;
398     } else if (cparams.responsive) {
399       // zero predictor for Squeeze residues
400       cparams.options.predictor = Predictor::Zero;
401     } else if (quality < 100) {
402       // If not responsive and lossy. TODO(veluca): use near_lossless instead?
403       cparams.options.predictor = Predictor::Gradient;
404     } else if (cparams.speed_tier < SpeedTier::kFalcon) {
405       // try median and weighted predictor for anything else
406       cparams.options.predictor = Predictor::Best;
407     } else {
408       // just weighted predictor in fastest mode
409       cparams.options.predictor = Predictor::Weighted;
410     }
411   }
412   tree_splits.push_back(0);
413   if (cparams.modular_mode == false) {
414     cparams.options.fast_decode_multiplier = 1.0f;
415     tree_splits.push_back(ModularStreamId::VarDCTDC(0).ID(frame_dim));
416     tree_splits.push_back(ModularStreamId::ModularDC(0).ID(frame_dim));
417     tree_splits.push_back(ModularStreamId::ACMetadata(0).ID(frame_dim));
418     tree_splits.push_back(ModularStreamId::QuantTable(0).ID(frame_dim));
419     tree_splits.push_back(ModularStreamId::ModularAC(0, 0).ID(frame_dim));
420     ac_metadata_size.resize(frame_dim.num_dc_groups);
421     extra_dc_precision.resize(frame_dim.num_dc_groups);
422   }
423   tree_splits.push_back(num_streams);
424   cparams.options.max_chan_size = frame_dim.group_dim;
425 
426   // TODO(veluca): figure out how to use different predictor sets per channel.
427   stream_options.resize(num_streams, cparams.options);
428 }
429 
ComputeEncodingData(const FrameHeader & frame_header,const ImageMetadata & metadata,Image3F * JXL_RESTRICT color,const std::vector<ImageF> & extra_channels,PassesEncoderState * JXL_RESTRICT enc_state,ThreadPool * pool,AuxOut * aux_out,bool do_color)430 Status ModularFrameEncoder::ComputeEncodingData(
431     const FrameHeader& frame_header, const ImageMetadata& metadata,
432     Image3F* JXL_RESTRICT color, const std::vector<ImageF>& extra_channels,
433     PassesEncoderState* JXL_RESTRICT enc_state, ThreadPool* pool,
434     AuxOut* aux_out, bool do_color) {
435   const FrameDimensions& frame_dim = enc_state->shared.frame_dim;
436 
437   if (do_color && frame_header.loop_filter.gab) {
438     GaborishInverse(color, 0.9908511000000001f, pool);
439   }
440 
441   if (do_color && metadata.bit_depth.bits_per_sample <= 16 &&
442       cparams.speed_tier < SpeedTier::kCheetah) {
443     FindBestPatchDictionary(*color, enc_state, nullptr, aux_out,
444                             cparams.color_transform == ColorTransform::kXYB);
445     PatchDictionaryEncoder::SubtractFrom(
446         enc_state->shared.image_features.patches, color);
447   }
448 
449   // Convert ImageBundle to modular Image object
450   const size_t xsize = frame_dim.xsize;
451   const size_t ysize = frame_dim.ysize;
452 
453   int nb_chans = 3;
454   if (metadata.color_encoding.IsGray() &&
455       cparams.color_transform == ColorTransform::kNone) {
456     nb_chans = 1;
457   }
458   if (!do_color) nb_chans = 0;
459 
460   nb_chans += extra_channels.size();
461 
462   bool fp = metadata.bit_depth.floating_point_sample;
463 
464   // bits_per_sample is just metadata for XYB images.
465   if (metadata.bit_depth.bits_per_sample >= 32 && do_color &&
466       cparams.color_transform != ColorTransform::kXYB) {
467     if (metadata.bit_depth.bits_per_sample == 32 && fp == false) {
468       return JXL_FAILURE("uint32_t not supported in enc_modular");
469     } else if (metadata.bit_depth.bits_per_sample > 32) {
470       return JXL_FAILURE("bits_per_sample > 32 not supported");
471     }
472   }
473 
474   int maxval =
475       (fp ? 1
476           : (1u << static_cast<uint32_t>(metadata.bit_depth.bits_per_sample)) -
477                 1);
478 
479   Image& gi = stream_images[0];
480   gi = Image(xsize, ysize, maxval, nb_chans);
481   int c = 0;
482   if (cparams.color_transform == ColorTransform::kXYB &&
483       cparams.modular_mode == true) {
484     static const float enc_factors[3] = {32768.0f, 2048.0f, 2048.0f};
485     DequantMatricesSetCustomDC(&enc_state->shared.matrices, enc_factors);
486   }
487   if (do_color) {
488     for (; c < 3; c++) {
489       if (metadata.color_encoding.IsGray() &&
490           cparams.color_transform == ColorTransform::kNone &&
491           c != (cparams.color_transform == ColorTransform::kXYB ? 1 : 0))
492         continue;
493       int c_out = c;
494       // XYB is encoded as YX(B-Y)
495       if (cparams.color_transform == ColorTransform::kXYB && c < 2)
496         c_out = 1 - c_out;
497       float factor = maxval;
498       if (cparams.color_transform == ColorTransform::kXYB)
499         factor = enc_state->shared.matrices.InvDCQuant(c);
500       if (c == 2 && cparams.color_transform == ColorTransform::kXYB) {
501         JXL_ASSERT(!fp);
502         for (size_t y = 0; y < ysize; ++y) {
503           const float* const JXL_RESTRICT row_in = color->PlaneRow(c, y);
504           pixel_type* const JXL_RESTRICT row_out = gi.channel[c_out].Row(y);
505           pixel_type* const JXL_RESTRICT row_Y = gi.channel[0].Row(y);
506           for (size_t x = 0; x < xsize; ++x) {
507             row_out[x] = row_in[x] * factor + 0.5f;
508             row_out[x] -= row_Y[x];
509           }
510         }
511       } else {
512         int bits = metadata.bit_depth.bits_per_sample;
513         int exp_bits = metadata.bit_depth.exponent_bits_per_sample;
514         gi.channel[c_out].hshift =
515             enc_state->shared.frame_header.chroma_subsampling.HShift(c);
516         gi.channel[c_out].vshift =
517             enc_state->shared.frame_header.chroma_subsampling.VShift(c);
518         size_t xsize_shifted = DivCeil(xsize, 1 << gi.channel[c_out].hshift);
519         size_t ysize_shifted = DivCeil(ysize, 1 << gi.channel[c_out].vshift);
520         gi.channel[c_out].resize(xsize_shifted, ysize_shifted);
521         for (size_t y = 0; y < ysize_shifted; ++y) {
522           const float* const JXL_RESTRICT row_in = color->PlaneRow(c, y);
523           pixel_type* const JXL_RESTRICT row_out = gi.channel[c_out].Row(y);
524           JXL_RETURN_IF_ERROR(float_to_int(row_in, row_out, xsize_shifted, bits,
525                                            exp_bits, fp, factor));
526         }
527       }
528     }
529     if (metadata.color_encoding.IsGray() &&
530         cparams.color_transform == ColorTransform::kNone)
531       c = 1;
532   }
533 
534   for (size_t ec = 0; ec < extra_channels.size(); ec++, c++) {
535     const ExtraChannelInfo& eci = metadata.extra_channel_info[ec];
536     size_t ecups = frame_header.extra_channel_upsampling[ec];
537     gi.channel[c].resize(DivCeil(frame_dim.xsize_upsampled, ecups),
538                          DivCeil(frame_dim.ysize_upsampled, ecups));
539     gi.channel[c].hshift = gi.channel[c].vshift =
540         CeilLog2Nonzero(ecups) - CeilLog2Nonzero(frame_header.upsampling);
541 
542     int bits = eci.bit_depth.bits_per_sample;
543     int exp_bits = eci.bit_depth.exponent_bits_per_sample;
544     bool fp = eci.bit_depth.floating_point_sample;
545     float factor = (fp ? 1 : ((1u << eci.bit_depth.bits_per_sample) - 1));
546     for (size_t y = 0; y < gi.channel[c].plane.ysize(); ++y) {
547       const float* const JXL_RESTRICT row_in = extra_channels[ec].Row(y);
548       pixel_type* const JXL_RESTRICT row_out = gi.channel[c].Row(y);
549       JXL_RETURN_IF_ERROR(float_to_int(row_in, row_out,
550                                        gi.channel[c].plane.xsize(), bits,
551                                        exp_bits, fp, factor));
552     }
553   }
554   JXL_ASSERT(c == nb_chans);
555 
556   // Set options and apply transformations
557 
558   if (quality < 100 || cparams.near_lossless) {
559     if (cparams.palette_colors != 0) {
560       JXL_DEBUG_V(3, "Lossy encode, not doing palette transforms");
561     }
562     if (cparams.color_transform == ColorTransform::kXYB) {
563       cparams.channel_colors_pre_transform_percent = 0;
564     }
565     cparams.channel_colors_percent = 0;
566     cparams.palette_colors = 0;
567   }
568 
569   // if few colors, do all-channel palette before trying channel palette
570   // Logic is as follows:
571   // - if you can make a palette with few colors (arbitrary threshold: 200),
572   //   then you can also make channel palettes, but they will just be extra
573   //   signaling cost for almost no benefit
574   // - if the palette needs more colors, then channel palette might help to
575   //   reduce palette signaling cost
576   if (cparams.palette_colors != 0 && cparams.speed_tier < SpeedTier::kFalcon) {
577     // all-channel palette (e.g. RGBA)
578     if (gi.nb_channels > 1) {
579       Transform maybe_palette(TransformId::kPalette);
580       maybe_palette.begin_c = gi.nb_meta_channels;
581       maybe_palette.num_c = gi.nb_channels;
582       maybe_palette.nb_colors =
583           std::min(std::min(200, (int)(xsize * ysize / 8)),
584                    std::abs(cparams.palette_colors) / 16);
585       maybe_palette.ordered_palette = cparams.palette_colors >= 0;
586       maybe_palette.lossy_palette = false;
587       gi.do_transform(maybe_palette, weighted::Header());
588     }
589   }
590 
591   // Global channel palette
592   if (cparams.channel_colors_pre_transform_percent > 0 &&
593       !cparams.lossy_palette) {
594     // single channel palette (like FLIF's ChannelCompact)
595     for (size_t i = 0; i < gi.nb_channels; i++) {
596       int min, max;
597       gi.channel[gi.nb_meta_channels + i].compute_minmax(&min, &max);
598       int64_t colors = max - min + 1;
599       JXL_DEBUG_V(10, "Channel %zu: range=%i..%i", i, min, max);
600       Transform maybe_palette_1(TransformId::kPalette);
601       maybe_palette_1.begin_c = i + gi.nb_meta_channels;
602       maybe_palette_1.num_c = 1;
603       // simple heuristic: if less than X percent of the values in the range
604       // actually occur, it is probably worth it to do a compaction
605       // (but only if the channel palette is less than 6% the size of the
606       // image itself)
607       maybe_palette_1.nb_colors = std::min(
608           (int)(xsize * ysize / 16),
609           (int)(cparams.channel_colors_pre_transform_percent / 100. * colors));
610       if (gi.do_transform(maybe_palette_1, weighted::Header())) {
611         // effective bit depth is lower, adjust quantization accordingly
612         gi.channel[gi.nb_meta_channels + i].compute_minmax(&min, &max);
613         if (max < maxval) maxval = max;
614       }
615     }
616   }
617 
618   // Global palette
619   if ((cparams.palette_colors != 0 || cparams.lossy_palette) &&
620       cparams.speed_tier < SpeedTier::kFalcon) {
621     // all-channel palette (e.g. RGBA)
622     if (gi.nb_channels > 1) {
623       Transform maybe_palette(TransformId::kPalette);
624       maybe_palette.begin_c = gi.nb_meta_channels;
625       maybe_palette.num_c = gi.nb_channels;
626       maybe_palette.nb_colors =
627           std::min((int)(xsize * ysize / 8), std::abs(cparams.palette_colors));
628       maybe_palette.ordered_palette = cparams.palette_colors >= 0;
629       maybe_palette.lossy_palette =
630           (cparams.lossy_palette && gi.nb_channels == 3);
631       if (maybe_palette.lossy_palette) {
632         maybe_palette.predictor = Predictor::Average4;
633       }
634       // TODO(veluca): use a custom weighted header if using the weighted
635       // predictor.
636       gi.do_transform(maybe_palette, weighted::Header());
637     }
638     // all-minus-one-channel palette (RGB with separate alpha, or CMY with
639     // separate K)
640     if (gi.nb_channels > 3) {
641       Transform maybe_palette_3(TransformId::kPalette);
642       maybe_palette_3.begin_c = gi.nb_meta_channels;
643       maybe_palette_3.num_c = gi.nb_channels - 1;
644       maybe_palette_3.nb_colors =
645           std::min((int)(xsize * ysize / 8), std::abs(cparams.palette_colors));
646       maybe_palette_3.ordered_palette = cparams.palette_colors >= 0;
647       maybe_palette_3.lossy_palette = cparams.lossy_palette;
648       if (maybe_palette_3.lossy_palette) {
649         maybe_palette_3.predictor = Predictor::Average4;
650       }
651       gi.do_transform(maybe_palette_3, weighted::Header());
652     }
653   }
654 
655   if (cparams.color_transform == ColorTransform::kNone && do_color && !fp) {
656     if (cparams.colorspace == 1 ||
657         (cparams.colorspace < 0 && (quality < 100 || cparams.near_lossless ||
658                                     cparams.speed_tier > SpeedTier::kHare))) {
659       Transform ycocg{TransformId::kRCT};
660       ycocg.rct_type = 6;
661       ycocg.begin_c = gi.nb_meta_channels;
662       gi.do_transform(ycocg, weighted::Header());
663     } else if (cparams.colorspace >= 2) {
664       Transform sg(TransformId::kRCT);
665       sg.begin_c = gi.nb_meta_channels;
666       sg.rct_type = cparams.colorspace - 2;
667       gi.do_transform(sg, weighted::Header());
668     }
669   }
670 
671   if (cparams.responsive && gi.nb_channels != 0) {
672     gi.do_transform(Transform(TransformId::kSqueeze),
673                     weighted::Header());  // use default squeezing
674   }
675 
676   std::vector<uint32_t> quants;
677 
678   if (quality < 100 || cquality < 100) {
679     quants.resize(gi.channel.size(), 1);
680     JXL_DEBUG_V(
681         2,
682         "Adding quantization constants corresponding to luma quality %.2f "
683         "and chroma quality %.2f",
684         quality, cquality);
685     if (!cparams.responsive) {
686       JXL_DEBUG_V(1,
687                   "Warning: lossy compression without Squeeze "
688                   "transform is just color quantization.");
689       quality = (400 + quality) / 5;
690       cquality = (400 + cquality) / 5;
691     }
692 
693     // convert 'quality' to quantization scaling factor
694     if (quality > 50) {
695       quality = 200.0 - quality * 2.0;
696     } else {
697       quality = 900.0 - quality * 16.0;
698     }
699     if (cquality > 50) {
700       cquality = 200.0 - cquality * 2.0;
701     } else {
702       cquality = 900.0 - cquality * 16.0;
703     }
704     if (cparams.color_transform != ColorTransform::kXYB) {
705       quality *= 0.01f * maxval / 255.f;
706       cquality *= 0.01f * maxval / 255.f;
707     } else {
708       quality *= 0.01f;
709       cquality *= 0.01f;
710     }
711 
712     if (cparams.options.nb_repeats == 0) {
713       return JXL_FAILURE("nb_repeats = 0 not supported with modular lossy!");
714     }
715     for (uint32_t i = gi.nb_meta_channels; i < gi.channel.size(); i++) {
716       Channel& ch = gi.channel[i];
717       int shift = ch.hcshift + ch.vcshift;  // number of pixel halvings
718       if (shift > 15) shift = 15;
719       int q;
720       // assuming default Squeeze here
721       int component = ((i - gi.nb_meta_channels) % gi.real_nb_channels);
722       // last 4 channels are final chroma residuals
723       if (gi.real_nb_channels > 2 && i >= gi.channel.size() - 4) {
724         component = 1;
725       }
726 
727       if (cparams.color_transform == ColorTransform::kXYB && component < 3) {
728         q = (component == 0 ? quality : cquality) * squeeze_quality_factor_xyb *
729             squeeze_xyb_qtable[component][shift];
730       } else {
731         if (cparams.colorspace != 0 && component > 0 && component < 3) {
732           q = cquality * squeeze_quality_factor * squeeze_chroma_qtable[shift];
733         } else {
734           q = quality * squeeze_quality_factor * squeeze_luma_factor *
735               squeeze_luma_qtable[shift];
736         }
737       }
738       if (q < 1) q = 1;
739       QuantizeChannel(gi.channel[i], q);
740       quants[i] = q;
741     }
742   }
743 
744   // Fill other groups.
745   struct GroupParams {
746     Rect rect;
747     int minShift;
748     int maxShift;
749     ModularStreamId id;
750   };
751   std::vector<GroupParams> stream_params;
752 
753   stream_options[0] = cparams.options;
754 
755   // DC
756   for (size_t group_id = 0; group_id < frame_dim.num_dc_groups; group_id++) {
757     const size_t gx = group_id % frame_dim.xsize_dc_groups;
758     const size_t gy = group_id / frame_dim.xsize_dc_groups;
759     const Rect rect(gx * frame_dim.dc_group_dim, gy * frame_dim.dc_group_dim,
760                     frame_dim.dc_group_dim, frame_dim.dc_group_dim);
761     // minShift==3 because (frame_dim.dc_group_dim >> 3) == frame_dim.group_dim
762     // maxShift==1000 is infinity
763     stream_params.push_back(
764         GroupParams{rect, 3, 1000, ModularStreamId::ModularDC(group_id)});
765   }
766   // AC global -> nothing.
767   // AC
768   for (size_t group_id = 0; group_id < frame_dim.num_groups; group_id++) {
769     const size_t gx = group_id % frame_dim.xsize_groups;
770     const size_t gy = group_id / frame_dim.xsize_groups;
771     const Rect mrect(gx * frame_dim.group_dim, gy * frame_dim.group_dim,
772                      frame_dim.group_dim, frame_dim.group_dim);
773     for (size_t i = 0; i < enc_state->progressive_splitter.GetNumPasses();
774          i++) {
775       int maxShift, minShift;
776       frame_header.passes.GetDownsamplingBracket(i, minShift, maxShift);
777       stream_params.push_back(GroupParams{
778           mrect, minShift, maxShift, ModularStreamId::ModularAC(group_id, i)});
779     }
780   }
781   gi_channel.resize(stream_images.size());
782 
783   RunOnPool(
784       pool, 0, stream_params.size(), ThreadPool::SkipInit(),
785       [&](size_t i, size_t _) {
786         stream_options[stream_params[i].id.ID(frame_dim)] = cparams.options;
787         JXL_CHECK(PrepareStreamParams(
788             stream_params[i].rect, cparams, stream_params[i].minShift,
789             stream_params[i].maxShift, stream_params[i].id, do_color));
790       },
791       "ChooseParams");
792   {
793     // Clear out channels that have been copied to groups.
794     Image& full_image = stream_images[0];
795     size_t c = full_image.nb_meta_channels;
796     for (; c < full_image.channel.size(); c++) {
797       Channel& fc = full_image.channel[c];
798       if (fc.w > frame_dim.group_dim || fc.h > frame_dim.group_dim) break;
799     }
800     for (; c < full_image.channel.size(); c++) {
801       full_image.channel[c].plane = ImageI();
802     }
803   }
804 
805   if (!quants.empty()) {
806     for (uint32_t stream_id = 0; stream_id < stream_images.size();
807          stream_id++) {
808       // skip non-modular stream_ids
809       if (stream_id > 0 && gi_channel[stream_id].empty()) continue;
810       Image& image = stream_images[stream_id];
811       const ModularOptions& options = stream_options[stream_id];
812       for (uint32_t i = image.nb_meta_channels; i < image.channel.size(); i++) {
813         if (i >= image.nb_meta_channels &&
814             (image.channel[i].w > options.max_chan_size ||
815              image.channel[i].h > options.max_chan_size)) {
816           continue;
817         }
818         if (stream_id > 0 && gi_channel[stream_id].empty()) continue;
819         size_t ch_id = stream_id == 0
820                            ? i
821                            : gi_channel[stream_id][i - image.nb_meta_channels];
822         uint32_t q = quants[ch_id];
823         // Inform the tree splitting heuristics that each channel in each group
824         // used this quantization factor. This will produce a tree with the
825         // given multipliers.
826         if (multiplier_info.empty() ||
827             multiplier_info.back().range[1][0] != stream_id ||
828             multiplier_info.back().multiplier != q) {
829           StaticPropRange range;
830           range[0] = {i, i + 1};
831           range[1] = {stream_id, stream_id + 1};
832           multiplier_info.push_back({range, (uint32_t)q});
833         } else {
834           // Previous channel in the same group had the same quantization
835           // factor. Don't provide two different ranges, as that creates
836           // unnecessary nodes.
837           multiplier_info.back().range[0][1] = i + 1;
838         }
839       }
840     }
841     // Merge group+channel settings that have the same channels and quantization
842     // factors, to avoid unnecessary nodes.
843     std::sort(multiplier_info.begin(), multiplier_info.end(),
844               [](ModularMultiplierInfo a, ModularMultiplierInfo b) {
845                 return std::make_tuple(a.range, a.multiplier) <
846                        std::make_tuple(b.range, b.multiplier);
847               });
848     size_t new_num = 1;
849     for (size_t i = 1; i < multiplier_info.size(); i++) {
850       ModularMultiplierInfo& prev = multiplier_info[new_num - 1];
851       ModularMultiplierInfo& cur = multiplier_info[i];
852       if (prev.range[0] == cur.range[0] && prev.multiplier == cur.multiplier &&
853           prev.range[1][1] == cur.range[1][0]) {
854         prev.range[1][1] = cur.range[1][1];
855       } else {
856         multiplier_info[new_num++] = multiplier_info[i];
857       }
858     }
859     multiplier_info.resize(new_num);
860   }
861 
862   return PrepareEncoding(pool, enc_state->shared.frame_dim,
863                          enc_state->heuristics.get(), aux_out);
864 }
865 
PrepareEncoding(ThreadPool * pool,const FrameDimensions & frame_dim,EncoderHeuristics * heuristics,AuxOut * aux_out)866 Status ModularFrameEncoder::PrepareEncoding(ThreadPool* pool,
867                                             const FrameDimensions& frame_dim,
868                                             EncoderHeuristics* heuristics,
869                                             AuxOut* aux_out) {
870   if (!tree.empty()) return true;
871 
872   // Compute tree.
873   size_t num_streams = stream_images.size();
874   stream_headers.resize(num_streams);
875   tokens.resize(num_streams);
876 
877   if (heuristics->CustomFixedTreeLossless(frame_dim, &tree)) {
878     // Using a fixed tree.
879   } else if (cparams.speed_tier != SpeedTier::kFalcon || quality != 100 ||
880              !cparams.modular_mode) {
881     // Avoid creating a tree with leaves that don't correspond to any pixels.
882     std::vector<size_t> useful_splits;
883     useful_splits.reserve(tree_splits.size());
884     for (size_t chunk = 0; chunk < tree_splits.size() - 1; chunk++) {
885       bool has_pixels = false;
886       size_t start = tree_splits[chunk];
887       size_t stop = tree_splits[chunk + 1];
888       for (size_t i = start; i < stop; i++) {
889         for (const Channel& c : stream_images[i].channel) {
890           if (c.w && c.h) has_pixels = true;
891         }
892       }
893       if (has_pixels) {
894         useful_splits.push_back(tree_splits[chunk]);
895       }
896     }
897     // Don't do anything if modular mode does not have any pixels in this image
898     if (useful_splits.empty()) return true;
899     useful_splits.push_back(tree_splits.back());
900 
901     std::atomic_flag invalid_force_wp = ATOMIC_FLAG_INIT;
902 
903     std::vector<Tree> trees(useful_splits.size() - 1);
904     RunOnPool(
905         pool, 0, useful_splits.size() - 1, ThreadPool::SkipInit(),
906         [&](size_t chunk, size_t _) {
907           // TODO(veluca): parallelize more.
908           size_t total_pixels = 0;
909           uint32_t start = useful_splits[chunk];
910           uint32_t stop = useful_splits[chunk + 1];
911           uint32_t max_c = 0;
912           if (stream_options[start].tree_kind !=
913               ModularOptions::TreeKind::kLearn) {
914             for (size_t i = start; i < stop; i++) {
915               for (const Channel& ch : stream_images[i].channel) {
916                 total_pixels += ch.w * ch.h;
917               }
918             }
919             trees[chunk] =
920                 PredefinedTree(stream_options[start].tree_kind, total_pixels);
921             return;
922           }
923           TreeSamples tree_samples;
924           if (!tree_samples.SetPredictor(stream_options[start].predictor,
925                                          stream_options[start].wp_tree_mode)) {
926             invalid_force_wp.test_and_set(std::memory_order_acq_rel);
927             return;
928           }
929           if (!tree_samples.SetProperties(
930                   stream_options[start].splitting_heuristics_properties,
931                   stream_options[start].wp_tree_mode)) {
932             invalid_force_wp.test_and_set(std::memory_order_acq_rel);
933             return;
934           }
935           std::vector<pixel_type> pixel_samples;
936           std::vector<pixel_type> diff_samples;
937           std::vector<uint32_t> group_pixel_count;
938           std::vector<uint32_t> channel_pixel_count;
939           for (size_t i = start; i < stop; i++) {
940             max_c = std::max<uint32_t>(stream_images[i].channel.size(), max_c);
941             CollectPixelSamples(stream_images[i], stream_options[i], i,
942                                 group_pixel_count, channel_pixel_count,
943                                 pixel_samples, diff_samples);
944           }
945           StaticPropRange range;
946           range[0] = {0, max_c};
947           range[1] = {start, stop};
948           auto local_multiplier_info = multiplier_info;
949 
950           tree_samples.PreQuantizeProperties(
951               range, local_multiplier_info, group_pixel_count,
952               channel_pixel_count, pixel_samples, diff_samples,
953               stream_options[start].max_property_values);
954           for (size_t i = start; i < stop; i++) {
955             JXL_CHECK(ModularGenericCompress(
956                 stream_images[i], stream_options[i], /*writer=*/nullptr,
957                 /*aux_out=*/nullptr, 0, i, &tree_samples, &total_pixels));
958           }
959 
960           // TODO(veluca): parallelize more.
961           trees[chunk] =
962               LearnTree(std::move(tree_samples), total_pixels,
963                         stream_options[start], local_multiplier_info, range);
964         },
965         "LearnTrees");
966     if (invalid_force_wp.test_and_set(std::memory_order_acq_rel)) {
967       return JXL_FAILURE("PrepareEncoding: force_no_wp with {Weighted}");
968     }
969     tree.clear();
970     MergeTrees(trees, useful_splits, 0, useful_splits.size() - 1, &tree);
971   } else {
972     // Fixed tree.
973     // TODO(veluca): determine cutoffs?
974     std::vector<int32_t> cutoffs = {-255, -191, -127, -95, -63, -47, -31, -23,
975                                     -15,  -11,  -7,   -5,  -3,  -1,  0,   1,
976                                     3,    5,    7,    11,  15,  23,  31,  47,
977                                     63,   95,   127,  191, 255};
978     size_t total_pixels = 0;
979     for (const Image& img : stream_images) {
980       for (const Channel& ch : img.channel) {
981         total_pixels += ch.w * ch.h;
982       }
983     }
984     tree = MakeFixedTree(kNumNonrefProperties - weighted::kNumProperties,
985                          cutoffs, Predictor::Weighted, total_pixels);
986   }
987   // TODO(veluca): do this somewhere else.
988   if (cparams.near_lossless) {
989     for (size_t i = 0; i < tree.size(); i++) {
990       tree[i].predictor_offset = 0;
991     }
992   }
993   tree_tokens.resize(1);
994   tree_tokens[0].clear();
995   Tree decoded_tree;
996   TokenizeTree(tree, &tree_tokens[0], &decoded_tree);
997   JXL_ASSERT(tree.size() == decoded_tree.size());
998   tree = std::move(decoded_tree);
999 
1000   if (WantDebugOutput(aux_out)) {
1001     PrintTree(tree, aux_out->debug_prefix + "/global_tree");
1002   }
1003 
1004   image_widths.resize(num_streams);
1005   RunOnPool(
1006       pool, 0, num_streams, ThreadPool::SkipInit(),
1007       [&](size_t stream_id, size_t _) {
1008         AuxOut my_aux_out;
1009         if (aux_out) {
1010           my_aux_out.dump_image = aux_out->dump_image;
1011           my_aux_out.debug_prefix = aux_out->debug_prefix;
1012         }
1013         tokens[stream_id].clear();
1014         JXL_CHECK(ModularGenericCompress(
1015             stream_images[stream_id], stream_options[stream_id],
1016             /*writer=*/nullptr, &my_aux_out, 0, stream_id,
1017             /*tree_samples=*/nullptr,
1018             /*total_pixels=*/nullptr,
1019             /*tree=*/&tree, /*header=*/&stream_headers[stream_id],
1020             /*tokens=*/&tokens[stream_id],
1021             /*widths=*/&image_widths[stream_id]));
1022       },
1023       "ComputeTokens");
1024   return true;
1025 }
1026 
EncodeGlobalInfo(BitWriter * writer,AuxOut * aux_out)1027 Status ModularFrameEncoder::EncodeGlobalInfo(BitWriter* writer,
1028                                              AuxOut* aux_out) {
1029   BitWriter::Allotment allotment(writer, 1);
1030   // If we are using brotli, or not using modular mode.
1031   if (tree_tokens.empty() || tree_tokens[0].empty()) {
1032     writer->Write(1, 0);
1033     ReclaimAndCharge(writer, &allotment, kLayerModularTree, aux_out);
1034     return true;
1035   }
1036   writer->Write(1, 1);
1037   ReclaimAndCharge(writer, &allotment, kLayerModularTree, aux_out);
1038 
1039   // Write tree
1040   HistogramParams params;
1041   if (cparams.speed_tier > SpeedTier::kKitten) {
1042     params.clustering = HistogramParams::ClusteringType::kFast;
1043     params.ans_histogram_strategy =
1044         HistogramParams::ANSHistogramStrategy::kApproximate;
1045     params.lz77_method = cparams.decoding_speed_tier >= 3
1046                              ? (cparams.speed_tier == SpeedTier::kFalcon
1047                                     ? HistogramParams::LZ77Method::kRLE
1048                                     : HistogramParams::LZ77Method::kLZ77)
1049                              : HistogramParams::LZ77Method::kNone;
1050     // Near-lossless DC, as well as modular mode, require choosing hybrid uint
1051     // more carefully.
1052     if ((!extra_dc_precision.empty() && extra_dc_precision[0] != 0) ||
1053         (cparams.modular_mode && cparams.speed_tier < SpeedTier::kCheetah)) {
1054       params.uint_method = HistogramParams::HybridUintMethod::kFast;
1055     } else {
1056       params.uint_method = HistogramParams::HybridUintMethod::kNone;
1057     }
1058   } else if (cparams.speed_tier <= SpeedTier::kTortoise) {
1059     params.lz77_method = HistogramParams::LZ77Method::kOptimal;
1060   } else {
1061     params.lz77_method = HistogramParams::LZ77Method::kLZ77;
1062   }
1063   if (cparams.decoding_speed_tier >= 1) {
1064     params.max_histograms = 12;
1065   }
1066   BuildAndEncodeHistograms(params, kNumTreeContexts, tree_tokens, &code,
1067                            &context_map, writer, kLayerModularTree, aux_out);
1068   WriteTokens(tree_tokens[0], code, context_map, writer, kLayerModularTree,
1069               aux_out);
1070   params.image_widths = image_widths;
1071   // Write histograms.
1072   BuildAndEncodeHistograms(params, (tree.size() + 1) / 2, tokens, &code,
1073                            &context_map, writer, kLayerModularGlobal, aux_out);
1074   return true;
1075 }
1076 
EncodeStream(BitWriter * writer,AuxOut * aux_out,size_t layer,const ModularStreamId & stream)1077 Status ModularFrameEncoder::EncodeStream(BitWriter* writer, AuxOut* aux_out,
1078                                          size_t layer,
1079                                          const ModularStreamId& stream) {
1080   size_t stream_id = stream.ID(frame_dim);
1081   if (stream_images[stream_id].real_nb_channels < 1) {
1082     return true;  // Image with no channels, header never gets decoded.
1083   }
1084   JXL_RETURN_IF_ERROR(
1085       Bundle::Write(stream_headers[stream_id], writer, layer, aux_out));
1086   WriteTokens(tokens[stream_id], code, context_map, writer, layer, aux_out);
1087   return true;
1088 }
1089 
1090 namespace {
EstimateWPCost(const Image & img,size_t i)1091 float EstimateWPCost(const Image& img, size_t i) {
1092   size_t extra_bits = 0;
1093   float histo_cost = 0;
1094   HybridUintConfig config;
1095   int32_t cutoffs[] = {-500, -392, -255, -191, -127, -95, -63, -47, -31,
1096                        -23,  -15,  -11,  -7,   -4,   -3,  -1,  0,   1,
1097                        3,    5,    7,    11,   15,   23,  31,  47,  63,
1098                        95,   127,  191,  255,  392,  500};
1099   constexpr size_t nc = sizeof(cutoffs) / sizeof(*cutoffs) + 1;
1100   Histogram histo[nc] = {};
1101   weighted::Header wp_header;
1102   PredictorMode(i, &wp_header);
1103   for (const Channel& ch : img.channel) {
1104     const intptr_t onerow = ch.plane.PixelsPerRow();
1105     weighted::State wp_state(wp_header, ch.w, ch.h);
1106     Properties properties(1);
1107     for (size_t y = 0; y < ch.h; y++) {
1108       const pixel_type* JXL_RESTRICT r = ch.Row(y);
1109       for (size_t x = 0; x < ch.w; x++) {
1110         size_t offset = 0;
1111         pixel_type_w left = (x ? r[x - 1] : y ? *(r + x - onerow) : 0);
1112         pixel_type_w top = (y ? *(r + x - onerow) : left);
1113         pixel_type_w topleft = (x && y ? *(r + x - 1 - onerow) : left);
1114         pixel_type_w topright =
1115             (x + 1 < ch.w && y ? *(r + x + 1 - onerow) : top);
1116         pixel_type_w toptop = (y > 1 ? *(r + x - onerow - onerow) : top);
1117         pixel_type guess = wp_state.Predict</*compute_properties=*/true>(
1118             x, y, ch.w, top, left, topright, topleft, toptop, &properties,
1119             offset);
1120         size_t ctx = 0;
1121         for (int c : cutoffs) {
1122           ctx += c >= properties[0];
1123         }
1124         pixel_type res = r[x] - guess;
1125         uint32_t token, nbits, bits;
1126         config.Encode(PackSigned(res), &token, &nbits, &bits);
1127         histo[ctx].Add(token);
1128         extra_bits += nbits;
1129         wp_state.UpdateErrors(r[x], x, y, ch.w);
1130       }
1131     }
1132     for (size_t h = 0; h < nc; h++) {
1133       histo_cost += histo[h].ShannonEntropy();
1134       histo[h].Clear();
1135     }
1136   }
1137   return histo_cost + extra_bits;
1138 }
1139 
EstimateCost(const Image & img)1140 float EstimateCost(const Image& img) {
1141   // TODO(veluca): consider SIMDfication of this code.
1142   size_t extra_bits = 0;
1143   float histo_cost = 0;
1144   HybridUintConfig config;
1145   uint32_t cutoffs[] = {0,  1,  3,  5,   7,   11,  15,  23, 31,
1146                         47, 63, 95, 127, 191, 255, 392, 500};
1147   constexpr size_t nc = sizeof(cutoffs) / sizeof(*cutoffs) + 1;
1148   Histogram histo[nc] = {};
1149   for (const Channel& ch : img.channel) {
1150     const intptr_t onerow = ch.plane.PixelsPerRow();
1151     for (size_t y = 0; y < ch.h; y++) {
1152       const pixel_type* JXL_RESTRICT r = ch.Row(y);
1153       for (size_t x = 0; x < ch.w; x++) {
1154         pixel_type_w left = (x ? r[x - 1] : y ? *(r + x - onerow) : 0);
1155         pixel_type_w top = (y ? *(r + x - onerow) : left);
1156         pixel_type_w topleft = (x && y ? *(r + x - 1 - onerow) : left);
1157         size_t maxdiff = std::max(std::max(left, top), topleft) -
1158                          std::min(std::min(left, top), topleft);
1159         size_t ctx = 0;
1160         for (uint32_t c : cutoffs) {
1161           ctx += c > maxdiff;
1162         }
1163         pixel_type res = r[x] - ClampedGradient(top, left, topleft);
1164         uint32_t token, nbits, bits;
1165         config.Encode(PackSigned(res), &token, &nbits, &bits);
1166         histo[ctx].Add(token);
1167         extra_bits += nbits;
1168       }
1169     }
1170     for (size_t h = 0; h < nc; h++) {
1171       histo_cost += histo[h].ShannonEntropy();
1172       histo[h].Clear();
1173     }
1174   }
1175   return histo_cost + extra_bits;
1176 }
1177 
1178 }  // namespace
1179 
PrepareStreamParams(const Rect & rect,const CompressParams & cparams,int minShift,int maxShift,const ModularStreamId & stream,bool do_color)1180 Status ModularFrameEncoder::PrepareStreamParams(const Rect& rect,
1181                                                 const CompressParams& cparams,
1182                                                 int minShift, int maxShift,
1183                                                 const ModularStreamId& stream,
1184                                                 bool do_color) {
1185   size_t stream_id = stream.ID(frame_dim);
1186   JXL_ASSERT(stream_id != 0);
1187   Image& full_image = stream_images[0];
1188   const size_t xsize = rect.xsize();
1189   const size_t ysize = rect.ysize();
1190   int maxval = full_image.maxval;
1191   Image& gi = stream_images[stream_id];
1192   gi = Image(xsize, ysize, maxval, 0);
1193   // start at the first bigger-than-frame_dim.group_dim non-metachannel
1194   size_t c = full_image.nb_meta_channels;
1195   for (; c < full_image.channel.size(); c++) {
1196     Channel& fc = full_image.channel[c];
1197     if (fc.w > frame_dim.group_dim || fc.h > frame_dim.group_dim) break;
1198   }
1199   for (; c < full_image.channel.size(); c++) {
1200     Channel& fc = full_image.channel[c];
1201     int shift = std::min(fc.hshift, fc.vshift);
1202     if (shift > maxShift) continue;
1203     if (shift < minShift) continue;
1204     Rect r(rect.x0() >> fc.hshift, rect.y0() >> fc.vshift,
1205            rect.xsize() >> fc.hshift, rect.ysize() >> fc.vshift, fc.w, fc.h);
1206     if (r.xsize() == 0 || r.ysize() == 0) continue;
1207     gi_channel[stream_id].push_back(c);
1208     Channel gc(r.xsize(), r.ysize());
1209     gc.hshift = fc.hshift;
1210     gc.vshift = fc.vshift;
1211     for (size_t y = 0; y < r.ysize(); ++y) {
1212       const pixel_type* const JXL_RESTRICT row_in = r.ConstRow(fc.plane, y);
1213       pixel_type* const JXL_RESTRICT row_out = gc.Row(y);
1214       for (size_t x = 0; x < r.xsize(); ++x) {
1215         row_out[x] = row_in[x];
1216       }
1217     }
1218     gi.channel.emplace_back(std::move(gc));
1219   }
1220   gi.nb_channels = gi.channel.size();
1221   gi.real_nb_channels = gi.nb_channels;
1222 
1223   // Do some per-group transforms
1224 
1225   float quality = cparams.quality_pair.first;
1226 
1227   // Local palette
1228   // TODO(veluca): make this work with quantize-after-prediction in lossy mode.
1229   if (quality == 100 && cparams.palette_colors != 0 &&
1230       cparams.speed_tier < SpeedTier::kCheetah) {
1231     // all-channel palette (e.g. RGBA)
1232     if (gi.nb_channels > 1) {
1233       Transform maybe_palette(TransformId::kPalette);
1234       maybe_palette.begin_c = gi.nb_meta_channels;
1235       maybe_palette.num_c = gi.nb_channels;
1236       maybe_palette.nb_colors = std::abs(cparams.palette_colors);
1237       maybe_palette.ordered_palette = cparams.palette_colors >= 0;
1238       gi.do_transform(maybe_palette, weighted::Header());
1239     }
1240     // all-minus-one-channel palette (RGB with separate alpha, or CMY with
1241     // separate K)
1242     if (gi.nb_channels > 3) {
1243       Transform maybe_palette_3(TransformId::kPalette);
1244       maybe_palette_3.begin_c = gi.nb_meta_channels;
1245       maybe_palette_3.num_c = gi.nb_channels - 1;
1246       maybe_palette_3.nb_colors = std::abs(cparams.palette_colors);
1247       maybe_palette_3.ordered_palette = cparams.palette_colors >= 0;
1248       maybe_palette_3.lossy_palette = cparams.lossy_palette;
1249       if (maybe_palette_3.lossy_palette) {
1250         maybe_palette_3.predictor = Predictor::Weighted;
1251       }
1252       gi.do_transform(maybe_palette_3, weighted::Header());
1253     }
1254   }
1255 
1256   // Local channel palette
1257   if (cparams.channel_colors_percent > 0 && quality == 100 &&
1258       !cparams.lossy_palette && cparams.speed_tier < SpeedTier::kCheetah) {
1259     // single channel palette (like FLIF's ChannelCompact)
1260     for (size_t i = 0; i < gi.nb_channels; i++) {
1261       int min, max;
1262       gi.channel[gi.nb_meta_channels + i].compute_minmax(&min, &max);
1263       int colors = max - min + 1;
1264       JXL_DEBUG_V(10, "Channel %zu: range=%i..%i", i, min, max);
1265       Transform maybe_palette_1(TransformId::kPalette);
1266       maybe_palette_1.begin_c = i + gi.nb_meta_channels;
1267       maybe_palette_1.num_c = 1;
1268       // simple heuristic: if less than X percent of the values in the range
1269       // actually occur, it is probably worth it to do a compaction
1270       // (but only if the channel palette is less than 80% the size of the
1271       // image itself)
1272       maybe_palette_1.nb_colors =
1273           std::min((int)(xsize * ysize * 0.8),
1274                    (int)(cparams.channel_colors_percent / 100. * colors));
1275       gi.do_transform(maybe_palette_1, weighted::Header());
1276     }
1277   }
1278   if (cparams.near_lossless > 0 && gi.nb_channels != 0) {
1279     Transform nl(TransformId::kNearLossless);
1280     nl.predictor = cparams.options.predictor;
1281     JXL_RETURN_IF_ERROR(nl.predictor != Predictor::Best);
1282     JXL_RETURN_IF_ERROR(nl.predictor != Predictor::Variable);
1283     nl.begin_c = gi.nb_meta_channels;
1284     if (cparams.colorspace == 0) {
1285       nl.num_c = gi.nb_channels;
1286       nl.max_delta_error = cparams.near_lossless;
1287       gi.do_transform(nl, weighted::Header());
1288     } else {
1289       nl.num_c = 1;
1290       nl.max_delta_error = cparams.near_lossless;
1291       gi.do_transform(nl, weighted::Header());
1292       nl.begin_c += 1;
1293       nl.num_c = gi.nb_channels - 1;
1294       nl.max_delta_error++;  // more loss for chroma
1295       gi.do_transform(nl, weighted::Header());
1296     }
1297   }
1298 
1299   // lossless and no specific color transform specified: try Nothing, YCoCg,
1300   // and 17 RCTs
1301   if (cparams.color_transform == ColorTransform::kNone && quality == 100 &&
1302       cparams.colorspace < 0 && gi.nb_channels > 2 && !cparams.near_lossless &&
1303       cparams.responsive == false && do_color &&
1304       cparams.speed_tier <= SpeedTier::kHare) {
1305     Transform sg(TransformId::kRCT);
1306     sg.begin_c = gi.nb_meta_channels;
1307 
1308     size_t nb_rcts_to_try = 0;
1309     switch (cparams.speed_tier) {
1310       case SpeedTier::kFalcon:
1311         nb_rcts_to_try = 0;  // Just do global YCoCg
1312         break;
1313       case SpeedTier::kCheetah:
1314         nb_rcts_to_try = 0;  // Just do global YCoCg
1315         break;
1316       case SpeedTier::kHare:
1317         nb_rcts_to_try = 4;
1318         break;
1319       case SpeedTier::kWombat:
1320         nb_rcts_to_try = 5;
1321         break;
1322       case SpeedTier::kSquirrel:
1323         nb_rcts_to_try = 7;
1324         break;
1325       case SpeedTier::kKitten:
1326         nb_rcts_to_try = 9;
1327         break;
1328       case SpeedTier::kTortoise:
1329         nb_rcts_to_try = 19;
1330         break;
1331     }
1332     float best_cost = std::numeric_limits<float>::max();
1333     size_t best_rct = 0;
1334     // These should be 19 actually different transforms; the remaining ones
1335     // are equivalent to one of these (note that the first two are do-nothing
1336     // and YCoCg) modulo channel reordering (which only matters in the case of
1337     // MA-with-prev-channels-properties) and/or sign (e.g. RmG vs GmR)
1338     for (int i : {0 * 7 + 0, 0 * 7 + 6, 0 * 7 + 5, 1 * 7 + 3, 3 * 7 + 5,
1339                   5 * 7 + 5, 1 * 7 + 5, 2 * 7 + 5, 1 * 7 + 1, 0 * 7 + 4,
1340                   1 * 7 + 2, 2 * 7 + 1, 2 * 7 + 2, 2 * 7 + 3, 4 * 7 + 4,
1341                   4 * 7 + 5, 0 * 7 + 2, 0 * 7 + 1, 0 * 7 + 3}) {
1342       if (nb_rcts_to_try == 0) break;
1343       int num_transforms_to_keep = gi.transform.size();
1344       sg.rct_type = i;
1345       gi.do_transform(sg, weighted::Header());
1346       float cost = EstimateCost(gi);
1347       if (cost < best_cost) {
1348         best_rct = i;
1349         best_cost = cost;
1350       }
1351       nb_rcts_to_try--;
1352       // Ensure we do not clamp channels to their supposed range, as this
1353       // otherwise breaks in the presence of patches.
1354       gi.undo_transforms(weighted::Header(), num_transforms_to_keep == 0
1355                                                  ? -1
1356                                                  : num_transforms_to_keep);
1357     }
1358     // Apply the best RCT to the image for future encoding.
1359     sg.rct_type = best_rct;
1360     gi.do_transform(sg, weighted::Header());
1361   } else {
1362     // No need to try anything, just use the default options.
1363   }
1364   size_t nb_wp_modes = 0;
1365   switch (cparams.speed_tier) {
1366     case SpeedTier::kFalcon:
1367       nb_wp_modes = 1;
1368       break;
1369     case SpeedTier::kCheetah:
1370       nb_wp_modes = 1;
1371       break;
1372     case SpeedTier::kHare:
1373       nb_wp_modes = 1;
1374       break;
1375     case SpeedTier::kWombat:
1376       nb_wp_modes = 1;
1377       break;
1378     case SpeedTier::kSquirrel:
1379       nb_wp_modes = 1;
1380       break;
1381     case SpeedTier::kKitten:
1382       nb_wp_modes = 2;
1383       break;
1384     case SpeedTier::kTortoise:
1385       nb_wp_modes = 5;
1386       break;
1387   }
1388   if (nb_wp_modes > 1 &&
1389       (stream_options[stream_id].predictor == Predictor::Weighted ||
1390        stream_options[stream_id].predictor == Predictor::Best ||
1391        stream_options[stream_id].predictor == Predictor::Variable)) {
1392     float best_cost = std::numeric_limits<float>::max();
1393     stream_options[stream_id].wp_mode = 0;
1394     for (size_t i = 0; i < nb_wp_modes; i++) {
1395       float cost = EstimateWPCost(gi, i);
1396       if (cost < best_cost) {
1397         best_cost = cost;
1398         stream_options[stream_id].wp_mode = i;
1399       }
1400     }
1401   }
1402   return true;
1403 }
1404 
QuantizeWP(const int32_t * qrow,size_t onerow,size_t c,size_t x,size_t y,size_t w,weighted::State * wp_state,float value,float inv_factor)1405 int QuantizeWP(const int32_t* qrow, size_t onerow, size_t c, size_t x, size_t y,
1406                size_t w, weighted::State* wp_state, float value,
1407                float inv_factor) {
1408   float svalue = value * inv_factor;
1409   PredictionResult pred =
1410       PredictNoTreeWP(w, qrow + x, onerow, x, y, Predictor::Weighted, wp_state);
1411   svalue -= pred.guess;
1412   int residual = roundf(svalue);
1413   if (residual > 2 || residual < -2) residual = roundf(svalue * 0.5) * 2;
1414   return residual + pred.guess;
1415 }
1416 
QuantizeGradient(const int32_t * qrow,size_t onerow,size_t c,size_t x,size_t y,size_t w,float value,float inv_factor)1417 int QuantizeGradient(const int32_t* qrow, size_t onerow, size_t c, size_t x,
1418                      size_t y, size_t w, float value, float inv_factor) {
1419   float svalue = value * inv_factor;
1420   PredictionResult pred =
1421       PredictNoTreeNoWP(w, qrow + x, onerow, x, y, Predictor::Gradient);
1422   svalue -= pred.guess;
1423   int residual = roundf(svalue);
1424   if (residual > 2 || residual < -2) residual = roundf(svalue * 0.5) * 2;
1425   return residual + pred.guess;
1426 }
1427 
AddVarDCTDC(const Image3F & dc,size_t group_index,bool nl_dc,PassesEncoderState * enc_state)1428 void ModularFrameEncoder::AddVarDCTDC(const Image3F& dc, size_t group_index,
1429                                       bool nl_dc,
1430                                       PassesEncoderState* enc_state) {
1431   const Rect r = enc_state->shared.DCGroupRect(group_index);
1432   extra_dc_precision[group_index] = nl_dc ? 1 : 0;
1433   float mul = 1 << extra_dc_precision[group_index];
1434 
1435   size_t stream_id = ModularStreamId::VarDCTDC(group_index).ID(frame_dim);
1436   stream_options[stream_id].max_chan_size = 0xFFFFFF;
1437   stream_options[stream_id].predictor = Predictor::Weighted;
1438   stream_options[stream_id].wp_tree_mode = ModularOptions::TreeMode::kWPOnly;
1439   if (cparams.speed_tier >= SpeedTier::kSquirrel) {
1440     stream_options[stream_id].tree_kind = ModularOptions::TreeKind::kWPFixedDC;
1441   }
1442   if (cparams.decoding_speed_tier >= 1) {
1443     stream_options[stream_id].tree_kind =
1444         ModularOptions::TreeKind::kGradientFixedDC;
1445   }
1446 
1447   stream_images[stream_id] = Image(r.xsize(), r.ysize(), 255, 3);
1448   if (nl_dc && stream_options[stream_id].tree_kind ==
1449                    ModularOptions::TreeKind::kGradientFixedDC) {
1450     JXL_ASSERT(enc_state->shared.frame_header.chroma_subsampling.Is444());
1451     for (size_t c : {1, 0, 2}) {
1452       float inv_factor = enc_state->shared.quantizer.GetInvDcStep(c) * mul;
1453       float y_factor = enc_state->shared.quantizer.GetDcStep(1) / mul;
1454       float cfl_factor = enc_state->shared.cmap.DCFactors()[c];
1455       for (size_t y = 0; y < r.ysize(); y++) {
1456         int32_t* quant_row =
1457             stream_images[stream_id].channel[c < 2 ? c ^ 1 : c].plane.Row(y);
1458         size_t stride = stream_images[stream_id]
1459                             .channel[c < 2 ? c ^ 1 : c]
1460                             .plane.PixelsPerRow();
1461         const float* row = r.ConstPlaneRow(dc, c, y);
1462         if (c == 1) {
1463           for (size_t x = 0; x < r.xsize(); x++) {
1464             quant_row[x] = QuantizeGradient(quant_row, stride, c, x, y,
1465                                             r.xsize(), row[x], inv_factor);
1466           }
1467         } else {
1468           int32_t* quant_row_y =
1469               stream_images[stream_id].channel[0].plane.Row(y);
1470           for (size_t x = 0; x < r.xsize(); x++) {
1471             quant_row[x] = QuantizeGradient(
1472                 quant_row, stride, c, x, y, r.xsize(),
1473                 row[x] - quant_row_y[x] * (y_factor * cfl_factor), inv_factor);
1474           }
1475         }
1476       }
1477     }
1478   } else if (nl_dc) {
1479     JXL_ASSERT(enc_state->shared.frame_header.chroma_subsampling.Is444());
1480     for (size_t c : {1, 0, 2}) {
1481       float inv_factor = enc_state->shared.quantizer.GetInvDcStep(c) * mul;
1482       float y_factor = enc_state->shared.quantizer.GetDcStep(1) / mul;
1483       float cfl_factor = enc_state->shared.cmap.DCFactors()[c];
1484       weighted::Header header;
1485       weighted::State wp_state(header, r.xsize(), r.ysize());
1486       for (size_t y = 0; y < r.ysize(); y++) {
1487         int32_t* quant_row =
1488             stream_images[stream_id].channel[c < 2 ? c ^ 1 : c].plane.Row(y);
1489         size_t stride = stream_images[stream_id]
1490                             .channel[c < 2 ? c ^ 1 : c]
1491                             .plane.PixelsPerRow();
1492         const float* row = r.ConstPlaneRow(dc, c, y);
1493         if (c == 1) {
1494           for (size_t x = 0; x < r.xsize(); x++) {
1495             quant_row[x] = QuantizeWP(quant_row, stride, c, x, y, r.xsize(),
1496                                       &wp_state, row[x], inv_factor);
1497             wp_state.UpdateErrors(quant_row[x], x, y, r.xsize());
1498           }
1499         } else {
1500           int32_t* quant_row_y =
1501               stream_images[stream_id].channel[0].plane.Row(y);
1502           for (size_t x = 0; x < r.xsize(); x++) {
1503             quant_row[x] = QuantizeWP(
1504                 quant_row, stride, c, x, y, r.xsize(), &wp_state,
1505                 row[x] - quant_row_y[x] * (y_factor * cfl_factor), inv_factor);
1506             wp_state.UpdateErrors(quant_row[x], x, y, r.xsize());
1507           }
1508         }
1509       }
1510     }
1511   } else if (enc_state->shared.frame_header.chroma_subsampling.Is444()) {
1512     for (size_t c : {1, 0, 2}) {
1513       float inv_factor = enc_state->shared.quantizer.GetInvDcStep(c) * mul;
1514       float y_factor = enc_state->shared.quantizer.GetDcStep(1) / mul;
1515       float cfl_factor = enc_state->shared.cmap.DCFactors()[c];
1516       for (size_t y = 0; y < r.ysize(); y++) {
1517         int32_t* quant_row =
1518             stream_images[stream_id].channel[c < 2 ? c ^ 1 : c].plane.Row(y);
1519         const float* row = r.ConstPlaneRow(dc, c, y);
1520         if (c == 1) {
1521           for (size_t x = 0; x < r.xsize(); x++) {
1522             quant_row[x] = roundf(row[x] * inv_factor);
1523           }
1524         } else {
1525           int32_t* quant_row_y =
1526               stream_images[stream_id].channel[0].plane.Row(y);
1527           for (size_t x = 0; x < r.xsize(); x++) {
1528             quant_row[x] =
1529                 roundf((row[x] - quant_row_y[x] * (y_factor * cfl_factor)) *
1530                        inv_factor);
1531           }
1532         }
1533       }
1534     }
1535   } else {
1536     for (size_t c : {1, 0, 2}) {
1537       Rect rect(
1538           r.x0() >> enc_state->shared.frame_header.chroma_subsampling.HShift(c),
1539           r.y0() >> enc_state->shared.frame_header.chroma_subsampling.VShift(c),
1540           r.xsize() >>
1541               enc_state->shared.frame_header.chroma_subsampling.HShift(c),
1542           r.ysize() >>
1543               enc_state->shared.frame_header.chroma_subsampling.VShift(c));
1544       float inv_factor = enc_state->shared.quantizer.GetInvDcStep(c) * mul;
1545       size_t ys = rect.ysize();
1546       size_t xs = rect.xsize();
1547       Channel& ch = stream_images[stream_id].channel[c < 2 ? c ^ 1 : c];
1548       ch.w = xs;
1549       ch.h = ys;
1550       ch.resize();
1551       for (size_t y = 0; y < ys; y++) {
1552         int32_t* quant_row = ch.plane.Row(y);
1553         const float* row = rect.ConstPlaneRow(dc, c, y);
1554         for (size_t x = 0; x < xs; x++) {
1555           quant_row[x] = roundf(row[x] * inv_factor);
1556         }
1557       }
1558     }
1559   }
1560 
1561   DequantDC(r, &enc_state->shared.dc_storage, &enc_state->shared.quant_dc,
1562             stream_images[stream_id], enc_state->shared.quantizer.MulDC(),
1563             1.0 / mul, enc_state->shared.cmap.DCFactors(),
1564             enc_state->shared.frame_header.chroma_subsampling,
1565             enc_state->shared.block_ctx_map);
1566 }
1567 
AddACMetadata(size_t group_index,bool jpeg_transcode,PassesEncoderState * enc_state)1568 void ModularFrameEncoder::AddACMetadata(size_t group_index, bool jpeg_transcode,
1569                                         PassesEncoderState* enc_state) {
1570   const Rect r = enc_state->shared.DCGroupRect(group_index);
1571   size_t stream_id = ModularStreamId::ACMetadata(group_index).ID(frame_dim);
1572   stream_options[stream_id].max_chan_size = 0xFFFFFF;
1573   stream_options[stream_id].wp_tree_mode = ModularOptions::TreeMode::kNoWP;
1574   if (jpeg_transcode) {
1575     stream_options[stream_id].tree_kind =
1576         ModularOptions::TreeKind::kJpegTranscodeACMeta;
1577   } else if (cparams.speed_tier == SpeedTier::kFalcon) {
1578     stream_options[stream_id].tree_kind =
1579         ModularOptions::TreeKind::kFalconACMeta;
1580   } else if (cparams.speed_tier > SpeedTier::kKitten) {
1581     stream_options[stream_id].tree_kind = ModularOptions::TreeKind::kACMeta;
1582   }
1583   // If we are using a non-constant CfL field, and are in a slow enough mode,
1584   // re-enable tree computation for it.
1585   if (cparams.speed_tier < SpeedTier::kSquirrel &&
1586       cparams.force_cfl_jpeg_recompression) {
1587     stream_options[stream_id].tree_kind = ModularOptions::TreeKind::kLearn;
1588   }
1589   // YToX, YToB, ACS + QF, EPF
1590   Image& image = stream_images[stream_id];
1591   image = Image(r.xsize(), r.ysize(), 255, 4);
1592   static_assert(kColorTileDimInBlocks == 8, "Color tile size changed");
1593   Rect cr(r.x0() >> 3, r.y0() >> 3, (r.xsize() + 7) >> 3, (r.ysize() + 7) >> 3);
1594   image.channel[0] = Channel(cr.xsize(), cr.ysize(), 3, 3);
1595   image.channel[1] = Channel(cr.xsize(), cr.ysize(), 3, 3);
1596   image.channel[2] = Channel(r.xsize() * r.ysize(), 2, 0, 0);
1597   ConvertPlaneAndClamp(cr, enc_state->shared.cmap.ytox_map,
1598                        Rect(image.channel[0].plane), &image.channel[0].plane);
1599   ConvertPlaneAndClamp(cr, enc_state->shared.cmap.ytob_map,
1600                        Rect(image.channel[1].plane), &image.channel[1].plane);
1601   size_t num = 0;
1602   for (size_t y = 0; y < r.ysize(); y++) {
1603     AcStrategyRow row_acs = enc_state->shared.ac_strategy.ConstRow(r, y);
1604     const int* row_qf = r.ConstRow(enc_state->shared.raw_quant_field, y);
1605     const uint8_t* row_epf = r.ConstRow(enc_state->shared.epf_sharpness, y);
1606     int* out_acs = image.channel[2].plane.Row(0);
1607     int* out_qf = image.channel[2].plane.Row(1);
1608     int* row_out_epf = image.channel[3].plane.Row(y);
1609     for (size_t x = 0; x < r.xsize(); x++) {
1610       row_out_epf[x] = row_epf[x];
1611       if (!row_acs[x].IsFirstBlock()) continue;
1612       out_acs[num] = row_acs[x].RawStrategy();
1613       out_qf[num] = row_qf[x] - 1;
1614       num++;
1615     }
1616   }
1617   image.channel[2].w = num;
1618   image.channel[2].resize();
1619   ac_metadata_size[group_index] = num;
1620 }
1621 
EncodeQuantTable(size_t size_x,size_t size_y,BitWriter * writer,const QuantEncoding & encoding,size_t idx,ModularFrameEncoder * modular_frame_encoder)1622 void ModularFrameEncoder::EncodeQuantTable(
1623     size_t size_x, size_t size_y, BitWriter* writer,
1624     const QuantEncoding& encoding, size_t idx,
1625     ModularFrameEncoder* modular_frame_encoder) {
1626   JXL_ASSERT(encoding.qraw.qtable != nullptr);
1627   JXL_ASSERT(size_x * size_y * 3 == encoding.qraw.qtable->size());
1628   JXL_CHECK(F16Coder::Write(encoding.qraw.qtable_den, writer));
1629   if (modular_frame_encoder) {
1630     JXL_CHECK(modular_frame_encoder->EncodeStream(
1631         writer, nullptr, 0, ModularStreamId::QuantTable(idx)));
1632     return;
1633   }
1634   Image image(size_x, size_y, 255, 3);
1635   for (size_t c = 0; c < 3; c++) {
1636     for (size_t y = 0; y < size_y; y++) {
1637       int* JXL_RESTRICT row = image.channel[c].Row(y);
1638       for (size_t x = 0; x < size_x; x++) {
1639         row[x] = (*encoding.qraw.qtable)[c * size_x * size_y + y * size_x + x];
1640       }
1641     }
1642   }
1643   ModularOptions cfopts;
1644   JXL_CHECK(ModularGenericCompress(image, cfopts, writer));
1645 }
1646 
AddQuantTable(size_t size_x,size_t size_y,const QuantEncoding & encoding,size_t idx)1647 void ModularFrameEncoder::AddQuantTable(size_t size_x, size_t size_y,
1648                                         const QuantEncoding& encoding,
1649                                         size_t idx) {
1650   size_t stream_id = ModularStreamId::QuantTable(idx).ID(frame_dim);
1651   JXL_ASSERT(encoding.qraw.qtable != nullptr);
1652   JXL_ASSERT(size_x * size_y * 3 == encoding.qraw.qtable->size());
1653   Image& image = stream_images[stream_id];
1654   image = Image(size_x, size_y, 255, 3);
1655   for (size_t c = 0; c < 3; c++) {
1656     for (size_t y = 0; y < size_y; y++) {
1657       int* JXL_RESTRICT row = image.channel[c].Row(y);
1658       for (size_t x = 0; x < size_x; x++) {
1659         row[x] = (*encoding.qraw.qtable)[c * size_x * size_y + y * size_x + x];
1660       }
1661     }
1662   }
1663 }
1664 }  // namespace jxl
1665