1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
2 //
3 // Use of this source code is governed by a BSD-style
4 // license that can be found in the LICENSE file.
5
6 #include "lib/jxl/enc_group.h"
7
8 #include <utility>
9
10 #include "hwy/aligned_allocator.h"
11
12 #undef HWY_TARGET_INCLUDE
13 #define HWY_TARGET_INCLUDE "lib/jxl/enc_group.cc"
14 #include <hwy/foreach_target.h>
15 #include <hwy/highway.h>
16
17 #include "lib/jxl/ac_strategy.h"
18 #include "lib/jxl/aux_out.h"
19 #include "lib/jxl/aux_out_fwd.h"
20 #include "lib/jxl/base/bits.h"
21 #include "lib/jxl/base/compiler_specific.h"
22 #include "lib/jxl/base/profiler.h"
23 #include "lib/jxl/common.h"
24 #include "lib/jxl/dct_util.h"
25 #include "lib/jxl/dec_transforms-inl.h"
26 #include "lib/jxl/enc_params.h"
27 #include "lib/jxl/enc_transforms-inl.h"
28 #include "lib/jxl/image.h"
29 #include "lib/jxl/quantizer-inl.h"
30 #include "lib/jxl/quantizer.h"
31 HWY_BEFORE_NAMESPACE();
32 namespace jxl {
33 namespace HWY_NAMESPACE {
34
35 // NOTE: caller takes care of extracting quant from rect of RawQuantField.
QuantizeBlockAC(const Quantizer & quantizer,const bool error_diffusion,size_t c,int32_t quant,float qm_multiplier,size_t quant_kind,size_t xsize,size_t ysize,const float * JXL_RESTRICT block_in,int32_t * JXL_RESTRICT block_out)36 void QuantizeBlockAC(const Quantizer& quantizer, const bool error_diffusion,
37 size_t c, int32_t quant, float qm_multiplier,
38 size_t quant_kind, size_t xsize, size_t ysize,
39 const float* JXL_RESTRICT block_in,
40 int32_t* JXL_RESTRICT block_out) {
41 PROFILER_FUNC;
42 const float* JXL_RESTRICT qm = quantizer.InvDequantMatrix(quant_kind, c);
43 const float qac = quantizer.Scale() * quant;
44 // Not SIMD-fied for now.
45 float thres[4] = {0.5f, 0.6f, 0.6f, 0.65f};
46 if (c != 1) {
47 for (int i = 1; i < 4; ++i) {
48 thres[i] = 0.75f;
49 }
50 }
51
52 if (!error_diffusion) {
53 HWY_CAPPED(float, kBlockDim) df;
54 HWY_CAPPED(int32_t, kBlockDim) di;
55 HWY_CAPPED(uint32_t, kBlockDim) du;
56 const auto quant = Set(df, qac * qm_multiplier);
57
58 for (size_t y = 0; y < ysize * kBlockDim; y++) {
59 size_t yfix = static_cast<size_t>(y >= ysize * kBlockDim / 2) * 2;
60 const size_t off = y * kBlockDim * xsize;
61 for (size_t x = 0; x < xsize * kBlockDim; x += Lanes(df)) {
62 auto thr = Zero(df);
63 if (xsize == 1) {
64 HWY_ALIGN uint32_t kMask[kBlockDim] = {0, 0, 0, 0,
65 ~0u, ~0u, ~0u, ~0u};
66 const auto mask = MaskFromVec(BitCast(df, Load(du, kMask + x)));
67 thr =
68 IfThenElse(mask, Set(df, thres[yfix + 1]), Set(df, thres[yfix]));
69 } else {
70 // Same for all lanes in the vector.
71 thr = Set(
72 df,
73 thres[yfix + static_cast<size_t>(x >= xsize * kBlockDim / 2)]);
74 }
75
76 const auto q = Load(df, qm + off + x) * quant;
77 const auto in = Load(df, block_in + off + x);
78 const auto val = q * in;
79 const auto nzero_mask = Abs(val) >= thr;
80 const auto v = ConvertTo(di, IfThenElseZero(nzero_mask, Round(val)));
81 Store(v, di, block_out + off + x);
82 }
83 }
84 return;
85 }
86
87 retry:
88 int hfNonZeros[4] = {};
89 float hfError[4] = {};
90 float hfMaxError[4] = {};
91 size_t hfMaxErrorIx[4] = {};
92 for (size_t y = 0; y < ysize * kBlockDim; y++) {
93 for (size_t x = 0; x < xsize * kBlockDim; x++) {
94 const size_t pos = y * kBlockDim * xsize + x;
95 if (x < xsize && y < ysize) {
96 // Ensure block is initialized
97 block_out[pos] = 0;
98 continue;
99 }
100 const size_t hfix = (static_cast<size_t>(y >= ysize * kBlockDim / 2) * 2 +
101 static_cast<size_t>(x >= xsize * kBlockDim / 2));
102 const float val = block_in[pos] * (qm[pos] * qac * qm_multiplier);
103 float v = (std::abs(val) < thres[hfix]) ? 0 : rintf(val);
104 const float error = std::abs(val) - std::abs(v);
105 hfError[hfix] += error;
106 if (hfMaxError[hfix] < error) {
107 hfMaxError[hfix] = error;
108 hfMaxErrorIx[hfix] = pos;
109 }
110 if (v != 0.0f) {
111 hfNonZeros[hfix] += std::abs(v);
112 }
113 block_out[pos] = static_cast<int32_t>(rintf(v));
114 }
115 }
116 if (c != 1) return;
117 // TODO(veluca): include AFV?
118 const size_t kPartialBlockKinds =
119 (1 << AcStrategy::Type::IDENTITY) | (1 << AcStrategy::Type::DCT2X2) |
120 (1 << AcStrategy::Type::DCT4X4) | (1 << AcStrategy::Type::DCT4X8) |
121 (1 << AcStrategy::Type::DCT8X4);
122 if ((1 << quant_kind) & kPartialBlockKinds) return;
123 float hfErrorLimit = 0.1f * (xsize * ysize) * kDCTBlockSize * 0.25f;
124 bool goretry = false;
125 for (int i = 1; i < 4; ++i) {
126 if (hfError[i] >= hfErrorLimit &&
127 hfNonZeros[i] <= (xsize + ysize) * 0.25f) {
128 if (thres[i] >= 0.4f) {
129 thres[i] -= 0.01f;
130 goretry = true;
131 }
132 }
133 }
134 if (goretry) goto retry;
135 for (int i = 1; i < 4; ++i) {
136 if (hfError[i] >= hfErrorLimit && hfNonZeros[i] == 0) {
137 const size_t pos = hfMaxErrorIx[i];
138 if (hfMaxError[i] >= 0.4f) {
139 block_out[pos] = block_in[pos] > 0.0f ? 1.0f : -1.0f;
140 }
141 }
142 }
143 }
144
145 // NOTE: caller takes care of extracting quant from rect of RawQuantField.
QuantizeRoundtripYBlockAC(const Quantizer & quantizer,const bool error_diffusion,int32_t quant,size_t quant_kind,size_t xsize,size_t ysize,const float * JXL_RESTRICT biases,float * JXL_RESTRICT inout,int32_t * JXL_RESTRICT quantized)146 void QuantizeRoundtripYBlockAC(const Quantizer& quantizer,
147 const bool error_diffusion, int32_t quant,
148 size_t quant_kind, size_t xsize, size_t ysize,
149 const float* JXL_RESTRICT biases,
150 float* JXL_RESTRICT inout,
151 int32_t* JXL_RESTRICT quantized) {
152 QuantizeBlockAC(quantizer, error_diffusion, 1, quant, 1.0f, quant_kind, xsize,
153 ysize, inout, quantized);
154
155 PROFILER_ZONE("enc quant adjust bias");
156 const float* JXL_RESTRICT dequant_matrix =
157 quantizer.DequantMatrix(quant_kind, 1);
158
159 HWY_CAPPED(float, kDCTBlockSize) df;
160 HWY_CAPPED(int32_t, kDCTBlockSize) di;
161 const auto inv_qac = Set(df, quantizer.inv_quant_ac(quant));
162 for (size_t k = 0; k < kDCTBlockSize * xsize * ysize; k += Lanes(df)) {
163 const auto quant = Load(di, quantized + k);
164 const auto adj_quant = AdjustQuantBias(di, 1, quant, biases);
165 const auto dequantm = Load(df, dequant_matrix + k);
166 Store(adj_quant * dequantm * inv_qac, df, inout + k);
167 }
168 }
169
ComputeCoefficients(size_t group_idx,PassesEncoderState * enc_state,const Image3F & opsin,Image3F * dc)170 void ComputeCoefficients(size_t group_idx, PassesEncoderState* enc_state,
171 const Image3F& opsin, Image3F* dc) {
172 PROFILER_FUNC;
173 const Rect block_group_rect = enc_state->shared.BlockGroupRect(group_idx);
174 const Rect group_rect = enc_state->shared.GroupRect(group_idx);
175 const Rect cmap_rect(
176 block_group_rect.x0() / kColorTileDimInBlocks,
177 block_group_rect.y0() / kColorTileDimInBlocks,
178 DivCeil(block_group_rect.xsize(), kColorTileDimInBlocks),
179 DivCeil(block_group_rect.ysize(), kColorTileDimInBlocks));
180
181 const size_t xsize_blocks = block_group_rect.xsize();
182 const size_t ysize_blocks = block_group_rect.ysize();
183
184 const size_t dc_stride = static_cast<size_t>(dc->PixelsPerRow());
185 const size_t opsin_stride = static_cast<size_t>(opsin.PixelsPerRow());
186
187 const ImageI& full_quant_field = enc_state->shared.raw_quant_field;
188 const CompressParams& cparams = enc_state->cparams;
189
190 // TODO(veluca): consider strategies to reduce this memory.
191 auto mem = hwy::AllocateAligned<int32_t>(3 * AcStrategy::kMaxCoeffArea);
192 auto fmem = hwy::AllocateAligned<float>(5 * AcStrategy::kMaxCoeffArea);
193 float* JXL_RESTRICT scratch_space =
194 fmem.get() + 3 * AcStrategy::kMaxCoeffArea;
195 {
196 // Only use error diffusion in Squirrel mode or slower.
197 const bool error_diffusion = cparams.speed_tier <= SpeedTier::kSquirrel;
198 constexpr HWY_CAPPED(float, kDCTBlockSize) d;
199
200 int32_t* JXL_RESTRICT coeffs[kMaxNumPasses][3] = {};
201 size_t num_passes = enc_state->progressive_splitter.GetNumPasses();
202 JXL_DASSERT(num_passes > 0);
203 for (size_t i = 0; i < num_passes; i++) {
204 // TODO(veluca): 16-bit quantized coeffs are not implemented yet.
205 JXL_ASSERT(enc_state->coeffs[i]->Type() == ACType::k32);
206 for (size_t c = 0; c < 3; c++) {
207 coeffs[i][c] = enc_state->coeffs[i]->PlaneRow(c, group_idx, 0).ptr32;
208 }
209 }
210
211 HWY_ALIGN float* coeffs_in = fmem.get();
212 HWY_ALIGN int32_t* quantized = mem.get();
213
214 size_t offset = 0;
215
216 for (size_t by = 0; by < ysize_blocks; ++by) {
217 const int32_t* JXL_RESTRICT row_quant_ac =
218 block_group_rect.ConstRow(full_quant_field, by);
219 size_t ty = by / kColorTileDimInBlocks;
220 const int8_t* JXL_RESTRICT row_cmap[3] = {
221 cmap_rect.ConstRow(enc_state->shared.cmap.ytox_map, ty),
222 nullptr,
223 cmap_rect.ConstRow(enc_state->shared.cmap.ytob_map, ty),
224 };
225 const float* JXL_RESTRICT opsin_rows[3] = {
226 group_rect.ConstPlaneRow(opsin, 0, by * kBlockDim),
227 group_rect.ConstPlaneRow(opsin, 1, by * kBlockDim),
228 group_rect.ConstPlaneRow(opsin, 2, by * kBlockDim),
229 };
230 float* JXL_RESTRICT dc_rows[3] = {
231 block_group_rect.PlaneRow(dc, 0, by),
232 block_group_rect.PlaneRow(dc, 1, by),
233 block_group_rect.PlaneRow(dc, 2, by),
234 };
235 AcStrategyRow ac_strategy_row =
236 enc_state->shared.ac_strategy.ConstRow(block_group_rect, by);
237 for (size_t tx = 0; tx < DivCeil(xsize_blocks, kColorTileDimInBlocks);
238 tx++) {
239 const auto x_factor =
240 Set(d, enc_state->shared.cmap.YtoXRatio(row_cmap[0][tx]));
241 const auto b_factor =
242 Set(d, enc_state->shared.cmap.YtoBRatio(row_cmap[2][tx]));
243 for (size_t bx = tx * kColorTileDimInBlocks;
244 bx < xsize_blocks && bx < (tx + 1) * kColorTileDimInBlocks; ++bx) {
245 const AcStrategy acs = ac_strategy_row[bx];
246 if (!acs.IsFirstBlock()) continue;
247
248 size_t xblocks = acs.covered_blocks_x();
249 size_t yblocks = acs.covered_blocks_y();
250
251 CoefficientLayout(&yblocks, &xblocks);
252
253 size_t size = kDCTBlockSize * xblocks * yblocks;
254
255 // DCT Y channel, roundtrip-quantize it and set DC.
256 const int32_t quant_ac = row_quant_ac[bx];
257 TransformFromPixels(acs.Strategy(), opsin_rows[1] + bx * kBlockDim,
258 opsin_stride, coeffs_in + size, scratch_space);
259 DCFromLowestFrequencies(acs.Strategy(), coeffs_in + size,
260 dc_rows[1] + bx, dc_stride);
261 QuantizeRoundtripYBlockAC(
262 enc_state->shared.quantizer, error_diffusion, quant_ac,
263 acs.RawStrategy(), xblocks, yblocks, kDefaultQuantBias,
264 coeffs_in + size, quantized + size);
265
266 // DCT X and B channels
267 for (size_t c : {0, 2}) {
268 TransformFromPixels(acs.Strategy(), opsin_rows[c] + bx * kBlockDim,
269 opsin_stride, coeffs_in + c * size,
270 scratch_space);
271 }
272
273 // Unapply color correlation
274 for (size_t k = 0; k < size; k += Lanes(d)) {
275 const auto in_x = Load(d, coeffs_in + k);
276 const auto in_y = Load(d, coeffs_in + size + k);
277 const auto in_b = Load(d, coeffs_in + 2 * size + k);
278 const auto out_x = in_x - x_factor * in_y;
279 const auto out_b = in_b - b_factor * in_y;
280 Store(out_x, d, coeffs_in + k);
281 Store(out_b, d, coeffs_in + 2 * size + k);
282 }
283
284 // Quantize X and B channels and set DC.
285 for (size_t c : {0, 2}) {
286 QuantizeBlockAC(enc_state->shared.quantizer, error_diffusion, c,
287 quant_ac,
288 c == 0 ? enc_state->x_qm_multiplier
289 : enc_state->b_qm_multiplier,
290 acs.RawStrategy(), xblocks, yblocks,
291 coeffs_in + c * size, quantized + c * size);
292 DCFromLowestFrequencies(acs.Strategy(), coeffs_in + c * size,
293 dc_rows[c] + bx, dc_stride);
294 }
295 enc_state->progressive_splitter.SplitACCoefficients(
296 quantized, size, acs, bx, by, offset, coeffs);
297 offset += size;
298 }
299 }
300 }
301 }
302 }
303
304 // NOLINTNEXTLINE(google-readability-namespace-comments)
305 } // namespace HWY_NAMESPACE
306 } // namespace jxl
307 HWY_AFTER_NAMESPACE();
308
309 #if HWY_ONCE
310 namespace jxl {
311 HWY_EXPORT(ComputeCoefficients);
ComputeCoefficients(size_t group_idx,PassesEncoderState * enc_state,const Image3F & opsin,Image3F * dc)312 void ComputeCoefficients(size_t group_idx, PassesEncoderState* enc_state,
313 const Image3F& opsin, Image3F* dc) {
314 return HWY_DYNAMIC_DISPATCH(ComputeCoefficients)(group_idx, enc_state, opsin,
315 dc);
316 }
317
EncodeGroupTokenizedCoefficients(size_t group_idx,size_t pass_idx,size_t histogram_idx,const PassesEncoderState & enc_state,BitWriter * writer,AuxOut * aux_out)318 Status EncodeGroupTokenizedCoefficients(size_t group_idx, size_t pass_idx,
319 size_t histogram_idx,
320 const PassesEncoderState& enc_state,
321 BitWriter* writer, AuxOut* aux_out) {
322 // Select which histogram to use among those of the current pass.
323 const size_t num_histograms = enc_state.shared.num_histograms;
324 // num_histograms is 0 only for lossless.
325 JXL_ASSERT(num_histograms == 0 || histogram_idx < num_histograms);
326 size_t histo_selector_bits = CeilLog2Nonzero(num_histograms);
327
328 if (histo_selector_bits != 0) {
329 BitWriter::Allotment allotment(writer, histo_selector_bits);
330 writer->Write(histo_selector_bits, histogram_idx);
331 ReclaimAndCharge(writer, &allotment, kLayerAC, aux_out);
332 }
333 WriteTokens(enc_state.passes[pass_idx].ac_tokens[group_idx],
334 enc_state.passes[pass_idx].codes,
335 enc_state.passes[pass_idx].context_map, writer, kLayerACTokens,
336 aux_out);
337
338 return true;
339 }
340
341 } // namespace jxl
342 #endif // HWY_ONCE
343