1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
2 //
3 // Use of this source code is governed by a BSD-style
4 // license that can be found in the LICENSE file.
5 
6 #include "lib/jxl/enc_group.h"
7 
8 #include <utility>
9 
10 #include "hwy/aligned_allocator.h"
11 
12 #undef HWY_TARGET_INCLUDE
13 #define HWY_TARGET_INCLUDE "lib/jxl/enc_group.cc"
14 #include <hwy/foreach_target.h>
15 #include <hwy/highway.h>
16 
17 #include "lib/jxl/ac_strategy.h"
18 #include "lib/jxl/aux_out.h"
19 #include "lib/jxl/aux_out_fwd.h"
20 #include "lib/jxl/base/bits.h"
21 #include "lib/jxl/base/compiler_specific.h"
22 #include "lib/jxl/base/profiler.h"
23 #include "lib/jxl/common.h"
24 #include "lib/jxl/dct_util.h"
25 #include "lib/jxl/dec_transforms-inl.h"
26 #include "lib/jxl/enc_params.h"
27 #include "lib/jxl/enc_transforms-inl.h"
28 #include "lib/jxl/image.h"
29 #include "lib/jxl/quantizer-inl.h"
30 #include "lib/jxl/quantizer.h"
31 HWY_BEFORE_NAMESPACE();
32 namespace jxl {
33 namespace HWY_NAMESPACE {
34 
35 // NOTE: caller takes care of extracting quant from rect of RawQuantField.
QuantizeBlockAC(const Quantizer & quantizer,const bool error_diffusion,size_t c,int32_t quant,float qm_multiplier,size_t quant_kind,size_t xsize,size_t ysize,const float * JXL_RESTRICT block_in,int32_t * JXL_RESTRICT block_out)36 void QuantizeBlockAC(const Quantizer& quantizer, const bool error_diffusion,
37                      size_t c, int32_t quant, float qm_multiplier,
38                      size_t quant_kind, size_t xsize, size_t ysize,
39                      const float* JXL_RESTRICT block_in,
40                      int32_t* JXL_RESTRICT block_out) {
41   PROFILER_FUNC;
42   const float* JXL_RESTRICT qm = quantizer.InvDequantMatrix(quant_kind, c);
43   const float qac = quantizer.Scale() * quant;
44   // Not SIMD-fied for now.
45   float thres[4] = {0.5f, 0.6f, 0.6f, 0.65f};
46   if (c != 1) {
47     for (int i = 1; i < 4; ++i) {
48       thres[i] = 0.75f;
49     }
50   }
51 
52   if (!error_diffusion) {
53     HWY_CAPPED(float, kBlockDim) df;
54     HWY_CAPPED(int32_t, kBlockDim) di;
55     HWY_CAPPED(uint32_t, kBlockDim) du;
56     const auto quant = Set(df, qac * qm_multiplier);
57 
58     for (size_t y = 0; y < ysize * kBlockDim; y++) {
59       size_t yfix = static_cast<size_t>(y >= ysize * kBlockDim / 2) * 2;
60       const size_t off = y * kBlockDim * xsize;
61       for (size_t x = 0; x < xsize * kBlockDim; x += Lanes(df)) {
62         auto thr = Zero(df);
63         if (xsize == 1) {
64           HWY_ALIGN uint32_t kMask[kBlockDim] = {0,   0,   0,   0,
65                                                  ~0u, ~0u, ~0u, ~0u};
66           const auto mask = MaskFromVec(BitCast(df, Load(du, kMask + x)));
67           thr =
68               IfThenElse(mask, Set(df, thres[yfix + 1]), Set(df, thres[yfix]));
69         } else {
70           // Same for all lanes in the vector.
71           thr = Set(
72               df,
73               thres[yfix + static_cast<size_t>(x >= xsize * kBlockDim / 2)]);
74         }
75 
76         const auto q = Load(df, qm + off + x) * quant;
77         const auto in = Load(df, block_in + off + x);
78         const auto val = q * in;
79         const auto nzero_mask = Abs(val) >= thr;
80         const auto v = ConvertTo(di, IfThenElseZero(nzero_mask, Round(val)));
81         Store(v, di, block_out + off + x);
82       }
83     }
84     return;
85   }
86 
87 retry:
88   int hfNonZeros[4] = {};
89   float hfError[4] = {};
90   float hfMaxError[4] = {};
91   size_t hfMaxErrorIx[4] = {};
92   for (size_t y = 0; y < ysize * kBlockDim; y++) {
93     for (size_t x = 0; x < xsize * kBlockDim; x++) {
94       const size_t pos = y * kBlockDim * xsize + x;
95       if (x < xsize && y < ysize) {
96         // Ensure block is initialized
97         block_out[pos] = 0;
98         continue;
99       }
100       const size_t hfix = (static_cast<size_t>(y >= ysize * kBlockDim / 2) * 2 +
101                            static_cast<size_t>(x >= xsize * kBlockDim / 2));
102       const float val = block_in[pos] * (qm[pos] * qac * qm_multiplier);
103       float v = (std::abs(val) < thres[hfix]) ? 0 : rintf(val);
104       const float error = std::abs(val) - std::abs(v);
105       hfError[hfix] += error;
106       if (hfMaxError[hfix] < error) {
107         hfMaxError[hfix] = error;
108         hfMaxErrorIx[hfix] = pos;
109       }
110       if (v != 0.0f) {
111         hfNonZeros[hfix] += std::abs(v);
112       }
113       block_out[pos] = static_cast<int32_t>(rintf(v));
114     }
115   }
116   if (c != 1) return;
117   // TODO(veluca): include AFV?
118   const size_t kPartialBlockKinds =
119       (1 << AcStrategy::Type::IDENTITY) | (1 << AcStrategy::Type::DCT2X2) |
120       (1 << AcStrategy::Type::DCT4X4) | (1 << AcStrategy::Type::DCT4X8) |
121       (1 << AcStrategy::Type::DCT8X4);
122   if ((1 << quant_kind) & kPartialBlockKinds) return;
123   float hfErrorLimit = 0.1f * (xsize * ysize) * kDCTBlockSize * 0.25f;
124   bool goretry = false;
125   for (int i = 1; i < 4; ++i) {
126     if (hfError[i] >= hfErrorLimit &&
127         hfNonZeros[i] <= (xsize + ysize) * 0.25f) {
128       if (thres[i] >= 0.4f) {
129         thres[i] -= 0.01f;
130         goretry = true;
131       }
132     }
133   }
134   if (goretry) goto retry;
135   for (int i = 1; i < 4; ++i) {
136     if (hfError[i] >= hfErrorLimit && hfNonZeros[i] == 0) {
137       const size_t pos = hfMaxErrorIx[i];
138       if (hfMaxError[i] >= 0.4f) {
139         block_out[pos] = block_in[pos] > 0.0f ? 1.0f : -1.0f;
140       }
141     }
142   }
143 }
144 
145 // NOTE: caller takes care of extracting quant from rect of RawQuantField.
QuantizeRoundtripYBlockAC(const Quantizer & quantizer,const bool error_diffusion,int32_t quant,size_t quant_kind,size_t xsize,size_t ysize,const float * JXL_RESTRICT biases,float * JXL_RESTRICT inout,int32_t * JXL_RESTRICT quantized)146 void QuantizeRoundtripYBlockAC(const Quantizer& quantizer,
147                                const bool error_diffusion, int32_t quant,
148                                size_t quant_kind, size_t xsize, size_t ysize,
149                                const float* JXL_RESTRICT biases,
150                                float* JXL_RESTRICT inout,
151                                int32_t* JXL_RESTRICT quantized) {
152   QuantizeBlockAC(quantizer, error_diffusion, 1, quant, 1.0f, quant_kind, xsize,
153                   ysize, inout, quantized);
154 
155   PROFILER_ZONE("enc quant adjust bias");
156   const float* JXL_RESTRICT dequant_matrix =
157       quantizer.DequantMatrix(quant_kind, 1);
158 
159   HWY_CAPPED(float, kDCTBlockSize) df;
160   HWY_CAPPED(int32_t, kDCTBlockSize) di;
161   const auto inv_qac = Set(df, quantizer.inv_quant_ac(quant));
162   for (size_t k = 0; k < kDCTBlockSize * xsize * ysize; k += Lanes(df)) {
163     const auto quant = Load(di, quantized + k);
164     const auto adj_quant = AdjustQuantBias(di, 1, quant, biases);
165     const auto dequantm = Load(df, dequant_matrix + k);
166     Store(adj_quant * dequantm * inv_qac, df, inout + k);
167   }
168 }
169 
ComputeCoefficients(size_t group_idx,PassesEncoderState * enc_state,const Image3F & opsin,Image3F * dc)170 void ComputeCoefficients(size_t group_idx, PassesEncoderState* enc_state,
171                          const Image3F& opsin, Image3F* dc) {
172   PROFILER_FUNC;
173   const Rect block_group_rect = enc_state->shared.BlockGroupRect(group_idx);
174   const Rect group_rect = enc_state->shared.GroupRect(group_idx);
175   const Rect cmap_rect(
176       block_group_rect.x0() / kColorTileDimInBlocks,
177       block_group_rect.y0() / kColorTileDimInBlocks,
178       DivCeil(block_group_rect.xsize(), kColorTileDimInBlocks),
179       DivCeil(block_group_rect.ysize(), kColorTileDimInBlocks));
180 
181   const size_t xsize_blocks = block_group_rect.xsize();
182   const size_t ysize_blocks = block_group_rect.ysize();
183 
184   const size_t dc_stride = static_cast<size_t>(dc->PixelsPerRow());
185   const size_t opsin_stride = static_cast<size_t>(opsin.PixelsPerRow());
186 
187   const ImageI& full_quant_field = enc_state->shared.raw_quant_field;
188   const CompressParams& cparams = enc_state->cparams;
189 
190   // TODO(veluca): consider strategies to reduce this memory.
191   auto mem = hwy::AllocateAligned<int32_t>(3 * AcStrategy::kMaxCoeffArea);
192   auto fmem = hwy::AllocateAligned<float>(5 * AcStrategy::kMaxCoeffArea);
193   float* JXL_RESTRICT scratch_space =
194       fmem.get() + 3 * AcStrategy::kMaxCoeffArea;
195   {
196     // Only use error diffusion in Squirrel mode or slower.
197     const bool error_diffusion = cparams.speed_tier <= SpeedTier::kSquirrel;
198     constexpr HWY_CAPPED(float, kDCTBlockSize) d;
199 
200     int32_t* JXL_RESTRICT coeffs[kMaxNumPasses][3] = {};
201     size_t num_passes = enc_state->progressive_splitter.GetNumPasses();
202     JXL_DASSERT(num_passes > 0);
203     for (size_t i = 0; i < num_passes; i++) {
204       // TODO(veluca): 16-bit quantized coeffs are not implemented yet.
205       JXL_ASSERT(enc_state->coeffs[i]->Type() == ACType::k32);
206       for (size_t c = 0; c < 3; c++) {
207         coeffs[i][c] = enc_state->coeffs[i]->PlaneRow(c, group_idx, 0).ptr32;
208       }
209     }
210 
211     HWY_ALIGN float* coeffs_in = fmem.get();
212     HWY_ALIGN int32_t* quantized = mem.get();
213 
214     size_t offset = 0;
215 
216     for (size_t by = 0; by < ysize_blocks; ++by) {
217       const int32_t* JXL_RESTRICT row_quant_ac =
218           block_group_rect.ConstRow(full_quant_field, by);
219       size_t ty = by / kColorTileDimInBlocks;
220       const int8_t* JXL_RESTRICT row_cmap[3] = {
221           cmap_rect.ConstRow(enc_state->shared.cmap.ytox_map, ty),
222           nullptr,
223           cmap_rect.ConstRow(enc_state->shared.cmap.ytob_map, ty),
224       };
225       const float* JXL_RESTRICT opsin_rows[3] = {
226           group_rect.ConstPlaneRow(opsin, 0, by * kBlockDim),
227           group_rect.ConstPlaneRow(opsin, 1, by * kBlockDim),
228           group_rect.ConstPlaneRow(opsin, 2, by * kBlockDim),
229       };
230       float* JXL_RESTRICT dc_rows[3] = {
231           block_group_rect.PlaneRow(dc, 0, by),
232           block_group_rect.PlaneRow(dc, 1, by),
233           block_group_rect.PlaneRow(dc, 2, by),
234       };
235       AcStrategyRow ac_strategy_row =
236           enc_state->shared.ac_strategy.ConstRow(block_group_rect, by);
237       for (size_t tx = 0; tx < DivCeil(xsize_blocks, kColorTileDimInBlocks);
238            tx++) {
239         const auto x_factor =
240             Set(d, enc_state->shared.cmap.YtoXRatio(row_cmap[0][tx]));
241         const auto b_factor =
242             Set(d, enc_state->shared.cmap.YtoBRatio(row_cmap[2][tx]));
243         for (size_t bx = tx * kColorTileDimInBlocks;
244              bx < xsize_blocks && bx < (tx + 1) * kColorTileDimInBlocks; ++bx) {
245           const AcStrategy acs = ac_strategy_row[bx];
246           if (!acs.IsFirstBlock()) continue;
247 
248           size_t xblocks = acs.covered_blocks_x();
249           size_t yblocks = acs.covered_blocks_y();
250 
251           CoefficientLayout(&yblocks, &xblocks);
252 
253           size_t size = kDCTBlockSize * xblocks * yblocks;
254 
255           // DCT Y channel, roundtrip-quantize it and set DC.
256           const int32_t quant_ac = row_quant_ac[bx];
257           TransformFromPixels(acs.Strategy(), opsin_rows[1] + bx * kBlockDim,
258                               opsin_stride, coeffs_in + size, scratch_space);
259           DCFromLowestFrequencies(acs.Strategy(), coeffs_in + size,
260                                   dc_rows[1] + bx, dc_stride);
261           QuantizeRoundtripYBlockAC(
262               enc_state->shared.quantizer, error_diffusion, quant_ac,
263               acs.RawStrategy(), xblocks, yblocks, kDefaultQuantBias,
264               coeffs_in + size, quantized + size);
265 
266           // DCT X and B channels
267           for (size_t c : {0, 2}) {
268             TransformFromPixels(acs.Strategy(), opsin_rows[c] + bx * kBlockDim,
269                                 opsin_stride, coeffs_in + c * size,
270                                 scratch_space);
271           }
272 
273           // Unapply color correlation
274           for (size_t k = 0; k < size; k += Lanes(d)) {
275             const auto in_x = Load(d, coeffs_in + k);
276             const auto in_y = Load(d, coeffs_in + size + k);
277             const auto in_b = Load(d, coeffs_in + 2 * size + k);
278             const auto out_x = in_x - x_factor * in_y;
279             const auto out_b = in_b - b_factor * in_y;
280             Store(out_x, d, coeffs_in + k);
281             Store(out_b, d, coeffs_in + 2 * size + k);
282           }
283 
284           // Quantize X and B channels and set DC.
285           for (size_t c : {0, 2}) {
286             QuantizeBlockAC(enc_state->shared.quantizer, error_diffusion, c,
287                             quant_ac,
288                             c == 0 ? enc_state->x_qm_multiplier
289                                    : enc_state->b_qm_multiplier,
290                             acs.RawStrategy(), xblocks, yblocks,
291                             coeffs_in + c * size, quantized + c * size);
292             DCFromLowestFrequencies(acs.Strategy(), coeffs_in + c * size,
293                                     dc_rows[c] + bx, dc_stride);
294           }
295           enc_state->progressive_splitter.SplitACCoefficients(
296               quantized, size, acs, bx, by, offset, coeffs);
297           offset += size;
298         }
299       }
300     }
301   }
302 }
303 
304 // NOLINTNEXTLINE(google-readability-namespace-comments)
305 }  // namespace HWY_NAMESPACE
306 }  // namespace jxl
307 HWY_AFTER_NAMESPACE();
308 
309 #if HWY_ONCE
310 namespace jxl {
311 HWY_EXPORT(ComputeCoefficients);
ComputeCoefficients(size_t group_idx,PassesEncoderState * enc_state,const Image3F & opsin,Image3F * dc)312 void ComputeCoefficients(size_t group_idx, PassesEncoderState* enc_state,
313                          const Image3F& opsin, Image3F* dc) {
314   return HWY_DYNAMIC_DISPATCH(ComputeCoefficients)(group_idx, enc_state, opsin,
315                                                    dc);
316 }
317 
EncodeGroupTokenizedCoefficients(size_t group_idx,size_t pass_idx,size_t histogram_idx,const PassesEncoderState & enc_state,BitWriter * writer,AuxOut * aux_out)318 Status EncodeGroupTokenizedCoefficients(size_t group_idx, size_t pass_idx,
319                                         size_t histogram_idx,
320                                         const PassesEncoderState& enc_state,
321                                         BitWriter* writer, AuxOut* aux_out) {
322   // Select which histogram to use among those of the current pass.
323   const size_t num_histograms = enc_state.shared.num_histograms;
324   // num_histograms is 0 only for lossless.
325   JXL_ASSERT(num_histograms == 0 || histogram_idx < num_histograms);
326   size_t histo_selector_bits = CeilLog2Nonzero(num_histograms);
327 
328   if (histo_selector_bits != 0) {
329     BitWriter::Allotment allotment(writer, histo_selector_bits);
330     writer->Write(histo_selector_bits, histogram_idx);
331     ReclaimAndCharge(writer, &allotment, kLayerAC, aux_out);
332   }
333   WriteTokens(enc_state.passes[pass_idx].ac_tokens[group_idx],
334               enc_state.passes[pass_idx].codes,
335               enc_state.passes[pass_idx].context_map, writer, kLayerACTokens,
336               aux_out);
337 
338   return true;
339 }
340 
341 }  // namespace jxl
342 #endif  // HWY_ONCE
343