1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
2 //
3 // Use of this source code is governed by a BSD-style
4 // license that can be found in the LICENSE file.
5 
6 #include "lib/jxl/enc_heuristics.h"
7 
8 #include <stddef.h>
9 #include <stdint.h>
10 
11 #include <algorithm>
12 #include <numeric>
13 #include <string>
14 
15 #include "lib/jxl/enc_ac_strategy.h"
16 #include "lib/jxl/enc_adaptive_quantization.h"
17 #include "lib/jxl/enc_ar_control_field.h"
18 #include "lib/jxl/enc_cache.h"
19 #include "lib/jxl/enc_chroma_from_luma.h"
20 #include "lib/jxl/enc_modular.h"
21 #include "lib/jxl/enc_noise.h"
22 #include "lib/jxl/enc_patch_dictionary.h"
23 #include "lib/jxl/enc_quant_weights.h"
24 #include "lib/jxl/enc_splines.h"
25 #include "lib/jxl/enc_xyb.h"
26 #include "lib/jxl/gaborish.h"
27 
28 namespace jxl {
29 namespace {
FindBestBlockEntropyModel(PassesEncoderState & enc_state)30 void FindBestBlockEntropyModel(PassesEncoderState& enc_state) {
31   if (enc_state.cparams.decoding_speed_tier >= 1) {
32     static constexpr uint8_t kSimpleCtxMap[] = {
33         // Cluster all blocks together
34         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,  //
35         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,  //
36         1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,  //
37     };
38     static_assert(
39         3 * kNumOrders == sizeof(kSimpleCtxMap) / sizeof *kSimpleCtxMap,
40         "Update simple context map");
41 
42     auto bcm = enc_state.shared.block_ctx_map;
43     bcm.ctx_map.assign(std::begin(kSimpleCtxMap), std::end(kSimpleCtxMap));
44     bcm.num_ctxs = 2;
45     bcm.num_dc_ctxs = 1;
46     return;
47   }
48   if (enc_state.cparams.speed_tier == SpeedTier::kFalcon) {
49     return;
50   }
51   const ImageI& rqf = enc_state.shared.raw_quant_field;
52   // No need to change context modeling for small images.
53   size_t tot = rqf.xsize() * rqf.ysize();
54   size_t size_for_ctx_model =
55       (1 << 10) * enc_state.cparams.butteraugli_distance;
56   if (tot < size_for_ctx_model) return;
57 
58   struct OccCounters {
59     // count the occurrences of each qf value and each strategy type.
60     OccCounters(const ImageI& rqf, const AcStrategyImage& ac_strategy) {
61       for (size_t y = 0; y < rqf.ysize(); y++) {
62         const int32_t* qf_row = rqf.Row(y);
63         AcStrategyRow acs_row = ac_strategy.ConstRow(y);
64         for (size_t x = 0; x < rqf.xsize(); x++) {
65           int ord = kStrategyOrder[acs_row[x].RawStrategy()];
66           int qf = qf_row[x] - 1;
67           qf_counts[qf]++;
68           qf_ord_counts[ord][qf]++;
69           ord_counts[ord]++;
70         }
71       }
72     }
73 
74     size_t qf_counts[256] = {};
75     size_t qf_ord_counts[kNumOrders][256] = {};
76     size_t ord_counts[kNumOrders] = {};
77   };
78   // The OccCounters struct is too big to allocate on the stack.
79   std::unique_ptr<OccCounters> counters(
80       new OccCounters(rqf, enc_state.shared.ac_strategy));
81 
82   // Splitting the context model according to the quantization field seems to
83   // mostly benefit only large images.
84   size_t size_for_qf_split = (1 << 13) * enc_state.cparams.butteraugli_distance;
85   size_t num_qf_segments = tot < size_for_qf_split ? 1 : 2;
86   std::vector<uint32_t>& qft = enc_state.shared.block_ctx_map.qf_thresholds;
87   qft.clear();
88   // Divide the quant field in up to num_qf_segments segments.
89   size_t cumsum = 0;
90   size_t next = 1;
91   size_t last_cut = 256;
92   size_t cut = tot * next / num_qf_segments;
93   for (uint32_t j = 0; j < 256; j++) {
94     cumsum += counters->qf_counts[j];
95     if (cumsum > cut) {
96       if (j != 0) {
97         qft.push_back(j);
98       }
99       last_cut = j;
100       while (cumsum > cut) {
101         next++;
102         cut = tot * next / num_qf_segments;
103       }
104     } else if (next > qft.size() + 1) {
105       if (j - 1 == last_cut && j != 0) {
106         qft.push_back(j);
107       }
108     }
109   }
110 
111   // Count the occurrences of each segment.
112   std::vector<size_t> counts(kNumOrders * (qft.size() + 1));
113   size_t qft_pos = 0;
114   for (size_t j = 0; j < 256; j++) {
115     if (qft_pos < qft.size() && j == qft[qft_pos]) {
116       qft_pos++;
117     }
118     for (size_t i = 0; i < kNumOrders; i++) {
119       counts[qft_pos + i * (qft.size() + 1)] += counters->qf_ord_counts[i][j];
120     }
121   }
122 
123   // Repeatedly merge the lowest-count pair.
124   std::vector<uint8_t> remap((qft.size() + 1) * kNumOrders);
125   std::iota(remap.begin(), remap.end(), 0);
126   std::vector<uint8_t> clusters(remap);
127   // This is O(n^2 log n), but n <= 14.
128   while (clusters.size() > 5) {
129     std::sort(clusters.begin(), clusters.end(),
130               [&](int a, int b) { return counts[a] > counts[b]; });
131     counts[clusters[clusters.size() - 2]] += counts[clusters.back()];
132     counts[clusters.back()] = 0;
133     remap[clusters.back()] = clusters[clusters.size() - 2];
134     clusters.pop_back();
135   }
136   for (size_t i = 0; i < remap.size(); i++) {
137     while (remap[remap[i]] != remap[i]) {
138       remap[i] = remap[remap[i]];
139     }
140   }
141   // Relabel starting from 0.
142   std::vector<uint8_t> remap_remap(remap.size(), remap.size());
143   size_t num = 0;
144   for (size_t i = 0; i < remap.size(); i++) {
145     if (remap_remap[remap[i]] == remap.size()) {
146       remap_remap[remap[i]] = num++;
147     }
148     remap[i] = remap_remap[remap[i]];
149   }
150   // Write the block context map.
151   auto& ctx_map = enc_state.shared.block_ctx_map.ctx_map;
152   ctx_map = remap;
153   ctx_map.resize(remap.size() * 3);
154   for (size_t i = remap.size(); i < remap.size() * 3; i++) {
155     ctx_map[i] = remap[i % remap.size()] + num;
156   }
157   enc_state.shared.block_ctx_map.num_ctxs =
158       *std::max_element(ctx_map.begin(), ctx_map.end()) + 1;
159 }
160 
161 // Returns the target size based on whether bitrate or direct targetsize is
162 // given.
TargetSize(const CompressParams & cparams,const FrameDimensions & frame_dim)163 size_t TargetSize(const CompressParams& cparams,
164                   const FrameDimensions& frame_dim) {
165   if (cparams.target_size > 0) {
166     return cparams.target_size;
167   }
168   if (cparams.target_bitrate > 0.0) {
169     return 0.5 + cparams.target_bitrate * frame_dim.xsize * frame_dim.ysize /
170                      kBitsPerByte;
171   }
172   return 0;
173 }
174 }  // namespace
175 
FindBestDequantMatrices(const CompressParams & cparams,const Image3F & opsin,ModularFrameEncoder * modular_frame_encoder,DequantMatrices * dequant_matrices)176 void FindBestDequantMatrices(const CompressParams& cparams,
177                              const Image3F& opsin,
178                              ModularFrameEncoder* modular_frame_encoder,
179                              DequantMatrices* dequant_matrices) {
180   // TODO(veluca): quant matrices for no-gaborish.
181   // TODO(veluca): heuristics for in-bitstream quant tables.
182   *dequant_matrices = DequantMatrices();
183   if (cparams.max_error_mode) {
184     // Set numerators of all quantization matrices to constant values.
185     float weights[3][1] = {{1.0f / cparams.max_error[0]},
186                            {1.0f / cparams.max_error[1]},
187                            {1.0f / cparams.max_error[2]}};
188     DctQuantWeightParams dct_params(weights);
189     std::vector<QuantEncoding> encodings(DequantMatrices::kNum,
190                                          QuantEncoding::DCT(dct_params));
191     DequantMatricesSetCustom(dequant_matrices, encodings,
192                              modular_frame_encoder);
193     float dc_weights[3] = {1.0f / cparams.max_error[0],
194                            1.0f / cparams.max_error[1],
195                            1.0f / cparams.max_error[2]};
196     DequantMatricesSetCustomDC(dequant_matrices, dc_weights);
197   }
198 }
199 
HandlesColorConversion(const CompressParams & cparams,const ImageBundle & ib)200 bool DefaultEncoderHeuristics::HandlesColorConversion(
201     const CompressParams& cparams, const ImageBundle& ib) {
202   return cparams.noise != Override::kOn && cparams.patches != Override::kOn &&
203          cparams.speed_tier >= SpeedTier::kWombat && cparams.resampling == 1 &&
204          cparams.color_transform == ColorTransform::kXYB &&
205          !cparams.modular_mode && !ib.HasAlpha();
206 }
207 
LossyFrameHeuristics(PassesEncoderState * enc_state,ModularFrameEncoder * modular_frame_encoder,const ImageBundle * original_pixels,Image3F * opsin,ThreadPool * pool,AuxOut * aux_out)208 Status DefaultEncoderHeuristics::LossyFrameHeuristics(
209     PassesEncoderState* enc_state, ModularFrameEncoder* modular_frame_encoder,
210     const ImageBundle* original_pixels, Image3F* opsin, ThreadPool* pool,
211     AuxOut* aux_out) {
212   PROFILER_ZONE("JxlLossyFrameHeuristics uninstrumented");
213 
214   CompressParams& cparams = enc_state->cparams;
215   PassesSharedState& shared = enc_state->shared;
216 
217   // Compute parameters for noise synthesis.
218   if (shared.frame_header.flags & FrameHeader::kNoise) {
219     PROFILER_ZONE("enc GetNoiseParam");
220     // Don't start at zero amplitude since adding noise is expensive -- it
221     // significantly slows down decoding, and this is unlikely to
222     // completely go away even with advanced optimizations. After the
223     // kNoiseModelingRampUpDistanceRange we have reached the full level,
224     // i.e. noise is no longer represented by the compressed image, so we
225     // can add full noise by the noise modeling itself.
226     static const float kNoiseModelingRampUpDistanceRange = 0.6;
227     static const float kNoiseLevelAtStartOfRampUp = 0.25;
228     static const float kNoiseRampupStart = 1.0;
229     // TODO(user) test and properly select quality_coef with smooth
230     // filter
231     float quality_coef = 1.0f;
232     const float rampup = (cparams.butteraugli_distance - kNoiseRampupStart) /
233                          kNoiseModelingRampUpDistanceRange;
234     if (rampup < 1.0f) {
235       quality_coef = kNoiseLevelAtStartOfRampUp +
236                      (1.0f - kNoiseLevelAtStartOfRampUp) * rampup;
237     }
238     if (rampup < 0.0f) {
239       quality_coef = kNoiseRampupStart;
240     }
241     if (!GetNoiseParameter(*opsin, &shared.image_features.noise_params,
242                            quality_coef)) {
243       shared.frame_header.flags &= ~FrameHeader::kNoise;
244     }
245   }
246   if (enc_state->shared.frame_header.upsampling != 1 && !cparams.already_downsampled) {
247     // In VarDCT mode, LossyFrameHeuristics takes care of running downsampling
248     // after noise, if necessary.
249     DownsampleImage(opsin, cparams.resampling);
250     PadImageToBlockMultipleInPlace(opsin);
251   }
252 
253   const FrameDimensions& frame_dim = enc_state->shared.frame_dim;
254   size_t target_size = TargetSize(cparams, frame_dim);
255   size_t opsin_target_size = target_size;
256   if (cparams.target_size > 0 || cparams.target_bitrate > 0.0) {
257     cparams.target_size = opsin_target_size;
258   } else if (cparams.butteraugli_distance < 0) {
259     return JXL_FAILURE("Expected non-negative distance");
260   }
261 
262   // Find and subtract splines.
263   if (cparams.speed_tier <= SpeedTier::kSquirrel) {
264     shared.image_features.splines = FindSplines(*opsin);
265     JXL_RETURN_IF_ERROR(
266         shared.image_features.splines.SubtractFrom(opsin, shared.cmap));
267   }
268 
269   // Find and subtract patches/dots.
270   if (ApplyOverride(cparams.patches,
271                     cparams.speed_tier <= SpeedTier::kSquirrel)) {
272     FindBestPatchDictionary(*opsin, enc_state, pool, aux_out);
273     PatchDictionaryEncoder::SubtractFrom(shared.image_features.patches, opsin);
274   }
275 
276   static const float kAcQuant = 0.79f;
277   const float quant_dc = InitialQuantDC(cparams.butteraugli_distance);
278   Quantizer& quantizer = enc_state->shared.quantizer;
279   // We don't know the quant field yet, but for computing the global scale
280   // assuming that it will be the same as for Falcon mode is good enough.
281   quantizer.ComputeGlobalScaleAndQuant(
282       quant_dc, kAcQuant / cparams.butteraugli_distance, 0);
283 
284   // TODO(veluca): we can now run all the code from here to FindBestQuantizer
285   // (excluded) one rect at a time. Do that.
286 
287   // Dependency graph:
288   //
289   // input: either XYB or input image
290   //
291   // input image -> XYB [optional]
292   // XYB -> initial quant field
293   // XYB -> Gaborished XYB
294   // Gaborished XYB -> CfL1
295   // initial quant field, Gaborished XYB, CfL1 -> ACS
296   // initial quant field, ACS, Gaborished XYB -> EPF control field
297   // initial quant field -> adjusted initial quant field
298   // adjusted initial quant field, ACS -> raw quant field
299   // raw quant field, ACS, Gaborished XYB -> CfL2
300   //
301   // output: Gaborished XYB, CfL, ACS, raw quant field, EPF control field.
302 
303   ArControlFieldHeuristics ar_heuristics;
304   AcStrategyHeuristics acs_heuristics;
305   CfLHeuristics cfl_heuristics;
306 
307   if (!opsin->xsize()) {
308     JXL_ASSERT(HandlesColorConversion(cparams, *original_pixels));
309     *opsin = Image3F(RoundUpToBlockDim(original_pixels->xsize()),
310                      RoundUpToBlockDim(original_pixels->ysize()));
311     opsin->ShrinkTo(original_pixels->xsize(), original_pixels->ysize());
312     ToXYB(*original_pixels, pool, opsin, /*linear=*/nullptr);
313     PadImageToBlockMultipleInPlace(opsin);
314   }
315 
316   // Compute an initial estimate of the quantization field.
317   // Call InitialQuantField only in Hare mode or slower. Otherwise, rely
318   // on simple heuristics in FindBestAcStrategy, or set a constant for Falcon
319   // mode.
320   if (cparams.speed_tier > SpeedTier::kHare || cparams.uniform_quant > 0) {
321     enc_state->initial_quant_field =
322         ImageF(shared.frame_dim.xsize_blocks, shared.frame_dim.ysize_blocks);
323     if (cparams.speed_tier == SpeedTier::kFalcon || cparams.uniform_quant > 0) {
324       float q = cparams.uniform_quant > 0
325                     ? cparams.uniform_quant
326                     : kAcQuant / cparams.butteraugli_distance;
327       FillImage(q, &enc_state->initial_quant_field);
328     }
329   } else {
330     // Call this here, as it relies on pre-gaborish values.
331     float butteraugli_distance_for_iqf = cparams.butteraugli_distance;
332     if (!shared.frame_header.loop_filter.gab) {
333       butteraugli_distance_for_iqf *= 0.73f;
334     }
335     enc_state->initial_quant_field = InitialQuantField(
336         butteraugli_distance_for_iqf, *opsin, shared.frame_dim, pool, 1.0f,
337         &enc_state->initial_quant_masking);
338   }
339 
340   // TODO(veluca): do something about animations.
341 
342   // Apply inverse-gaborish.
343   if (shared.frame_header.loop_filter.gab) {
344     GaborishInverse(opsin, 0.9908511000000001f, pool);
345   }
346 
347   cfl_heuristics.Init(*opsin);
348   acs_heuristics.Init(*opsin, enc_state);
349 
350   auto process_tile = [&](size_t tid, size_t thread) {
351     size_t n_enc_tiles =
352         DivCeil(enc_state->shared.frame_dim.xsize_blocks, kEncTileDimInBlocks);
353     size_t tx = tid % n_enc_tiles;
354     size_t ty = tid / n_enc_tiles;
355     size_t by0 = ty * kEncTileDimInBlocks;
356     size_t by1 = std::min((ty + 1) * kEncTileDimInBlocks,
357                           enc_state->shared.frame_dim.ysize_blocks);
358     size_t bx0 = tx * kEncTileDimInBlocks;
359     size_t bx1 = std::min((tx + 1) * kEncTileDimInBlocks,
360                           enc_state->shared.frame_dim.xsize_blocks);
361     Rect r(bx0, by0, bx1 - bx0, by1 - by0);
362 
363     // For speeds up to Wombat, we only compute the color correlation map
364     // once we know the transform type and the quantization map.
365     if (cparams.speed_tier <= SpeedTier::kSquirrel) {
366       cfl_heuristics.ComputeTile(r, *opsin, enc_state->shared.matrices,
367                                  /*ac_strategy=*/nullptr,
368                                  /*quantizer=*/nullptr, /*fast=*/false, thread,
369                                  &enc_state->shared.cmap);
370     }
371 
372     // Choose block sizes.
373     acs_heuristics.ProcessRect(r);
374 
375     // Choose amount of post-processing smoothing.
376     // TODO(veluca): should this go *after* AdjustQuantField?
377     ar_heuristics.RunRect(r, *opsin, enc_state, thread);
378 
379     // Always set the initial quant field, so we can compute the CfL map with
380     // more accuracy. The initial quant field might change in slower modes, but
381     // adjusting the quant field with butteraugli when all the other encoding
382     // parameters are fixed is likely a more reliable choice anyway.
383     AdjustQuantField(enc_state->shared.ac_strategy, r,
384                      &enc_state->initial_quant_field);
385     quantizer.SetQuantFieldRect(enc_state->initial_quant_field, r,
386                                 &enc_state->shared.raw_quant_field);
387 
388     // Compute a non-default CfL map if we are at Hare speed, or slower.
389     if (cparams.speed_tier <= SpeedTier::kHare) {
390       cfl_heuristics.ComputeTile(
391           r, *opsin, enc_state->shared.matrices, &enc_state->shared.ac_strategy,
392           &enc_state->shared.quantizer,
393           /*fast=*/cparams.speed_tier >= SpeedTier::kWombat, thread,
394           &enc_state->shared.cmap);
395     }
396   };
397   RunOnPool(
398       pool, 0,
399       DivCeil(enc_state->shared.frame_dim.xsize_blocks, kEncTileDimInBlocks) *
400           DivCeil(enc_state->shared.frame_dim.ysize_blocks,
401                   kEncTileDimInBlocks),
402       [&](const size_t num_threads) {
403         ar_heuristics.PrepareForThreads(num_threads);
404         cfl_heuristics.PrepareForThreads(num_threads);
405         return true;
406       },
407       process_tile, "Enc Heuristics");
408 
409   acs_heuristics.Finalize(aux_out);
410   if (cparams.speed_tier <= SpeedTier::kHare) {
411     cfl_heuristics.ComputeDC(/*fast=*/cparams.speed_tier >= SpeedTier::kWombat,
412                              &enc_state->shared.cmap);
413   }
414 
415   FindBestDequantMatrices(cparams, *opsin, modular_frame_encoder,
416                           &enc_state->shared.matrices);
417 
418   // Refine quantization levels.
419   FindBestQuantizer(original_pixels, *opsin, enc_state, pool, aux_out);
420 
421   // Choose a context model that depends on the amount of quantization for AC.
422   if (cparams.speed_tier != SpeedTier::kFalcon) {
423     FindBestBlockEntropyModel(*enc_state);
424   }
425   return true;
426 }
427 
428 }  // namespace jxl
429