1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
2 //
3 // Use of this source code is governed by a BSD-style
4 // license that can be found in the LICENSE file.
5 
6 #include "lib/jxl/enc_ar_control_field.h"
7 
8 #include <stdint.h>
9 #include <stdlib.h>
10 
11 #include <algorithm>
12 
13 #undef HWY_TARGET_INCLUDE
14 #define HWY_TARGET_INCLUDE "lib/jxl/enc_ar_control_field.cc"
15 #include <hwy/foreach_target.h>
16 #include <hwy/highway.h>
17 
18 #include "lib/jxl/ac_strategy.h"
19 #include "lib/jxl/base/compiler_specific.h"
20 #include "lib/jxl/base/data_parallel.h"
21 #include "lib/jxl/base/status.h"
22 #include "lib/jxl/chroma_from_luma.h"
23 #include "lib/jxl/common.h"
24 #include "lib/jxl/enc_adaptive_quantization.h"
25 #include "lib/jxl/enc_params.h"
26 #include "lib/jxl/image.h"
27 #include "lib/jxl/image_bundle.h"
28 #include "lib/jxl/image_ops.h"
29 #include "lib/jxl/quant_weights.h"
30 #include "lib/jxl/quantizer.h"
31 
32 HWY_BEFORE_NAMESPACE();
33 namespace jxl {
34 namespace HWY_NAMESPACE {
35 namespace {
36 
ProcessTile(const Image3F & opsin,PassesEncoderState * enc_state,const Rect & rect,ArControlFieldHeuristics::TempImages * temp_image)37 void ProcessTile(const Image3F& opsin, PassesEncoderState* enc_state,
38                  const Rect& rect,
39                  ArControlFieldHeuristics::TempImages* temp_image) {
40   constexpr size_t N = kBlockDim;
41   ImageB* JXL_RESTRICT epf_sharpness = &enc_state->shared.epf_sharpness;
42   ImageF* JXL_RESTRICT quant = &enc_state->initial_quant_field;
43   JXL_ASSERT(
44       epf_sharpness->xsize() == enc_state->shared.frame_dim.xsize_blocks &&
45       epf_sharpness->ysize() == enc_state->shared.frame_dim.ysize_blocks);
46 
47   if (enc_state->cparams.butteraugli_distance < kMinButteraugliForDynamicAR ||
48       enc_state->cparams.speed_tier > SpeedTier::kWombat ||
49       enc_state->shared.frame_header.loop_filter.epf_iters == 0) {
50     FillPlane(static_cast<uint8_t>(4), epf_sharpness, rect);
51     return;
52   }
53 
54   // Likely better to have a higher X weight, like:
55   // const float kChannelWeights[3] = {47.0f, 4.35f, 0.287f};
56   const float kChannelWeights[3] = {4.35f, 4.35f, 0.287f};
57   const float kChannelWeightsLapNeg[3] = {-0.125f * kChannelWeights[0],
58                                           -0.125f * kChannelWeights[1],
59                                           -0.125f * kChannelWeights[2]};
60   const size_t sharpness_stride =
61       static_cast<size_t>(epf_sharpness->PixelsPerRow());
62 
63   size_t by0 = rect.y0();
64   size_t by1 = rect.y0() + rect.ysize();
65   size_t bx0 = rect.x0();
66   size_t bx1 = rect.x0() + rect.xsize();
67   temp_image->InitOnce();
68   ImageF& laplacian_sqrsum = temp_image->laplacian_sqrsum;
69   // Calculate the L2 of the 3x3 Laplacian in an integral transform
70   // (for example 32x32 dct). This relates to transforms ability
71   // to propagate artefacts.
72   size_t y0 = by0 == 0 ? 2 : 0;
73   size_t y1 = by1 * N + 4 <= opsin.ysize() + 2 ? (by1 - by0) * N + 4
74                                                : opsin.ysize() + 2 - by0 * N;
75   size_t x0 = bx0 == 0 ? 2 : 0;
76   size_t x1 = bx1 * N + 4 <= opsin.xsize() + 2 ? (bx1 - bx0) * N + 4
77                                                : opsin.xsize() + 2 - bx0 * N;
78   HWY_FULL(float) df;
79   for (size_t y = y0; y < y1; y++) {
80     float* JXL_RESTRICT laplacian_sqrsum_row = laplacian_sqrsum.Row(y);
81     size_t cy = y + by0 * N - 2;
82     const float* JXL_RESTRICT in_row_t[3];
83     const float* JXL_RESTRICT in_row[3];
84     const float* JXL_RESTRICT in_row_b[3];
85     for (size_t c = 0; c < 3; c++) {
86       in_row_t[c] = opsin.PlaneRow(c, cy > 0 ? cy - 1 : cy);
87       in_row[c] = opsin.PlaneRow(c, cy);
88       in_row_b[c] = opsin.PlaneRow(c, cy + 1 < opsin.ysize() ? cy + 1 : cy);
89     }
90     auto compute_laplacian_scalar = [&](size_t x) {
91       size_t cx = x + bx0 * N - 2;
92       const size_t prevX = cx >= 1 ? cx - 1 : cx;
93       const size_t nextX = cx + 1 < opsin.xsize() ? cx + 1 : cx;
94       float sumsqr = 0;
95       for (size_t c = 0; c < 3; c++) {
96         float laplacian =
97             kChannelWeights[c] * in_row[c][cx] +
98             kChannelWeightsLapNeg[c] *
99                 (in_row[c][prevX] + in_row[c][nextX] + in_row_b[c][prevX] +
100                  in_row_b[c][cx] + in_row_b[c][nextX] + in_row_t[c][prevX] +
101                  in_row_t[c][cx] + in_row_t[c][nextX]);
102         sumsqr += laplacian * laplacian;
103       }
104       laplacian_sqrsum_row[x] = sumsqr;
105     };
106     size_t x = x0;
107     for (; x + bx0 * N < 3; x++) {
108       compute_laplacian_scalar(x);
109     }
110     // Interior. One extra pixel of border as the last pixel is special.
111     for (; x + Lanes(df) <= x1 && x + Lanes(df) + bx0 * N - 1 <= opsin.xsize();
112          x += Lanes(df)) {
113       size_t cx = x + bx0 * N - 2;
114       auto sumsqr = Zero(df);
115       for (size_t c = 0; c < 3; c++) {
116         auto laplacian =
117             LoadU(df, in_row[c] + cx) * Set(df, kChannelWeights[c]);
118         auto sum_oth0 = LoadU(df, in_row[c] + cx - 1);
119         auto sum_oth1 = LoadU(df, in_row[c] + cx + 1);
120         auto sum_oth2 = LoadU(df, in_row_t[c] + cx - 1);
121         auto sum_oth3 = LoadU(df, in_row_t[c] + cx);
122         sum_oth0 += LoadU(df, in_row_t[c] + cx + 1);
123         sum_oth1 += LoadU(df, in_row_b[c] + cx - 1);
124         sum_oth2 += LoadU(df, in_row_b[c] + cx);
125         sum_oth3 += LoadU(df, in_row_b[c] + cx + 1);
126         sum_oth0 += sum_oth1;
127         sum_oth2 += sum_oth3;
128         sum_oth0 += sum_oth2;
129         laplacian =
130             MulAdd(Set(df, kChannelWeightsLapNeg[c]), sum_oth0, laplacian);
131         sumsqr = MulAdd(laplacian, laplacian, sumsqr);
132       }
133       StoreU(sumsqr, df, laplacian_sqrsum_row + x);
134     }
135     for (; x < x1; x++) {
136       compute_laplacian_scalar(x);
137     }
138   }
139   HWY_CAPPED(float, 4) df4;
140   // Calculate the L2 of the 3x3 Laplacian in 4x4 blocks within the area
141   // of the integral transform. Sample them within the integral transform
142   // with two offsets (0,0) and (-2, -2) pixels (sqrsum_00 and sqrsum_22,
143   //  respectively).
144   ImageF& sqrsum_00 = temp_image->sqrsum_00;
145   size_t sqrsum_00_stride = sqrsum_00.PixelsPerRow();
146   float* JXL_RESTRICT sqrsum_00_row = sqrsum_00.Row(0);
147   for (size_t y = 0; y < (by1 - by0) * 2; y++) {
148     const float* JXL_RESTRICT rows_in[4];
149     for (size_t iy = 0; iy < 4; iy++) {
150       rows_in[iy] = laplacian_sqrsum.ConstRow(y * 4 + iy + 2);
151     }
152     float* JXL_RESTRICT row_out = sqrsum_00_row + y * sqrsum_00_stride;
153     for (size_t x = 0; x < (bx1 - bx0) * 2; x++) {
154       auto sum = Zero(df4);
155       for (size_t iy = 0; iy < 4; iy++) {
156         for (size_t ix = 0; ix < 4; ix += Lanes(df4)) {
157           sum += LoadU(df4, rows_in[iy] + x * 4 + ix + 2);
158         }
159       }
160       row_out[x] = GetLane(Sqrt(SumOfLanes(sum))) * (1.0f / 4.0f);
161     }
162   }
163   // Indexing iy and ix is a bit tricky as we include a 2 pixel border
164   // around the block for evenness calculations. This is similar to what
165   // we did in guetzli for the observability of artefacts, except there
166   // the element is a sliding 5x5, not sparsely sampled 4x4 box like here.
167   ImageF& sqrsum_22 = temp_image->sqrsum_22;
168   size_t sqrsum_22_stride = sqrsum_22.PixelsPerRow();
169   float* JXL_RESTRICT sqrsum_22_row = sqrsum_22.Row(0);
170   for (size_t y = 0; y < (by1 - by0) * 2 + 1; y++) {
171     const float* JXL_RESTRICT rows_in[4];
172     for (size_t iy = 0; iy < 4; iy++) {
173       rows_in[iy] = laplacian_sqrsum.ConstRow(y * 4 + iy);
174     }
175     float* JXL_RESTRICT row_out = sqrsum_22_row + y * sqrsum_22_stride;
176     // ignore pixels outside the image.
177     // Y coordinates are relative to by0*8+y*4.
178     size_t sy = y * 4 + by0 * 8 > 0 ? 0 : 2;
179     size_t ey = y * 4 + by0 * 8 + 4 <= opsin.ysize() + 2
180                     ? 4
181                     : opsin.ysize() - y * 4 - by0 * 8 + 2;
182     for (size_t x = 0; x < (bx1 - bx0) * 2 + 1; x++) {
183       // ignore pixels outside the image.
184       // X coordinates are relative to bx0*8.
185       size_t sx = x * 4 + bx0 * 8 > 0 ? x * 4 : x * 4 + 2;
186       size_t ex = x * 4 + bx0 * 8 + 4 <= opsin.xsize() + 2
187                       ? x * 4 + 4
188                       : opsin.xsize() - bx0 * 8 + 2;
189       if (ex - sx == 4 && ey - sy == 4) {
190         auto sum = Zero(df4);
191         for (size_t iy = 0; iy < 4; iy++) {
192           for (size_t ix = 0; ix < 4; ix += Lanes(df4)) {
193             sum += Load(df4, rows_in[iy] + sx + ix);
194           }
195         }
196         row_out[x] = GetLane(Sqrt(SumOfLanes(sum))) * (1.0f / 4.0f);
197       } else {
198         float sum = 0;
199         for (size_t iy = sy; iy < ey; iy++) {
200           for (size_t ix = sx; ix < ex; ix++) {
201             sum += rows_in[iy][ix];
202           }
203         }
204         row_out[x] = std::sqrt(sum / ((ex - sx) * (ey - sy)));
205       }
206     }
207   }
208   for (size_t by = by0; by < by1; by++) {
209     AcStrategyRow acs_row = enc_state->shared.ac_strategy.ConstRow(by);
210     uint8_t* JXL_RESTRICT out_row = epf_sharpness->Row(by);
211     float* JXL_RESTRICT quant_row = quant->Row(by);
212     for (size_t bx = bx0; bx < bx1; bx++) {
213       AcStrategy acs = acs_row[bx];
214       if (!acs.IsFirstBlock()) continue;
215       // The errors are going to be linear to the quantization value in this
216       // locality. We only have access to the initial quant field here.
217       float quant_val = 1.0f / quant_row[bx];
218 
219       const auto sq00 = [&](size_t y, size_t x) {
220         return sqrsum_00_row[((by - by0) * 2 + y) * sqrsum_00_stride +
221                              (bx - bx0) * 2 + x];
222       };
223       const auto sq22 = [&](size_t y, size_t x) {
224         return sqrsum_22_row[((by - by0) * 2 + y) * sqrsum_22_stride +
225                              (bx - bx0) * 2 + x];
226       };
227       float sqrsum_integral_transform = 0;
228       for (size_t iy = 0; iy < acs.covered_blocks_y() * 2; iy++) {
229         for (size_t ix = 0; ix < acs.covered_blocks_x() * 2; ix++) {
230           sqrsum_integral_transform += sq00(iy, ix) * sq00(iy, ix);
231         }
232       }
233       sqrsum_integral_transform /=
234           4 * acs.covered_blocks_x() * acs.covered_blocks_y();
235       sqrsum_integral_transform = std::sqrt(sqrsum_integral_transform);
236       // If masking is high or amplitude of the artefacts is low, then no
237       // smoothing is needed.
238       for (size_t iy = 0; iy < acs.covered_blocks_y(); iy++) {
239         for (size_t ix = 0; ix < acs.covered_blocks_x(); ix++) {
240           // Five 4x4 blocks for masking estimation, all within the
241           // 8x8 area.
242           float minval_1 = std::min(sq00(2 * iy + 0, 2 * ix + 0),
243                                     sq00(2 * iy + 0, 2 * ix + 1));
244           float minval_2 = std::min(sq00(2 * iy + 1, 2 * ix + 0),
245                                     sq00(2 * iy + 1, 2 * ix + 1));
246           float minval = std::min(minval_1, minval_2);
247           minval = std::min(minval, sq22(2 * iy + 1, 2 * ix + 1));
248           // Nine more 4x4 blocks for masking estimation, includes
249           // the 2 pixel area around the 8x8 block being controlled.
250           float minval2_1 = std::min(sq22(2 * iy + 0, 2 * ix + 0),
251                                      sq22(2 * iy + 0, 2 * ix + 1));
252           float minval2_2 = std::min(sq22(2 * iy + 0, 2 * ix + 2),
253                                      sq22(2 * iy + 1, 2 * ix + 0));
254           float minval2_3 = std::min(sq22(2 * iy + 1, 2 * ix + 1),
255                                      sq22(2 * iy + 1, 2 * ix + 2));
256           float minval2_4 = std::min(sq22(2 * iy + 2, 2 * ix + 0),
257                                      sq22(2 * iy + 2, 2 * ix + 1));
258           float minval2_5 = std::min(minval2_1, minval2_2);
259           float minval2_6 = std::min(minval2_3, minval2_4);
260           float minval2 = std::min(minval2_5, minval2_6);
261           minval2 = std::min(minval2, sq22(2 * iy + 2, 2 * ix + 2));
262           float minval3 = std::min(minval, minval2);
263           minval *= 0.125f;
264           minval += 0.625f * minval3;
265           minval +=
266               0.125f * std::min(1.5f * minval3, sq22(2 * iy + 1, 2 * ix + 1));
267           minval += 0.125f * minval2;
268           // Larger kBias, less smoothing for low intensity changes.
269           float kDeltaLimit = 3.2;
270           float bias = 0.0625f * quant_val;
271           float delta =
272               (sqrsum_integral_transform + (kDeltaLimit + 0.05) * bias) /
273               (minval + bias);
274           int out = 4;
275           if (delta > kDeltaLimit) {
276             out = 4;  // smooth
277           } else {
278             out = 0;
279           }
280           // 'threshold' is separate from 'bias' for easier tuning of these
281           // heuristics.
282           float threshold = 0.0625f * quant_val;
283           const float kSmoothLimit = 0.085f;
284           float smooth = 0.20f * (sq00(2 * iy + 0, 2 * ix + 0) +
285                                   sq00(2 * iy + 0, 2 * ix + 1) +
286                                   sq00(2 * iy + 1, 2 * ix + 0) +
287                                   sq00(2 * iy + 1, 2 * ix + 1) + minval);
288           if (smooth < kSmoothLimit * threshold) {
289             out = 4;
290           }
291           out_row[bx + sharpness_stride * iy + ix] = out;
292         }
293       }
294     }
295   }
296 }
297 
298 }  // namespace
299 // NOLINTNEXTLINE(google-readability-namespace-comments)
300 }  // namespace HWY_NAMESPACE
301 }  // namespace jxl
302 HWY_AFTER_NAMESPACE();
303 
304 #if HWY_ONCE
305 namespace jxl {
306 HWY_EXPORT(ProcessTile);
307 
RunRect(const Rect & block_rect,const Image3F & opsin,PassesEncoderState * enc_state,size_t thread)308 void ArControlFieldHeuristics::RunRect(const Rect& block_rect,
309                                        const Image3F& opsin,
310                                        PassesEncoderState* enc_state,
311                                        size_t thread) {
312   HWY_DYNAMIC_DISPATCH(ProcessTile)
313   (opsin, enc_state, block_rect, &temp_images[thread]);
314 }
315 
316 }  // namespace jxl
317 
318 #endif
319