1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
2 //
3 // Use of this source code is governed by a BSD-style
4 // license that can be found in the LICENSE file.
5
6 #include <stddef.h>
7 #include <stdint.h>
8
9 #include <algorithm>
10 #include <numeric>
11 #include <string>
12
13 #include "lib/jxl/convolve.h"
14 #include "lib/jxl/enc_ac_strategy.h"
15 #include "lib/jxl/enc_adaptive_quantization.h"
16 #include "lib/jxl/enc_ar_control_field.h"
17 #include "lib/jxl/enc_cache.h"
18 #include "lib/jxl/enc_heuristics.h"
19 #include "lib/jxl/enc_noise.h"
20 #include "lib/jxl/gaborish.h"
21 #include "lib/jxl/gauss_blur.h"
22
23 #undef HWY_TARGET_INCLUDE
24 #define HWY_TARGET_INCLUDE "lib/jxl/enc_fast_heuristics.cc"
25 #include <hwy/foreach_target.h>
26 #include <hwy/highway.h>
27
28 HWY_BEFORE_NAMESPACE();
29 namespace jxl {
30 namespace HWY_NAMESPACE {
31 namespace {
32 using DF4 = HWY_CAPPED(float, 4);
33 DF4 df4;
34 HWY_FULL(float) df;
35
Heuristics(PassesEncoderState * enc_state,ModularFrameEncoder * modular_frame_encoder,const ImageBundle * linear,Image3F * opsin,ThreadPool * pool,AuxOut * aux_out)36 Status Heuristics(PassesEncoderState* enc_state,
37 ModularFrameEncoder* modular_frame_encoder,
38 const ImageBundle* linear, Image3F* opsin, ThreadPool* pool,
39 AuxOut* aux_out) {
40 PROFILER_ZONE("JxlLossyFrameHeuristics uninstrumented");
41 CompressParams& cparams = enc_state->cparams;
42 PassesSharedState& shared = enc_state->shared;
43 const FrameDimensions& frame_dim = enc_state->shared.frame_dim;
44 JXL_CHECK(cparams.butteraugli_distance > 0);
45
46 // TODO(veluca): make this tiled.
47 if (shared.frame_header.loop_filter.gab) {
48 GaborishInverse(opsin, 0.9908511000000001f, pool);
49 }
50 // Compute image of high frequencies by removing a blurred version.
51 // TODO(veluca): certainly can be made faster, and use less memory...
52 constexpr size_t pad = 16;
53 Image3F padded = PadImageMirror(*opsin, pad, pad);
54 // Make the image (X, Y, B-Y)
55 // TODO(veluca): SubtractFrom is not parallel *and* not SIMD-fied.
56 SubtractFrom(padded.Plane(1), &padded.Plane(2));
57 // Ensure that OOB access for CfL does nothing. Not necessary if doing things
58 // properly...
59 Image3F hf(padded.xsize() + 64, padded.ysize());
60 ZeroFillImage(&hf);
61 hf.ShrinkTo(padded.xsize(), padded.ysize());
62 ImageF temp(padded.xsize(), padded.ysize());
63 // TODO(veluca): consider some faster blurring method.
64 auto g = CreateRecursiveGaussian(11.415258091746161);
65 for (size_t c = 0; c < 3; c++) {
66 FastGaussian(g, padded.Plane(c), pool, &temp, &hf.Plane(c));
67 SubtractFrom(padded.Plane(c), &hf.Plane(c));
68 }
69 // TODO(veluca): DC CfL?
70 size_t xcolortiles = DivCeil(frame_dim.xsize_blocks, kColorTileDimInBlocks);
71 size_t ycolortiles = DivCeil(frame_dim.ysize_blocks, kColorTileDimInBlocks);
72 JXL_RETURN_IF_ERROR(RunOnPool(
73 pool, 0, xcolortiles * ycolortiles, ThreadPool::NoInit,
74 [&](size_t tile_id, size_t _) {
75 size_t tx = tile_id % xcolortiles;
76 size_t ty = tile_id / xcolortiles;
77 size_t x0 = tx * kColorTileDim;
78 size_t x1 = std::min(x0 + kColorTileDim, hf.xsize());
79 size_t y0 = ty * kColorTileDim;
80 size_t y1 = std::min(y0 + kColorTileDim, hf.ysize());
81 for (size_t c : {0, 2}) {
82 static constexpr float kInvColorFactor = 1.0f / kDefaultColorFactor;
83 auto ca = Zero(df);
84 auto cb = Zero(df);
85 const auto inv_color_factor = Set(df, kInvColorFactor);
86 for (size_t y = y0; y < y1; y++) {
87 const float* row_m = hf.PlaneRow(1, y);
88 const float* row_s = hf.PlaneRow(c, y);
89 for (size_t x = x0; x < x1; x += Lanes(df)) {
90 // color residual = ax + b
91 const auto a = inv_color_factor * Load(df, row_m + x);
92 const auto b = Zero(df) - Load(df, row_s + x);
93 ca = MulAdd(a, a, ca);
94 cb = MulAdd(a, b, cb);
95 }
96 }
97 float best = -GetLane(SumOfLanes(df, cb)) /
98 (GetLane(SumOfLanes(df, ca)) + 1e-9f);
99 int8_t& res = (c == 0 ? shared.cmap.ytox_map : shared.cmap.ytob_map)
100 .Row(ty)[tx];
101 res = std::max(-128.0f, std::min(127.0f, roundf(best)));
102 }
103 },
104 "CfL"));
105 Image3F pooled(frame_dim.xsize_padded / 4, frame_dim.ysize_padded / 4);
106 Image3F summed(frame_dim.xsize_padded / 4, frame_dim.ysize_padded / 4);
107 JXL_RETURN_IF_ERROR(RunOnPool(
108 pool, 0, frame_dim.ysize_padded / 4, ThreadPool::NoInit,
109 [&](size_t y, size_t _) {
110 for (size_t c = 0; c < 3; c++) {
111 float* JXL_RESTRICT row_out = pooled.PlaneRow(c, y);
112 float* JXL_RESTRICT row_out_avg = summed.PlaneRow(c, y);
113 const float* JXL_RESTRICT row_in[4];
114 for (size_t iy = 0; iy < 4; iy++) {
115 row_in[iy] = hf.PlaneRow(c, 4 * y + pad + iy);
116 }
117 for (size_t x = 0; x < frame_dim.xsize_padded / 4; x++) {
118 auto max = Zero(df4);
119 auto sum = Zero(df4);
120 for (size_t iy = 0; iy < 4; iy++) {
121 for (size_t ix = 0; ix < 4; ix += Lanes(df4)) {
122 const auto nn = Abs(Load(df4, row_in[iy] + x * 4 + ix + pad));
123 sum += nn;
124 max = IfThenElse(max > nn, max, nn);
125 }
126 }
127 row_out_avg[x] = GetLane(SumOfLanes(df4, sum));
128 row_out[x] = GetLane(MaxOfLanes(df4, max));
129 }
130 }
131 },
132 "MaxPool"));
133 // TODO(veluca): better handling of the border
134 // TODO(veluca): consider some faster blurring method.
135 // TODO(veluca): parallelize.
136 // Remove noise from the resulting image.
137 auto g2 = CreateRecursiveGaussian(2.0849544429861884);
138 constexpr size_t pad2 = 16;
139 Image3F summed_pad = PadImageMirror(summed, pad2, pad2);
140 ImageF tmp_out(summed_pad.xsize(), summed_pad.ysize());
141 ImageF tmp2(summed_pad.xsize(), summed_pad.ysize());
142 Image3F pooled_pad = PadImageMirror(pooled, pad2, pad2);
143 for (size_t c = 0; c < 3; c++) {
144 FastGaussian(g2, summed_pad.Plane(c), pool, &tmp2, &tmp_out);
145 const auto unblurred_multiplier = Set(df, 0.5f);
146 for (size_t y = 0; y < summed.ysize(); y++) {
147 float* row = summed.PlaneRow(c, y);
148 const float* row_blur = tmp_out.Row(y + pad2);
149 for (size_t x = 0; x < summed.xsize(); x += Lanes(df)) {
150 const auto b = Load(df, row_blur + x + pad2);
151 const auto o = Load(df, row + x) * unblurred_multiplier;
152 const auto m = IfThenElse(b > o, b, o);
153 Store(m, df, row + x);
154 }
155 }
156 }
157 for (size_t c = 0; c < 3; c++) {
158 FastGaussian(g2, pooled_pad.Plane(c), pool, &tmp2, &tmp_out);
159 const auto unblurred_multiplier = Set(df, 0.5f);
160 for (size_t y = 0; y < pooled.ysize(); y++) {
161 float* row = pooled.PlaneRow(c, y);
162 const float* row_blur = tmp_out.Row(y + pad2);
163 for (size_t x = 0; x < pooled.xsize(); x += Lanes(df)) {
164 const auto b = Load(df, row_blur + x + pad2);
165 const auto o = Load(df, row + x) * unblurred_multiplier;
166 const auto m = IfThenElse(b > o, b, o);
167 Store(m, df, row + x);
168 }
169 }
170 }
171 const static float kChannelMul[3] = {
172 7.9644294909680253f,
173 0.5700000183257159f,
174 0.20267448837597055f,
175 };
176 ImageF pooledhf44(pooled.xsize(), pooled.ysize());
177 for (size_t y = 0; y < pooled.ysize(); y++) {
178 const float* row_in_x = pooled.ConstPlaneRow(0, y);
179 const float* row_in_y = pooled.ConstPlaneRow(1, y);
180 const float* row_in_b = pooled.ConstPlaneRow(2, y);
181 float* row_out = pooledhf44.Row(y);
182 for (size_t x = 0; x < pooled.xsize(); x += Lanes(df)) {
183 auto v = Set(df, kChannelMul[0]) * Load(df, row_in_x + x);
184 v = MulAdd(Set(df, kChannelMul[1]), Load(df, row_in_y + x), v);
185 v = MulAdd(Set(df, kChannelMul[2]), Load(df, row_in_b + x), v);
186 Store(v, df, row_out + x);
187 }
188 }
189 ImageF summedhf44(summed.xsize(), summed.ysize());
190 for (size_t y = 0; y < summed.ysize(); y++) {
191 const float* row_in_x = summed.ConstPlaneRow(0, y);
192 const float* row_in_y = summed.ConstPlaneRow(1, y);
193 const float* row_in_b = summed.ConstPlaneRow(2, y);
194 float* row_out = summedhf44.Row(y);
195 for (size_t x = 0; x < summed.xsize(); x += Lanes(df)) {
196 auto v = Set(df, kChannelMul[0]) * Load(df, row_in_x + x);
197 v = MulAdd(Set(df, kChannelMul[1]), Load(df, row_in_y + x), v);
198 v = MulAdd(Set(df, kChannelMul[2]), Load(df, row_in_b + x), v);
199 Store(v, df, row_out + x);
200 }
201 }
202 aux_out->DumpPlaneNormalized("pooledhf44", pooledhf44);
203 aux_out->DumpPlaneNormalized("summedhf44", summedhf44);
204
205 static const float kDcQuantMul = 0.88170190420916206;
206 static const float kAcQuantMul = 2.5165738934721524;
207
208 float dc_quant = kDcQuantMul * InitialQuantDC(cparams.butteraugli_distance);
209 float ac_quant_base = kAcQuantMul / cparams.butteraugli_distance;
210 ImageF quant_field(frame_dim.xsize_blocks, frame_dim.ysize_blocks);
211
212 static_assert(kColorTileDim == 64, "Fix the code below");
213 auto mmacs = [&](size_t bx, size_t by, AcStrategy acs, float& min,
214 float& max) {
215 min = 1e10;
216 max = 0;
217 for (size_t y = 2 * by; y < 2 * (by + acs.covered_blocks_y()); y++) {
218 const float* row = summedhf44.Row(y);
219 for (size_t x = 2 * bx; x < 2 * (bx + acs.covered_blocks_x()); x++) {
220 min = std::min(min, row[x]);
221 max = std::max(max, row[x]);
222 }
223 }
224 };
225 // Multipliers for allowed range of summedhf44.
226 std::pair<AcStrategy::Type, float> candidates[] = {
227 // The order is such that, in case of ties, 8x8 is favoured over 4x4 which
228 // is favoured over 2x2. Similarly, we prefer square transforms over
229 // same-area rectangular ones.
230 {AcStrategy::Type::DCT2X2, 1.5f},
231 {AcStrategy::Type::DCT4X4, 1.4f},
232 {AcStrategy::Type::DCT4X8, 1.2f},
233 {AcStrategy::Type::DCT8X4, 1.2f},
234 {AcStrategy::Type::AFV0,
235 1.15f}, // doesn't really work with these heuristics
236 {AcStrategy::Type::AFV1, 1.15f},
237 {AcStrategy::Type::AFV2, 1.15f},
238 {AcStrategy::Type::AFV3, 1.15f},
239 {AcStrategy::Type::DCT, 1.0f},
240 {AcStrategy::Type::DCT16X8, 0.8f},
241 {AcStrategy::Type::DCT8X16, 0.8f},
242 {AcStrategy::Type::DCT16X16, 0.2f},
243 {AcStrategy::Type::DCT16X32, 0.2f},
244 {AcStrategy::Type::DCT32X16, 0.2f},
245 {AcStrategy::Type::DCT32X32, 0.2f},
246 {AcStrategy::Type::DCT32X64, 0.1f},
247 {AcStrategy::Type::DCT64X32, 0.1f},
248 {AcStrategy::Type::DCT64X64, 0.04f},
249
250 #if 0
251 {AcStrategy::Type::DCT2X2, 1e+10}, {AcStrategy::Type::DCT4X4, 2.0f},
252 {AcStrategy::Type::DCT, 1.0f}, {AcStrategy::Type::DCT16X8, 1.0f},
253 {AcStrategy::Type::DCT8X16, 1.0f}, {AcStrategy::Type::DCT32X8, 1.0f},
254 {AcStrategy::Type::DCT8X32, 1.0f}, {AcStrategy::Type::DCT32X16, 1.0f},
255 {AcStrategy::Type::DCT16X32, 1.0f}, {AcStrategy::Type::DCT64X32, 1.0f},
256 {AcStrategy::Type::DCT32X64, 1.0f}, {AcStrategy::Type::DCT16X16, 1.0f},
257 {AcStrategy::Type::DCT32X32, 1.0f}, {AcStrategy::Type::DCT64X64, 1.0f},
258 #endif
259 // TODO(veluca): figure out if we want 4x8 and/or AVF.
260 };
261 float max_range = 1e-8f + 0.5f * std::pow(cparams.butteraugli_distance, 0.5f);
262 // Change quant field and sharpness amounts based on (pooled|summed)hf44, and
263 // compute block sizes.
264 // TODO(veluca): maybe this could be done per group: it would allow choosing
265 // floating blocks better.
266 JXL_RETURN_IF_ERROR(RunOnPool(
267 pool, 0, xcolortiles * ycolortiles, ThreadPool::NoInit,
268 [&](size_t tile_id, size_t _) {
269 size_t tx = tile_id % xcolortiles;
270 size_t ty = tile_id / xcolortiles;
271 size_t x0 = tx * kColorTileDim / kBlockDim;
272 size_t x1 = std::min(x0 + kColorTileDimInBlocks, quant_field.xsize());
273 size_t y0 = ty * kColorTileDim / kBlockDim;
274 size_t y1 = std::min(y0 + kColorTileDimInBlocks, quant_field.ysize());
275 size_t qf_stride = quant_field.PixelsPerRow();
276 size_t epf_stride = shared.epf_sharpness.PixelsPerRow();
277 bool chosen_mask[64] = {};
278 for (size_t y = y0; y < y1; y++) {
279 uint8_t* epf_row = shared.epf_sharpness.Row(y);
280 float* qf_row = quant_field.Row(y);
281 for (size_t x = x0; x < x1; x++) {
282 if (chosen_mask[(y - y0) * 8 + (x - x0)]) continue;
283 // Default to DCT8 just in case something funny happens in the loop
284 // below.
285 AcStrategy::Type best = AcStrategy::DCT;
286 size_t best_covered = 1;
287 float qf = ac_quant_base;
288 for (size_t i = 0; i < sizeof(candidates) / sizeof(*candidates);
289 i++) {
290 AcStrategy acs = AcStrategy::FromRawStrategy(candidates[i].first);
291 if (y + acs.covered_blocks_y() > y1) continue;
292 if (x + acs.covered_blocks_x() > x1) continue;
293 bool fits = true;
294 for (size_t iy = y; iy < y + acs.covered_blocks_y(); iy++) {
295 for (size_t ix = x; ix < x + acs.covered_blocks_x(); ix++) {
296 if (chosen_mask[(iy - y0) * 8 + (ix - x0)]) {
297 fits = false;
298 break;
299 }
300 }
301 }
302 if (!fits) continue;
303 float min, max;
304 mmacs(x, y, acs, min, max);
305 if (max - min > max_range * candidates[i].second) continue;
306 size_t cb = acs.covered_blocks_x() * acs.covered_blocks_y();
307 if (cb >= best_covered) {
308 best_covered = cb;
309 best = candidates[i].first;
310 // TODO(veluca): make this better.
311 qf = ac_quant_base /
312 (3.9312946339134007f + 2.6011435675118082f * min);
313 }
314 }
315 shared.ac_strategy.Set(x, y, best);
316 AcStrategy acs = AcStrategy::FromRawStrategy(best);
317 for (size_t iy = y; iy < y + acs.covered_blocks_y(); iy++) {
318 for (size_t ix = x; ix < x + acs.covered_blocks_x(); ix++) {
319 chosen_mask[(iy - y0) * 8 + (ix - x0)] = 1;
320 qf_row[ix + (iy - y) * qf_stride] = qf;
321 }
322 }
323 // TODO
324 for (size_t iy = y; iy < y + acs.covered_blocks_y(); iy++) {
325 for (size_t ix = x; ix < x + acs.covered_blocks_x(); ix++) {
326 epf_row[ix + (iy - y) * epf_stride] = 4;
327 }
328 }
329 }
330 }
331 },
332 "QF+ACS+EPF"));
333 aux_out->DumpPlaneNormalized("qf", quant_field);
334 aux_out->DumpPlaneNormalized("epf", shared.epf_sharpness);
335 DumpAcStrategy(shared.ac_strategy, frame_dim.xsize_padded,
336 frame_dim.ysize_padded, "acs", aux_out);
337
338 shared.quantizer.SetQuantField(dc_quant, quant_field,
339 &shared.raw_quant_field);
340
341 return true;
342 }
343 } // namespace
344 // NOLINTNEXTLINE(google-readability-namespace-comments)
345 } // namespace HWY_NAMESPACE
346 } // namespace jxl
347 HWY_AFTER_NAMESPACE();
348
349 #if HWY_ONCE
350 namespace jxl {
351 HWY_EXPORT(Heuristics);
LossyFrameHeuristics(PassesEncoderState * enc_state,ModularFrameEncoder * modular_frame_encoder,const ImageBundle * linear,Image3F * opsin,const JxlCmsInterface & cms,ThreadPool * pool,AuxOut * aux_out)352 Status FastEncoderHeuristics::LossyFrameHeuristics(
353 PassesEncoderState* enc_state, ModularFrameEncoder* modular_frame_encoder,
354 const ImageBundle* linear, Image3F* opsin, const JxlCmsInterface& cms,
355 ThreadPool* pool, AuxOut* aux_out) {
356 return HWY_DYNAMIC_DISPATCH(Heuristics)(enc_state, modular_frame_encoder,
357 linear, opsin, pool, aux_out);
358 }
359
360 } // namespace jxl
361 #endif
362