1 /*
2 * Copyright(c) 2019 Intel Corporation
3 *
4 * This source code is subject to the terms of the BSD 2 Clause License and
5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6 * was not distributed with this source code in the LICENSE file, you can
7 * obtain it at https://www.aomedia.org/license/software-license. If the Alliance for Open
8 * Media Patent License 1.0 was not distributed with this source code in the
9 * PATENTS file, you can obtain it at https://www.aomedia.org/license/patent-license.
10 */
11 
12 #include "EbDefinitions.h"
13 
14 #if EN_AVX512_SUPPORT
15 #include <assert.h>
16 #include "aom_dsp_rtcd.h"
17 #include "EbTransforms.h"
18 #include <immintrin.h>
19 #include "transpose_encoder_avx512.h"
20 
21 const int32_t *cospi_arr(int32_t n);
22 const int32_t *sinpi_arr(int32_t n);
23 
24 void av1_transform_config(TxType tx_type, TxSize tx_size, Txfm2dFlipCfg *cfg);
25 
26 typedef void (*fwd_transform_1d_avx512)(const __m512i *in, __m512i *out, const int8_t bit,
27                                         const int32_t num_cols);
28 
29 #define btf_32_type0_avx512_new(ww0, ww1, in0, in1, out0, out1, r, bit) \
30     do {                                                                \
31         const __m512i in0_w0 = _mm512_mullo_epi32(in0, ww0);            \
32         const __m512i in1_w1 = _mm512_mullo_epi32(in1, ww1);            \
33         out0                 = _mm512_add_epi32(in0_w0, in1_w1);        \
34         out0                 = _mm512_add_epi32(out0, r);               \
35         out0                 = _mm512_srai_epi32(out0, (uint8_t)bit);   \
36         const __m512i in0_w1 = _mm512_mullo_epi32(in0, ww1);            \
37         const __m512i in1_w0 = _mm512_mullo_epi32(in1, ww0);            \
38         out1                 = _mm512_sub_epi32(in0_w1, in1_w0);        \
39         out1                 = _mm512_add_epi32(out1, r);               \
40         out1                 = _mm512_srai_epi32(out1, (uint8_t)bit);   \
41     } while (0)
42 
43 // out0 = in0*w0 + in1*w1
44 // out1 = in1*w0 - in0*w1
45 #define btf_32_type1_avx512_new(ww0, ww1, in0, in1, out0, out1, r, bit) \
46     do { btf_32_type0_avx512_new(ww1, ww0, in1, in0, out0, out1, r, bit); } while (0)
47 
48 static const int8_t *fwd_txfm_shift_ls[TX_SIZES_ALL] = {
49     fwd_shift_4x4,   fwd_shift_8x8,   fwd_shift_16x16, fwd_shift_32x32, fwd_shift_64x64,
50     fwd_shift_4x8,   fwd_shift_8x4,   fwd_shift_8x16,  fwd_shift_16x8,  fwd_shift_16x32,
51     fwd_shift_32x16, fwd_shift_32x64, fwd_shift_64x32, fwd_shift_4x16,  fwd_shift_16x4,
52     fwd_shift_8x32,  fwd_shift_32x8,  fwd_shift_16x64, fwd_shift_64x16,
53 };
54 
load_buffer_16x16_avx512(const int16_t * input,__m512i * out,int32_t stride,int32_t flipud,int32_t fliplr,const int8_t shift)55 static INLINE void load_buffer_16x16_avx512(const int16_t *input, __m512i *out, int32_t stride,
56                                             int32_t flipud, int32_t fliplr, const int8_t shift) {
57     __m256i temp[16];
58     uint8_t ushift = (uint8_t)shift;
59     if (flipud) {
60         /* load rows upside down (bottom to top) */
61         for (int32_t i = 0; i < 16; i++) {
62             int idx   = 15 - i;
63             temp[idx] = _mm256_loadu_si256((const __m256i *)(input + i * stride));
64             out[idx]  = _mm512_cvtepi16_epi32(temp[idx]);
65             out[idx]  = _mm512_slli_epi32(out[idx], ushift);
66         }
67     } else {
68         /* load rows normally */
69         for (int32_t i = 0; i < 16; i++) {
70             temp[i] = _mm256_loadu_si256((const __m256i *)(input + i * stride));
71             out[i]  = _mm512_cvtepi16_epi32(temp[i]);
72             out[i]  = _mm512_slli_epi32(out[i], ushift);
73         }
74     }
75 
76     if (fliplr) {
77         /*flip columns left to right*/
78         uint32_t idx[] = {15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0};
79         __m512i  index = _mm512_loadu_si512(idx);
80 
81         for (int32_t i = 0; i < 16; i++) { out[i] = _mm512_permutexvar_epi32(index, out[i]); }
82     }
83 }
84 
fidtx16x16_avx512(const __m512i * in,__m512i * out,const int8_t bit,const int32_t col_num)85 static void fidtx16x16_avx512(const __m512i *in, __m512i *out, const int8_t bit,
86                               const int32_t col_num) {
87     (void)bit;
88     const uint8_t bits     = 12; // new_sqrt2_bits = 12
89     const int32_t sqrt     = 2 * 5793; // 2 * new_sqrt2
90     const __m512i newsqrt  = _mm512_set1_epi32(sqrt);
91     const __m512i rounding = _mm512_set1_epi32(1 << (bits - 1));
92     __m512i       temp;
93     int32_t       num_iters = 16 * col_num;
94     for (int32_t i = 0; i < num_iters; i++) {
95         temp   = _mm512_mullo_epi32(in[i], newsqrt);
96         temp   = _mm512_add_epi32(temp, rounding);
97         out[i] = _mm512_srai_epi32(temp, bits);
98     }
99 }
100 
col_txfm_16x16_rounding_avx512(__m512i * in,const int8_t shift)101 static INLINE void col_txfm_16x16_rounding_avx512(__m512i *in, const int8_t shift) {
102     uint8_t       ushift   = (uint8_t)shift;
103     const __m512i rounding = _mm512_set1_epi32(1 << (ushift - 1));
104     for (int32_t i = 0; i < 16; i++) {
105         in[i] = _mm512_add_epi32(in[i], rounding);
106         in[i] = _mm512_srai_epi32(in[i], ushift);
107     }
108 }
109 
write_buffer_16x16(const __m512i * res,int32_t * output)110 static INLINE void write_buffer_16x16(const __m512i *res, int32_t *output) {
111     int32_t fact = -1, index = -1;
112     for (int32_t i = 0; i < 8; i++) {
113         _mm512_storeu_si512((__m512i *)(output + (++fact) * 32), res[++index]);
114         _mm512_storeu_si512((__m512i *)(output + (fact)*32 + 16), res[++index]);
115     }
116 }
117 
half_btf_avx512(const __m512i * w0,const __m512i * n0,const __m512i * w1,const __m512i * n1,const __m512i * rounding,uint8_t bit)118 static INLINE __m512i half_btf_avx512(const __m512i *w0, const __m512i *n0, const __m512i *w1,
119                                       const __m512i *n1, const __m512i *rounding, uint8_t bit) {
120     __m512i x, y;
121 
122     x = _mm512_mullo_epi32(*w0, *n0);
123     y = _mm512_mullo_epi32(*w1, *n1);
124     x = _mm512_add_epi32(x, y);
125     x = _mm512_add_epi32(x, *rounding);
126     x = _mm512_srai_epi32(x, bit);
127     return x;
128 }
129 
fadst16x16_avx512(const __m512i * in,__m512i * out,const int8_t bit,const int32_t col_num)130 static void fadst16x16_avx512(const __m512i *in, __m512i *out, const int8_t bit,
131                               const int32_t col_num) {
132     const int32_t *cospi    = cospi_arr(bit);
133     const __m512i  cospi32  = _mm512_set1_epi32(cospi[32]);
134     const __m512i  cospi48  = _mm512_set1_epi32(cospi[48]);
135     const __m512i  cospi16  = _mm512_set1_epi32(cospi[16]);
136     const __m512i  cospim16 = _mm512_set1_epi32(-cospi[16]);
137     const __m512i  cospim48 = _mm512_set1_epi32(-cospi[48]);
138     const __m512i  cospi8   = _mm512_set1_epi32(cospi[8]);
139     const __m512i  cospi56  = _mm512_set1_epi32(cospi[56]);
140     const __m512i  cospim56 = _mm512_set1_epi32(-cospi[56]);
141     const __m512i  cospim8  = _mm512_set1_epi32(-cospi[8]);
142     const __m512i  cospi24  = _mm512_set1_epi32(cospi[24]);
143     const __m512i  cospim24 = _mm512_set1_epi32(-cospi[24]);
144     const __m512i  cospim40 = _mm512_set1_epi32(-cospi[40]);
145     const __m512i  cospi40  = _mm512_set1_epi32(cospi[40]);
146     const __m512i  cospi2   = _mm512_set1_epi32(cospi[2]);
147     const __m512i  cospi62  = _mm512_set1_epi32(cospi[62]);
148     const __m512i  cospim2  = _mm512_set1_epi32(-cospi[2]);
149     const __m512i  cospi10  = _mm512_set1_epi32(cospi[10]);
150     const __m512i  cospi54  = _mm512_set1_epi32(cospi[54]);
151     const __m512i  cospim10 = _mm512_set1_epi32(-cospi[10]);
152     const __m512i  cospi18  = _mm512_set1_epi32(cospi[18]);
153     const __m512i  cospi46  = _mm512_set1_epi32(cospi[46]);
154     const __m512i  cospim18 = _mm512_set1_epi32(-cospi[18]);
155     const __m512i  cospi26  = _mm512_set1_epi32(cospi[26]);
156     const __m512i  cospi38  = _mm512_set1_epi32(cospi[38]);
157     const __m512i  cospim26 = _mm512_set1_epi32(-cospi[26]);
158     const __m512i  cospi34  = _mm512_set1_epi32(cospi[34]);
159     const __m512i  cospi30  = _mm512_set1_epi32(cospi[30]);
160     const __m512i  cospim34 = _mm512_set1_epi32(-cospi[34]);
161     const __m512i  cospi42  = _mm512_set1_epi32(cospi[42]);
162     const __m512i  cospi22  = _mm512_set1_epi32(cospi[22]);
163     const __m512i  cospim42 = _mm512_set1_epi32(-cospi[42]);
164     const __m512i  cospi50  = _mm512_set1_epi32(cospi[50]);
165     const __m512i  cospi14  = _mm512_set1_epi32(cospi[14]);
166     const __m512i  cospim50 = _mm512_set1_epi32(-cospi[50]);
167     const __m512i  cospi58  = _mm512_set1_epi32(cospi[58]);
168     const __m512i  cospi6   = _mm512_set1_epi32(cospi[6]);
169     const __m512i  cospim58 = _mm512_set1_epi32(-cospi[58]);
170     const __m512i  rnding   = _mm512_set1_epi32(1 << (bit - 1));
171     const __m512i  zeroes   = _mm512_setzero_si512();
172 
173     __m512i u[16], v[16], x, y;
174     int32_t col;
175 
176     for (col = 0; col < col_num; ++col) {
177         // stage 1
178         u[1]  = _mm512_sub_epi32(zeroes, in[15 * col_num + col]);
179         u[2]  = _mm512_sub_epi32(zeroes, in[7 * col_num + col]);
180         u[4]  = _mm512_sub_epi32(zeroes, in[3 * col_num + col]);
181         u[7]  = _mm512_sub_epi32(zeroes, in[11 * col_num + col]);
182         u[8]  = _mm512_sub_epi32(zeroes, in[1 * col_num + col]);
183         u[11] = _mm512_sub_epi32(zeroes, in[9 * col_num + col]);
184         u[13] = _mm512_sub_epi32(zeroes, in[13 * col_num + col]);
185         u[14] = _mm512_sub_epi32(zeroes, in[5 * col_num + col]);
186 
187         // stage 2
188         x    = _mm512_mullo_epi32(u[2], cospi32);
189         y    = _mm512_mullo_epi32(in[8 * col_num + col], cospi32);
190         v[2] = _mm512_add_epi32(x, y);
191         v[2] = _mm512_add_epi32(v[2], rnding);
192         v[2] = _mm512_srai_epi32(v[2], (uint8_t)bit);
193 
194         v[3] = _mm512_sub_epi32(x, y);
195         v[3] = _mm512_add_epi32(v[3], rnding);
196         v[3] = _mm512_srai_epi32(v[3], (uint8_t)bit);
197 
198         x    = _mm512_mullo_epi32(in[4 * col_num + col], cospi32);
199         y    = _mm512_mullo_epi32(u[7], cospi32);
200         v[6] = _mm512_add_epi32(x, y);
201         v[6] = _mm512_add_epi32(v[6], rnding);
202         v[6] = _mm512_srai_epi32(v[6], (uint8_t)bit);
203 
204         v[7] = _mm512_sub_epi32(x, y);
205         v[7] = _mm512_add_epi32(v[7], rnding);
206         v[7] = _mm512_srai_epi32(v[7], (uint8_t)bit);
207 
208         x     = _mm512_mullo_epi32(in[6 * col_num + col], cospi32);
209         y     = _mm512_mullo_epi32(u[11], cospi32);
210         v[10] = _mm512_add_epi32(x, y);
211         v[10] = _mm512_add_epi32(v[10], rnding);
212         v[10] = _mm512_srai_epi32(v[10], (uint8_t)bit);
213 
214         v[11] = _mm512_sub_epi32(x, y);
215         v[11] = _mm512_add_epi32(v[11], rnding);
216         v[11] = _mm512_srai_epi32(v[11], (uint8_t)bit);
217 
218         x     = _mm512_mullo_epi32(u[14], cospi32);
219         y     = _mm512_mullo_epi32(in[10 * col_num + col], cospi32);
220         v[14] = _mm512_add_epi32(x, y);
221         v[14] = _mm512_add_epi32(v[14], rnding);
222         v[14] = _mm512_srai_epi32(v[14], (uint8_t)bit);
223 
224         v[15] = _mm512_sub_epi32(x, y);
225         v[15] = _mm512_add_epi32(v[15], rnding);
226         v[15] = _mm512_srai_epi32(v[15], (uint8_t)bit);
227 
228         // stage 3
229         u[0]  = _mm512_add_epi32(in[0 * col_num + col], v[2]);
230         u[3]  = _mm512_sub_epi32(u[1], v[3]);
231         u[1]  = _mm512_add_epi32(u[1], v[3]);
232         u[2]  = _mm512_sub_epi32(in[0 * col_num + col], v[2]);
233         u[6]  = _mm512_sub_epi32(u[4], v[6]);
234         u[4]  = _mm512_add_epi32(u[4], v[6]);
235         u[5]  = _mm512_add_epi32(in[12 * col_num + col], v[7]);
236         u[7]  = _mm512_sub_epi32(in[12 * col_num + col], v[7]);
237         u[10] = _mm512_sub_epi32(u[8], v[10]);
238         u[8]  = _mm512_add_epi32(u[8], v[10]);
239         u[9]  = _mm512_add_epi32(in[14 * col_num + col], v[11]);
240         u[11] = _mm512_sub_epi32(in[14 * col_num + col], v[11]);
241         u[12] = _mm512_add_epi32(in[2 * col_num + col], v[14]);
242         u[15] = _mm512_sub_epi32(u[13], v[15]);
243         u[13] = _mm512_add_epi32(u[13], v[15]);
244         u[14] = _mm512_sub_epi32(in[2 * col_num + col], v[14]);
245 
246         // stage 4
247         v[4]  = half_btf_avx512(&cospi16, &u[4], &cospi48, &u[5], &rnding, bit);
248         v[5]  = half_btf_avx512(&cospi48, &u[4], &cospim16, &u[5], &rnding, bit);
249         v[6]  = half_btf_avx512(&cospim48, &u[6], &cospi16, &u[7], &rnding, bit);
250         v[7]  = half_btf_avx512(&cospi16, &u[6], &cospi48, &u[7], &rnding, bit);
251         v[12] = half_btf_avx512(&cospi16, &u[12], &cospi48, &u[13], &rnding, bit);
252         v[13] = half_btf_avx512(&cospi48, &u[12], &cospim16, &u[13], &rnding, bit);
253         v[14] = half_btf_avx512(&cospim48, &u[14], &cospi16, &u[15], &rnding, bit);
254         v[15] = half_btf_avx512(&cospi16, &u[14], &cospi48, &u[15], &rnding, bit);
255 
256         // stage 5
257         u[4]  = _mm512_sub_epi32(u[0], v[4]);
258         u[0]  = _mm512_add_epi32(u[0], v[4]);
259         u[5]  = _mm512_sub_epi32(u[1], v[5]);
260         u[1]  = _mm512_add_epi32(u[1], v[5]);
261         u[6]  = _mm512_sub_epi32(u[2], v[6]);
262         u[2]  = _mm512_add_epi32(u[2], v[6]);
263         u[7]  = _mm512_sub_epi32(u[3], v[7]);
264         u[3]  = _mm512_add_epi32(u[3], v[7]);
265         u[12] = _mm512_sub_epi32(u[8], v[12]);
266         u[8]  = _mm512_add_epi32(u[8], v[12]);
267         u[13] = _mm512_sub_epi32(u[9], v[13]);
268         u[9]  = _mm512_add_epi32(u[9], v[13]);
269         u[14] = _mm512_sub_epi32(u[10], v[14]);
270         u[10] = _mm512_add_epi32(u[10], v[14]);
271         u[15] = _mm512_sub_epi32(u[11], v[15]);
272         u[11] = _mm512_add_epi32(u[11], v[15]);
273 
274         // stage 6
275         v[8]  = half_btf_avx512(&cospi8, &u[8], &cospi56, &u[9], &rnding, bit);
276         v[9]  = half_btf_avx512(&cospi56, &u[8], &cospim8, &u[9], &rnding, bit);
277         v[10] = half_btf_avx512(&cospi40, &u[10], &cospi24, &u[11], &rnding, bit);
278         v[11] = half_btf_avx512(&cospi24, &u[10], &cospim40, &u[11], &rnding, bit);
279         v[12] = half_btf_avx512(&cospim56, &u[12], &cospi8, &u[13], &rnding, bit);
280         v[13] = half_btf_avx512(&cospi8, &u[12], &cospi56, &u[13], &rnding, bit);
281         v[14] = half_btf_avx512(&cospim24, &u[14], &cospi40, &u[15], &rnding, bit);
282         v[15] = half_btf_avx512(&cospi40, &u[14], &cospi24, &u[15], &rnding, bit);
283 
284         // stage 7
285         u[8]  = _mm512_sub_epi32(u[0], v[8]);
286         u[0]  = _mm512_add_epi32(u[0], v[8]);
287         u[9]  = _mm512_sub_epi32(u[1], v[9]);
288         u[1]  = _mm512_add_epi32(u[1], v[9]);
289         u[10] = _mm512_sub_epi32(u[2], v[10]);
290         u[2]  = _mm512_add_epi32(u[2], v[10]);
291         u[11] = _mm512_sub_epi32(u[3], v[11]);
292         u[3]  = _mm512_add_epi32(u[3], v[11]);
293         u[12] = _mm512_sub_epi32(u[4], v[12]);
294         u[4]  = _mm512_add_epi32(u[4], v[12]);
295         u[13] = _mm512_sub_epi32(u[5], v[13]);
296         u[5]  = _mm512_add_epi32(u[5], v[13]);
297         u[14] = _mm512_sub_epi32(u[6], v[14]);
298         u[6]  = _mm512_add_epi32(u[6], v[14]);
299         u[15] = _mm512_sub_epi32(u[7], v[15]);
300         u[7]  = _mm512_add_epi32(u[7], v[15]);
301 
302         // stage 8
303         out[15 * col_num + col] = half_btf_avx512(&cospi2, &u[0], &cospi62, &u[1], &rnding, bit);
304         out[0 * col_num + col]  = half_btf_avx512(&cospi62, &u[0], &cospim2, &u[1], &rnding, bit);
305         out[13 * col_num + col] = half_btf_avx512(&cospi10, &u[2], &cospi54, &u[3], &rnding, bit);
306         out[2 * col_num + col]  = half_btf_avx512(&cospi54, &u[2], &cospim10, &u[3], &rnding, bit);
307         out[11 * col_num + col] = half_btf_avx512(&cospi18, &u[4], &cospi46, &u[5], &rnding, bit);
308         out[4 * col_num + col]  = half_btf_avx512(&cospi46, &u[4], &cospim18, &u[5], &rnding, bit);
309         out[9 * col_num + col]  = half_btf_avx512(&cospi26, &u[6], &cospi38, &u[7], &rnding, bit);
310         out[6 * col_num + col]  = half_btf_avx512(&cospi38, &u[6], &cospim26, &u[7], &rnding, bit);
311         out[7 * col_num + col]  = half_btf_avx512(&cospi34, &u[8], &cospi30, &u[9], &rnding, bit);
312         out[8 * col_num + col]  = half_btf_avx512(&cospi30, &u[8], &cospim34, &u[9], &rnding, bit);
313         out[5 * col_num + col]  = half_btf_avx512(&cospi42, &u[10], &cospi22, &u[11], &rnding, bit);
314         out[10 * col_num + col] = half_btf_avx512(&cospi22, &u[10], &cospim42, &u[11], &rnding, bit);
315         out[3 * col_num + col]  = half_btf_avx512(&cospi50, &u[12], &cospi14, &u[13], &rnding, bit);
316         out[12 * col_num + col] = half_btf_avx512(&cospi14, &u[12], &cospim50, &u[13], &rnding, bit);
317         out[1 * col_num + col]  = half_btf_avx512(&cospi58, &u[14], &cospi6, &u[15], &rnding, bit);
318         out[14 * col_num + col] = half_btf_avx512(&cospi6, &u[14], &cospim58, &u[15], &rnding, bit);
319     }
320 }
fdct16x16_avx512(const __m512i * in,__m512i * out,const int8_t bit,const int32_t col_num)321 static void fdct16x16_avx512(const __m512i *in, __m512i *out, const int8_t bit,
322                              const int32_t col_num) {
323     const int32_t *cospi    = cospi_arr(bit);
324     const __m512i  cospi32  = _mm512_set1_epi32(cospi[32]);
325     const __m512i  cospim32 = _mm512_set1_epi32(-cospi[32]);
326     const __m512i  cospi48  = _mm512_set1_epi32(cospi[48]);
327     const __m512i  cospi16  = _mm512_set1_epi32(cospi[16]);
328     const __m512i  cospim48 = _mm512_set1_epi32(-cospi[48]);
329     const __m512i  cospim16 = _mm512_set1_epi32(-cospi[16]);
330     const __m512i  cospi56  = _mm512_set1_epi32(cospi[56]);
331     const __m512i  cospi8   = _mm512_set1_epi32(cospi[8]);
332     const __m512i  cospi24  = _mm512_set1_epi32(cospi[24]);
333     const __m512i  cospi40  = _mm512_set1_epi32(cospi[40]);
334     const __m512i  cospi60  = _mm512_set1_epi32(cospi[60]);
335     const __m512i  cospi4   = _mm512_set1_epi32(cospi[4]);
336     const __m512i  cospi28  = _mm512_set1_epi32(cospi[28]);
337     const __m512i  cospi36  = _mm512_set1_epi32(cospi[36]);
338     const __m512i  cospi44  = _mm512_set1_epi32(cospi[44]);
339     const __m512i  cospi20  = _mm512_set1_epi32(cospi[20]);
340     const __m512i  cospi12  = _mm512_set1_epi32(cospi[12]);
341     const __m512i  cospi52  = _mm512_set1_epi32(cospi[52]);
342     const __m512i  rnding   = _mm512_set1_epi32(1 << (bit - 1));
343     __m512i        u[16], v[16], x;
344     int32_t        col;
345 
346     for (col = 0; col < col_num; ++col) {
347         // stage 0
348         // stage 1
349         u[0]  = _mm512_add_epi32(in[0 * col_num + col], in[15 * col_num + col]);
350         u[15] = _mm512_sub_epi32(in[0 * col_num + col], in[15 * col_num + col]);
351         u[1]  = _mm512_add_epi32(in[1 * col_num + col], in[14 * col_num + col]);
352         u[14] = _mm512_sub_epi32(in[1 * col_num + col], in[14 * col_num + col]);
353         u[2]  = _mm512_add_epi32(in[2 * col_num + col], in[13 * col_num + col]);
354         u[13] = _mm512_sub_epi32(in[2 * col_num + col], in[13 * col_num + col]);
355         u[3]  = _mm512_add_epi32(in[3 * col_num + col], in[12 * col_num + col]);
356         u[12] = _mm512_sub_epi32(in[3 * col_num + col], in[12 * col_num + col]);
357         u[4]  = _mm512_add_epi32(in[4 * col_num + col], in[11 * col_num + col]);
358         u[11] = _mm512_sub_epi32(in[4 * col_num + col], in[11 * col_num + col]);
359         u[5]  = _mm512_add_epi32(in[5 * col_num + col], in[10 * col_num + col]);
360         u[10] = _mm512_sub_epi32(in[5 * col_num + col], in[10 * col_num + col]);
361         u[6]  = _mm512_add_epi32(in[6 * col_num + col], in[9 * col_num + col]);
362         u[9]  = _mm512_sub_epi32(in[6 * col_num + col], in[9 * col_num + col]);
363         u[7]  = _mm512_add_epi32(in[7 * col_num + col], in[8 * col_num + col]);
364         u[8]  = _mm512_sub_epi32(in[7 * col_num + col], in[8 * col_num + col]);
365 
366         // stage 2
367         v[0] = _mm512_add_epi32(u[0], u[7]);
368         v[7] = _mm512_sub_epi32(u[0], u[7]);
369         v[1] = _mm512_add_epi32(u[1], u[6]);
370         v[6] = _mm512_sub_epi32(u[1], u[6]);
371         v[2] = _mm512_add_epi32(u[2], u[5]);
372         v[5] = _mm512_sub_epi32(u[2], u[5]);
373         v[3] = _mm512_add_epi32(u[3], u[4]);
374         v[4] = _mm512_sub_epi32(u[3], u[4]);
375 
376         v[10] = _mm512_mullo_epi32(u[10], cospim32);
377         x     = _mm512_mullo_epi32(u[13], cospi32);
378         v[10] = _mm512_add_epi32(v[10], x);
379         v[10] = _mm512_add_epi32(v[10], rnding);
380         v[10] = _mm512_srai_epi32(v[10], (uint8_t)bit);
381 
382         v[13] = _mm512_mullo_epi32(u[10], cospi32);
383         x     = _mm512_mullo_epi32(u[13], cospim32);
384         v[13] = _mm512_sub_epi32(v[13], x);
385         v[13] = _mm512_add_epi32(v[13], rnding);
386         v[13] = _mm512_srai_epi32(v[13], (uint8_t)bit);
387 
388         v[11] = _mm512_mullo_epi32(u[11], cospim32);
389         x     = _mm512_mullo_epi32(u[12], cospi32);
390         v[11] = _mm512_add_epi32(v[11], x);
391         v[11] = _mm512_add_epi32(v[11], rnding);
392         v[11] = _mm512_srai_epi32(v[11], (uint8_t)bit);
393 
394         v[12] = _mm512_mullo_epi32(u[11], cospi32);
395         x     = _mm512_mullo_epi32(u[12], cospim32);
396         v[12] = _mm512_sub_epi32(v[12], x);
397         v[12] = _mm512_add_epi32(v[12], rnding);
398         v[12] = _mm512_srai_epi32(v[12], (uint8_t)bit);
399 
400         // stage 3
401         u[0] = _mm512_add_epi32(v[0], v[3]);
402         u[3] = _mm512_sub_epi32(v[0], v[3]);
403         u[1] = _mm512_add_epi32(v[1], v[2]);
404         u[2] = _mm512_sub_epi32(v[1], v[2]);
405 
406         u[5] = _mm512_mullo_epi32(v[5], cospim32);
407         x    = _mm512_mullo_epi32(v[6], cospi32);
408         u[5] = _mm512_add_epi32(u[5], x);
409         u[5] = _mm512_add_epi32(u[5], rnding);
410         u[5] = _mm512_srai_epi32(u[5], (uint8_t)bit);
411 
412         u[6] = _mm512_mullo_epi32(v[5], cospi32);
413         x    = _mm512_mullo_epi32(v[6], cospim32);
414         u[6] = _mm512_sub_epi32(u[6], x);
415         u[6] = _mm512_add_epi32(u[6], rnding);
416         u[6] = _mm512_srai_epi32(u[6], (uint8_t)bit);
417 
418         u[11] = _mm512_sub_epi32(u[8], v[11]);
419         u[8]  = _mm512_add_epi32(u[8], v[11]);
420         u[10] = _mm512_sub_epi32(u[9], v[10]);
421         u[9]  = _mm512_add_epi32(u[9], v[10]);
422         u[12] = _mm512_sub_epi32(u[15], v[12]);
423         u[15] = _mm512_add_epi32(u[15], v[12]);
424         u[13] = _mm512_sub_epi32(u[14], v[13]);
425         u[14] = _mm512_add_epi32(u[14], v[13]);
426 
427         // stage 4
428         u[0] = _mm512_mullo_epi32(u[0], cospi32);
429         u[1] = _mm512_mullo_epi32(u[1], cospi32);
430         v[0] = _mm512_add_epi32(u[0], u[1]);
431         v[0] = _mm512_add_epi32(v[0], rnding);
432         out[0 * col_num + col] = _mm512_srai_epi32(v[0], (uint8_t)bit);
433 
434         v[1] = _mm512_sub_epi32(u[0], u[1]);
435         v[1] = _mm512_add_epi32(v[1], rnding);
436         out[8 * col_num + col] = _mm512_srai_epi32(v[1], (uint8_t)bit);
437 
438         v[2] = _mm512_mullo_epi32(u[2], cospi48);
439         x    = _mm512_mullo_epi32(u[3], cospi16);
440         v[2] = _mm512_add_epi32(v[2], x);
441         v[2] = _mm512_add_epi32(v[2], rnding);
442         out[4 * col_num + col] = _mm512_srai_epi32(v[2], (uint8_t)bit);
443 
444         v[3] = _mm512_mullo_epi32(u[2], cospi16);
445         x    = _mm512_mullo_epi32(u[3], cospi48);
446         v[3] = _mm512_sub_epi32(x, v[3]);
447         v[3] = _mm512_add_epi32(v[3], rnding);
448         out[12 * col_num + col] = _mm512_srai_epi32(v[3], (uint8_t)bit);
449 
450         v[5] = _mm512_sub_epi32(v[4], u[5]);
451         v[4] = _mm512_add_epi32(v[4], u[5]);
452         v[6] = _mm512_sub_epi32(v[7], u[6]);
453         v[7] = _mm512_add_epi32(v[7], u[6]);
454 
455         v[9] = _mm512_mullo_epi32(u[9], cospim16);
456         x    = _mm512_mullo_epi32(u[14], cospi48);
457         v[9] = _mm512_add_epi32(v[9], x);
458         v[9] = _mm512_add_epi32(v[9], rnding);
459         v[9] = _mm512_srai_epi32(v[9], (uint8_t)bit);
460 
461         v[14] = _mm512_mullo_epi32(u[9], cospi48);
462         x     = _mm512_mullo_epi32(u[14], cospim16);
463         v[14] = _mm512_sub_epi32(v[14], x);
464         v[14] = _mm512_add_epi32(v[14], rnding);
465         v[14] = _mm512_srai_epi32(v[14], (uint8_t)bit);
466 
467         v[10] = _mm512_mullo_epi32(u[10], cospim48);
468         x     = _mm512_mullo_epi32(u[13], cospim16);
469         v[10] = _mm512_add_epi32(v[10], x);
470         v[10] = _mm512_add_epi32(v[10], rnding);
471         v[10] = _mm512_srai_epi32(v[10], (uint8_t)bit);
472 
473         v[13] = _mm512_mullo_epi32(u[10], cospim16);
474         x     = _mm512_mullo_epi32(u[13], cospim48);
475         v[13] = _mm512_sub_epi32(v[13], x);
476         v[13] = _mm512_add_epi32(v[13], rnding);
477         v[13] = _mm512_srai_epi32(v[13], (uint8_t)bit);
478 
479         // stage 5
480         u[4] = _mm512_mullo_epi32(v[4], cospi56);
481         x    = _mm512_mullo_epi32(v[7], cospi8);
482         u[4] = _mm512_add_epi32(u[4], x);
483         u[4] = _mm512_add_epi32(u[4], rnding);
484         out[2 * col_num + col] = _mm512_srai_epi32(u[4], (uint8_t)bit);
485 
486         u[7] = _mm512_mullo_epi32(v[4], cospi8);
487         x    = _mm512_mullo_epi32(v[7], cospi56);
488         u[7] = _mm512_sub_epi32(x, u[7]);
489         u[7] = _mm512_add_epi32(u[7], rnding);
490         out[14 * col_num + col] = _mm512_srai_epi32(u[7], (uint8_t)bit);
491 
492         u[5] = _mm512_mullo_epi32(v[5], cospi24);
493         x    = _mm512_mullo_epi32(v[6], cospi40);
494         u[5] = _mm512_add_epi32(u[5], x);
495         u[5] = _mm512_add_epi32(u[5], rnding);
496         out[10 * col_num + col] = _mm512_srai_epi32(u[5], (uint8_t)bit);
497 
498         u[6] = _mm512_mullo_epi32(v[5], cospi40);
499         x    = _mm512_mullo_epi32(v[6], cospi24);
500         u[6] = _mm512_sub_epi32(x, u[6]);
501         u[6] = _mm512_add_epi32(u[6], rnding);
502         out[6 * col_num + col] = _mm512_srai_epi32(u[6], (uint8_t)bit);
503 
504         u[9] = _mm512_sub_epi32(u[8], v[9]);
505         u[8]  = _mm512_add_epi32(u[8], v[9]);
506         u[10] = _mm512_sub_epi32(u[11], v[10]);
507         u[11] = _mm512_add_epi32(u[11], v[10]);
508         u[13] = _mm512_sub_epi32(u[12], v[13]);
509         u[12] = _mm512_add_epi32(u[12], v[13]);
510         u[14] = _mm512_sub_epi32(u[15], v[14]);
511         u[15] = _mm512_add_epi32(u[15], v[14]);
512 
513         // stage 6
514         v[8] = _mm512_mullo_epi32(u[8], cospi60);
515         x    = _mm512_mullo_epi32(u[15], cospi4);
516         v[8] = _mm512_add_epi32(v[8], x);
517         v[8] = _mm512_add_epi32(v[8], rnding);
518         out[1 * col_num + col] = _mm512_srai_epi32(v[8], (uint8_t)bit);
519 
520         v[15] = _mm512_mullo_epi32(u[8], cospi4);
521         x     = _mm512_mullo_epi32(u[15], cospi60);
522         v[15] = _mm512_sub_epi32(x, v[15]);
523         v[15] = _mm512_add_epi32(v[15], rnding);
524         out[15 * col_num + col] = _mm512_srai_epi32(v[15], (uint8_t)bit);
525 
526         v[9] = _mm512_mullo_epi32(u[9], cospi28);
527         x    = _mm512_mullo_epi32(u[14], cospi36);
528         v[9] = _mm512_add_epi32(v[9], x);
529         v[9] = _mm512_add_epi32(v[9], rnding);
530         out[9 * col_num + col] = _mm512_srai_epi32(v[9], (uint8_t)bit);
531 
532         v[14] = _mm512_mullo_epi32(u[9], cospi36);
533         x     = _mm512_mullo_epi32(u[14], cospi28);
534         v[14] = _mm512_sub_epi32(x, v[14]);
535         v[14] = _mm512_add_epi32(v[14], rnding);
536         out[7 * col_num + col] = _mm512_srai_epi32(v[14], (uint8_t)bit);
537 
538         v[10] = _mm512_mullo_epi32(u[10], cospi44);
539         x     = _mm512_mullo_epi32(u[13], cospi20);
540         v[10] = _mm512_add_epi32(v[10], x);
541         v[10] = _mm512_add_epi32(v[10], rnding);
542         out[5 * col_num + col] = _mm512_srai_epi32(v[10], (uint8_t)bit);
543 
544         v[13] = _mm512_mullo_epi32(u[10], cospi20);
545         x     = _mm512_mullo_epi32(u[13], cospi44);
546         v[13] = _mm512_sub_epi32(x, v[13]);
547         v[13] = _mm512_add_epi32(v[13], rnding);
548         out[11 * col_num + col] = _mm512_srai_epi32(v[13], (uint8_t)bit);
549 
550         v[11] = _mm512_mullo_epi32(u[11], cospi12);
551         x     = _mm512_mullo_epi32(u[12], cospi52);
552         v[11] = _mm512_add_epi32(v[11], x);
553         v[11] = _mm512_add_epi32(v[11], rnding);
554         out[13 * col_num + col] = _mm512_srai_epi32(v[11], (uint8_t)bit);
555 
556         v[12] = _mm512_mullo_epi32(u[11], cospi52);
557         x     = _mm512_mullo_epi32(u[12], cospi12);
558         v[12] = _mm512_sub_epi32(x, v[12]);
559         v[12] = _mm512_add_epi32(v[12], rnding);
560         out[3 * col_num + col] = _mm512_srai_epi32(v[12], (uint8_t)bit);
561     }
562 }
563 
av1_fwd_txfm2d_16x16_avx512(int16_t * input,int32_t * coeff,uint32_t stride,TxType tx_type,uint8_t bd)564 void av1_fwd_txfm2d_16x16_avx512(int16_t *input, int32_t *coeff, uint32_t stride, TxType tx_type,
565                                  uint8_t bd) {
566     __m512i       in[16], out[16];
567     const int8_t *shift   = fwd_txfm_shift_ls[TX_16X16];
568     const int32_t txw_idx = get_txw_idx(TX_16X16);
569     const int32_t txh_idx = get_txh_idx(TX_16X16);
570     const int32_t col_num = 1;
571     switch (tx_type) {
572     case IDTX:
573         load_buffer_16x16_avx512(input, in, stride, 0, 0, shift[0]);
574         fidtx16x16_avx512(in, out, fwd_cos_bit_col[txw_idx][txh_idx], col_num);
575         col_txfm_16x16_rounding_avx512(out, -shift[1]);
576         fidtx16x16_avx512(out, out, fwd_cos_bit_row[txw_idx][txh_idx], col_num);
577         write_buffer_16x16(out, coeff);
578         break;
579     case DCT_DCT:
580         load_buffer_16x16_avx512(input, in, stride, 0, 0, shift[0]);
581         fdct16x16_avx512(in, out, fwd_cos_bit_col[txw_idx][txh_idx], col_num);
582         col_txfm_16x16_rounding_avx512(out, -shift[1]);
583         transpose_16x16_avx512(out, in);
584         fdct16x16_avx512(in, out, fwd_cos_bit_row[txw_idx][txh_idx], col_num);
585         transpose_16x16_avx512(out, in);
586         write_buffer_16x16(in, coeff);
587         break;
588     case ADST_DCT:
589         load_buffer_16x16_avx512(input, in, stride, 0, 0, shift[0]);
590         fadst16x16_avx512(in, out, fwd_cos_bit_col[txw_idx][txh_idx], col_num);
591         col_txfm_16x16_rounding_avx512(out, -shift[1]);
592         transpose_16x16_avx512(out, in);
593         fdct16x16_avx512(in, out, fwd_cos_bit_row[txw_idx][txh_idx], col_num);
594         transpose_16x16_avx512(out, in);
595         write_buffer_16x16(in, coeff);
596         break;
597     case DCT_ADST:
598         load_buffer_16x16_avx512(input, in, stride, 0, 0, shift[0]);
599         fdct16x16_avx512(in, out, fwd_cos_bit_col[txw_idx][txh_idx], col_num);
600         col_txfm_16x16_rounding_avx512(out, -shift[1]);
601         transpose_16x16_avx512(out, in);
602         fadst16x16_avx512(in, out, fwd_cos_bit_row[txw_idx][txh_idx], col_num);
603         transpose_16x16_avx512(out, in);
604         write_buffer_16x16(in, coeff);
605         break;
606     case ADST_ADST:
607         load_buffer_16x16_avx512(input, in, stride, 0, 0, shift[0]);
608         fadst16x16_avx512(in, out, fwd_cos_bit_col[txw_idx][txh_idx], col_num);
609         col_txfm_16x16_rounding_avx512(out, -shift[1]);
610         transpose_16x16_avx512(out, in);
611         fadst16x16_avx512(in, out, fwd_cos_bit_row[txw_idx][txh_idx], col_num);
612         transpose_16x16_avx512(out, in);
613         write_buffer_16x16(in, coeff);
614         break;
615     case DCT_FLIPADST:
616         load_buffer_16x16_avx512(input, in, stride, 0, 1, shift[0]);
617         fdct16x16_avx512(in, out, fwd_cos_bit_col[txw_idx][txh_idx], col_num);
618         col_txfm_16x16_rounding_avx512(out, -shift[1]);
619         transpose_16x16_avx512(out, in);
620         fadst16x16_avx512(in, out, fwd_cos_bit_row[txw_idx][txh_idx], col_num);
621         transpose_16x16_avx512(out, in);
622         write_buffer_16x16(in, coeff);
623         break;
624     case FLIPADST_DCT:
625         load_buffer_16x16_avx512(input, in, stride, 1, 0, shift[0]);
626         fadst16x16_avx512(in, out, fwd_cos_bit_col[txw_idx][txh_idx], col_num);
627         col_txfm_16x16_rounding_avx512(out, -shift[1]);
628         transpose_16x16_avx512(out, in);
629         fdct16x16_avx512(in, out, fwd_cos_bit_row[txw_idx][txh_idx], col_num);
630         transpose_16x16_avx512(out, in);
631         write_buffer_16x16(in, coeff);
632         break;
633     case FLIPADST_FLIPADST:
634         load_buffer_16x16_avx512(input, in, stride, 1, 1, shift[0]);
635         fadst16x16_avx512(in, out, fwd_cos_bit_col[txw_idx][txh_idx], col_num);
636         col_txfm_16x16_rounding_avx512(out, -shift[1]);
637         transpose_16x16_avx512(out, in);
638         fadst16x16_avx512(in, out, fwd_cos_bit_row[txw_idx][txh_idx], col_num);
639         transpose_16x16_avx512(out, in);
640         write_buffer_16x16(in, coeff);
641         break;
642     case ADST_FLIPADST:
643         load_buffer_16x16_avx512(input, in, stride, 0, 1, shift[0]);
644         fadst16x16_avx512(in, out, fwd_cos_bit_col[txw_idx][txh_idx], col_num);
645         col_txfm_16x16_rounding_avx512(out, -shift[1]);
646         transpose_16x16_avx512(out, in);
647         fadst16x16_avx512(in, out, fwd_cos_bit_row[txw_idx][txh_idx], col_num);
648         transpose_16x16_avx512(out, in);
649         write_buffer_16x16(in, coeff);
650         break;
651     case FLIPADST_ADST:
652         load_buffer_16x16_avx512(input, in, stride, 1, 0, shift[0]);
653         fadst16x16_avx512(in, out, fwd_cos_bit_col[txw_idx][txh_idx], col_num);
654         col_txfm_16x16_rounding_avx512(out, -shift[1]);
655         transpose_16x16_avx512(out, in);
656         fadst16x16_avx512(in, out, fwd_cos_bit_row[txw_idx][txh_idx], col_num);
657         transpose_16x16_avx512(out, in);
658         write_buffer_16x16(in, coeff);
659         break;
660     case V_DCT:
661         load_buffer_16x16_avx512(input, in, stride, 0, 0, shift[0]);
662         fdct16x16_avx512(in, out, fwd_cos_bit_col[txw_idx][txh_idx], col_num);
663         col_txfm_16x16_rounding_avx512(out, -shift[1]);
664         fidtx16x16_avx512(out, out, fwd_cos_bit_row[txw_idx][txh_idx], col_num);
665         write_buffer_16x16(out, coeff);
666         break;
667     case H_DCT:
668         load_buffer_16x16_avx512(input, in, stride, 0, 0, shift[0]);
669         fidtx16x16_avx512(in, in, fwd_cos_bit_col[txw_idx][txh_idx], col_num);
670         col_txfm_16x16_rounding_avx512(in, -shift[1]);
671         transpose_16x16_avx512(in, out);
672         fdct16x16_avx512(out, in, fwd_cos_bit_row[txw_idx][txh_idx], col_num);
673         transpose_16x16_avx512(in, out);
674         write_buffer_16x16(out, coeff);
675         break;
676     case V_ADST:
677         load_buffer_16x16_avx512(input, in, stride, 0, 0, shift[0]);
678         fadst16x16_avx512(in, out, fwd_cos_bit_col[txw_idx][txh_idx], col_num);
679         col_txfm_16x16_rounding_avx512(out, -shift[1]);
680         fidtx16x16_avx512(out, out, fwd_cos_bit_row[txw_idx][txh_idx], col_num);
681         write_buffer_16x16(out, coeff);
682         break;
683     case H_ADST:
684         load_buffer_16x16_avx512(input, in, stride, 0, 0, shift[0]);
685         fidtx16x16_avx512(in, in, fwd_cos_bit_col[txw_idx][txh_idx], col_num);
686         col_txfm_16x16_rounding_avx512(in, -shift[1]);
687         transpose_16x16_avx512(in, out);
688         fadst16x16_avx512(out, in, fwd_cos_bit_row[txw_idx][txh_idx], col_num);
689         transpose_16x16_avx512(in, out);
690         write_buffer_16x16(out, coeff);
691         break;
692     case V_FLIPADST:
693         load_buffer_16x16_avx512(input, in, stride, 1, 0, shift[0]);
694         fadst16x16_avx512(in, out, fwd_cos_bit_col[txw_idx][txh_idx], col_num);
695         col_txfm_16x16_rounding_avx512(out, -shift[1]);
696         fidtx16x16_avx512(out, out, fwd_cos_bit_row[txw_idx][txh_idx], col_num);
697         write_buffer_16x16(out, coeff);
698         break;
699     case H_FLIPADST:
700         load_buffer_16x16_avx512(input, in, stride, 0, 1, shift[0]);
701         fidtx16x16_avx512(in, in, fwd_cos_bit_col[txw_idx][txh_idx], col_num);
702         col_txfm_16x16_rounding_avx512(in, -shift[1]);
703         transpose_16x16_avx512(in, out);
704         fadst16x16_avx512(out, in, fwd_cos_bit_row[txw_idx][txh_idx], col_num);
705         transpose_16x16_avx512(in, out);
706         write_buffer_16x16(out, coeff);
707         break;
708     default: assert(0);
709     }
710     (void)bd;
711 }
712 
av1_fdct32_new_avx512(const __m512i * input,__m512i * output,const int8_t cos_bit,const int32_t col_num,const int32_t stride)713 static void av1_fdct32_new_avx512(const __m512i *input, __m512i *output, const int8_t cos_bit,
714                                   const int32_t col_num, const int32_t stride) {
715     const int32_t *cospi      = cospi_arr(cos_bit);
716     const __m512i  __rounding = _mm512_set1_epi32(1 << (cos_bit - 1));
717     const int32_t  columns    = col_num >> 4;
718 
719     __m512i cospi_m32 = _mm512_set1_epi32(-cospi[32]);
720     __m512i cospi_p32 = _mm512_set1_epi32(cospi[32]);
721     __m512i cospi_m16 = _mm512_set1_epi32(-cospi[16]);
722     __m512i cospi_p48 = _mm512_set1_epi32(cospi[48]);
723     __m512i cospi_m48 = _mm512_set1_epi32(-cospi[48]);
724     __m512i cospi_m08 = _mm512_set1_epi32(-cospi[8]);
725     __m512i cospi_p56 = _mm512_set1_epi32(cospi[56]);
726     __m512i cospi_m56 = _mm512_set1_epi32(-cospi[56]);
727     __m512i cospi_p40 = _mm512_set1_epi32(cospi[40]);
728     __m512i cospi_m40 = _mm512_set1_epi32(-cospi[40]);
729     __m512i cospi_p24 = _mm512_set1_epi32(cospi[24]);
730     __m512i cospi_m24 = _mm512_set1_epi32(-cospi[24]);
731     __m512i cospi_p16 = _mm512_set1_epi32(cospi[16]);
732     __m512i cospi_p08 = _mm512_set1_epi32(cospi[8]);
733     __m512i cospi_p04 = _mm512_set1_epi32(cospi[4]);
734     __m512i cospi_p60 = _mm512_set1_epi32(cospi[60]);
735     __m512i cospi_p36 = _mm512_set1_epi32(cospi[36]);
736     __m512i cospi_p28 = _mm512_set1_epi32(cospi[28]);
737     __m512i cospi_p20 = _mm512_set1_epi32(cospi[20]);
738     __m512i cospi_p44 = _mm512_set1_epi32(cospi[44]);
739     __m512i cospi_p52 = _mm512_set1_epi32(cospi[52]);
740     __m512i cospi_p12 = _mm512_set1_epi32(cospi[12]);
741     __m512i cospi_p02 = _mm512_set1_epi32(cospi[2]);
742     __m512i cospi_p06 = _mm512_set1_epi32(cospi[6]);
743     __m512i cospi_p62 = _mm512_set1_epi32(cospi[62]);
744     __m512i cospi_p34 = _mm512_set1_epi32(cospi[34]);
745     __m512i cospi_p30 = _mm512_set1_epi32(cospi[30]);
746     __m512i cospi_p18 = _mm512_set1_epi32(cospi[18]);
747     __m512i cospi_p46 = _mm512_set1_epi32(cospi[46]);
748     __m512i cospi_p50 = _mm512_set1_epi32(cospi[50]);
749     __m512i cospi_p14 = _mm512_set1_epi32(cospi[14]);
750     __m512i cospi_p10 = _mm512_set1_epi32(cospi[10]);
751     __m512i cospi_p54 = _mm512_set1_epi32(cospi[54]);
752     __m512i cospi_p42 = _mm512_set1_epi32(cospi[42]);
753     __m512i cospi_p22 = _mm512_set1_epi32(cospi[22]);
754     __m512i cospi_p26 = _mm512_set1_epi32(cospi[26]);
755     __m512i cospi_p38 = _mm512_set1_epi32(cospi[38]);
756     __m512i cospi_p58 = _mm512_set1_epi32(cospi[58]);
757 
758     __m512i buf0[32];
759     __m512i buf1[32];
760 
761     for (int32_t col = 0; col < columns; col++) {
762         const __m512i *in  = &input[col];
763         __m512i *      out = &output[col];
764 
765         // stage 0
766         // stage 1
767         buf1[0]  = _mm512_add_epi32(in[0 * stride], in[31 * stride]);
768         buf1[31] = _mm512_sub_epi32(in[0 * stride], in[31 * stride]);
769         buf1[1]  = _mm512_add_epi32(in[1 * stride], in[30 * stride]);
770         buf1[30] = _mm512_sub_epi32(in[1 * stride], in[30 * stride]);
771         buf1[2]  = _mm512_add_epi32(in[2 * stride], in[29 * stride]);
772         buf1[29] = _mm512_sub_epi32(in[2 * stride], in[29 * stride]);
773         buf1[3]  = _mm512_add_epi32(in[3 * stride], in[28 * stride]);
774         buf1[28] = _mm512_sub_epi32(in[3 * stride], in[28 * stride]);
775         buf1[4]  = _mm512_add_epi32(in[4 * stride], in[27 * stride]);
776         buf1[27] = _mm512_sub_epi32(in[4 * stride], in[27 * stride]);
777         buf1[5]  = _mm512_add_epi32(in[5 * stride], in[26 * stride]);
778         buf1[26] = _mm512_sub_epi32(in[5 * stride], in[26 * stride]);
779         buf1[6]  = _mm512_add_epi32(in[6 * stride], in[25 * stride]);
780         buf1[25] = _mm512_sub_epi32(in[6 * stride], in[25 * stride]);
781         buf1[7]  = _mm512_add_epi32(in[7 * stride], in[24 * stride]);
782         buf1[24] = _mm512_sub_epi32(in[7 * stride], in[24 * stride]);
783         buf1[8]  = _mm512_add_epi32(in[8 * stride], in[23 * stride]);
784         buf1[23] = _mm512_sub_epi32(in[8 * stride], in[23 * stride]);
785         buf1[9]  = _mm512_add_epi32(in[9 * stride], in[22 * stride]);
786         buf1[22] = _mm512_sub_epi32(in[9 * stride], in[22 * stride]);
787         buf1[10] = _mm512_add_epi32(in[10 * stride], in[21 * stride]);
788         buf1[21] = _mm512_sub_epi32(in[10 * stride], in[21 * stride]);
789         buf1[11] = _mm512_add_epi32(in[11 * stride], in[20 * stride]);
790         buf1[20] = _mm512_sub_epi32(in[11 * stride], in[20 * stride]);
791         buf1[12] = _mm512_add_epi32(in[12 * stride], in[19 * stride]);
792         buf1[19] = _mm512_sub_epi32(in[12 * stride], in[19 * stride]);
793         buf1[13] = _mm512_add_epi32(in[13 * stride], in[18 * stride]);
794         buf1[18] = _mm512_sub_epi32(in[13 * stride], in[18 * stride]);
795         buf1[14] = _mm512_add_epi32(in[14 * stride], in[17 * stride]);
796         buf1[17] = _mm512_sub_epi32(in[14 * stride], in[17 * stride]);
797         buf1[15] = _mm512_add_epi32(in[15 * stride], in[16 * stride]);
798         buf1[16] = _mm512_sub_epi32(in[15 * stride], in[16 * stride]);
799 
800         // stage 2
801         buf0[0]  = _mm512_add_epi32(buf1[0], buf1[15]);
802         buf0[15] = _mm512_sub_epi32(buf1[0], buf1[15]);
803         buf0[1]  = _mm512_add_epi32(buf1[1], buf1[14]);
804         buf0[14] = _mm512_sub_epi32(buf1[1], buf1[14]);
805         buf0[2]  = _mm512_add_epi32(buf1[2], buf1[13]);
806         buf0[13] = _mm512_sub_epi32(buf1[2], buf1[13]);
807         buf0[3]  = _mm512_add_epi32(buf1[3], buf1[12]);
808         buf0[12] = _mm512_sub_epi32(buf1[3], buf1[12]);
809         buf0[4]  = _mm512_add_epi32(buf1[4], buf1[11]);
810         buf0[11] = _mm512_sub_epi32(buf1[4], buf1[11]);
811         buf0[5]  = _mm512_add_epi32(buf1[5], buf1[10]);
812         buf0[10] = _mm512_sub_epi32(buf1[5], buf1[10]);
813         buf0[6]  = _mm512_add_epi32(buf1[6], buf1[9]);
814         buf0[9]  = _mm512_sub_epi32(buf1[6], buf1[9]);
815         buf0[7]  = _mm512_add_epi32(buf1[7], buf1[8]);
816         buf0[8]  = _mm512_sub_epi32(buf1[7], buf1[8]);
817         btf_32_type0_avx512_new(
818             cospi_m32, cospi_p32, buf1[20], buf1[27], buf0[20], buf0[27], __rounding, cos_bit);
819         btf_32_type0_avx512_new(
820             cospi_m32, cospi_p32, buf1[21], buf1[26], buf0[21], buf0[26], __rounding, cos_bit);
821         btf_32_type0_avx512_new(
822             cospi_m32, cospi_p32, buf1[22], buf1[25], buf0[22], buf0[25], __rounding, cos_bit);
823         btf_32_type0_avx512_new(
824             cospi_m32, cospi_p32, buf1[23], buf1[24], buf0[23], buf0[24], __rounding, cos_bit);
825 
826         // stage 3
827         buf1[0] = _mm512_add_epi32(buf0[0], buf0[7]);
828         buf1[7] = _mm512_sub_epi32(buf0[0], buf0[7]);
829         buf1[1] = _mm512_add_epi32(buf0[1], buf0[6]);
830         buf1[6] = _mm512_sub_epi32(buf0[1], buf0[6]);
831         buf1[2] = _mm512_add_epi32(buf0[2], buf0[5]);
832         buf1[5] = _mm512_sub_epi32(buf0[2], buf0[5]);
833         buf1[3] = _mm512_add_epi32(buf0[3], buf0[4]);
834         buf1[4] = _mm512_sub_epi32(buf0[3], buf0[4]);
835         btf_32_type0_avx512_new(
836             cospi_m32, cospi_p32, buf0[10], buf0[13], buf1[10], buf1[13], __rounding, cos_bit);
837         btf_32_type0_avx512_new(
838             cospi_m32, cospi_p32, buf0[11], buf0[12], buf1[11], buf1[12], __rounding, cos_bit);
839         buf1[23] = _mm512_sub_epi32(buf1[16], buf0[23]);
840         buf1[16] = _mm512_add_epi32(buf1[16], buf0[23]);
841         buf1[22] = _mm512_sub_epi32(buf1[17], buf0[22]);
842         buf1[17] = _mm512_add_epi32(buf1[17], buf0[22]);
843         buf1[21] = _mm512_sub_epi32(buf1[18], buf0[21]);
844         buf1[18] = _mm512_add_epi32(buf1[18], buf0[21]);
845         buf1[20] = _mm512_sub_epi32(buf1[19], buf0[20]);
846         buf1[19] = _mm512_add_epi32(buf1[19], buf0[20]);
847         buf1[24] = _mm512_sub_epi32(buf1[31], buf0[24]);
848         buf1[31] = _mm512_add_epi32(buf1[31], buf0[24]);
849         buf1[25] = _mm512_sub_epi32(buf1[30], buf0[25]);
850         buf1[30] = _mm512_add_epi32(buf1[30], buf0[25]);
851         buf1[26] = _mm512_sub_epi32(buf1[29], buf0[26]);
852         buf1[29] = _mm512_add_epi32(buf1[29], buf0[26]);
853         buf1[27] = _mm512_sub_epi32(buf1[28], buf0[27]);
854         buf1[28] = _mm512_add_epi32(buf1[28], buf0[27]);
855 
856         // stage 4
857         buf0[0] = _mm512_add_epi32(buf1[0], buf1[3]);
858         buf0[3] = _mm512_sub_epi32(buf1[0], buf1[3]);
859         buf0[1] = _mm512_add_epi32(buf1[1], buf1[2]);
860         buf0[2] = _mm512_sub_epi32(buf1[1], buf1[2]);
861         btf_32_type0_avx512_new(
862             cospi_m32, cospi_p32, buf1[5], buf1[6], buf0[5], buf0[6], __rounding, cos_bit);
863         buf0[11] = _mm512_sub_epi32(buf0[8], buf1[11]);
864         buf0[8]  = _mm512_add_epi32(buf0[8], buf1[11]);
865         buf0[10] = _mm512_sub_epi32(buf0[9], buf1[10]);
866         buf0[9]  = _mm512_add_epi32(buf0[9], buf1[10]);
867         buf0[12] = _mm512_sub_epi32(buf0[15], buf1[12]);
868         buf0[15] = _mm512_add_epi32(buf0[15], buf1[12]);
869         buf0[13] = _mm512_sub_epi32(buf0[14], buf1[13]);
870         buf0[14] = _mm512_add_epi32(buf0[14], buf1[13]);
871         btf_32_type0_avx512_new(
872             cospi_m16, cospi_p48, buf1[18], buf1[29], buf0[18], buf0[29], __rounding, cos_bit);
873         btf_32_type0_avx512_new(
874             cospi_m16, cospi_p48, buf1[19], buf1[28], buf0[19], buf0[28], __rounding, cos_bit);
875         btf_32_type0_avx512_new(
876             cospi_m48, cospi_m16, buf1[20], buf1[27], buf0[20], buf0[27], __rounding, cos_bit);
877         btf_32_type0_avx512_new(
878             cospi_m48, cospi_m16, buf1[21], buf1[26], buf0[21], buf0[26], __rounding, cos_bit);
879 
880         // stage 5
881         btf_32_type0_avx512_new(cospi_p32,
882                                 cospi_p32,
883                                 buf0[0],
884                                 buf0[1],
885                                 out[0 * stride],
886                                 out[16 * stride],
887                                 __rounding,
888                                 cos_bit);
889         btf_32_type1_avx512_new(cospi_p48,
890                                 cospi_p16,
891                                 buf0[2],
892                                 buf0[3],
893                                 out[8 * stride],
894                                 out[24 * stride],
895                                 __rounding,
896                                 cos_bit);
897         buf1[5] = _mm512_sub_epi32(buf1[4], buf0[5]);
898         buf1[4] = _mm512_add_epi32(buf1[4], buf0[5]);
899         buf1[6] = _mm512_sub_epi32(buf1[7], buf0[6]);
900         buf1[7] = _mm512_add_epi32(buf1[7], buf0[6]);
901         btf_32_type0_avx512_new(
902             cospi_m16, cospi_p48, buf0[9], buf0[14], buf1[9], buf1[14], __rounding, cos_bit);
903         btf_32_type0_avx512_new(
904             cospi_m48, cospi_m16, buf0[10], buf0[13], buf1[10], buf1[13], __rounding, cos_bit);
905         buf1[19] = _mm512_sub_epi32(buf1[16], buf0[19]);
906         buf1[16] = _mm512_add_epi32(buf1[16], buf0[19]);
907         buf1[18] = _mm512_sub_epi32(buf1[17], buf0[18]);
908         buf1[17] = _mm512_add_epi32(buf1[17], buf0[18]);
909         buf1[20] = _mm512_sub_epi32(buf1[23], buf0[20]);
910         buf1[23] = _mm512_add_epi32(buf1[23], buf0[20]);
911         buf1[21] = _mm512_sub_epi32(buf1[22], buf0[21]);
912         buf1[22] = _mm512_add_epi32(buf1[22], buf0[21]);
913         buf1[27] = _mm512_sub_epi32(buf1[24], buf0[27]);
914         buf1[24] = _mm512_add_epi32(buf1[24], buf0[27]);
915         buf1[26] = _mm512_sub_epi32(buf1[25], buf0[26]);
916         buf1[25] = _mm512_add_epi32(buf1[25], buf0[26]);
917         buf1[28] = _mm512_sub_epi32(buf1[31], buf0[28]);
918         buf1[31] = _mm512_add_epi32(buf1[31], buf0[28]);
919         buf1[29] = _mm512_sub_epi32(buf1[30], buf0[29]);
920         buf1[30] = _mm512_add_epi32(buf1[30], buf0[29]);
921 
922         // stage 6
923         btf_32_type1_avx512_new(cospi_p56,
924                                 cospi_p08,
925                                 buf1[4],
926                                 buf1[7],
927                                 out[4 * stride],
928                                 out[28 * stride],
929                                 __rounding,
930                                 cos_bit);
931         btf_32_type1_avx512_new(cospi_p24,
932                                 cospi_p40,
933                                 buf1[5],
934                                 buf1[6],
935                                 out[20 * stride],
936                                 out[12 * stride],
937                                 __rounding,
938                                 cos_bit);
939         buf0[9]  = _mm512_sub_epi32(buf0[8], buf1[9]);
940         buf0[8]  = _mm512_add_epi32(buf0[8], buf1[9]);
941         buf0[10] = _mm512_sub_epi32(buf0[11], buf1[10]);
942         buf0[11] = _mm512_add_epi32(buf0[11], buf1[10]);
943         buf0[13] = _mm512_sub_epi32(buf0[12], buf1[13]);
944         buf0[12] = _mm512_add_epi32(buf0[12], buf1[13]);
945         buf0[14] = _mm512_sub_epi32(buf0[15], buf1[14]);
946         buf0[15] = _mm512_add_epi32(buf0[15], buf1[14]);
947         btf_32_type0_avx512_new(
948             cospi_m08, cospi_p56, buf1[17], buf1[30], buf0[17], buf0[30], __rounding, cos_bit);
949         btf_32_type0_avx512_new(
950             cospi_m56, cospi_m08, buf1[18], buf1[29], buf0[18], buf0[29], __rounding, cos_bit);
951         btf_32_type0_avx512_new(
952             cospi_m40, cospi_p24, buf1[21], buf1[26], buf0[21], buf0[26], __rounding, cos_bit);
953         btf_32_type0_avx512_new(
954             cospi_m24, cospi_m40, buf1[22], buf1[25], buf0[22], buf0[25], __rounding, cos_bit);
955 
956         // stage 7
957         btf_32_type1_avx512_new(cospi_p60,
958                                 cospi_p04,
959                                 buf0[8],
960                                 buf0[15],
961                                 out[2 * stride],
962                                 out[30 * stride],
963                                 __rounding,
964                                 cos_bit);
965         btf_32_type1_avx512_new(cospi_p28,
966                                 cospi_p36,
967                                 buf0[9],
968                                 buf0[14],
969                                 out[18 * stride],
970                                 out[14 * stride],
971                                 __rounding,
972                                 cos_bit);
973         btf_32_type1_avx512_new(cospi_p44,
974                                 cospi_p20,
975                                 buf0[10],
976                                 buf0[13],
977                                 out[10 * stride],
978                                 out[22 * stride],
979                                 __rounding,
980                                 cos_bit);
981         btf_32_type1_avx512_new(cospi_p12,
982                                 cospi_p52,
983                                 buf0[11],
984                                 buf0[12],
985                                 out[26 * stride],
986                                 out[6 * stride],
987                                 __rounding,
988                                 cos_bit);
989         buf1[17] = _mm512_sub_epi32(buf1[16], buf0[17]);
990         buf1[16] = _mm512_add_epi32(buf1[16], buf0[17]);
991         buf1[18] = _mm512_sub_epi32(buf1[19], buf0[18]);
992         buf1[19] = _mm512_add_epi32(buf1[19], buf0[18]);
993         buf1[21] = _mm512_sub_epi32(buf1[20], buf0[21]);
994         buf1[20] = _mm512_add_epi32(buf1[20], buf0[21]);
995         buf1[22] = _mm512_sub_epi32(buf1[23], buf0[22]);
996         buf1[23] = _mm512_add_epi32(buf1[23], buf0[22]);
997         buf1[25] = _mm512_sub_epi32(buf1[24], buf0[25]);
998         buf1[24] = _mm512_add_epi32(buf1[24], buf0[25]);
999         buf1[26] = _mm512_sub_epi32(buf1[27], buf0[26]);
1000         buf1[27] = _mm512_add_epi32(buf1[27], buf0[26]);
1001         buf1[29] = _mm512_sub_epi32(buf1[28], buf0[29]);
1002         buf1[28] = _mm512_add_epi32(buf1[28], buf0[29]);
1003         buf1[30] = _mm512_sub_epi32(buf1[31], buf0[30]);
1004         buf1[31] = _mm512_add_epi32(buf1[31], buf0[30]);
1005 
1006         // stage 8
1007         btf_32_type1_avx512_new(cospi_p62,
1008                                 cospi_p02,
1009                                 buf1[16],
1010                                 buf1[31],
1011                                 out[1 * stride],
1012                                 out[31 * stride],
1013                                 __rounding,
1014                                 cos_bit);
1015         btf_32_type1_avx512_new(cospi_p30,
1016                                 cospi_p34,
1017                                 buf1[17],
1018                                 buf1[30],
1019                                 out[17 * stride],
1020                                 out[15 * stride],
1021                                 __rounding,
1022                                 cos_bit);
1023         btf_32_type1_avx512_new(cospi_p46,
1024                                 cospi_p18,
1025                                 buf1[18],
1026                                 buf1[29],
1027                                 out[9 * stride],
1028                                 out[23 * stride],
1029                                 __rounding,
1030                                 cos_bit);
1031         btf_32_type1_avx512_new(cospi_p14,
1032                                 cospi_p50,
1033                                 buf1[19],
1034                                 buf1[28],
1035                                 out[25 * stride],
1036                                 out[7 * stride],
1037                                 __rounding,
1038                                 cos_bit);
1039         btf_32_type1_avx512_new(cospi_p54,
1040                                 cospi_p10,
1041                                 buf1[20],
1042                                 buf1[27],
1043                                 out[5 * stride],
1044                                 out[27 * stride],
1045                                 __rounding,
1046                                 cos_bit);
1047         btf_32_type1_avx512_new(cospi_p22,
1048                                 cospi_p42,
1049                                 buf1[21],
1050                                 buf1[26],
1051                                 out[21 * stride],
1052                                 out[11 * stride],
1053                                 __rounding,
1054                                 cos_bit);
1055         btf_32_type1_avx512_new(cospi_p38,
1056                                 cospi_p26,
1057                                 buf1[22],
1058                                 buf1[25],
1059                                 out[13 * stride],
1060                                 out[19 * stride],
1061                                 __rounding,
1062                                 cos_bit);
1063         btf_32_type1_avx512_new(cospi_p06,
1064                                 cospi_p58,
1065                                 buf1[23],
1066                                 buf1[24],
1067                                 out[29 * stride],
1068                                 out[3 * stride],
1069                                 __rounding,
1070                                 cos_bit);
1071     }
1072 }
1073 
fdct32x32_avx512(const __m512i * input,__m512i * output,const int8_t cos_bit,const int8_t * stage_range)1074 static INLINE void fdct32x32_avx512(const __m512i *input, __m512i *output, const int8_t cos_bit,
1075                                     const int8_t *stage_range) {
1076     const int32_t txfm_size   = 32;
1077     const int32_t num_per_512 = 16;
1078     int32_t       col_num     = txfm_size / num_per_512;
1079     (void)stage_range;
1080     av1_fdct32_new_avx512(input, output, cos_bit, txfm_size, col_num);
1081 }
1082 
av1_idtx32_new_avx512(const __m512i * input,__m512i * output,const int8_t cos_bit,const int32_t col_num)1083 void av1_idtx32_new_avx512(const __m512i *input, __m512i *output, const int8_t cos_bit,
1084                            const int32_t col_num) {
1085     (void)cos_bit;
1086     for (int32_t i = 0; i < 32; i++) {
1087         output[i * col_num] = _mm512_slli_epi32(input[i * col_num], (uint8_t)2);
1088     }
1089 }
1090 
fidtx32x32_avx512(const __m512i * input,__m512i * output,const int8_t cos_bit,const int8_t * stage_range)1091 static void fidtx32x32_avx512(const __m512i *input, __m512i *output, const int8_t cos_bit,
1092                               const int8_t *stage_range) {
1093     (void)stage_range;
1094 
1095     for (int32_t i = 0; i < 2; i++) {
1096         av1_idtx32_new_avx512(&input[i * 32], &output[i * 32], cos_bit, 1);
1097     }
1098 }
1099 
1100 typedef void (*TxfmFuncAVX512)(const __m512i *input, __m512i *output, const int8_t cos_bit,
1101                                const int8_t *stage_range);
1102 
fwd_txfm_type_to_func(TxfmType txfmtype)1103 static INLINE TxfmFuncAVX512 fwd_txfm_type_to_func(TxfmType txfmtype) {
1104     switch (txfmtype) {
1105     case TXFM_TYPE_DCT32: return fdct32x32_avx512; break;
1106     case TXFM_TYPE_IDENTITY32: return fidtx32x32_avx512; break;
1107     default: assert(0);
1108     }
1109     return NULL;
1110 }
1111 
load_buffer_32x32_avx512(const int16_t * input,__m512i * output,int32_t stride)1112 static INLINE void load_buffer_32x32_avx512(const int16_t *input, __m512i *output, int32_t stride) {
1113     __m256i temp[2];
1114     int32_t i;
1115 
1116     for (i = 0; i < 32; ++i) {
1117         temp[0] = _mm256_loadu_si256((const __m256i *)(input + 0 * 16));
1118         temp[1] = _mm256_loadu_si256((const __m256i *)(input + 1 * 16));
1119 
1120         output[0] = _mm512_cvtepi16_epi32(temp[0]);
1121         output[1] = _mm512_cvtepi16_epi32(temp[1]);
1122 
1123         input += stride;
1124         output += 2;
1125     }
1126 }
1127 
av1_round_shift_array_avx512(__m512i * input,__m512i * output,const int32_t size,const int8_t bit)1128 static INLINE void av1_round_shift_array_avx512(__m512i *input, __m512i *output, const int32_t size,
1129                                                 const int8_t bit) {
1130     if (bit > 0) {
1131         __m512i round = _mm512_set1_epi32(1 << (bit - 1));
1132         int32_t i;
1133         for (i = 0; i < size; i++) {
1134             output[i] = _mm512_srai_epi32(_mm512_add_epi32(input[i], round), (uint8_t)bit);
1135         }
1136     } else {
1137         int32_t i;
1138         for (i = 0; i < size; i++) { output[i] = _mm512_slli_epi32(input[i], (uint8_t)(-bit)); }
1139     }
1140 }
1141 
fwd_txfm2d_32x32_avx512(const int16_t * input,int32_t * output,const int32_t stride,const Txfm2dFlipCfg * cfg,int32_t * txfm_buf)1142 static INLINE void fwd_txfm2d_32x32_avx512(const int16_t *input, int32_t *output,
1143                                            const int32_t stride, const Txfm2dFlipCfg *cfg,
1144                                            int32_t *txfm_buf) {
1145     assert(cfg->tx_size < TX_SIZES);
1146     const int32_t        txfm_size       = tx_size_wide[cfg->tx_size];
1147     const int8_t *       shift           = cfg->shift;
1148     const int8_t *       stage_range_col = cfg->stage_range_col;
1149     const int8_t *       stage_range_row = cfg->stage_range_row;
1150     const int8_t         cos_bit_col     = cfg->cos_bit_col;
1151     const int8_t         cos_bit_row     = cfg->cos_bit_row;
1152     const TxfmFuncAVX512 txfm_func_col   = fwd_txfm_type_to_func(cfg->txfm_type_col);
1153     const TxfmFuncAVX512 txfm_func_row   = fwd_txfm_type_to_func(cfg->txfm_type_row);
1154     ASSERT(txfm_func_col);
1155     ASSERT(txfm_func_row);
1156     __m512i *buf_512         = (__m512i *)txfm_buf;
1157     __m512i *out_512         = (__m512i *)output;
1158     int32_t  num_per_512     = 16;
1159     int32_t  txfm2d_size_512 = txfm_size * txfm_size / num_per_512;
1160 
1161     load_buffer_32x32_avx512(input, buf_512, stride);
1162     av1_round_shift_array_avx512(buf_512, out_512, txfm2d_size_512, -shift[0]);
1163     txfm_func_col(out_512, buf_512, cos_bit_col, stage_range_col);
1164     av1_round_shift_array_avx512(buf_512, out_512, txfm2d_size_512, -shift[1]);
1165     transpose_16nx16n_avx512(txfm_size, out_512, buf_512);
1166     txfm_func_row(buf_512, out_512, cos_bit_row, stage_range_row);
1167     av1_round_shift_array_avx512(out_512, buf_512, txfm2d_size_512, -shift[2]);
1168     transpose_16nx16n_avx512(txfm_size, buf_512, out_512);
1169 }
1170 
av1_fwd_txfm2d_32x32_avx512(int16_t * input,int32_t * output,uint32_t stride,TxType tx_type,uint8_t bd)1171 void av1_fwd_txfm2d_32x32_avx512(int16_t *input, int32_t *output, uint32_t stride, TxType tx_type,
1172                                  uint8_t bd) {
1173     DECLARE_ALIGNED(64, int32_t, txfm_buf[1024]);
1174     Txfm2dFlipCfg cfg;
1175     av1_transform_config(tx_type, TX_32X32, &cfg);
1176     (void)bd;
1177     fwd_txfm2d_32x32_avx512(input, output, stride, &cfg, txfm_buf);
1178 }
1179 
fidtx64x64_avx512(const __m512i * input,__m512i * output)1180 static void fidtx64x64_avx512(const __m512i *input, __m512i *output) {
1181     const uint8_t bits     = 12; // new_sqrt2_bits = 12
1182     const int32_t sqrt     = 4 * 5793; // 4 * new_sqrt2
1183     const int32_t col_num  = 4;
1184     const __m512i newsqrt  = _mm512_set1_epi32(sqrt);
1185     const __m512i rounding = _mm512_set1_epi32(1 << (bits - 1));
1186 
1187     __m512i temp;
1188     int32_t num_iters = 64 * col_num;
1189     for (int32_t i = 0; i < num_iters; i++) {
1190         temp      = _mm512_mullo_epi32(input[i], newsqrt);
1191         temp      = _mm512_add_epi32(temp, rounding);
1192         output[i] = _mm512_srai_epi32(temp, bits);
1193     }
1194 }
1195 
load_buffer_64x64_avx512(const int16_t * input,int32_t stride,__m512i * output)1196 static INLINE void load_buffer_64x64_avx512(const int16_t *input, int32_t stride, __m512i *output) {
1197     __m256i x0, x1, x2, x3;
1198     __m512i v0, v1, v2, v3;
1199     int32_t i;
1200 
1201     for (i = 0; i < 64; ++i) {
1202         x0 = _mm256_loadu_si256((const __m256i *)(input + 0 * 16));
1203         x1 = _mm256_loadu_si256((const __m256i *)(input + 1 * 16));
1204         x2 = _mm256_loadu_si256((const __m256i *)(input + 2 * 16));
1205         x3 = _mm256_loadu_si256((const __m256i *)(input + 3 * 16));
1206 
1207         v0 = _mm512_cvtepi16_epi32(x0);
1208         v1 = _mm512_cvtepi16_epi32(x1);
1209         v2 = _mm512_cvtepi16_epi32(x2);
1210         v3 = _mm512_cvtepi16_epi32(x3);
1211 
1212         _mm512_storeu_si512(output + 0, v0);
1213         _mm512_storeu_si512(output + 1, v1);
1214         _mm512_storeu_si512(output + 2, v2);
1215         _mm512_storeu_si512(output + 3, v3);
1216 
1217         input += stride;
1218         output += 4;
1219     }
1220 }
1221 
av1_fdct64_new_avx512(const __m512i * input,__m512i * output,const int8_t cos_bit,const int32_t col_num,const int32_t stride)1222 static void av1_fdct64_new_avx512(const __m512i *input, __m512i *output, const int8_t cos_bit,
1223                                   const int32_t col_num, const int32_t stride) {
1224     const int32_t *cospi      = cospi_arr(cos_bit);
1225     const __m512i  __rounding = _mm512_set1_epi32(1 << (cos_bit - 1));
1226     const int32_t  columns    = col_num >> 4;
1227 
1228     __m512i cospi_m32 = _mm512_set1_epi32(-cospi[32]);
1229     __m512i cospi_p32 = _mm512_set1_epi32(cospi[32]);
1230     __m512i cospi_m16 = _mm512_set1_epi32(-cospi[16]);
1231     __m512i cospi_p48 = _mm512_set1_epi32(cospi[48]);
1232     __m512i cospi_m48 = _mm512_set1_epi32(-cospi[48]);
1233     __m512i cospi_p16 = _mm512_set1_epi32(cospi[16]);
1234     __m512i cospi_m08 = _mm512_set1_epi32(-cospi[8]);
1235     __m512i cospi_p56 = _mm512_set1_epi32(cospi[56]);
1236     __m512i cospi_m56 = _mm512_set1_epi32(-cospi[56]);
1237     __m512i cospi_m40 = _mm512_set1_epi32(-cospi[40]);
1238     __m512i cospi_p24 = _mm512_set1_epi32(cospi[24]);
1239     __m512i cospi_m24 = _mm512_set1_epi32(-cospi[24]);
1240     __m512i cospi_p08 = _mm512_set1_epi32(cospi[8]);
1241     __m512i cospi_p40 = _mm512_set1_epi32(cospi[40]);
1242     __m512i cospi_p60 = _mm512_set1_epi32(cospi[60]);
1243     __m512i cospi_p04 = _mm512_set1_epi32(cospi[4]);
1244     __m512i cospi_p28 = _mm512_set1_epi32(cospi[28]);
1245     __m512i cospi_p36 = _mm512_set1_epi32(cospi[36]);
1246     __m512i cospi_p44 = _mm512_set1_epi32(cospi[44]);
1247     __m512i cospi_p20 = _mm512_set1_epi32(cospi[20]);
1248     __m512i cospi_p12 = _mm512_set1_epi32(cospi[12]);
1249     __m512i cospi_p52 = _mm512_set1_epi32(cospi[52]);
1250     __m512i cospi_m04 = _mm512_set1_epi32(-cospi[4]);
1251     __m512i cospi_m60 = _mm512_set1_epi32(-cospi[60]);
1252     __m512i cospi_m36 = _mm512_set1_epi32(-cospi[36]);
1253     __m512i cospi_m28 = _mm512_set1_epi32(-cospi[28]);
1254     __m512i cospi_m20 = _mm512_set1_epi32(-cospi[20]);
1255     __m512i cospi_m44 = _mm512_set1_epi32(-cospi[44]);
1256     __m512i cospi_m52 = _mm512_set1_epi32(-cospi[52]);
1257     __m512i cospi_m12 = _mm512_set1_epi32(-cospi[12]);
1258     __m512i cospi_p62 = _mm512_set1_epi32(cospi[62]);
1259     __m512i cospi_p02 = _mm512_set1_epi32(cospi[2]);
1260     __m512i cospi_p30 = _mm512_set1_epi32(cospi[30]);
1261     __m512i cospi_p34 = _mm512_set1_epi32(cospi[34]);
1262     __m512i cospi_p46 = _mm512_set1_epi32(cospi[46]);
1263     __m512i cospi_p18 = _mm512_set1_epi32(cospi[18]);
1264     __m512i cospi_p14 = _mm512_set1_epi32(cospi[14]);
1265     __m512i cospi_p50 = _mm512_set1_epi32(cospi[50]);
1266     __m512i cospi_p54 = _mm512_set1_epi32(cospi[54]);
1267     __m512i cospi_p10 = _mm512_set1_epi32(cospi[10]);
1268     __m512i cospi_p22 = _mm512_set1_epi32(cospi[22]);
1269     __m512i cospi_p42 = _mm512_set1_epi32(cospi[42]);
1270     __m512i cospi_p38 = _mm512_set1_epi32(cospi[38]);
1271     __m512i cospi_p26 = _mm512_set1_epi32(cospi[26]);
1272     __m512i cospi_p06 = _mm512_set1_epi32(cospi[6]);
1273     __m512i cospi_p58 = _mm512_set1_epi32(cospi[58]);
1274     __m512i cospi_p63 = _mm512_set1_epi32(cospi[63]);
1275     __m512i cospi_p01 = _mm512_set1_epi32(cospi[1]);
1276     __m512i cospi_p31 = _mm512_set1_epi32(cospi[31]);
1277     __m512i cospi_p33 = _mm512_set1_epi32(cospi[33]);
1278     __m512i cospi_p47 = _mm512_set1_epi32(cospi[47]);
1279     __m512i cospi_p17 = _mm512_set1_epi32(cospi[17]);
1280     __m512i cospi_p15 = _mm512_set1_epi32(cospi[15]);
1281     __m512i cospi_p49 = _mm512_set1_epi32(cospi[49]);
1282     __m512i cospi_p55 = _mm512_set1_epi32(cospi[55]);
1283     __m512i cospi_p09 = _mm512_set1_epi32(cospi[9]);
1284     __m512i cospi_p23 = _mm512_set1_epi32(cospi[23]);
1285     __m512i cospi_p41 = _mm512_set1_epi32(cospi[41]);
1286     __m512i cospi_p39 = _mm512_set1_epi32(cospi[39]);
1287     __m512i cospi_p25 = _mm512_set1_epi32(cospi[25]);
1288     __m512i cospi_p07 = _mm512_set1_epi32(cospi[7]);
1289     __m512i cospi_p57 = _mm512_set1_epi32(cospi[57]);
1290     __m512i cospi_p59 = _mm512_set1_epi32(cospi[59]);
1291     __m512i cospi_p05 = _mm512_set1_epi32(cospi[5]);
1292     __m512i cospi_p27 = _mm512_set1_epi32(cospi[27]);
1293     __m512i cospi_p37 = _mm512_set1_epi32(cospi[37]);
1294     __m512i cospi_p43 = _mm512_set1_epi32(cospi[43]);
1295     __m512i cospi_p21 = _mm512_set1_epi32(cospi[21]);
1296     __m512i cospi_p11 = _mm512_set1_epi32(cospi[11]);
1297     __m512i cospi_p53 = _mm512_set1_epi32(cospi[53]);
1298     __m512i cospi_p51 = _mm512_set1_epi32(cospi[51]);
1299     __m512i cospi_p13 = _mm512_set1_epi32(cospi[13]);
1300     __m512i cospi_p19 = _mm512_set1_epi32(cospi[19]);
1301     __m512i cospi_p45 = _mm512_set1_epi32(cospi[45]);
1302     __m512i cospi_p35 = _mm512_set1_epi32(cospi[35]);
1303     __m512i cospi_p29 = _mm512_set1_epi32(cospi[29]);
1304     __m512i cospi_p03 = _mm512_set1_epi32(cospi[3]);
1305     __m512i cospi_p61 = _mm512_set1_epi32(cospi[61]);
1306 
1307     for (int32_t col = 0; col < columns; col++) {
1308         const __m512i *in  = &input[col];
1309         __m512i *      out = &output[col];
1310 
1311         // stage 1
1312         __m512i x1[64];
1313         x1[0]  = _mm512_add_epi32(in[0 * stride], in[63 * stride]);
1314         x1[63] = _mm512_sub_epi32(in[0 * stride], in[63 * stride]);
1315         x1[1]  = _mm512_add_epi32(in[1 * stride], in[62 * stride]);
1316         x1[62] = _mm512_sub_epi32(in[1 * stride], in[62 * stride]);
1317         x1[2]  = _mm512_add_epi32(in[2 * stride], in[61 * stride]);
1318         x1[61] = _mm512_sub_epi32(in[2 * stride], in[61 * stride]);
1319         x1[3]  = _mm512_add_epi32(in[3 * stride], in[60 * stride]);
1320         x1[60] = _mm512_sub_epi32(in[3 * stride], in[60 * stride]);
1321         x1[4]  = _mm512_add_epi32(in[4 * stride], in[59 * stride]);
1322         x1[59] = _mm512_sub_epi32(in[4 * stride], in[59 * stride]);
1323         x1[5]  = _mm512_add_epi32(in[5 * stride], in[58 * stride]);
1324         x1[58] = _mm512_sub_epi32(in[5 * stride], in[58 * stride]);
1325         x1[6]  = _mm512_add_epi32(in[6 * stride], in[57 * stride]);
1326         x1[57] = _mm512_sub_epi32(in[6 * stride], in[57 * stride]);
1327         x1[7]  = _mm512_add_epi32(in[7 * stride], in[56 * stride]);
1328         x1[56] = _mm512_sub_epi32(in[7 * stride], in[56 * stride]);
1329         x1[8]  = _mm512_add_epi32(in[8 * stride], in[55 * stride]);
1330         x1[55] = _mm512_sub_epi32(in[8 * stride], in[55 * stride]);
1331         x1[9]  = _mm512_add_epi32(in[9 * stride], in[54 * stride]);
1332         x1[54] = _mm512_sub_epi32(in[9 * stride], in[54 * stride]);
1333         x1[10] = _mm512_add_epi32(in[10 * stride], in[53 * stride]);
1334         x1[53] = _mm512_sub_epi32(in[10 * stride], in[53 * stride]);
1335         x1[11] = _mm512_add_epi32(in[11 * stride], in[52 * stride]);
1336         x1[52] = _mm512_sub_epi32(in[11 * stride], in[52 * stride]);
1337         x1[12] = _mm512_add_epi32(in[12 * stride], in[51 * stride]);
1338         x1[51] = _mm512_sub_epi32(in[12 * stride], in[51 * stride]);
1339         x1[13] = _mm512_add_epi32(in[13 * stride], in[50 * stride]);
1340         x1[50] = _mm512_sub_epi32(in[13 * stride], in[50 * stride]);
1341         x1[14] = _mm512_add_epi32(in[14 * stride], in[49 * stride]);
1342         x1[49] = _mm512_sub_epi32(in[14 * stride], in[49 * stride]);
1343         x1[15] = _mm512_add_epi32(in[15 * stride], in[48 * stride]);
1344         x1[48] = _mm512_sub_epi32(in[15 * stride], in[48 * stride]);
1345         x1[16] = _mm512_add_epi32(in[16 * stride], in[47 * stride]);
1346         x1[47] = _mm512_sub_epi32(in[16 * stride], in[47 * stride]);
1347         x1[17] = _mm512_add_epi32(in[17 * stride], in[46 * stride]);
1348         x1[46] = _mm512_sub_epi32(in[17 * stride], in[46 * stride]);
1349         x1[18] = _mm512_add_epi32(in[18 * stride], in[45 * stride]);
1350         x1[45] = _mm512_sub_epi32(in[18 * stride], in[45 * stride]);
1351         x1[19] = _mm512_add_epi32(in[19 * stride], in[44 * stride]);
1352         x1[44] = _mm512_sub_epi32(in[19 * stride], in[44 * stride]);
1353         x1[20] = _mm512_add_epi32(in[20 * stride], in[43 * stride]);
1354         x1[43] = _mm512_sub_epi32(in[20 * stride], in[43 * stride]);
1355         x1[21] = _mm512_add_epi32(in[21 * stride], in[42 * stride]);
1356         x1[42] = _mm512_sub_epi32(in[21 * stride], in[42 * stride]);
1357         x1[22] = _mm512_add_epi32(in[22 * stride], in[41 * stride]);
1358         x1[41] = _mm512_sub_epi32(in[22 * stride], in[41 * stride]);
1359         x1[23] = _mm512_add_epi32(in[23 * stride], in[40 * stride]);
1360         x1[40] = _mm512_sub_epi32(in[23 * stride], in[40 * stride]);
1361         x1[24] = _mm512_add_epi32(in[24 * stride], in[39 * stride]);
1362         x1[39] = _mm512_sub_epi32(in[24 * stride], in[39 * stride]);
1363         x1[25] = _mm512_add_epi32(in[25 * stride], in[38 * stride]);
1364         x1[38] = _mm512_sub_epi32(in[25 * stride], in[38 * stride]);
1365         x1[26] = _mm512_add_epi32(in[26 * stride], in[37 * stride]);
1366         x1[37] = _mm512_sub_epi32(in[26 * stride], in[37 * stride]);
1367         x1[27] = _mm512_add_epi32(in[27 * stride], in[36 * stride]);
1368         x1[36] = _mm512_sub_epi32(in[27 * stride], in[36 * stride]);
1369         x1[28] = _mm512_add_epi32(in[28 * stride], in[35 * stride]);
1370         x1[35] = _mm512_sub_epi32(in[28 * stride], in[35 * stride]);
1371         x1[29] = _mm512_add_epi32(in[29 * stride], in[34 * stride]);
1372         x1[34] = _mm512_sub_epi32(in[29 * stride], in[34 * stride]);
1373         x1[30] = _mm512_add_epi32(in[30 * stride], in[33 * stride]);
1374         x1[33] = _mm512_sub_epi32(in[30 * stride], in[33 * stride]);
1375         x1[31] = _mm512_add_epi32(in[31 * stride], in[32 * stride]);
1376         x1[32] = _mm512_sub_epi32(in[31 * stride], in[32 * stride]);
1377 
1378         // stage 2
1379         __m512i x2[54];
1380         x2[0]  = _mm512_add_epi32(x1[0], x1[31]);
1381         x2[31] = _mm512_sub_epi32(x1[0], x1[31]);
1382         x2[1]  = _mm512_add_epi32(x1[1], x1[30]);
1383         x2[30] = _mm512_sub_epi32(x1[1], x1[30]);
1384         x2[2]  = _mm512_add_epi32(x1[2], x1[29]);
1385         x2[29] = _mm512_sub_epi32(x1[2], x1[29]);
1386         x2[3]  = _mm512_add_epi32(x1[3], x1[28]);
1387         x2[28] = _mm512_sub_epi32(x1[3], x1[28]);
1388         x2[4]  = _mm512_add_epi32(x1[4], x1[27]);
1389         x2[27] = _mm512_sub_epi32(x1[4], x1[27]);
1390         x2[5]  = _mm512_add_epi32(x1[5], x1[26]);
1391         x2[26] = _mm512_sub_epi32(x1[5], x1[26]);
1392         x2[6]  = _mm512_add_epi32(x1[6], x1[25]);
1393         x2[25] = _mm512_sub_epi32(x1[6], x1[25]);
1394         x2[7]  = _mm512_add_epi32(x1[7], x1[24]);
1395         x2[24] = _mm512_sub_epi32(x1[7], x1[24]);
1396         x2[8]  = _mm512_add_epi32(x1[8], x1[23]);
1397         x2[23] = _mm512_sub_epi32(x1[8], x1[23]);
1398         x2[9]  = _mm512_add_epi32(x1[9], x1[22]);
1399         x2[22] = _mm512_sub_epi32(x1[9], x1[22]);
1400         x2[10] = _mm512_add_epi32(x1[10], x1[21]);
1401         x2[21] = _mm512_sub_epi32(x1[10], x1[21]);
1402         x2[11] = _mm512_add_epi32(x1[11], x1[20]);
1403         x2[20] = _mm512_sub_epi32(x1[11], x1[20]);
1404         x2[12] = _mm512_add_epi32(x1[12], x1[19]);
1405         x2[19] = _mm512_sub_epi32(x1[12], x1[19]);
1406         x2[13] = _mm512_add_epi32(x1[13], x1[18]);
1407         x2[18] = _mm512_sub_epi32(x1[13], x1[18]);
1408         x2[14] = _mm512_add_epi32(x1[14], x1[17]);
1409         x2[17] = _mm512_sub_epi32(x1[14], x1[17]);
1410         x2[15] = _mm512_add_epi32(x1[15], x1[16]);
1411         x2[16] = _mm512_sub_epi32(x1[15], x1[16]);
1412         btf_32_type0_avx512_new(
1413             cospi_m32, cospi_p32, x1[40], x1[55], x2[32], x2[47], __rounding, cos_bit);
1414         btf_32_type0_avx512_new(
1415             cospi_m32, cospi_p32, x1[41], x1[54], x2[33], x2[46], __rounding, cos_bit);
1416         btf_32_type0_avx512_new(
1417             cospi_m32, cospi_p32, x1[42], x1[53], x2[34], x2[45], __rounding, cos_bit);
1418         btf_32_type0_avx512_new(
1419             cospi_m32, cospi_p32, x1[43], x1[52], x2[35], x2[44], __rounding, cos_bit);
1420         btf_32_type0_avx512_new(
1421             cospi_m32, cospi_p32, x1[44], x1[51], x2[36], x2[43], __rounding, cos_bit);
1422         btf_32_type0_avx512_new(
1423             cospi_m32, cospi_p32, x1[45], x1[50], x2[37], x2[42], __rounding, cos_bit);
1424         btf_32_type0_avx512_new(
1425             cospi_m32, cospi_p32, x1[46], x1[49], x2[38], x2[41], __rounding, cos_bit);
1426         btf_32_type0_avx512_new(
1427             cospi_m32, cospi_p32, x1[47], x1[48], x2[39], x2[40], __rounding, cos_bit);
1428 
1429         // stage 3
1430         __m512i x3[56];
1431         x3[0]  = _mm512_add_epi32(x2[0], x2[15]);
1432         x3[15] = _mm512_sub_epi32(x2[0], x2[15]);
1433         x3[1]  = _mm512_add_epi32(x2[1], x2[14]);
1434         x3[14] = _mm512_sub_epi32(x2[1], x2[14]);
1435         x3[2]  = _mm512_add_epi32(x2[2], x2[13]);
1436         x3[13] = _mm512_sub_epi32(x2[2], x2[13]);
1437         x3[3]  = _mm512_add_epi32(x2[3], x2[12]);
1438         x3[12] = _mm512_sub_epi32(x2[3], x2[12]);
1439         x3[4]  = _mm512_add_epi32(x2[4], x2[11]);
1440         x3[11] = _mm512_sub_epi32(x2[4], x2[11]);
1441         x3[5]  = _mm512_add_epi32(x2[5], x2[10]);
1442         x3[10] = _mm512_sub_epi32(x2[5], x2[10]);
1443         x3[6]  = _mm512_add_epi32(x2[6], x2[9]);
1444         x3[9]  = _mm512_sub_epi32(x2[6], x2[9]);
1445         x3[7]  = _mm512_add_epi32(x2[7], x2[8]);
1446         x3[8]  = _mm512_sub_epi32(x2[7], x2[8]);
1447         btf_32_type0_avx512_new(
1448             cospi_m32, cospi_p32, x2[20], x2[27], x3[16], x3[23], __rounding, cos_bit);
1449         btf_32_type0_avx512_new(
1450             cospi_m32, cospi_p32, x2[21], x2[26], x3[17], x3[22], __rounding, cos_bit);
1451         btf_32_type0_avx512_new(
1452             cospi_m32, cospi_p32, x2[22], x2[25], x3[18], x3[21], __rounding, cos_bit);
1453         btf_32_type0_avx512_new(
1454             cospi_m32, cospi_p32, x2[23], x2[24], x3[19], x3[20], __rounding, cos_bit);
1455         x3[32] = _mm512_add_epi32(x1[32], x2[39]);
1456         x3[47] = _mm512_sub_epi32(x1[32], x2[39]);
1457         x3[33] = _mm512_add_epi32(x1[33], x2[38]);
1458         x3[46] = _mm512_sub_epi32(x1[33], x2[38]);
1459         x3[34] = _mm512_add_epi32(x1[34], x2[37]);
1460         x3[45] = _mm512_sub_epi32(x1[34], x2[37]);
1461         x3[35] = _mm512_add_epi32(x1[35], x2[36]);
1462         x3[44] = _mm512_sub_epi32(x1[35], x2[36]);
1463         x3[36] = _mm512_add_epi32(x1[36], x2[35]);
1464         x3[43] = _mm512_sub_epi32(x1[36], x2[35]);
1465         x3[37] = _mm512_add_epi32(x1[37], x2[34]);
1466         x3[42] = _mm512_sub_epi32(x1[37], x2[34]);
1467         x3[38] = _mm512_add_epi32(x1[38], x2[33]);
1468         x3[41] = _mm512_sub_epi32(x1[38], x2[33]);
1469         x3[39] = _mm512_add_epi32(x1[39], x2[32]);
1470         x3[40] = _mm512_sub_epi32(x1[39], x2[32]);
1471         x3[48] = _mm512_sub_epi32(x1[63], x2[40]);
1472         x3[24] = _mm512_add_epi32(x1[63], x2[40]);
1473         x3[49] = _mm512_sub_epi32(x1[62], x2[41]);
1474         x3[25] = _mm512_add_epi32(x1[62], x2[41]);
1475         x3[50] = _mm512_sub_epi32(x1[61], x2[42]);
1476         x3[26] = _mm512_add_epi32(x1[61], x2[42]);
1477         x3[51] = _mm512_sub_epi32(x1[60], x2[43]);
1478         x3[27] = _mm512_add_epi32(x1[60], x2[43]);
1479         x3[52] = _mm512_sub_epi32(x1[59], x2[44]);
1480         x3[28] = _mm512_add_epi32(x1[59], x2[44]);
1481         x3[53] = _mm512_sub_epi32(x1[58], x2[45]);
1482         x3[29] = _mm512_add_epi32(x1[58], x2[45]);
1483         x3[54] = _mm512_sub_epi32(x1[57], x2[46]);
1484         x3[30] = _mm512_add_epi32(x1[57], x2[46]);
1485         x3[55] = _mm512_sub_epi32(x1[56], x2[47]);
1486         x3[31] = _mm512_add_epi32(x1[56], x2[47]);
1487 
1488         // stage 4
1489         //__m512i x4[44]; replace with x1
1490         x1[0] = _mm512_add_epi32(x3[0], x3[7]);
1491         x1[7] = _mm512_sub_epi32(x3[0], x3[7]);
1492         x1[1] = _mm512_add_epi32(x3[1], x3[6]);
1493         x1[6] = _mm512_sub_epi32(x3[1], x3[6]);
1494         x1[2] = _mm512_add_epi32(x3[2], x3[5]);
1495         x1[5] = _mm512_sub_epi32(x3[2], x3[5]);
1496         x1[3] = _mm512_add_epi32(x3[3], x3[4]);
1497         x1[4] = _mm512_sub_epi32(x3[3], x3[4]);
1498         btf_32_type0_avx512_new(
1499             cospi_m32, cospi_p32, x3[10], x3[13], x1[8], x1[11], __rounding, cos_bit);
1500         btf_32_type0_avx512_new(
1501             cospi_m32, cospi_p32, x3[11], x3[12], x1[9], x1[10], __rounding, cos_bit);
1502         x1[12] = _mm512_add_epi32(x2[16], x3[19]);
1503         x1[19] = _mm512_sub_epi32(x2[16], x3[19]);
1504         x1[13] = _mm512_add_epi32(x2[17], x3[18]);
1505         x1[18] = _mm512_sub_epi32(x2[17], x3[18]);
1506         x1[14] = _mm512_add_epi32(x2[18], x3[17]);
1507         x1[17] = _mm512_sub_epi32(x2[18], x3[17]);
1508         x1[15] = _mm512_add_epi32(x2[19], x3[16]);
1509         x1[16] = _mm512_sub_epi32(x2[19], x3[16]);
1510         x1[20] = _mm512_sub_epi32(x2[31], x3[20]);
1511         x1[27] = _mm512_add_epi32(x2[31], x3[20]);
1512         x1[21] = _mm512_sub_epi32(x2[30], x3[21]);
1513         x1[26] = _mm512_add_epi32(x2[30], x3[21]);
1514         x1[22] = _mm512_sub_epi32(x2[29], x3[22]);
1515         x1[25] = _mm512_add_epi32(x2[29], x3[22]);
1516         x1[23] = _mm512_sub_epi32(x2[28], x3[23]);
1517         x1[24] = _mm512_add_epi32(x2[28], x3[23]);
1518         btf_32_type0_avx512_new(
1519             cospi_m16, cospi_p48, x3[36], x3[28], x1[28], x1[43], __rounding, cos_bit);
1520         btf_32_type0_avx512_new(
1521             cospi_m16, cospi_p48, x3[37], x3[29], x1[29], x1[42], __rounding, cos_bit);
1522         btf_32_type0_avx512_new(
1523             cospi_m16, cospi_p48, x3[38], x3[30], x1[30], x1[41], __rounding, cos_bit);
1524         btf_32_type0_avx512_new(
1525             cospi_m16, cospi_p48, x3[39], x3[31], x1[31], x1[40], __rounding, cos_bit);
1526         btf_32_type0_avx512_new(
1527             cospi_m48, cospi_m16, x3[40], x3[55], x1[32], x1[39], __rounding, cos_bit);
1528         btf_32_type0_avx512_new(
1529             cospi_m48, cospi_m16, x3[41], x3[54], x1[33], x1[38], __rounding, cos_bit);
1530         btf_32_type0_avx512_new(
1531             cospi_m48, cospi_m16, x3[42], x3[53], x1[34], x1[37], __rounding, cos_bit);
1532         btf_32_type0_avx512_new(
1533             cospi_m48, cospi_m16, x3[43], x3[52], x1[35], x1[36], __rounding, cos_bit);
1534 
1535         // stage 5
1536         //__m512i x5[54]; replace with x2
1537         x2[0] = _mm512_add_epi32(x1[0], x1[3]);
1538         x2[3] = _mm512_sub_epi32(x1[0], x1[3]);
1539         x2[1] = _mm512_add_epi32(x1[1], x1[2]);
1540         x2[2] = _mm512_sub_epi32(x1[1], x1[2]);
1541         btf_32_type0_avx512_new(
1542             cospi_m32, cospi_p32, x1[5], x1[6], x2[4], x2[5], __rounding, cos_bit);
1543         x2[6]  = _mm512_add_epi32(x3[8], x1[9]);
1544         x2[9]  = _mm512_sub_epi32(x3[8], x1[9]);
1545         x2[7]  = _mm512_add_epi32(x3[9], x1[8]);
1546         x2[8]  = _mm512_sub_epi32(x3[9], x1[8]);
1547         x2[10] = _mm512_sub_epi32(x3[15], x1[10]);
1548         x2[13] = _mm512_add_epi32(x3[15], x1[10]);
1549         x2[11] = _mm512_sub_epi32(x3[14], x1[11]);
1550         x2[12] = _mm512_add_epi32(x3[14], x1[11]);
1551         btf_32_type0_avx512_new(
1552             cospi_m16, cospi_p48, x1[14], x1[25], x2[14], x2[21], __rounding, cos_bit);
1553         btf_32_type0_avx512_new(
1554             cospi_m16, cospi_p48, x1[15], x1[24], x2[15], x2[20], __rounding, cos_bit);
1555         btf_32_type0_avx512_new(
1556             cospi_m48, cospi_m16, x1[16], x1[23], x2[16], x2[19], __rounding, cos_bit);
1557         btf_32_type0_avx512_new(
1558             cospi_m48, cospi_m16, x1[17], x1[22], x2[17], x2[18], __rounding, cos_bit);
1559         x2[22] = _mm512_add_epi32(x3[32], x1[31]);
1560         x2[29] = _mm512_sub_epi32(x3[32], x1[31]);
1561         x2[23] = _mm512_add_epi32(x3[33], x1[30]);
1562         x2[28] = _mm512_sub_epi32(x3[33], x1[30]);
1563         x2[24] = _mm512_add_epi32(x3[34], x1[29]);
1564         x2[27] = _mm512_sub_epi32(x3[34], x1[29]);
1565         x2[25] = _mm512_add_epi32(x3[35], x1[28]);
1566         x2[26] = _mm512_sub_epi32(x3[35], x1[28]);
1567         x2[30] = _mm512_sub_epi32(x3[47], x1[32]);
1568         x2[37] = _mm512_add_epi32(x3[47], x1[32]);
1569         x2[31] = _mm512_sub_epi32(x3[46], x1[33]);
1570         x2[36] = _mm512_add_epi32(x3[46], x1[33]);
1571         x2[32] = _mm512_sub_epi32(x3[45], x1[34]);
1572         x2[35] = _mm512_add_epi32(x3[45], x1[34]);
1573         x2[33] = _mm512_sub_epi32(x3[44], x1[35]);
1574         x2[34] = _mm512_add_epi32(x3[44], x1[35]);
1575         x2[38] = _mm512_add_epi32(x3[48], x1[39]);
1576         x2[45] = _mm512_sub_epi32(x3[48], x1[39]);
1577         x2[39] = _mm512_add_epi32(x3[49], x1[38]);
1578         x2[44] = _mm512_sub_epi32(x3[49], x1[38]);
1579         x2[40] = _mm512_add_epi32(x3[50], x1[37]);
1580         x2[43] = _mm512_sub_epi32(x3[50], x1[37]);
1581         x2[41] = _mm512_add_epi32(x3[51], x1[36]);
1582         x2[42] = _mm512_sub_epi32(x3[51], x1[36]);
1583         x2[46] = _mm512_sub_epi32(x3[24], x1[40]);
1584         x2[53] = _mm512_add_epi32(x3[24], x1[40]);
1585         x2[47] = _mm512_sub_epi32(x3[25], x1[41]);
1586         x2[52] = _mm512_add_epi32(x3[25], x1[41]);
1587         x2[48] = _mm512_sub_epi32(x3[26], x1[42]);
1588         x2[51] = _mm512_add_epi32(x3[26], x1[42]);
1589         x2[49] = _mm512_sub_epi32(x3[27], x1[43]);
1590         x2[50] = _mm512_add_epi32(x3[27], x1[43]);
1591 
1592         // stage 6
1593         //__m512i x6[40]; replace with x3
1594         btf_32_type0_avx512_new(cospi_p32,
1595                                 cospi_p32,
1596                                 x2[0],
1597                                 x2[1],
1598                                 out[0 * stride],
1599                                 out[32 * stride],
1600                                 __rounding,
1601                                 cos_bit);
1602         btf_32_type1_avx512_new(cospi_p48,
1603                                 cospi_p16,
1604                                 x2[2],
1605                                 x2[3],
1606                                 out[16 * stride],
1607                                 out[48 * stride],
1608                                 __rounding,
1609                                 cos_bit);
1610         x3[0] = _mm512_add_epi32(x1[4], x2[4]);
1611         x3[1] = _mm512_sub_epi32(x1[4], x2[4]);
1612         x3[2] = _mm512_sub_epi32(x1[7], x2[5]);
1613         x3[3] = _mm512_add_epi32(x1[7], x2[5]);
1614         btf_32_type0_avx512_new(
1615             cospi_m16, cospi_p48, x2[7], x2[12], x3[4], x3[7], __rounding, cos_bit);
1616         btf_32_type0_avx512_new(
1617             cospi_m48, cospi_m16, x2[8], x2[11], x3[5], x3[6], __rounding, cos_bit);
1618         x3[8]  = _mm512_add_epi32(x1[12], x2[15]);
1619         x3[11] = _mm512_sub_epi32(x1[12], x2[15]);
1620         x3[9]  = _mm512_add_epi32(x1[13], x2[14]);
1621         x3[10] = _mm512_sub_epi32(x1[13], x2[14]);
1622         x3[12] = _mm512_sub_epi32(x1[19], x2[16]);
1623         x3[15] = _mm512_add_epi32(x1[19], x2[16]);
1624         x3[13] = _mm512_sub_epi32(x1[18], x2[17]);
1625         x3[14] = _mm512_add_epi32(x1[18], x2[17]);
1626         x3[16] = _mm512_add_epi32(x1[20], x2[19]);
1627         x3[19] = _mm512_sub_epi32(x1[20], x2[19]);
1628         x3[17] = _mm512_add_epi32(x1[21], x2[18]);
1629         x3[18] = _mm512_sub_epi32(x1[21], x2[18]);
1630         x3[20] = _mm512_sub_epi32(x1[27], x2[20]);
1631         x3[23] = _mm512_add_epi32(x1[27], x2[20]);
1632         x3[21] = _mm512_sub_epi32(x1[26], x2[21]);
1633         x3[22] = _mm512_add_epi32(x1[26], x2[21]);
1634         btf_32_type0_avx512_new(
1635             cospi_m08, cospi_p56, x2[24], x2[51], x3[24], x3[39], __rounding, cos_bit);
1636         btf_32_type0_avx512_new(
1637             cospi_m08, cospi_p56, x2[25], x2[50], x3[25], x3[38], __rounding, cos_bit);
1638         btf_32_type0_avx512_new(
1639             cospi_m56, cospi_m08, x2[26], x2[49], x3[26], x3[37], __rounding, cos_bit);
1640         btf_32_type0_avx512_new(
1641             cospi_m56, cospi_m08, x2[27], x2[48], x3[27], x3[36], __rounding, cos_bit);
1642         btf_32_type0_avx512_new(
1643             cospi_m40, cospi_p24, x2[32], x2[43], x3[28], x3[35], __rounding, cos_bit);
1644         btf_32_type0_avx512_new(
1645             cospi_m40, cospi_p24, x2[33], x2[42], x3[29], x3[34], __rounding, cos_bit);
1646         btf_32_type0_avx512_new(
1647             cospi_m24, cospi_m40, x2[34], x2[41], x3[30], x3[33], __rounding, cos_bit);
1648         btf_32_type0_avx512_new(
1649             cospi_m24, cospi_m40, x2[35], x2[40], x3[31], x3[32], __rounding, cos_bit);
1650 
1651         // stage 7
1652         //__m512i x7[48]; replace with x1
1653         btf_32_type1_avx512_new(cospi_p56,
1654                                 cospi_p08,
1655                                 x3[0],
1656                                 x3[3],
1657                                 out[8 * stride],
1658                                 out[56 * stride],
1659                                 __rounding,
1660                                 cos_bit);
1661         btf_32_type1_avx512_new(cospi_p24,
1662                                 cospi_p40,
1663                                 x3[1],
1664                                 x3[2],
1665                                 out[40 * stride],
1666                                 out[24 * stride],
1667                                 __rounding,
1668                                 cos_bit);
1669         x1[0] = _mm512_add_epi32(x2[6], x3[4]);
1670         x1[1] = _mm512_sub_epi32(x2[6], x3[4]);
1671         x1[2] = _mm512_sub_epi32(x2[9], x3[5]);
1672         x1[3] = _mm512_add_epi32(x2[9], x3[5]);
1673         x1[4] = _mm512_add_epi32(x2[10], x3[6]);
1674         x1[5] = _mm512_sub_epi32(x2[10], x3[6]);
1675         x1[6] = _mm512_sub_epi32(x2[13], x3[7]);
1676         x1[7] = _mm512_add_epi32(x2[13], x3[7]);
1677         btf_32_type0_avx512_new(
1678             cospi_m08, cospi_p56, x3[9], x3[22], x1[8], x1[15], __rounding, cos_bit);
1679         btf_32_type0_avx512_new(
1680             cospi_m56, cospi_m08, x3[10], x3[21], x1[9], x1[14], __rounding, cos_bit);
1681         btf_32_type0_avx512_new(
1682             cospi_m40, cospi_p24, x3[13], x3[18], x1[10], x1[13], __rounding, cos_bit);
1683         btf_32_type0_avx512_new(
1684             cospi_m24, cospi_m40, x3[14], x3[17], x1[11], x1[12], __rounding, cos_bit);
1685         x1[16] = _mm512_add_epi32(x2[22], x3[25]);
1686         x1[17] = _mm512_sub_epi32(x2[22], x3[25]);
1687         x1[19] = _mm512_add_epi32(x2[23], x3[24]);
1688         x1[20] = _mm512_sub_epi32(x2[23], x3[24]);
1689         x1[18] = _mm512_sub_epi32(x2[29], x3[26]);
1690         x1[21] = _mm512_add_epi32(x2[29], x3[26]);
1691         x1[22] = _mm512_sub_epi32(x2[28], x3[27]);
1692         x1[23] = _mm512_add_epi32(x2[28], x3[27]);
1693         x1[24] = _mm512_add_epi32(x2[30], x3[29]);
1694         x1[25] = _mm512_sub_epi32(x2[30], x3[29]);
1695         x1[26] = _mm512_add_epi32(x2[31], x3[28]);
1696         x1[27] = _mm512_sub_epi32(x2[31], x3[28]);
1697         x1[28] = _mm512_sub_epi32(x2[37], x3[30]);
1698         x1[29] = _mm512_add_epi32(x2[37], x3[30]);
1699         x1[30] = _mm512_sub_epi32(x2[36], x3[31]);
1700         x1[31] = _mm512_add_epi32(x2[36], x3[31]);
1701         x1[32] = _mm512_add_epi32(x2[38], x3[33]);
1702         x1[33] = _mm512_sub_epi32(x2[38], x3[33]);
1703         x1[34] = _mm512_add_epi32(x2[39], x3[32]);
1704         x1[35] = _mm512_sub_epi32(x2[39], x3[32]);
1705         x1[36] = _mm512_sub_epi32(x2[45], x3[34]);
1706         x1[37] = _mm512_add_epi32(x2[45], x3[34]);
1707         x1[38] = _mm512_sub_epi32(x2[44], x3[35]);
1708         x1[39] = _mm512_add_epi32(x2[44], x3[35]);
1709         x1[40] = _mm512_add_epi32(x2[46], x3[37]);
1710         x1[41] = _mm512_sub_epi32(x2[46], x3[37]);
1711         x1[42] = _mm512_add_epi32(x2[47], x3[36]);
1712         x1[43] = _mm512_sub_epi32(x2[47], x3[36]);
1713         x1[44] = _mm512_sub_epi32(x2[53], x3[38]);
1714         x1[45] = _mm512_add_epi32(x2[53], x3[38]);
1715         x1[46] = _mm512_sub_epi32(x2[52], x3[39]);
1716         x1[47] = _mm512_add_epi32(x2[52], x3[39]);
1717 
1718         // stage 8
1719         //__m512i x8[32]; replace with x2
1720         btf_32_type1_avx512_new(cospi_p60,
1721                                 cospi_p04,
1722                                 x1[0],
1723                                 x1[7],
1724                                 out[4 * stride],
1725                                 out[60 * stride],
1726                                 __rounding,
1727                                 cos_bit);
1728         btf_32_type1_avx512_new(cospi_p28,
1729                                 cospi_p36,
1730                                 x1[1],
1731                                 x1[6],
1732                                 out[36 * stride],
1733                                 out[28 * stride],
1734                                 __rounding,
1735                                 cos_bit);
1736         btf_32_type1_avx512_new(cospi_p44,
1737                                 cospi_p20,
1738                                 x1[2],
1739                                 x1[5],
1740                                 out[20 * stride],
1741                                 out[44 * stride],
1742                                 __rounding,
1743                                 cos_bit);
1744         btf_32_type1_avx512_new(cospi_p12,
1745                                 cospi_p52,
1746                                 x1[3],
1747                                 x1[4],
1748                                 out[52 * stride],
1749                                 out[12 * stride],
1750                                 __rounding,
1751                                 cos_bit);
1752         x2[0]  = _mm512_add_epi32(x3[8], x1[8]);
1753         x2[1]  = _mm512_sub_epi32(x3[8], x1[8]);
1754         x2[2]  = _mm512_sub_epi32(x3[11], x1[9]);
1755         x2[3]  = _mm512_add_epi32(x3[11], x1[9]);
1756         x2[4]  = _mm512_add_epi32(x3[12], x1[10]);
1757         x2[5]  = _mm512_sub_epi32(x3[12], x1[10]);
1758         x2[6]  = _mm512_sub_epi32(x3[15], x1[11]);
1759         x2[7]  = _mm512_add_epi32(x3[15], x1[11]);
1760         x2[8]  = _mm512_add_epi32(x3[16], x1[12]);
1761         x2[9]  = _mm512_sub_epi32(x3[16], x1[12]);
1762         x2[10] = _mm512_sub_epi32(x3[19], x1[13]);
1763         x2[11] = _mm512_add_epi32(x3[19], x1[13]);
1764         x2[12] = _mm512_add_epi32(x3[20], x1[14]);
1765         x2[13] = _mm512_sub_epi32(x3[20], x1[14]);
1766         x2[14] = _mm512_sub_epi32(x3[23], x1[15]);
1767         x2[15] = _mm512_add_epi32(x3[23], x1[15]);
1768         btf_32_type0_avx512_new(
1769             cospi_m04, cospi_p60, x1[19], x1[47], x2[16], x2[31], __rounding, cos_bit);
1770         btf_32_type0_avx512_new(
1771             cospi_m60, cospi_m04, x1[20], x1[46], x2[17], x2[30], __rounding, cos_bit);
1772         btf_32_type0_avx512_new(
1773             cospi_m36, cospi_p28, x1[22], x1[43], x2[18], x2[29], __rounding, cos_bit);
1774         btf_32_type0_avx512_new(
1775             cospi_m28, cospi_m36, x1[23], x1[42], x2[19], x2[28], __rounding, cos_bit);
1776         btf_32_type0_avx512_new(
1777             cospi_m20, cospi_p44, x1[26], x1[39], x2[20], x2[27], __rounding, cos_bit);
1778         btf_32_type0_avx512_new(
1779             cospi_m44, cospi_m20, x1[27], x1[38], x2[21], x2[26], __rounding, cos_bit);
1780         btf_32_type0_avx512_new(
1781             cospi_m52, cospi_p12, x1[30], x1[35], x2[22], x2[25], __rounding, cos_bit);
1782         btf_32_type0_avx512_new(
1783             cospi_m12, cospi_m52, x1[31], x1[34], x2[23], x2[24], __rounding, cos_bit);
1784 
1785         // stage 9
1786         //__m512i x9[32]; replace with x3
1787         btf_32_type1_avx512_new(cospi_p62,
1788                                 cospi_p02,
1789                                 x2[0],
1790                                 x2[15],
1791                                 out[2 * stride],
1792                                 out[62 * stride],
1793                                 __rounding,
1794                                 cos_bit);
1795         btf_32_type1_avx512_new(cospi_p30,
1796                                 cospi_p34,
1797                                 x2[1],
1798                                 x2[14],
1799                                 out[34 * stride],
1800                                 out[30 * stride],
1801                                 __rounding,
1802                                 cos_bit);
1803         btf_32_type1_avx512_new(cospi_p46,
1804                                 cospi_p18,
1805                                 x2[2],
1806                                 x2[13],
1807                                 out[18 * stride],
1808                                 out[46 * stride],
1809                                 __rounding,
1810                                 cos_bit);
1811         btf_32_type1_avx512_new(cospi_p14,
1812                                 cospi_p50,
1813                                 x2[3],
1814                                 x2[12],
1815                                 out[50 * stride],
1816                                 out[14 * stride],
1817                                 __rounding,
1818                                 cos_bit);
1819         btf_32_type1_avx512_new(cospi_p54,
1820                                 cospi_p10,
1821                                 x2[4],
1822                                 x2[11],
1823                                 out[10 * stride],
1824                                 out[54 * stride],
1825                                 __rounding,
1826                                 cos_bit);
1827         btf_32_type1_avx512_new(cospi_p22,
1828                                 cospi_p42,
1829                                 x2[5],
1830                                 x2[10],
1831                                 out[42 * stride],
1832                                 out[22 * stride],
1833                                 __rounding,
1834                                 cos_bit);
1835         btf_32_type1_avx512_new(cospi_p38,
1836                                 cospi_p26,
1837                                 x2[6],
1838                                 x2[9],
1839                                 out[26 * stride],
1840                                 out[38 * stride],
1841                                 __rounding,
1842                                 cos_bit);
1843         btf_32_type1_avx512_new(cospi_p06,
1844                                 cospi_p58,
1845                                 x2[7],
1846                                 x2[8],
1847                                 out[58 * stride],
1848                                 out[6 * stride],
1849                                 __rounding,
1850                                 cos_bit);
1851         x3[0]  = _mm512_add_epi32(x1[16], x2[16]);
1852         x3[1]  = _mm512_sub_epi32(x1[16], x2[16]);
1853         x3[2]  = _mm512_sub_epi32(x1[17], x2[17]);
1854         x3[3]  = _mm512_add_epi32(x1[17], x2[17]);
1855         x3[4]  = _mm512_add_epi32(x1[18], x2[18]);
1856         x3[5]  = _mm512_sub_epi32(x1[18], x2[18]);
1857         x3[6]  = _mm512_sub_epi32(x1[21], x2[19]);
1858         x3[7]  = _mm512_add_epi32(x1[21], x2[19]);
1859         x3[8]  = _mm512_add_epi32(x1[24], x2[20]);
1860         x3[9]  = _mm512_sub_epi32(x1[24], x2[20]);
1861         x3[10] = _mm512_sub_epi32(x1[25], x2[21]);
1862         x3[11] = _mm512_add_epi32(x1[25], x2[21]);
1863         x3[12] = _mm512_add_epi32(x1[28], x2[22]);
1864         x3[13] = _mm512_sub_epi32(x1[28], x2[22]);
1865         x3[14] = _mm512_sub_epi32(x1[29], x2[23]);
1866         x3[15] = _mm512_add_epi32(x1[29], x2[23]);
1867         x3[16] = _mm512_add_epi32(x1[32], x2[24]);
1868         x3[17] = _mm512_sub_epi32(x1[32], x2[24]);
1869         x3[18] = _mm512_sub_epi32(x1[33], x2[25]);
1870         x3[19] = _mm512_add_epi32(x1[33], x2[25]);
1871         x3[20] = _mm512_add_epi32(x1[36], x2[26]);
1872         x3[21] = _mm512_sub_epi32(x1[36], x2[26]);
1873         x3[22] = _mm512_sub_epi32(x1[37], x2[27]);
1874         x3[23] = _mm512_add_epi32(x1[37], x2[27]);
1875         x3[24] = _mm512_add_epi32(x1[40], x2[28]);
1876         x3[25] = _mm512_sub_epi32(x1[40], x2[28]);
1877         x3[26] = _mm512_sub_epi32(x1[41], x2[29]);
1878         x3[27] = _mm512_add_epi32(x1[41], x2[29]);
1879         x3[28] = _mm512_add_epi32(x1[44], x2[30]);
1880         x3[29] = _mm512_sub_epi32(x1[44], x2[30]);
1881         x3[30] = _mm512_sub_epi32(x1[45], x2[31]);
1882         x3[31] = _mm512_add_epi32(x1[45], x2[31]);
1883 
1884         // stage 10
1885         btf_32_type1_avx512_new(cospi_p63,
1886                                 cospi_p01,
1887                                 x3[0],
1888                                 x3[31],
1889                                 out[1 * stride],
1890                                 out[63 * stride],
1891                                 __rounding,
1892                                 cos_bit);
1893         btf_32_type1_avx512_new(cospi_p31,
1894                                 cospi_p33,
1895                                 x3[1],
1896                                 x3[30],
1897                                 out[33 * stride],
1898                                 out[31 * stride],
1899                                 __rounding,
1900                                 cos_bit);
1901         btf_32_type1_avx512_new(cospi_p47,
1902                                 cospi_p17,
1903                                 x3[2],
1904                                 x3[29],
1905                                 out[17 * stride],
1906                                 out[47 * stride],
1907                                 __rounding,
1908                                 cos_bit);
1909         btf_32_type1_avx512_new(cospi_p15,
1910                                 cospi_p49,
1911                                 x3[3],
1912                                 x3[28],
1913                                 out[49 * stride],
1914                                 out[15 * stride],
1915                                 __rounding,
1916                                 cos_bit);
1917         btf_32_type1_avx512_new(cospi_p55,
1918                                 cospi_p09,
1919                                 x3[4],
1920                                 x3[27],
1921                                 out[9 * stride],
1922                                 out[55 * stride],
1923                                 __rounding,
1924                                 cos_bit);
1925         btf_32_type1_avx512_new(cospi_p23,
1926                                 cospi_p41,
1927                                 x3[5],
1928                                 x3[26],
1929                                 out[41 * stride],
1930                                 out[23 * stride],
1931                                 __rounding,
1932                                 cos_bit);
1933         btf_32_type1_avx512_new(cospi_p39,
1934                                 cospi_p25,
1935                                 x3[6],
1936                                 x3[25],
1937                                 out[25 * stride],
1938                                 out[39 * stride],
1939                                 __rounding,
1940                                 cos_bit);
1941         btf_32_type1_avx512_new(cospi_p07,
1942                                 cospi_p57,
1943                                 x3[7],
1944                                 x3[24],
1945                                 out[57 * stride],
1946                                 out[7 * stride],
1947                                 __rounding,
1948                                 cos_bit);
1949         btf_32_type1_avx512_new(cospi_p59,
1950                                 cospi_p05,
1951                                 x3[8],
1952                                 x3[23],
1953                                 out[5 * stride],
1954                                 out[59 * stride],
1955                                 __rounding,
1956                                 cos_bit);
1957         btf_32_type1_avx512_new(cospi_p27,
1958                                 cospi_p37,
1959                                 x3[9],
1960                                 x3[22],
1961                                 out[37 * stride],
1962                                 out[27 * stride],
1963                                 __rounding,
1964                                 cos_bit);
1965         btf_32_type1_avx512_new(cospi_p43,
1966                                 cospi_p21,
1967                                 x3[10],
1968                                 x3[21],
1969                                 out[21 * stride],
1970                                 out[43 * stride],
1971                                 __rounding,
1972                                 cos_bit);
1973         btf_32_type1_avx512_new(cospi_p11,
1974                                 cospi_p53,
1975                                 x3[11],
1976                                 x3[20],
1977                                 out[53 * stride],
1978                                 out[11 * stride],
1979                                 __rounding,
1980                                 cos_bit);
1981         btf_32_type1_avx512_new(cospi_p51,
1982                                 cospi_p13,
1983                                 x3[12],
1984                                 x3[19],
1985                                 out[13 * stride],
1986                                 out[51 * stride],
1987                                 __rounding,
1988                                 cos_bit);
1989         btf_32_type1_avx512_new(cospi_p19,
1990                                 cospi_p45,
1991                                 x3[13],
1992                                 x3[18],
1993                                 out[45 * stride],
1994                                 out[19 * stride],
1995                                 __rounding,
1996                                 cos_bit);
1997         btf_32_type1_avx512_new(cospi_p35,
1998                                 cospi_p29,
1999                                 x3[14],
2000                                 x3[17],
2001                                 out[29 * stride],
2002                                 out[35 * stride],
2003                                 __rounding,
2004                                 cos_bit);
2005         btf_32_type1_avx512_new(cospi_p03,
2006                                 cospi_p61,
2007                                 x3[15],
2008                                 x3[16],
2009                                 out[61 * stride],
2010                                 out[3 * stride],
2011                                 __rounding,
2012                                 cos_bit);
2013     }
2014 }
2015 
fdct64x64_avx512(const __m512i * input,__m512i * output,const int8_t cos_bit)2016 static INLINE void fdct64x64_avx512(const __m512i *input, __m512i *output, const int8_t cos_bit) {
2017     const int32_t txfm_size   = 64;
2018     const int32_t num_per_512 = 16;
2019     int32_t       col_num     = txfm_size / num_per_512;
2020     av1_fdct64_new_avx512(input, output, cos_bit, txfm_size, col_num);
2021 }
2022 
av1_fwd_txfm2d_64x64_avx512(int16_t * input,int32_t * output,uint32_t stride,TxType tx_type,uint8_t bd)2023 void av1_fwd_txfm2d_64x64_avx512(int16_t *input, int32_t *output, uint32_t stride, TxType tx_type,
2024                                  uint8_t bd) {
2025     (void)bd;
2026     __m512i       in[256];
2027     __m512i *     out     = (__m512i *)output;
2028     const int32_t txw_idx = tx_size_wide_log2[TX_64X64] - tx_size_wide_log2[0];
2029     const int32_t txh_idx = tx_size_high_log2[TX_64X64] - tx_size_high_log2[0];
2030     const int8_t *shift   = fwd_txfm_shift_ls[TX_64X64];
2031 
2032     switch (tx_type) {
2033     case IDTX:
2034         load_buffer_64x64_avx512(input, stride, out);
2035         fidtx64x64_avx512(out, in);
2036         av1_round_shift_array_avx512(in, out, 256, -shift[1]);
2037         transpose_16nx16n_avx512(64, out, in);
2038 
2039         /*row wise transform*/
2040         fidtx64x64_avx512(in, out);
2041         av1_round_shift_array_avx512(out, in, 256, -shift[2]);
2042         transpose_16nx16n_avx512(64, in, out);
2043         break;
2044     case DCT_DCT:
2045         load_buffer_64x64_avx512(input, stride, out);
2046         fdct64x64_avx512(out, in, fwd_cos_bit_col[txw_idx][txh_idx]);
2047         av1_round_shift_array_avx512(in, out, 256, -shift[1]);
2048         transpose_16nx16n_avx512(64, out, in);
2049 
2050         /*row wise transform*/
2051         fdct64x64_avx512(in, out, fwd_cos_bit_row[txw_idx][txh_idx]);
2052         av1_round_shift_array_avx512(out, in, 256, -shift[2]);
2053         transpose_16nx16n_avx512(64, in, out);
2054         break;
2055     default: assert(0);
2056     }
2057 }
2058 
load_buffer_16_avx512(const int16_t * input,__m512i * in,int32_t stride,int32_t flipud,int32_t fliplr,const int8_t shift)2059 static INLINE void load_buffer_16_avx512(const int16_t *input, __m512i *in, int32_t stride,
2060                                          int32_t flipud, int32_t fliplr, const int8_t shift) {
2061     (void)flipud;
2062     (void)fliplr;
2063     __m256i temp;
2064     uint8_t ushift = (uint8_t)shift;
2065     temp           = _mm256_loadu_si256((const __m256i *)(input + 0 * stride));
2066 
2067     in[0] = _mm512_cvtepi16_epi32(temp);
2068     in[0] = _mm512_slli_epi32(in[0], ushift);
2069 }
2070 
load_buffer_32_avx512(const int16_t * input,__m512i * in,int32_t stride,int32_t flipud,int32_t fliplr,const int8_t shift)2071 static INLINE void load_buffer_32_avx512(const int16_t *input, __m512i *in, int32_t stride,
2072                                          int32_t flipud, int32_t fliplr, const int8_t shift) {
2073     (void)flipud;
2074     (void)fliplr;
2075     __m256i temp[2];
2076     uint8_t ushift = (uint8_t)shift;
2077     temp[0]        = _mm256_loadu_si256((const __m256i *)(input + 0 * stride));
2078     temp[1]        = _mm256_loadu_si256((const __m256i *)(input + 1 * stride));
2079 
2080     in[0] = _mm512_cvtepi16_epi32(temp[0]);
2081     in[1] = _mm512_cvtepi16_epi32(temp[1]);
2082 
2083     in[0] = _mm512_slli_epi32(in[0], ushift);
2084     in[1] = _mm512_slli_epi32(in[1], ushift);
2085 }
2086 
load_buffer_32x16n(const int16_t * input,__m512i * out,int32_t stride,int32_t flipud,int32_t fliplr,const int8_t shift,const int32_t height)2087 static INLINE void load_buffer_32x16n(const int16_t *input, __m512i *out, int32_t stride,
2088                                       int32_t flipud, int32_t fliplr, const int8_t shift,
2089                                       const int32_t height) {
2090     const int16_t *in     = input;
2091     __m512i *      output = out;
2092     for (int32_t col = 0; col < height; col++) {
2093         in     = input + col * stride;
2094         output = out + col * 2;
2095         load_buffer_32_avx512(in, output, 16, flipud, fliplr, shift);
2096     }
2097 }
2098 
av1_round_shift_rect_array_32_avx512(__m512i * input,__m512i * output,const int32_t size,const int8_t bit,const int32_t val)2099 static INLINE void av1_round_shift_rect_array_32_avx512(__m512i *input, __m512i *output,
2100                                                         const int32_t size, const int8_t bit,
2101                                                         const int32_t val) {
2102     const __m512i sqrt2  = _mm512_set1_epi32(val);
2103     const __m512i round2 = _mm512_set1_epi32(1 << (12 - 1));
2104     int32_t       i;
2105     if (bit > 0) {
2106         const __m512i round1 = _mm512_set1_epi32(1 << (bit - 1));
2107         __m512i       r0, r1, r2, r3;
2108         for (i = 0; i < size; i++) {
2109             r0        = _mm512_add_epi32(input[i], round1);
2110             r1        = _mm512_srai_epi32(r0, (uint8_t)bit);
2111             r2        = _mm512_mullo_epi32(sqrt2, r1);
2112             r3        = _mm512_add_epi32(r2, round2);
2113             output[i] = _mm512_srai_epi32(r3, (uint8_t)12);
2114         }
2115     } else {
2116         __m512i r0, r1, r2;
2117         for (i = 0; i < size; i++) {
2118             r0        = _mm512_slli_epi32(input[i], (uint8_t)-bit);
2119             r1        = _mm512_mullo_epi32(sqrt2, r0);
2120             r2        = _mm512_add_epi32(r1, round2);
2121             output[i] = _mm512_srai_epi32(r2, (uint8_t)12);
2122         }
2123     }
2124 }
2125 
av1_fwd_txfm2d_32x64_avx512(int16_t * input,int32_t * output,uint32_t stride,TxType tx_type,uint8_t bd)2126 void av1_fwd_txfm2d_32x64_avx512(int16_t *input, int32_t *output, uint32_t stride, TxType tx_type,
2127                                  uint8_t bd) {
2128     (void)tx_type;
2129     __m512i       in[128];
2130     __m512i *     outcoef512    = (__m512i *)output;
2131     const int8_t *shift         = fwd_txfm_shift_ls[TX_32X64];
2132     const int32_t txw_idx       = get_txw_idx(TX_32X64);
2133     const int32_t txh_idx       = get_txh_idx(TX_32X64);
2134     const int32_t txfm_size_col = tx_size_wide[TX_32X64];
2135     const int32_t txfm_size_row = tx_size_high[TX_32X64];
2136     const int8_t  bitcol        = fwd_cos_bit_col[txw_idx][txh_idx];
2137     const int8_t  bitrow        = fwd_cos_bit_row[txw_idx][txh_idx];
2138     const int32_t num_row       = txfm_size_row >> 4;
2139     const int32_t num_col       = txfm_size_col >> 4;
2140 
2141     // column transform
2142     load_buffer_32x16n(input, in, stride, 0, 0, shift[0], txfm_size_row);
2143     av1_fdct64_new_avx512(in, in, bitcol, txfm_size_col, num_col);
2144 
2145     for (int32_t i = 0; i < 8; i++) { col_txfm_16x16_rounding_avx512((in + i * 16), -shift[1]); }
2146     transpose_16nx16m_avx512(in, outcoef512, txfm_size_col, txfm_size_row);
2147 
2148     // row transform
2149     av1_fdct32_new_avx512(outcoef512, in, bitrow, txfm_size_row, num_row);
2150     transpose_16nx16m_avx512(in, outcoef512, txfm_size_row, txfm_size_col);
2151     av1_round_shift_rect_array_32_avx512(outcoef512, outcoef512, 128, -shift[2], 5793);
2152     (void)bd;
2153 }
2154 
av1_fwd_txfm2d_64x32_avx512(int16_t * input,int32_t * output,uint32_t stride,TxType tx_type,uint8_t bd)2155 void av1_fwd_txfm2d_64x32_avx512(int16_t *input, int32_t *output, uint32_t stride, TxType tx_type,
2156                                  uint8_t bd) {
2157     (void)tx_type;
2158     __m512i       in[128];
2159     __m512i *     outcoef512    = (__m512i *)output;
2160     const int8_t *shift         = fwd_txfm_shift_ls[TX_64X32];
2161     const int32_t txw_idx       = get_txw_idx(TX_64X32);
2162     const int32_t txh_idx       = get_txh_idx(TX_64X32);
2163     const int32_t txfm_size_col = tx_size_wide[TX_64X32];
2164     const int32_t txfm_size_row = tx_size_high[TX_64X32];
2165     const int8_t  bitcol        = fwd_cos_bit_col[txw_idx][txh_idx];
2166     const int8_t  bitrow        = fwd_cos_bit_row[txw_idx][txh_idx];
2167     const int32_t num_row       = txfm_size_row >> 4;
2168     const int32_t num_col       = txfm_size_col >> 4;
2169 
2170     // column transform
2171     for (int32_t i = 0; i < 32; i++) {
2172         load_buffer_32_avx512(input + 0 + i * stride, in + 0 + i * 4, 16, 0, 0, shift[0]);
2173         load_buffer_32_avx512(input + 32 + i * stride, in + 2 + i * 4, 16, 0, 0, shift[0]);
2174     }
2175 
2176     av1_fdct32_new_avx512(in, in, bitcol, txfm_size_col, num_col);
2177 
2178     for (int32_t i = 0; i < 8; i++) { col_txfm_16x16_rounding_avx512((in + i * 16), -shift[1]); }
2179     transpose_16nx16m_avx512(in, outcoef512, txfm_size_col, txfm_size_row);
2180 
2181     // row transform
2182     av1_fdct64_new_avx512(outcoef512, in, bitrow, txfm_size_row, num_row);
2183     transpose_16nx16m_avx512(in, outcoef512, txfm_size_row, txfm_size_col);
2184     av1_round_shift_rect_array_32_avx512(outcoef512, outcoef512, 128, -shift[2], 5793);
2185     (void)bd;
2186 }
2187 
av1_fwd_txfm2d_16x64_avx512(int16_t * input,int32_t * output,uint32_t stride,TxType tx_type,uint8_t bd)2188 void av1_fwd_txfm2d_16x64_avx512(int16_t *input, int32_t *output, uint32_t stride, TxType tx_type,
2189                                  uint8_t bd) {
2190     __m512i       in[64];
2191     __m512i *     outcoeff512   = (__m512i *)output;
2192     const int8_t *shift         = fwd_txfm_shift_ls[TX_16X64];
2193     const int32_t txw_idx       = get_txw_idx(TX_16X64);
2194     const int32_t txh_idx       = get_txh_idx(TX_16X64);
2195     const int32_t txfm_size_col = tx_size_wide[TX_16X64];
2196     const int32_t txfm_size_row = tx_size_high[TX_16X64];
2197     const int8_t  bitcol        = fwd_cos_bit_col[txw_idx][txh_idx];
2198     const int8_t  bitrow        = fwd_cos_bit_row[txw_idx][txh_idx];
2199     const int32_t num_row       = txfm_size_row >> 4;
2200     const int32_t num_col       = txfm_size_col >> 4;
2201 
2202     // column tranform
2203     for (int32_t i = 0; i < txfm_size_row; i++) {
2204         load_buffer_16_avx512(input + i * stride, in + i, 16, 0, 0, shift[0]);
2205     }
2206 
2207     av1_fdct64_new_avx512(in, outcoeff512, bitcol, txfm_size_col, num_col);
2208 
2209     col_txfm_16x16_rounding_avx512(outcoeff512, -shift[1]);
2210     col_txfm_16x16_rounding_avx512(outcoeff512 + 16, -shift[1]);
2211     col_txfm_16x16_rounding_avx512(outcoeff512 + 32, -shift[1]);
2212     col_txfm_16x16_rounding_avx512(outcoeff512 + 48, -shift[1]);
2213     transpose_16nx16m_avx512(outcoeff512, in, txfm_size_col, txfm_size_row);
2214     // row tranform
2215     fdct16x16_avx512(in, in, bitrow, num_row);
2216     transpose_16nx16m_avx512(in, outcoeff512, txfm_size_row, txfm_size_col);
2217     (void)bd;
2218     (void)tx_type;
2219 }
2220 
av1_fwd_txfm2d_64x16_avx512(int16_t * input,int32_t * output,uint32_t stride,TxType tx_type,uint8_t bd)2221 void av1_fwd_txfm2d_64x16_avx512(int16_t *input, int32_t *output, uint32_t stride, TxType tx_type,
2222                                  uint8_t bd) {
2223     __m512i       in[64];
2224     __m512i *     outcoeff512   = (__m512i *)output;
2225     const int8_t *shift         = fwd_txfm_shift_ls[TX_64X16];
2226     const int32_t txw_idx       = get_txw_idx(TX_64X16);
2227     const int32_t txh_idx       = get_txh_idx(TX_64X16);
2228     const int32_t txfm_size_col = tx_size_wide[TX_64X16];
2229     const int32_t txfm_size_row = tx_size_high[TX_64X16];
2230     const int8_t  bitcol        = fwd_cos_bit_col[txw_idx][txh_idx];
2231     const int8_t  bitrow        = fwd_cos_bit_row[txw_idx][txh_idx];
2232     const int32_t num_row       = txfm_size_row >> 4;
2233     const int32_t num_col       = txfm_size_col >> 4;
2234     // column tranform
2235     for (int32_t i = 0; i < txfm_size_row; i++) {
2236         load_buffer_16_avx512(input + 0 + i * stride, in + 0 + i * 4, 16, 0, 0, shift[0]);
2237         load_buffer_16_avx512(input + 16 + i * stride, in + 1 + i * 4, 16, 0, 0, shift[0]);
2238         load_buffer_16_avx512(input + 32 + i * stride, in + 2 + i * 4, 16, 0, 0, shift[0]);
2239         load_buffer_16_avx512(input + 48 + i * stride, in + 3 + i * 4, 16, 0, 0, shift[0]);
2240     }
2241 
2242     fdct16x16_avx512(in, outcoeff512, bitcol, num_col);
2243     col_txfm_16x16_rounding_avx512(outcoeff512, -shift[1]);
2244     col_txfm_16x16_rounding_avx512(outcoeff512 + 16, -shift[1]);
2245     col_txfm_16x16_rounding_avx512(outcoeff512 + 32, -shift[1]);
2246     col_txfm_16x16_rounding_avx512(outcoeff512 + 48, -shift[1]);
2247     transpose_16nx16m_avx512(outcoeff512, in, txfm_size_col, txfm_size_row);
2248     // row tranform
2249     av1_fdct64_new_avx512(in, in, bitrow, txfm_size_row, num_row);
2250     transpose_16nx16m_avx512(in, outcoeff512, txfm_size_row, txfm_size_col);
2251     (void)bd;
2252     (void)tx_type;
2253 }
2254 
av1_fdct32_new_line_wraper_avx512(const __m512i * input,__m512i * output,const int8_t cos_bit,const int32_t stride)2255 static void av1_fdct32_new_line_wraper_avx512(const __m512i *input, __m512i *output,
2256                                               const int8_t cos_bit, const int32_t stride) {
2257     av1_fdct32_new_avx512(input, output, cos_bit, 16, stride);
2258 }
2259 
2260 static const fwd_transform_1d_avx512 col_fwdtxfm_16x32_arr[TX_TYPES] = {
2261     av1_fdct32_new_line_wraper_avx512, // DCT_DCT
2262     NULL, // ADST_DCT
2263     NULL, // DCT_ADST
2264     NULL, // ADST_ADST
2265     NULL, // FLIPADST_DCT
2266     NULL, // DCT_FLIPADST
2267     NULL, // FLIPADST_FLIPADST
2268     NULL, // ADST_FLIPADST
2269     NULL, // FLIPADST_ADST
2270     av1_idtx32_new_avx512, // IDTX
2271     NULL, // V_DCT
2272     NULL, // H_DCT
2273     NULL, // V_ADST
2274     NULL, // H_ADST
2275     NULL, // V_FLIPADST
2276     NULL // H_FLIPADST
2277 };
2278 
2279 static const fwd_transform_1d_avx512 row_fwdtxfm_16x32_arr[TX_TYPES] = {
2280     fdct16x16_avx512, // DCT_DCT
2281     NULL, // ADST_DCT
2282     NULL, // DCT_ADST
2283     NULL, // ADST_ADST
2284     NULL, // FLIPADST_DCT
2285     NULL, // DCT_FLIPADST
2286     NULL, // FLIPADST_FLIPADST
2287     NULL, // ADST_FLIPADST
2288     NULL, // FLIPADST_ADST
2289     fidtx16x16_avx512, // IDTX
2290     NULL, // V_DCT
2291     NULL, // H_DCT
2292     NULL, // V_ADST
2293     NULL, // H_ADST
2294     NULL, // V_FLIPADST
2295     NULL // H_FLIPADST
2296 };
2297 
2298 /* call this function only for DCT_DCT, IDTX */
av1_fwd_txfm2d_16x32_avx512(int16_t * input,int32_t * output,uint32_t stride,TxType tx_type,uint8_t bd)2299 void av1_fwd_txfm2d_16x32_avx512(int16_t *input, int32_t *output, uint32_t stride, TxType tx_type,
2300                                  uint8_t bd) {
2301     __m512i                       in[32];
2302     __m512i *                     outcoef512    = (__m512i *)output;
2303     const int8_t *                shift         = fwd_txfm_shift_ls[TX_16X32];
2304     const int32_t                 txw_idx       = get_txw_idx(TX_16X32);
2305     const int32_t                 txh_idx       = get_txh_idx(TX_16X32);
2306     const fwd_transform_1d_avx512 col_txfm      = col_fwdtxfm_16x32_arr[tx_type];
2307     const fwd_transform_1d_avx512 row_txfm      = row_fwdtxfm_16x32_arr[tx_type];
2308     const int8_t                  bitcol        = fwd_cos_bit_col[txw_idx][txh_idx];
2309     const int8_t                  bitrow        = fwd_cos_bit_row[txw_idx][txh_idx];
2310     const int32_t                 txfm_size_col = tx_size_wide[TX_16X32];
2311     const int32_t                 txfm_size_row = tx_size_high[TX_16X32];
2312     const int32_t                 num_row       = txfm_size_row >> 4;
2313     const int32_t                 num_col       = txfm_size_col >> 4;
2314 
2315     // column transform
2316     load_buffer_16x16_avx512(input, in, stride, 0, 0, shift[0]);
2317     load_buffer_16x16_avx512(input + 16 * stride, in + 16, stride, 0, 0, shift[0]);
2318 
2319     col_txfm(in, in, bitcol, num_col);
2320     col_txfm_16x16_rounding_avx512(&in[0], -shift[1]);
2321     col_txfm_16x16_rounding_avx512(&in[16], -shift[1]);
2322     transpose_16nx16m_avx512(in, outcoef512, txfm_size_col, txfm_size_row);
2323 
2324     // row transform
2325     row_txfm(outcoef512, in, bitrow, num_row);
2326     transpose_16nx16m_avx512(in, outcoef512, txfm_size_row, txfm_size_col);
2327     av1_round_shift_rect_array_32_avx512(outcoef512, outcoef512, 32, -shift[2], 5793);
2328     (void)bd;
2329 }
2330 
2331 /* call this function only for DCT_DCT, IDTX */
av1_fwd_txfm2d_32x16_avx512(int16_t * input,int32_t * output,uint32_t stride,TxType tx_type,uint8_t bd)2332 void av1_fwd_txfm2d_32x16_avx512(int16_t *input, int32_t *output, uint32_t stride, TxType tx_type,
2333                                  uint8_t bd) {
2334     __m512i                       in[32];
2335     __m512i *                     outcoef512    = (__m512i *)output;
2336     const int8_t *                shift         = fwd_txfm_shift_ls[TX_32X16];
2337     const int32_t                 txw_idx       = get_txw_idx(TX_32X16);
2338     const int32_t                 txh_idx       = get_txh_idx(TX_32X16);
2339     const fwd_transform_1d_avx512 col_txfm      = row_fwdtxfm_16x32_arr[tx_type];
2340     const fwd_transform_1d_avx512 row_txfm      = col_fwdtxfm_16x32_arr[tx_type];
2341     const int8_t                  bitcol        = fwd_cos_bit_col[txw_idx][txh_idx];
2342     const int8_t                  bitrow        = fwd_cos_bit_row[txw_idx][txh_idx];
2343     const int32_t                 txfm_size_col = tx_size_wide[TX_32X16];
2344     const int32_t                 txfm_size_row = tx_size_high[TX_32X16];
2345     const int32_t                 num_row       = txfm_size_row >> 4;
2346     const int32_t                 num_col       = txfm_size_col >> 4;
2347 
2348     // column transform
2349     load_buffer_32x16n(input, in, stride, 0, 0, shift[0], txfm_size_row);
2350     col_txfm(in, in, bitcol, num_col);
2351     col_txfm_16x16_rounding_avx512(&in[0], -shift[1]);
2352     col_txfm_16x16_rounding_avx512(&in[16], -shift[1]);
2353     transpose_16nx16m_avx512(in, outcoef512, txfm_size_col, txfm_size_row);
2354 
2355     // row transform
2356     row_txfm(outcoef512, in, bitrow, num_row);
2357 
2358     transpose_16nx16m_avx512(in, outcoef512, txfm_size_row, txfm_size_col);
2359     av1_round_shift_rect_array_32_avx512(outcoef512, outcoef512, 32, -shift[2], 5793);
2360     (void)bd;
2361 }
2362 
av1_round_shift_array_wxh_N2_avx512(__m512i * input,__m512i * output,const int32_t col_num,const int32_t height,const int8_t bit)2363 static AOM_FORCE_INLINE void av1_round_shift_array_wxh_N2_avx512(__m512i *input, __m512i *output,
2364                                                                  const int32_t col_num,
2365                                                                  const int32_t height,
2366                                                                  const int8_t  bit) {
2367     if (bit > 0) {
2368         __m512i round = _mm512_set1_epi32(1 << (bit - 1));
2369         int32_t i, j;
2370         for (i = 0; i < height; i++) {
2371             for (j = 0; j < col_num / 2; j++) {
2372                 output[i * col_num + j] = _mm512_srai_epi32(
2373                     _mm512_add_epi32(input[i * col_num + j], round), (uint8_t)bit);
2374             }
2375         }
2376     } else {
2377         int32_t i, j;
2378         for (i = 0; i < height; i++) {
2379             for (j = 0; j < col_num / 2; j++) {
2380                 output[i * col_num + j] = _mm512_slli_epi32(input[i * col_num + j],
2381                                                             (uint8_t)(-bit));
2382             }
2383         }
2384     }
2385 }
2386 
av1_round_shift_rect_array_wxh_N2_avx512(__m512i * input,__m512i * output,const int32_t col_num,const int32_t height,const int8_t bit,const int32_t val)2387 static AOM_FORCE_INLINE void av1_round_shift_rect_array_wxh_N2_avx512(
2388     __m512i *input, __m512i *output, const int32_t col_num, const int32_t height, const int8_t bit,
2389     const int32_t val) {
2390     const __m512i sqrt2  = _mm512_set1_epi32(val);
2391     const __m512i round2 = _mm512_set1_epi32(1 << (12 - 1));
2392     int32_t       i, j;
2393     if (bit > 0) {
2394         const __m512i round1 = _mm512_set1_epi32(1 << (bit - 1));
2395         __m512i       r0, r1, r2, r3;
2396         for (i = 0; i < height; i++) {
2397             for (j = 0; j < col_num / 2; j++) {
2398                 r0                      = _mm512_add_epi32(input[i * col_num + j], round1);
2399                 r1                      = _mm512_srai_epi32(r0, (uint8_t)bit);
2400                 r2                      = _mm512_mullo_epi32(sqrt2, r1);
2401                 r3                      = _mm512_add_epi32(r2, round2);
2402                 output[i * col_num + j] = _mm512_srai_epi32(r3, (uint8_t)12);
2403             }
2404         }
2405     } else {
2406         __m512i r0, r1, r2;
2407         for (i = 0; i < height; i++) {
2408             for (j = 0; j < col_num / 2; j++) {
2409                 r0                      = _mm512_slli_epi32(input[i * col_num + j], (uint8_t)-bit);
2410                 r1                      = _mm512_mullo_epi32(sqrt2, r0);
2411                 r2                      = _mm512_add_epi32(r1, round2);
2412                 output[i * col_num + j] = _mm512_srai_epi32(r2, (uint8_t)12);
2413             }
2414         }
2415     }
2416 }
2417 
clear_buffer_wxh_N2_avx512(__m512i * buff,int32_t width,int32_t height)2418 static AOM_FORCE_INLINE void clear_buffer_wxh_N2_avx512(__m512i *buff, int32_t width,
2419                                                         int32_t height) {
2420     const __m512i zero512 = _mm512_setzero_si512();
2421     int32_t       i, j;
2422     assert(width > 16);
2423 
2424     int32_t num_col = width >> 4;
2425 
2426     //top-right quarter
2427     for (i = 0; i < height / 2; i++)
2428         for (j = num_col / 2; j < num_col; j++) buff[i * num_col + j] = zero512;
2429 
2430     //bottom half
2431     for (i = height / 2; i < height; i++)
2432         for (j = 0; j < num_col; j++) buff[i * num_col + j] = zero512;
2433 }
2434 
load_buffer_32x32_in_64x64_avx512(const int16_t * input,int32_t stride,__m512i * output)2435 static AOM_FORCE_INLINE void load_buffer_32x32_in_64x64_avx512(const int16_t *input, int32_t stride,
2436                                                                __m512i *output) {
2437     __m256i x0, x1;
2438     __m512i v0, v1;
2439     int32_t i;
2440 
2441     for (i = 0; i < 32; ++i) {
2442         x0 = _mm256_loadu_si256((const __m256i *)(input + 0 * 16));
2443         x1 = _mm256_loadu_si256((const __m256i *)(input + 1 * 16));
2444 
2445         v0 = _mm512_cvtepi16_epi32(x0);
2446         v1 = _mm512_cvtepi16_epi32(x1);
2447 
2448         _mm512_storeu_si512(output + 0, v0);
2449         _mm512_storeu_si512(output + 1, v1);
2450 
2451         input += stride;
2452         output += 4;
2453     }
2454 }
2455 
load_buffer_16xh_in_32x32_avx512(const int16_t * input,__m512i * output,int32_t stride,int32_t height)2456 static AOM_FORCE_INLINE void load_buffer_16xh_in_32x32_avx512(const int16_t *input, __m512i *output,
2457                                                               int32_t stride, int32_t height) {
2458     __m256i temp;
2459     int32_t i;
2460 
2461     for (i = 0; i < height; ++i) {
2462         temp      = _mm256_loadu_si256((const __m256i *)(input + 0 * 16));
2463         output[0] = _mm512_cvtepi16_epi32(temp);
2464 
2465         input += stride;
2466         output += 2;
2467     }
2468 }
2469 
load_buffer_32xh_in_32x32_avx512(const int16_t * input,__m512i * output,int32_t stride,int32_t height)2470 static AOM_FORCE_INLINE void load_buffer_32xh_in_32x32_avx512(const int16_t *input, __m512i *output,
2471                                                               int32_t stride, int32_t height) {
2472     __m256i temp[2];
2473     int32_t i;
2474 
2475     for (i = 0; i < height; ++i) {
2476         temp[0] = _mm256_loadu_si256((const __m256i *)(input + 0 * 16));
2477         temp[1] = _mm256_loadu_si256((const __m256i *)(input + 1 * 16));
2478 
2479         output[0] = _mm512_cvtepi16_epi32(temp[0]);
2480         output[1] = _mm512_cvtepi16_epi32(temp[1]);
2481 
2482         input += stride;
2483         output += 2;
2484     }
2485 }
2486 
av1_fdct64_new_N2_avx512(const __m512i * input,__m512i * output,const int8_t cos_bit,const int32_t col_num,const int32_t stride)2487 static void av1_fdct64_new_N2_avx512(const __m512i *input, __m512i *output, const int8_t cos_bit,
2488                                      const int32_t col_num, const int32_t stride) {
2489     const int32_t *cospi      = cospi_arr(cos_bit);
2490     const __m512i  __rounding = _mm512_set1_epi32(1 << (cos_bit - 1));
2491     const int32_t  columns    = col_num >> 4;
2492 
2493     __m512i cospi_m32 = _mm512_set1_epi32(-cospi[32]);
2494     __m512i cospi_p32 = _mm512_set1_epi32(cospi[32]);
2495     __m512i cospi_m16 = _mm512_set1_epi32(-cospi[16]);
2496     __m512i cospi_p48 = _mm512_set1_epi32(cospi[48]);
2497     __m512i cospi_m48 = _mm512_set1_epi32(-cospi[48]);
2498     __m512i cospi_p16 = _mm512_set1_epi32(cospi[16]);
2499     __m512i cospi_m08 = _mm512_set1_epi32(-cospi[8]);
2500     __m512i cospi_p56 = _mm512_set1_epi32(cospi[56]);
2501     __m512i cospi_m56 = _mm512_set1_epi32(-cospi[56]);
2502     __m512i cospi_m40 = _mm512_set1_epi32(-cospi[40]);
2503     __m512i cospi_p24 = _mm512_set1_epi32(cospi[24]);
2504     __m512i cospi_m24 = _mm512_set1_epi32(-cospi[24]);
2505     __m512i cospi_p08 = _mm512_set1_epi32(cospi[8]);
2506     __m512i cospi_p60 = _mm512_set1_epi32(cospi[60]);
2507     __m512i cospi_p04 = _mm512_set1_epi32(cospi[4]);
2508     __m512i cospi_p28 = _mm512_set1_epi32(cospi[28]);
2509     __m512i cospi_p44 = _mm512_set1_epi32(cospi[44]);
2510     __m512i cospi_p20 = _mm512_set1_epi32(cospi[20]);
2511     __m512i cospi_p12 = _mm512_set1_epi32(cospi[12]);
2512     __m512i cospi_m04 = _mm512_set1_epi32(-cospi[4]);
2513     __m512i cospi_m60 = _mm512_set1_epi32(-cospi[60]);
2514     __m512i cospi_m36 = _mm512_set1_epi32(-cospi[36]);
2515     __m512i cospi_m28 = _mm512_set1_epi32(-cospi[28]);
2516     __m512i cospi_m20 = _mm512_set1_epi32(-cospi[20]);
2517     __m512i cospi_m44 = _mm512_set1_epi32(-cospi[44]);
2518     __m512i cospi_m52 = _mm512_set1_epi32(-cospi[52]);
2519     __m512i cospi_m12 = _mm512_set1_epi32(-cospi[12]);
2520     __m512i cospi_p62 = _mm512_set1_epi32(cospi[62]);
2521     __m512i cospi_p02 = _mm512_set1_epi32(cospi[2]);
2522     __m512i cospi_p30 = _mm512_set1_epi32(cospi[30]);
2523     __m512i cospi_m34 = _mm512_set1_epi32(-cospi[34]);
2524     __m512i cospi_p46 = _mm512_set1_epi32(cospi[46]);
2525     __m512i cospi_p18 = _mm512_set1_epi32(cospi[18]);
2526     __m512i cospi_p14 = _mm512_set1_epi32(cospi[14]);
2527     __m512i cospi_m50 = _mm512_set1_epi32(-cospi[50]);
2528     __m512i cospi_p54 = _mm512_set1_epi32(cospi[54]);
2529     __m512i cospi_p10 = _mm512_set1_epi32(cospi[10]);
2530     __m512i cospi_p22 = _mm512_set1_epi32(cospi[22]);
2531     __m512i cospi_m42 = _mm512_set1_epi32(-cospi[42]);
2532     __m512i cospi_p38 = _mm512_set1_epi32(cospi[38]);
2533     __m512i cospi_p26 = _mm512_set1_epi32(cospi[26]);
2534     __m512i cospi_p06 = _mm512_set1_epi32(cospi[6]);
2535     __m512i cospi_m58 = _mm512_set1_epi32(-cospi[58]);
2536     __m512i cospi_p63 = _mm512_set1_epi32(cospi[63]);
2537     __m512i cospi_p01 = _mm512_set1_epi32(cospi[1]);
2538     __m512i cospi_p31 = _mm512_set1_epi32(cospi[31]);
2539     __m512i cospi_m33 = _mm512_set1_epi32(-cospi[33]);
2540     __m512i cospi_p47 = _mm512_set1_epi32(cospi[47]);
2541     __m512i cospi_p17 = _mm512_set1_epi32(cospi[17]);
2542     __m512i cospi_p15 = _mm512_set1_epi32(cospi[15]);
2543     __m512i cospi_m49 = _mm512_set1_epi32(-cospi[49]);
2544     __m512i cospi_p55 = _mm512_set1_epi32(cospi[55]);
2545     __m512i cospi_p09 = _mm512_set1_epi32(cospi[9]);
2546     __m512i cospi_p23 = _mm512_set1_epi32(cospi[23]);
2547     __m512i cospi_m41 = _mm512_set1_epi32(-cospi[41]);
2548     __m512i cospi_p39 = _mm512_set1_epi32(cospi[39]);
2549     __m512i cospi_p25 = _mm512_set1_epi32(cospi[25]);
2550     __m512i cospi_p07 = _mm512_set1_epi32(cospi[7]);
2551     __m512i cospi_m57 = _mm512_set1_epi32(-cospi[57]);
2552     __m512i cospi_p59 = _mm512_set1_epi32(cospi[59]);
2553     __m512i cospi_p05 = _mm512_set1_epi32(cospi[5]);
2554     __m512i cospi_p27 = _mm512_set1_epi32(cospi[27]);
2555     __m512i cospi_m37 = _mm512_set1_epi32(-cospi[37]);
2556     __m512i cospi_p43 = _mm512_set1_epi32(cospi[43]);
2557     __m512i cospi_p21 = _mm512_set1_epi32(cospi[21]);
2558     __m512i cospi_p11 = _mm512_set1_epi32(cospi[11]);
2559     __m512i cospi_m53 = _mm512_set1_epi32(-cospi[53]);
2560     __m512i cospi_p51 = _mm512_set1_epi32(cospi[51]);
2561     __m512i cospi_p13 = _mm512_set1_epi32(cospi[13]);
2562     __m512i cospi_p19 = _mm512_set1_epi32(cospi[19]);
2563     __m512i cospi_m45 = _mm512_set1_epi32(-cospi[45]);
2564     __m512i cospi_p35 = _mm512_set1_epi32(cospi[35]);
2565     __m512i cospi_p29 = _mm512_set1_epi32(cospi[29]);
2566     __m512i cospi_p03 = _mm512_set1_epi32(cospi[3]);
2567     __m512i cospi_m61 = _mm512_set1_epi32(-cospi[61]);
2568 
2569     for (int32_t col = 0; col < columns; col++) {
2570         const __m512i *in  = &input[col];
2571         __m512i *      out = &output[col];
2572 
2573         // stage 1
2574         __m512i x1[64];
2575         x1[0]  = _mm512_add_epi32(in[0 * stride], in[63 * stride]);
2576         x1[63] = _mm512_sub_epi32(in[0 * stride], in[63 * stride]);
2577         x1[1]  = _mm512_add_epi32(in[1 * stride], in[62 * stride]);
2578         x1[62] = _mm512_sub_epi32(in[1 * stride], in[62 * stride]);
2579         x1[2]  = _mm512_add_epi32(in[2 * stride], in[61 * stride]);
2580         x1[61] = _mm512_sub_epi32(in[2 * stride], in[61 * stride]);
2581         x1[3]  = _mm512_add_epi32(in[3 * stride], in[60 * stride]);
2582         x1[60] = _mm512_sub_epi32(in[3 * stride], in[60 * stride]);
2583         x1[4]  = _mm512_add_epi32(in[4 * stride], in[59 * stride]);
2584         x1[59] = _mm512_sub_epi32(in[4 * stride], in[59 * stride]);
2585         x1[5]  = _mm512_add_epi32(in[5 * stride], in[58 * stride]);
2586         x1[58] = _mm512_sub_epi32(in[5 * stride], in[58 * stride]);
2587         x1[6]  = _mm512_add_epi32(in[6 * stride], in[57 * stride]);
2588         x1[57] = _mm512_sub_epi32(in[6 * stride], in[57 * stride]);
2589         x1[7]  = _mm512_add_epi32(in[7 * stride], in[56 * stride]);
2590         x1[56] = _mm512_sub_epi32(in[7 * stride], in[56 * stride]);
2591         x1[8]  = _mm512_add_epi32(in[8 * stride], in[55 * stride]);
2592         x1[55] = _mm512_sub_epi32(in[8 * stride], in[55 * stride]);
2593         x1[9]  = _mm512_add_epi32(in[9 * stride], in[54 * stride]);
2594         x1[54] = _mm512_sub_epi32(in[9 * stride], in[54 * stride]);
2595         x1[10] = _mm512_add_epi32(in[10 * stride], in[53 * stride]);
2596         x1[53] = _mm512_sub_epi32(in[10 * stride], in[53 * stride]);
2597         x1[11] = _mm512_add_epi32(in[11 * stride], in[52 * stride]);
2598         x1[52] = _mm512_sub_epi32(in[11 * stride], in[52 * stride]);
2599         x1[12] = _mm512_add_epi32(in[12 * stride], in[51 * stride]);
2600         x1[51] = _mm512_sub_epi32(in[12 * stride], in[51 * stride]);
2601         x1[13] = _mm512_add_epi32(in[13 * stride], in[50 * stride]);
2602         x1[50] = _mm512_sub_epi32(in[13 * stride], in[50 * stride]);
2603         x1[14] = _mm512_add_epi32(in[14 * stride], in[49 * stride]);
2604         x1[49] = _mm512_sub_epi32(in[14 * stride], in[49 * stride]);
2605         x1[15] = _mm512_add_epi32(in[15 * stride], in[48 * stride]);
2606         x1[48] = _mm512_sub_epi32(in[15 * stride], in[48 * stride]);
2607         x1[16] = _mm512_add_epi32(in[16 * stride], in[47 * stride]);
2608         x1[47] = _mm512_sub_epi32(in[16 * stride], in[47 * stride]);
2609         x1[17] = _mm512_add_epi32(in[17 * stride], in[46 * stride]);
2610         x1[46] = _mm512_sub_epi32(in[17 * stride], in[46 * stride]);
2611         x1[18] = _mm512_add_epi32(in[18 * stride], in[45 * stride]);
2612         x1[45] = _mm512_sub_epi32(in[18 * stride], in[45 * stride]);
2613         x1[19] = _mm512_add_epi32(in[19 * stride], in[44 * stride]);
2614         x1[44] = _mm512_sub_epi32(in[19 * stride], in[44 * stride]);
2615         x1[20] = _mm512_add_epi32(in[20 * stride], in[43 * stride]);
2616         x1[43] = _mm512_sub_epi32(in[20 * stride], in[43 * stride]);
2617         x1[21] = _mm512_add_epi32(in[21 * stride], in[42 * stride]);
2618         x1[42] = _mm512_sub_epi32(in[21 * stride], in[42 * stride]);
2619         x1[22] = _mm512_add_epi32(in[22 * stride], in[41 * stride]);
2620         x1[41] = _mm512_sub_epi32(in[22 * stride], in[41 * stride]);
2621         x1[23] = _mm512_add_epi32(in[23 * stride], in[40 * stride]);
2622         x1[40] = _mm512_sub_epi32(in[23 * stride], in[40 * stride]);
2623         x1[24] = _mm512_add_epi32(in[24 * stride], in[39 * stride]);
2624         x1[39] = _mm512_sub_epi32(in[24 * stride], in[39 * stride]);
2625         x1[25] = _mm512_add_epi32(in[25 * stride], in[38 * stride]);
2626         x1[38] = _mm512_sub_epi32(in[25 * stride], in[38 * stride]);
2627         x1[26] = _mm512_add_epi32(in[26 * stride], in[37 * stride]);
2628         x1[37] = _mm512_sub_epi32(in[26 * stride], in[37 * stride]);
2629         x1[27] = _mm512_add_epi32(in[27 * stride], in[36 * stride]);
2630         x1[36] = _mm512_sub_epi32(in[27 * stride], in[36 * stride]);
2631         x1[28] = _mm512_add_epi32(in[28 * stride], in[35 * stride]);
2632         x1[35] = _mm512_sub_epi32(in[28 * stride], in[35 * stride]);
2633         x1[29] = _mm512_add_epi32(in[29 * stride], in[34 * stride]);
2634         x1[34] = _mm512_sub_epi32(in[29 * stride], in[34 * stride]);
2635         x1[30] = _mm512_add_epi32(in[30 * stride], in[33 * stride]);
2636         x1[33] = _mm512_sub_epi32(in[30 * stride], in[33 * stride]);
2637         x1[31] = _mm512_add_epi32(in[31 * stride], in[32 * stride]);
2638         x1[32] = _mm512_sub_epi32(in[31 * stride], in[32 * stride]);
2639 
2640         // stage 2
2641         __m512i x2[64];
2642         x2[0]  = _mm512_add_epi32(x1[0], x1[31]);
2643         x2[31] = _mm512_sub_epi32(x1[0], x1[31]);
2644         x2[1]  = _mm512_add_epi32(x1[1], x1[30]);
2645         x2[30] = _mm512_sub_epi32(x1[1], x1[30]);
2646         x2[2]  = _mm512_add_epi32(x1[2], x1[29]);
2647         x2[29] = _mm512_sub_epi32(x1[2], x1[29]);
2648         x2[3]  = _mm512_add_epi32(x1[3], x1[28]);
2649         x2[28] = _mm512_sub_epi32(x1[3], x1[28]);
2650         x2[4]  = _mm512_add_epi32(x1[4], x1[27]);
2651         x2[27] = _mm512_sub_epi32(x1[4], x1[27]);
2652         x2[5]  = _mm512_add_epi32(x1[5], x1[26]);
2653         x2[26] = _mm512_sub_epi32(x1[5], x1[26]);
2654         x2[6]  = _mm512_add_epi32(x1[6], x1[25]);
2655         x2[25] = _mm512_sub_epi32(x1[6], x1[25]);
2656         x2[7]  = _mm512_add_epi32(x1[7], x1[24]);
2657         x2[24] = _mm512_sub_epi32(x1[7], x1[24]);
2658         x2[8]  = _mm512_add_epi32(x1[8], x1[23]);
2659         x2[23] = _mm512_sub_epi32(x1[8], x1[23]);
2660         x2[9]  = _mm512_add_epi32(x1[9], x1[22]);
2661         x2[22] = _mm512_sub_epi32(x1[9], x1[22]);
2662         x2[10] = _mm512_add_epi32(x1[10], x1[21]);
2663         x2[21] = _mm512_sub_epi32(x1[10], x1[21]);
2664         x2[11] = _mm512_add_epi32(x1[11], x1[20]);
2665         x2[20] = _mm512_sub_epi32(x1[11], x1[20]);
2666         x2[12] = _mm512_add_epi32(x1[12], x1[19]);
2667         x2[19] = _mm512_sub_epi32(x1[12], x1[19]);
2668         x2[13] = _mm512_add_epi32(x1[13], x1[18]);
2669         x2[18] = _mm512_sub_epi32(x1[13], x1[18]);
2670         x2[14] = _mm512_add_epi32(x1[14], x1[17]);
2671         x2[17] = _mm512_sub_epi32(x1[14], x1[17]);
2672         x2[15] = _mm512_add_epi32(x1[15], x1[16]);
2673         x2[16] = _mm512_sub_epi32(x1[15], x1[16]);
2674         x2[32] = x1[32];
2675         x2[33] = x1[33];
2676         x2[34] = x1[34];
2677         x2[35] = x1[35];
2678         x2[36] = x1[36];
2679         x2[37] = x1[37];
2680         x2[38] = x1[38];
2681         x2[39] = x1[39];
2682         btf_32_type0_avx512_new(
2683             cospi_m32, cospi_p32, x1[40], x1[55], x2[40], x2[55], __rounding, cos_bit);
2684         btf_32_type0_avx512_new(
2685             cospi_m32, cospi_p32, x1[41], x1[54], x2[41], x2[54], __rounding, cos_bit);
2686         btf_32_type0_avx512_new(
2687             cospi_m32, cospi_p32, x1[42], x1[53], x2[42], x2[53], __rounding, cos_bit);
2688         btf_32_type0_avx512_new(
2689             cospi_m32, cospi_p32, x1[43], x1[52], x2[43], x2[52], __rounding, cos_bit);
2690         btf_32_type0_avx512_new(
2691             cospi_m32, cospi_p32, x1[44], x1[51], x2[44], x2[51], __rounding, cos_bit);
2692         btf_32_type0_avx512_new(
2693             cospi_m32, cospi_p32, x1[45], x1[50], x2[45], x2[50], __rounding, cos_bit);
2694         btf_32_type0_avx512_new(
2695             cospi_m32, cospi_p32, x1[46], x1[49], x2[46], x2[49], __rounding, cos_bit);
2696         btf_32_type0_avx512_new(
2697             cospi_m32, cospi_p32, x1[47], x1[48], x2[47], x2[48], __rounding, cos_bit);
2698         x2[56] = x1[56];
2699         x2[57] = x1[57];
2700         x2[58] = x1[58];
2701         x2[59] = x1[59];
2702         x2[60] = x1[60];
2703         x2[61] = x1[61];
2704         x2[62] = x1[62];
2705         x2[63] = x1[63];
2706 
2707         // stage 3
2708         __m512i x3[64];
2709         x3[0]  = _mm512_add_epi32(x2[0], x2[15]);
2710         x3[15] = _mm512_sub_epi32(x2[0], x2[15]);
2711         x3[1]  = _mm512_add_epi32(x2[1], x2[14]);
2712         x3[14] = _mm512_sub_epi32(x2[1], x2[14]);
2713         x3[2]  = _mm512_add_epi32(x2[2], x2[13]);
2714         x3[13] = _mm512_sub_epi32(x2[2], x2[13]);
2715         x3[3]  = _mm512_add_epi32(x2[3], x2[12]);
2716         x3[12] = _mm512_sub_epi32(x2[3], x2[12]);
2717         x3[4]  = _mm512_add_epi32(x2[4], x2[11]);
2718         x3[11] = _mm512_sub_epi32(x2[4], x2[11]);
2719         x3[5]  = _mm512_add_epi32(x2[5], x2[10]);
2720         x3[10] = _mm512_sub_epi32(x2[5], x2[10]);
2721         x3[6]  = _mm512_add_epi32(x2[6], x2[9]);
2722         x3[9]  = _mm512_sub_epi32(x2[6], x2[9]);
2723         x3[7]  = _mm512_add_epi32(x2[7], x2[8]);
2724         x3[8]  = _mm512_sub_epi32(x2[7], x2[8]);
2725         x3[16] = x2[16];
2726         x3[17] = x2[17];
2727         x3[18] = x2[18];
2728         x3[19] = x2[19];
2729         btf_32_type0_avx512_new(
2730             cospi_m32, cospi_p32, x2[20], x2[27], x3[20], x3[27], __rounding, cos_bit);
2731         btf_32_type0_avx512_new(
2732             cospi_m32, cospi_p32, x2[21], x2[26], x3[21], x3[26], __rounding, cos_bit);
2733         btf_32_type0_avx512_new(
2734             cospi_m32, cospi_p32, x2[22], x2[25], x3[22], x3[25], __rounding, cos_bit);
2735         btf_32_type0_avx512_new(
2736             cospi_m32, cospi_p32, x2[23], x2[24], x3[23], x3[24], __rounding, cos_bit);
2737         x3[28] = x2[28];
2738         x3[29] = x2[29];
2739         x3[30] = x2[30];
2740         x3[31] = x2[31];
2741         x3[32] = _mm512_add_epi32(x2[32], x2[47]);
2742         x3[47] = _mm512_sub_epi32(x2[32], x2[47]);
2743         x3[33] = _mm512_add_epi32(x2[33], x2[46]);
2744         x3[46] = _mm512_sub_epi32(x2[33], x2[46]);
2745         x3[34] = _mm512_add_epi32(x2[34], x2[45]);
2746         x3[45] = _mm512_sub_epi32(x2[34], x2[45]);
2747         x3[35] = _mm512_add_epi32(x2[35], x2[44]);
2748         x3[44] = _mm512_sub_epi32(x2[35], x2[44]);
2749         x3[36] = _mm512_add_epi32(x2[36], x2[43]);
2750         x3[43] = _mm512_sub_epi32(x2[36], x2[43]);
2751         x3[37] = _mm512_add_epi32(x2[37], x2[42]);
2752         x3[42] = _mm512_sub_epi32(x2[37], x2[42]);
2753         x3[38] = _mm512_add_epi32(x2[38], x2[41]);
2754         x3[41] = _mm512_sub_epi32(x2[38], x2[41]);
2755         x3[39] = _mm512_add_epi32(x2[39], x2[40]);
2756         x3[40] = _mm512_sub_epi32(x2[39], x2[40]);
2757         x3[48] = _mm512_sub_epi32(x2[63], x2[48]);
2758         x3[63] = _mm512_add_epi32(x2[63], x2[48]);
2759         x3[49] = _mm512_sub_epi32(x2[62], x2[49]);
2760         x3[62] = _mm512_add_epi32(x2[62], x2[49]);
2761         x3[50] = _mm512_sub_epi32(x2[61], x2[50]);
2762         x3[61] = _mm512_add_epi32(x2[61], x2[50]);
2763         x3[51] = _mm512_sub_epi32(x2[60], x2[51]);
2764         x3[60] = _mm512_add_epi32(x2[60], x2[51]);
2765         x3[52] = _mm512_sub_epi32(x2[59], x2[52]);
2766         x3[59] = _mm512_add_epi32(x2[59], x2[52]);
2767         x3[53] = _mm512_sub_epi32(x2[58], x2[53]);
2768         x3[58] = _mm512_add_epi32(x2[58], x2[53]);
2769         x3[54] = _mm512_sub_epi32(x2[57], x2[54]);
2770         x3[57] = _mm512_add_epi32(x2[57], x2[54]);
2771         x3[55] = _mm512_sub_epi32(x2[56], x2[55]);
2772         x3[56] = _mm512_add_epi32(x2[56], x2[55]);
2773 
2774         // stage 4
2775         __m512i x4[64];
2776         x4[0] = _mm512_add_epi32(x3[0], x3[7]);
2777         x4[7] = _mm512_sub_epi32(x3[0], x3[7]);
2778         x4[1] = _mm512_add_epi32(x3[1], x3[6]);
2779         x4[6] = _mm512_sub_epi32(x3[1], x3[6]);
2780         x4[2] = _mm512_add_epi32(x3[2], x3[5]);
2781         x4[5] = _mm512_sub_epi32(x3[2], x3[5]);
2782         x4[3] = _mm512_add_epi32(x3[3], x3[4]);
2783         x4[4] = _mm512_sub_epi32(x3[3], x3[4]);
2784         x4[8] = x3[8];
2785         x4[9] = x3[9];
2786         btf_32_type0_avx512_new(
2787             cospi_m32, cospi_p32, x3[10], x3[13], x4[10], x4[13], __rounding, cos_bit);
2788         btf_32_type0_avx512_new(
2789             cospi_m32, cospi_p32, x3[11], x3[12], x4[11], x4[12], __rounding, cos_bit);
2790         x4[14] = x3[14];
2791         x4[15] = x3[15];
2792         x4[16] = _mm512_add_epi32(x3[16], x3[23]);
2793         x4[23] = _mm512_sub_epi32(x3[16], x3[23]);
2794         x4[17] = _mm512_add_epi32(x3[17], x3[22]);
2795         x4[22] = _mm512_sub_epi32(x3[17], x3[22]);
2796         x4[18] = _mm512_add_epi32(x3[18], x3[21]);
2797         x4[21] = _mm512_sub_epi32(x3[18], x3[21]);
2798         x4[19] = _mm512_add_epi32(x3[19], x3[20]);
2799         x4[20] = _mm512_sub_epi32(x3[19], x3[20]);
2800         x4[24] = _mm512_sub_epi32(x3[31], x3[24]);
2801         x4[31] = _mm512_add_epi32(x3[31], x3[24]);
2802         x4[25] = _mm512_sub_epi32(x3[30], x3[25]);
2803         x4[30] = _mm512_add_epi32(x3[30], x3[25]);
2804         x4[26] = _mm512_sub_epi32(x3[29], x3[26]);
2805         x4[29] = _mm512_add_epi32(x3[29], x3[26]);
2806         x4[27] = _mm512_sub_epi32(x3[28], x3[27]);
2807         x4[28] = _mm512_add_epi32(x3[28], x3[27]);
2808         x4[32] = x3[32];
2809         x4[33] = x3[33];
2810         x4[34] = x3[34];
2811         x4[35] = x3[35];
2812         btf_32_type0_avx512_new(
2813             cospi_m16, cospi_p48, x3[36], x3[59], x4[36], x4[59], __rounding, cos_bit);
2814         btf_32_type0_avx512_new(
2815             cospi_m16, cospi_p48, x3[37], x3[58], x4[37], x4[58], __rounding, cos_bit);
2816         btf_32_type0_avx512_new(
2817             cospi_m16, cospi_p48, x3[38], x3[57], x4[38], x4[57], __rounding, cos_bit);
2818         btf_32_type0_avx512_new(
2819             cospi_m16, cospi_p48, x3[39], x3[56], x4[39], x4[56], __rounding, cos_bit);
2820         btf_32_type0_avx512_new(
2821             cospi_m48, cospi_m16, x3[40], x3[55], x4[40], x4[55], __rounding, cos_bit);
2822         btf_32_type0_avx512_new(
2823             cospi_m48, cospi_m16, x3[41], x3[54], x4[41], x4[54], __rounding, cos_bit);
2824         btf_32_type0_avx512_new(
2825             cospi_m48, cospi_m16, x3[42], x3[53], x4[42], x4[53], __rounding, cos_bit);
2826         btf_32_type0_avx512_new(
2827             cospi_m48, cospi_m16, x3[43], x3[52], x4[43], x4[52], __rounding, cos_bit);
2828         x4[44] = x3[44];
2829         x4[45] = x3[45];
2830         x4[46] = x3[46];
2831         x4[47] = x3[47];
2832         x4[48] = x3[48];
2833         x4[49] = x3[49];
2834         x4[50] = x3[50];
2835         x4[51] = x3[51];
2836         x4[60] = x3[60];
2837         x4[61] = x3[61];
2838         x4[62] = x3[62];
2839         x4[63] = x3[63];
2840 
2841         // stage 5
2842         __m512i x5[64];
2843         x5[0] = _mm512_add_epi32(x4[0], x4[3]);
2844         x5[3] = _mm512_sub_epi32(x4[0], x4[3]);
2845         x5[1] = _mm512_add_epi32(x4[1], x4[2]);
2846         x5[2] = _mm512_sub_epi32(x4[1], x4[2]);
2847         x5[4] = x4[4];
2848         btf_32_type0_avx512_new(
2849             cospi_m32, cospi_p32, x4[5], x4[6], x5[5], x5[6], __rounding, cos_bit);
2850         x5[7]  = x4[7];
2851         x5[8]  = _mm512_add_epi32(x4[8], x4[11]);
2852         x5[11] = _mm512_sub_epi32(x4[8], x4[11]);
2853         x5[9]  = _mm512_add_epi32(x4[9], x4[10]);
2854         x5[10] = _mm512_sub_epi32(x4[9], x4[10]);
2855         x5[12] = _mm512_sub_epi32(x4[15], x4[12]);
2856         x5[15] = _mm512_add_epi32(x4[15], x4[12]);
2857         x5[13] = _mm512_sub_epi32(x4[14], x4[13]);
2858         x5[14] = _mm512_add_epi32(x4[14], x4[13]);
2859         x5[16] = x4[16];
2860         x5[17] = x4[17];
2861         btf_32_type0_avx512_new(
2862             cospi_m16, cospi_p48, x4[18], x4[29], x5[18], x5[29], __rounding, cos_bit);
2863         btf_32_type0_avx512_new(
2864             cospi_m16, cospi_p48, x4[19], x4[28], x5[19], x5[28], __rounding, cos_bit);
2865         btf_32_type0_avx512_new(
2866             cospi_m48, cospi_m16, x4[20], x4[27], x5[20], x5[27], __rounding, cos_bit);
2867         btf_32_type0_avx512_new(
2868             cospi_m48, cospi_m16, x4[21], x4[26], x5[21], x5[26], __rounding, cos_bit);
2869         x5[22] = x4[22];
2870         x5[23] = x4[23];
2871         x5[24] = x4[24];
2872         x5[25] = x4[25];
2873         x5[30] = x4[30];
2874         x5[31] = x4[31];
2875         x5[32] = _mm512_add_epi32(x4[32], x4[39]);
2876         x5[39] = _mm512_sub_epi32(x4[32], x4[39]);
2877         x5[33] = _mm512_add_epi32(x4[33], x4[38]);
2878         x5[38] = _mm512_sub_epi32(x4[33], x4[38]);
2879         x5[34] = _mm512_add_epi32(x4[34], x4[37]);
2880         x5[37] = _mm512_sub_epi32(x4[34], x4[37]);
2881         x5[35] = _mm512_add_epi32(x4[35], x4[36]);
2882         x5[36] = _mm512_sub_epi32(x4[35], x4[36]);
2883         x5[40] = _mm512_sub_epi32(x4[47], x4[40]);
2884         x5[47] = _mm512_add_epi32(x4[47], x4[40]);
2885         x5[41] = _mm512_sub_epi32(x4[46], x4[41]);
2886         x5[46] = _mm512_add_epi32(x4[46], x4[41]);
2887         x5[42] = _mm512_sub_epi32(x4[45], x4[42]);
2888         x5[45] = _mm512_add_epi32(x4[45], x4[42]);
2889         x5[43] = _mm512_sub_epi32(x4[44], x4[43]);
2890         x5[44] = _mm512_add_epi32(x4[44], x4[43]);
2891         x5[48] = _mm512_add_epi32(x4[48], x4[55]);
2892         x5[55] = _mm512_sub_epi32(x4[48], x4[55]);
2893         x5[49] = _mm512_add_epi32(x4[49], x4[54]);
2894         x5[54] = _mm512_sub_epi32(x4[49], x4[54]);
2895         x5[50] = _mm512_add_epi32(x4[50], x4[53]);
2896         x5[53] = _mm512_sub_epi32(x4[50], x4[53]);
2897         x5[51] = _mm512_add_epi32(x4[51], x4[52]);
2898         x5[52] = _mm512_sub_epi32(x4[51], x4[52]);
2899         x5[56] = _mm512_sub_epi32(x4[63], x4[56]);
2900         x5[63] = _mm512_add_epi32(x4[63], x4[56]);
2901         x5[57] = _mm512_sub_epi32(x4[62], x4[57]);
2902         x5[62] = _mm512_add_epi32(x4[62], x4[57]);
2903         x5[58] = _mm512_sub_epi32(x4[61], x4[58]);
2904         x5[61] = _mm512_add_epi32(x4[61], x4[58]);
2905         x5[59] = _mm512_sub_epi32(x4[60], x4[59]);
2906         x5[60] = _mm512_add_epi32(x4[60], x4[59]);
2907 
2908         // stage 6
2909         __m512i x6[64];
2910         x6[0] = half_btf_avx512(&cospi_p32, &x5[0], &cospi_p32, &x5[1], &__rounding, cos_bit);
2911         x6[2] = half_btf_avx512(&cospi_p48, &x5[2], &cospi_p16, &x5[3], &__rounding, cos_bit);
2912         x6[4] = _mm512_add_epi32(x5[4], x5[5]);
2913         x6[5] = _mm512_sub_epi32(x5[4], x5[5]);
2914         x6[6] = _mm512_sub_epi32(x5[7], x5[6]);
2915         x6[7] = _mm512_add_epi32(x5[7], x5[6]);
2916         x6[8] = x5[8];
2917         btf_32_type0_avx512_new(
2918             cospi_m16, cospi_p48, x5[9], x5[14], x6[9], x6[14], __rounding, cos_bit);
2919         btf_32_type0_avx512_new(
2920             cospi_m48, cospi_m16, x5[10], x5[13], x6[10], x6[13], __rounding, cos_bit);
2921         x6[11] = x5[11];
2922         x6[12] = x5[12];
2923         x6[15] = x5[15];
2924         x6[16] = _mm512_add_epi32(x5[16], x5[19]);
2925         x6[19] = _mm512_sub_epi32(x5[16], x5[19]);
2926         x6[17] = _mm512_add_epi32(x5[17], x5[18]);
2927         x6[18] = _mm512_sub_epi32(x5[17], x5[18]);
2928         x6[20] = _mm512_sub_epi32(x5[23], x5[20]);
2929         x6[23] = _mm512_add_epi32(x5[23], x5[20]);
2930         x6[21] = _mm512_sub_epi32(x5[22], x5[21]);
2931         x6[22] = _mm512_add_epi32(x5[22], x5[21]);
2932         x6[24] = _mm512_add_epi32(x5[24], x5[27]);
2933         x6[27] = _mm512_sub_epi32(x5[24], x5[27]);
2934         x6[25] = _mm512_add_epi32(x5[25], x5[26]);
2935         x6[26] = _mm512_sub_epi32(x5[25], x5[26]);
2936         x6[28] = _mm512_sub_epi32(x5[31], x5[28]);
2937         x6[31] = _mm512_add_epi32(x5[31], x5[28]);
2938         x6[29] = _mm512_sub_epi32(x5[30], x5[29]);
2939         x6[30] = _mm512_add_epi32(x5[30], x5[29]);
2940         x6[32] = x5[32];
2941         x6[33] = x5[33];
2942         btf_32_type0_avx512_new(
2943             cospi_m08, cospi_p56, x5[34], x5[61], x6[34], x6[61], __rounding, cos_bit);
2944         btf_32_type0_avx512_new(
2945             cospi_m08, cospi_p56, x5[35], x5[60], x6[35], x6[60], __rounding, cos_bit);
2946         btf_32_type0_avx512_new(
2947             cospi_m56, cospi_m08, x5[36], x5[59], x6[36], x6[59], __rounding, cos_bit);
2948         btf_32_type0_avx512_new(
2949             cospi_m56, cospi_m08, x5[37], x5[58], x6[37], x6[58], __rounding, cos_bit);
2950         x6[38] = x5[38];
2951         x6[39] = x5[39];
2952         x6[40] = x5[40];
2953         x6[41] = x5[41];
2954         btf_32_type0_avx512_new(
2955             cospi_m40, cospi_p24, x5[42], x5[53], x6[42], x6[53], __rounding, cos_bit);
2956         btf_32_type0_avx512_new(
2957             cospi_m40, cospi_p24, x5[43], x5[52], x6[43], x6[52], __rounding, cos_bit);
2958         btf_32_type0_avx512_new(
2959             cospi_m24, cospi_m40, x5[44], x5[51], x6[44], x6[51], __rounding, cos_bit);
2960         btf_32_type0_avx512_new(
2961             cospi_m24, cospi_m40, x5[45], x5[50], x6[45], x6[50], __rounding, cos_bit);
2962         x6[46] = x5[46];
2963         x6[47] = x5[47];
2964         x6[48] = x5[48];
2965         x6[49] = x5[49];
2966         x6[54] = x5[54];
2967         x6[55] = x5[55];
2968         x6[56] = x5[56];
2969         x6[57] = x5[57];
2970         x6[62] = x5[62];
2971         x6[63] = x5[63];
2972 
2973         // stage 7
2974         __m512i x7[64];
2975         x7[0]  = x6[0];
2976         x7[2]  = x6[2];
2977         x7[4]  = half_btf_avx512(&cospi_p56, &x6[4], &cospi_p08, &x6[7], &__rounding, cos_bit);
2978         x7[6]  = half_btf_avx512(&cospi_p24, &x6[6], &cospi_m40, &x6[5], &__rounding, cos_bit);
2979         x7[8]  = _mm512_add_epi32(x6[8], x6[9]);
2980         x7[9]  = _mm512_sub_epi32(x6[8], x6[9]);
2981         x7[10] = _mm512_sub_epi32(x6[11], x6[10]);
2982         x7[11] = _mm512_add_epi32(x6[11], x6[10]);
2983         x7[12] = _mm512_add_epi32(x6[12], x6[13]);
2984         x7[13] = _mm512_sub_epi32(x6[12], x6[13]);
2985         x7[14] = _mm512_sub_epi32(x6[15], x6[14]);
2986         x7[15] = _mm512_add_epi32(x6[15], x6[14]);
2987         x7[16] = x6[16];
2988         btf_32_type0_avx512_new(
2989             cospi_m08, cospi_p56, x6[17], x6[30], x7[17], x7[30], __rounding, cos_bit);
2990         btf_32_type0_avx512_new(
2991             cospi_m56, cospi_m08, x6[18], x6[29], x7[18], x7[29], __rounding, cos_bit);
2992         x7[19] = x6[19];
2993         x7[20] = x6[20];
2994         btf_32_type0_avx512_new(
2995             cospi_m40, cospi_p24, x6[21], x6[26], x7[21], x7[26], __rounding, cos_bit);
2996         btf_32_type0_avx512_new(
2997             cospi_m24, cospi_m40, x6[22], x6[25], x7[22], x7[25], __rounding, cos_bit);
2998         x7[23] = x6[23];
2999         x7[24] = x6[24];
3000         x7[27] = x6[27];
3001         x7[28] = x6[28];
3002         x7[31] = x6[31];
3003         x7[32] = _mm512_add_epi32(x6[32], x6[35]);
3004         x7[35] = _mm512_sub_epi32(x6[32], x6[35]);
3005         x7[33] = _mm512_add_epi32(x6[33], x6[34]);
3006         x7[34] = _mm512_sub_epi32(x6[33], x6[34]);
3007         x7[36] = _mm512_sub_epi32(x6[39], x6[36]);
3008         x7[39] = _mm512_add_epi32(x6[39], x6[36]);
3009         x7[37] = _mm512_sub_epi32(x6[38], x6[37]);
3010         x7[38] = _mm512_add_epi32(x6[38], x6[37]);
3011         x7[40] = _mm512_add_epi32(x6[40], x6[43]);
3012         x7[43] = _mm512_sub_epi32(x6[40], x6[43]);
3013         x7[41] = _mm512_add_epi32(x6[41], x6[42]);
3014         x7[42] = _mm512_sub_epi32(x6[41], x6[42]);
3015         x7[44] = _mm512_sub_epi32(x6[47], x6[44]);
3016         x7[47] = _mm512_add_epi32(x6[47], x6[44]);
3017         x7[45] = _mm512_sub_epi32(x6[46], x6[45]);
3018         x7[46] = _mm512_add_epi32(x6[46], x6[45]);
3019         x7[48] = _mm512_add_epi32(x6[48], x6[51]);
3020         x7[51] = _mm512_sub_epi32(x6[48], x6[51]);
3021         x7[49] = _mm512_add_epi32(x6[49], x6[50]);
3022         x7[50] = _mm512_sub_epi32(x6[49], x6[50]);
3023         x7[52] = _mm512_sub_epi32(x6[55], x6[52]);
3024         x7[55] = _mm512_add_epi32(x6[55], x6[52]);
3025         x7[53] = _mm512_sub_epi32(x6[54], x6[53]);
3026         x7[54] = _mm512_add_epi32(x6[54], x6[53]);
3027         x7[56] = _mm512_add_epi32(x6[56], x6[59]);
3028         x7[59] = _mm512_sub_epi32(x6[56], x6[59]);
3029         x7[57] = _mm512_add_epi32(x6[57], x6[58]);
3030         x7[58] = _mm512_sub_epi32(x6[57], x6[58]);
3031         x7[60] = _mm512_sub_epi32(x6[63], x6[60]);
3032         x7[63] = _mm512_add_epi32(x6[63], x6[60]);
3033         x7[61] = _mm512_sub_epi32(x6[62], x6[61]);
3034         x7[62] = _mm512_add_epi32(x6[62], x6[61]);
3035 
3036         // stage 8
3037         __m512i x8[64];
3038         x8[0] = x7[0];
3039         x8[2] = x7[2];
3040         x8[4] = x7[4];
3041         x8[6] = x7[6];
3042 
3043         x8[8]  = half_btf_avx512(&cospi_p60, &x7[8], &cospi_p04, &x7[15], &__rounding, cos_bit);
3044         x8[14] = half_btf_avx512(&cospi_p28, &x7[14], &cospi_m36, &x7[9], &__rounding, cos_bit);
3045         x8[10] = half_btf_avx512(&cospi_p44, &x7[10], &cospi_p20, &x7[13], &__rounding, cos_bit);
3046         x8[12] = half_btf_avx512(&cospi_p12, &x7[12], &cospi_m52, &x7[11], &__rounding, cos_bit);
3047         x8[16] = _mm512_add_epi32(x7[16], x7[17]);
3048         x8[17] = _mm512_sub_epi32(x7[16], x7[17]);
3049         x8[18] = _mm512_sub_epi32(x7[19], x7[18]);
3050         x8[19] = _mm512_add_epi32(x7[19], x7[18]);
3051         x8[20] = _mm512_add_epi32(x7[20], x7[21]);
3052         x8[21] = _mm512_sub_epi32(x7[20], x7[21]);
3053         x8[22] = _mm512_sub_epi32(x7[23], x7[22]);
3054         x8[23] = _mm512_add_epi32(x7[23], x7[22]);
3055         x8[24] = _mm512_add_epi32(x7[24], x7[25]);
3056         x8[25] = _mm512_sub_epi32(x7[24], x7[25]);
3057         x8[26] = _mm512_sub_epi32(x7[27], x7[26]);
3058         x8[27] = _mm512_add_epi32(x7[27], x7[26]);
3059         x8[28] = _mm512_add_epi32(x7[28], x7[29]);
3060         x8[29] = _mm512_sub_epi32(x7[28], x7[29]);
3061         x8[30] = _mm512_sub_epi32(x7[31], x7[30]);
3062         x8[31] = _mm512_add_epi32(x7[31], x7[30]);
3063         x8[32] = x7[32];
3064         btf_32_type0_avx512_new(
3065             cospi_m04, cospi_p60, x7[33], x7[62], x8[33], x8[62], __rounding, cos_bit);
3066         btf_32_type0_avx512_new(
3067             cospi_m60, cospi_m04, x7[34], x7[61], x8[34], x8[61], __rounding, cos_bit);
3068         x8[35] = x7[35];
3069         x8[36] = x7[36];
3070         btf_32_type0_avx512_new(
3071             cospi_m36, cospi_p28, x7[37], x7[58], x8[37], x8[58], __rounding, cos_bit);
3072         btf_32_type0_avx512_new(
3073             cospi_m28, cospi_m36, x7[38], x7[57], x8[38], x8[57], __rounding, cos_bit);
3074         x8[39] = x7[39];
3075         x8[40] = x7[40];
3076         btf_32_type0_avx512_new(
3077             cospi_m20, cospi_p44, x7[41], x7[54], x8[41], x8[54], __rounding, cos_bit);
3078         btf_32_type0_avx512_new(
3079             cospi_m44, cospi_m20, x7[42], x7[53], x8[42], x8[53], __rounding, cos_bit);
3080         x8[43] = x7[43];
3081         x8[44] = x7[44];
3082         btf_32_type0_avx512_new(
3083             cospi_m52, cospi_p12, x7[45], x7[50], x8[45], x8[50], __rounding, cos_bit);
3084         btf_32_type0_avx512_new(
3085             cospi_m12, cospi_m52, x7[46], x7[49], x8[46], x8[49], __rounding, cos_bit);
3086         x8[47] = x7[47];
3087         x8[48] = x7[48];
3088         x8[51] = x7[51];
3089         x8[52] = x7[52];
3090         x8[55] = x7[55];
3091         x8[56] = x7[56];
3092         x8[59] = x7[59];
3093         x8[60] = x7[60];
3094         x8[63] = x7[63];
3095 
3096         // stage 9
3097         __m512i x9[64];
3098         x9[0]  = x8[0];
3099         x9[2]  = x8[2];
3100         x9[4]  = x8[4];
3101         x9[6]  = x8[6];
3102         x9[8]  = x8[8];
3103         x9[10] = x8[10];
3104         x9[12] = x8[12];
3105         x9[14] = x8[14];
3106         x9[16] = half_btf_avx512(&cospi_p62, &x8[16], &cospi_p02, &x8[31], &__rounding, cos_bit);
3107         x9[30] = half_btf_avx512(&cospi_p30, &x8[30], &cospi_m34, &x8[17], &__rounding, cos_bit);
3108         x9[18] = half_btf_avx512(&cospi_p46, &x8[18], &cospi_p18, &x8[29], &__rounding, cos_bit);
3109         x9[28] = half_btf_avx512(&cospi_p14, &x8[28], &cospi_m50, &x8[19], &__rounding, cos_bit);
3110         x9[20] = half_btf_avx512(&cospi_p54, &x8[20], &cospi_p10, &x8[27], &__rounding, cos_bit);
3111         x9[26] = half_btf_avx512(&cospi_p22, &x8[26], &cospi_m42, &x8[21], &__rounding, cos_bit);
3112         x9[22] = half_btf_avx512(&cospi_p38, &x8[22], &cospi_p26, &x8[25], &__rounding, cos_bit);
3113         x9[24] = half_btf_avx512(&cospi_p06, &x8[24], &cospi_m58, &x8[23], &__rounding, cos_bit);
3114         x9[32] = _mm512_add_epi32(x8[32], x8[33]);
3115         x9[33] = _mm512_sub_epi32(x8[32], x8[33]);
3116         x9[34] = _mm512_sub_epi32(x8[35], x8[34]);
3117         x9[35] = _mm512_add_epi32(x8[35], x8[34]);
3118         x9[36] = _mm512_add_epi32(x8[36], x8[37]);
3119         x9[37] = _mm512_sub_epi32(x8[36], x8[37]);
3120         x9[38] = _mm512_sub_epi32(x8[39], x8[38]);
3121         x9[39] = _mm512_add_epi32(x8[39], x8[38]);
3122         x9[40] = _mm512_add_epi32(x8[40], x8[41]);
3123         x9[41] = _mm512_sub_epi32(x8[40], x8[41]);
3124         x9[42] = _mm512_sub_epi32(x8[43], x8[42]);
3125         x9[43] = _mm512_add_epi32(x8[43], x8[42]);
3126         x9[44] = _mm512_add_epi32(x8[44], x8[45]);
3127         x9[45] = _mm512_sub_epi32(x8[44], x8[45]);
3128         x9[46] = _mm512_sub_epi32(x8[47], x8[46]);
3129         x9[47] = _mm512_add_epi32(x8[47], x8[46]);
3130         x9[48] = _mm512_add_epi32(x8[48], x8[49]);
3131         x9[49] = _mm512_sub_epi32(x8[48], x8[49]);
3132         x9[50] = _mm512_sub_epi32(x8[51], x8[50]);
3133         x9[51] = _mm512_add_epi32(x8[51], x8[50]);
3134         x9[52] = _mm512_add_epi32(x8[52], x8[53]);
3135         x9[53] = _mm512_sub_epi32(x8[52], x8[53]);
3136         x9[54] = _mm512_sub_epi32(x8[55], x8[54]);
3137         x9[55] = _mm512_add_epi32(x8[55], x8[54]);
3138         x9[56] = _mm512_add_epi32(x8[56], x8[57]);
3139         x9[57] = _mm512_sub_epi32(x8[56], x8[57]);
3140         x9[58] = _mm512_sub_epi32(x8[59], x8[58]);
3141         x9[59] = _mm512_add_epi32(x8[59], x8[58]);
3142         x9[60] = _mm512_add_epi32(x8[60], x8[61]);
3143         x9[61] = _mm512_sub_epi32(x8[60], x8[61]);
3144         x9[62] = _mm512_sub_epi32(x8[63], x8[62]);
3145         x9[63] = _mm512_add_epi32(x8[63], x8[62]);
3146 
3147         // stage 10
3148         __m512i x10[64];
3149         out[0 * stride]  = x9[0];
3150         out[16 * stride] = x9[2];
3151         out[8 * stride]  = x9[4];
3152         out[24 * stride] = x9[6];
3153         out[4 * stride]  = x9[8];
3154         out[20 * stride] = x9[10];
3155         out[12 * stride] = x9[12];
3156         out[28 * stride] = x9[14];
3157         out[2 * stride]  = x9[16];
3158         out[18 * stride] = x9[18];
3159         out[10 * stride] = x9[20];
3160         out[26 * stride] = x9[22];
3161         out[6 * stride]  = x9[24];
3162         out[22 * stride] = x9[26];
3163         out[14 * stride] = x9[28];
3164         out[30 * stride] = x9[30];
3165         x10[32] = half_btf_avx512(&cospi_p63, &x9[32], &cospi_p01, &x9[63], &__rounding, cos_bit);
3166         x10[62] = half_btf_avx512(&cospi_p31, &x9[62], &cospi_m33, &x9[33], &__rounding, cos_bit);
3167         x10[34] = half_btf_avx512(&cospi_p47, &x9[34], &cospi_p17, &x9[61], &__rounding, cos_bit);
3168         x10[60] = half_btf_avx512(&cospi_p15, &x9[60], &cospi_m49, &x9[35], &__rounding, cos_bit);
3169         x10[36] = half_btf_avx512(&cospi_p55, &x9[36], &cospi_p09, &x9[59], &__rounding, cos_bit);
3170         x10[58] = half_btf_avx512(&cospi_p23, &x9[58], &cospi_m41, &x9[37], &__rounding, cos_bit);
3171         x10[38] = half_btf_avx512(&cospi_p39, &x9[38], &cospi_p25, &x9[57], &__rounding, cos_bit);
3172         x10[56] = half_btf_avx512(&cospi_p07, &x9[56], &cospi_m57, &x9[39], &__rounding, cos_bit);
3173         x10[40] = half_btf_avx512(&cospi_p59, &x9[40], &cospi_p05, &x9[55], &__rounding, cos_bit);
3174         x10[54] = half_btf_avx512(&cospi_p27, &x9[54], &cospi_m37, &x9[41], &__rounding, cos_bit);
3175         x10[42] = half_btf_avx512(&cospi_p43, &x9[42], &cospi_p21, &x9[53], &__rounding, cos_bit);
3176         x10[52] = half_btf_avx512(&cospi_p11, &x9[52], &cospi_m53, &x9[43], &__rounding, cos_bit);
3177         x10[44] = half_btf_avx512(&cospi_p51, &x9[44], &cospi_p13, &x9[51], &__rounding, cos_bit);
3178         x10[50] = half_btf_avx512(&cospi_p19, &x9[50], &cospi_m45, &x9[45], &__rounding, cos_bit);
3179         x10[46] = half_btf_avx512(&cospi_p35, &x9[46], &cospi_p29, &x9[49], &__rounding, cos_bit);
3180         x10[48] = half_btf_avx512(&cospi_p03, &x9[48], &cospi_m61, &x9[47], &__rounding, cos_bit);
3181 
3182         // stage 11
3183         out[1 * stride]  = x10[32];
3184         out[3 * stride]  = x10[48];
3185         out[5 * stride]  = x10[40];
3186         out[7 * stride]  = x10[56];
3187         out[9 * stride]  = x10[36];
3188         out[11 * stride] = x10[52];
3189         out[13 * stride] = x10[44];
3190         out[15 * stride] = x10[60];
3191         out[17 * stride] = x10[34];
3192         out[19 * stride] = x10[50];
3193         out[21 * stride] = x10[42];
3194         out[23 * stride] = x10[58];
3195         out[25 * stride] = x10[38];
3196         out[27 * stride] = x10[54];
3197         out[29 * stride] = x10[46];
3198         out[31 * stride] = x10[62];
3199     }
3200 }
3201 
fidtx64x64_N2_avx512(const __m512i * input,__m512i * output)3202 static void fidtx64x64_N2_avx512(const __m512i *input, __m512i *output) {
3203     const uint8_t bits     = 12; // new_sqrt2_bits = 12
3204     const int32_t sqrt     = 4 * 5793; // 4 * new_sqrt2
3205     const int32_t col_num  = 4;
3206     const __m512i newsqrt  = _mm512_set1_epi32(sqrt);
3207     const __m512i rounding = _mm512_set1_epi32(1 << (bits - 1));
3208 
3209     __m512i temp;
3210     int32_t num_iters = 64 * col_num;
3211     for (int32_t i = 0; i < num_iters / 2; i += 4) {
3212         temp          = _mm512_mullo_epi32(input[i], newsqrt);
3213         temp          = _mm512_add_epi32(temp, rounding);
3214         output[i]     = _mm512_srai_epi32(temp, bits);
3215         temp          = _mm512_mullo_epi32(input[i + 1], newsqrt);
3216         temp          = _mm512_add_epi32(temp, rounding);
3217         output[i + 1] = _mm512_srai_epi32(temp, bits);
3218     }
3219 }
3220 
av1_fdct32_new_N2_avx512(const __m512i * input,__m512i * output,const int8_t cos_bit,const int32_t col_num,const int32_t stride)3221 static void av1_fdct32_new_N2_avx512(const __m512i *input, __m512i *output, const int8_t cos_bit,
3222                                      const int32_t col_num, const int32_t stride) {
3223     const int32_t *cospi      = cospi_arr(cos_bit);
3224     const __m512i  __rounding = _mm512_set1_epi32(1 << (cos_bit - 1));
3225     const int32_t  columns    = col_num >> 4;
3226 
3227     __m512i cospi_m32 = _mm512_set1_epi32(-cospi[32]);
3228     __m512i cospi_p32 = _mm512_set1_epi32(cospi[32]);
3229     __m512i cospi_m16 = _mm512_set1_epi32(-cospi[16]);
3230     __m512i cospi_p48 = _mm512_set1_epi32(cospi[48]);
3231     __m512i cospi_m48 = _mm512_set1_epi32(-cospi[48]);
3232     __m512i cospi_m08 = _mm512_set1_epi32(-cospi[8]);
3233     __m512i cospi_p56 = _mm512_set1_epi32(cospi[56]);
3234     __m512i cospi_m56 = _mm512_set1_epi32(-cospi[56]);
3235     __m512i cospi_m40 = _mm512_set1_epi32(-cospi[40]);
3236     __m512i cospi_p24 = _mm512_set1_epi32(cospi[24]);
3237     __m512i cospi_m24 = _mm512_set1_epi32(-cospi[24]);
3238     __m512i cospi_p16 = _mm512_set1_epi32(cospi[16]);
3239     __m512i cospi_p08 = _mm512_set1_epi32(cospi[8]);
3240     __m512i cospi_p04 = _mm512_set1_epi32(cospi[4]);
3241     __m512i cospi_p60 = _mm512_set1_epi32(cospi[60]);
3242     __m512i cospi_m36 = _mm512_set1_epi32(-cospi[36]);
3243     __m512i cospi_p28 = _mm512_set1_epi32(cospi[28]);
3244     __m512i cospi_p20 = _mm512_set1_epi32(cospi[20]);
3245     __m512i cospi_p44 = _mm512_set1_epi32(cospi[44]);
3246     __m512i cospi_m52 = _mm512_set1_epi32(-cospi[52]);
3247     __m512i cospi_p12 = _mm512_set1_epi32(cospi[12]);
3248     __m512i cospi_p02 = _mm512_set1_epi32(cospi[2]);
3249     __m512i cospi_p06 = _mm512_set1_epi32(cospi[6]);
3250     __m512i cospi_p62 = _mm512_set1_epi32(cospi[62]);
3251     __m512i cospi_m34 = _mm512_set1_epi32(-cospi[34]);
3252     __m512i cospi_p30 = _mm512_set1_epi32(cospi[30]);
3253     __m512i cospi_p18 = _mm512_set1_epi32(cospi[18]);
3254     __m512i cospi_p46 = _mm512_set1_epi32(cospi[46]);
3255     __m512i cospi_m50 = _mm512_set1_epi32(-cospi[50]);
3256     __m512i cospi_p14 = _mm512_set1_epi32(cospi[14]);
3257     __m512i cospi_p10 = _mm512_set1_epi32(cospi[10]);
3258     __m512i cospi_p54 = _mm512_set1_epi32(cospi[54]);
3259     __m512i cospi_m42 = _mm512_set1_epi32(-cospi[42]);
3260     __m512i cospi_p22 = _mm512_set1_epi32(cospi[22]);
3261     __m512i cospi_p26 = _mm512_set1_epi32(cospi[26]);
3262     __m512i cospi_p38 = _mm512_set1_epi32(cospi[38]);
3263     __m512i cospi_m58 = _mm512_set1_epi32(-cospi[58]);
3264 
3265     __m512i buf0[32];
3266     __m512i buf1[32];
3267 
3268     for (int32_t col = 0; col < columns; col++) {
3269         const __m512i *in  = &input[col];
3270         __m512i *      out = &output[col];
3271 
3272         // stage 0
3273         // stage 1
3274         buf1[0]  = _mm512_add_epi32(in[0 * stride], in[31 * stride]);
3275         buf1[31] = _mm512_sub_epi32(in[0 * stride], in[31 * stride]);
3276         buf1[1]  = _mm512_add_epi32(in[1 * stride], in[30 * stride]);
3277         buf1[30] = _mm512_sub_epi32(in[1 * stride], in[30 * stride]);
3278         buf1[2]  = _mm512_add_epi32(in[2 * stride], in[29 * stride]);
3279         buf1[29] = _mm512_sub_epi32(in[2 * stride], in[29 * stride]);
3280         buf1[3]  = _mm512_add_epi32(in[3 * stride], in[28 * stride]);
3281         buf1[28] = _mm512_sub_epi32(in[3 * stride], in[28 * stride]);
3282         buf1[4]  = _mm512_add_epi32(in[4 * stride], in[27 * stride]);
3283         buf1[27] = _mm512_sub_epi32(in[4 * stride], in[27 * stride]);
3284         buf1[5]  = _mm512_add_epi32(in[5 * stride], in[26 * stride]);
3285         buf1[26] = _mm512_sub_epi32(in[5 * stride], in[26 * stride]);
3286         buf1[6]  = _mm512_add_epi32(in[6 * stride], in[25 * stride]);
3287         buf1[25] = _mm512_sub_epi32(in[6 * stride], in[25 * stride]);
3288         buf1[7]  = _mm512_add_epi32(in[7 * stride], in[24 * stride]);
3289         buf1[24] = _mm512_sub_epi32(in[7 * stride], in[24 * stride]);
3290         buf1[8]  = _mm512_add_epi32(in[8 * stride], in[23 * stride]);
3291         buf1[23] = _mm512_sub_epi32(in[8 * stride], in[23 * stride]);
3292         buf1[9]  = _mm512_add_epi32(in[9 * stride], in[22 * stride]);
3293         buf1[22] = _mm512_sub_epi32(in[9 * stride], in[22 * stride]);
3294         buf1[10] = _mm512_add_epi32(in[10 * stride], in[21 * stride]);
3295         buf1[21] = _mm512_sub_epi32(in[10 * stride], in[21 * stride]);
3296         buf1[11] = _mm512_add_epi32(in[11 * stride], in[20 * stride]);
3297         buf1[20] = _mm512_sub_epi32(in[11 * stride], in[20 * stride]);
3298         buf1[12] = _mm512_add_epi32(in[12 * stride], in[19 * stride]);
3299         buf1[19] = _mm512_sub_epi32(in[12 * stride], in[19 * stride]);
3300         buf1[13] = _mm512_add_epi32(in[13 * stride], in[18 * stride]);
3301         buf1[18] = _mm512_sub_epi32(in[13 * stride], in[18 * stride]);
3302         buf1[14] = _mm512_add_epi32(in[14 * stride], in[17 * stride]);
3303         buf1[17] = _mm512_sub_epi32(in[14 * stride], in[17 * stride]);
3304         buf1[15] = _mm512_add_epi32(in[15 * stride], in[16 * stride]);
3305         buf1[16] = _mm512_sub_epi32(in[15 * stride], in[16 * stride]);
3306 
3307         // stage 2
3308         buf0[0]  = _mm512_add_epi32(buf1[0], buf1[15]);
3309         buf0[15] = _mm512_sub_epi32(buf1[0], buf1[15]);
3310         buf0[1]  = _mm512_add_epi32(buf1[1], buf1[14]);
3311         buf0[14] = _mm512_sub_epi32(buf1[1], buf1[14]);
3312         buf0[2]  = _mm512_add_epi32(buf1[2], buf1[13]);
3313         buf0[13] = _mm512_sub_epi32(buf1[2], buf1[13]);
3314         buf0[3]  = _mm512_add_epi32(buf1[3], buf1[12]);
3315         buf0[12] = _mm512_sub_epi32(buf1[3], buf1[12]);
3316         buf0[4]  = _mm512_add_epi32(buf1[4], buf1[11]);
3317         buf0[11] = _mm512_sub_epi32(buf1[4], buf1[11]);
3318         buf0[5]  = _mm512_add_epi32(buf1[5], buf1[10]);
3319         buf0[10] = _mm512_sub_epi32(buf1[5], buf1[10]);
3320         buf0[6]  = _mm512_add_epi32(buf1[6], buf1[9]);
3321         buf0[9]  = _mm512_sub_epi32(buf1[6], buf1[9]);
3322         buf0[7]  = _mm512_add_epi32(buf1[7], buf1[8]);
3323         buf0[8]  = _mm512_sub_epi32(buf1[7], buf1[8]);
3324         buf0[16] = buf1[16];
3325         buf0[17] = buf1[17];
3326         buf0[18] = buf1[18];
3327         buf0[19] = buf1[19];
3328         btf_32_type0_avx512_new(
3329             cospi_m32, cospi_p32, buf1[20], buf1[27], buf0[20], buf0[27], __rounding, cos_bit);
3330         btf_32_type0_avx512_new(
3331             cospi_m32, cospi_p32, buf1[21], buf1[26], buf0[21], buf0[26], __rounding, cos_bit);
3332         btf_32_type0_avx512_new(
3333             cospi_m32, cospi_p32, buf1[22], buf1[25], buf0[22], buf0[25], __rounding, cos_bit);
3334         btf_32_type0_avx512_new(
3335             cospi_m32, cospi_p32, buf1[23], buf1[24], buf0[23], buf0[24], __rounding, cos_bit);
3336         buf0[28] = buf1[28];
3337         buf0[29] = buf1[29];
3338         buf0[30] = buf1[30];
3339         buf0[31] = buf1[31];
3340 
3341         // stage 3
3342         buf1[0] = _mm512_add_epi32(buf0[0], buf0[7]);
3343         buf1[7] = _mm512_sub_epi32(buf0[0], buf0[7]);
3344         buf1[1] = _mm512_add_epi32(buf0[1], buf0[6]);
3345         buf1[6] = _mm512_sub_epi32(buf0[1], buf0[6]);
3346         buf1[2] = _mm512_add_epi32(buf0[2], buf0[5]);
3347         buf1[5] = _mm512_sub_epi32(buf0[2], buf0[5]);
3348         buf1[3] = _mm512_add_epi32(buf0[3], buf0[4]);
3349         buf1[4] = _mm512_sub_epi32(buf0[3], buf0[4]);
3350         buf1[8] = buf0[8];
3351         buf1[9] = buf0[9];
3352         btf_32_type0_avx512_new(
3353             cospi_m32, cospi_p32, buf0[10], buf0[13], buf1[10], buf1[13], __rounding, cos_bit);
3354         btf_32_type0_avx512_new(
3355             cospi_m32, cospi_p32, buf0[11], buf0[12], buf1[11], buf1[12], __rounding, cos_bit);
3356         buf1[14] = buf0[14];
3357         buf1[15] = buf0[15];
3358         buf1[16] = _mm512_add_epi32(buf0[16], buf0[23]);
3359         buf1[23] = _mm512_sub_epi32(buf0[16], buf0[23]);
3360         buf1[17] = _mm512_add_epi32(buf0[17], buf0[22]);
3361         buf1[22] = _mm512_sub_epi32(buf0[17], buf0[22]);
3362         buf1[18] = _mm512_add_epi32(buf0[18], buf0[21]);
3363         buf1[21] = _mm512_sub_epi32(buf0[18], buf0[21]);
3364         buf1[19] = _mm512_add_epi32(buf0[19], buf0[20]);
3365         buf1[20] = _mm512_sub_epi32(buf0[19], buf0[20]);
3366         buf1[24] = _mm512_sub_epi32(buf0[31], buf0[24]);
3367         buf1[31] = _mm512_add_epi32(buf0[31], buf0[24]);
3368         buf1[25] = _mm512_sub_epi32(buf0[30], buf0[25]);
3369         buf1[30] = _mm512_add_epi32(buf0[30], buf0[25]);
3370         buf1[26] = _mm512_sub_epi32(buf0[29], buf0[26]);
3371         buf1[29] = _mm512_add_epi32(buf0[29], buf0[26]);
3372         buf1[27] = _mm512_sub_epi32(buf0[28], buf0[27]);
3373         buf1[28] = _mm512_add_epi32(buf0[28], buf0[27]);
3374 
3375         // stage 4
3376         buf0[0] = _mm512_add_epi32(buf1[0], buf1[3]);
3377         buf0[3] = _mm512_sub_epi32(buf1[0], buf1[3]);
3378         buf0[1] = _mm512_add_epi32(buf1[1], buf1[2]);
3379         buf0[2] = _mm512_sub_epi32(buf1[1], buf1[2]);
3380         buf0[4] = buf1[4];
3381         btf_32_type0_avx512_new(
3382             cospi_m32, cospi_p32, buf1[5], buf1[6], buf0[5], buf0[6], __rounding, cos_bit);
3383         buf0[7]  = buf1[7];
3384         buf0[8]  = _mm512_add_epi32(buf1[8], buf1[11]);
3385         buf0[11] = _mm512_sub_epi32(buf1[8], buf1[11]);
3386         buf0[9]  = _mm512_add_epi32(buf1[9], buf1[10]);
3387         buf0[10] = _mm512_sub_epi32(buf1[9], buf1[10]);
3388         buf0[12] = _mm512_sub_epi32(buf1[15], buf1[12]);
3389         buf0[15] = _mm512_add_epi32(buf1[15], buf1[12]);
3390         buf0[13] = _mm512_sub_epi32(buf1[14], buf1[13]);
3391         buf0[14] = _mm512_add_epi32(buf1[14], buf1[13]);
3392         buf0[16] = buf1[16];
3393         buf0[17] = buf1[17];
3394         btf_32_type0_avx512_new(
3395             cospi_m16, cospi_p48, buf1[18], buf1[29], buf0[18], buf0[29], __rounding, cos_bit);
3396         btf_32_type0_avx512_new(
3397             cospi_m16, cospi_p48, buf1[19], buf1[28], buf0[19], buf0[28], __rounding, cos_bit);
3398         btf_32_type0_avx512_new(
3399             cospi_m48, cospi_m16, buf1[20], buf1[27], buf0[20], buf0[27], __rounding, cos_bit);
3400         btf_32_type0_avx512_new(
3401             cospi_m48, cospi_m16, buf1[21], buf1[26], buf0[21], buf0[26], __rounding, cos_bit);
3402         buf0[22] = buf1[22];
3403         buf0[23] = buf1[23];
3404         buf0[24] = buf1[24];
3405         buf0[25] = buf1[25];
3406         buf0[30] = buf1[30];
3407         buf0[31] = buf1[31];
3408 
3409         // stage 5
3410         buf1[0] = half_btf_avx512(&cospi_p32, &buf0[0], &cospi_p32, &buf0[1], &__rounding, cos_bit);
3411         buf1[2] = half_btf_avx512(&cospi_p48, &buf0[2], &cospi_p16, &buf0[3], &__rounding, cos_bit);
3412         buf1[4] = _mm512_add_epi32(buf0[4], buf0[5]);
3413         buf1[5] = _mm512_sub_epi32(buf0[4], buf0[5]);
3414         buf1[6] = _mm512_sub_epi32(buf0[7], buf0[6]);
3415         buf1[7] = _mm512_add_epi32(buf0[7], buf0[6]);
3416         buf1[8] = buf0[8];
3417         btf_32_type0_avx512_new(
3418             cospi_m16, cospi_p48, buf0[9], buf0[14], buf1[9], buf1[14], __rounding, cos_bit);
3419         btf_32_type0_avx512_new(
3420             cospi_m48, cospi_m16, buf0[10], buf0[13], buf1[10], buf1[13], __rounding, cos_bit);
3421         buf1[11] = buf0[11];
3422         buf1[12] = buf0[12];
3423         buf1[15] = buf0[15];
3424         buf1[16] = _mm512_add_epi32(buf0[16], buf0[19]);
3425         buf1[19] = _mm512_sub_epi32(buf0[16], buf0[19]);
3426         buf1[17] = _mm512_add_epi32(buf0[17], buf0[18]);
3427         buf1[18] = _mm512_sub_epi32(buf0[17], buf0[18]);
3428         buf1[20] = _mm512_sub_epi32(buf0[23], buf0[20]);
3429         buf1[23] = _mm512_add_epi32(buf0[23], buf0[20]);
3430         buf1[21] = _mm512_sub_epi32(buf0[22], buf0[21]);
3431         buf1[22] = _mm512_add_epi32(buf0[22], buf0[21]);
3432         buf1[24] = _mm512_add_epi32(buf0[24], buf0[27]);
3433         buf1[27] = _mm512_sub_epi32(buf0[24], buf0[27]);
3434         buf1[25] = _mm512_add_epi32(buf0[25], buf0[26]);
3435         buf1[26] = _mm512_sub_epi32(buf0[25], buf0[26]);
3436         buf1[28] = _mm512_sub_epi32(buf0[31], buf0[28]);
3437         buf1[31] = _mm512_add_epi32(buf0[31], buf0[28]);
3438         buf1[29] = _mm512_sub_epi32(buf0[30], buf0[29]);
3439         buf1[30] = _mm512_add_epi32(buf0[30], buf0[29]);
3440 
3441         // stage 6
3442         buf0[0] = buf1[0];
3443         buf0[2] = buf1[2];
3444         buf0[4] = half_btf_avx512(&cospi_p56, &buf1[4], &cospi_p08, &buf1[7], &__rounding, cos_bit);
3445         buf0[6] = half_btf_avx512(&cospi_p24, &buf1[6], &cospi_m40, &buf1[5], &__rounding, cos_bit);
3446         buf0[8] = _mm512_add_epi32(buf1[8], buf1[9]);
3447         buf0[9] = _mm512_sub_epi32(buf1[8], buf1[9]);
3448         buf0[10] = _mm512_sub_epi32(buf1[11], buf1[10]);
3449         buf0[11] = _mm512_add_epi32(buf1[11], buf1[10]);
3450         buf0[12] = _mm512_add_epi32(buf1[12], buf1[13]);
3451         buf0[13] = _mm512_sub_epi32(buf1[12], buf1[13]);
3452         buf0[14] = _mm512_sub_epi32(buf1[15], buf1[14]);
3453         buf0[15] = _mm512_add_epi32(buf1[15], buf1[14]);
3454         buf0[16] = buf1[16];
3455         btf_32_type0_avx512_new(
3456             cospi_m08, cospi_p56, buf1[17], buf1[30], buf0[17], buf0[30], __rounding, cos_bit);
3457         btf_32_type0_avx512_new(
3458             cospi_m56, cospi_m08, buf1[18], buf1[29], buf0[18], buf0[29], __rounding, cos_bit);
3459         buf0[19] = buf1[19];
3460         buf0[20] = buf1[20];
3461         btf_32_type0_avx512_new(
3462             cospi_m40, cospi_p24, buf1[21], buf1[26], buf0[21], buf0[26], __rounding, cos_bit);
3463         btf_32_type0_avx512_new(
3464             cospi_m24, cospi_m40, buf1[22], buf1[25], buf0[22], buf0[25], __rounding, cos_bit);
3465         buf0[23] = buf1[23];
3466         buf0[24] = buf1[24];
3467         buf0[27] = buf1[27];
3468         buf0[28] = buf1[28];
3469         buf0[31] = buf1[31];
3470 
3471         // stage 7
3472         buf1[0] = buf0[0];
3473         buf1[2] = buf0[2];
3474         buf1[4] = buf0[4];
3475         buf1[6] = buf0[6];
3476         buf1[8] = half_btf_avx512(
3477             &cospi_p60, &buf0[8], &cospi_p04, &buf0[15], &__rounding, cos_bit);
3478         buf1[14] = half_btf_avx512(
3479             &cospi_p28, &buf0[14], &cospi_m36, &buf0[9], &__rounding, cos_bit);
3480         buf1[10] = half_btf_avx512(
3481             &cospi_p44, &buf0[10], &cospi_p20, &buf0[13], &__rounding, cos_bit);
3482         buf1[12] = half_btf_avx512(
3483             &cospi_p12, &buf0[12], &cospi_m52, &buf0[11], &__rounding, cos_bit);
3484         buf1[16] = _mm512_add_epi32(buf0[16], buf0[17]);
3485         buf1[17] = _mm512_sub_epi32(buf0[16], buf0[17]);
3486         buf1[18] = _mm512_sub_epi32(buf0[19], buf0[18]);
3487         buf1[19] = _mm512_add_epi32(buf0[19], buf0[18]);
3488         buf1[20] = _mm512_add_epi32(buf0[20], buf0[21]);
3489         buf1[21] = _mm512_sub_epi32(buf0[20], buf0[21]);
3490         buf1[22] = _mm512_sub_epi32(buf0[23], buf0[22]);
3491         buf1[23] = _mm512_add_epi32(buf0[23], buf0[22]);
3492         buf1[24] = _mm512_add_epi32(buf0[24], buf0[25]);
3493         buf1[25] = _mm512_sub_epi32(buf0[24], buf0[25]);
3494         buf1[26] = _mm512_sub_epi32(buf0[27], buf0[26]);
3495         buf1[27] = _mm512_add_epi32(buf0[27], buf0[26]);
3496         buf1[28] = _mm512_add_epi32(buf0[28], buf0[29]);
3497         buf1[29] = _mm512_sub_epi32(buf0[28], buf0[29]);
3498         buf1[30] = _mm512_sub_epi32(buf0[31], buf0[30]);
3499         buf1[31] = _mm512_add_epi32(buf0[31], buf0[30]);
3500 
3501         // stage 8
3502         buf0[0]  = buf1[0];
3503         buf0[2]  = buf1[2];
3504         buf0[4]  = buf1[4];
3505         buf0[6]  = buf1[6];
3506         buf0[8]  = buf1[8];
3507         buf0[10] = buf1[10];
3508         buf0[12] = buf1[12];
3509         buf0[14] = buf1[14];
3510         buf0[16] = half_btf_avx512(
3511             &cospi_p62, &buf1[16], &cospi_p02, &buf1[31], &__rounding, cos_bit);
3512         buf0[30] = half_btf_avx512(
3513             &cospi_p30, &buf1[30], &cospi_m34, &buf1[17], &__rounding, cos_bit);
3514         buf0[18] = half_btf_avx512(
3515             &cospi_p46, &buf1[18], &cospi_p18, &buf1[29], &__rounding, cos_bit);
3516         buf0[28] = half_btf_avx512(
3517             &cospi_p14, &buf1[28], &cospi_m50, &buf1[19], &__rounding, cos_bit);
3518         buf0[20] = half_btf_avx512(
3519             &cospi_p54, &buf1[20], &cospi_p10, &buf1[27], &__rounding, cos_bit);
3520         buf0[26] = half_btf_avx512(
3521             &cospi_p22, &buf1[26], &cospi_m42, &buf1[21], &__rounding, cos_bit);
3522         buf0[22] = half_btf_avx512(
3523             &cospi_p38, &buf1[22], &cospi_p26, &buf1[25], &__rounding, cos_bit);
3524         buf0[24] = half_btf_avx512(
3525             &cospi_p06, &buf1[24], &cospi_m58, &buf1[23], &__rounding, cos_bit);
3526 
3527         // stage 9
3528         out[0 * stride]  = buf0[0];
3529         out[1 * stride]  = buf0[16];
3530         out[2 * stride]  = buf0[8];
3531         out[3 * stride]  = buf0[24];
3532         out[4 * stride]  = buf0[4];
3533         out[5 * stride]  = buf0[20];
3534         out[6 * stride]  = buf0[12];
3535         out[7 * stride]  = buf0[28];
3536         out[8 * stride]  = buf0[2];
3537         out[9 * stride]  = buf0[18];
3538         out[10 * stride] = buf0[10];
3539         out[11 * stride] = buf0[26];
3540         out[12 * stride] = buf0[6];
3541         out[13 * stride] = buf0[22];
3542         out[14 * stride] = buf0[14];
3543         out[15 * stride] = buf0[30];
3544     }
3545 }
3546 
av1_idtx32_wxh_N2_avx512(const __m512i * input,__m512i * output,int32_t col_num,int32_t row_num)3547 static AOM_FORCE_INLINE void av1_idtx32_wxh_N2_avx512(const __m512i *input, __m512i *output,
3548                                                       int32_t col_num, int32_t row_num) {
3549     for (int32_t i = 0; i < row_num; i++) {
3550         for (int32_t j = 0; j < col_num / 2; j++) {
3551             output[i * col_num + j] = _mm512_slli_epi32(input[i * col_num + j], (uint8_t)2);
3552         }
3553     }
3554 }
3555 
av1_fwd_txfm2d_32x32_N2_avx512(int16_t * input,int32_t * output,uint32_t stride,TxType tx_type,uint8_t bd)3556 void av1_fwd_txfm2d_32x32_N2_avx512(int16_t *input, int32_t *output, uint32_t stride,
3557                                     TxType tx_type, uint8_t bd) {
3558     DECLARE_ALIGNED(64, int32_t, txfm_buf[1024]);
3559     Txfm2dFlipCfg cfg;
3560     av1_transform_config(tx_type, TX_32X32, &cfg);
3561     (void)bd;
3562     const int32_t txfm_size       = 32;
3563     const int8_t *shift           = cfg.shift;
3564     const int8_t  cos_bit_col     = cfg.cos_bit_col;
3565     const int8_t  cos_bit_row     = cfg.cos_bit_row;
3566     __m512i *     buf_512         = (__m512i *)txfm_buf;
3567     __m512i *     out_512         = (__m512i *)output;
3568     int32_t       num_per_512     = 16;
3569     int32_t       txfm2d_size_512 = txfm_size * txfm_size / num_per_512;
3570     int32_t       col_num         = txfm_size / num_per_512;
3571 
3572     switch (tx_type) {
3573     case IDTX:
3574         load_buffer_16xh_in_32x32_avx512(input, buf_512, stride, txfm_size / 2);
3575         av1_round_shift_array_wxh_N2_avx512(buf_512, out_512, col_num, txfm_size / 2, -shift[0]);
3576         av1_idtx32_wxh_N2_avx512(out_512, buf_512, col_num, txfm_size / 2);
3577         av1_round_shift_array_wxh_N2_avx512(buf_512, out_512, col_num, txfm_size / 2, -shift[1]);
3578         av1_idtx32_wxh_N2_avx512(out_512, buf_512, col_num, txfm_size / 2);
3579         av1_round_shift_array_wxh_N2_avx512(buf_512, out_512, col_num, txfm_size / 2, -shift[2]);
3580         clear_buffer_wxh_N2_avx512(out_512, 32, 32);
3581         break;
3582     case DCT_DCT:
3583         load_buffer_32x32_avx512(input, buf_512, stride);
3584         av1_round_shift_array_avx512(buf_512, out_512, txfm2d_size_512, -shift[0]);
3585         av1_fdct32_new_N2_avx512(out_512, buf_512, cos_bit_col, txfm_size, col_num);
3586         av1_round_shift_array_avx512(buf_512, out_512, txfm2d_size_512 / 2, -shift[1]);
3587         transpose_16nx16n_N2_half_avx512(txfm_size, out_512, buf_512);
3588         av1_fdct32_new_N2_avx512(buf_512, out_512, cos_bit_row, txfm_size / 2, col_num);
3589         av1_round_shift_array_wxh_N2_avx512(out_512, buf_512, col_num, txfm_size / 2, -shift[2]);
3590         transpose_16nx16n_N2_quad_avx512(txfm_size, buf_512, out_512);
3591         clear_buffer_wxh_N2_avx512(out_512, 32, 32);
3592         break;
3593     case V_DCT:
3594         load_buffer_16xh_in_32x32_avx512(input, buf_512, stride, txfm_size);
3595         av1_round_shift_array_avx512(buf_512, out_512, txfm2d_size_512, -shift[0]);
3596         av1_fdct32_new_N2_avx512(out_512, buf_512, cos_bit_col, txfm_size / 2, col_num);
3597         av1_round_shift_array_wxh_N2_avx512(buf_512, out_512, col_num, txfm_size / 2, -shift[1]);
3598         av1_idtx32_wxh_N2_avx512(out_512, buf_512, col_num, txfm_size / 2);
3599         av1_round_shift_array_wxh_N2_avx512(buf_512, out_512, col_num, txfm_size / 2, -shift[2]);
3600         clear_buffer_wxh_N2_avx512(out_512, 32, 32);
3601         break;
3602     case H_DCT:
3603         load_buffer_32xh_in_32x32_avx512(input, buf_512, stride, txfm_size / 2);
3604         av1_round_shift_array_avx512(buf_512, out_512, txfm2d_size_512 / 2, -shift[0]);
3605         av1_idtx32_new_avx512(out_512, buf_512, cos_bit_col, 1);
3606         av1_round_shift_array_avx512(buf_512, out_512, txfm2d_size_512 / 2, -shift[1]);
3607         transpose_16nx16n_N2_half_avx512(txfm_size, out_512, buf_512);
3608         av1_fdct32_new_N2_avx512(buf_512, out_512, cos_bit_row, txfm_size / 2, col_num);
3609         av1_round_shift_array_wxh_N2_avx512(out_512, buf_512, col_num, txfm_size / 2, -shift[2]);
3610         transpose_16nx16n_N2_quad_avx512(txfm_size, buf_512, out_512);
3611         clear_buffer_wxh_N2_avx512(out_512, 32, 32);
3612         break;
3613     default: assert(0);
3614     }
3615 }
3616 
av1_fwd_txfm2d_64x64_N2_avx512(int16_t * input,int32_t * output,uint32_t stride,TxType tx_type,uint8_t bd)3617 void av1_fwd_txfm2d_64x64_N2_avx512(int16_t *input, int32_t *output, uint32_t stride,
3618                                     TxType tx_type, uint8_t bd) {
3619     (void)bd;
3620     __m512i       in[256];
3621     __m512i *     out         = (__m512i *)output;
3622     const int32_t txw_idx     = tx_size_wide_log2[TX_64X64] - tx_size_wide_log2[0];
3623     const int32_t txh_idx     = tx_size_high_log2[TX_64X64] - tx_size_high_log2[0];
3624     const int8_t *shift       = fwd_txfm_shift_ls[TX_64X64];
3625     const int32_t txfm_size   = 64;
3626     const int32_t num_per_512 = 16;
3627     int32_t       col_num     = txfm_size / num_per_512;
3628 
3629     switch (tx_type) {
3630     case IDTX:
3631         load_buffer_32x32_in_64x64_avx512(input, stride, out);
3632         fidtx64x64_N2_avx512(out, in);
3633         av1_round_shift_array_wxh_N2_avx512(in, out, col_num, txfm_size / 2, -shift[1]);
3634         /*row wise transform*/
3635         fidtx64x64_N2_avx512(out, in);
3636         av1_round_shift_array_wxh_N2_avx512(in, out, col_num, txfm_size / 2, -shift[2]);
3637         clear_buffer_wxh_N2_avx512(out, 64, 64);
3638         break;
3639     case DCT_DCT:
3640         load_buffer_64x64_avx512(input, stride, out);
3641         av1_fdct64_new_N2_avx512(out, in, fwd_cos_bit_col[txw_idx][txh_idx], txfm_size, col_num);
3642         av1_round_shift_array_avx512(in, out, 256 / 2, -shift[1]);
3643         transpose_16nx16n_N2_half_avx512(64, out, in);
3644 
3645         /*row wise transform*/
3646         av1_fdct64_new_N2_avx512(
3647             in, out, fwd_cos_bit_row[txw_idx][txh_idx], txfm_size / 2, col_num);
3648         av1_round_shift_array_wxh_N2_avx512(out, in, col_num, txfm_size / 2, -shift[2]);
3649         transpose_16nx16n_N2_quad_avx512(64, in, out);
3650         clear_buffer_wxh_N2_avx512(out, 64, 64);
3651         break;
3652     default: assert(0);
3653     }
3654 }
3655 
av1_fwd_txfm2d_32x64_N2_avx512(int16_t * input,int32_t * output,uint32_t stride,TxType tx_type,uint8_t bd)3656 void av1_fwd_txfm2d_32x64_N2_avx512(int16_t *input, int32_t *output, uint32_t stride,
3657                                     TxType tx_type, uint8_t bd) {
3658     (void)tx_type;
3659     __m512i       in[128];
3660     __m512i *     outcoef512    = (__m512i *)output;
3661     const int8_t *shift         = fwd_txfm_shift_ls[TX_32X64];
3662     const int32_t txw_idx       = get_txw_idx(TX_32X64);
3663     const int32_t txh_idx       = get_txh_idx(TX_32X64);
3664     const int32_t txfm_size_col = tx_size_wide[TX_32X64];
3665     const int32_t txfm_size_row = tx_size_high[TX_32X64];
3666     const int8_t  bitcol        = fwd_cos_bit_col[txw_idx][txh_idx];
3667     const int8_t  bitrow        = fwd_cos_bit_row[txw_idx][txh_idx];
3668     const int32_t num_row       = txfm_size_row >> 4;
3669     const int32_t num_col       = txfm_size_col >> 4;
3670 
3671     // column transform
3672     load_buffer_32x16n(input, in, stride, 0, 0, shift[0], txfm_size_row);
3673     av1_fdct64_new_N2_avx512(in, in, bitcol, txfm_size_col, num_col);
3674     for (int32_t i = 0; i < 4; i++) { col_txfm_16x16_rounding_avx512((in + i * 16), -shift[1]); }
3675     transpose_16nx16m_N2_half_avx512(in, outcoef512, txfm_size_col, txfm_size_row);
3676 
3677     // row transform
3678     av1_fdct32_new_N2_avx512(outcoef512, in, bitrow, txfm_size_row / 2, num_row);
3679     transpose_16nx16m_N2_quad_avx512(in, outcoef512, txfm_size_row, txfm_size_col);
3680     av1_round_shift_rect_array_wxh_N2_avx512(
3681         outcoef512, outcoef512, num_col, txfm_size_row / 2, -shift[2], 5793);
3682     clear_buffer_wxh_N2_avx512(outcoef512, 32, 64);
3683     (void)bd;
3684 }
3685 
av1_fwd_txfm2d_64x32_N2_avx512(int16_t * input,int32_t * output,uint32_t stride,TxType tx_type,uint8_t bd)3686 void av1_fwd_txfm2d_64x32_N2_avx512(int16_t *input, int32_t *output, uint32_t stride,
3687                                     TxType tx_type, uint8_t bd) {
3688     (void)tx_type;
3689     __m512i       in[128];
3690     __m512i *     outcoef512    = (__m512i *)output;
3691     const int8_t *shift         = fwd_txfm_shift_ls[TX_64X32];
3692     const int32_t txw_idx       = get_txw_idx(TX_64X32);
3693     const int32_t txh_idx       = get_txh_idx(TX_64X32);
3694     const int32_t txfm_size_col = tx_size_wide[TX_64X32];
3695     const int32_t txfm_size_row = tx_size_high[TX_64X32];
3696     const int8_t  bitcol        = fwd_cos_bit_col[txw_idx][txh_idx];
3697     const int8_t  bitrow        = fwd_cos_bit_row[txw_idx][txh_idx];
3698     const int32_t num_row       = txfm_size_row >> 4;
3699     const int32_t num_col       = txfm_size_col >> 4;
3700 
3701     // column transform
3702     for (int32_t i = 0; i < 32; i++) {
3703         load_buffer_32_avx512(input + 0 + i * stride, in + 0 + i * 4, 16, 0, 0, shift[0]);
3704         load_buffer_32_avx512(input + 32 + i * stride, in + 2 + i * 4, 16, 0, 0, shift[0]);
3705     }
3706     av1_fdct32_new_N2_avx512(in, in, bitcol, txfm_size_col, num_col);
3707     for (int32_t i = 0; i < 4; i++) { col_txfm_16x16_rounding_avx512((in + i * 16), -shift[1]); }
3708     transpose_16nx16m_N2_half_avx512(in, outcoef512, txfm_size_col, txfm_size_row);
3709 
3710     // row transform
3711     av1_fdct64_new_N2_avx512(outcoef512, in, bitrow, txfm_size_row / 2, num_row);
3712     transpose_16nx16m_N2_quad_avx512(in, outcoef512, txfm_size_row, txfm_size_col);
3713     av1_round_shift_rect_array_wxh_N2_avx512(
3714         outcoef512, outcoef512, num_col, txfm_size_row / 2, -shift[2], 5793);
3715     clear_buffer_wxh_N2_avx512(outcoef512, 64, 32);
3716     (void)bd;
3717 }
3718 /******************************END*N2***************************************/
3719 
3720 /********************************N4****************************************/
3721 
clear_buffer_64x64_N4_avx512(__m512i * buff)3722 static AOM_FORCE_INLINE void clear_buffer_64x64_N4_avx512(__m512i *buff) {
3723     const __m512i zero512 = _mm512_setzero_si512();
3724     int32_t       i;
3725 
3726     for (i = 0; i < 16; i++) {
3727         buff[i * 4 + 1] = zero512;
3728         buff[i * 4 + 2] = zero512;
3729         buff[i * 4 + 3] = zero512;
3730     }
3731 
3732     for (i = 16; i < 64; i++) {
3733         buff[i * 4 + 0] = zero512;
3734         buff[i * 4 + 1] = zero512;
3735         buff[i * 4 + 2] = zero512;
3736         buff[i * 4 + 3] = zero512;
3737     }
3738 }
3739 
av1_round_shift_array_64x64_N4_avx512(__m512i * input,__m512i * output,const int8_t bit)3740 static AOM_FORCE_INLINE void av1_round_shift_array_64x64_N4_avx512(__m512i *input, __m512i *output,
3741                                                                    const int8_t bit) {
3742     if (bit > 0) {
3743         __m512i round = _mm512_set1_epi32(1 << (bit - 1));
3744         int32_t i;
3745         for (i = 0; i < 16; i++) {
3746             output[i * 4] = _mm512_srai_epi32(_mm512_add_epi32(input[i * 4], round), (uint8_t)bit);
3747         }
3748     } else {
3749         int32_t i;
3750         for (i = 0; i < 16; i++) {
3751             output[i * 4] = _mm512_slli_epi32(input[i * 4], (uint8_t)(-bit));
3752         }
3753     }
3754 }
3755 
load_buffer_16x16_in_64x64_avx512(const int16_t * input,int32_t stride,__m512i * output)3756 static AOM_FORCE_INLINE void load_buffer_16x16_in_64x64_avx512(const int16_t *input, int32_t stride,
3757                                                                __m512i *output) {
3758     __m256i x0;
3759     __m512i v0;
3760     int32_t i;
3761 
3762     for (i = 0; i < 16; ++i) {
3763         x0 = _mm256_loadu_si256((const __m256i *)(input));
3764         v0 = _mm512_cvtepi16_epi32(x0);
3765         _mm512_storeu_si512(output, v0);
3766 
3767         input += stride;
3768         output += 4;
3769     }
3770 }
3771 
av1_fdct64_new_N4_avx512(const __m512i * input,__m512i * output,const int8_t cos_bit,const int32_t col_num,const int32_t stride)3772 static void av1_fdct64_new_N4_avx512(const __m512i *input, __m512i *output, const int8_t cos_bit,
3773                                      const int32_t col_num, const int32_t stride) {
3774     const int32_t *cospi      = cospi_arr(cos_bit);
3775     const __m512i  __rounding = _mm512_set1_epi32(1 << (cos_bit - 1));
3776     const int32_t  columns    = col_num >> 4;
3777 
3778     __m512i cospi_m32 = _mm512_set1_epi32(-cospi[32]);
3779     __m512i cospi_p32 = _mm512_set1_epi32(cospi[32]);
3780     __m512i cospi_m16 = _mm512_set1_epi32(-cospi[16]);
3781     __m512i cospi_p48 = _mm512_set1_epi32(cospi[48]);
3782     __m512i cospi_m48 = _mm512_set1_epi32(-cospi[48]);
3783     __m512i cospi_m08 = _mm512_set1_epi32(-cospi[8]);
3784     __m512i cospi_p56 = _mm512_set1_epi32(cospi[56]);
3785     __m512i cospi_m56 = _mm512_set1_epi32(-cospi[56]);
3786     __m512i cospi_m40 = _mm512_set1_epi32(-cospi[40]);
3787     __m512i cospi_p24 = _mm512_set1_epi32(cospi[24]);
3788     __m512i cospi_m24 = _mm512_set1_epi32(-cospi[24]);
3789     __m512i cospi_p08 = _mm512_set1_epi32(cospi[8]);
3790     __m512i cospi_p60 = _mm512_set1_epi32(cospi[60]);
3791     __m512i cospi_p04 = _mm512_set1_epi32(cospi[4]);
3792     __m512i cospi_p28 = _mm512_set1_epi32(cospi[28]);
3793     __m512i cospi_p44 = _mm512_set1_epi32(cospi[44]);
3794     __m512i cospi_p12 = _mm512_set1_epi32(cospi[12]);
3795     __m512i cospi_m04 = _mm512_set1_epi32(-cospi[4]);
3796     __m512i cospi_m60 = _mm512_set1_epi32(-cospi[60]);
3797     __m512i cospi_m36 = _mm512_set1_epi32(-cospi[36]);
3798     __m512i cospi_m28 = _mm512_set1_epi32(-cospi[28]);
3799     __m512i cospi_m20 = _mm512_set1_epi32(-cospi[20]);
3800     __m512i cospi_m44 = _mm512_set1_epi32(-cospi[44]);
3801     __m512i cospi_m52 = _mm512_set1_epi32(-cospi[52]);
3802     __m512i cospi_m12 = _mm512_set1_epi32(-cospi[12]);
3803     __m512i cospi_p62 = _mm512_set1_epi32(cospi[62]);
3804     __m512i cospi_p02 = _mm512_set1_epi32(cospi[2]);
3805     __m512i cospi_p14 = _mm512_set1_epi32(cospi[14]);
3806     __m512i cospi_m50 = _mm512_set1_epi32(-cospi[50]);
3807     __m512i cospi_p54 = _mm512_set1_epi32(cospi[54]);
3808     __m512i cospi_p10 = _mm512_set1_epi32(cospi[10]);
3809     __m512i cospi_p06 = _mm512_set1_epi32(cospi[6]);
3810     __m512i cospi_m58 = _mm512_set1_epi32(-cospi[58]);
3811     __m512i cospi_p63 = _mm512_set1_epi32(cospi[63]);
3812     __m512i cospi_p01 = _mm512_set1_epi32(cospi[1]);
3813     __m512i cospi_p15 = _mm512_set1_epi32(cospi[15]);
3814     __m512i cospi_m49 = _mm512_set1_epi32(-cospi[49]);
3815     __m512i cospi_p55 = _mm512_set1_epi32(cospi[55]);
3816     __m512i cospi_p09 = _mm512_set1_epi32(cospi[9]);
3817     __m512i cospi_p07 = _mm512_set1_epi32(cospi[7]);
3818     __m512i cospi_m57 = _mm512_set1_epi32(-cospi[57]);
3819     __m512i cospi_p59 = _mm512_set1_epi32(cospi[59]);
3820     __m512i cospi_p05 = _mm512_set1_epi32(cospi[5]);
3821     __m512i cospi_p11 = _mm512_set1_epi32(cospi[11]);
3822     __m512i cospi_m53 = _mm512_set1_epi32(-cospi[53]);
3823     __m512i cospi_p51 = _mm512_set1_epi32(cospi[51]);
3824     __m512i cospi_p13 = _mm512_set1_epi32(cospi[13]);
3825     __m512i cospi_p03 = _mm512_set1_epi32(cospi[3]);
3826     __m512i cospi_m61 = _mm512_set1_epi32(-cospi[61]);
3827 
3828     for (int32_t col = 0; col < columns; col++) {
3829         const __m512i *in  = &input[col];
3830         __m512i *      out = &output[col];
3831 
3832         // stage 1
3833         __m512i x1[64];
3834         x1[0]  = _mm512_add_epi32(in[0 * stride], in[63 * stride]);
3835         x1[63] = _mm512_sub_epi32(in[0 * stride], in[63 * stride]);
3836         x1[1]  = _mm512_add_epi32(in[1 * stride], in[62 * stride]);
3837         x1[62] = _mm512_sub_epi32(in[1 * stride], in[62 * stride]);
3838         x1[2]  = _mm512_add_epi32(in[2 * stride], in[61 * stride]);
3839         x1[61] = _mm512_sub_epi32(in[2 * stride], in[61 * stride]);
3840         x1[3]  = _mm512_add_epi32(in[3 * stride], in[60 * stride]);
3841         x1[60] = _mm512_sub_epi32(in[3 * stride], in[60 * stride]);
3842         x1[4]  = _mm512_add_epi32(in[4 * stride], in[59 * stride]);
3843         x1[59] = _mm512_sub_epi32(in[4 * stride], in[59 * stride]);
3844         x1[5]  = _mm512_add_epi32(in[5 * stride], in[58 * stride]);
3845         x1[58] = _mm512_sub_epi32(in[5 * stride], in[58 * stride]);
3846         x1[6]  = _mm512_add_epi32(in[6 * stride], in[57 * stride]);
3847         x1[57] = _mm512_sub_epi32(in[6 * stride], in[57 * stride]);
3848         x1[7]  = _mm512_add_epi32(in[7 * stride], in[56 * stride]);
3849         x1[56] = _mm512_sub_epi32(in[7 * stride], in[56 * stride]);
3850         x1[8]  = _mm512_add_epi32(in[8 * stride], in[55 * stride]);
3851         x1[55] = _mm512_sub_epi32(in[8 * stride], in[55 * stride]);
3852         x1[9]  = _mm512_add_epi32(in[9 * stride], in[54 * stride]);
3853         x1[54] = _mm512_sub_epi32(in[9 * stride], in[54 * stride]);
3854         x1[10] = _mm512_add_epi32(in[10 * stride], in[53 * stride]);
3855         x1[53] = _mm512_sub_epi32(in[10 * stride], in[53 * stride]);
3856         x1[11] = _mm512_add_epi32(in[11 * stride], in[52 * stride]);
3857         x1[52] = _mm512_sub_epi32(in[11 * stride], in[52 * stride]);
3858         x1[12] = _mm512_add_epi32(in[12 * stride], in[51 * stride]);
3859         x1[51] = _mm512_sub_epi32(in[12 * stride], in[51 * stride]);
3860         x1[13] = _mm512_add_epi32(in[13 * stride], in[50 * stride]);
3861         x1[50] = _mm512_sub_epi32(in[13 * stride], in[50 * stride]);
3862         x1[14] = _mm512_add_epi32(in[14 * stride], in[49 * stride]);
3863         x1[49] = _mm512_sub_epi32(in[14 * stride], in[49 * stride]);
3864         x1[15] = _mm512_add_epi32(in[15 * stride], in[48 * stride]);
3865         x1[48] = _mm512_sub_epi32(in[15 * stride], in[48 * stride]);
3866         x1[16] = _mm512_add_epi32(in[16 * stride], in[47 * stride]);
3867         x1[47] = _mm512_sub_epi32(in[16 * stride], in[47 * stride]);
3868         x1[17] = _mm512_add_epi32(in[17 * stride], in[46 * stride]);
3869         x1[46] = _mm512_sub_epi32(in[17 * stride], in[46 * stride]);
3870         x1[18] = _mm512_add_epi32(in[18 * stride], in[45 * stride]);
3871         x1[45] = _mm512_sub_epi32(in[18 * stride], in[45 * stride]);
3872         x1[19] = _mm512_add_epi32(in[19 * stride], in[44 * stride]);
3873         x1[44] = _mm512_sub_epi32(in[19 * stride], in[44 * stride]);
3874         x1[20] = _mm512_add_epi32(in[20 * stride], in[43 * stride]);
3875         x1[43] = _mm512_sub_epi32(in[20 * stride], in[43 * stride]);
3876         x1[21] = _mm512_add_epi32(in[21 * stride], in[42 * stride]);
3877         x1[42] = _mm512_sub_epi32(in[21 * stride], in[42 * stride]);
3878         x1[22] = _mm512_add_epi32(in[22 * stride], in[41 * stride]);
3879         x1[41] = _mm512_sub_epi32(in[22 * stride], in[41 * stride]);
3880         x1[23] = _mm512_add_epi32(in[23 * stride], in[40 * stride]);
3881         x1[40] = _mm512_sub_epi32(in[23 * stride], in[40 * stride]);
3882         x1[24] = _mm512_add_epi32(in[24 * stride], in[39 * stride]);
3883         x1[39] = _mm512_sub_epi32(in[24 * stride], in[39 * stride]);
3884         x1[25] = _mm512_add_epi32(in[25 * stride], in[38 * stride]);
3885         x1[38] = _mm512_sub_epi32(in[25 * stride], in[38 * stride]);
3886         x1[26] = _mm512_add_epi32(in[26 * stride], in[37 * stride]);
3887         x1[37] = _mm512_sub_epi32(in[26 * stride], in[37 * stride]);
3888         x1[27] = _mm512_add_epi32(in[27 * stride], in[36 * stride]);
3889         x1[36] = _mm512_sub_epi32(in[27 * stride], in[36 * stride]);
3890         x1[28] = _mm512_add_epi32(in[28 * stride], in[35 * stride]);
3891         x1[35] = _mm512_sub_epi32(in[28 * stride], in[35 * stride]);
3892         x1[29] = _mm512_add_epi32(in[29 * stride], in[34 * stride]);
3893         x1[34] = _mm512_sub_epi32(in[29 * stride], in[34 * stride]);
3894         x1[30] = _mm512_add_epi32(in[30 * stride], in[33 * stride]);
3895         x1[33] = _mm512_sub_epi32(in[30 * stride], in[33 * stride]);
3896         x1[31] = _mm512_add_epi32(in[31 * stride], in[32 * stride]);
3897         x1[32] = _mm512_sub_epi32(in[31 * stride], in[32 * stride]);
3898 
3899         // stage 2
3900         __m512i x2[64];
3901         x2[0]  = _mm512_add_epi32(x1[0], x1[31]);
3902         x2[31] = _mm512_sub_epi32(x1[0], x1[31]);
3903         x2[1]  = _mm512_add_epi32(x1[1], x1[30]);
3904         x2[30] = _mm512_sub_epi32(x1[1], x1[30]);
3905         x2[2]  = _mm512_add_epi32(x1[2], x1[29]);
3906         x2[29] = _mm512_sub_epi32(x1[2], x1[29]);
3907         x2[3]  = _mm512_add_epi32(x1[3], x1[28]);
3908         x2[28] = _mm512_sub_epi32(x1[3], x1[28]);
3909         x2[4]  = _mm512_add_epi32(x1[4], x1[27]);
3910         x2[27] = _mm512_sub_epi32(x1[4], x1[27]);
3911         x2[5]  = _mm512_add_epi32(x1[5], x1[26]);
3912         x2[26] = _mm512_sub_epi32(x1[5], x1[26]);
3913         x2[6]  = _mm512_add_epi32(x1[6], x1[25]);
3914         x2[25] = _mm512_sub_epi32(x1[6], x1[25]);
3915         x2[7]  = _mm512_add_epi32(x1[7], x1[24]);
3916         x2[24] = _mm512_sub_epi32(x1[7], x1[24]);
3917         x2[8]  = _mm512_add_epi32(x1[8], x1[23]);
3918         x2[23] = _mm512_sub_epi32(x1[8], x1[23]);
3919         x2[9]  = _mm512_add_epi32(x1[9], x1[22]);
3920         x2[22] = _mm512_sub_epi32(x1[9], x1[22]);
3921         x2[10] = _mm512_add_epi32(x1[10], x1[21]);
3922         x2[21] = _mm512_sub_epi32(x1[10], x1[21]);
3923         x2[11] = _mm512_add_epi32(x1[11], x1[20]);
3924         x2[20] = _mm512_sub_epi32(x1[11], x1[20]);
3925         x2[12] = _mm512_add_epi32(x1[12], x1[19]);
3926         x2[19] = _mm512_sub_epi32(x1[12], x1[19]);
3927         x2[13] = _mm512_add_epi32(x1[13], x1[18]);
3928         x2[18] = _mm512_sub_epi32(x1[13], x1[18]);
3929         x2[14] = _mm512_add_epi32(x1[14], x1[17]);
3930         x2[17] = _mm512_sub_epi32(x1[14], x1[17]);
3931         x2[15] = _mm512_add_epi32(x1[15], x1[16]);
3932         x2[16] = _mm512_sub_epi32(x1[15], x1[16]);
3933         x2[32] = x1[32];
3934         x2[33] = x1[33];
3935         x2[34] = x1[34];
3936         x2[35] = x1[35];
3937         x2[36] = x1[36];
3938         x2[37] = x1[37];
3939         x2[38] = x1[38];
3940         x2[39] = x1[39];
3941         btf_32_type0_avx512_new(
3942             cospi_m32, cospi_p32, x1[40], x1[55], x2[40], x2[55], __rounding, cos_bit);
3943         btf_32_type0_avx512_new(
3944             cospi_m32, cospi_p32, x1[41], x1[54], x2[41], x2[54], __rounding, cos_bit);
3945         btf_32_type0_avx512_new(
3946             cospi_m32, cospi_p32, x1[42], x1[53], x2[42], x2[53], __rounding, cos_bit);
3947         btf_32_type0_avx512_new(
3948             cospi_m32, cospi_p32, x1[43], x1[52], x2[43], x2[52], __rounding, cos_bit);
3949         btf_32_type0_avx512_new(
3950             cospi_m32, cospi_p32, x1[44], x1[51], x2[44], x2[51], __rounding, cos_bit);
3951         btf_32_type0_avx512_new(
3952             cospi_m32, cospi_p32, x1[45], x1[50], x2[45], x2[50], __rounding, cos_bit);
3953         btf_32_type0_avx512_new(
3954             cospi_m32, cospi_p32, x1[46], x1[49], x2[46], x2[49], __rounding, cos_bit);
3955         btf_32_type0_avx512_new(
3956             cospi_m32, cospi_p32, x1[47], x1[48], x2[47], x2[48], __rounding, cos_bit);
3957         x2[56] = x1[56];
3958         x2[57] = x1[57];
3959         x2[58] = x1[58];
3960         x2[59] = x1[59];
3961         x2[60] = x1[60];
3962         x2[61] = x1[61];
3963         x2[62] = x1[62];
3964         x2[63] = x1[63];
3965 
3966         // stage 3
3967         __m512i x3[64];
3968         x3[0]  = _mm512_add_epi32(x2[0], x2[15]);
3969         x3[15] = _mm512_sub_epi32(x2[0], x2[15]);
3970         x3[1]  = _mm512_add_epi32(x2[1], x2[14]);
3971         x3[14] = _mm512_sub_epi32(x2[1], x2[14]);
3972         x3[2]  = _mm512_add_epi32(x2[2], x2[13]);
3973         x3[13] = _mm512_sub_epi32(x2[2], x2[13]);
3974         x3[3]  = _mm512_add_epi32(x2[3], x2[12]);
3975         x3[12] = _mm512_sub_epi32(x2[3], x2[12]);
3976         x3[4]  = _mm512_add_epi32(x2[4], x2[11]);
3977         x3[11] = _mm512_sub_epi32(x2[4], x2[11]);
3978         x3[5]  = _mm512_add_epi32(x2[5], x2[10]);
3979         x3[10] = _mm512_sub_epi32(x2[5], x2[10]);
3980         x3[6]  = _mm512_add_epi32(x2[6], x2[9]);
3981         x3[9]  = _mm512_sub_epi32(x2[6], x2[9]);
3982         x3[7]  = _mm512_add_epi32(x2[7], x2[8]);
3983         x3[8]  = _mm512_sub_epi32(x2[7], x2[8]);
3984         x3[16] = x2[16];
3985         x3[17] = x2[17];
3986         x3[18] = x2[18];
3987         x3[19] = x2[19];
3988         btf_32_type0_avx512_new(
3989             cospi_m32, cospi_p32, x2[20], x2[27], x3[20], x3[27], __rounding, cos_bit);
3990         btf_32_type0_avx512_new(
3991             cospi_m32, cospi_p32, x2[21], x2[26], x3[21], x3[26], __rounding, cos_bit);
3992         btf_32_type0_avx512_new(
3993             cospi_m32, cospi_p32, x2[22], x2[25], x3[22], x3[25], __rounding, cos_bit);
3994         btf_32_type0_avx512_new(
3995             cospi_m32, cospi_p32, x2[23], x2[24], x3[23], x3[24], __rounding, cos_bit);
3996         x3[28] = x2[28];
3997         x3[29] = x2[29];
3998         x3[30] = x2[30];
3999         x3[31] = x2[31];
4000         x3[32] = _mm512_add_epi32(x2[32], x2[47]);
4001         x3[47] = _mm512_sub_epi32(x2[32], x2[47]);
4002         x3[33] = _mm512_add_epi32(x2[33], x2[46]);
4003         x3[46] = _mm512_sub_epi32(x2[33], x2[46]);
4004         x3[34] = _mm512_add_epi32(x2[34], x2[45]);
4005         x3[45] = _mm512_sub_epi32(x2[34], x2[45]);
4006         x3[35] = _mm512_add_epi32(x2[35], x2[44]);
4007         x3[44] = _mm512_sub_epi32(x2[35], x2[44]);
4008         x3[36] = _mm512_add_epi32(x2[36], x2[43]);
4009         x3[43] = _mm512_sub_epi32(x2[36], x2[43]);
4010         x3[37] = _mm512_add_epi32(x2[37], x2[42]);
4011         x3[42] = _mm512_sub_epi32(x2[37], x2[42]);
4012         x3[38] = _mm512_add_epi32(x2[38], x2[41]);
4013         x3[41] = _mm512_sub_epi32(x2[38], x2[41]);
4014         x3[39] = _mm512_add_epi32(x2[39], x2[40]);
4015         x3[40] = _mm512_sub_epi32(x2[39], x2[40]);
4016         x3[48] = _mm512_sub_epi32(x2[63], x2[48]);
4017         x3[63] = _mm512_add_epi32(x2[63], x2[48]);
4018         x3[49] = _mm512_sub_epi32(x2[62], x2[49]);
4019         x3[62] = _mm512_add_epi32(x2[62], x2[49]);
4020         x3[50] = _mm512_sub_epi32(x2[61], x2[50]);
4021         x3[61] = _mm512_add_epi32(x2[61], x2[50]);
4022         x3[51] = _mm512_sub_epi32(x2[60], x2[51]);
4023         x3[60] = _mm512_add_epi32(x2[60], x2[51]);
4024         x3[52] = _mm512_sub_epi32(x2[59], x2[52]);
4025         x3[59] = _mm512_add_epi32(x2[59], x2[52]);
4026         x3[53] = _mm512_sub_epi32(x2[58], x2[53]);
4027         x3[58] = _mm512_add_epi32(x2[58], x2[53]);
4028         x3[54] = _mm512_sub_epi32(x2[57], x2[54]);
4029         x3[57] = _mm512_add_epi32(x2[57], x2[54]);
4030         x3[55] = _mm512_sub_epi32(x2[56], x2[55]);
4031         x3[56] = _mm512_add_epi32(x2[56], x2[55]);
4032 
4033         // stage 4
4034         __m512i x4[64];
4035         x4[0] = _mm512_add_epi32(x3[0], x3[7]);
4036         x4[7] = _mm512_sub_epi32(x3[0], x3[7]);
4037         x4[1] = _mm512_add_epi32(x3[1], x3[6]);
4038         x4[6] = _mm512_sub_epi32(x3[1], x3[6]);
4039         x4[2] = _mm512_add_epi32(x3[2], x3[5]);
4040         x4[5] = _mm512_sub_epi32(x3[2], x3[5]);
4041         x4[3] = _mm512_add_epi32(x3[3], x3[4]);
4042         x4[4] = _mm512_sub_epi32(x3[3], x3[4]);
4043         x4[8] = x3[8];
4044         x4[9] = x3[9];
4045         btf_32_type0_avx512_new(
4046             cospi_m32, cospi_p32, x3[10], x3[13], x4[10], x4[13], __rounding, cos_bit);
4047         btf_32_type0_avx512_new(
4048             cospi_m32, cospi_p32, x3[11], x3[12], x4[11], x4[12], __rounding, cos_bit);
4049         x4[14] = x3[14];
4050         x4[15] = x3[15];
4051         x4[16] = _mm512_add_epi32(x3[16], x3[23]);
4052         x4[23] = _mm512_sub_epi32(x3[16], x3[23]);
4053         x4[17] = _mm512_add_epi32(x3[17], x3[22]);
4054         x4[22] = _mm512_sub_epi32(x3[17], x3[22]);
4055         x4[18] = _mm512_add_epi32(x3[18], x3[21]);
4056         x4[21] = _mm512_sub_epi32(x3[18], x3[21]);
4057         x4[19] = _mm512_add_epi32(x3[19], x3[20]);
4058         x4[20] = _mm512_sub_epi32(x3[19], x3[20]);
4059         x4[24] = _mm512_sub_epi32(x3[31], x3[24]);
4060         x4[31] = _mm512_add_epi32(x3[31], x3[24]);
4061         x4[25] = _mm512_sub_epi32(x3[30], x3[25]);
4062         x4[30] = _mm512_add_epi32(x3[30], x3[25]);
4063         x4[26] = _mm512_sub_epi32(x3[29], x3[26]);
4064         x4[29] = _mm512_add_epi32(x3[29], x3[26]);
4065         x4[27] = _mm512_sub_epi32(x3[28], x3[27]);
4066         x4[28] = _mm512_add_epi32(x3[28], x3[27]);
4067         x4[32] = x3[32];
4068         x4[33] = x3[33];
4069         x4[34] = x3[34];
4070         x4[35] = x3[35];
4071         btf_32_type0_avx512_new(
4072             cospi_m16, cospi_p48, x3[36], x3[59], x4[36], x4[59], __rounding, cos_bit);
4073         btf_32_type0_avx512_new(
4074             cospi_m16, cospi_p48, x3[37], x3[58], x4[37], x4[58], __rounding, cos_bit);
4075         btf_32_type0_avx512_new(
4076             cospi_m16, cospi_p48, x3[38], x3[57], x4[38], x4[57], __rounding, cos_bit);
4077         btf_32_type0_avx512_new(
4078             cospi_m16, cospi_p48, x3[39], x3[56], x4[39], x4[56], __rounding, cos_bit);
4079         btf_32_type0_avx512_new(
4080             cospi_m48, cospi_m16, x3[40], x3[55], x4[40], x4[55], __rounding, cos_bit);
4081         btf_32_type0_avx512_new(
4082             cospi_m48, cospi_m16, x3[41], x3[54], x4[41], x4[54], __rounding, cos_bit);
4083         btf_32_type0_avx512_new(
4084             cospi_m48, cospi_m16, x3[42], x3[53], x4[42], x4[53], __rounding, cos_bit);
4085         btf_32_type0_avx512_new(
4086             cospi_m48, cospi_m16, x3[43], x3[52], x4[43], x4[52], __rounding, cos_bit);
4087         x4[44] = x3[44];
4088         x4[45] = x3[45];
4089         x4[46] = x3[46];
4090         x4[47] = x3[47];
4091         x4[48] = x3[48];
4092         x4[49] = x3[49];
4093         x4[50] = x3[50];
4094         x4[51] = x3[51];
4095         x4[60] = x3[60];
4096         x4[61] = x3[61];
4097         x4[62] = x3[62];
4098         x4[63] = x3[63];
4099 
4100         // stage 5
4101         __m512i x5[64];
4102         x5[0] = _mm512_add_epi32(x4[0], x4[3]);
4103         x5[1] = _mm512_add_epi32(x4[1], x4[2]);
4104         x5[4] = x4[4];
4105         btf_32_type0_avx512_new(
4106             cospi_m32, cospi_p32, x4[5], x4[6], x5[5], x5[6], __rounding, cos_bit);
4107         x5[7]  = x4[7];
4108         x5[8]  = _mm512_add_epi32(x4[8], x4[11]);
4109         x5[11] = _mm512_sub_epi32(x4[8], x4[11]);
4110         x5[9]  = _mm512_add_epi32(x4[9], x4[10]);
4111         x5[10] = _mm512_sub_epi32(x4[9], x4[10]);
4112         x5[12] = _mm512_sub_epi32(x4[15], x4[12]);
4113         x5[15] = _mm512_add_epi32(x4[15], x4[12]);
4114         x5[13] = _mm512_sub_epi32(x4[14], x4[13]);
4115         x5[14] = _mm512_add_epi32(x4[14], x4[13]);
4116         x5[16] = x4[16];
4117         x5[17] = x4[17];
4118         btf_32_type0_avx512_new(
4119             cospi_m16, cospi_p48, x4[18], x4[29], x5[18], x5[29], __rounding, cos_bit);
4120         btf_32_type0_avx512_new(
4121             cospi_m16, cospi_p48, x4[19], x4[28], x5[19], x5[28], __rounding, cos_bit);
4122         btf_32_type0_avx512_new(
4123             cospi_m48, cospi_m16, x4[20], x4[27], x5[20], x5[27], __rounding, cos_bit);
4124         btf_32_type0_avx512_new(
4125             cospi_m48, cospi_m16, x4[21], x4[26], x5[21], x5[26], __rounding, cos_bit);
4126         x5[22] = x4[22];
4127         x5[23] = x4[23];
4128         x5[24] = x4[24];
4129         x5[25] = x4[25];
4130         x5[30] = x4[30];
4131         x5[31] = x4[31];
4132         x5[32] = _mm512_add_epi32(x4[32], x4[39]);
4133         x5[39] = _mm512_sub_epi32(x4[32], x4[39]);
4134         x5[33] = _mm512_add_epi32(x4[33], x4[38]);
4135         x5[38] = _mm512_sub_epi32(x4[33], x4[38]);
4136         x5[34] = _mm512_add_epi32(x4[34], x4[37]);
4137         x5[37] = _mm512_sub_epi32(x4[34], x4[37]);
4138         x5[35] = _mm512_add_epi32(x4[35], x4[36]);
4139         x5[36] = _mm512_sub_epi32(x4[35], x4[36]);
4140         x5[40] = _mm512_sub_epi32(x4[47], x4[40]);
4141         x5[47] = _mm512_add_epi32(x4[47], x4[40]);
4142         x5[41] = _mm512_sub_epi32(x4[46], x4[41]);
4143         x5[46] = _mm512_add_epi32(x4[46], x4[41]);
4144         x5[42] = _mm512_sub_epi32(x4[45], x4[42]);
4145         x5[45] = _mm512_add_epi32(x4[45], x4[42]);
4146         x5[43] = _mm512_sub_epi32(x4[44], x4[43]);
4147         x5[44] = _mm512_add_epi32(x4[44], x4[43]);
4148         x5[48] = _mm512_add_epi32(x4[48], x4[55]);
4149         x5[55] = _mm512_sub_epi32(x4[48], x4[55]);
4150         x5[49] = _mm512_add_epi32(x4[49], x4[54]);
4151         x5[54] = _mm512_sub_epi32(x4[49], x4[54]);
4152         x5[50] = _mm512_add_epi32(x4[50], x4[53]);
4153         x5[53] = _mm512_sub_epi32(x4[50], x4[53]);
4154         x5[51] = _mm512_add_epi32(x4[51], x4[52]);
4155         x5[52] = _mm512_sub_epi32(x4[51], x4[52]);
4156         x5[56] = _mm512_sub_epi32(x4[63], x4[56]);
4157         x5[63] = _mm512_add_epi32(x4[63], x4[56]);
4158         x5[57] = _mm512_sub_epi32(x4[62], x4[57]);
4159         x5[62] = _mm512_add_epi32(x4[62], x4[57]);
4160         x5[58] = _mm512_sub_epi32(x4[61], x4[58]);
4161         x5[61] = _mm512_add_epi32(x4[61], x4[58]);
4162         x5[59] = _mm512_sub_epi32(x4[60], x4[59]);
4163         x5[60] = _mm512_add_epi32(x4[60], x4[59]);
4164 
4165         // stage 6
4166         __m512i x6[64];
4167         x6[0] = half_btf_avx512(&cospi_p32, &x5[0], &cospi_p32, &x5[1], &__rounding, cos_bit);
4168         x6[4] = _mm512_add_epi32(x5[4], x5[5]);
4169         x6[7] = _mm512_add_epi32(x5[7], x5[6]);
4170         x6[8] = x5[8];
4171         btf_32_type0_avx512_new(
4172             cospi_m16, cospi_p48, x5[9], x5[14], x6[9], x6[14], __rounding, cos_bit);
4173         btf_32_type0_avx512_new(
4174             cospi_m48, cospi_m16, x5[10], x5[13], x6[10], x6[13], __rounding, cos_bit);
4175         x6[11] = x5[11];
4176         x6[12] = x5[12];
4177         x6[15] = x5[15];
4178         x6[16] = _mm512_add_epi32(x5[16], x5[19]);
4179         x6[19] = _mm512_sub_epi32(x5[16], x5[19]);
4180         x6[17] = _mm512_add_epi32(x5[17], x5[18]);
4181         x6[18] = _mm512_sub_epi32(x5[17], x5[18]);
4182         x6[20] = _mm512_sub_epi32(x5[23], x5[20]);
4183         x6[23] = _mm512_add_epi32(x5[23], x5[20]);
4184         x6[21] = _mm512_sub_epi32(x5[22], x5[21]);
4185         x6[22] = _mm512_add_epi32(x5[22], x5[21]);
4186         x6[24] = _mm512_add_epi32(x5[24], x5[27]);
4187         x6[27] = _mm512_sub_epi32(x5[24], x5[27]);
4188         x6[25] = _mm512_add_epi32(x5[25], x5[26]);
4189         x6[26] = _mm512_sub_epi32(x5[25], x5[26]);
4190         x6[28] = _mm512_sub_epi32(x5[31], x5[28]);
4191         x6[31] = _mm512_add_epi32(x5[31], x5[28]);
4192         x6[29] = _mm512_sub_epi32(x5[30], x5[29]);
4193         x6[30] = _mm512_add_epi32(x5[30], x5[29]);
4194         x6[32] = x5[32];
4195         x6[33] = x5[33];
4196         btf_32_type0_avx512_new(
4197             cospi_m08, cospi_p56, x5[34], x5[61], x6[34], x6[61], __rounding, cos_bit);
4198         btf_32_type0_avx512_new(
4199             cospi_m08, cospi_p56, x5[35], x5[60], x6[35], x6[60], __rounding, cos_bit);
4200         btf_32_type0_avx512_new(
4201             cospi_m56, cospi_m08, x5[36], x5[59], x6[36], x6[59], __rounding, cos_bit);
4202         btf_32_type0_avx512_new(
4203             cospi_m56, cospi_m08, x5[37], x5[58], x6[37], x6[58], __rounding, cos_bit);
4204         x6[38] = x5[38];
4205         x6[39] = x5[39];
4206         x6[40] = x5[40];
4207         x6[41] = x5[41];
4208         btf_32_type0_avx512_new(
4209             cospi_m40, cospi_p24, x5[42], x5[53], x6[42], x6[53], __rounding, cos_bit);
4210         btf_32_type0_avx512_new(
4211             cospi_m40, cospi_p24, x5[43], x5[52], x6[43], x6[52], __rounding, cos_bit);
4212         btf_32_type0_avx512_new(
4213             cospi_m24, cospi_m40, x5[44], x5[51], x6[44], x6[51], __rounding, cos_bit);
4214         btf_32_type0_avx512_new(
4215             cospi_m24, cospi_m40, x5[45], x5[50], x6[45], x6[50], __rounding, cos_bit);
4216         x6[46] = x5[46];
4217         x6[47] = x5[47];
4218         x6[48] = x5[48];
4219         x6[49] = x5[49];
4220         x6[54] = x5[54];
4221         x6[55] = x5[55];
4222         x6[56] = x5[56];
4223         x6[57] = x5[57];
4224         x6[62] = x5[62];
4225         x6[63] = x5[63];
4226 
4227         // stage 7
4228         __m512i x7[64];
4229         x7[0]  = x6[0];
4230         x7[4]  = half_btf_avx512(&cospi_p56, &x6[4], &cospi_p08, &x6[7], &__rounding, cos_bit);
4231         x7[8]  = _mm512_add_epi32(x6[8], x6[9]);
4232         x7[11] = _mm512_add_epi32(x6[11], x6[10]);
4233         x7[12] = _mm512_add_epi32(x6[12], x6[13]);
4234         x7[15] = _mm512_add_epi32(x6[15], x6[14]);
4235         x7[16] = x6[16];
4236         btf_32_type0_avx512_new(
4237             cospi_m08, cospi_p56, x6[17], x6[30], x7[17], x7[30], __rounding, cos_bit);
4238         btf_32_type0_avx512_new(
4239             cospi_m56, cospi_m08, x6[18], x6[29], x7[18], x7[29], __rounding, cos_bit);
4240         x7[19] = x6[19];
4241         x7[20] = x6[20];
4242         btf_32_type0_avx512_new(
4243             cospi_m40, cospi_p24, x6[21], x6[26], x7[21], x7[26], __rounding, cos_bit);
4244         btf_32_type0_avx512_new(
4245             cospi_m24, cospi_m40, x6[22], x6[25], x7[22], x7[25], __rounding, cos_bit);
4246         x7[23] = x6[23];
4247         x7[24] = x6[24];
4248         x7[27] = x6[27];
4249         x7[28] = x6[28];
4250         x7[31] = x6[31];
4251         x7[32] = _mm512_add_epi32(x6[32], x6[35]);
4252         x7[35] = _mm512_sub_epi32(x6[32], x6[35]);
4253         x7[33] = _mm512_add_epi32(x6[33], x6[34]);
4254         x7[34] = _mm512_sub_epi32(x6[33], x6[34]);
4255         x7[36] = _mm512_sub_epi32(x6[39], x6[36]);
4256         x7[39] = _mm512_add_epi32(x6[39], x6[36]);
4257         x7[37] = _mm512_sub_epi32(x6[38], x6[37]);
4258         x7[38] = _mm512_add_epi32(x6[38], x6[37]);
4259         x7[40] = _mm512_add_epi32(x6[40], x6[43]);
4260         x7[43] = _mm512_sub_epi32(x6[40], x6[43]);
4261         x7[41] = _mm512_add_epi32(x6[41], x6[42]);
4262         x7[42] = _mm512_sub_epi32(x6[41], x6[42]);
4263         x7[44] = _mm512_sub_epi32(x6[47], x6[44]);
4264         x7[47] = _mm512_add_epi32(x6[47], x6[44]);
4265         x7[45] = _mm512_sub_epi32(x6[46], x6[45]);
4266         x7[46] = _mm512_add_epi32(x6[46], x6[45]);
4267         x7[48] = _mm512_add_epi32(x6[48], x6[51]);
4268         x7[51] = _mm512_sub_epi32(x6[48], x6[51]);
4269         x7[49] = _mm512_add_epi32(x6[49], x6[50]);
4270         x7[50] = _mm512_sub_epi32(x6[49], x6[50]);
4271         x7[52] = _mm512_sub_epi32(x6[55], x6[52]);
4272         x7[55] = _mm512_add_epi32(x6[55], x6[52]);
4273         x7[53] = _mm512_sub_epi32(x6[54], x6[53]);
4274         x7[54] = _mm512_add_epi32(x6[54], x6[53]);
4275         x7[56] = _mm512_add_epi32(x6[56], x6[59]);
4276         x7[59] = _mm512_sub_epi32(x6[56], x6[59]);
4277         x7[57] = _mm512_add_epi32(x6[57], x6[58]);
4278         x7[58] = _mm512_sub_epi32(x6[57], x6[58]);
4279         x7[60] = _mm512_sub_epi32(x6[63], x6[60]);
4280         x7[63] = _mm512_add_epi32(x6[63], x6[60]);
4281         x7[61] = _mm512_sub_epi32(x6[62], x6[61]);
4282         x7[62] = _mm512_add_epi32(x6[62], x6[61]);
4283 
4284         // stage 8
4285         __m512i x8[64];
4286         out[0 * stride] = x7[0];
4287         out[8 * stride] = x7[4];
4288         out[4 * stride] = half_btf_avx512(
4289             &cospi_p60, &x7[8], &cospi_p04, &x7[15], &__rounding, cos_bit);
4290         out[12 * stride] = half_btf_avx512(
4291             &cospi_p12, &x7[12], &cospi_m52, &x7[11], &__rounding, cos_bit);
4292         x8[16] = _mm512_add_epi32(x7[16], x7[17]);
4293         x8[19] = _mm512_add_epi32(x7[19], x7[18]);
4294         x8[20] = _mm512_add_epi32(x7[20], x7[21]);
4295         x8[23] = _mm512_add_epi32(x7[23], x7[22]);
4296         x8[24] = _mm512_add_epi32(x7[24], x7[25]);
4297         x8[27] = _mm512_add_epi32(x7[27], x7[26]);
4298         x8[28] = _mm512_add_epi32(x7[28], x7[29]);
4299         x8[31] = _mm512_add_epi32(x7[31], x7[30]);
4300         x8[32] = x7[32];
4301         btf_32_type0_avx512_new(
4302             cospi_m04, cospi_p60, x7[33], x7[62], x8[33], x8[62], __rounding, cos_bit);
4303         btf_32_type0_avx512_new(
4304             cospi_m60, cospi_m04, x7[34], x7[61], x8[34], x8[61], __rounding, cos_bit);
4305         x8[35] = x7[35];
4306         x8[36] = x7[36];
4307         btf_32_type0_avx512_new(
4308             cospi_m36, cospi_p28, x7[37], x7[58], x8[37], x8[58], __rounding, cos_bit);
4309         btf_32_type0_avx512_new(
4310             cospi_m28, cospi_m36, x7[38], x7[57], x8[38], x8[57], __rounding, cos_bit);
4311         x8[39] = x7[39];
4312         x8[40] = x7[40];
4313         btf_32_type0_avx512_new(
4314             cospi_m20, cospi_p44, x7[41], x7[54], x8[41], x8[54], __rounding, cos_bit);
4315         btf_32_type0_avx512_new(
4316             cospi_m44, cospi_m20, x7[42], x7[53], x8[42], x8[53], __rounding, cos_bit);
4317         x8[43] = x7[43];
4318         x8[44] = x7[44];
4319         btf_32_type0_avx512_new(
4320             cospi_m52, cospi_p12, x7[45], x7[50], x8[45], x8[50], __rounding, cos_bit);
4321         btf_32_type0_avx512_new(
4322             cospi_m12, cospi_m52, x7[46], x7[49], x8[46], x8[49], __rounding, cos_bit);
4323         x8[47] = x7[47];
4324         x8[48] = x7[48];
4325         x8[51] = x7[51];
4326         x8[52] = x7[52];
4327         x8[55] = x7[55];
4328         x8[56] = x7[56];
4329         x8[59] = x7[59];
4330         x8[60] = x7[60];
4331         x8[63] = x7[63];
4332 
4333         // stage 9
4334         __m512i x9[16];
4335         out[2 * stride] = half_btf_avx512(
4336             &cospi_p62, &x8[16], &cospi_p02, &x8[31], &__rounding, cos_bit);
4337         out[14 * stride] = half_btf_avx512(
4338             &cospi_p14, &x8[28], &cospi_m50, &x8[19], &__rounding, cos_bit);
4339         out[10 * stride] = half_btf_avx512(
4340             &cospi_p54, &x8[20], &cospi_p10, &x8[27], &__rounding, cos_bit);
4341         out[6 * stride] = half_btf_avx512(
4342             &cospi_p06, &x8[24], &cospi_m58, &x8[23], &__rounding, cos_bit);
4343         x9[0]  = _mm512_add_epi32(x8[32], x8[33]);
4344         x9[1]  = _mm512_add_epi32(x8[35], x8[34]);
4345         x9[2]  = _mm512_add_epi32(x8[36], x8[37]);
4346         x9[3]  = _mm512_add_epi32(x8[39], x8[38]);
4347         x9[4]  = _mm512_add_epi32(x8[40], x8[41]);
4348         x9[5]  = _mm512_add_epi32(x8[43], x8[42]);
4349         x9[6]  = _mm512_add_epi32(x8[44], x8[45]);
4350         x9[7]  = _mm512_add_epi32(x8[47], x8[46]);
4351         x9[8]  = _mm512_add_epi32(x8[48], x8[49]);
4352         x9[9]  = _mm512_add_epi32(x8[51], x8[50]);
4353         x9[10] = _mm512_add_epi32(x8[63], x8[62]);
4354         x9[11] = _mm512_add_epi32(x8[52], x8[53]);
4355         x9[12] = _mm512_add_epi32(x8[55], x8[54]);
4356         x9[13] = _mm512_add_epi32(x8[56], x8[57]);
4357         x9[14] = _mm512_add_epi32(x8[59], x8[58]);
4358         x9[15] = _mm512_add_epi32(x8[60], x8[61]);
4359 
4360         // stage 10
4361         out[1 * stride] = half_btf_avx512(
4362             &cospi_p63, &x9[0], &cospi_p01, &x9[10], &__rounding, cos_bit);
4363         out[15 * stride] = half_btf_avx512(
4364             &cospi_p15, &x9[15], &cospi_m49, &x9[1], &__rounding, cos_bit);
4365         out[9 * stride] = half_btf_avx512(
4366             &cospi_p55, &x9[2], &cospi_p09, &x9[14], &__rounding, cos_bit);
4367         out[7 * stride] = half_btf_avx512(
4368             &cospi_p07, &x9[13], &cospi_m57, &x9[3], &__rounding, cos_bit);
4369         out[5 * stride] = half_btf_avx512(
4370             &cospi_p59, &x9[4], &cospi_p05, &x9[12], &__rounding, cos_bit);
4371         out[11 * stride] = half_btf_avx512(
4372             &cospi_p11, &x9[11], &cospi_m53, &x9[5], &__rounding, cos_bit);
4373         out[13 * stride] = half_btf_avx512(
4374             &cospi_p51, &x9[6], &cospi_p13, &x9[9], &__rounding, cos_bit);
4375         out[3 * stride] = half_btf_avx512(
4376             &cospi_p03, &x9[8], &cospi_m61, &x9[7], &__rounding, cos_bit);
4377     }
4378 }
4379 
fidtx64x64_N4_avx512(const __m512i * input,__m512i * output)4380 static void fidtx64x64_N4_avx512(const __m512i *input, __m512i *output) {
4381     const uint8_t bits     = 12; // new_sqrt2_bits = 12
4382     const int32_t sqrt     = 4 * 5793; // 4 * new_sqrt2
4383     const __m512i newsqrt  = _mm512_set1_epi32(sqrt);
4384     const __m512i rounding = _mm512_set1_epi32(1 << (bits - 1));
4385 
4386     __m512i temp;
4387     for (int32_t i = 0; i < 64; i += 4) {
4388         temp      = _mm512_mullo_epi32(input[i], newsqrt);
4389         temp      = _mm512_add_epi32(temp, rounding);
4390         output[i] = _mm512_srai_epi32(temp, bits);
4391     }
4392 }
4393 
av1_fwd_txfm2d_64x64_N4_avx512(int16_t * input,int32_t * output,uint32_t stride,TxType tx_type,uint8_t bd)4394 void av1_fwd_txfm2d_64x64_N4_avx512(int16_t *input, int32_t *output, uint32_t stride,
4395                                     TxType tx_type, uint8_t bd) {
4396     (void)bd;
4397     __m512i       in[256];
4398     __m512i *     out         = (__m512i *)output;
4399     const int32_t txw_idx     = tx_size_wide_log2[TX_64X64] - tx_size_wide_log2[0];
4400     const int32_t txh_idx     = tx_size_high_log2[TX_64X64] - tx_size_high_log2[0];
4401     const int8_t *shift       = fwd_txfm_shift_ls[TX_64X64];
4402     const int32_t txfm_size   = 64;
4403     const int32_t num_per_512 = 16;
4404     int32_t       col_num     = txfm_size / num_per_512;
4405 
4406     switch (tx_type) {
4407     case IDTX:
4408         load_buffer_16x16_in_64x64_avx512(input, stride, out);
4409         fidtx64x64_N4_avx512(out, in);
4410         av1_round_shift_array_64x64_N4_avx512(in, out, -shift[1]);
4411         /*row wise transform*/
4412         fidtx64x64_N4_avx512(out, in);
4413         av1_round_shift_array_64x64_N4_avx512(in, out, -shift[2]);
4414         clear_buffer_64x64_N4_avx512(out);
4415         break;
4416     case DCT_DCT:
4417         load_buffer_64x64_avx512(input, stride, out);
4418         av1_fdct64_new_N4_avx512(out, in, fwd_cos_bit_col[txw_idx][txh_idx], txfm_size, col_num);
4419         av1_round_shift_array_avx512(in, out, 256 / 4, -shift[1]);
4420         transpose_16nx16n_N4_half_avx512(64, out, in);
4421         /*row wise transform*/
4422         av1_fdct64_new_N4_avx512(
4423             in, out, fwd_cos_bit_row[txw_idx][txh_idx], txfm_size / 4, col_num);
4424         av1_round_shift_array_64x64_N4_avx512(out, in, -shift[2]);
4425         transpose_16nx16n_N4_quad_avx512(64, in, out);
4426         clear_buffer_64x64_N4_avx512(out);
4427         break;
4428     default: assert(0);
4429     }
4430 }
4431 
4432 #endif
4433