1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
2 //
3 // Use of this source code is governed by a BSD-style
4 // license that can be found in the LICENSE file.
5 
6 #include <stddef.h>
7 #include <stdint.h>
8 
9 #include <algorithm>
10 #include <numeric>
11 #include <string>
12 
13 #include "lib/jxl/convolve.h"
14 #include "lib/jxl/enc_ac_strategy.h"
15 #include "lib/jxl/enc_adaptive_quantization.h"
16 #include "lib/jxl/enc_ar_control_field.h"
17 #include "lib/jxl/enc_cache.h"
18 #include "lib/jxl/enc_heuristics.h"
19 #include "lib/jxl/enc_noise.h"
20 #include "lib/jxl/gaborish.h"
21 #include "lib/jxl/gauss_blur.h"
22 
23 #undef HWY_TARGET_INCLUDE
24 #define HWY_TARGET_INCLUDE "lib/jxl/enc_fast_heuristics.cc"
25 #include <hwy/foreach_target.h>
26 #include <hwy/highway.h>
27 
28 HWY_BEFORE_NAMESPACE();
29 namespace jxl {
30 namespace HWY_NAMESPACE {
31 namespace {
32 using DF4 = HWY_CAPPED(float, 4);
33 DF4 df4;
34 HWY_FULL(float) df;
35 
Heuristics(PassesEncoderState * enc_state,ModularFrameEncoder * modular_frame_encoder,const ImageBundle * linear,Image3F * opsin,ThreadPool * pool,AuxOut * aux_out)36 Status Heuristics(PassesEncoderState* enc_state,
37                   ModularFrameEncoder* modular_frame_encoder,
38                   const ImageBundle* linear, Image3F* opsin, ThreadPool* pool,
39                   AuxOut* aux_out) {
40   PROFILER_ZONE("JxlLossyFrameHeuristics uninstrumented");
41   CompressParams& cparams = enc_state->cparams;
42   PassesSharedState& shared = enc_state->shared;
43   const FrameDimensions& frame_dim = enc_state->shared.frame_dim;
44   JXL_CHECK(cparams.butteraugli_distance > 0);
45 
46   // TODO(veluca): make this tiled.
47   if (shared.frame_header.loop_filter.gab) {
48     GaborishInverse(opsin, 0.9908511000000001f, pool);
49   }
50   // Compute image of high frequencies by removing a blurred version.
51   // TODO(veluca): certainly can be made faster, and use less memory...
52   constexpr size_t pad = 16;
53   Image3F padded = PadImageMirror(*opsin, pad, pad);
54   // Make the image (X, Y, B-Y)
55   // TODO(veluca): SubtractFrom is not parallel *and* not SIMD-fied.
56   SubtractFrom(padded.Plane(1), &padded.Plane(2));
57   // Ensure that OOB access for CfL does nothing. Not necessary if doing things
58   // properly...
59   Image3F hf(padded.xsize() + 64, padded.ysize());
60   ZeroFillImage(&hf);
61   hf.ShrinkTo(padded.xsize(), padded.ysize());
62   ImageF temp(padded.xsize(), padded.ysize());
63   // TODO(veluca): consider some faster blurring method.
64   auto g = CreateRecursiveGaussian(11.415258091746161);
65   for (size_t c = 0; c < 3; c++) {
66     FastGaussian(g, padded.Plane(c), pool, &temp, &hf.Plane(c));
67     SubtractFrom(padded.Plane(c), &hf.Plane(c));
68   }
69   // TODO(veluca): DC CfL?
70   size_t xcolortiles = DivCeil(frame_dim.xsize_blocks, kColorTileDimInBlocks);
71   size_t ycolortiles = DivCeil(frame_dim.ysize_blocks, kColorTileDimInBlocks);
72   JXL_RETURN_IF_ERROR(RunOnPool(
73       pool, 0, xcolortiles * ycolortiles, ThreadPool::NoInit,
74       [&](size_t tile_id, size_t _) {
75         size_t tx = tile_id % xcolortiles;
76         size_t ty = tile_id / xcolortiles;
77         size_t x0 = tx * kColorTileDim;
78         size_t x1 = std::min(x0 + kColorTileDim, hf.xsize());
79         size_t y0 = ty * kColorTileDim;
80         size_t y1 = std::min(y0 + kColorTileDim, hf.ysize());
81         for (size_t c : {0, 2}) {
82           static constexpr float kInvColorFactor = 1.0f / kDefaultColorFactor;
83           auto ca = Zero(df);
84           auto cb = Zero(df);
85           const auto inv_color_factor = Set(df, kInvColorFactor);
86           for (size_t y = y0; y < y1; y++) {
87             const float* row_m = hf.PlaneRow(1, y);
88             const float* row_s = hf.PlaneRow(c, y);
89             for (size_t x = x0; x < x1; x += Lanes(df)) {
90               // color residual = ax + b
91               const auto a = inv_color_factor * Load(df, row_m + x);
92               const auto b = Zero(df) - Load(df, row_s + x);
93               ca = MulAdd(a, a, ca);
94               cb = MulAdd(a, b, cb);
95             }
96           }
97           float best = -GetLane(SumOfLanes(df, cb)) /
98                        (GetLane(SumOfLanes(df, ca)) + 1e-9f);
99           int8_t& res = (c == 0 ? shared.cmap.ytox_map : shared.cmap.ytob_map)
100                             .Row(ty)[tx];
101           res = std::max(-128.0f, std::min(127.0f, roundf(best)));
102         }
103       },
104       "CfL"));
105   Image3F pooled(frame_dim.xsize_padded / 4, frame_dim.ysize_padded / 4);
106   Image3F summed(frame_dim.xsize_padded / 4, frame_dim.ysize_padded / 4);
107   JXL_RETURN_IF_ERROR(RunOnPool(
108       pool, 0, frame_dim.ysize_padded / 4, ThreadPool::NoInit,
109       [&](size_t y, size_t _) {
110         for (size_t c = 0; c < 3; c++) {
111           float* JXL_RESTRICT row_out = pooled.PlaneRow(c, y);
112           float* JXL_RESTRICT row_out_avg = summed.PlaneRow(c, y);
113           const float* JXL_RESTRICT row_in[4];
114           for (size_t iy = 0; iy < 4; iy++) {
115             row_in[iy] = hf.PlaneRow(c, 4 * y + pad + iy);
116           }
117           for (size_t x = 0; x < frame_dim.xsize_padded / 4; x++) {
118             auto max = Zero(df4);
119             auto sum = Zero(df4);
120             for (size_t iy = 0; iy < 4; iy++) {
121               for (size_t ix = 0; ix < 4; ix += Lanes(df4)) {
122                 const auto nn = Abs(Load(df4, row_in[iy] + x * 4 + ix + pad));
123                 sum += nn;
124                 max = IfThenElse(max > nn, max, nn);
125               }
126             }
127             row_out_avg[x] = GetLane(SumOfLanes(df4, sum));
128             row_out[x] = GetLane(MaxOfLanes(df4, max));
129           }
130         }
131       },
132       "MaxPool"));
133   // TODO(veluca): better handling of the border
134   // TODO(veluca): consider some faster blurring method.
135   // TODO(veluca): parallelize.
136   // Remove noise from the resulting image.
137   auto g2 = CreateRecursiveGaussian(2.0849544429861884);
138   constexpr size_t pad2 = 16;
139   Image3F summed_pad = PadImageMirror(summed, pad2, pad2);
140   ImageF tmp_out(summed_pad.xsize(), summed_pad.ysize());
141   ImageF tmp2(summed_pad.xsize(), summed_pad.ysize());
142   Image3F pooled_pad = PadImageMirror(pooled, pad2, pad2);
143   for (size_t c = 0; c < 3; c++) {
144     FastGaussian(g2, summed_pad.Plane(c), pool, &tmp2, &tmp_out);
145     const auto unblurred_multiplier = Set(df, 0.5f);
146     for (size_t y = 0; y < summed.ysize(); y++) {
147       float* row = summed.PlaneRow(c, y);
148       const float* row_blur = tmp_out.Row(y + pad2);
149       for (size_t x = 0; x < summed.xsize(); x += Lanes(df)) {
150         const auto b = Load(df, row_blur + x + pad2);
151         const auto o = Load(df, row + x) * unblurred_multiplier;
152         const auto m = IfThenElse(b > o, b, o);
153         Store(m, df, row + x);
154       }
155     }
156   }
157   for (size_t c = 0; c < 3; c++) {
158     FastGaussian(g2, pooled_pad.Plane(c), pool, &tmp2, &tmp_out);
159     const auto unblurred_multiplier = Set(df, 0.5f);
160     for (size_t y = 0; y < pooled.ysize(); y++) {
161       float* row = pooled.PlaneRow(c, y);
162       const float* row_blur = tmp_out.Row(y + pad2);
163       for (size_t x = 0; x < pooled.xsize(); x += Lanes(df)) {
164         const auto b = Load(df, row_blur + x + pad2);
165         const auto o = Load(df, row + x) * unblurred_multiplier;
166         const auto m = IfThenElse(b > o, b, o);
167         Store(m, df, row + x);
168       }
169     }
170   }
171   const static float kChannelMul[3] = {
172       7.9644294909680253f,
173       0.5700000183257159f,
174       0.20267448837597055f,
175   };
176   ImageF pooledhf44(pooled.xsize(), pooled.ysize());
177   for (size_t y = 0; y < pooled.ysize(); y++) {
178     const float* row_in_x = pooled.ConstPlaneRow(0, y);
179     const float* row_in_y = pooled.ConstPlaneRow(1, y);
180     const float* row_in_b = pooled.ConstPlaneRow(2, y);
181     float* row_out = pooledhf44.Row(y);
182     for (size_t x = 0; x < pooled.xsize(); x += Lanes(df)) {
183       auto v = Set(df, kChannelMul[0]) * Load(df, row_in_x + x);
184       v = MulAdd(Set(df, kChannelMul[1]), Load(df, row_in_y + x), v);
185       v = MulAdd(Set(df, kChannelMul[2]), Load(df, row_in_b + x), v);
186       Store(v, df, row_out + x);
187     }
188   }
189   ImageF summedhf44(summed.xsize(), summed.ysize());
190   for (size_t y = 0; y < summed.ysize(); y++) {
191     const float* row_in_x = summed.ConstPlaneRow(0, y);
192     const float* row_in_y = summed.ConstPlaneRow(1, y);
193     const float* row_in_b = summed.ConstPlaneRow(2, y);
194     float* row_out = summedhf44.Row(y);
195     for (size_t x = 0; x < summed.xsize(); x += Lanes(df)) {
196       auto v = Set(df, kChannelMul[0]) * Load(df, row_in_x + x);
197       v = MulAdd(Set(df, kChannelMul[1]), Load(df, row_in_y + x), v);
198       v = MulAdd(Set(df, kChannelMul[2]), Load(df, row_in_b + x), v);
199       Store(v, df, row_out + x);
200     }
201   }
202   aux_out->DumpPlaneNormalized("pooledhf44", pooledhf44);
203   aux_out->DumpPlaneNormalized("summedhf44", summedhf44);
204 
205   static const float kDcQuantMul = 0.88170190420916206;
206   static const float kAcQuantMul = 2.5165738934721524;
207 
208   float dc_quant = kDcQuantMul * InitialQuantDC(cparams.butteraugli_distance);
209   float ac_quant_base = kAcQuantMul / cparams.butteraugli_distance;
210   ImageF quant_field(frame_dim.xsize_blocks, frame_dim.ysize_blocks);
211 
212   static_assert(kColorTileDim == 64, "Fix the code below");
213   auto mmacs = [&](size_t bx, size_t by, AcStrategy acs, float& min,
214                    float& max) {
215     min = 1e10;
216     max = 0;
217     for (size_t y = 2 * by; y < 2 * (by + acs.covered_blocks_y()); y++) {
218       const float* row = summedhf44.Row(y);
219       for (size_t x = 2 * bx; x < 2 * (bx + acs.covered_blocks_x()); x++) {
220         min = std::min(min, row[x]);
221         max = std::max(max, row[x]);
222       }
223     }
224   };
225   // Multipliers for allowed range of summedhf44.
226   std::pair<AcStrategy::Type, float> candidates[] = {
227     // The order is such that, in case of ties, 8x8 is favoured over 4x4 which
228     // is favoured over 2x2. Similarly, we prefer square transforms over
229     // same-area rectangular ones.
230     {AcStrategy::Type::DCT2X2, 1.5f},
231     {AcStrategy::Type::DCT4X4, 1.4f},
232     {AcStrategy::Type::DCT4X8, 1.2f},
233     {AcStrategy::Type::DCT8X4, 1.2f},
234     {AcStrategy::Type::AFV0,
235      1.15f},  // doesn't really work with these heuristics
236     {AcStrategy::Type::AFV1, 1.15f},
237     {AcStrategy::Type::AFV2, 1.15f},
238     {AcStrategy::Type::AFV3, 1.15f},
239     {AcStrategy::Type::DCT, 1.0f},
240     {AcStrategy::Type::DCT16X8, 0.8f},
241     {AcStrategy::Type::DCT8X16, 0.8f},
242     {AcStrategy::Type::DCT16X16, 0.2f},
243     {AcStrategy::Type::DCT16X32, 0.2f},
244     {AcStrategy::Type::DCT32X16, 0.2f},
245     {AcStrategy::Type::DCT32X32, 0.2f},
246     {AcStrategy::Type::DCT32X64, 0.1f},
247     {AcStrategy::Type::DCT64X32, 0.1f},
248     {AcStrategy::Type::DCT64X64, 0.04f},
249 
250 #if 0
251       {AcStrategy::Type::DCT2X2, 1e+10},  {AcStrategy::Type::DCT4X4, 2.0f},
252       {AcStrategy::Type::DCT, 1.0f},      {AcStrategy::Type::DCT16X8, 1.0f},
253       {AcStrategy::Type::DCT8X16, 1.0f},  {AcStrategy::Type::DCT32X8, 1.0f},
254       {AcStrategy::Type::DCT8X32, 1.0f},  {AcStrategy::Type::DCT32X16, 1.0f},
255       {AcStrategy::Type::DCT16X32, 1.0f}, {AcStrategy::Type::DCT64X32, 1.0f},
256       {AcStrategy::Type::DCT32X64, 1.0f}, {AcStrategy::Type::DCT16X16, 1.0f},
257       {AcStrategy::Type::DCT32X32, 1.0f}, {AcStrategy::Type::DCT64X64, 1.0f},
258 #endif
259     // TODO(veluca): figure out if we want 4x8 and/or AVF.
260   };
261   float max_range = 1e-8f + 0.5f * std::pow(cparams.butteraugli_distance, 0.5f);
262   // Change quant field and sharpness amounts based on (pooled|summed)hf44, and
263   // compute block sizes.
264   // TODO(veluca): maybe this could be done per group: it would allow choosing
265   // floating blocks better.
266   JXL_RETURN_IF_ERROR(RunOnPool(
267       pool, 0, xcolortiles * ycolortiles, ThreadPool::NoInit,
268       [&](size_t tile_id, size_t _) {
269         size_t tx = tile_id % xcolortiles;
270         size_t ty = tile_id / xcolortiles;
271         size_t x0 = tx * kColorTileDim / kBlockDim;
272         size_t x1 = std::min(x0 + kColorTileDimInBlocks, quant_field.xsize());
273         size_t y0 = ty * kColorTileDim / kBlockDim;
274         size_t y1 = std::min(y0 + kColorTileDimInBlocks, quant_field.ysize());
275         size_t qf_stride = quant_field.PixelsPerRow();
276         size_t epf_stride = shared.epf_sharpness.PixelsPerRow();
277         bool chosen_mask[64] = {};
278         for (size_t y = y0; y < y1; y++) {
279           uint8_t* epf_row = shared.epf_sharpness.Row(y);
280           float* qf_row = quant_field.Row(y);
281           for (size_t x = x0; x < x1; x++) {
282             if (chosen_mask[(y - y0) * 8 + (x - x0)]) continue;
283             // Default to DCT8 just in case something funny happens in the loop
284             // below.
285             AcStrategy::Type best = AcStrategy::DCT;
286             size_t best_covered = 1;
287             float qf = ac_quant_base;
288             for (size_t i = 0; i < sizeof(candidates) / sizeof(*candidates);
289                  i++) {
290               AcStrategy acs = AcStrategy::FromRawStrategy(candidates[i].first);
291               if (y + acs.covered_blocks_y() > y1) continue;
292               if (x + acs.covered_blocks_x() > x1) continue;
293               bool fits = true;
294               for (size_t iy = y; iy < y + acs.covered_blocks_y(); iy++) {
295                 for (size_t ix = x; ix < x + acs.covered_blocks_x(); ix++) {
296                   if (chosen_mask[(iy - y0) * 8 + (ix - x0)]) {
297                     fits = false;
298                     break;
299                   }
300                 }
301               }
302               if (!fits) continue;
303               float min, max;
304               mmacs(x, y, acs, min, max);
305               if (max - min > max_range * candidates[i].second) continue;
306               size_t cb = acs.covered_blocks_x() * acs.covered_blocks_y();
307               if (cb >= best_covered) {
308                 best_covered = cb;
309                 best = candidates[i].first;
310                 // TODO(veluca): make this better.
311                 qf = ac_quant_base /
312                      (3.9312946339134007f + 2.6011435675118082f * min);
313               }
314             }
315             shared.ac_strategy.Set(x, y, best);
316             AcStrategy acs = AcStrategy::FromRawStrategy(best);
317             for (size_t iy = y; iy < y + acs.covered_blocks_y(); iy++) {
318               for (size_t ix = x; ix < x + acs.covered_blocks_x(); ix++) {
319                 chosen_mask[(iy - y0) * 8 + (ix - x0)] = 1;
320                 qf_row[ix + (iy - y) * qf_stride] = qf;
321               }
322             }
323             // TODO
324             for (size_t iy = y; iy < y + acs.covered_blocks_y(); iy++) {
325               for (size_t ix = x; ix < x + acs.covered_blocks_x(); ix++) {
326                 epf_row[ix + (iy - y) * epf_stride] = 4;
327               }
328             }
329           }
330         }
331       },
332       "QF+ACS+EPF"));
333   aux_out->DumpPlaneNormalized("qf", quant_field);
334   aux_out->DumpPlaneNormalized("epf", shared.epf_sharpness);
335   DumpAcStrategy(shared.ac_strategy, frame_dim.xsize_padded,
336                  frame_dim.ysize_padded, "acs", aux_out);
337 
338   shared.quantizer.SetQuantField(dc_quant, quant_field,
339                                  &shared.raw_quant_field);
340 
341   return true;
342 }
343 }  // namespace
344 // NOLINTNEXTLINE(google-readability-namespace-comments)
345 }  // namespace HWY_NAMESPACE
346 }  // namespace jxl
347 HWY_AFTER_NAMESPACE();
348 
349 #if HWY_ONCE
350 namespace jxl {
351 HWY_EXPORT(Heuristics);
LossyFrameHeuristics(PassesEncoderState * enc_state,ModularFrameEncoder * modular_frame_encoder,const ImageBundle * linear,Image3F * opsin,const JxlCmsInterface & cms,ThreadPool * pool,AuxOut * aux_out)352 Status FastEncoderHeuristics::LossyFrameHeuristics(
353     PassesEncoderState* enc_state, ModularFrameEncoder* modular_frame_encoder,
354     const ImageBundle* linear, Image3F* opsin, const JxlCmsInterface& cms,
355     ThreadPool* pool, AuxOut* aux_out) {
356   return HWY_DYNAMIC_DISPATCH(Heuristics)(enc_state, modular_frame_encoder,
357                                           linear, opsin, pool, aux_out);
358 }
359 
360 }  // namespace jxl
361 #endif
362