1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
2 //
3 // Use of this source code is governed by a BSD-style
4 // license that can be found in the LICENSE file.
5 
6 #include "lib/jxl/compressed_dc.h"
7 
8 #include <stdint.h>
9 #include <stdlib.h>
10 #include <string.h>
11 
12 #include <algorithm>
13 #include <array>
14 #include <memory>
15 #include <utility>
16 #include <vector>
17 
18 #undef HWY_TARGET_INCLUDE
19 #define HWY_TARGET_INCLUDE "lib/jxl/compressed_dc.cc"
20 #include <hwy/aligned_allocator.h>
21 #include <hwy/foreach_target.h>
22 #include <hwy/highway.h>
23 
24 #include "lib/jxl/ac_strategy.h"
25 #include "lib/jxl/ans_params.h"
26 #include "lib/jxl/aux_out.h"
27 #include "lib/jxl/aux_out_fwd.h"
28 #include "lib/jxl/base/bits.h"
29 #include "lib/jxl/base/compiler_specific.h"
30 #include "lib/jxl/base/data_parallel.h"
31 #include "lib/jxl/base/padded_bytes.h"
32 #include "lib/jxl/base/profiler.h"
33 #include "lib/jxl/base/status.h"
34 #include "lib/jxl/chroma_from_luma.h"
35 #include "lib/jxl/common.h"
36 #include "lib/jxl/dec_ans.h"
37 #include "lib/jxl/dec_bit_reader.h"
38 #include "lib/jxl/dec_cache.h"
39 #include "lib/jxl/entropy_coder.h"
40 #include "lib/jxl/image.h"
41 HWY_BEFORE_NAMESPACE();
42 namespace jxl {
43 namespace HWY_NAMESPACE {
44 
45 using D = HWY_FULL(float);
46 using DScalar = HWY_CAPPED(float, 1);
47 
48 // These templates are not found via ADL.
49 using hwy::HWY_NAMESPACE::Rebind;
50 using hwy::HWY_NAMESPACE::Vec;
51 
52 // TODO(veluca): optimize constants.
53 const float w1 = 0.20345139757231578f;
54 const float w2 = 0.0334829185968739f;
55 const float w0 = 1.0f - 4.0f * (w1 + w2);
56 
57 template <class V>
MaxWorkaround(V a,V b)58 V MaxWorkaround(V a, V b) {
59 #if (HWY_TARGET == HWY_AVX3) && HWY_COMPILER_CLANG <= 800
60   // Prevents "Do not know how to split the result of this operator" error
61   return IfThenElse(a > b, a, b);
62 #else
63   return Max(a, b);
64 #endif
65 }
66 
67 template <typename D>
ComputePixelChannel(const D d,const float dc_factor,const float * JXL_RESTRICT row_top,const float * JXL_RESTRICT row,const float * JXL_RESTRICT row_bottom,Vec<D> * JXL_RESTRICT mc,Vec<D> * JXL_RESTRICT sm,Vec<D> * JXL_RESTRICT gap,size_t x)68 JXL_INLINE void ComputePixelChannel(const D d, const float dc_factor,
69                                     const float* JXL_RESTRICT row_top,
70                                     const float* JXL_RESTRICT row,
71                                     const float* JXL_RESTRICT row_bottom,
72                                     Vec<D>* JXL_RESTRICT mc,
73                                     Vec<D>* JXL_RESTRICT sm,
74                                     Vec<D>* JXL_RESTRICT gap, size_t x) {
75   const auto tl = LoadU(d, row_top + x - 1);
76   const auto tc = Load(d, row_top + x);
77   const auto tr = LoadU(d, row_top + x + 1);
78 
79   const auto ml = LoadU(d, row + x - 1);
80   *mc = Load(d, row + x);
81   const auto mr = LoadU(d, row + x + 1);
82 
83   const auto bl = LoadU(d, row_bottom + x - 1);
84   const auto bc = Load(d, row_bottom + x);
85   const auto br = LoadU(d, row_bottom + x + 1);
86 
87   const auto w_center = Set(d, w0);
88   const auto w_side = Set(d, w1);
89   const auto w_corner = Set(d, w2);
90 
91   const auto corner = tl + tr + bl + br;
92   const auto side = ml + mr + tc + bc;
93   *sm = corner * w_corner + side * w_side + *mc * w_center;
94 
95   const auto dc_quant = Set(d, dc_factor);
96   *gap = MaxWorkaround(*gap, Abs((*mc - *sm) / dc_quant));
97 }
98 
99 template <typename D>
ComputePixel(const float * JXL_RESTRICT dc_factors,const float * JXL_RESTRICT * JXL_RESTRICT rows_top,const float * JXL_RESTRICT * JXL_RESTRICT rows,const float * JXL_RESTRICT * JXL_RESTRICT rows_bottom,float * JXL_RESTRICT * JXL_RESTRICT out_rows,size_t x)100 JXL_INLINE void ComputePixel(
101     const float* JXL_RESTRICT dc_factors,
102     const float* JXL_RESTRICT* JXL_RESTRICT rows_top,
103     const float* JXL_RESTRICT* JXL_RESTRICT rows,
104     const float* JXL_RESTRICT* JXL_RESTRICT rows_bottom,
105     float* JXL_RESTRICT* JXL_RESTRICT out_rows, size_t x) {
106   const D d;
107   auto mc_x = Undefined(d);
108   auto mc_y = Undefined(d);
109   auto mc_b = Undefined(d);
110   auto sm_x = Undefined(d);
111   auto sm_y = Undefined(d);
112   auto sm_b = Undefined(d);
113   auto gap = Set(d, 0.5f);
114   ComputePixelChannel(d, dc_factors[0], rows_top[0], rows[0], rows_bottom[0],
115                       &mc_x, &sm_x, &gap, x);
116   ComputePixelChannel(d, dc_factors[1], rows_top[1], rows[1], rows_bottom[1],
117                       &mc_y, &sm_y, &gap, x);
118   ComputePixelChannel(d, dc_factors[2], rows_top[2], rows[2], rows_bottom[2],
119                       &mc_b, &sm_b, &gap, x);
120   auto factor = MulAdd(Set(d, -4.0f), gap, Set(d, 3.0f));
121   factor = ZeroIfNegative(factor);
122 
123   auto out = MulAdd(sm_x - mc_x, factor, mc_x);
124   Store(out, d, out_rows[0] + x);
125   out = MulAdd(sm_y - mc_y, factor, mc_y);
126   Store(out, d, out_rows[1] + x);
127   out = MulAdd(sm_b - mc_b, factor, mc_b);
128   Store(out, d, out_rows[2] + x);
129 }
130 
AdaptiveDCSmoothing(const float * dc_factors,Image3F * dc,ThreadPool * pool)131 void AdaptiveDCSmoothing(const float* dc_factors, Image3F* dc,
132                          ThreadPool* pool) {
133   const size_t xsize = dc->xsize();
134   const size_t ysize = dc->ysize();
135   if (ysize <= 2 || xsize <= 2) return;
136 
137   // TODO(veluca): use tile-based processing?
138   // TODO(veluca): decide if changes to the y channel should be propagated to
139   // the x and b channels through color correlation.
140   JXL_ASSERT(w1 + w2 < 0.25f);
141 
142   PROFILER_FUNC;
143 
144   Image3F smoothed(xsize, ysize);
145   // Fill in borders that the loop below will not. First and last are unused.
146   for (size_t c = 0; c < 3; c++) {
147     for (size_t y : {size_t(0), ysize - 1}) {
148       memcpy(smoothed.PlaneRow(c, y), dc->PlaneRow(c, y),
149              xsize * sizeof(float));
150     }
151   }
152   auto process_row = [&](const uint32_t y, size_t /*thread*/) {
153     const float* JXL_RESTRICT rows_top[3]{
154         dc->ConstPlaneRow(0, y - 1),
155         dc->ConstPlaneRow(1, y - 1),
156         dc->ConstPlaneRow(2, y - 1),
157     };
158     const float* JXL_RESTRICT rows[3] = {
159         dc->ConstPlaneRow(0, y),
160         dc->ConstPlaneRow(1, y),
161         dc->ConstPlaneRow(2, y),
162     };
163     const float* JXL_RESTRICT rows_bottom[3] = {
164         dc->ConstPlaneRow(0, y + 1),
165         dc->ConstPlaneRow(1, y + 1),
166         dc->ConstPlaneRow(2, y + 1),
167     };
168     float* JXL_RESTRICT rows_out[3] = {
169         smoothed.PlaneRow(0, y),
170         smoothed.PlaneRow(1, y),
171         smoothed.PlaneRow(2, y),
172     };
173     for (size_t x : {size_t(0), xsize - 1}) {
174       for (size_t c = 0; c < 3; c++) {
175         rows_out[c][x] = rows[c][x];
176       }
177     }
178 
179     size_t x = 1;
180     // First pixels
181     const size_t N = Lanes(D());
182     for (; x < std::min(N, xsize - 1); x++) {
183       ComputePixel<DScalar>(dc_factors, rows_top, rows, rows_bottom, rows_out,
184                             x);
185     }
186     // Full vectors.
187     for (; x + N <= xsize - 1; x += N) {
188       ComputePixel<D>(dc_factors, rows_top, rows, rows_bottom, rows_out, x);
189     }
190     // Last pixels.
191     for (; x < xsize - 1; x++) {
192       ComputePixel<DScalar>(dc_factors, rows_top, rows, rows_bottom, rows_out,
193                             x);
194     }
195   };
196   JXL_CHECK(RunOnPool(pool, 1, ysize - 1, ThreadPool::NoInit, process_row,
197                       "DCSmoothingRow"));
198   dc->Swap(smoothed);
199 }
200 
201 // DC dequantization.
DequantDC(const Rect & r,Image3F * dc,ImageB * quant_dc,const Image & in,const float * dc_factors,float mul,const float * cfl_factors,YCbCrChromaSubsampling chroma_subsampling,const BlockCtxMap & bctx)202 void DequantDC(const Rect& r, Image3F* dc, ImageB* quant_dc, const Image& in,
203                const float* dc_factors, float mul, const float* cfl_factors,
204                YCbCrChromaSubsampling chroma_subsampling,
205                const BlockCtxMap& bctx) {
206   const HWY_FULL(float) df;
207   const Rebind<pixel_type, HWY_FULL(float)> di;  // assumes pixel_type <= float
208   if (chroma_subsampling.Is444()) {
209     const auto fac_x = Set(df, dc_factors[0] * mul);
210     const auto fac_y = Set(df, dc_factors[1] * mul);
211     const auto fac_b = Set(df, dc_factors[2] * mul);
212     const auto cfl_fac_x = Set(df, cfl_factors[0]);
213     const auto cfl_fac_b = Set(df, cfl_factors[2]);
214     for (size_t y = 0; y < r.ysize(); y++) {
215       float* dec_row_x = r.PlaneRow(dc, 0, y);
216       float* dec_row_y = r.PlaneRow(dc, 1, y);
217       float* dec_row_b = r.PlaneRow(dc, 2, y);
218       const int32_t* quant_row_x = in.channel[1].plane.Row(y);
219       const int32_t* quant_row_y = in.channel[0].plane.Row(y);
220       const int32_t* quant_row_b = in.channel[2].plane.Row(y);
221       for (size_t x = 0; x < r.xsize(); x += Lanes(di)) {
222         const auto in_q_x = Load(di, quant_row_x + x);
223         const auto in_q_y = Load(di, quant_row_y + x);
224         const auto in_q_b = Load(di, quant_row_b + x);
225         const auto in_x = ConvertTo(df, in_q_x) * fac_x;
226         const auto in_y = ConvertTo(df, in_q_y) * fac_y;
227         const auto in_b = ConvertTo(df, in_q_b) * fac_b;
228         Store(in_y, df, dec_row_y + x);
229         Store(MulAdd(in_y, cfl_fac_x, in_x), df, dec_row_x + x);
230         Store(MulAdd(in_y, cfl_fac_b, in_b), df, dec_row_b + x);
231       }
232     }
233   } else {
234     for (size_t c : {1, 0, 2}) {
235       Rect rect(r.x0() >> chroma_subsampling.HShift(c),
236                 r.y0() >> chroma_subsampling.VShift(c),
237                 r.xsize() >> chroma_subsampling.HShift(c),
238                 r.ysize() >> chroma_subsampling.VShift(c));
239       const auto fac = Set(df, dc_factors[c] * mul);
240       const Channel& ch = in.channel[c < 2 ? c ^ 1 : c];
241       for (size_t y = 0; y < rect.ysize(); y++) {
242         const int32_t* quant_row = ch.plane.Row(y);
243         float* row = rect.PlaneRow(dc, c, y);
244         for (size_t x = 0; x < rect.xsize(); x += Lanes(di)) {
245           const auto in_q = Load(di, quant_row + x);
246           const auto in = ConvertTo(df, in_q) * fac;
247           Store(in, df, row + x);
248         }
249       }
250     }
251   }
252   if (bctx.num_dc_ctxs <= 1) {
253     for (size_t y = 0; y < r.ysize(); y++) {
254       uint8_t* qdc_row = r.Row(quant_dc, y);
255       memset(qdc_row, 0, sizeof(*qdc_row) * r.xsize());
256     }
257   } else {
258     for (size_t y = 0; y < r.ysize(); y++) {
259       uint8_t* qdc_row_val = r.Row(quant_dc, y);
260       const int32_t* quant_row_x =
261           in.channel[1].plane.Row(y >> chroma_subsampling.VShift(0));
262       const int32_t* quant_row_y =
263           in.channel[0].plane.Row(y >> chroma_subsampling.VShift(1));
264       const int32_t* quant_row_b =
265           in.channel[2].plane.Row(y >> chroma_subsampling.VShift(2));
266       for (size_t x = 0; x < r.xsize(); x++) {
267         int bucket_x = 0, bucket_y = 0, bucket_b = 0;
268         for (int t : bctx.dc_thresholds[0]) {
269           if (quant_row_x[x >> chroma_subsampling.HShift(0)] > t) bucket_x++;
270         }
271         for (int t : bctx.dc_thresholds[1]) {
272           if (quant_row_y[x >> chroma_subsampling.HShift(1)] > t) bucket_y++;
273         }
274         for (int t : bctx.dc_thresholds[2]) {
275           if (quant_row_b[x >> chroma_subsampling.HShift(2)] > t) bucket_b++;
276         }
277         int bucket = bucket_x;
278         bucket *= bctx.dc_thresholds[2].size() + 1;
279         bucket += bucket_b;
280         bucket *= bctx.dc_thresholds[1].size() + 1;
281         bucket += bucket_y;
282         qdc_row_val[x] = bucket;
283       }
284     }
285   }
286 }
287 
288 // NOLINTNEXTLINE(google-readability-namespace-comments)
289 }  // namespace HWY_NAMESPACE
290 }  // namespace jxl
291 HWY_AFTER_NAMESPACE();
292 
293 #if HWY_ONCE
294 namespace jxl {
295 
296 HWY_EXPORT(DequantDC);
297 HWY_EXPORT(AdaptiveDCSmoothing);
AdaptiveDCSmoothing(const float * dc_factors,Image3F * dc,ThreadPool * pool)298 void AdaptiveDCSmoothing(const float* dc_factors, Image3F* dc,
299                          ThreadPool* pool) {
300   return HWY_DYNAMIC_DISPATCH(AdaptiveDCSmoothing)(dc_factors, dc, pool);
301 }
302 
DequantDC(const Rect & r,Image3F * dc,ImageB * quant_dc,const Image & in,const float * dc_factors,float mul,const float * cfl_factors,YCbCrChromaSubsampling chroma_subsampling,const BlockCtxMap & bctx)303 void DequantDC(const Rect& r, Image3F* dc, ImageB* quant_dc, const Image& in,
304                const float* dc_factors, float mul, const float* cfl_factors,
305                YCbCrChromaSubsampling chroma_subsampling,
306                const BlockCtxMap& bctx) {
307   return HWY_DYNAMIC_DISPATCH(DequantDC)(r, dc, quant_dc, in, dc_factors, mul,
308                                          cfl_factors, chroma_subsampling, bctx);
309 }
310 
311 }  // namespace jxl
312 #endif  // HWY_ONCE
313