1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
2 //
3 // Use of this source code is governed by a BSD-style
4 // license that can be found in the LICENSE file.
5
6 #include "lib/jxl/enc_ar_control_field.h"
7
8 #include <stdint.h>
9 #include <stdlib.h>
10
11 #include <algorithm>
12
13 #undef HWY_TARGET_INCLUDE
14 #define HWY_TARGET_INCLUDE "lib/jxl/enc_ar_control_field.cc"
15 #include <hwy/foreach_target.h>
16 #include <hwy/highway.h>
17
18 #include "lib/jxl/ac_strategy.h"
19 #include "lib/jxl/base/compiler_specific.h"
20 #include "lib/jxl/base/data_parallel.h"
21 #include "lib/jxl/base/status.h"
22 #include "lib/jxl/chroma_from_luma.h"
23 #include "lib/jxl/common.h"
24 #include "lib/jxl/enc_adaptive_quantization.h"
25 #include "lib/jxl/enc_params.h"
26 #include "lib/jxl/image.h"
27 #include "lib/jxl/image_bundle.h"
28 #include "lib/jxl/image_ops.h"
29 #include "lib/jxl/quant_weights.h"
30 #include "lib/jxl/quantizer.h"
31
32 HWY_BEFORE_NAMESPACE();
33 namespace jxl {
34 namespace HWY_NAMESPACE {
35 namespace {
36
ProcessTile(const Image3F & opsin,PassesEncoderState * enc_state,const Rect & rect,ArControlFieldHeuristics::TempImages * temp_image)37 void ProcessTile(const Image3F& opsin, PassesEncoderState* enc_state,
38 const Rect& rect,
39 ArControlFieldHeuristics::TempImages* temp_image) {
40 constexpr size_t N = kBlockDim;
41 ImageB* JXL_RESTRICT epf_sharpness = &enc_state->shared.epf_sharpness;
42 ImageF* JXL_RESTRICT quant = &enc_state->initial_quant_field;
43 JXL_ASSERT(
44 epf_sharpness->xsize() == enc_state->shared.frame_dim.xsize_blocks &&
45 epf_sharpness->ysize() == enc_state->shared.frame_dim.ysize_blocks);
46
47 if (enc_state->cparams.butteraugli_distance < kMinButteraugliForDynamicAR ||
48 enc_state->cparams.speed_tier > SpeedTier::kWombat ||
49 enc_state->shared.frame_header.loop_filter.epf_iters == 0) {
50 FillPlane(static_cast<uint8_t>(4), epf_sharpness, rect);
51 return;
52 }
53
54 // Likely better to have a higher X weight, like:
55 // const float kChannelWeights[3] = {47.0f, 4.35f, 0.287f};
56 const float kChannelWeights[3] = {4.35f, 4.35f, 0.287f};
57 const float kChannelWeightsLapNeg[3] = {-0.125f * kChannelWeights[0],
58 -0.125f * kChannelWeights[1],
59 -0.125f * kChannelWeights[2]};
60 const size_t sharpness_stride =
61 static_cast<size_t>(epf_sharpness->PixelsPerRow());
62
63 size_t by0 = rect.y0();
64 size_t by1 = rect.y0() + rect.ysize();
65 size_t bx0 = rect.x0();
66 size_t bx1 = rect.x0() + rect.xsize();
67 temp_image->InitOnce();
68 ImageF& laplacian_sqrsum = temp_image->laplacian_sqrsum;
69 // Calculate the L2 of the 3x3 Laplacian in an integral transform
70 // (for example 32x32 dct). This relates to transforms ability
71 // to propagate artefacts.
72 size_t y0 = by0 == 0 ? 2 : 0;
73 size_t y1 = by1 * N + 4 <= opsin.ysize() + 2 ? (by1 - by0) * N + 4
74 : opsin.ysize() + 2 - by0 * N;
75 size_t x0 = bx0 == 0 ? 2 : 0;
76 size_t x1 = bx1 * N + 4 <= opsin.xsize() + 2 ? (bx1 - bx0) * N + 4
77 : opsin.xsize() + 2 - bx0 * N;
78 HWY_FULL(float) df;
79 for (size_t y = y0; y < y1; y++) {
80 float* JXL_RESTRICT laplacian_sqrsum_row = laplacian_sqrsum.Row(y);
81 size_t cy = y + by0 * N - 2;
82 const float* JXL_RESTRICT in_row_t[3];
83 const float* JXL_RESTRICT in_row[3];
84 const float* JXL_RESTRICT in_row_b[3];
85 for (size_t c = 0; c < 3; c++) {
86 in_row_t[c] = opsin.PlaneRow(c, cy > 0 ? cy - 1 : cy);
87 in_row[c] = opsin.PlaneRow(c, cy);
88 in_row_b[c] = opsin.PlaneRow(c, cy + 1 < opsin.ysize() ? cy + 1 : cy);
89 }
90 auto compute_laplacian_scalar = [&](size_t x) {
91 size_t cx = x + bx0 * N - 2;
92 const size_t prevX = cx >= 1 ? cx - 1 : cx;
93 const size_t nextX = cx + 1 < opsin.xsize() ? cx + 1 : cx;
94 float sumsqr = 0;
95 for (size_t c = 0; c < 3; c++) {
96 float laplacian =
97 kChannelWeights[c] * in_row[c][cx] +
98 kChannelWeightsLapNeg[c] *
99 (in_row[c][prevX] + in_row[c][nextX] + in_row_b[c][prevX] +
100 in_row_b[c][cx] + in_row_b[c][nextX] + in_row_t[c][prevX] +
101 in_row_t[c][cx] + in_row_t[c][nextX]);
102 sumsqr += laplacian * laplacian;
103 }
104 laplacian_sqrsum_row[x] = sumsqr;
105 };
106 size_t x = x0;
107 for (; x + bx0 * N < 3; x++) {
108 compute_laplacian_scalar(x);
109 }
110 // Interior. One extra pixel of border as the last pixel is special.
111 for (; x + Lanes(df) <= x1 && x + Lanes(df) + bx0 * N - 1 <= opsin.xsize();
112 x += Lanes(df)) {
113 size_t cx = x + bx0 * N - 2;
114 auto sumsqr = Zero(df);
115 for (size_t c = 0; c < 3; c++) {
116 auto laplacian =
117 LoadU(df, in_row[c] + cx) * Set(df, kChannelWeights[c]);
118 auto sum_oth0 = LoadU(df, in_row[c] + cx - 1);
119 auto sum_oth1 = LoadU(df, in_row[c] + cx + 1);
120 auto sum_oth2 = LoadU(df, in_row_t[c] + cx - 1);
121 auto sum_oth3 = LoadU(df, in_row_t[c] + cx);
122 sum_oth0 += LoadU(df, in_row_t[c] + cx + 1);
123 sum_oth1 += LoadU(df, in_row_b[c] + cx - 1);
124 sum_oth2 += LoadU(df, in_row_b[c] + cx);
125 sum_oth3 += LoadU(df, in_row_b[c] + cx + 1);
126 sum_oth0 += sum_oth1;
127 sum_oth2 += sum_oth3;
128 sum_oth0 += sum_oth2;
129 laplacian =
130 MulAdd(Set(df, kChannelWeightsLapNeg[c]), sum_oth0, laplacian);
131 sumsqr = MulAdd(laplacian, laplacian, sumsqr);
132 }
133 StoreU(sumsqr, df, laplacian_sqrsum_row + x);
134 }
135 for (; x < x1; x++) {
136 compute_laplacian_scalar(x);
137 }
138 }
139 HWY_CAPPED(float, 4) df4;
140 // Calculate the L2 of the 3x3 Laplacian in 4x4 blocks within the area
141 // of the integral transform. Sample them within the integral transform
142 // with two offsets (0,0) and (-2, -2) pixels (sqrsum_00 and sqrsum_22,
143 // respectively).
144 ImageF& sqrsum_00 = temp_image->sqrsum_00;
145 size_t sqrsum_00_stride = sqrsum_00.PixelsPerRow();
146 float* JXL_RESTRICT sqrsum_00_row = sqrsum_00.Row(0);
147 for (size_t y = 0; y < (by1 - by0) * 2; y++) {
148 const float* JXL_RESTRICT rows_in[4];
149 for (size_t iy = 0; iy < 4; iy++) {
150 rows_in[iy] = laplacian_sqrsum.ConstRow(y * 4 + iy + 2);
151 }
152 float* JXL_RESTRICT row_out = sqrsum_00_row + y * sqrsum_00_stride;
153 for (size_t x = 0; x < (bx1 - bx0) * 2; x++) {
154 auto sum = Zero(df4);
155 for (size_t iy = 0; iy < 4; iy++) {
156 for (size_t ix = 0; ix < 4; ix += Lanes(df4)) {
157 sum += LoadU(df4, rows_in[iy] + x * 4 + ix + 2);
158 }
159 }
160 row_out[x] = GetLane(Sqrt(SumOfLanes(sum))) * (1.0f / 4.0f);
161 }
162 }
163 // Indexing iy and ix is a bit tricky as we include a 2 pixel border
164 // around the block for evenness calculations. This is similar to what
165 // we did in guetzli for the observability of artefacts, except there
166 // the element is a sliding 5x5, not sparsely sampled 4x4 box like here.
167 ImageF& sqrsum_22 = temp_image->sqrsum_22;
168 size_t sqrsum_22_stride = sqrsum_22.PixelsPerRow();
169 float* JXL_RESTRICT sqrsum_22_row = sqrsum_22.Row(0);
170 for (size_t y = 0; y < (by1 - by0) * 2 + 1; y++) {
171 const float* JXL_RESTRICT rows_in[4];
172 for (size_t iy = 0; iy < 4; iy++) {
173 rows_in[iy] = laplacian_sqrsum.ConstRow(y * 4 + iy);
174 }
175 float* JXL_RESTRICT row_out = sqrsum_22_row + y * sqrsum_22_stride;
176 // ignore pixels outside the image.
177 // Y coordinates are relative to by0*8+y*4.
178 size_t sy = y * 4 + by0 * 8 > 0 ? 0 : 2;
179 size_t ey = y * 4 + by0 * 8 + 4 <= opsin.ysize() + 2
180 ? 4
181 : opsin.ysize() - y * 4 - by0 * 8 + 2;
182 for (size_t x = 0; x < (bx1 - bx0) * 2 + 1; x++) {
183 // ignore pixels outside the image.
184 // X coordinates are relative to bx0*8.
185 size_t sx = x * 4 + bx0 * 8 > 0 ? x * 4 : x * 4 + 2;
186 size_t ex = x * 4 + bx0 * 8 + 4 <= opsin.xsize() + 2
187 ? x * 4 + 4
188 : opsin.xsize() - bx0 * 8 + 2;
189 if (ex - sx == 4 && ey - sy == 4) {
190 auto sum = Zero(df4);
191 for (size_t iy = 0; iy < 4; iy++) {
192 for (size_t ix = 0; ix < 4; ix += Lanes(df4)) {
193 sum += Load(df4, rows_in[iy] + sx + ix);
194 }
195 }
196 row_out[x] = GetLane(Sqrt(SumOfLanes(sum))) * (1.0f / 4.0f);
197 } else {
198 float sum = 0;
199 for (size_t iy = sy; iy < ey; iy++) {
200 for (size_t ix = sx; ix < ex; ix++) {
201 sum += rows_in[iy][ix];
202 }
203 }
204 row_out[x] = std::sqrt(sum / ((ex - sx) * (ey - sy)));
205 }
206 }
207 }
208 for (size_t by = by0; by < by1; by++) {
209 AcStrategyRow acs_row = enc_state->shared.ac_strategy.ConstRow(by);
210 uint8_t* JXL_RESTRICT out_row = epf_sharpness->Row(by);
211 float* JXL_RESTRICT quant_row = quant->Row(by);
212 for (size_t bx = bx0; bx < bx1; bx++) {
213 AcStrategy acs = acs_row[bx];
214 if (!acs.IsFirstBlock()) continue;
215 // The errors are going to be linear to the quantization value in this
216 // locality. We only have access to the initial quant field here.
217 float quant_val = 1.0f / quant_row[bx];
218
219 const auto sq00 = [&](size_t y, size_t x) {
220 return sqrsum_00_row[((by - by0) * 2 + y) * sqrsum_00_stride +
221 (bx - bx0) * 2 + x];
222 };
223 const auto sq22 = [&](size_t y, size_t x) {
224 return sqrsum_22_row[((by - by0) * 2 + y) * sqrsum_22_stride +
225 (bx - bx0) * 2 + x];
226 };
227 float sqrsum_integral_transform = 0;
228 for (size_t iy = 0; iy < acs.covered_blocks_y() * 2; iy++) {
229 for (size_t ix = 0; ix < acs.covered_blocks_x() * 2; ix++) {
230 sqrsum_integral_transform += sq00(iy, ix) * sq00(iy, ix);
231 }
232 }
233 sqrsum_integral_transform /=
234 4 * acs.covered_blocks_x() * acs.covered_blocks_y();
235 sqrsum_integral_transform = std::sqrt(sqrsum_integral_transform);
236 // If masking is high or amplitude of the artefacts is low, then no
237 // smoothing is needed.
238 for (size_t iy = 0; iy < acs.covered_blocks_y(); iy++) {
239 for (size_t ix = 0; ix < acs.covered_blocks_x(); ix++) {
240 // Five 4x4 blocks for masking estimation, all within the
241 // 8x8 area.
242 float minval_1 = std::min(sq00(2 * iy + 0, 2 * ix + 0),
243 sq00(2 * iy + 0, 2 * ix + 1));
244 float minval_2 = std::min(sq00(2 * iy + 1, 2 * ix + 0),
245 sq00(2 * iy + 1, 2 * ix + 1));
246 float minval = std::min(minval_1, minval_2);
247 minval = std::min(minval, sq22(2 * iy + 1, 2 * ix + 1));
248 // Nine more 4x4 blocks for masking estimation, includes
249 // the 2 pixel area around the 8x8 block being controlled.
250 float minval2_1 = std::min(sq22(2 * iy + 0, 2 * ix + 0),
251 sq22(2 * iy + 0, 2 * ix + 1));
252 float minval2_2 = std::min(sq22(2 * iy + 0, 2 * ix + 2),
253 sq22(2 * iy + 1, 2 * ix + 0));
254 float minval2_3 = std::min(sq22(2 * iy + 1, 2 * ix + 1),
255 sq22(2 * iy + 1, 2 * ix + 2));
256 float minval2_4 = std::min(sq22(2 * iy + 2, 2 * ix + 0),
257 sq22(2 * iy + 2, 2 * ix + 1));
258 float minval2_5 = std::min(minval2_1, minval2_2);
259 float minval2_6 = std::min(minval2_3, minval2_4);
260 float minval2 = std::min(minval2_5, minval2_6);
261 minval2 = std::min(minval2, sq22(2 * iy + 2, 2 * ix + 2));
262 float minval3 = std::min(minval, minval2);
263 minval *= 0.125f;
264 minval += 0.625f * minval3;
265 minval +=
266 0.125f * std::min(1.5f * minval3, sq22(2 * iy + 1, 2 * ix + 1));
267 minval += 0.125f * minval2;
268 // Larger kBias, less smoothing for low intensity changes.
269 float kDeltaLimit = 3.2;
270 float bias = 0.0625f * quant_val;
271 float delta =
272 (sqrsum_integral_transform + (kDeltaLimit + 0.05) * bias) /
273 (minval + bias);
274 int out = 4;
275 if (delta > kDeltaLimit) {
276 out = 4; // smooth
277 } else {
278 out = 0;
279 }
280 // 'threshold' is separate from 'bias' for easier tuning of these
281 // heuristics.
282 float threshold = 0.0625f * quant_val;
283 const float kSmoothLimit = 0.085f;
284 float smooth = 0.20f * (sq00(2 * iy + 0, 2 * ix + 0) +
285 sq00(2 * iy + 0, 2 * ix + 1) +
286 sq00(2 * iy + 1, 2 * ix + 0) +
287 sq00(2 * iy + 1, 2 * ix + 1) + minval);
288 if (smooth < kSmoothLimit * threshold) {
289 out = 4;
290 }
291 out_row[bx + sharpness_stride * iy + ix] = out;
292 }
293 }
294 }
295 }
296 }
297
298 } // namespace
299 // NOLINTNEXTLINE(google-readability-namespace-comments)
300 } // namespace HWY_NAMESPACE
301 } // namespace jxl
302 HWY_AFTER_NAMESPACE();
303
304 #if HWY_ONCE
305 namespace jxl {
306 HWY_EXPORT(ProcessTile);
307
RunRect(const Rect & block_rect,const Image3F & opsin,PassesEncoderState * enc_state,size_t thread)308 void ArControlFieldHeuristics::RunRect(const Rect& block_rect,
309 const Image3F& opsin,
310 PassesEncoderState* enc_state,
311 size_t thread) {
312 HWY_DYNAMIC_DISPATCH(ProcessTile)
313 (opsin, enc_state, block_rect, &temp_images[thread]);
314 }
315
316 } // namespace jxl
317
318 #endif
319