1 // Copyright (c) the JPEG XL Project Authors. All rights reserved.
2 //
3 // Use of this source code is governed by a BSD-style
4 // license that can be found in the LICENSE file.
5 
6 // Fast SIMD floating-point (I)DCT, any power of two.
7 
8 #if defined(LIB_JXL_DCT_INL_H_) == defined(HWY_TARGET_TOGGLE)
9 #ifdef LIB_JXL_DCT_INL_H_
10 #undef LIB_JXL_DCT_INL_H_
11 #else
12 #define LIB_JXL_DCT_INL_H_
13 #endif
14 
15 #include <stddef.h>
16 
17 #include <hwy/highway.h>
18 
19 #include "lib/jxl/dct_block-inl.h"
20 #include "lib/jxl/dct_scales.h"
21 #include "lib/jxl/transpose-inl.h"
22 HWY_BEFORE_NAMESPACE();
23 namespace jxl {
24 namespace HWY_NAMESPACE {
25 namespace {
26 
27 template <size_t SZ>
28 struct FVImpl {
29   using type = HWY_CAPPED(float, SZ);
30 };
31 
32 template <>
33 struct FVImpl<0> {
34   using type = HWY_FULL(float);
35 };
36 
37 template <size_t SZ>
38 using FV = typename FVImpl<SZ>::type;
39 
40 // Implementation of Lowest Complexity Self Recursive Radix-2 DCT II/III
41 // Algorithms, by Siriani M. Perera and Jianhua Liu.
42 
43 template <size_t N, size_t SZ>
44 struct CoeffBundle {
45   static void AddReverse(const float* JXL_RESTRICT ain1,
46                          const float* JXL_RESTRICT ain2,
47                          float* JXL_RESTRICT aout) {
48     for (size_t i = 0; i < N; i++) {
49       auto in1 = Load(FV<SZ>(), ain1 + i * SZ);
50       auto in2 = Load(FV<SZ>(), ain2 + (N - i - 1) * SZ);
51       Store(in1 + in2, FV<SZ>(), aout + i * SZ);
52     }
53   }
54   static void SubReverse(const float* JXL_RESTRICT ain1,
55                          const float* JXL_RESTRICT ain2,
56                          float* JXL_RESTRICT aout) {
57     for (size_t i = 0; i < N; i++) {
58       auto in1 = Load(FV<SZ>(), ain1 + i * SZ);
59       auto in2 = Load(FV<SZ>(), ain2 + (N - i - 1) * SZ);
60       Store(in1 - in2, FV<SZ>(), aout + i * SZ);
61     }
62   }
63   static void B(float* JXL_RESTRICT coeff) {
64     auto sqrt2 = Set(FV<SZ>(), kSqrt2);
65     auto in1 = Load(FV<SZ>(), coeff);
66     auto in2 = Load(FV<SZ>(), coeff + SZ);
67     Store(MulAdd(in1, sqrt2, in2), FV<SZ>(), coeff);
68     for (size_t i = 1; i + 1 < N; i++) {
69       auto in1 = Load(FV<SZ>(), coeff + i * SZ);
70       auto in2 = Load(FV<SZ>(), coeff + (i + 1) * SZ);
71       Store(in1 + in2, FV<SZ>(), coeff + i * SZ);
72     }
73   }
74   static void BTranspose(float* JXL_RESTRICT coeff) {
75     for (size_t i = N - 1; i > 0; i--) {
76       auto in1 = Load(FV<SZ>(), coeff + i * SZ);
77       auto in2 = Load(FV<SZ>(), coeff + (i - 1) * SZ);
78       Store(in1 + in2, FV<SZ>(), coeff + i * SZ);
79     }
80     auto sqrt2 = Set(FV<SZ>(), kSqrt2);
81     auto in1 = Load(FV<SZ>(), coeff);
82     Store(in1 * sqrt2, FV<SZ>(), coeff);
83   }
84   // Ideally optimized away by compiler (except the multiply).
85   static void InverseEvenOdd(const float* JXL_RESTRICT ain,
86                              float* JXL_RESTRICT aout) {
87     for (size_t i = 0; i < N / 2; i++) {
88       auto in1 = Load(FV<SZ>(), ain + i * SZ);
89       Store(in1, FV<SZ>(), aout + 2 * i * SZ);
90     }
91     for (size_t i = N / 2; i < N; i++) {
92       auto in1 = Load(FV<SZ>(), ain + i * SZ);
93       Store(in1, FV<SZ>(), aout + (2 * (i - N / 2) + 1) * SZ);
94     }
95   }
96   // Ideally optimized away by compiler.
97   static void ForwardEvenOdd(const float* JXL_RESTRICT ain, size_t ain_stride,
98                              float* JXL_RESTRICT aout) {
99     for (size_t i = 0; i < N / 2; i++) {
100       auto in1 = LoadU(FV<SZ>(), ain + 2 * i * ain_stride);
101       Store(in1, FV<SZ>(), aout + i * SZ);
102     }
103     for (size_t i = N / 2; i < N; i++) {
104       auto in1 = LoadU(FV<SZ>(), ain + (2 * (i - N / 2) + 1) * ain_stride);
105       Store(in1, FV<SZ>(), aout + i * SZ);
106     }
107   }
108   // Invoked on full vector.
109   static void Multiply(float* JXL_RESTRICT coeff) {
110     for (size_t i = 0; i < N / 2; i++) {
111       auto in1 = Load(FV<SZ>(), coeff + (N / 2 + i) * SZ);
112       auto mul = Set(FV<SZ>(), WcMultipliers<N>::kMultipliers[i]);
113       Store(in1 * mul, FV<SZ>(), coeff + (N / 2 + i) * SZ);
114     }
115   }
116   static void MultiplyAndAdd(const float* JXL_RESTRICT coeff,
117                              float* JXL_RESTRICT out, size_t out_stride) {
118     for (size_t i = 0; i < N / 2; i++) {
119       auto mul = Set(FV<SZ>(), WcMultipliers<N>::kMultipliers[i]);
120       auto in1 = Load(FV<SZ>(), coeff + i * SZ);
121       auto in2 = Load(FV<SZ>(), coeff + (N / 2 + i) * SZ);
122       auto out1 = MulAdd(mul, in2, in1);
123       auto out2 = NegMulAdd(mul, in2, in1);
124       StoreU(out1, FV<SZ>(), out + i * out_stride);
125       StoreU(out2, FV<SZ>(), out + (N - i - 1) * out_stride);
126     }
127   }
128   template <typename Block>
129   static void LoadFromBlock(const Block& in, size_t off,
130                             float* JXL_RESTRICT coeff) {
131     for (size_t i = 0; i < N; i++) {
132       Store(in.LoadPart(FV<SZ>(), i, off), FV<SZ>(), coeff + i * SZ);
133     }
134   }
135   template <typename Block>
136   static void StoreToBlockAndScale(const float* JXL_RESTRICT coeff,
137                                    const Block& out, size_t off) {
138     auto mul = Set(FV<SZ>(), 1.0f / N);
139     for (size_t i = 0; i < N; i++) {
140       out.StorePart(FV<SZ>(), mul * Load(FV<SZ>(), coeff + i * SZ), i, off);
141     }
142   }
143 };
144 
145 template <size_t N, size_t SZ>
146 struct DCT1DImpl;
147 
148 template <size_t SZ>
149 struct DCT1DImpl<1, SZ> {
150   JXL_INLINE void operator()(float* JXL_RESTRICT mem) {}
151 };
152 
153 template <size_t SZ>
154 struct DCT1DImpl<2, SZ> {
155   JXL_INLINE void operator()(float* JXL_RESTRICT mem) {
156     auto in1 = Load(FV<SZ>(), mem);
157     auto in2 = Load(FV<SZ>(), mem + SZ);
158     Store(in1 + in2, FV<SZ>(), mem);
159     Store(in1 - in2, FV<SZ>(), mem + SZ);
160   }
161 };
162 
163 template <size_t N, size_t SZ>
164 struct DCT1DImpl {
165   void operator()(float* JXL_RESTRICT mem) {
166     // This is relatively small (4kB with 64-DCT and AVX-512)
167     HWY_ALIGN float tmp[N * SZ];
168     CoeffBundle<N / 2, SZ>::AddReverse(mem, mem + N / 2 * SZ, tmp);
169     DCT1DImpl<N / 2, SZ>()(tmp);
170     CoeffBundle<N / 2, SZ>::SubReverse(mem, mem + N / 2 * SZ, tmp + N / 2 * SZ);
171     CoeffBundle<N, SZ>::Multiply(tmp);
172     DCT1DImpl<N / 2, SZ>()(tmp + N / 2 * SZ);
173     CoeffBundle<N / 2, SZ>::B(tmp + N / 2 * SZ);
174     CoeffBundle<N, SZ>::InverseEvenOdd(tmp, mem);
175   }
176 };
177 
178 template <size_t N, size_t SZ>
179 struct IDCT1DImpl;
180 
181 template <size_t SZ>
182 struct IDCT1DImpl<1, SZ> {
183   JXL_INLINE void operator()(const float* from, size_t from_stride, float* to,
184                              size_t to_stride) {
185     StoreU(LoadU(FV<SZ>(), from), FV<SZ>(), to);
186   }
187 };
188 
189 template <size_t SZ>
190 struct IDCT1DImpl<2, SZ> {
191   JXL_INLINE void operator()(const float* from, size_t from_stride, float* to,
192                              size_t to_stride) {
193     JXL_DASSERT(from_stride >= SZ);
194     JXL_DASSERT(to_stride >= SZ);
195     auto in1 = LoadU(FV<SZ>(), from);
196     auto in2 = LoadU(FV<SZ>(), from + from_stride);
197     StoreU(in1 + in2, FV<SZ>(), to);
198     StoreU(in1 - in2, FV<SZ>(), to + to_stride);
199   }
200 };
201 
202 template <size_t N, size_t SZ>
203 struct IDCT1DImpl {
204   void operator()(const float* from, size_t from_stride, float* to,
205                   size_t to_stride) {
206     JXL_DASSERT(from_stride >= SZ);
207     JXL_DASSERT(to_stride >= SZ);
208     // This is relatively small (4kB with 64-DCT and AVX-512)
209     HWY_ALIGN float tmp[N * SZ];
210     CoeffBundle<N, SZ>::ForwardEvenOdd(from, from_stride, tmp);
211     IDCT1DImpl<N / 2, SZ>()(tmp, SZ, tmp, SZ);
212     CoeffBundle<N / 2, SZ>::BTranspose(tmp + N / 2 * SZ);
213     IDCT1DImpl<N / 2, SZ>()(tmp + N / 2 * SZ, SZ, tmp + N / 2 * SZ, SZ);
214     CoeffBundle<N, SZ>::MultiplyAndAdd(tmp, to, to_stride);
215   }
216 };
217 
218 template <size_t N, size_t M_or_0, typename FromBlock, typename ToBlock>
219 void DCT1DWrapper(const FromBlock& from, const ToBlock& to, size_t Mp) {
220   size_t M = M_or_0 != 0 ? M_or_0 : Mp;
221   constexpr size_t SZ = MaxLanes(FV<M_or_0>());
222   HWY_ALIGN float tmp[N * SZ];
223   for (size_t i = 0; i < M; i += Lanes(FV<M_or_0>())) {
224     // TODO(veluca): consider removing the temporary memory here (as is done in
225     // IDCT), if it turns out that some compilers don't optimize away the loads
226     // and this is performance-critical.
227     CoeffBundle<N, SZ>::LoadFromBlock(from, i, tmp);
228     DCT1DImpl<N, SZ>()(tmp);
229     CoeffBundle<N, SZ>::StoreToBlockAndScale(tmp, to, i);
230   }
231 }
232 
233 template <size_t N, size_t M_or_0, typename FromBlock, typename ToBlock>
234 void IDCT1DWrapper(const FromBlock& from, const ToBlock& to, size_t Mp) {
235   size_t M = M_or_0 != 0 ? M_or_0 : Mp;
236   constexpr size_t SZ = MaxLanes(FV<M_or_0>());
237   for (size_t i = 0; i < M; i += Lanes(FV<M_or_0>())) {
238     IDCT1DImpl<N, SZ>()(from.Address(0, i), from.Stride(), to.Address(0, i),
239                         to.Stride());
240   }
241 }
242 
243 template <size_t N, size_t M, typename = void>
244 struct DCT1D {
245   template <typename FromBlock, typename ToBlock>
246   void operator()(const FromBlock& from, const ToBlock& to) {
247     return DCT1DWrapper<N, M>(from, to, M);
248   }
249 };
250 
251 template <size_t N, size_t M>
252 struct DCT1D<N, M, typename std::enable_if<(M > MaxLanes(FV<0>()))>::type> {
253   template <typename FromBlock, typename ToBlock>
254   void operator()(const FromBlock& from, const ToBlock& to) {
255     return NoInlineWrapper(DCT1DWrapper<N, 0, FromBlock, ToBlock>, from, to, M);
256   }
257 };
258 
259 template <size_t N, size_t M, typename = void>
260 struct IDCT1D {
261   template <typename FromBlock, typename ToBlock>
262   void operator()(const FromBlock& from, const ToBlock& to) {
263     return IDCT1DWrapper<N, M>(from, to, M);
264   }
265 };
266 
267 template <size_t N, size_t M>
268 struct IDCT1D<N, M, typename std::enable_if<(M > MaxLanes(FV<0>()))>::type> {
269   template <typename FromBlock, typename ToBlock>
270   void operator()(const FromBlock& from, const ToBlock& to) {
271     return NoInlineWrapper(IDCT1DWrapper<N, 0, FromBlock, ToBlock>, from, to,
272                            M);
273   }
274 };
275 
276 // Computes the maybe-transposed, scaled DCT of a block, that needs to be
277 // HWY_ALIGN'ed.
278 template <size_t ROWS, size_t COLS>
279 struct ComputeScaledDCT {
280   // scratch_space must be aligned, and should have space for ROWS*COLS
281   // floats.
282   template <class From>
283   HWY_MAYBE_UNUSED void operator()(const From& from, float* to,
284                                    float* JXL_RESTRICT scratch_space) {
285     float* JXL_RESTRICT block = scratch_space;
286     if (ROWS < COLS) {
287       DCT1D<ROWS, COLS>()(from, DCTTo(block, COLS));
288       Transpose<ROWS, COLS>::Run(DCTFrom(block, COLS), DCTTo(to, ROWS));
289       DCT1D<COLS, ROWS>()(DCTFrom(to, ROWS), DCTTo(block, ROWS));
290       Transpose<COLS, ROWS>::Run(DCTFrom(block, ROWS), DCTTo(to, COLS));
291     } else {
292       DCT1D<ROWS, COLS>()(from, DCTTo(to, COLS));
293       Transpose<ROWS, COLS>::Run(DCTFrom(to, COLS), DCTTo(block, ROWS));
294       DCT1D<COLS, ROWS>()(DCTFrom(block, ROWS), DCTTo(to, ROWS));
295     }
296   }
297 };
298 // Computes the maybe-transposed, scaled IDCT of a block, that needs to be
299 // HWY_ALIGN'ed.
300 template <size_t ROWS, size_t COLS>
301 struct ComputeScaledIDCT {
302   // scratch_space must be aligned, and should have space for ROWS*COLS
303   // floats.
304   template <class To>
305   HWY_MAYBE_UNUSED void operator()(float* JXL_RESTRICT from, const To& to,
306                                    float* JXL_RESTRICT scratch_space) {
307     float* JXL_RESTRICT block = scratch_space;
308     // Reverse the steps done in ComputeScaledDCT.
309     if (ROWS < COLS) {
310       Transpose<ROWS, COLS>::Run(DCTFrom(from, COLS), DCTTo(block, ROWS));
311       IDCT1D<COLS, ROWS>()(DCTFrom(block, ROWS), DCTTo(from, ROWS));
312       Transpose<COLS, ROWS>::Run(DCTFrom(from, ROWS), DCTTo(block, COLS));
313       IDCT1D<ROWS, COLS>()(DCTFrom(block, COLS), to);
314     } else {
315       IDCT1D<COLS, ROWS>()(DCTFrom(from, ROWS), DCTTo(block, ROWS));
316       Transpose<COLS, ROWS>::Run(DCTFrom(block, ROWS), DCTTo(from, COLS));
317       IDCT1D<ROWS, COLS>()(DCTFrom(from, COLS), to);
318     }
319   }
320 };
321 
322 }  // namespace
323 // NOLINTNEXTLINE(google-readability-namespace-comments)
324 }  // namespace HWY_NAMESPACE
325 }  // namespace jxl
326 HWY_AFTER_NAMESPACE();
327 #endif  // LIB_JXL_DCT_INL_H_
328