1 // Copyright 2021 The libgav1 Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "src/dsp/inverse_transform.h"
16 #include "src/utils/cpu.h"
17 
18 #if LIBGAV1_ENABLE_NEON && LIBGAV1_MAX_BITDEPTH >= 10
19 
20 #include <arm_neon.h>
21 
22 #include <algorithm>
23 #include <cassert>
24 #include <cstdint>
25 
26 #include "src/dsp/arm/common_neon.h"
27 #include "src/dsp/constants.h"
28 #include "src/dsp/dsp.h"
29 #include "src/utils/array_2d.h"
30 #include "src/utils/common.h"
31 #include "src/utils/compiler_attributes.h"
32 #include "src/utils/constants.h"
33 
34 namespace libgav1 {
35 namespace dsp {
36 namespace {
37 
38 // Include the constants and utility functions inside the anonymous namespace.
39 #include "src/dsp/inverse_transform.inc"
40 
41 //------------------------------------------------------------------------------
42 
Transpose4x4(const int32x4_t in[4],int32x4_t out[4])43 LIBGAV1_ALWAYS_INLINE void Transpose4x4(const int32x4_t in[4],
44                                         int32x4_t out[4]) {
45   // in:
46   // 00 01 02 03
47   // 10 11 12 13
48   // 20 21 22 23
49   // 30 31 32 33
50 
51   // 00 10 02 12   a.val[0]
52   // 01 11 03 13   a.val[1]
53   // 20 30 22 32   b.val[0]
54   // 21 31 23 33   b.val[1]
55   const int32x4x2_t a = vtrnq_s32(in[0], in[1]);
56   const int32x4x2_t b = vtrnq_s32(in[2], in[3]);
57   out[0] = vextq_s32(vextq_s32(a.val[0], a.val[0], 2), b.val[0], 2);
58   out[1] = vextq_s32(vextq_s32(a.val[1], a.val[1], 2), b.val[1], 2);
59   out[2] = vextq_s32(a.val[0], vextq_s32(b.val[0], b.val[0], 2), 2);
60   out[3] = vextq_s32(a.val[1], vextq_s32(b.val[1], b.val[1], 2), 2);
61   // out:
62   // 00 10 20 30
63   // 01 11 21 31
64   // 02 12 22 32
65   // 03 13 23 33
66 }
67 
68 //------------------------------------------------------------------------------
69 template <int store_count>
StoreDst(int32_t * LIBGAV1_RESTRICT dst,int32_t stride,int32_t idx,const int32x4_t * const s)70 LIBGAV1_ALWAYS_INLINE void StoreDst(int32_t* LIBGAV1_RESTRICT dst,
71                                     int32_t stride, int32_t idx,
72                                     const int32x4_t* const s) {
73   assert(store_count % 4 == 0);
74   for (int i = 0; i < store_count; i += 4) {
75     vst1q_s32(&dst[i * stride + idx], s[i]);
76     vst1q_s32(&dst[(i + 1) * stride + idx], s[i + 1]);
77     vst1q_s32(&dst[(i + 2) * stride + idx], s[i + 2]);
78     vst1q_s32(&dst[(i + 3) * stride + idx], s[i + 3]);
79   }
80 }
81 
82 template <int load_count>
LoadSrc(const int32_t * LIBGAV1_RESTRICT src,int32_t stride,int32_t idx,int32x4_t * x)83 LIBGAV1_ALWAYS_INLINE void LoadSrc(const int32_t* LIBGAV1_RESTRICT src,
84                                    int32_t stride, int32_t idx, int32x4_t* x) {
85   assert(load_count % 4 == 0);
86   for (int i = 0; i < load_count; i += 4) {
87     x[i] = vld1q_s32(&src[i * stride + idx]);
88     x[i + 1] = vld1q_s32(&src[(i + 1) * stride + idx]);
89     x[i + 2] = vld1q_s32(&src[(i + 2) * stride + idx]);
90     x[i + 3] = vld1q_s32(&src[(i + 3) * stride + idx]);
91   }
92 }
93 
94 // Butterfly rotate 4 values.
ButterflyRotation_4(int32x4_t * a,int32x4_t * b,const int angle,const bool flip)95 LIBGAV1_ALWAYS_INLINE void ButterflyRotation_4(int32x4_t* a, int32x4_t* b,
96                                                const int angle,
97                                                const bool flip) {
98   const int32_t cos128 = Cos128(angle);
99   const int32_t sin128 = Sin128(angle);
100   const int32x4_t acc_x = vmulq_n_s32(*a, cos128);
101   const int32x4_t acc_y = vmulq_n_s32(*a, sin128);
102   // The max range for the input is 18 bits. The cos128/sin128 is 13 bits,
103   // which leaves 1 bit for the add/subtract. For 10bpp, x/y will fit in a 32
104   // bit lane.
105   const int32x4_t x0 = vmlsq_n_s32(acc_x, *b, sin128);
106   const int32x4_t y0 = vmlaq_n_s32(acc_y, *b, cos128);
107   const int32x4_t x = vrshrq_n_s32(x0, 12);
108   const int32x4_t y = vrshrq_n_s32(y0, 12);
109   if (flip) {
110     *a = y;
111     *b = x;
112   } else {
113     *a = x;
114     *b = y;
115   }
116 }
117 
ButterflyRotation_FirstIsZero(int32x4_t * a,int32x4_t * b,const int angle,const bool flip)118 LIBGAV1_ALWAYS_INLINE void ButterflyRotation_FirstIsZero(int32x4_t* a,
119                                                          int32x4_t* b,
120                                                          const int angle,
121                                                          const bool flip) {
122   const int32_t cos128 = Cos128(angle);
123   const int32_t sin128 = Sin128(angle);
124   assert(sin128 <= 0xfff);
125   const int32x4_t x0 = vmulq_n_s32(*b, -sin128);
126   const int32x4_t y0 = vmulq_n_s32(*b, cos128);
127   const int32x4_t x = vrshrq_n_s32(x0, 12);
128   const int32x4_t y = vrshrq_n_s32(y0, 12);
129   if (flip) {
130     *a = y;
131     *b = x;
132   } else {
133     *a = x;
134     *b = y;
135   }
136 }
137 
ButterflyRotation_SecondIsZero(int32x4_t * a,int32x4_t * b,const int angle,const bool flip)138 LIBGAV1_ALWAYS_INLINE void ButterflyRotation_SecondIsZero(int32x4_t* a,
139                                                           int32x4_t* b,
140                                                           const int angle,
141                                                           const bool flip) {
142   const int32_t cos128 = Cos128(angle);
143   const int32_t sin128 = Sin128(angle);
144   const int32x4_t x0 = vmulq_n_s32(*a, cos128);
145   const int32x4_t y0 = vmulq_n_s32(*a, sin128);
146   const int32x4_t x = vrshrq_n_s32(x0, 12);
147   const int32x4_t y = vrshrq_n_s32(y0, 12);
148   if (flip) {
149     *a = y;
150     *b = x;
151   } else {
152     *a = x;
153     *b = y;
154   }
155 }
156 
HadamardRotation(int32x4_t * a,int32x4_t * b,bool flip)157 LIBGAV1_ALWAYS_INLINE void HadamardRotation(int32x4_t* a, int32x4_t* b,
158                                             bool flip) {
159   int32x4_t x, y;
160   if (flip) {
161     y = vqaddq_s32(*b, *a);
162     x = vqsubq_s32(*b, *a);
163   } else {
164     x = vqaddq_s32(*a, *b);
165     y = vqsubq_s32(*a, *b);
166   }
167   *a = x;
168   *b = y;
169 }
170 
HadamardRotation(int32x4_t * a,int32x4_t * b,bool flip,const int32x4_t min,const int32x4_t max)171 LIBGAV1_ALWAYS_INLINE void HadamardRotation(int32x4_t* a, int32x4_t* b,
172                                             bool flip, const int32x4_t min,
173                                             const int32x4_t max) {
174   int32x4_t x, y;
175   if (flip) {
176     y = vqaddq_s32(*b, *a);
177     x = vqsubq_s32(*b, *a);
178   } else {
179     x = vqaddq_s32(*a, *b);
180     y = vqsubq_s32(*a, *b);
181   }
182   *a = vmaxq_s32(vminq_s32(x, max), min);
183   *b = vmaxq_s32(vminq_s32(y, max), min);
184 }
185 
186 using ButterflyRotationFunc = void (*)(int32x4_t* a, int32x4_t* b, int angle,
187                                        bool flip);
188 
189 //------------------------------------------------------------------------------
190 // Discrete Cosine Transforms (DCT).
191 
192 template <int width>
DctDcOnly(void * dest,int adjusted_tx_height,bool should_round,int row_shift)193 LIBGAV1_ALWAYS_INLINE bool DctDcOnly(void* dest, int adjusted_tx_height,
194                                      bool should_round, int row_shift) {
195   if (adjusted_tx_height > 1) return false;
196 
197   auto* dst = static_cast<int32_t*>(dest);
198   const int32x4_t v_src = vdupq_n_s32(dst[0]);
199   const uint32x4_t v_mask = vdupq_n_u32(should_round ? 0xffffffff : 0);
200   const int32x4_t v_src_round =
201       vqrdmulhq_n_s32(v_src, kTransformRowMultiplier << (31 - 12));
202   const int32x4_t s0 = vbslq_s32(v_mask, v_src_round, v_src);
203   const int32_t cos128 = Cos128(32);
204   const int32x4_t xy = vqrdmulhq_n_s32(s0, cos128 << (31 - 12));
205   // vqrshlq_s32 will shift right if shift value is negative.
206   const int32x4_t xy_shifted = vqrshlq_s32(xy, vdupq_n_s32(-row_shift));
207   // Clamp result to signed 16 bits.
208   const int32x4_t result = vmovl_s16(vqmovn_s32(xy_shifted));
209   if (width == 4) {
210     vst1q_s32(dst, result);
211   } else {
212     for (int i = 0; i < width; i += 4) {
213       vst1q_s32(dst, result);
214       dst += 4;
215     }
216   }
217   return true;
218 }
219 
220 template <int height>
DctDcOnlyColumn(void * dest,int adjusted_tx_height,int width)221 LIBGAV1_ALWAYS_INLINE bool DctDcOnlyColumn(void* dest, int adjusted_tx_height,
222                                            int width) {
223   if (adjusted_tx_height > 1) return false;
224 
225   auto* dst = static_cast<int32_t*>(dest);
226   const int32_t cos128 = Cos128(32);
227 
228   // Calculate dc values for first row.
229   if (width == 4) {
230     const int32x4_t v_src = vld1q_s32(dst);
231     const int32x4_t xy = vqrdmulhq_n_s32(v_src, cos128 << (31 - 12));
232     vst1q_s32(dst, xy);
233   } else {
234     int i = 0;
235     do {
236       const int32x4_t v_src = vld1q_s32(&dst[i]);
237       const int32x4_t xy = vqrdmulhq_n_s32(v_src, cos128 << (31 - 12));
238       vst1q_s32(&dst[i], xy);
239       i += 4;
240     } while (i < width);
241   }
242 
243   // Copy first row to the rest of the block.
244   for (int y = 1; y < height; ++y) {
245     memcpy(&dst[y * width], dst, width * sizeof(dst[0]));
246   }
247   return true;
248 }
249 
250 template <ButterflyRotationFunc butterfly_rotation,
251           bool is_fast_butterfly = false>
Dct4Stages(int32x4_t * s,const int32x4_t min,const int32x4_t max,const bool is_last_stage)252 LIBGAV1_ALWAYS_INLINE void Dct4Stages(int32x4_t* s, const int32x4_t min,
253                                       const int32x4_t max,
254                                       const bool is_last_stage) {
255   // stage 12.
256   if (is_fast_butterfly) {
257     ButterflyRotation_SecondIsZero(&s[0], &s[1], 32, true);
258     ButterflyRotation_SecondIsZero(&s[2], &s[3], 48, false);
259   } else {
260     butterfly_rotation(&s[0], &s[1], 32, true);
261     butterfly_rotation(&s[2], &s[3], 48, false);
262   }
263 
264   // stage 17.
265   if (is_last_stage) {
266     HadamardRotation(&s[0], &s[3], false);
267     HadamardRotation(&s[1], &s[2], false);
268   } else {
269     HadamardRotation(&s[0], &s[3], false, min, max);
270     HadamardRotation(&s[1], &s[2], false, min, max);
271   }
272 }
273 
274 template <ButterflyRotationFunc butterfly_rotation>
Dct4_NEON(void * dest,int32_t step,bool is_row,int row_shift)275 LIBGAV1_ALWAYS_INLINE void Dct4_NEON(void* dest, int32_t step, bool is_row,
276                                      int row_shift) {
277   auto* const dst = static_cast<int32_t*>(dest);
278   // When |is_row| is true, set range to the row range, otherwise, set to the
279   // column range.
280   const int32_t range = is_row ? kBitdepth10 + 7 : 15;
281   const int32x4_t min = vdupq_n_s32(-(1 << range));
282   const int32x4_t max = vdupq_n_s32((1 << range) - 1);
283   int32x4_t s[4], x[4];
284 
285   LoadSrc<4>(dst, step, 0, x);
286   if (is_row) {
287     Transpose4x4(x, x);
288   }
289 
290   // stage 1.
291   // kBitReverseLookup 0, 2, 1, 3
292   s[0] = x[0];
293   s[1] = x[2];
294   s[2] = x[1];
295   s[3] = x[3];
296 
297   Dct4Stages<butterfly_rotation>(s, min, max, /*is_last_stage=*/true);
298 
299   if (is_row) {
300     const int32x4_t v_row_shift = vdupq_n_s32(-row_shift);
301     for (auto& i : s) {
302       i = vmovl_s16(vqmovn_s32(vqrshlq_s32(i, v_row_shift)));
303     }
304     Transpose4x4(s, s);
305   }
306   StoreDst<4>(dst, step, 0, s);
307 }
308 
309 template <ButterflyRotationFunc butterfly_rotation,
310           bool is_fast_butterfly = false>
Dct8Stages(int32x4_t * s,const int32x4_t min,const int32x4_t max,const bool is_last_stage)311 LIBGAV1_ALWAYS_INLINE void Dct8Stages(int32x4_t* s, const int32x4_t min,
312                                       const int32x4_t max,
313                                       const bool is_last_stage) {
314   // stage 8.
315   if (is_fast_butterfly) {
316     ButterflyRotation_SecondIsZero(&s[4], &s[7], 56, false);
317     ButterflyRotation_FirstIsZero(&s[5], &s[6], 24, false);
318   } else {
319     butterfly_rotation(&s[4], &s[7], 56, false);
320     butterfly_rotation(&s[5], &s[6], 24, false);
321   }
322 
323   // stage 13.
324   HadamardRotation(&s[4], &s[5], false, min, max);
325   HadamardRotation(&s[6], &s[7], true, min, max);
326 
327   // stage 18.
328   butterfly_rotation(&s[6], &s[5], 32, true);
329 
330   // stage 22.
331   if (is_last_stage) {
332     HadamardRotation(&s[0], &s[7], false);
333     HadamardRotation(&s[1], &s[6], false);
334     HadamardRotation(&s[2], &s[5], false);
335     HadamardRotation(&s[3], &s[4], false);
336   } else {
337     HadamardRotation(&s[0], &s[7], false, min, max);
338     HadamardRotation(&s[1], &s[6], false, min, max);
339     HadamardRotation(&s[2], &s[5], false, min, max);
340     HadamardRotation(&s[3], &s[4], false, min, max);
341   }
342 }
343 
344 // Process dct8 rows or columns, depending on the |is_row| flag.
345 template <ButterflyRotationFunc butterfly_rotation>
Dct8_NEON(void * dest,int32_t step,bool is_row,int row_shift)346 LIBGAV1_ALWAYS_INLINE void Dct8_NEON(void* dest, int32_t step, bool is_row,
347                                      int row_shift) {
348   auto* const dst = static_cast<int32_t*>(dest);
349   const int32_t range = is_row ? kBitdepth10 + 7 : 15;
350   const int32x4_t min = vdupq_n_s32(-(1 << range));
351   const int32x4_t max = vdupq_n_s32((1 << range) - 1);
352   int32x4_t s[8], x[8];
353 
354   if (is_row) {
355     LoadSrc<4>(dst, step, 0, &x[0]);
356     LoadSrc<4>(dst, step, 4, &x[4]);
357     Transpose4x4(&x[0], &x[0]);
358     Transpose4x4(&x[4], &x[4]);
359   } else {
360     LoadSrc<8>(dst, step, 0, &x[0]);
361   }
362 
363   // stage 1.
364   // kBitReverseLookup 0, 4, 2, 6, 1, 5, 3, 7,
365   s[0] = x[0];
366   s[1] = x[4];
367   s[2] = x[2];
368   s[3] = x[6];
369   s[4] = x[1];
370   s[5] = x[5];
371   s[6] = x[3];
372   s[7] = x[7];
373 
374   Dct4Stages<butterfly_rotation>(s, min, max, /*is_last_stage=*/false);
375   Dct8Stages<butterfly_rotation>(s, min, max, /*is_last_stage=*/true);
376 
377   if (is_row) {
378     const int32x4_t v_row_shift = vdupq_n_s32(-row_shift);
379     for (auto& i : s) {
380       i = vmovl_s16(vqmovn_s32(vqrshlq_s32(i, v_row_shift)));
381     }
382     Transpose4x4(&s[0], &s[0]);
383     Transpose4x4(&s[4], &s[4]);
384     StoreDst<4>(dst, step, 0, &s[0]);
385     StoreDst<4>(dst, step, 4, &s[4]);
386   } else {
387     StoreDst<8>(dst, step, 0, &s[0]);
388   }
389 }
390 
391 template <ButterflyRotationFunc butterfly_rotation,
392           bool is_fast_butterfly = false>
Dct16Stages(int32x4_t * s,const int32x4_t min,const int32x4_t max,const bool is_last_stage)393 LIBGAV1_ALWAYS_INLINE void Dct16Stages(int32x4_t* s, const int32x4_t min,
394                                        const int32x4_t max,
395                                        const bool is_last_stage) {
396   // stage 5.
397   if (is_fast_butterfly) {
398     ButterflyRotation_SecondIsZero(&s[8], &s[15], 60, false);
399     ButterflyRotation_FirstIsZero(&s[9], &s[14], 28, false);
400     ButterflyRotation_SecondIsZero(&s[10], &s[13], 44, false);
401     ButterflyRotation_FirstIsZero(&s[11], &s[12], 12, false);
402   } else {
403     butterfly_rotation(&s[8], &s[15], 60, false);
404     butterfly_rotation(&s[9], &s[14], 28, false);
405     butterfly_rotation(&s[10], &s[13], 44, false);
406     butterfly_rotation(&s[11], &s[12], 12, false);
407   }
408 
409   // stage 9.
410   HadamardRotation(&s[8], &s[9], false, min, max);
411   HadamardRotation(&s[10], &s[11], true, min, max);
412   HadamardRotation(&s[12], &s[13], false, min, max);
413   HadamardRotation(&s[14], &s[15], true, min, max);
414 
415   // stage 14.
416   butterfly_rotation(&s[14], &s[9], 48, true);
417   butterfly_rotation(&s[13], &s[10], 112, true);
418 
419   // stage 19.
420   HadamardRotation(&s[8], &s[11], false, min, max);
421   HadamardRotation(&s[9], &s[10], false, min, max);
422   HadamardRotation(&s[12], &s[15], true, min, max);
423   HadamardRotation(&s[13], &s[14], true, min, max);
424 
425   // stage 23.
426   butterfly_rotation(&s[13], &s[10], 32, true);
427   butterfly_rotation(&s[12], &s[11], 32, true);
428 
429   // stage 26.
430   if (is_last_stage) {
431     HadamardRotation(&s[0], &s[15], false);
432     HadamardRotation(&s[1], &s[14], false);
433     HadamardRotation(&s[2], &s[13], false);
434     HadamardRotation(&s[3], &s[12], false);
435     HadamardRotation(&s[4], &s[11], false);
436     HadamardRotation(&s[5], &s[10], false);
437     HadamardRotation(&s[6], &s[9], false);
438     HadamardRotation(&s[7], &s[8], false);
439   } else {
440     HadamardRotation(&s[0], &s[15], false, min, max);
441     HadamardRotation(&s[1], &s[14], false, min, max);
442     HadamardRotation(&s[2], &s[13], false, min, max);
443     HadamardRotation(&s[3], &s[12], false, min, max);
444     HadamardRotation(&s[4], &s[11], false, min, max);
445     HadamardRotation(&s[5], &s[10], false, min, max);
446     HadamardRotation(&s[6], &s[9], false, min, max);
447     HadamardRotation(&s[7], &s[8], false, min, max);
448   }
449 }
450 
451 // Process dct16 rows or columns, depending on the |is_row| flag.
452 template <ButterflyRotationFunc butterfly_rotation>
Dct16_NEON(void * dest,int32_t step,bool is_row,int row_shift)453 LIBGAV1_ALWAYS_INLINE void Dct16_NEON(void* dest, int32_t step, bool is_row,
454                                       int row_shift) {
455   auto* const dst = static_cast<int32_t*>(dest);
456   const int32_t range = is_row ? kBitdepth10 + 7 : 15;
457   const int32x4_t min = vdupq_n_s32(-(1 << range));
458   const int32x4_t max = vdupq_n_s32((1 << range) - 1);
459   int32x4_t s[16], x[16];
460 
461   if (is_row) {
462     for (int idx = 0; idx < 16; idx += 8) {
463       LoadSrc<4>(dst, step, idx, &x[idx]);
464       LoadSrc<4>(dst, step, idx + 4, &x[idx + 4]);
465       Transpose4x4(&x[idx], &x[idx]);
466       Transpose4x4(&x[idx + 4], &x[idx + 4]);
467     }
468   } else {
469     LoadSrc<16>(dst, step, 0, &x[0]);
470   }
471 
472   // stage 1
473   // kBitReverseLookup 0, 8, 4, 12, 2, 10, 6, 14, 1, 9, 5, 13, 3, 11, 7, 15,
474   s[0] = x[0];
475   s[1] = x[8];
476   s[2] = x[4];
477   s[3] = x[12];
478   s[4] = x[2];
479   s[5] = x[10];
480   s[6] = x[6];
481   s[7] = x[14];
482   s[8] = x[1];
483   s[9] = x[9];
484   s[10] = x[5];
485   s[11] = x[13];
486   s[12] = x[3];
487   s[13] = x[11];
488   s[14] = x[7];
489   s[15] = x[15];
490 
491   Dct4Stages<butterfly_rotation>(s, min, max, /*is_last_stage=*/false);
492   Dct8Stages<butterfly_rotation>(s, min, max, /*is_last_stage=*/false);
493   Dct16Stages<butterfly_rotation>(s, min, max, /*is_last_stage=*/true);
494 
495   if (is_row) {
496     const int32x4_t v_row_shift = vdupq_n_s32(-row_shift);
497     for (auto& i : s) {
498       i = vmovl_s16(vqmovn_s32(vqrshlq_s32(i, v_row_shift)));
499     }
500     for (int idx = 0; idx < 16; idx += 8) {
501       Transpose4x4(&s[idx], &s[idx]);
502       Transpose4x4(&s[idx + 4], &s[idx + 4]);
503       StoreDst<4>(dst, step, idx, &s[idx]);
504       StoreDst<4>(dst, step, idx + 4, &s[idx + 4]);
505     }
506   } else {
507     StoreDst<16>(dst, step, 0, &s[0]);
508   }
509 }
510 
511 template <ButterflyRotationFunc butterfly_rotation,
512           bool is_fast_butterfly = false>
Dct32Stages(int32x4_t * s,const int32x4_t min,const int32x4_t max,const bool is_last_stage)513 LIBGAV1_ALWAYS_INLINE void Dct32Stages(int32x4_t* s, const int32x4_t min,
514                                        const int32x4_t max,
515                                        const bool is_last_stage) {
516   // stage 3
517   if (is_fast_butterfly) {
518     ButterflyRotation_SecondIsZero(&s[16], &s[31], 62, false);
519     ButterflyRotation_FirstIsZero(&s[17], &s[30], 30, false);
520     ButterflyRotation_SecondIsZero(&s[18], &s[29], 46, false);
521     ButterflyRotation_FirstIsZero(&s[19], &s[28], 14, false);
522     ButterflyRotation_SecondIsZero(&s[20], &s[27], 54, false);
523     ButterflyRotation_FirstIsZero(&s[21], &s[26], 22, false);
524     ButterflyRotation_SecondIsZero(&s[22], &s[25], 38, false);
525     ButterflyRotation_FirstIsZero(&s[23], &s[24], 6, false);
526   } else {
527     butterfly_rotation(&s[16], &s[31], 62, false);
528     butterfly_rotation(&s[17], &s[30], 30, false);
529     butterfly_rotation(&s[18], &s[29], 46, false);
530     butterfly_rotation(&s[19], &s[28], 14, false);
531     butterfly_rotation(&s[20], &s[27], 54, false);
532     butterfly_rotation(&s[21], &s[26], 22, false);
533     butterfly_rotation(&s[22], &s[25], 38, false);
534     butterfly_rotation(&s[23], &s[24], 6, false);
535   }
536 
537   // stage 6.
538   HadamardRotation(&s[16], &s[17], false, min, max);
539   HadamardRotation(&s[18], &s[19], true, min, max);
540   HadamardRotation(&s[20], &s[21], false, min, max);
541   HadamardRotation(&s[22], &s[23], true, min, max);
542   HadamardRotation(&s[24], &s[25], false, min, max);
543   HadamardRotation(&s[26], &s[27], true, min, max);
544   HadamardRotation(&s[28], &s[29], false, min, max);
545   HadamardRotation(&s[30], &s[31], true, min, max);
546 
547   // stage 10.
548   butterfly_rotation(&s[30], &s[17], 24 + 32, true);
549   butterfly_rotation(&s[29], &s[18], 24 + 64 + 32, true);
550   butterfly_rotation(&s[26], &s[21], 24, true);
551   butterfly_rotation(&s[25], &s[22], 24 + 64, true);
552 
553   // stage 15.
554   HadamardRotation(&s[16], &s[19], false, min, max);
555   HadamardRotation(&s[17], &s[18], false, min, max);
556   HadamardRotation(&s[20], &s[23], true, min, max);
557   HadamardRotation(&s[21], &s[22], true, min, max);
558   HadamardRotation(&s[24], &s[27], false, min, max);
559   HadamardRotation(&s[25], &s[26], false, min, max);
560   HadamardRotation(&s[28], &s[31], true, min, max);
561   HadamardRotation(&s[29], &s[30], true, min, max);
562 
563   // stage 20.
564   butterfly_rotation(&s[29], &s[18], 48, true);
565   butterfly_rotation(&s[28], &s[19], 48, true);
566   butterfly_rotation(&s[27], &s[20], 48 + 64, true);
567   butterfly_rotation(&s[26], &s[21], 48 + 64, true);
568 
569   // stage 24.
570   HadamardRotation(&s[16], &s[23], false, min, max);
571   HadamardRotation(&s[17], &s[22], false, min, max);
572   HadamardRotation(&s[18], &s[21], false, min, max);
573   HadamardRotation(&s[19], &s[20], false, min, max);
574   HadamardRotation(&s[24], &s[31], true, min, max);
575   HadamardRotation(&s[25], &s[30], true, min, max);
576   HadamardRotation(&s[26], &s[29], true, min, max);
577   HadamardRotation(&s[27], &s[28], true, min, max);
578 
579   // stage 27.
580   butterfly_rotation(&s[27], &s[20], 32, true);
581   butterfly_rotation(&s[26], &s[21], 32, true);
582   butterfly_rotation(&s[25], &s[22], 32, true);
583   butterfly_rotation(&s[24], &s[23], 32, true);
584 
585   // stage 29.
586   if (is_last_stage) {
587     HadamardRotation(&s[0], &s[31], false);
588     HadamardRotation(&s[1], &s[30], false);
589     HadamardRotation(&s[2], &s[29], false);
590     HadamardRotation(&s[3], &s[28], false);
591     HadamardRotation(&s[4], &s[27], false);
592     HadamardRotation(&s[5], &s[26], false);
593     HadamardRotation(&s[6], &s[25], false);
594     HadamardRotation(&s[7], &s[24], false);
595     HadamardRotation(&s[8], &s[23], false);
596     HadamardRotation(&s[9], &s[22], false);
597     HadamardRotation(&s[10], &s[21], false);
598     HadamardRotation(&s[11], &s[20], false);
599     HadamardRotation(&s[12], &s[19], false);
600     HadamardRotation(&s[13], &s[18], false);
601     HadamardRotation(&s[14], &s[17], false);
602     HadamardRotation(&s[15], &s[16], false);
603   } else {
604     HadamardRotation(&s[0], &s[31], false, min, max);
605     HadamardRotation(&s[1], &s[30], false, min, max);
606     HadamardRotation(&s[2], &s[29], false, min, max);
607     HadamardRotation(&s[3], &s[28], false, min, max);
608     HadamardRotation(&s[4], &s[27], false, min, max);
609     HadamardRotation(&s[5], &s[26], false, min, max);
610     HadamardRotation(&s[6], &s[25], false, min, max);
611     HadamardRotation(&s[7], &s[24], false, min, max);
612     HadamardRotation(&s[8], &s[23], false, min, max);
613     HadamardRotation(&s[9], &s[22], false, min, max);
614     HadamardRotation(&s[10], &s[21], false, min, max);
615     HadamardRotation(&s[11], &s[20], false, min, max);
616     HadamardRotation(&s[12], &s[19], false, min, max);
617     HadamardRotation(&s[13], &s[18], false, min, max);
618     HadamardRotation(&s[14], &s[17], false, min, max);
619     HadamardRotation(&s[15], &s[16], false, min, max);
620   }
621 }
622 
623 // Process dct32 rows or columns, depending on the |is_row| flag.
Dct32_NEON(void * dest,const int32_t step,const bool is_row,int row_shift)624 LIBGAV1_ALWAYS_INLINE void Dct32_NEON(void* dest, const int32_t step,
625                                       const bool is_row, int row_shift) {
626   auto* const dst = static_cast<int32_t*>(dest);
627   const int32_t range = is_row ? kBitdepth10 + 7 : 15;
628   const int32x4_t min = vdupq_n_s32(-(1 << range));
629   const int32x4_t max = vdupq_n_s32((1 << range) - 1);
630   int32x4_t s[32], x[32];
631 
632   if (is_row) {
633     for (int idx = 0; idx < 32; idx += 8) {
634       LoadSrc<4>(dst, step, idx, &x[idx]);
635       LoadSrc<4>(dst, step, idx + 4, &x[idx + 4]);
636       Transpose4x4(&x[idx], &x[idx]);
637       Transpose4x4(&x[idx + 4], &x[idx + 4]);
638     }
639   } else {
640     LoadSrc<32>(dst, step, 0, &x[0]);
641   }
642 
643   // stage 1
644   // kBitReverseLookup
645   // 0, 16, 8, 24, 4, 20, 12, 28, 2, 18, 10, 26, 6, 22, 14, 30,
646   s[0] = x[0];
647   s[1] = x[16];
648   s[2] = x[8];
649   s[3] = x[24];
650   s[4] = x[4];
651   s[5] = x[20];
652   s[6] = x[12];
653   s[7] = x[28];
654   s[8] = x[2];
655   s[9] = x[18];
656   s[10] = x[10];
657   s[11] = x[26];
658   s[12] = x[6];
659   s[13] = x[22];
660   s[14] = x[14];
661   s[15] = x[30];
662 
663   // 1, 17, 9, 25, 5, 21, 13, 29, 3, 19, 11, 27, 7, 23, 15, 31,
664   s[16] = x[1];
665   s[17] = x[17];
666   s[18] = x[9];
667   s[19] = x[25];
668   s[20] = x[5];
669   s[21] = x[21];
670   s[22] = x[13];
671   s[23] = x[29];
672   s[24] = x[3];
673   s[25] = x[19];
674   s[26] = x[11];
675   s[27] = x[27];
676   s[28] = x[7];
677   s[29] = x[23];
678   s[30] = x[15];
679   s[31] = x[31];
680 
681   Dct4Stages<ButterflyRotation_4>(s, min, max, /*is_last_stage=*/false);
682   Dct8Stages<ButterflyRotation_4>(s, min, max, /*is_last_stage=*/false);
683   Dct16Stages<ButterflyRotation_4>(s, min, max, /*is_last_stage=*/false);
684   Dct32Stages<ButterflyRotation_4>(s, min, max, /*is_last_stage=*/true);
685 
686   if (is_row) {
687     const int32x4_t v_row_shift = vdupq_n_s32(-row_shift);
688     for (int idx = 0; idx < 32; idx += 8) {
689       int32x4_t output[8];
690       Transpose4x4(&s[idx], &output[0]);
691       Transpose4x4(&s[idx + 4], &output[4]);
692       for (auto& o : output) {
693         o = vmovl_s16(vqmovn_s32(vqrshlq_s32(o, v_row_shift)));
694       }
695       StoreDst<4>(dst, step, idx, &output[0]);
696       StoreDst<4>(dst, step, idx + 4, &output[4]);
697     }
698   } else {
699     StoreDst<32>(dst, step, 0, &s[0]);
700   }
701 }
702 
Dct64_NEON(void * dest,int32_t step,bool is_row,int row_shift)703 void Dct64_NEON(void* dest, int32_t step, bool is_row, int row_shift) {
704   auto* const dst = static_cast<int32_t*>(dest);
705   const int32_t range = is_row ? kBitdepth10 + 7 : 15;
706   const int32x4_t min = vdupq_n_s32(-(1 << range));
707   const int32x4_t max = vdupq_n_s32((1 << range) - 1);
708   int32x4_t s[64], x[32];
709 
710   if (is_row) {
711     // The last 32 values of every row are always zero if the |tx_width| is
712     // 64.
713     for (int idx = 0; idx < 32; idx += 8) {
714       LoadSrc<4>(dst, step, idx, &x[idx]);
715       LoadSrc<4>(dst, step, idx + 4, &x[idx + 4]);
716       Transpose4x4(&x[idx], &x[idx]);
717       Transpose4x4(&x[idx + 4], &x[idx + 4]);
718     }
719   } else {
720     // The last 32 values of every column are always zero if the |tx_height| is
721     // 64.
722     LoadSrc<32>(dst, step, 0, &x[0]);
723   }
724 
725   // stage 1
726   // kBitReverseLookup
727   // 0, 32, 16, 48, 8, 40, 24, 56, 4, 36, 20, 52, 12, 44, 28, 60,
728   s[0] = x[0];
729   s[2] = x[16];
730   s[4] = x[8];
731   s[6] = x[24];
732   s[8] = x[4];
733   s[10] = x[20];
734   s[12] = x[12];
735   s[14] = x[28];
736 
737   // 2, 34, 18, 50, 10, 42, 26, 58, 6, 38, 22, 54, 14, 46, 30, 62,
738   s[16] = x[2];
739   s[18] = x[18];
740   s[20] = x[10];
741   s[22] = x[26];
742   s[24] = x[6];
743   s[26] = x[22];
744   s[28] = x[14];
745   s[30] = x[30];
746 
747   // 1, 33, 17, 49, 9, 41, 25, 57, 5, 37, 21, 53, 13, 45, 29, 61,
748   s[32] = x[1];
749   s[34] = x[17];
750   s[36] = x[9];
751   s[38] = x[25];
752   s[40] = x[5];
753   s[42] = x[21];
754   s[44] = x[13];
755   s[46] = x[29];
756 
757   // 3, 35, 19, 51, 11, 43, 27, 59, 7, 39, 23, 55, 15, 47, 31, 63
758   s[48] = x[3];
759   s[50] = x[19];
760   s[52] = x[11];
761   s[54] = x[27];
762   s[56] = x[7];
763   s[58] = x[23];
764   s[60] = x[15];
765   s[62] = x[31];
766 
767   Dct4Stages<ButterflyRotation_4, /*is_fast_butterfly=*/true>(
768       s, min, max, /*is_last_stage=*/false);
769   Dct8Stages<ButterflyRotation_4, /*is_fast_butterfly=*/true>(
770       s, min, max, /*is_last_stage=*/false);
771   Dct16Stages<ButterflyRotation_4, /*is_fast_butterfly=*/true>(
772       s, min, max, /*is_last_stage=*/false);
773   Dct32Stages<ButterflyRotation_4, /*is_fast_butterfly=*/true>(
774       s, min, max, /*is_last_stage=*/false);
775 
776   //-- start dct 64 stages
777   // stage 2.
778   ButterflyRotation_SecondIsZero(&s[32], &s[63], 63 - 0, false);
779   ButterflyRotation_FirstIsZero(&s[33], &s[62], 63 - 32, false);
780   ButterflyRotation_SecondIsZero(&s[34], &s[61], 63 - 16, false);
781   ButterflyRotation_FirstIsZero(&s[35], &s[60], 63 - 48, false);
782   ButterflyRotation_SecondIsZero(&s[36], &s[59], 63 - 8, false);
783   ButterflyRotation_FirstIsZero(&s[37], &s[58], 63 - 40, false);
784   ButterflyRotation_SecondIsZero(&s[38], &s[57], 63 - 24, false);
785   ButterflyRotation_FirstIsZero(&s[39], &s[56], 63 - 56, false);
786   ButterflyRotation_SecondIsZero(&s[40], &s[55], 63 - 4, false);
787   ButterflyRotation_FirstIsZero(&s[41], &s[54], 63 - 36, false);
788   ButterflyRotation_SecondIsZero(&s[42], &s[53], 63 - 20, false);
789   ButterflyRotation_FirstIsZero(&s[43], &s[52], 63 - 52, false);
790   ButterflyRotation_SecondIsZero(&s[44], &s[51], 63 - 12, false);
791   ButterflyRotation_FirstIsZero(&s[45], &s[50], 63 - 44, false);
792   ButterflyRotation_SecondIsZero(&s[46], &s[49], 63 - 28, false);
793   ButterflyRotation_FirstIsZero(&s[47], &s[48], 63 - 60, false);
794 
795   // stage 4.
796   HadamardRotation(&s[32], &s[33], false, min, max);
797   HadamardRotation(&s[34], &s[35], true, min, max);
798   HadamardRotation(&s[36], &s[37], false, min, max);
799   HadamardRotation(&s[38], &s[39], true, min, max);
800   HadamardRotation(&s[40], &s[41], false, min, max);
801   HadamardRotation(&s[42], &s[43], true, min, max);
802   HadamardRotation(&s[44], &s[45], false, min, max);
803   HadamardRotation(&s[46], &s[47], true, min, max);
804   HadamardRotation(&s[48], &s[49], false, min, max);
805   HadamardRotation(&s[50], &s[51], true, min, max);
806   HadamardRotation(&s[52], &s[53], false, min, max);
807   HadamardRotation(&s[54], &s[55], true, min, max);
808   HadamardRotation(&s[56], &s[57], false, min, max);
809   HadamardRotation(&s[58], &s[59], true, min, max);
810   HadamardRotation(&s[60], &s[61], false, min, max);
811   HadamardRotation(&s[62], &s[63], true, min, max);
812 
813   // stage 7.
814   ButterflyRotation_4(&s[62], &s[33], 60 - 0, true);
815   ButterflyRotation_4(&s[61], &s[34], 60 - 0 + 64, true);
816   ButterflyRotation_4(&s[58], &s[37], 60 - 32, true);
817   ButterflyRotation_4(&s[57], &s[38], 60 - 32 + 64, true);
818   ButterflyRotation_4(&s[54], &s[41], 60 - 16, true);
819   ButterflyRotation_4(&s[53], &s[42], 60 - 16 + 64, true);
820   ButterflyRotation_4(&s[50], &s[45], 60 - 48, true);
821   ButterflyRotation_4(&s[49], &s[46], 60 - 48 + 64, true);
822 
823   // stage 11.
824   HadamardRotation(&s[32], &s[35], false, min, max);
825   HadamardRotation(&s[33], &s[34], false, min, max);
826   HadamardRotation(&s[36], &s[39], true, min, max);
827   HadamardRotation(&s[37], &s[38], true, min, max);
828   HadamardRotation(&s[40], &s[43], false, min, max);
829   HadamardRotation(&s[41], &s[42], false, min, max);
830   HadamardRotation(&s[44], &s[47], true, min, max);
831   HadamardRotation(&s[45], &s[46], true, min, max);
832   HadamardRotation(&s[48], &s[51], false, min, max);
833   HadamardRotation(&s[49], &s[50], false, min, max);
834   HadamardRotation(&s[52], &s[55], true, min, max);
835   HadamardRotation(&s[53], &s[54], true, min, max);
836   HadamardRotation(&s[56], &s[59], false, min, max);
837   HadamardRotation(&s[57], &s[58], false, min, max);
838   HadamardRotation(&s[60], &s[63], true, min, max);
839   HadamardRotation(&s[61], &s[62], true, min, max);
840 
841   // stage 16.
842   ButterflyRotation_4(&s[61], &s[34], 56, true);
843   ButterflyRotation_4(&s[60], &s[35], 56, true);
844   ButterflyRotation_4(&s[59], &s[36], 56 + 64, true);
845   ButterflyRotation_4(&s[58], &s[37], 56 + 64, true);
846   ButterflyRotation_4(&s[53], &s[42], 56 - 32, true);
847   ButterflyRotation_4(&s[52], &s[43], 56 - 32, true);
848   ButterflyRotation_4(&s[51], &s[44], 56 - 32 + 64, true);
849   ButterflyRotation_4(&s[50], &s[45], 56 - 32 + 64, true);
850 
851   // stage 21.
852   HadamardRotation(&s[32], &s[39], false, min, max);
853   HadamardRotation(&s[33], &s[38], false, min, max);
854   HadamardRotation(&s[34], &s[37], false, min, max);
855   HadamardRotation(&s[35], &s[36], false, min, max);
856   HadamardRotation(&s[40], &s[47], true, min, max);
857   HadamardRotation(&s[41], &s[46], true, min, max);
858   HadamardRotation(&s[42], &s[45], true, min, max);
859   HadamardRotation(&s[43], &s[44], true, min, max);
860   HadamardRotation(&s[48], &s[55], false, min, max);
861   HadamardRotation(&s[49], &s[54], false, min, max);
862   HadamardRotation(&s[50], &s[53], false, min, max);
863   HadamardRotation(&s[51], &s[52], false, min, max);
864   HadamardRotation(&s[56], &s[63], true, min, max);
865   HadamardRotation(&s[57], &s[62], true, min, max);
866   HadamardRotation(&s[58], &s[61], true, min, max);
867   HadamardRotation(&s[59], &s[60], true, min, max);
868 
869   // stage 25.
870   ButterflyRotation_4(&s[59], &s[36], 48, true);
871   ButterflyRotation_4(&s[58], &s[37], 48, true);
872   ButterflyRotation_4(&s[57], &s[38], 48, true);
873   ButterflyRotation_4(&s[56], &s[39], 48, true);
874   ButterflyRotation_4(&s[55], &s[40], 112, true);
875   ButterflyRotation_4(&s[54], &s[41], 112, true);
876   ButterflyRotation_4(&s[53], &s[42], 112, true);
877   ButterflyRotation_4(&s[52], &s[43], 112, true);
878 
879   // stage 28.
880   HadamardRotation(&s[32], &s[47], false, min, max);
881   HadamardRotation(&s[33], &s[46], false, min, max);
882   HadamardRotation(&s[34], &s[45], false, min, max);
883   HadamardRotation(&s[35], &s[44], false, min, max);
884   HadamardRotation(&s[36], &s[43], false, min, max);
885   HadamardRotation(&s[37], &s[42], false, min, max);
886   HadamardRotation(&s[38], &s[41], false, min, max);
887   HadamardRotation(&s[39], &s[40], false, min, max);
888   HadamardRotation(&s[48], &s[63], true, min, max);
889   HadamardRotation(&s[49], &s[62], true, min, max);
890   HadamardRotation(&s[50], &s[61], true, min, max);
891   HadamardRotation(&s[51], &s[60], true, min, max);
892   HadamardRotation(&s[52], &s[59], true, min, max);
893   HadamardRotation(&s[53], &s[58], true, min, max);
894   HadamardRotation(&s[54], &s[57], true, min, max);
895   HadamardRotation(&s[55], &s[56], true, min, max);
896 
897   // stage 30.
898   ButterflyRotation_4(&s[55], &s[40], 32, true);
899   ButterflyRotation_4(&s[54], &s[41], 32, true);
900   ButterflyRotation_4(&s[53], &s[42], 32, true);
901   ButterflyRotation_4(&s[52], &s[43], 32, true);
902   ButterflyRotation_4(&s[51], &s[44], 32, true);
903   ButterflyRotation_4(&s[50], &s[45], 32, true);
904   ButterflyRotation_4(&s[49], &s[46], 32, true);
905   ButterflyRotation_4(&s[48], &s[47], 32, true);
906 
907   // stage 31.
908   for (int i = 0; i < 32; i += 4) {
909     HadamardRotation(&s[i], &s[63 - i], false, min, max);
910     HadamardRotation(&s[i + 1], &s[63 - i - 1], false, min, max);
911     HadamardRotation(&s[i + 2], &s[63 - i - 2], false, min, max);
912     HadamardRotation(&s[i + 3], &s[63 - i - 3], false, min, max);
913   }
914   //-- end dct 64 stages
915   if (is_row) {
916     const int32x4_t v_row_shift = vdupq_n_s32(-row_shift);
917     for (int idx = 0; idx < 64; idx += 8) {
918       int32x4_t output[8];
919       Transpose4x4(&s[idx], &output[0]);
920       Transpose4x4(&s[idx + 4], &output[4]);
921       for (auto& o : output) {
922         o = vmovl_s16(vqmovn_s32(vqrshlq_s32(o, v_row_shift)));
923       }
924       StoreDst<4>(dst, step, idx, &output[0]);
925       StoreDst<4>(dst, step, idx + 4, &output[4]);
926     }
927   } else {
928     StoreDst<64>(dst, step, 0, &s[0]);
929   }
930 }
931 
932 //------------------------------------------------------------------------------
933 // Asymmetric Discrete Sine Transforms (ADST).
Adst4_NEON(void * dest,int32_t step,bool is_row,int row_shift)934 LIBGAV1_ALWAYS_INLINE void Adst4_NEON(void* dest, int32_t step, bool is_row,
935                                       int row_shift) {
936   auto* const dst = static_cast<int32_t*>(dest);
937   int32x4_t s[8];
938   int32x4_t x[4];
939 
940   LoadSrc<4>(dst, step, 0, x);
941   if (is_row) {
942     Transpose4x4(x, x);
943   }
944 
945   // stage 1.
946   s[5] = vmulq_n_s32(x[3], kAdst4Multiplier[1]);
947   s[6] = vmulq_n_s32(x[3], kAdst4Multiplier[3]);
948 
949   // stage 2.
950   const int32x4_t a7 = vsubq_s32(x[0], x[2]);
951   const int32x4_t b7 = vaddq_s32(a7, x[3]);
952 
953   // stage 3.
954   s[0] = vmulq_n_s32(x[0], kAdst4Multiplier[0]);
955   s[1] = vmulq_n_s32(x[0], kAdst4Multiplier[1]);
956   // s[0] = s[0] + s[3]
957   s[0] = vmlaq_n_s32(s[0], x[2], kAdst4Multiplier[3]);
958   // s[1] = s[1] - s[4]
959   s[1] = vmlsq_n_s32(s[1], x[2], kAdst4Multiplier[0]);
960 
961   s[3] = vmulq_n_s32(x[1], kAdst4Multiplier[2]);
962   s[2] = vmulq_n_s32(b7, kAdst4Multiplier[2]);
963 
964   // stage 4.
965   s[0] = vaddq_s32(s[0], s[5]);
966   s[1] = vsubq_s32(s[1], s[6]);
967 
968   // stages 5 and 6.
969   const int32x4_t x0 = vaddq_s32(s[0], s[3]);
970   const int32x4_t x1 = vaddq_s32(s[1], s[3]);
971   const int32x4_t x3_a = vaddq_s32(s[0], s[1]);
972   const int32x4_t x3 = vsubq_s32(x3_a, s[3]);
973   x[0] = vrshrq_n_s32(x0, 12);
974   x[1] = vrshrq_n_s32(x1, 12);
975   x[2] = vrshrq_n_s32(s[2], 12);
976   x[3] = vrshrq_n_s32(x3, 12);
977 
978   if (is_row) {
979     const int32x4_t v_row_shift = vdupq_n_s32(-row_shift);
980     x[0] = vmovl_s16(vqmovn_s32(vqrshlq_s32(x[0], v_row_shift)));
981     x[1] = vmovl_s16(vqmovn_s32(vqrshlq_s32(x[1], v_row_shift)));
982     x[2] = vmovl_s16(vqmovn_s32(vqrshlq_s32(x[2], v_row_shift)));
983     x[3] = vmovl_s16(vqmovn_s32(vqrshlq_s32(x[3], v_row_shift)));
984     Transpose4x4(x, x);
985   }
986   StoreDst<4>(dst, step, 0, x);
987 }
988 
989 alignas(16) constexpr int32_t kAdst4DcOnlyMultiplier[4] = {1321, 2482, 3344,
990                                                            2482};
991 
Adst4DcOnly(void * dest,int adjusted_tx_height,bool should_round,int row_shift)992 LIBGAV1_ALWAYS_INLINE bool Adst4DcOnly(void* dest, int adjusted_tx_height,
993                                        bool should_round, int row_shift) {
994   if (adjusted_tx_height > 1) return false;
995 
996   auto* dst = static_cast<int32_t*>(dest);
997   int32x4_t s[2];
998 
999   const int32x4_t v_src0 = vdupq_n_s32(dst[0]);
1000   const uint32x4_t v_mask = vdupq_n_u32(should_round ? 0xffffffff : 0);
1001   const int32x4_t v_src0_round =
1002       vqrdmulhq_n_s32(v_src0, kTransformRowMultiplier << (31 - 12));
1003 
1004   const int32x4_t v_src = vbslq_s32(v_mask, v_src0_round, v_src0);
1005   const int32x4_t kAdst4DcOnlyMultipliers = vld1q_s32(kAdst4DcOnlyMultiplier);
1006   s[1] = vdupq_n_s32(0);
1007 
1008   // s0*k0 s0*k1 s0*k2 s0*k1
1009   s[0] = vmulq_s32(kAdst4DcOnlyMultipliers, v_src);
1010   // 0     0     0     s0*k0
1011   s[1] = vextq_s32(s[1], s[0], 1);
1012 
1013   const int32x4_t x3 = vaddq_s32(s[0], s[1]);
1014   const int32x4_t dst_0 = vrshrq_n_s32(x3, 12);
1015 
1016   // vqrshlq_s32 will shift right if shift value is negative.
1017   vst1q_s32(dst,
1018             vmovl_s16(vqmovn_s32(vqrshlq_s32(dst_0, vdupq_n_s32(-row_shift)))));
1019 
1020   return true;
1021 }
1022 
Adst4DcOnlyColumn(void * dest,int adjusted_tx_height,int width)1023 LIBGAV1_ALWAYS_INLINE bool Adst4DcOnlyColumn(void* dest, int adjusted_tx_height,
1024                                              int width) {
1025   if (adjusted_tx_height > 1) return false;
1026 
1027   auto* dst = static_cast<int32_t*>(dest);
1028   int32x4_t s[4];
1029 
1030   int i = 0;
1031   do {
1032     const int32x4_t v_src = vld1q_s32(&dst[i]);
1033 
1034     s[0] = vmulq_n_s32(v_src, kAdst4Multiplier[0]);
1035     s[1] = vmulq_n_s32(v_src, kAdst4Multiplier[1]);
1036     s[2] = vmulq_n_s32(v_src, kAdst4Multiplier[2]);
1037 
1038     const int32x4_t x0 = s[0];
1039     const int32x4_t x1 = s[1];
1040     const int32x4_t x2 = s[2];
1041     const int32x4_t x3 = vaddq_s32(s[0], s[1]);
1042     const int32x4_t dst_0 = vrshrq_n_s32(x0, 12);
1043     const int32x4_t dst_1 = vrshrq_n_s32(x1, 12);
1044     const int32x4_t dst_2 = vrshrq_n_s32(x2, 12);
1045     const int32x4_t dst_3 = vrshrq_n_s32(x3, 12);
1046 
1047     vst1q_s32(&dst[i], dst_0);
1048     vst1q_s32(&dst[i + width * 1], dst_1);
1049     vst1q_s32(&dst[i + width * 2], dst_2);
1050     vst1q_s32(&dst[i + width * 3], dst_3);
1051 
1052     i += 4;
1053   } while (i < width);
1054 
1055   return true;
1056 }
1057 
1058 template <ButterflyRotationFunc butterfly_rotation>
Adst8_NEON(void * dest,int32_t step,bool is_row,int row_shift)1059 LIBGAV1_ALWAYS_INLINE void Adst8_NEON(void* dest, int32_t step, bool is_row,
1060                                       int row_shift) {
1061   auto* const dst = static_cast<int32_t*>(dest);
1062   const int32_t range = is_row ? kBitdepth10 + 7 : 15;
1063   const int32x4_t min = vdupq_n_s32(-(1 << range));
1064   const int32x4_t max = vdupq_n_s32((1 << range) - 1);
1065   int32x4_t s[8], x[8];
1066 
1067   if (is_row) {
1068     LoadSrc<4>(dst, step, 0, &x[0]);
1069     LoadSrc<4>(dst, step, 4, &x[4]);
1070     Transpose4x4(&x[0], &x[0]);
1071     Transpose4x4(&x[4], &x[4]);
1072   } else {
1073     LoadSrc<8>(dst, step, 0, &x[0]);
1074   }
1075 
1076   // stage 1.
1077   s[0] = x[7];
1078   s[1] = x[0];
1079   s[2] = x[5];
1080   s[3] = x[2];
1081   s[4] = x[3];
1082   s[5] = x[4];
1083   s[6] = x[1];
1084   s[7] = x[6];
1085 
1086   // stage 2.
1087   butterfly_rotation(&s[0], &s[1], 60 - 0, true);
1088   butterfly_rotation(&s[2], &s[3], 60 - 16, true);
1089   butterfly_rotation(&s[4], &s[5], 60 - 32, true);
1090   butterfly_rotation(&s[6], &s[7], 60 - 48, true);
1091 
1092   // stage 3.
1093   HadamardRotation(&s[0], &s[4], false, min, max);
1094   HadamardRotation(&s[1], &s[5], false, min, max);
1095   HadamardRotation(&s[2], &s[6], false, min, max);
1096   HadamardRotation(&s[3], &s[7], false, min, max);
1097 
1098   // stage 4.
1099   butterfly_rotation(&s[4], &s[5], 48 - 0, true);
1100   butterfly_rotation(&s[7], &s[6], 48 - 32, true);
1101 
1102   // stage 5.
1103   HadamardRotation(&s[0], &s[2], false, min, max);
1104   HadamardRotation(&s[4], &s[6], false, min, max);
1105   HadamardRotation(&s[1], &s[3], false, min, max);
1106   HadamardRotation(&s[5], &s[7], false, min, max);
1107 
1108   // stage 6.
1109   butterfly_rotation(&s[2], &s[3], 32, true);
1110   butterfly_rotation(&s[6], &s[7], 32, true);
1111 
1112   // stage 7.
1113   x[0] = s[0];
1114   x[1] = vqnegq_s32(s[4]);
1115   x[2] = s[6];
1116   x[3] = vqnegq_s32(s[2]);
1117   x[4] = s[3];
1118   x[5] = vqnegq_s32(s[7]);
1119   x[6] = s[5];
1120   x[7] = vqnegq_s32(s[1]);
1121 
1122   if (is_row) {
1123     const int32x4_t v_row_shift = vdupq_n_s32(-row_shift);
1124     for (auto& i : x) {
1125       i = vmovl_s16(vqmovn_s32(vqrshlq_s32(i, v_row_shift)));
1126     }
1127     Transpose4x4(&x[0], &x[0]);
1128     Transpose4x4(&x[4], &x[4]);
1129     StoreDst<4>(dst, step, 0, &x[0]);
1130     StoreDst<4>(dst, step, 4, &x[4]);
1131   } else {
1132     StoreDst<8>(dst, step, 0, &x[0]);
1133   }
1134 }
1135 
Adst8DcOnly(void * dest,int adjusted_tx_height,bool should_round,int row_shift)1136 LIBGAV1_ALWAYS_INLINE bool Adst8DcOnly(void* dest, int adjusted_tx_height,
1137                                        bool should_round, int row_shift) {
1138   if (adjusted_tx_height > 1) return false;
1139 
1140   auto* dst = static_cast<int32_t*>(dest);
1141   int32x4_t s[8];
1142 
1143   const int32x4_t v_src = vdupq_n_s32(dst[0]);
1144   const uint32x4_t v_mask = vdupq_n_u32(should_round ? 0xffffffff : 0);
1145   const int32x4_t v_src_round =
1146       vqrdmulhq_n_s32(v_src, kTransformRowMultiplier << (31 - 12));
1147   // stage 1.
1148   s[1] = vbslq_s32(v_mask, v_src_round, v_src);
1149 
1150   // stage 2.
1151   ButterflyRotation_FirstIsZero(&s[0], &s[1], 60, true);
1152 
1153   // stage 3.
1154   s[4] = s[0];
1155   s[5] = s[1];
1156 
1157   // stage 4.
1158   ButterflyRotation_4(&s[4], &s[5], 48, true);
1159 
1160   // stage 5.
1161   s[2] = s[0];
1162   s[3] = s[1];
1163   s[6] = s[4];
1164   s[7] = s[5];
1165 
1166   // stage 6.
1167   ButterflyRotation_4(&s[2], &s[3], 32, true);
1168   ButterflyRotation_4(&s[6], &s[7], 32, true);
1169 
1170   // stage 7.
1171   int32x4_t x[8];
1172   x[0] = s[0];
1173   x[1] = vqnegq_s32(s[4]);
1174   x[2] = s[6];
1175   x[3] = vqnegq_s32(s[2]);
1176   x[4] = s[3];
1177   x[5] = vqnegq_s32(s[7]);
1178   x[6] = s[5];
1179   x[7] = vqnegq_s32(s[1]);
1180 
1181   for (int i = 0; i < 8; ++i) {
1182     // vqrshlq_s32 will shift right if shift value is negative.
1183     x[i] = vmovl_s16(vqmovn_s32(vqrshlq_s32(x[i], vdupq_n_s32(-row_shift))));
1184     vst1q_lane_s32(&dst[i], x[i], 0);
1185   }
1186 
1187   return true;
1188 }
1189 
Adst8DcOnlyColumn(void * dest,int adjusted_tx_height,int width)1190 LIBGAV1_ALWAYS_INLINE bool Adst8DcOnlyColumn(void* dest, int adjusted_tx_height,
1191                                              int width) {
1192   if (adjusted_tx_height > 1) return false;
1193 
1194   auto* dst = static_cast<int32_t*>(dest);
1195   int32x4_t s[8];
1196 
1197   int i = 0;
1198   do {
1199     const int32x4_t v_src = vld1q_s32(dst);
1200     // stage 1.
1201     s[1] = v_src;
1202 
1203     // stage 2.
1204     ButterflyRotation_FirstIsZero(&s[0], &s[1], 60, true);
1205 
1206     // stage 3.
1207     s[4] = s[0];
1208     s[5] = s[1];
1209 
1210     // stage 4.
1211     ButterflyRotation_4(&s[4], &s[5], 48, true);
1212 
1213     // stage 5.
1214     s[2] = s[0];
1215     s[3] = s[1];
1216     s[6] = s[4];
1217     s[7] = s[5];
1218 
1219     // stage 6.
1220     ButterflyRotation_4(&s[2], &s[3], 32, true);
1221     ButterflyRotation_4(&s[6], &s[7], 32, true);
1222 
1223     // stage 7.
1224     int32x4_t x[8];
1225     x[0] = s[0];
1226     x[1] = vqnegq_s32(s[4]);
1227     x[2] = s[6];
1228     x[3] = vqnegq_s32(s[2]);
1229     x[4] = s[3];
1230     x[5] = vqnegq_s32(s[7]);
1231     x[6] = s[5];
1232     x[7] = vqnegq_s32(s[1]);
1233 
1234     for (int j = 0; j < 8; ++j) {
1235       vst1q_s32(&dst[j * width], x[j]);
1236     }
1237     i += 4;
1238     dst += 4;
1239   } while (i < width);
1240 
1241   return true;
1242 }
1243 
1244 template <ButterflyRotationFunc butterfly_rotation>
Adst16_NEON(void * dest,int32_t step,bool is_row,int row_shift)1245 LIBGAV1_ALWAYS_INLINE void Adst16_NEON(void* dest, int32_t step, bool is_row,
1246                                        int row_shift) {
1247   auto* const dst = static_cast<int32_t*>(dest);
1248   const int32_t range = is_row ? kBitdepth10 + 7 : 15;
1249   const int32x4_t min = vdupq_n_s32(-(1 << range));
1250   const int32x4_t max = vdupq_n_s32((1 << range) - 1);
1251   int32x4_t s[16], x[16];
1252 
1253   if (is_row) {
1254     for (int idx = 0; idx < 16; idx += 8) {
1255       LoadSrc<4>(dst, step, idx, &x[idx]);
1256       LoadSrc<4>(dst, step, idx + 4, &x[idx + 4]);
1257       Transpose4x4(&x[idx], &x[idx]);
1258       Transpose4x4(&x[idx + 4], &x[idx + 4]);
1259     }
1260   } else {
1261     LoadSrc<16>(dst, step, 0, &x[0]);
1262   }
1263 
1264   // stage 1.
1265   s[0] = x[15];
1266   s[1] = x[0];
1267   s[2] = x[13];
1268   s[3] = x[2];
1269   s[4] = x[11];
1270   s[5] = x[4];
1271   s[6] = x[9];
1272   s[7] = x[6];
1273   s[8] = x[7];
1274   s[9] = x[8];
1275   s[10] = x[5];
1276   s[11] = x[10];
1277   s[12] = x[3];
1278   s[13] = x[12];
1279   s[14] = x[1];
1280   s[15] = x[14];
1281 
1282   // stage 2.
1283   butterfly_rotation(&s[0], &s[1], 62 - 0, true);
1284   butterfly_rotation(&s[2], &s[3], 62 - 8, true);
1285   butterfly_rotation(&s[4], &s[5], 62 - 16, true);
1286   butterfly_rotation(&s[6], &s[7], 62 - 24, true);
1287   butterfly_rotation(&s[8], &s[9], 62 - 32, true);
1288   butterfly_rotation(&s[10], &s[11], 62 - 40, true);
1289   butterfly_rotation(&s[12], &s[13], 62 - 48, true);
1290   butterfly_rotation(&s[14], &s[15], 62 - 56, true);
1291 
1292   // stage 3.
1293   HadamardRotation(&s[0], &s[8], false, min, max);
1294   HadamardRotation(&s[1], &s[9], false, min, max);
1295   HadamardRotation(&s[2], &s[10], false, min, max);
1296   HadamardRotation(&s[3], &s[11], false, min, max);
1297   HadamardRotation(&s[4], &s[12], false, min, max);
1298   HadamardRotation(&s[5], &s[13], false, min, max);
1299   HadamardRotation(&s[6], &s[14], false, min, max);
1300   HadamardRotation(&s[7], &s[15], false, min, max);
1301 
1302   // stage 4.
1303   butterfly_rotation(&s[8], &s[9], 56 - 0, true);
1304   butterfly_rotation(&s[13], &s[12], 8 + 0, true);
1305   butterfly_rotation(&s[10], &s[11], 56 - 32, true);
1306   butterfly_rotation(&s[15], &s[14], 8 + 32, true);
1307 
1308   // stage 5.
1309   HadamardRotation(&s[0], &s[4], false, min, max);
1310   HadamardRotation(&s[8], &s[12], false, min, max);
1311   HadamardRotation(&s[1], &s[5], false, min, max);
1312   HadamardRotation(&s[9], &s[13], false, min, max);
1313   HadamardRotation(&s[2], &s[6], false, min, max);
1314   HadamardRotation(&s[10], &s[14], false, min, max);
1315   HadamardRotation(&s[3], &s[7], false, min, max);
1316   HadamardRotation(&s[11], &s[15], false, min, max);
1317 
1318   // stage 6.
1319   butterfly_rotation(&s[4], &s[5], 48 - 0, true);
1320   butterfly_rotation(&s[12], &s[13], 48 - 0, true);
1321   butterfly_rotation(&s[7], &s[6], 48 - 32, true);
1322   butterfly_rotation(&s[15], &s[14], 48 - 32, true);
1323 
1324   // stage 7.
1325   HadamardRotation(&s[0], &s[2], false, min, max);
1326   HadamardRotation(&s[4], &s[6], false, min, max);
1327   HadamardRotation(&s[8], &s[10], false, min, max);
1328   HadamardRotation(&s[12], &s[14], false, min, max);
1329   HadamardRotation(&s[1], &s[3], false, min, max);
1330   HadamardRotation(&s[5], &s[7], false, min, max);
1331   HadamardRotation(&s[9], &s[11], false, min, max);
1332   HadamardRotation(&s[13], &s[15], false, min, max);
1333 
1334   // stage 8.
1335   butterfly_rotation(&s[2], &s[3], 32, true);
1336   butterfly_rotation(&s[6], &s[7], 32, true);
1337   butterfly_rotation(&s[10], &s[11], 32, true);
1338   butterfly_rotation(&s[14], &s[15], 32, true);
1339 
1340   // stage 9.
1341   x[0] = s[0];
1342   x[1] = vqnegq_s32(s[8]);
1343   x[2] = s[12];
1344   x[3] = vqnegq_s32(s[4]);
1345   x[4] = s[6];
1346   x[5] = vqnegq_s32(s[14]);
1347   x[6] = s[10];
1348   x[7] = vqnegq_s32(s[2]);
1349   x[8] = s[3];
1350   x[9] = vqnegq_s32(s[11]);
1351   x[10] = s[15];
1352   x[11] = vqnegq_s32(s[7]);
1353   x[12] = s[5];
1354   x[13] = vqnegq_s32(s[13]);
1355   x[14] = s[9];
1356   x[15] = vqnegq_s32(s[1]);
1357 
1358   if (is_row) {
1359     const int32x4_t v_row_shift = vdupq_n_s32(-row_shift);
1360     for (auto& i : x) {
1361       i = vmovl_s16(vqmovn_s32(vqrshlq_s32(i, v_row_shift)));
1362     }
1363     for (int idx = 0; idx < 16; idx += 8) {
1364       Transpose4x4(&x[idx], &x[idx]);
1365       Transpose4x4(&x[idx + 4], &x[idx + 4]);
1366       StoreDst<4>(dst, step, idx, &x[idx]);
1367       StoreDst<4>(dst, step, idx + 4, &x[idx + 4]);
1368     }
1369   } else {
1370     StoreDst<16>(dst, step, 0, &x[0]);
1371   }
1372 }
1373 
Adst16DcOnlyInternal(int32x4_t * s,int32x4_t * x)1374 LIBGAV1_ALWAYS_INLINE void Adst16DcOnlyInternal(int32x4_t* s, int32x4_t* x) {
1375   // stage 2.
1376   ButterflyRotation_FirstIsZero(&s[0], &s[1], 62, true);
1377 
1378   // stage 3.
1379   s[8] = s[0];
1380   s[9] = s[1];
1381 
1382   // stage 4.
1383   ButterflyRotation_4(&s[8], &s[9], 56, true);
1384 
1385   // stage 5.
1386   s[4] = s[0];
1387   s[12] = s[8];
1388   s[5] = s[1];
1389   s[13] = s[9];
1390 
1391   // stage 6.
1392   ButterflyRotation_4(&s[4], &s[5], 48, true);
1393   ButterflyRotation_4(&s[12], &s[13], 48, true);
1394 
1395   // stage 7.
1396   s[2] = s[0];
1397   s[6] = s[4];
1398   s[10] = s[8];
1399   s[14] = s[12];
1400   s[3] = s[1];
1401   s[7] = s[5];
1402   s[11] = s[9];
1403   s[15] = s[13];
1404 
1405   // stage 8.
1406   ButterflyRotation_4(&s[2], &s[3], 32, true);
1407   ButterflyRotation_4(&s[6], &s[7], 32, true);
1408   ButterflyRotation_4(&s[10], &s[11], 32, true);
1409   ButterflyRotation_4(&s[14], &s[15], 32, true);
1410 
1411   // stage 9.
1412   x[0] = s[0];
1413   x[1] = vqnegq_s32(s[8]);
1414   x[2] = s[12];
1415   x[3] = vqnegq_s32(s[4]);
1416   x[4] = s[6];
1417   x[5] = vqnegq_s32(s[14]);
1418   x[6] = s[10];
1419   x[7] = vqnegq_s32(s[2]);
1420   x[8] = s[3];
1421   x[9] = vqnegq_s32(s[11]);
1422   x[10] = s[15];
1423   x[11] = vqnegq_s32(s[7]);
1424   x[12] = s[5];
1425   x[13] = vqnegq_s32(s[13]);
1426   x[14] = s[9];
1427   x[15] = vqnegq_s32(s[1]);
1428 }
1429 
Adst16DcOnly(void * dest,int adjusted_tx_height,bool should_round,int row_shift)1430 LIBGAV1_ALWAYS_INLINE bool Adst16DcOnly(void* dest, int adjusted_tx_height,
1431                                         bool should_round, int row_shift) {
1432   if (adjusted_tx_height > 1) return false;
1433 
1434   auto* dst = static_cast<int32_t*>(dest);
1435   int32x4_t s[16];
1436   int32x4_t x[16];
1437   const int32x4_t v_src = vdupq_n_s32(dst[0]);
1438   const uint32x4_t v_mask = vdupq_n_u32(should_round ? 0xffffffff : 0);
1439   const int32x4_t v_src_round =
1440       vqrdmulhq_n_s32(v_src, kTransformRowMultiplier << (31 - 12));
1441   // stage 1.
1442   s[1] = vbslq_s32(v_mask, v_src_round, v_src);
1443 
1444   Adst16DcOnlyInternal(s, x);
1445 
1446   for (int i = 0; i < 16; ++i) {
1447     // vqrshlq_s32 will shift right if shift value is negative.
1448     x[i] = vmovl_s16(vqmovn_s32(vqrshlq_s32(x[i], vdupq_n_s32(-row_shift))));
1449     vst1q_lane_s32(&dst[i], x[i], 0);
1450   }
1451 
1452   return true;
1453 }
1454 
Adst16DcOnlyColumn(void * dest,int adjusted_tx_height,int width)1455 LIBGAV1_ALWAYS_INLINE bool Adst16DcOnlyColumn(void* dest,
1456                                               int adjusted_tx_height,
1457                                               int width) {
1458   if (adjusted_tx_height > 1) return false;
1459 
1460   auto* dst = static_cast<int32_t*>(dest);
1461   int i = 0;
1462   do {
1463     int32x4_t s[16];
1464     int32x4_t x[16];
1465     const int32x4_t v_src = vld1q_s32(dst);
1466     // stage 1.
1467     s[1] = v_src;
1468 
1469     Adst16DcOnlyInternal(s, x);
1470 
1471     for (int j = 0; j < 16; ++j) {
1472       vst1q_s32(&dst[j * width], x[j]);
1473     }
1474     i += 4;
1475     dst += 4;
1476   } while (i < width);
1477 
1478   return true;
1479 }
1480 
1481 //------------------------------------------------------------------------------
1482 // Identity Transforms.
1483 
Identity4_NEON(void * dest,int32_t step,int shift)1484 LIBGAV1_ALWAYS_INLINE void Identity4_NEON(void* dest, int32_t step, int shift) {
1485   auto* const dst = static_cast<int32_t*>(dest);
1486   const int32x4_t v_dual_round = vdupq_n_s32((1 + (shift << 1)) << 11);
1487   const int32x4_t v_multiplier = vdupq_n_s32(kIdentity4Multiplier);
1488   const int32x4_t v_shift = vdupq_n_s32(-(12 + shift));
1489   for (int i = 0; i < 4; ++i) {
1490     const int32x4_t v_src = vld1q_s32(&dst[i * step]);
1491     const int32x4_t v_src_mult_lo =
1492         vmlaq_s32(v_dual_round, v_src, v_multiplier);
1493     const int32x4_t shift_lo = vqshlq_s32(v_src_mult_lo, v_shift);
1494     vst1q_s32(&dst[i * step], vmovl_s16(vqmovn_s32(shift_lo)));
1495   }
1496 }
1497 
Identity4DcOnly(void * dest,int adjusted_tx_height,bool should_round,int tx_height)1498 LIBGAV1_ALWAYS_INLINE bool Identity4DcOnly(void* dest, int adjusted_tx_height,
1499                                            bool should_round, int tx_height) {
1500   if (adjusted_tx_height > 1) return false;
1501 
1502   auto* dst = static_cast<int32_t*>(dest);
1503   const int32x4_t v_src0 = vdupq_n_s32(dst[0]);
1504   const uint32x4_t v_mask = vdupq_n_u32(should_round ? 0xffffffff : 0);
1505   const int32x4_t v_src_round =
1506       vqrdmulhq_n_s32(v_src0, kTransformRowMultiplier << (31 - 12));
1507   const int32x4_t v_src = vbslq_s32(v_mask, v_src_round, v_src0);
1508   const int shift = tx_height < 16 ? 0 : 1;
1509   const int32x4_t v_dual_round = vdupq_n_s32((1 + (shift << 1)) << 11);
1510   const int32x4_t v_multiplier = vdupq_n_s32(kIdentity4Multiplier);
1511   const int32x4_t v_shift = vdupq_n_s32(-(12 + shift));
1512   const int32x4_t v_src_mult_lo = vmlaq_s32(v_dual_round, v_src, v_multiplier);
1513   const int32x4_t dst_0 = vqshlq_s32(v_src_mult_lo, v_shift);
1514   vst1q_lane_s32(dst, vmovl_s16(vqmovn_s32(dst_0)), 0);
1515   return true;
1516 }
1517 
1518 template <int identity_size>
IdentityColumnStoreToFrame(Array2DView<uint16_t> frame,const int start_x,const int start_y,const int tx_width,const int tx_height,const int32_t * LIBGAV1_RESTRICT source)1519 LIBGAV1_ALWAYS_INLINE void IdentityColumnStoreToFrame(
1520     Array2DView<uint16_t> frame, const int start_x, const int start_y,
1521     const int tx_width, const int tx_height,
1522     const int32_t* LIBGAV1_RESTRICT source) {
1523   static_assert(identity_size == 4 || identity_size == 8 ||
1524                     identity_size == 16 || identity_size == 32,
1525                 "Invalid identity_size.");
1526   const int stride = frame.columns();
1527   uint16_t* LIBGAV1_RESTRICT dst = frame[start_y] + start_x;
1528   const int32x4_t v_dual_round = vdupq_n_s32((1 + (1 << 4)) << 11);
1529   const uint16x4_t v_max_bitdepth = vdup_n_u16((1 << kBitdepth10) - 1);
1530 
1531   if (identity_size < 32) {
1532     if (tx_width == 4) {
1533       int i = 0;
1534       do {
1535         int32x4x2_t v_src, v_dst_i, a, b;
1536         v_src.val[0] = vld1q_s32(&source[i * 4]);
1537         v_src.val[1] = vld1q_s32(&source[(i * 4) + 4]);
1538         if (identity_size == 4) {
1539           v_dst_i.val[0] =
1540               vmlaq_n_s32(v_dual_round, v_src.val[0], kIdentity4Multiplier);
1541           v_dst_i.val[1] =
1542               vmlaq_n_s32(v_dual_round, v_src.val[1], kIdentity4Multiplier);
1543           a.val[0] = vshrq_n_s32(v_dst_i.val[0], 4 + 12);
1544           a.val[1] = vshrq_n_s32(v_dst_i.val[1], 4 + 12);
1545         } else if (identity_size == 8) {
1546           v_dst_i.val[0] = vaddq_s32(v_src.val[0], v_src.val[0]);
1547           v_dst_i.val[1] = vaddq_s32(v_src.val[1], v_src.val[1]);
1548           a.val[0] = vrshrq_n_s32(v_dst_i.val[0], 4);
1549           a.val[1] = vrshrq_n_s32(v_dst_i.val[1], 4);
1550         } else {  // identity_size == 16
1551           v_dst_i.val[0] =
1552               vmlaq_n_s32(v_dual_round, v_src.val[0], kIdentity16Multiplier);
1553           v_dst_i.val[1] =
1554               vmlaq_n_s32(v_dual_round, v_src.val[1], kIdentity16Multiplier);
1555           a.val[0] = vshrq_n_s32(v_dst_i.val[0], 4 + 12);
1556           a.val[1] = vshrq_n_s32(v_dst_i.val[1], 4 + 12);
1557         }
1558         uint16x4x2_t frame_data;
1559         frame_data.val[0] = vld1_u16(dst);
1560         frame_data.val[1] = vld1_u16(dst + stride);
1561         b.val[0] = vaddw_s16(a.val[0], vreinterpret_s16_u16(frame_data.val[0]));
1562         b.val[1] = vaddw_s16(a.val[1], vreinterpret_s16_u16(frame_data.val[1]));
1563         vst1_u16(dst, vmin_u16(vqmovun_s32(b.val[0]), v_max_bitdepth));
1564         vst1_u16(dst + stride, vmin_u16(vqmovun_s32(b.val[1]), v_max_bitdepth));
1565         dst += stride << 1;
1566         i += 2;
1567       } while (i < tx_height);
1568     } else {
1569       int i = 0;
1570       do {
1571         const int row = i * tx_width;
1572         int j = 0;
1573         do {
1574           int32x4x2_t v_src, v_dst_i, a, b;
1575           v_src.val[0] = vld1q_s32(&source[row + j]);
1576           v_src.val[1] = vld1q_s32(&source[row + j + 4]);
1577           if (identity_size == 4) {
1578             v_dst_i.val[0] =
1579                 vmlaq_n_s32(v_dual_round, v_src.val[0], kIdentity4Multiplier);
1580             v_dst_i.val[1] =
1581                 vmlaq_n_s32(v_dual_round, v_src.val[1], kIdentity4Multiplier);
1582             a.val[0] = vshrq_n_s32(v_dst_i.val[0], 4 + 12);
1583             a.val[1] = vshrq_n_s32(v_dst_i.val[1], 4 + 12);
1584           } else if (identity_size == 8) {
1585             v_dst_i.val[0] = vaddq_s32(v_src.val[0], v_src.val[0]);
1586             v_dst_i.val[1] = vaddq_s32(v_src.val[1], v_src.val[1]);
1587             a.val[0] = vrshrq_n_s32(v_dst_i.val[0], 4);
1588             a.val[1] = vrshrq_n_s32(v_dst_i.val[1], 4);
1589           } else {  // identity_size == 16
1590             v_dst_i.val[0] =
1591                 vmlaq_n_s32(v_dual_round, v_src.val[0], kIdentity16Multiplier);
1592             v_dst_i.val[1] =
1593                 vmlaq_n_s32(v_dual_round, v_src.val[1], kIdentity16Multiplier);
1594             a.val[0] = vshrq_n_s32(v_dst_i.val[0], 4 + 12);
1595             a.val[1] = vshrq_n_s32(v_dst_i.val[1], 4 + 12);
1596           }
1597           uint16x4x2_t frame_data;
1598           frame_data.val[0] = vld1_u16(dst + j);
1599           frame_data.val[1] = vld1_u16(dst + j + 4);
1600           b.val[0] =
1601               vaddw_s16(a.val[0], vreinterpret_s16_u16(frame_data.val[0]));
1602           b.val[1] =
1603               vaddw_s16(a.val[1], vreinterpret_s16_u16(frame_data.val[1]));
1604           vst1_u16(dst + j, vmin_u16(vqmovun_s32(b.val[0]), v_max_bitdepth));
1605           vst1_u16(dst + j + 4,
1606                    vmin_u16(vqmovun_s32(b.val[1]), v_max_bitdepth));
1607           j += 8;
1608         } while (j < tx_width);
1609         dst += stride;
1610       } while (++i < tx_height);
1611     }
1612   } else {
1613     int i = 0;
1614     do {
1615       const int row = i * tx_width;
1616       int j = 0;
1617       do {
1618         const int32x4_t v_dst_i = vld1q_s32(&source[row + j]);
1619         const uint16x4_t frame_data = vld1_u16(dst + j);
1620         const int32x4_t a = vrshrq_n_s32(v_dst_i, 2);
1621         const int32x4_t b = vaddw_s16(a, vreinterpret_s16_u16(frame_data));
1622         const uint16x4_t d = vmin_u16(vqmovun_s32(b), v_max_bitdepth);
1623         vst1_u16(dst + j, d);
1624         j += 4;
1625       } while (j < tx_width);
1626       dst += stride;
1627     } while (++i < tx_height);
1628   }
1629 }
1630 
Identity4RowColumnStoreToFrame(Array2DView<uint16_t> frame,const int start_x,const int start_y,const int tx_width,const int tx_height,const int32_t * LIBGAV1_RESTRICT source)1631 LIBGAV1_ALWAYS_INLINE void Identity4RowColumnStoreToFrame(
1632     Array2DView<uint16_t> frame, const int start_x, const int start_y,
1633     const int tx_width, const int tx_height,
1634     const int32_t* LIBGAV1_RESTRICT source) {
1635   const int stride = frame.columns();
1636   uint16_t* LIBGAV1_RESTRICT dst = frame[start_y] + start_x;
1637   const int32x4_t v_round = vdupq_n_s32((1 + (0)) << 11);
1638   const uint16x4_t v_max_bitdepth = vdup_n_u16((1 << kBitdepth10) - 1);
1639 
1640   if (tx_width == 4) {
1641     int i = 0;
1642     do {
1643       const int32x4_t v_src = vld1q_s32(&source[i * 4]);
1644       const int32x4_t v_dst_row =
1645           vshrq_n_s32(vmlaq_n_s32(v_round, v_src, kIdentity4Multiplier), 12);
1646       const int32x4_t v_dst_col =
1647           vmlaq_n_s32(v_round, v_dst_row, kIdentity4Multiplier);
1648       const uint16x4_t frame_data = vld1_u16(dst);
1649       const int32x4_t a = vrshrq_n_s32(v_dst_col, 4 + 12);
1650       const int32x4_t b = vaddw_s16(a, vreinterpret_s16_u16(frame_data));
1651       vst1_u16(dst, vmin_u16(vqmovun_s32(b), v_max_bitdepth));
1652       dst += stride;
1653     } while (++i < tx_height);
1654   } else {
1655     int i = 0;
1656     do {
1657       const int row = i * tx_width;
1658       int j = 0;
1659       do {
1660         int32x4x2_t v_src, v_src_round, v_dst_row, v_dst_col, a, b;
1661         v_src.val[0] = vld1q_s32(&source[row + j]);
1662         v_src.val[1] = vld1q_s32(&source[row + j + 4]);
1663         v_src_round.val[0] = vshrq_n_s32(
1664             vmlaq_n_s32(v_round, v_src.val[0], kTransformRowMultiplier), 12);
1665         v_src_round.val[1] = vshrq_n_s32(
1666             vmlaq_n_s32(v_round, v_src.val[1], kTransformRowMultiplier), 12);
1667         v_dst_row.val[0] = vqaddq_s32(v_src_round.val[0], v_src_round.val[0]);
1668         v_dst_row.val[1] = vqaddq_s32(v_src_round.val[1], v_src_round.val[1]);
1669         v_dst_col.val[0] =
1670             vmlaq_n_s32(v_round, v_dst_row.val[0], kIdentity4Multiplier);
1671         v_dst_col.val[1] =
1672             vmlaq_n_s32(v_round, v_dst_row.val[1], kIdentity4Multiplier);
1673         uint16x4x2_t frame_data;
1674         frame_data.val[0] = vld1_u16(dst + j);
1675         frame_data.val[1] = vld1_u16(dst + j + 4);
1676         a.val[0] = vrshrq_n_s32(v_dst_col.val[0], 4 + 12);
1677         a.val[1] = vrshrq_n_s32(v_dst_col.val[1], 4 + 12);
1678         b.val[0] = vaddw_s16(a.val[0], vreinterpret_s16_u16(frame_data.val[0]));
1679         b.val[1] = vaddw_s16(a.val[1], vreinterpret_s16_u16(frame_data.val[1]));
1680         vst1_u16(dst + j, vmin_u16(vqmovun_s32(b.val[0]), v_max_bitdepth));
1681         vst1_u16(dst + j + 4, vmin_u16(vqmovun_s32(b.val[1]), v_max_bitdepth));
1682         j += 8;
1683       } while (j < tx_width);
1684       dst += stride;
1685     } while (++i < tx_height);
1686   }
1687 }
1688 
Identity8Row32_NEON(void * dest,int32_t step)1689 LIBGAV1_ALWAYS_INLINE void Identity8Row32_NEON(void* dest, int32_t step) {
1690   auto* const dst = static_cast<int32_t*>(dest);
1691 
1692   // When combining the identity8 multiplier with the row shift, the
1693   // calculations for tx_height equal to 32 can be simplified from
1694   // ((A * 2) + 2) >> 2) to ((A + 1) >> 1).
1695   for (int i = 0; i < 4; ++i) {
1696     const int32x4_t v_src_lo = vld1q_s32(&dst[i * step]);
1697     const int32x4_t v_src_hi = vld1q_s32(&dst[(i * step) + 4]);
1698     const int32x4_t a_lo = vrshrq_n_s32(v_src_lo, 1);
1699     const int32x4_t a_hi = vrshrq_n_s32(v_src_hi, 1);
1700     vst1q_s32(&dst[i * step], vmovl_s16(vqmovn_s32(a_lo)));
1701     vst1q_s32(&dst[(i * step) + 4], vmovl_s16(vqmovn_s32(a_hi)));
1702   }
1703 }
1704 
Identity8Row4_NEON(void * dest,int32_t step)1705 LIBGAV1_ALWAYS_INLINE void Identity8Row4_NEON(void* dest, int32_t step) {
1706   auto* const dst = static_cast<int32_t*>(dest);
1707 
1708   for (int i = 0; i < 4; ++i) {
1709     const int32x4_t v_src_lo = vld1q_s32(&dst[i * step]);
1710     const int32x4_t v_src_hi = vld1q_s32(&dst[(i * step) + 4]);
1711     const int32x4_t v_srcx2_lo = vqaddq_s32(v_src_lo, v_src_lo);
1712     const int32x4_t v_srcx2_hi = vqaddq_s32(v_src_hi, v_src_hi);
1713     vst1q_s32(&dst[i * step], vmovl_s16(vqmovn_s32(v_srcx2_lo)));
1714     vst1q_s32(&dst[(i * step) + 4], vmovl_s16(vqmovn_s32(v_srcx2_hi)));
1715   }
1716 }
1717 
Identity8DcOnly(void * dest,int adjusted_tx_height,bool should_round,int row_shift)1718 LIBGAV1_ALWAYS_INLINE bool Identity8DcOnly(void* dest, int adjusted_tx_height,
1719                                            bool should_round, int row_shift) {
1720   if (adjusted_tx_height > 1) return false;
1721 
1722   auto* dst = static_cast<int32_t*>(dest);
1723   const int32x4_t v_src0 = vdupq_n_s32(dst[0]);
1724   const uint32x4_t v_mask = vdupq_n_u32(should_round ? 0xffffffff : 0);
1725   const int32x4_t v_src_round =
1726       vqrdmulhq_n_s32(v_src0, kTransformRowMultiplier << (31 - 12));
1727   const int32x4_t v_src = vbslq_s32(v_mask, v_src_round, v_src0);
1728   const int32x4_t v_srcx2 = vaddq_s32(v_src, v_src);
1729   const int32x4_t dst_0 = vqrshlq_s32(v_srcx2, vdupq_n_s32(-row_shift));
1730   vst1q_lane_s32(dst, vmovl_s16(vqmovn_s32(dst_0)), 0);
1731   return true;
1732 }
1733 
Identity16Row_NEON(void * dest,int32_t step,int shift)1734 LIBGAV1_ALWAYS_INLINE void Identity16Row_NEON(void* dest, int32_t step,
1735                                               int shift) {
1736   auto* const dst = static_cast<int32_t*>(dest);
1737   const int32x4_t v_dual_round = vdupq_n_s32((1 + (shift << 1)) << 11);
1738   const int32x4_t v_shift = vdupq_n_s32(-(12 + shift));
1739 
1740   for (int i = 0; i < 4; ++i) {
1741     for (int j = 0; j < 2; ++j) {
1742       int32x4x2_t v_src;
1743       v_src.val[0] = vld1q_s32(&dst[i * step + j * 8]);
1744       v_src.val[1] = vld1q_s32(&dst[i * step + j * 8 + 4]);
1745       const int32x4_t v_src_mult_lo =
1746           vmlaq_n_s32(v_dual_round, v_src.val[0], kIdentity16Multiplier);
1747       const int32x4_t v_src_mult_hi =
1748           vmlaq_n_s32(v_dual_round, v_src.val[1], kIdentity16Multiplier);
1749       const int32x4_t shift_lo = vqshlq_s32(v_src_mult_lo, v_shift);
1750       const int32x4_t shift_hi = vqshlq_s32(v_src_mult_hi, v_shift);
1751       vst1q_s32(&dst[i * step + j * 8], vmovl_s16(vqmovn_s32(shift_lo)));
1752       vst1q_s32(&dst[i * step + j * 8 + 4], vmovl_s16(vqmovn_s32(shift_hi)));
1753     }
1754   }
1755 }
1756 
Identity16DcOnly(void * dest,int adjusted_tx_height,bool should_round,int shift)1757 LIBGAV1_ALWAYS_INLINE bool Identity16DcOnly(void* dest, int adjusted_tx_height,
1758                                             bool should_round, int shift) {
1759   if (adjusted_tx_height > 1) return false;
1760 
1761   auto* dst = static_cast<int32_t*>(dest);
1762   const int32x4_t v_src0 = vdupq_n_s32(dst[0]);
1763   const uint32x4_t v_mask = vdupq_n_u32(should_round ? 0xffffffff : 0);
1764   const int32x4_t v_src_round =
1765       vqrdmulhq_n_s32(v_src0, kTransformRowMultiplier << (31 - 12));
1766   const int32x4_t v_src = vbslq_s32(v_mask, v_src_round, v_src0);
1767   const int32x4_t v_dual_round = vdupq_n_s32((1 + (shift << 1)) << 11);
1768   const int32x4_t v_src_mult_lo =
1769       vmlaq_n_s32(v_dual_round, v_src, kIdentity16Multiplier);
1770   const int32x4_t dst_0 = vqshlq_s32(v_src_mult_lo, vdupq_n_s32(-(12 + shift)));
1771   vst1q_lane_s32(dst, vmovl_s16(vqmovn_s32(dst_0)), 0);
1772   return true;
1773 }
1774 
Identity32Row16_NEON(void * dest,const int32_t step)1775 LIBGAV1_ALWAYS_INLINE void Identity32Row16_NEON(void* dest,
1776                                                 const int32_t step) {
1777   auto* const dst = static_cast<int32_t*>(dest);
1778 
1779   // When combining the identity32 multiplier with the row shift, the
1780   // calculation for tx_height equal to 16 can be simplified from
1781   // ((A * 4) + 1) >> 1) to (A * 2).
1782   for (int i = 0; i < 4; ++i) {
1783     for (int j = 0; j < 32; j += 4) {
1784       const int32x4_t v_src = vld1q_s32(&dst[i * step + j]);
1785       const int32x4_t v_dst_i = vqaddq_s32(v_src, v_src);
1786       vst1q_s32(&dst[i * step + j], v_dst_i);
1787     }
1788   }
1789 }
1790 
Identity32DcOnly(void * dest,int adjusted_tx_height)1791 LIBGAV1_ALWAYS_INLINE bool Identity32DcOnly(void* dest,
1792                                             int adjusted_tx_height) {
1793   if (adjusted_tx_height > 1) return false;
1794 
1795   auto* dst = static_cast<int32_t*>(dest);
1796   const int32x2_t v_src0 = vdup_n_s32(dst[0]);
1797   const int32x2_t v_src =
1798       vqrdmulh_n_s32(v_src0, kTransformRowMultiplier << (31 - 12));
1799   // When combining the identity32 multiplier with the row shift, the
1800   // calculation for tx_height equal to 16 can be simplified from
1801   // ((A * 4) + 1) >> 1) to (A * 2).
1802   const int32x2_t v_dst_0 = vqadd_s32(v_src, v_src);
1803   vst1_lane_s32(dst, v_dst_0, 0);
1804   return true;
1805 }
1806 
1807 //------------------------------------------------------------------------------
1808 // Walsh Hadamard Transform.
1809 
1810 // Process 4 wht4 rows and columns.
Wht4_NEON(uint16_t * LIBGAV1_RESTRICT dst,const int dst_stride,const void * LIBGAV1_RESTRICT source,const int adjusted_tx_height)1811 LIBGAV1_ALWAYS_INLINE void Wht4_NEON(uint16_t* LIBGAV1_RESTRICT dst,
1812                                      const int dst_stride,
1813                                      const void* LIBGAV1_RESTRICT source,
1814                                      const int adjusted_tx_height) {
1815   const auto* const src = static_cast<const int32_t*>(source);
1816   int32x4_t s[4];
1817 
1818   if (adjusted_tx_height == 1) {
1819     // Special case: only src[0] is nonzero.
1820     //   src[0]  0   0   0
1821     //       0   0   0   0
1822     //       0   0   0   0
1823     //       0   0   0   0
1824     //
1825     // After the row and column transforms are applied, we have:
1826     //       f   h   h   h
1827     //       g   i   i   i
1828     //       g   i   i   i
1829     //       g   i   i   i
1830     // where f, g, h, i are computed as follows.
1831     int32_t f = (src[0] >> 2) - (src[0] >> 3);
1832     const int32_t g = f >> 1;
1833     f = f - (f >> 1);
1834     const int32_t h = (src[0] >> 3) - (src[0] >> 4);
1835     const int32_t i = (src[0] >> 4);
1836     s[0] = vdupq_n_s32(h);
1837     s[0] = vsetq_lane_s32(f, s[0], 0);
1838     s[1] = vdupq_n_s32(i);
1839     s[1] = vsetq_lane_s32(g, s[1], 0);
1840     s[2] = s[3] = s[1];
1841   } else {
1842     // Load the 4x4 source in transposed form.
1843     int32x4x4_t columns = vld4q_s32(src);
1844 
1845     // Shift right and permute the columns for the WHT.
1846     s[0] = vshrq_n_s32(columns.val[0], 2);
1847     s[2] = vshrq_n_s32(columns.val[1], 2);
1848     s[3] = vshrq_n_s32(columns.val[2], 2);
1849     s[1] = vshrq_n_s32(columns.val[3], 2);
1850 
1851     // Row transforms.
1852     s[0] = vaddq_s32(s[0], s[2]);
1853     s[3] = vsubq_s32(s[3], s[1]);
1854     int32x4_t e = vhsubq_s32(s[0], s[3]);  // e = (s[0] - s[3]) >> 1
1855     s[1] = vsubq_s32(e, s[1]);
1856     s[2] = vsubq_s32(e, s[2]);
1857     s[0] = vsubq_s32(s[0], s[1]);
1858     s[3] = vaddq_s32(s[3], s[2]);
1859 
1860     int32x4_t x[4];
1861     Transpose4x4(s, x);
1862 
1863     s[0] = x[0];
1864     s[2] = x[1];
1865     s[3] = x[2];
1866     s[1] = x[3];
1867 
1868     // Column transforms.
1869     s[0] = vaddq_s32(s[0], s[2]);
1870     s[3] = vsubq_s32(s[3], s[1]);
1871     e = vhsubq_s32(s[0], s[3]);  // e = (s[0] - s[3]) >> 1
1872     s[1] = vsubq_s32(e, s[1]);
1873     s[2] = vsubq_s32(e, s[2]);
1874     s[0] = vsubq_s32(s[0], s[1]);
1875     s[3] = vaddq_s32(s[3], s[2]);
1876   }
1877 
1878   // Store to frame.
1879   const uint16x4_t v_max_bitdepth = vdup_n_u16((1 << kBitdepth10) - 1);
1880   for (int row = 0; row < 4; row += 1) {
1881     const uint16x4_t frame_data = vld1_u16(dst);
1882     const int32x4_t b = vaddw_s16(s[row], vreinterpret_s16_u16(frame_data));
1883     vst1_u16(dst, vmin_u16(vqmovun_s32(b), v_max_bitdepth));
1884     dst += dst_stride;
1885   }
1886 }
1887 
1888 //------------------------------------------------------------------------------
1889 // row/column transform loops
1890 
1891 template <int tx_height>
FlipColumns(int32_t * source,int tx_width)1892 LIBGAV1_ALWAYS_INLINE void FlipColumns(int32_t* source, int tx_width) {
1893   if (tx_width >= 16) {
1894     int i = 0;
1895     do {
1896       // 00 01 02 03
1897       const int32x4_t a = vld1q_s32(&source[i]);
1898       const int32x4_t b = vld1q_s32(&source[i + 4]);
1899       const int32x4_t c = vld1q_s32(&source[i + 8]);
1900       const int32x4_t d = vld1q_s32(&source[i + 12]);
1901       // 01 00 03 02
1902       const int32x4_t a_rev = vrev64q_s32(a);
1903       const int32x4_t b_rev = vrev64q_s32(b);
1904       const int32x4_t c_rev = vrev64q_s32(c);
1905       const int32x4_t d_rev = vrev64q_s32(d);
1906       // 03 02 01 00
1907       vst1q_s32(&source[i], vextq_s32(d_rev, d_rev, 2));
1908       vst1q_s32(&source[i + 4], vextq_s32(c_rev, c_rev, 2));
1909       vst1q_s32(&source[i + 8], vextq_s32(b_rev, b_rev, 2));
1910       vst1q_s32(&source[i + 12], vextq_s32(a_rev, a_rev, 2));
1911       i += 16;
1912     } while (i < tx_width * tx_height);
1913   } else if (tx_width == 8) {
1914     for (int i = 0; i < 8 * tx_height; i += 8) {
1915       // 00 01 02 03
1916       const int32x4_t a = vld1q_s32(&source[i]);
1917       const int32x4_t b = vld1q_s32(&source[i + 4]);
1918       // 01 00 03 02
1919       const int32x4_t a_rev = vrev64q_s32(a);
1920       const int32x4_t b_rev = vrev64q_s32(b);
1921       // 03 02 01 00
1922       vst1q_s32(&source[i], vextq_s32(b_rev, b_rev, 2));
1923       vst1q_s32(&source[i + 4], vextq_s32(a_rev, a_rev, 2));
1924     }
1925   } else {
1926     // Process two rows per iteration.
1927     for (int i = 0; i < 4 * tx_height; i += 8) {
1928       // 00 01 02 03
1929       const int32x4_t a = vld1q_s32(&source[i]);
1930       const int32x4_t b = vld1q_s32(&source[i + 4]);
1931       // 01 00 03 02
1932       const int32x4_t a_rev = vrev64q_s32(a);
1933       const int32x4_t b_rev = vrev64q_s32(b);
1934       // 03 02 01 00
1935       vst1q_s32(&source[i], vextq_s32(a_rev, a_rev, 2));
1936       vst1q_s32(&source[i + 4], vextq_s32(b_rev, b_rev, 2));
1937     }
1938   }
1939 }
1940 
1941 template <int tx_width>
ApplyRounding(int32_t * source,int num_rows)1942 LIBGAV1_ALWAYS_INLINE void ApplyRounding(int32_t* source, int num_rows) {
1943   // Process two rows per iteration.
1944   int i = 0;
1945   do {
1946     const int32x4_t a_lo = vld1q_s32(&source[i]);
1947     const int32x4_t a_hi = vld1q_s32(&source[i + 4]);
1948     const int32x4_t b_lo =
1949         vqrdmulhq_n_s32(a_lo, kTransformRowMultiplier << (31 - 12));
1950     const int32x4_t b_hi =
1951         vqrdmulhq_n_s32(a_hi, kTransformRowMultiplier << (31 - 12));
1952     vst1q_s32(&source[i], b_lo);
1953     vst1q_s32(&source[i + 4], b_hi);
1954     i += 8;
1955   } while (i < tx_width * num_rows);
1956 }
1957 
1958 template <int tx_width>
RowShift(int32_t * source,int num_rows,int row_shift)1959 LIBGAV1_ALWAYS_INLINE void RowShift(int32_t* source, int num_rows,
1960                                     int row_shift) {
1961   // vqrshlq_s32 will shift right if shift value is negative.
1962   row_shift = -row_shift;
1963 
1964   // Process two rows per iteration.
1965   int i = 0;
1966   do {
1967     const int32x4_t residual0 = vld1q_s32(&source[i]);
1968     const int32x4_t residual1 = vld1q_s32(&source[i + 4]);
1969     vst1q_s32(&source[i], vqrshlq_s32(residual0, vdupq_n_s32(row_shift)));
1970     vst1q_s32(&source[i + 4], vqrshlq_s32(residual1, vdupq_n_s32(row_shift)));
1971     i += 8;
1972   } while (i < tx_width * num_rows);
1973 }
1974 
1975 template <int tx_height, bool enable_flip_rows = false>
StoreToFrameWithRound(Array2DView<uint16_t> frame,const int start_x,const int start_y,const int tx_width,const int32_t * LIBGAV1_RESTRICT source,TransformType tx_type)1976 LIBGAV1_ALWAYS_INLINE void StoreToFrameWithRound(
1977     Array2DView<uint16_t> frame, const int start_x, const int start_y,
1978     const int tx_width, const int32_t* LIBGAV1_RESTRICT source,
1979     TransformType tx_type) {
1980   const bool flip_rows =
1981       enable_flip_rows ? kTransformFlipRowsMask.Contains(tx_type) : false;
1982   const int stride = frame.columns();
1983   uint16_t* LIBGAV1_RESTRICT dst = frame[start_y] + start_x;
1984 
1985   if (tx_width == 4) {
1986     for (int i = 0; i < tx_height; ++i) {
1987       const int row = flip_rows ? (tx_height - i - 1) * 4 : i * 4;
1988       const int32x4_t residual = vld1q_s32(&source[row]);
1989       const uint16x4_t frame_data = vld1_u16(dst);
1990       const int32x4_t a = vrshrq_n_s32(residual, 4);
1991       const uint32x4_t b = vaddw_u16(vreinterpretq_u32_s32(a), frame_data);
1992       const uint16x4_t d = vqmovun_s32(vreinterpretq_s32_u32(b));
1993       vst1_u16(dst, vmin_u16(d, vdup_n_u16((1 << kBitdepth10) - 1)));
1994       dst += stride;
1995     }
1996   } else {
1997     for (int i = 0; i < tx_height; ++i) {
1998       const int y = start_y + i;
1999       const int row = flip_rows ? (tx_height - i - 1) * tx_width : i * tx_width;
2000       int j = 0;
2001       do {
2002         const int x = start_x + j;
2003         const int32x4_t residual = vld1q_s32(&source[row + j]);
2004         const int32x4_t residual_hi = vld1q_s32(&source[row + j + 4]);
2005         const uint16x8_t frame_data = vld1q_u16(frame[y] + x);
2006         const int32x4_t a = vrshrq_n_s32(residual, 4);
2007         const int32x4_t a_hi = vrshrq_n_s32(residual_hi, 4);
2008         const uint32x4_t b =
2009             vaddw_u16(vreinterpretq_u32_s32(a), vget_low_u16(frame_data));
2010         const uint32x4_t b_hi =
2011             vaddw_u16(vreinterpretq_u32_s32(a_hi), vget_high_u16(frame_data));
2012         const uint16x4_t d = vqmovun_s32(vreinterpretq_s32_u32(b));
2013         const uint16x4_t d_hi = vqmovun_s32(vreinterpretq_s32_u32(b_hi));
2014         vst1q_u16(frame[y] + x, vminq_u16(vcombine_u16(d, d_hi),
2015                                           vdupq_n_u16((1 << kBitdepth10) - 1)));
2016         j += 8;
2017       } while (j < tx_width);
2018     }
2019   }
2020 }
2021 
Dct4TransformLoopRow_NEON(TransformType,TransformSize tx_size,int adjusted_tx_height,void * src_buffer,int,int,void *)2022 void Dct4TransformLoopRow_NEON(TransformType /*tx_type*/, TransformSize tx_size,
2023                                int adjusted_tx_height, void* src_buffer,
2024                                int /*start_x*/, int /*start_y*/,
2025                                void* /*dst_frame*/) {
2026   auto* src = static_cast<int32_t*>(src_buffer);
2027   const int tx_height = kTransformHeight[tx_size];
2028   const bool should_round = (tx_height == 8);
2029   const int row_shift = static_cast<int>(tx_height == 16);
2030 
2031   if (DctDcOnly<4>(src, adjusted_tx_height, should_round, row_shift)) {
2032     return;
2033   }
2034 
2035   if (should_round) {
2036     ApplyRounding<4>(src, adjusted_tx_height);
2037   }
2038 
2039   // Process 4 1d dct4 rows in parallel per iteration.
2040   int i = adjusted_tx_height;
2041   auto* data = src;
2042   do {
2043     Dct4_NEON<ButterflyRotation_4>(data, /*step=*/4, /*is_row=*/true,
2044                                    row_shift);
2045     data += 16;
2046     i -= 4;
2047   } while (i != 0);
2048 }
2049 
Dct4TransformLoopColumn_NEON(TransformType tx_type,TransformSize tx_size,int adjusted_tx_height,void * LIBGAV1_RESTRICT src_buffer,int start_x,int start_y,void * LIBGAV1_RESTRICT dst_frame)2050 void Dct4TransformLoopColumn_NEON(TransformType tx_type, TransformSize tx_size,
2051                                   int adjusted_tx_height,
2052                                   void* LIBGAV1_RESTRICT src_buffer,
2053                                   int start_x, int start_y,
2054                                   void* LIBGAV1_RESTRICT dst_frame) {
2055   auto* src = static_cast<int32_t*>(src_buffer);
2056   const int tx_width = kTransformWidth[tx_size];
2057 
2058   if (kTransformFlipColumnsMask.Contains(tx_type)) {
2059     FlipColumns<4>(src, tx_width);
2060   }
2061 
2062   if (!DctDcOnlyColumn<4>(src, adjusted_tx_height, tx_width)) {
2063     // Process 4 1d dct4 columns in parallel per iteration.
2064     int i = tx_width;
2065     auto* data = src;
2066     do {
2067       Dct4_NEON<ButterflyRotation_4>(data, tx_width, /*transpose=*/false,
2068                                      /*row_shift=*/0);
2069       data += 4;
2070       i -= 4;
2071     } while (i != 0);
2072   }
2073 
2074   auto& frame = *static_cast<Array2DView<uint16_t>*>(dst_frame);
2075   StoreToFrameWithRound<4>(frame, start_x, start_y, tx_width, src, tx_type);
2076 }
2077 
Dct8TransformLoopRow_NEON(TransformType,TransformSize tx_size,int adjusted_tx_height,void * src_buffer,int,int,void *)2078 void Dct8TransformLoopRow_NEON(TransformType /*tx_type*/, TransformSize tx_size,
2079                                int adjusted_tx_height, void* src_buffer,
2080                                int /*start_x*/, int /*start_y*/,
2081                                void* /*dst_frame*/) {
2082   auto* src = static_cast<int32_t*>(src_buffer);
2083   const bool should_round = kShouldRound[tx_size];
2084   const uint8_t row_shift = kTransformRowShift[tx_size];
2085 
2086   if (DctDcOnly<8>(src, adjusted_tx_height, should_round, row_shift)) {
2087     return;
2088   }
2089 
2090   if (should_round) {
2091     ApplyRounding<8>(src, adjusted_tx_height);
2092   }
2093 
2094   // Process 4 1d dct8 rows in parallel per iteration.
2095   int i = adjusted_tx_height;
2096   auto* data = src;
2097   do {
2098     Dct8_NEON<ButterflyRotation_4>(data, /*step=*/8, /*is_row=*/true,
2099                                    row_shift);
2100     data += 32;
2101     i -= 4;
2102   } while (i != 0);
2103 }
2104 
Dct8TransformLoopColumn_NEON(TransformType tx_type,TransformSize tx_size,int adjusted_tx_height,void * LIBGAV1_RESTRICT src_buffer,int start_x,int start_y,void * LIBGAV1_RESTRICT dst_frame)2105 void Dct8TransformLoopColumn_NEON(TransformType tx_type, TransformSize tx_size,
2106                                   int adjusted_tx_height,
2107                                   void* LIBGAV1_RESTRICT src_buffer,
2108                                   int start_x, int start_y,
2109                                   void* LIBGAV1_RESTRICT dst_frame) {
2110   auto* src = static_cast<int32_t*>(src_buffer);
2111   const int tx_width = kTransformWidth[tx_size];
2112 
2113   if (kTransformFlipColumnsMask.Contains(tx_type)) {
2114     FlipColumns<8>(src, tx_width);
2115   }
2116 
2117   if (!DctDcOnlyColumn<8>(src, adjusted_tx_height, tx_width)) {
2118     // Process 4 1d dct8 columns in parallel per iteration.
2119     int i = tx_width;
2120     auto* data = src;
2121     do {
2122       Dct8_NEON<ButterflyRotation_4>(data, tx_width, /*is_row=*/false,
2123                                      /*row_shift=*/0);
2124       data += 4;
2125       i -= 4;
2126     } while (i != 0);
2127   }
2128   auto& frame = *static_cast<Array2DView<uint16_t>*>(dst_frame);
2129   StoreToFrameWithRound<8>(frame, start_x, start_y, tx_width, src, tx_type);
2130 }
2131 
Dct16TransformLoopRow_NEON(TransformType,TransformSize tx_size,int adjusted_tx_height,void * src_buffer,int,int,void *)2132 void Dct16TransformLoopRow_NEON(TransformType /*tx_type*/,
2133                                 TransformSize tx_size, int adjusted_tx_height,
2134                                 void* src_buffer, int /*start_x*/,
2135                                 int /*start_y*/, void* /*dst_frame*/) {
2136   auto* src = static_cast<int32_t*>(src_buffer);
2137   const bool should_round = kShouldRound[tx_size];
2138   const uint8_t row_shift = kTransformRowShift[tx_size];
2139 
2140   if (DctDcOnly<16>(src, adjusted_tx_height, should_round, row_shift)) {
2141     return;
2142   }
2143 
2144   if (should_round) {
2145     ApplyRounding<16>(src, adjusted_tx_height);
2146   }
2147 
2148   assert(adjusted_tx_height % 4 == 0);
2149   int i = adjusted_tx_height;
2150   auto* data = src;
2151   do {
2152     // Process 4 1d dct16 rows in parallel per iteration.
2153     Dct16_NEON<ButterflyRotation_4>(data, 16, /*is_row=*/true, row_shift);
2154     data += 64;
2155     i -= 4;
2156   } while (i != 0);
2157 }
2158 
Dct16TransformLoopColumn_NEON(TransformType tx_type,TransformSize tx_size,int adjusted_tx_height,void * LIBGAV1_RESTRICT src_buffer,int start_x,int start_y,void * LIBGAV1_RESTRICT dst_frame)2159 void Dct16TransformLoopColumn_NEON(TransformType tx_type, TransformSize tx_size,
2160                                    int adjusted_tx_height,
2161                                    void* LIBGAV1_RESTRICT src_buffer,
2162                                    int start_x, int start_y,
2163                                    void* LIBGAV1_RESTRICT dst_frame) {
2164   auto* src = static_cast<int32_t*>(src_buffer);
2165   const int tx_width = kTransformWidth[tx_size];
2166 
2167   if (kTransformFlipColumnsMask.Contains(tx_type)) {
2168     FlipColumns<16>(src, tx_width);
2169   }
2170 
2171   if (!DctDcOnlyColumn<16>(src, adjusted_tx_height, tx_width)) {
2172     // Process 4 1d dct16 columns in parallel per iteration.
2173     int i = tx_width;
2174     auto* data = src;
2175     do {
2176       Dct16_NEON<ButterflyRotation_4>(data, tx_width, /*is_row=*/false,
2177                                       /*row_shift=*/0);
2178       data += 4;
2179       i -= 4;
2180     } while (i != 0);
2181   }
2182   auto& frame = *static_cast<Array2DView<uint16_t>*>(dst_frame);
2183   StoreToFrameWithRound<16>(frame, start_x, start_y, tx_width, src, tx_type);
2184 }
2185 
Dct32TransformLoopRow_NEON(TransformType,TransformSize tx_size,int adjusted_tx_height,void * src_buffer,int,int,void *)2186 void Dct32TransformLoopRow_NEON(TransformType /*tx_type*/,
2187                                 TransformSize tx_size, int adjusted_tx_height,
2188                                 void* src_buffer, int /*start_x*/,
2189                                 int /*start_y*/, void* /*dst_frame*/) {
2190   auto* src = static_cast<int32_t*>(src_buffer);
2191   const bool should_round = kShouldRound[tx_size];
2192   const uint8_t row_shift = kTransformRowShift[tx_size];
2193 
2194   if (DctDcOnly<32>(src, adjusted_tx_height, should_round, row_shift)) {
2195     return;
2196   }
2197 
2198   if (should_round) {
2199     ApplyRounding<32>(src, adjusted_tx_height);
2200   }
2201 
2202   assert(adjusted_tx_height % 4 == 0);
2203   int i = adjusted_tx_height;
2204   auto* data = src;
2205   do {
2206     // Process 4 1d dct32 rows in parallel per iteration.
2207     Dct32_NEON(data, 32, /*is_row=*/true, row_shift);
2208     data += 128;
2209     i -= 4;
2210   } while (i != 0);
2211 }
2212 
Dct32TransformLoopColumn_NEON(TransformType tx_type,TransformSize tx_size,int adjusted_tx_height,void * LIBGAV1_RESTRICT src_buffer,int start_x,int start_y,void * LIBGAV1_RESTRICT dst_frame)2213 void Dct32TransformLoopColumn_NEON(TransformType tx_type, TransformSize tx_size,
2214                                    int adjusted_tx_height,
2215                                    void* LIBGAV1_RESTRICT src_buffer,
2216                                    int start_x, int start_y,
2217                                    void* LIBGAV1_RESTRICT dst_frame) {
2218   auto* src = static_cast<int32_t*>(src_buffer);
2219   const int tx_width = kTransformWidth[tx_size];
2220 
2221   if (kTransformFlipColumnsMask.Contains(tx_type)) {
2222     FlipColumns<32>(src, tx_width);
2223   }
2224 
2225   if (!DctDcOnlyColumn<32>(src, adjusted_tx_height, tx_width)) {
2226     // Process 4 1d dct32 columns in parallel per iteration.
2227     int i = tx_width;
2228     auto* data = src;
2229     do {
2230       Dct32_NEON(data, tx_width, /*is_row=*/false, /*row_shift=*/0);
2231       data += 4;
2232       i -= 4;
2233     } while (i != 0);
2234   }
2235   auto& frame = *static_cast<Array2DView<uint16_t>*>(dst_frame);
2236   StoreToFrameWithRound<32>(frame, start_x, start_y, tx_width, src, tx_type);
2237 }
2238 
Dct64TransformLoopRow_NEON(TransformType,TransformSize tx_size,int adjusted_tx_height,void * src_buffer,int,int,void *)2239 void Dct64TransformLoopRow_NEON(TransformType /*tx_type*/,
2240                                 TransformSize tx_size, int adjusted_tx_height,
2241                                 void* src_buffer, int /*start_x*/,
2242                                 int /*start_y*/, void* /*dst_frame*/) {
2243   auto* src = static_cast<int32_t*>(src_buffer);
2244   const bool should_round = kShouldRound[tx_size];
2245   const uint8_t row_shift = kTransformRowShift[tx_size];
2246 
2247   if (DctDcOnly<64>(src, adjusted_tx_height, should_round, row_shift)) {
2248     return;
2249   }
2250 
2251   if (should_round) {
2252     ApplyRounding<64>(src, adjusted_tx_height);
2253   }
2254 
2255   assert(adjusted_tx_height % 4 == 0);
2256   int i = adjusted_tx_height;
2257   auto* data = src;
2258   do {
2259     // Process 4 1d dct64 rows in parallel per iteration.
2260     Dct64_NEON(data, 64, /*is_row=*/true, row_shift);
2261     data += 128 * 2;
2262     i -= 4;
2263   } while (i != 0);
2264 }
2265 
Dct64TransformLoopColumn_NEON(TransformType tx_type,TransformSize tx_size,int adjusted_tx_height,void * LIBGAV1_RESTRICT src_buffer,int start_x,int start_y,void * LIBGAV1_RESTRICT dst_frame)2266 void Dct64TransformLoopColumn_NEON(TransformType tx_type, TransformSize tx_size,
2267                                    int adjusted_tx_height,
2268                                    void* LIBGAV1_RESTRICT src_buffer,
2269                                    int start_x, int start_y,
2270                                    void* LIBGAV1_RESTRICT dst_frame) {
2271   auto* src = static_cast<int32_t*>(src_buffer);
2272   const int tx_width = kTransformWidth[tx_size];
2273 
2274   if (kTransformFlipColumnsMask.Contains(tx_type)) {
2275     FlipColumns<64>(src, tx_width);
2276   }
2277 
2278   if (!DctDcOnlyColumn<64>(src, adjusted_tx_height, tx_width)) {
2279     // Process 4 1d dct64 columns in parallel per iteration.
2280     int i = tx_width;
2281     auto* data = src;
2282     do {
2283       Dct64_NEON(data, tx_width, /*is_row=*/false, /*row_shift=*/0);
2284       data += 4;
2285       i -= 4;
2286     } while (i != 0);
2287   }
2288   auto& frame = *static_cast<Array2DView<uint16_t>*>(dst_frame);
2289   StoreToFrameWithRound<64>(frame, start_x, start_y, tx_width, src, tx_type);
2290 }
2291 
Adst4TransformLoopRow_NEON(TransformType,TransformSize tx_size,int adjusted_tx_height,void * src_buffer,int,int,void *)2292 void Adst4TransformLoopRow_NEON(TransformType /*tx_type*/,
2293                                 TransformSize tx_size, int adjusted_tx_height,
2294                                 void* src_buffer, int /*start_x*/,
2295                                 int /*start_y*/, void* /*dst_frame*/) {
2296   auto* src = static_cast<int32_t*>(src_buffer);
2297   const int tx_height = kTransformHeight[tx_size];
2298   const int row_shift = static_cast<int>(tx_height == 16);
2299   const bool should_round = (tx_height == 8);
2300 
2301   if (Adst4DcOnly(src, adjusted_tx_height, should_round, row_shift)) {
2302     return;
2303   }
2304 
2305   if (should_round) {
2306     ApplyRounding<4>(src, adjusted_tx_height);
2307   }
2308 
2309   // Process 4 1d adst4 rows in parallel per iteration.
2310   int i = adjusted_tx_height;
2311   auto* data = src;
2312   do {
2313     Adst4_NEON(data, /*step=*/4, /*is_row=*/true, row_shift);
2314     data += 16;
2315     i -= 4;
2316   } while (i != 0);
2317 }
2318 
Adst4TransformLoopColumn_NEON(TransformType tx_type,TransformSize tx_size,int adjusted_tx_height,void * LIBGAV1_RESTRICT src_buffer,int start_x,int start_y,void * LIBGAV1_RESTRICT dst_frame)2319 void Adst4TransformLoopColumn_NEON(TransformType tx_type, TransformSize tx_size,
2320                                    int adjusted_tx_height,
2321                                    void* LIBGAV1_RESTRICT src_buffer,
2322                                    int start_x, int start_y,
2323                                    void* LIBGAV1_RESTRICT dst_frame) {
2324   auto* src = static_cast<int32_t*>(src_buffer);
2325   const int tx_width = kTransformWidth[tx_size];
2326 
2327   if (kTransformFlipColumnsMask.Contains(tx_type)) {
2328     FlipColumns<4>(src, tx_width);
2329   }
2330 
2331   if (!Adst4DcOnlyColumn(src, adjusted_tx_height, tx_width)) {
2332     // Process 4 1d adst4 columns in parallel per iteration.
2333     int i = tx_width;
2334     auto* data = src;
2335     do {
2336       Adst4_NEON(data, tx_width, /*is_row=*/false, /*row_shift=*/0);
2337       data += 4;
2338       i -= 4;
2339     } while (i != 0);
2340   }
2341 
2342   auto& frame = *static_cast<Array2DView<uint16_t>*>(dst_frame);
2343   StoreToFrameWithRound<4, /*enable_flip_rows=*/true>(frame, start_x, start_y,
2344                                                       tx_width, src, tx_type);
2345 }
2346 
Adst8TransformLoopRow_NEON(TransformType,TransformSize tx_size,int adjusted_tx_height,void * src_buffer,int,int,void *)2347 void Adst8TransformLoopRow_NEON(TransformType /*tx_type*/,
2348                                 TransformSize tx_size, int adjusted_tx_height,
2349                                 void* src_buffer, int /*start_x*/,
2350                                 int /*start_y*/, void* /*dst_frame*/) {
2351   auto* src = static_cast<int32_t*>(src_buffer);
2352   const bool should_round = kShouldRound[tx_size];
2353   const uint8_t row_shift = kTransformRowShift[tx_size];
2354 
2355   if (Adst8DcOnly(src, adjusted_tx_height, should_round, row_shift)) {
2356     return;
2357   }
2358 
2359   if (should_round) {
2360     ApplyRounding<8>(src, adjusted_tx_height);
2361   }
2362 
2363   // Process 4 1d adst8 rows in parallel per iteration.
2364   assert(adjusted_tx_height % 4 == 0);
2365   int i = adjusted_tx_height;
2366   auto* data = src;
2367   do {
2368     Adst8_NEON<ButterflyRotation_4>(data, /*step=*/8,
2369                                     /*transpose=*/true, row_shift);
2370     data += 32;
2371     i -= 4;
2372   } while (i != 0);
2373 }
2374 
Adst8TransformLoopColumn_NEON(TransformType tx_type,TransformSize tx_size,int adjusted_tx_height,void * LIBGAV1_RESTRICT src_buffer,int start_x,int start_y,void * LIBGAV1_RESTRICT dst_frame)2375 void Adst8TransformLoopColumn_NEON(TransformType tx_type, TransformSize tx_size,
2376                                    int adjusted_tx_height,
2377                                    void* LIBGAV1_RESTRICT src_buffer,
2378                                    int start_x, int start_y,
2379                                    void* LIBGAV1_RESTRICT dst_frame) {
2380   auto* src = static_cast<int32_t*>(src_buffer);
2381   const int tx_width = kTransformWidth[tx_size];
2382 
2383   if (kTransformFlipColumnsMask.Contains(tx_type)) {
2384     FlipColumns<8>(src, tx_width);
2385   }
2386 
2387   if (!Adst8DcOnlyColumn(src, adjusted_tx_height, tx_width)) {
2388     // Process 4 1d adst8 columns in parallel per iteration.
2389     int i = tx_width;
2390     auto* data = src;
2391     do {
2392       Adst8_NEON<ButterflyRotation_4>(data, tx_width, /*transpose=*/false,
2393                                       /*row_shift=*/0);
2394       data += 4;
2395       i -= 4;
2396     } while (i != 0);
2397   }
2398   auto& frame = *static_cast<Array2DView<uint16_t>*>(dst_frame);
2399   StoreToFrameWithRound<8, /*enable_flip_rows=*/true>(frame, start_x, start_y,
2400                                                       tx_width, src, tx_type);
2401 }
2402 
Adst16TransformLoopRow_NEON(TransformType,TransformSize tx_size,int adjusted_tx_height,void * src_buffer,int,int,void *)2403 void Adst16TransformLoopRow_NEON(TransformType /*tx_type*/,
2404                                  TransformSize tx_size, int adjusted_tx_height,
2405                                  void* src_buffer, int /*start_x*/,
2406                                  int /*start_y*/, void* /*dst_frame*/) {
2407   auto* src = static_cast<int32_t*>(src_buffer);
2408   const bool should_round = kShouldRound[tx_size];
2409   const uint8_t row_shift = kTransformRowShift[tx_size];
2410 
2411   if (Adst16DcOnly(src, adjusted_tx_height, should_round, row_shift)) {
2412     return;
2413   }
2414 
2415   if (should_round) {
2416     ApplyRounding<16>(src, adjusted_tx_height);
2417   }
2418 
2419   assert(adjusted_tx_height % 4 == 0);
2420   int i = adjusted_tx_height;
2421   do {
2422     // Process 4 1d adst16 rows in parallel per iteration.
2423     Adst16_NEON<ButterflyRotation_4>(src, 16, /*is_row=*/true, row_shift);
2424     src += 64;
2425     i -= 4;
2426   } while (i != 0);
2427 }
2428 
Adst16TransformLoopColumn_NEON(TransformType tx_type,TransformSize tx_size,int adjusted_tx_height,void * LIBGAV1_RESTRICT src_buffer,int start_x,int start_y,void * LIBGAV1_RESTRICT dst_frame)2429 void Adst16TransformLoopColumn_NEON(TransformType tx_type,
2430                                     TransformSize tx_size,
2431                                     int adjusted_tx_height,
2432                                     void* LIBGAV1_RESTRICT src_buffer,
2433                                     int start_x, int start_y,
2434                                     void* LIBGAV1_RESTRICT dst_frame) {
2435   auto* src = static_cast<int32_t*>(src_buffer);
2436   const int tx_width = kTransformWidth[tx_size];
2437 
2438   if (kTransformFlipColumnsMask.Contains(tx_type)) {
2439     FlipColumns<16>(src, tx_width);
2440   }
2441 
2442   if (!Adst16DcOnlyColumn(src, adjusted_tx_height, tx_width)) {
2443     int i = tx_width;
2444     auto* data = src;
2445     do {
2446       // Process 4 1d adst16 columns in parallel per iteration.
2447       Adst16_NEON<ButterflyRotation_4>(data, tx_width, /*is_row=*/false,
2448                                        /*row_shift=*/0);
2449       data += 4;
2450       i -= 4;
2451     } while (i != 0);
2452   }
2453   auto& frame = *static_cast<Array2DView<uint16_t>*>(dst_frame);
2454   StoreToFrameWithRound<16, /*enable_flip_rows=*/true>(frame, start_x, start_y,
2455                                                        tx_width, src, tx_type);
2456 }
2457 
Identity4TransformLoopRow_NEON(TransformType tx_type,TransformSize tx_size,int adjusted_tx_height,void * src_buffer,int,int,void *)2458 void Identity4TransformLoopRow_NEON(TransformType tx_type,
2459                                     TransformSize tx_size,
2460                                     int adjusted_tx_height, void* src_buffer,
2461                                     int /*start_x*/, int /*start_y*/,
2462                                     void* /*dst_frame*/) {
2463   // Special case: Process row calculations during column transform call.
2464   // Improves performance.
2465   if (tx_type == kTransformTypeIdentityIdentity &&
2466       tx_size == kTransformSize4x4) {
2467     return;
2468   }
2469 
2470   auto* src = static_cast<int32_t*>(src_buffer);
2471   const int tx_height = kTransformHeight[tx_size];
2472   const bool should_round = (tx_height == 8);
2473 
2474   if (Identity4DcOnly(src, adjusted_tx_height, should_round, tx_height)) {
2475     return;
2476   }
2477 
2478   if (should_round) {
2479     ApplyRounding<4>(src, adjusted_tx_height);
2480   }
2481 
2482   const int shift = tx_height > 8 ? 1 : 0;
2483   int i = adjusted_tx_height;
2484   do {
2485     Identity4_NEON(src, /*step=*/4, shift);
2486     src += 16;
2487     i -= 4;
2488   } while (i != 0);
2489 }
2490 
Identity4TransformLoopColumn_NEON(TransformType tx_type,TransformSize tx_size,int adjusted_tx_height,void * LIBGAV1_RESTRICT src_buffer,int start_x,int start_y,void * LIBGAV1_RESTRICT dst_frame)2491 void Identity4TransformLoopColumn_NEON(TransformType tx_type,
2492                                        TransformSize tx_size,
2493                                        int adjusted_tx_height,
2494                                        void* LIBGAV1_RESTRICT src_buffer,
2495                                        int start_x, int start_y,
2496                                        void* LIBGAV1_RESTRICT dst_frame) {
2497   auto& frame = *static_cast<Array2DView<uint16_t>*>(dst_frame);
2498   auto* src = static_cast<int32_t*>(src_buffer);
2499   const int tx_width = kTransformWidth[tx_size];
2500 
2501   // Special case: Process row calculations during column transform call.
2502   if (tx_type == kTransformTypeIdentityIdentity &&
2503       (tx_size == kTransformSize4x4 || tx_size == kTransformSize8x4)) {
2504     Identity4RowColumnStoreToFrame(frame, start_x, start_y, tx_width,
2505                                    adjusted_tx_height, src);
2506     return;
2507   }
2508 
2509   if (kTransformFlipColumnsMask.Contains(tx_type)) {
2510     FlipColumns<4>(src, tx_width);
2511   }
2512 
2513   IdentityColumnStoreToFrame<4>(frame, start_x, start_y, tx_width,
2514                                 adjusted_tx_height, src);
2515 }
2516 
Identity8TransformLoopRow_NEON(TransformType tx_type,TransformSize tx_size,int adjusted_tx_height,void * src_buffer,int,int,void *)2517 void Identity8TransformLoopRow_NEON(TransformType tx_type,
2518                                     TransformSize tx_size,
2519                                     int adjusted_tx_height, void* src_buffer,
2520                                     int /*start_x*/, int /*start_y*/,
2521                                     void* /*dst_frame*/) {
2522   // Special case: Process row calculations during column transform call.
2523   // Improves performance.
2524   if (tx_type == kTransformTypeIdentityIdentity &&
2525       tx_size == kTransformSize8x4) {
2526     return;
2527   }
2528 
2529   auto* src = static_cast<int32_t*>(src_buffer);
2530   const int tx_height = kTransformHeight[tx_size];
2531   const bool should_round = kShouldRound[tx_size];
2532   const uint8_t row_shift = kTransformRowShift[tx_size];
2533 
2534   if (Identity8DcOnly(src, adjusted_tx_height, should_round, row_shift)) {
2535     return;
2536   }
2537   if (should_round) {
2538     ApplyRounding<8>(src, adjusted_tx_height);
2539   }
2540 
2541   // When combining the identity8 multiplier with the row shift, the
2542   // calculations for tx_height == 8 and tx_height == 16 can be simplified
2543   // from ((A * 2) + 1) >> 1) to A. For 10bpp, A must be clamped to a signed 16
2544   // bit value.
2545   if ((tx_height & 0x18) != 0) {
2546     for (int i = 0; i < tx_height; ++i) {
2547       const int32x4_t v_src_lo = vld1q_s32(&src[i * 8]);
2548       const int32x4_t v_src_hi = vld1q_s32(&src[(i * 8) + 4]);
2549       vst1q_s32(&src[i * 8], vmovl_s16(vqmovn_s32(v_src_lo)));
2550       vst1q_s32(&src[(i * 8) + 4], vmovl_s16(vqmovn_s32(v_src_hi)));
2551     }
2552     return;
2553   }
2554   if (tx_height == 32) {
2555     int i = adjusted_tx_height;
2556     do {
2557       Identity8Row32_NEON(src, /*step=*/8);
2558       src += 32;
2559       i -= 4;
2560     } while (i != 0);
2561     return;
2562   }
2563 
2564   assert(tx_size == kTransformSize8x4);
2565   int i = adjusted_tx_height;
2566   do {
2567     Identity8Row4_NEON(src, /*step=*/8);
2568     src += 32;
2569     i -= 4;
2570   } while (i != 0);
2571 }
2572 
Identity8TransformLoopColumn_NEON(TransformType tx_type,TransformSize tx_size,int adjusted_tx_height,void * LIBGAV1_RESTRICT src_buffer,int start_x,int start_y,void * LIBGAV1_RESTRICT dst_frame)2573 void Identity8TransformLoopColumn_NEON(TransformType tx_type,
2574                                        TransformSize tx_size,
2575                                        int adjusted_tx_height,
2576                                        void* LIBGAV1_RESTRICT src_buffer,
2577                                        int start_x, int start_y,
2578                                        void* LIBGAV1_RESTRICT dst_frame) {
2579   auto* src = static_cast<int32_t*>(src_buffer);
2580   const int tx_width = kTransformWidth[tx_size];
2581 
2582   if (kTransformFlipColumnsMask.Contains(tx_type)) {
2583     FlipColumns<8>(src, tx_width);
2584   }
2585   auto& frame = *static_cast<Array2DView<uint16_t>*>(dst_frame);
2586   IdentityColumnStoreToFrame<8>(frame, start_x, start_y, tx_width,
2587                                 adjusted_tx_height, src);
2588 }
2589 
Identity16TransformLoopRow_NEON(TransformType,TransformSize tx_size,int adjusted_tx_height,void * src_buffer,int,int,void *)2590 void Identity16TransformLoopRow_NEON(TransformType /*tx_type*/,
2591                                      TransformSize tx_size,
2592                                      int adjusted_tx_height, void* src_buffer,
2593                                      int /*start_x*/, int /*start_y*/,
2594                                      void* /*dst_frame*/) {
2595   auto* src = static_cast<int32_t*>(src_buffer);
2596   const bool should_round = kShouldRound[tx_size];
2597   const uint8_t row_shift = kTransformRowShift[tx_size];
2598 
2599   if (Identity16DcOnly(src, adjusted_tx_height, should_round, row_shift)) {
2600     return;
2601   }
2602 
2603   if (should_round) {
2604     ApplyRounding<16>(src, adjusted_tx_height);
2605   }
2606   int i = adjusted_tx_height;
2607   do {
2608     Identity16Row_NEON(src, /*step=*/16, row_shift);
2609     src += 64;
2610     i -= 4;
2611   } while (i != 0);
2612 }
2613 
Identity16TransformLoopColumn_NEON(TransformType tx_type,TransformSize tx_size,int adjusted_tx_height,void * LIBGAV1_RESTRICT src_buffer,int start_x,int start_y,void * LIBGAV1_RESTRICT dst_frame)2614 void Identity16TransformLoopColumn_NEON(TransformType tx_type,
2615                                         TransformSize tx_size,
2616                                         int adjusted_tx_height,
2617                                         void* LIBGAV1_RESTRICT src_buffer,
2618                                         int start_x, int start_y,
2619                                         void* LIBGAV1_RESTRICT dst_frame) {
2620   auto* src = static_cast<int32_t*>(src_buffer);
2621   const int tx_width = kTransformWidth[tx_size];
2622 
2623   if (kTransformFlipColumnsMask.Contains(tx_type)) {
2624     FlipColumns<16>(src, tx_width);
2625   }
2626   auto& frame = *static_cast<Array2DView<uint16_t>*>(dst_frame);
2627   IdentityColumnStoreToFrame<16>(frame, start_x, start_y, tx_width,
2628                                  adjusted_tx_height, src);
2629 }
2630 
Identity32TransformLoopRow_NEON(TransformType,TransformSize tx_size,int adjusted_tx_height,void * src_buffer,int,int,void *)2631 void Identity32TransformLoopRow_NEON(TransformType /*tx_type*/,
2632                                      TransformSize tx_size,
2633                                      int adjusted_tx_height, void* src_buffer,
2634                                      int /*start_x*/, int /*start_y*/,
2635                                      void* /*dst_frame*/) {
2636   const int tx_height = kTransformHeight[tx_size];
2637 
2638   // When combining the identity32 multiplier with the row shift, the
2639   // calculations for tx_height == 8 and tx_height == 32 can be simplified
2640   // from ((A * 4) + 2) >> 2) to A.
2641   if ((tx_height & 0x28) != 0) {
2642     return;
2643   }
2644 
2645   // Process kTransformSize32x16. The src is always rounded before the identity
2646   // transform and shifted by 1 afterwards.
2647   auto* src = static_cast<int32_t*>(src_buffer);
2648   if (Identity32DcOnly(src, adjusted_tx_height)) {
2649     return;
2650   }
2651 
2652   assert(tx_size == kTransformSize32x16);
2653   ApplyRounding<32>(src, adjusted_tx_height);
2654   int i = adjusted_tx_height;
2655   do {
2656     Identity32Row16_NEON(src, /*step=*/32);
2657     src += 128;
2658     i -= 4;
2659   } while (i != 0);
2660 }
2661 
Identity32TransformLoopColumn_NEON(TransformType,TransformSize tx_size,int adjusted_tx_height,void * LIBGAV1_RESTRICT src_buffer,int start_x,int start_y,void * LIBGAV1_RESTRICT dst_frame)2662 void Identity32TransformLoopColumn_NEON(TransformType /*tx_type*/,
2663                                         TransformSize tx_size,
2664                                         int adjusted_tx_height,
2665                                         void* LIBGAV1_RESTRICT src_buffer,
2666                                         int start_x, int start_y,
2667                                         void* LIBGAV1_RESTRICT dst_frame) {
2668   auto& frame = *static_cast<Array2DView<uint16_t>*>(dst_frame);
2669   auto* src = static_cast<int32_t*>(src_buffer);
2670   const int tx_width = kTransformWidth[tx_size];
2671 
2672   IdentityColumnStoreToFrame<32>(frame, start_x, start_y, tx_width,
2673                                  adjusted_tx_height, src);
2674 }
2675 
Wht4TransformLoopRow_NEON(TransformType tx_type,TransformSize tx_size,int,void *,int,int,void *)2676 void Wht4TransformLoopRow_NEON(TransformType tx_type, TransformSize tx_size,
2677                                int /*adjusted_tx_height*/, void* /*src_buffer*/,
2678                                int /*start_x*/, int /*start_y*/,
2679                                void* /*dst_frame*/) {
2680   assert(tx_type == kTransformTypeDctDct);
2681   assert(tx_size == kTransformSize4x4);
2682   static_cast<void>(tx_type);
2683   static_cast<void>(tx_size);
2684   // Do both row and column transforms in the column-transform pass.
2685 }
2686 
Wht4TransformLoopColumn_NEON(TransformType tx_type,TransformSize tx_size,int adjusted_tx_height,void * LIBGAV1_RESTRICT src_buffer,int start_x,int start_y,void * LIBGAV1_RESTRICT dst_frame)2687 void Wht4TransformLoopColumn_NEON(TransformType tx_type, TransformSize tx_size,
2688                                   int adjusted_tx_height,
2689                                   void* LIBGAV1_RESTRICT src_buffer,
2690                                   int start_x, int start_y,
2691                                   void* LIBGAV1_RESTRICT dst_frame) {
2692   assert(tx_type == kTransformTypeDctDct);
2693   assert(tx_size == kTransformSize4x4);
2694   static_cast<void>(tx_type);
2695   static_cast<void>(tx_size);
2696 
2697   // Process 4 1d wht4 rows and columns in parallel.
2698   const auto* src = static_cast<int32_t*>(src_buffer);
2699   auto& frame = *static_cast<Array2DView<uint16_t>*>(dst_frame);
2700   uint16_t* dst = frame[start_y] + start_x;
2701   const int dst_stride = frame.columns();
2702   Wht4_NEON(dst, dst_stride, src, adjusted_tx_height);
2703 }
2704 
2705 //------------------------------------------------------------------------------
2706 
Init10bpp()2707 void Init10bpp() {
2708   Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth10);
2709   assert(dsp != nullptr);
2710   // Maximum transform size for Dct is 64.
2711   dsp->inverse_transforms[kTransform1dDct][kTransform1dSize4][kRow] =
2712       Dct4TransformLoopRow_NEON;
2713   dsp->inverse_transforms[kTransform1dDct][kTransform1dSize4][kColumn] =
2714       Dct4TransformLoopColumn_NEON;
2715   dsp->inverse_transforms[kTransform1dDct][kTransform1dSize8][kRow] =
2716       Dct8TransformLoopRow_NEON;
2717   dsp->inverse_transforms[kTransform1dDct][kTransform1dSize8][kColumn] =
2718       Dct8TransformLoopColumn_NEON;
2719   dsp->inverse_transforms[kTransform1dDct][kTransform1dSize16][kRow] =
2720       Dct16TransformLoopRow_NEON;
2721   dsp->inverse_transforms[kTransform1dDct][kTransform1dSize16][kColumn] =
2722       Dct16TransformLoopColumn_NEON;
2723   dsp->inverse_transforms[kTransform1dDct][kTransform1dSize32][kRow] =
2724       Dct32TransformLoopRow_NEON;
2725   dsp->inverse_transforms[kTransform1dDct][kTransform1dSize32][kColumn] =
2726       Dct32TransformLoopColumn_NEON;
2727   dsp->inverse_transforms[kTransform1dDct][kTransform1dSize64][kRow] =
2728       Dct64TransformLoopRow_NEON;
2729   dsp->inverse_transforms[kTransform1dDct][kTransform1dSize64][kColumn] =
2730       Dct64TransformLoopColumn_NEON;
2731 
2732   // Maximum transform size for Adst is 16.
2733   dsp->inverse_transforms[kTransform1dAdst][kTransform1dSize4][kRow] =
2734       Adst4TransformLoopRow_NEON;
2735   dsp->inverse_transforms[kTransform1dAdst][kTransform1dSize4][kColumn] =
2736       Adst4TransformLoopColumn_NEON;
2737   dsp->inverse_transforms[kTransform1dAdst][kTransform1dSize8][kRow] =
2738       Adst8TransformLoopRow_NEON;
2739   dsp->inverse_transforms[kTransform1dAdst][kTransform1dSize8][kColumn] =
2740       Adst8TransformLoopColumn_NEON;
2741   dsp->inverse_transforms[kTransform1dAdst][kTransform1dSize16][kRow] =
2742       Adst16TransformLoopRow_NEON;
2743   dsp->inverse_transforms[kTransform1dAdst][kTransform1dSize16][kColumn] =
2744       Adst16TransformLoopColumn_NEON;
2745 
2746   // Maximum transform size for Identity transform is 32.
2747   dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize4][kRow] =
2748       Identity4TransformLoopRow_NEON;
2749   dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize4][kColumn] =
2750       Identity4TransformLoopColumn_NEON;
2751   dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize8][kRow] =
2752       Identity8TransformLoopRow_NEON;
2753   dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize8][kColumn] =
2754       Identity8TransformLoopColumn_NEON;
2755   dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize16][kRow] =
2756       Identity16TransformLoopRow_NEON;
2757   dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize16][kColumn] =
2758       Identity16TransformLoopColumn_NEON;
2759   dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize32][kRow] =
2760       Identity32TransformLoopRow_NEON;
2761   dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize32][kColumn] =
2762       Identity32TransformLoopColumn_NEON;
2763 
2764   // Maximum transform size for Wht is 4.
2765   dsp->inverse_transforms[kTransform1dWht][kTransform1dSize4][kRow] =
2766       Wht4TransformLoopRow_NEON;
2767   dsp->inverse_transforms[kTransform1dWht][kTransform1dSize4][kColumn] =
2768       Wht4TransformLoopColumn_NEON;
2769 }
2770 
2771 }  // namespace
2772 
InverseTransformInit10bpp_NEON()2773 void InverseTransformInit10bpp_NEON() { Init10bpp(); }
2774 
2775 }  // namespace dsp
2776 }  // namespace libgav1
2777 #else   // !LIBGAV1_ENABLE_NEON || LIBGAV1_MAX_BITDEPTH < 10
2778 namespace libgav1 {
2779 namespace dsp {
2780 
InverseTransformInit10bpp_NEON()2781 void InverseTransformInit10bpp_NEON() {}
2782 
2783 }  // namespace dsp
2784 }  // namespace libgav1
2785 #endif  // LIBGAV1_ENABLE_NEON && LIBGAV1_MAX_BITDEPTH >= 10
2786