1 /*****************************************************************************
2  * This file is part of Kvazaar HEVC encoder.
3  *
4  * Copyright (c) 2021, Tampere University, ITU/ISO/IEC, project contributors
5  * All rights reserved.
6  *
7  * Redistribution and use in source and binary forms, with or without modification,
8  * are permitted provided that the following conditions are met:
9  *
10  * * Redistributions of source code must retain the above copyright notice, this
11  *   list of conditions and the following disclaimer.
12  *
13  * * Redistributions in binary form must reproduce the above copyright notice, this
14  *   list of conditions and the following disclaimer in the documentation and/or
15  *   other materials provided with the distribution.
16  *
17  * * Neither the name of the Tampere University or ITU/ISO/IEC nor the names of its
18  *   contributors may be used to endorse or promote products derived from
19  *   this software without specific prior written permission.
20  *
21  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
22  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
23  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24  * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
25  * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
26  * INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
27  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION HOWEVER CAUSED AND ON
28  * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30  * INCLUDING NEGLIGENCE OR OTHERWISE ARISING IN ANY WAY OUT OF THE USE OF THIS
31  ****************************************************************************/
32 
33 /*
34 * \file
35 */
36 
37 #include "strategies/avx2/dct-avx2.h"
38 
39 #if COMPILE_INTEL_AVX2
40 #include "kvazaar.h"
41 #if KVZ_BIT_DEPTH == 8
42 #include <immintrin.h>
43 
44 #include "strategyselector.h"
45 #include "tables.h"
46 
47 extern const int16_t kvz_g_dst_4[4][4];
48 extern const int16_t kvz_g_dct_4[4][4];
49 extern const int16_t kvz_g_dct_8[8][8];
50 extern const int16_t kvz_g_dct_16[16][16];
51 extern const int16_t kvz_g_dct_32[32][32];
52 
53 extern const int16_t kvz_g_dst_4_t[4][4];
54 extern const int16_t kvz_g_dct_4_t[4][4];
55 extern const int16_t kvz_g_dct_8_t[8][8];
56 extern const int16_t kvz_g_dct_16_t[16][16];
57 extern const int16_t kvz_g_dct_32_t[32][32];
58 
59 /*
60 * \file
61 * \brief AVX2 transformations.
62 */
63 
swap_lanes(__m256i v)64 static INLINE __m256i swap_lanes(__m256i v)
65 {
66   return _mm256_permute4x64_epi64(v, _MM_SHUFFLE(1, 0, 3, 2));
67 }
68 
truncate_avx2(__m256i v,__m256i debias,int32_t shift)69 static INLINE __m256i truncate_avx2(__m256i v, __m256i debias, int32_t shift)
70 {
71   __m256i truncable = _mm256_add_epi32 (v,         debias);
72   return              _mm256_srai_epi32(truncable, shift);
73 }
74 
75 // 4x4 matrix multiplication with value clipping.
76 // Parameters: Two 4x4 matrices containing 16-bit values in consecutive addresses,
77 //             destination for the result and the shift value for clipping.
mul_clip_matrix_4x4_avx2(const __m256i left,const __m256i right,int shift)78 static __m256i mul_clip_matrix_4x4_avx2(const __m256i left, const __m256i right, int shift)
79 {
80   const int32_t add    = 1 << (shift - 1);
81   const __m256i debias = _mm256_set1_epi32(add);
82 
83   __m256i right_los = _mm256_permute4x64_epi64(right, _MM_SHUFFLE(2, 0, 2, 0));
84   __m256i right_his = _mm256_permute4x64_epi64(right, _MM_SHUFFLE(3, 1, 3, 1));
85 
86   __m256i right_cols_up = _mm256_unpacklo_epi16(right_los, right_his);
87   __m256i right_cols_dn = _mm256_unpackhi_epi16(right_los, right_his);
88 
89   __m256i left_slice1 = _mm256_shuffle_epi32(left, _MM_SHUFFLE(0, 0, 0, 0));
90   __m256i left_slice2 = _mm256_shuffle_epi32(left, _MM_SHUFFLE(1, 1, 1, 1));
91   __m256i left_slice3 = _mm256_shuffle_epi32(left, _MM_SHUFFLE(2, 2, 2, 2));
92   __m256i left_slice4 = _mm256_shuffle_epi32(left, _MM_SHUFFLE(3, 3, 3, 3));
93 
94   __m256i prod1 = _mm256_madd_epi16(left_slice1, right_cols_up);
95   __m256i prod2 = _mm256_madd_epi16(left_slice2, right_cols_dn);
96   __m256i prod3 = _mm256_madd_epi16(left_slice3, right_cols_up);
97   __m256i prod4 = _mm256_madd_epi16(left_slice4, right_cols_dn);
98 
99   __m256i rows_up = _mm256_add_epi32(prod1, prod2);
100   __m256i rows_dn = _mm256_add_epi32(prod3, prod4);
101 
102   __m256i rows_up_tr = truncate_avx2(rows_up, debias, shift);
103   __m256i rows_dn_tr = truncate_avx2(rows_dn, debias, shift);
104 
105   __m256i result = _mm256_packs_epi32(rows_up_tr, rows_dn_tr);
106   return result;
107 }
108 
matrix_dst_4x4_avx2(int8_t bitdepth,const int16_t * input,int16_t * output)109 static void matrix_dst_4x4_avx2(int8_t bitdepth, const int16_t *input, int16_t *output)
110 {
111   int32_t shift_1st = kvz_g_convert_to_bit[4] + 1 + (bitdepth - 8);
112   int32_t shift_2nd = kvz_g_convert_to_bit[4] + 8;
113   const int16_t *tdst = &kvz_g_dst_4_t[0][0];
114   const int16_t *dst  = &kvz_g_dst_4  [0][0];
115 
116   __m256i tdst_v = _mm256_load_si256((const __m256i *) tdst);
117   __m256i  dst_v = _mm256_load_si256((const __m256i *)  dst);
118   __m256i   in_v = _mm256_load_si256((const __m256i *)input);
119 
120   __m256i tmp    = mul_clip_matrix_4x4_avx2(in_v,  tdst_v, shift_1st);
121   __m256i result = mul_clip_matrix_4x4_avx2(dst_v, tmp,    shift_2nd);
122 
123   _mm256_store_si256((__m256i *)output, result);
124 }
125 
matrix_idst_4x4_avx2(int8_t bitdepth,const int16_t * input,int16_t * output)126 static void matrix_idst_4x4_avx2(int8_t bitdepth, const int16_t *input, int16_t *output)
127 {
128   int32_t shift_1st = 7;
129   int32_t shift_2nd = 12 - (bitdepth - 8);
130 
131   const int16_t *tdst = &kvz_g_dst_4_t[0][0];
132   const int16_t *dst  = &kvz_g_dst_4  [0][0];
133 
134   __m256i tdst_v = _mm256_load_si256((const __m256i *)tdst);
135   __m256i  dst_v = _mm256_load_si256((const __m256i *) dst);
136   __m256i   in_v = _mm256_load_si256((const __m256i *)input);
137 
138   __m256i tmp    = mul_clip_matrix_4x4_avx2(tdst_v, in_v,  shift_1st);
139   __m256i result = mul_clip_matrix_4x4_avx2(tmp,    dst_v, shift_2nd);
140 
141   _mm256_store_si256((__m256i *)output, result);
142 }
143 
matrix_dct_4x4_avx2(int8_t bitdepth,const int16_t * input,int16_t * output)144 static void matrix_dct_4x4_avx2(int8_t bitdepth, const int16_t *input, int16_t *output)
145 {
146   int32_t shift_1st = kvz_g_convert_to_bit[4] + 1 + (bitdepth - 8);
147   int32_t shift_2nd = kvz_g_convert_to_bit[4] + 8;
148   const int16_t *tdct = &kvz_g_dct_4_t[0][0];
149   const int16_t *dct  = &kvz_g_dct_4  [0][0];
150 
151   __m256i tdct_v = _mm256_load_si256((const __m256i *) tdct);
152   __m256i  dct_v = _mm256_load_si256((const __m256i *)  dct);
153   __m256i   in_v = _mm256_load_si256((const __m256i *)input);
154 
155   __m256i tmp    = mul_clip_matrix_4x4_avx2(in_v,  tdct_v, shift_1st);
156   __m256i result = mul_clip_matrix_4x4_avx2(dct_v, tmp,    shift_2nd);
157 
158   _mm256_store_si256((__m256i *)output, result);
159 }
160 
matrix_idct_4x4_avx2(int8_t bitdepth,const int16_t * input,int16_t * output)161 static void matrix_idct_4x4_avx2(int8_t bitdepth, const int16_t *input, int16_t *output)
162 {
163   int32_t shift_1st = 7;
164   int32_t shift_2nd = 12 - (bitdepth - 8);
165 
166   const int16_t *tdct = &kvz_g_dct_4_t[0][0];
167   const int16_t *dct  = &kvz_g_dct_4  [0][0];
168 
169   __m256i tdct_v = _mm256_load_si256((const __m256i *)tdct);
170   __m256i  dct_v = _mm256_load_si256((const __m256i *) dct);
171   __m256i   in_v = _mm256_load_si256((const __m256i *)input);
172 
173   __m256i tmp    = mul_clip_matrix_4x4_avx2(tdct_v, in_v,  shift_1st);
174   __m256i result = mul_clip_matrix_4x4_avx2(tmp,    dct_v, shift_2nd);
175 
176   _mm256_store_si256((__m256i *)output, result);
177 }
178 
mul_clip_matrix_8x8_avx2(const int16_t * left,const int16_t * right,int16_t * dst,const int32_t shift)179 static void mul_clip_matrix_8x8_avx2(const int16_t *left, const int16_t *right, int16_t *dst, const int32_t shift)
180 {
181   const __m256i transp_mask = _mm256_broadcastsi128_si256(_mm_setr_epi8(0, 1, 8, 9, 2, 3, 10, 11, 4, 5, 12, 13, 6, 7, 14, 15));
182 
183   const int32_t add    = 1 << (shift - 1);
184   const __m256i debias = _mm256_set1_epi32(add);
185 
186   __m256i left_dr[4] = {
187     _mm256_load_si256((const __m256i *)left + 0),
188     _mm256_load_si256((const __m256i *)left + 1),
189     _mm256_load_si256((const __m256i *)left + 2),
190     _mm256_load_si256((const __m256i *)left + 3),
191   };
192   __m256i right_dr[4] = {
193     _mm256_load_si256((const __m256i *)right + 0),
194     _mm256_load_si256((const __m256i *)right + 1),
195     _mm256_load_si256((const __m256i *)right + 2),
196     _mm256_load_si256((const __m256i *)right + 3),
197   };
198 
199   __m256i rdrs_rearr[8];
200 
201   // Rearrange right matrix
202   for (int32_t dry = 0; dry < 4; dry++) {
203     __m256i rdr = right_dr[dry];
204     __m256i rdr_los = _mm256_permute4x64_epi64(rdr, _MM_SHUFFLE(2, 0, 2, 0));
205     __m256i rdr_his = _mm256_permute4x64_epi64(rdr, _MM_SHUFFLE(3, 1, 3, 1));
206 
207     __m256i rdr_lo_rearr = _mm256_shuffle_epi8(rdr_los, transp_mask);
208     __m256i rdr_hi_rearr = _mm256_shuffle_epi8(rdr_his, transp_mask);
209 
210     rdrs_rearr[dry * 2 + 0] = rdr_lo_rearr;
211     rdrs_rearr[dry * 2 + 1] = rdr_hi_rearr;
212   }
213 
214   // Double-Row Y for destination matrix
215   for (int32_t dry = 0; dry < 4; dry++) {
216     __m256i ldr = left_dr[dry];
217 
218     __m256i ldr_slice12 = _mm256_shuffle_epi32(ldr, _MM_SHUFFLE(0, 0, 0, 0));
219     __m256i ldr_slice34 = _mm256_shuffle_epi32(ldr, _MM_SHUFFLE(1, 1, 1, 1));
220     __m256i ldr_slice56 = _mm256_shuffle_epi32(ldr, _MM_SHUFFLE(2, 2, 2, 2));
221     __m256i ldr_slice78 = _mm256_shuffle_epi32(ldr, _MM_SHUFFLE(3, 3, 3, 3));
222 
223     __m256i prod1 = _mm256_madd_epi16(ldr_slice12, rdrs_rearr[0]);
224     __m256i prod2 = _mm256_madd_epi16(ldr_slice12, rdrs_rearr[1]);
225     __m256i prod3 = _mm256_madd_epi16(ldr_slice34, rdrs_rearr[2]);
226     __m256i prod4 = _mm256_madd_epi16(ldr_slice34, rdrs_rearr[3]);
227     __m256i prod5 = _mm256_madd_epi16(ldr_slice56, rdrs_rearr[4]);
228     __m256i prod6 = _mm256_madd_epi16(ldr_slice56, rdrs_rearr[5]);
229     __m256i prod7 = _mm256_madd_epi16(ldr_slice78, rdrs_rearr[6]);
230     __m256i prod8 = _mm256_madd_epi16(ldr_slice78, rdrs_rearr[7]);
231 
232     __m256i lo_1 = _mm256_add_epi32(prod1, prod3);
233     __m256i hi_1 = _mm256_add_epi32(prod2, prod4);
234     __m256i lo_2 = _mm256_add_epi32(prod5, prod7);
235     __m256i hi_2 = _mm256_add_epi32(prod6, prod8);
236 
237     __m256i lo   = _mm256_add_epi32(lo_1,  lo_2);
238     __m256i hi   = _mm256_add_epi32(hi_1,  hi_2);
239 
240     __m256i lo_tr = truncate_avx2(lo, debias, shift);
241     __m256i hi_tr = truncate_avx2(hi, debias, shift);
242 
243     __m256i final_dr = _mm256_packs_epi32(lo_tr, hi_tr);
244 
245     _mm256_store_si256((__m256i *)dst + dry, final_dr);
246   }
247 }
248 
249 // Multiplies A by B_T's transpose and stores result's transpose in output,
250 // which should be an array of 4 __m256i's
matmul_8x8_a_bt_t(const int16_t * a,const int16_t * b_t,__m256i * output,const int8_t shift)251 static void matmul_8x8_a_bt_t(const int16_t *a, const int16_t *b_t,
252     __m256i *output, const int8_t shift)
253 {
254   const int32_t add    = 1 << (shift - 1);
255   const __m256i debias = _mm256_set1_epi32(add);
256 
257   // Keep upper row intact and swap neighboring 16-bit words in lower row
258   const __m256i shuf_lorow_mask =
259       _mm256_setr_epi8(0,  1,  2,  3,  4,  5,  6,  7,
260                        8,  9,  10, 11, 12, 13, 14, 15,
261                        18, 19, 16, 17, 22, 23, 20, 21,
262                        26, 27, 24, 25, 30, 31, 28, 29);
263 
264   const __m256i *b_t_256 = (const __m256i *)b_t;
265 
266   // Dual Rows, because two 8x16b words fit in one YMM
267   __m256i a_dr_0      = _mm256_load_si256((__m256i *)a + 0);
268   __m256i a_dr_1      = _mm256_load_si256((__m256i *)a + 1);
269   __m256i a_dr_2      = _mm256_load_si256((__m256i *)a + 2);
270   __m256i a_dr_3      = _mm256_load_si256((__m256i *)a + 3);
271 
272   __m256i a_dr_0_swp  = swap_lanes(a_dr_0);
273   __m256i a_dr_1_swp  = swap_lanes(a_dr_1);
274   __m256i a_dr_2_swp  = swap_lanes(a_dr_2);
275   __m256i a_dr_3_swp  = swap_lanes(a_dr_3);
276 
277   for (int dry = 0; dry < 4; dry++) {
278 
279     // Read dual columns of B matrix by reading rows of its transpose
280     __m256i b_dc        = _mm256_load_si256(b_t_256 + dry);
281 
282     __m256i prod0       = _mm256_madd_epi16(b_dc,     a_dr_0);
283     __m256i prod0_swp   = _mm256_madd_epi16(b_dc,     a_dr_0_swp);
284     __m256i prod1       = _mm256_madd_epi16(b_dc,     a_dr_1);
285     __m256i prod1_swp   = _mm256_madd_epi16(b_dc,     a_dr_1_swp);
286     __m256i prod2       = _mm256_madd_epi16(b_dc,     a_dr_2);
287     __m256i prod2_swp   = _mm256_madd_epi16(b_dc,     a_dr_2_swp);
288     __m256i prod3       = _mm256_madd_epi16(b_dc,     a_dr_3);
289     __m256i prod3_swp   = _mm256_madd_epi16(b_dc,     a_dr_3_swp);
290 
291     __m256i hsum0       = _mm256_hadd_epi32(prod0,    prod0_swp);
292     __m256i hsum1       = _mm256_hadd_epi32(prod1,    prod1_swp);
293     __m256i hsum2       = _mm256_hadd_epi32(prod2,    prod2_swp);
294     __m256i hsum3       = _mm256_hadd_epi32(prod3,    prod3_swp);
295 
296     __m256i hsum2c_0    = _mm256_hadd_epi32(hsum0,    hsum1);
297     __m256i hsum2c_1    = _mm256_hadd_epi32(hsum2,    hsum3);
298 
299     __m256i hsum2c_0_tr = truncate_avx2(hsum2c_0, debias, shift);
300     __m256i hsum2c_1_tr = truncate_avx2(hsum2c_1, debias, shift);
301 
302     __m256i tmp_dc      = _mm256_packs_epi32(hsum2c_0_tr, hsum2c_1_tr);
303 
304     output[dry]         = _mm256_shuffle_epi8(tmp_dc, shuf_lorow_mask);
305   }
306 }
307 
308 // Multiplies A by B_T's transpose and stores result in output
309 // which should be an array of 4 __m256i's
matmul_8x8_a_bt(const int16_t * a,const __m256i * b_t,int16_t * output,const int8_t shift)310 static void matmul_8x8_a_bt(const int16_t *a, const __m256i *b_t,
311     int16_t *output, const int8_t shift)
312 {
313   const int32_t add    = 1 << (shift - 1);
314   const __m256i debias = _mm256_set1_epi32(add);
315 
316   const __m256i shuf_lorow_mask =
317       _mm256_setr_epi8(0,  1,  2,  3,  4,  5,  6,  7,
318                        8,  9,  10, 11, 12, 13, 14, 15,
319                        18, 19, 16, 17, 22, 23, 20, 21,
320                        26, 27, 24, 25, 30, 31, 28, 29);
321 
322   const __m256i *a_256 = (const __m256i *)a;
323 
324   __m256i b_dc_0      = b_t[0];
325   __m256i b_dc_1      = b_t[1];
326   __m256i b_dc_2      = b_t[2];
327   __m256i b_dc_3      = b_t[3];
328 
329   __m256i b_dc_0_swp  = swap_lanes(b_dc_0);
330   __m256i b_dc_1_swp  = swap_lanes(b_dc_1);
331   __m256i b_dc_2_swp  = swap_lanes(b_dc_2);
332   __m256i b_dc_3_swp  = swap_lanes(b_dc_3);
333 
334   for (int dry = 0; dry < 4; dry++) {
335     __m256i a_dr        = _mm256_load_si256(a_256 + dry);
336 
337     __m256i prod0       = _mm256_madd_epi16(a_dr,     b_dc_0);
338     __m256i prod0_swp   = _mm256_madd_epi16(a_dr,     b_dc_0_swp);
339     __m256i prod1       = _mm256_madd_epi16(a_dr,     b_dc_1);
340     __m256i prod1_swp   = _mm256_madd_epi16(a_dr,     b_dc_1_swp);
341     __m256i prod2       = _mm256_madd_epi16(a_dr,     b_dc_2);
342     __m256i prod2_swp   = _mm256_madd_epi16(a_dr,     b_dc_2_swp);
343     __m256i prod3       = _mm256_madd_epi16(a_dr,     b_dc_3);
344     __m256i prod3_swp   = _mm256_madd_epi16(a_dr,     b_dc_3_swp);
345 
346     __m256i hsum0       = _mm256_hadd_epi32(prod0,    prod0_swp);
347     __m256i hsum1       = _mm256_hadd_epi32(prod1,    prod1_swp);
348     __m256i hsum2       = _mm256_hadd_epi32(prod2,    prod2_swp);
349     __m256i hsum3       = _mm256_hadd_epi32(prod3,    prod3_swp);
350 
351     __m256i hsum2c_0    = _mm256_hadd_epi32(hsum0,    hsum1);
352     __m256i hsum2c_1    = _mm256_hadd_epi32(hsum2,    hsum3);
353 
354     __m256i hsum2c_0_tr = truncate_avx2(hsum2c_0, debias, shift);
355     __m256i hsum2c_1_tr = truncate_avx2(hsum2c_1, debias, shift);
356 
357     __m256i tmp_dr      = _mm256_packs_epi32(hsum2c_0_tr, hsum2c_1_tr);
358 
359     __m256i final_dr    = _mm256_shuffle_epi8(tmp_dr, shuf_lorow_mask);
360 
361     _mm256_store_si256((__m256i *)output + dry, final_dr);
362   }
363 }
364 
matrix_dct_8x8_avx2(int8_t bitdepth,const int16_t * input,int16_t * output)365 static void matrix_dct_8x8_avx2(int8_t bitdepth, const int16_t *input, int16_t *output)
366 {
367   int32_t shift_1st = kvz_g_convert_to_bit[8] + 1 + (bitdepth - 8);
368   int32_t shift_2nd = kvz_g_convert_to_bit[8] + 8;
369 
370   const int16_t *dct  = &kvz_g_dct_8[0][0];
371 
372   /*
373    * Multiply input by the tranpose of DCT matrix into tmpres, and DCT matrix
374    * by tmpres - this is then our output matrix
375    *
376    * It's easier to implement an AVX2 matrix multiplication if you can multiply
377    * the left term with the transpose of the right term. Here things are stored
378    * row-wise, not column-wise, so we can effectively read DCT_T column-wise
379    * into YMM registers by reading DCT row-wise. Also because of this, the
380    * first multiplication is hacked to produce the transpose of the result
381    * instead, since it will be used in similar fashion as the right operand
382    * in the second multiplication.
383    */
384 
385   __m256i tmpres[4];
386 
387   matmul_8x8_a_bt_t(input,  dct, tmpres, shift_1st);
388   matmul_8x8_a_bt  (dct, tmpres, output, shift_2nd);
389 }
390 
matrix_idct_8x8_avx2(int8_t bitdepth,const int16_t * input,int16_t * output)391 static void matrix_idct_8x8_avx2(int8_t bitdepth, const int16_t *input, int16_t *output)
392 {
393   int32_t shift_1st = 7;
394   int32_t shift_2nd = 12 - (bitdepth - 8);
395   ALIGNED(64) int16_t tmp[8 * 8];
396 
397   const int16_t *tdct = &kvz_g_dct_8_t[0][0];
398   const int16_t *dct  = &kvz_g_dct_8  [0][0];
399 
400   mul_clip_matrix_8x8_avx2(tdct, input, tmp,    shift_1st);
401   mul_clip_matrix_8x8_avx2(tmp,  dct,   output, shift_2nd);
402 
403   /*
404    * Because:
405    * out = tdct * input * dct = tdct * (input * dct) = tdct * (input * transpose(tdct))
406    * This could almost be done this way:
407    *
408    * matmul_8x8_a_bt_t(input, tdct, debias1, shift_1st, tmp);
409    * matmul_8x8_a_bt  (tdct,  tmp,  debias2, shift_2nd, output);
410    *
411    * But not really, since it will fall victim to some very occasional
412    * rounding errors. Sadly.
413    */
414 }
415 
matmul_16x16_a_bt(const __m256i * a,const __m256i * b_t,__m256i * output,const int32_t shift)416 static void matmul_16x16_a_bt(const __m256i *a,
417                               const __m256i *b_t,
418                                     __m256i *output,
419                               const int32_t  shift)
420 {
421   const int32_t add    = 1 << (shift - 1);
422   const __m256i debias = _mm256_set1_epi32(add);
423 
424   for (int32_t y = 0; y < 16; y++) {
425     __m256i a_r = a[y];
426     __m256i results_32[2];
427 
428     for (int32_t fco = 0; fco < 2; fco++) {
429       // Read first cols 0, 1, 2, 3, 8, 9, 10, 11, and then next 4
430       __m256i bt_c0  = b_t[fco * 4 + 0];
431       __m256i bt_c1  = b_t[fco * 4 + 1];
432       __m256i bt_c2  = b_t[fco * 4 + 2];
433       __m256i bt_c3  = b_t[fco * 4 + 3];
434       __m256i bt_c8  = b_t[fco * 4 + 8];
435       __m256i bt_c9  = b_t[fco * 4 + 9];
436       __m256i bt_c10 = b_t[fco * 4 + 10];
437       __m256i bt_c11 = b_t[fco * 4 + 11];
438 
439       __m256i p0  = _mm256_madd_epi16(a_r, bt_c0);
440       __m256i p1  = _mm256_madd_epi16(a_r, bt_c1);
441       __m256i p2  = _mm256_madd_epi16(a_r, bt_c2);
442       __m256i p3  = _mm256_madd_epi16(a_r, bt_c3);
443       __m256i p8  = _mm256_madd_epi16(a_r, bt_c8);
444       __m256i p9  = _mm256_madd_epi16(a_r, bt_c9);
445       __m256i p10 = _mm256_madd_epi16(a_r, bt_c10);
446       __m256i p11 = _mm256_madd_epi16(a_r, bt_c11);
447 
448       // Combine low lanes from P0 and P8, high lanes from them, and the same
449       // with P1:P9 and so on
450       __m256i p0l = _mm256_permute2x128_si256(p0, p8,  0x20);
451       __m256i p0h = _mm256_permute2x128_si256(p0, p8,  0x31);
452       __m256i p1l = _mm256_permute2x128_si256(p1, p9,  0x20);
453       __m256i p1h = _mm256_permute2x128_si256(p1, p9,  0x31);
454       __m256i p2l = _mm256_permute2x128_si256(p2, p10, 0x20);
455       __m256i p2h = _mm256_permute2x128_si256(p2, p10, 0x31);
456       __m256i p3l = _mm256_permute2x128_si256(p3, p11, 0x20);
457       __m256i p3h = _mm256_permute2x128_si256(p3, p11, 0x31);
458 
459       __m256i s0  = _mm256_add_epi32(p0l, p0h);
460       __m256i s1  = _mm256_add_epi32(p1l, p1h);
461       __m256i s2  = _mm256_add_epi32(p2l, p2h);
462       __m256i s3  = _mm256_add_epi32(p3l, p3h);
463 
464       __m256i s4  = _mm256_unpacklo_epi64(s0, s1);
465       __m256i s5  = _mm256_unpackhi_epi64(s0, s1);
466       __m256i s6  = _mm256_unpacklo_epi64(s2, s3);
467       __m256i s7  = _mm256_unpackhi_epi64(s2, s3);
468 
469       __m256i s8  = _mm256_add_epi32(s4, s5);
470       __m256i s9  = _mm256_add_epi32(s6, s7);
471 
472       __m256i res = _mm256_hadd_epi32(s8, s9);
473       results_32[fco] = truncate_avx2(res, debias, shift);
474     }
475     output[y] = _mm256_packs_epi32(results_32[0], results_32[1]);
476   }
477 }
478 
479 // NOTE: The strides measured by s_stride_log2 and d_stride_log2 are in units
480 // of 16 coeffs, not 1!
transpose_16x16_stride(const int16_t * src,int16_t * dst,uint8_t s_stride_log2,uint8_t d_stride_log2)481 static void transpose_16x16_stride(const int16_t *src,
482                                          int16_t *dst,
483                                          uint8_t  s_stride_log2,
484                                          uint8_t  d_stride_log2)
485 {
486   __m256i tmp_128[16];
487   for (uint32_t i = 0; i < 16; i += 8) {
488 
489     // After every n-bit unpack, 2n-bit units in the vectors will be in
490     // correct order. Pair words first, then dwords, then qwords. After that,
491     // whole lanes will be correct.
492     __m256i tmp_32[8];
493     __m256i tmp_64[8];
494 
495     __m256i m[8] = {
496       _mm256_load_si256((const __m256i *)src + ((i + 0) << s_stride_log2)),
497       _mm256_load_si256((const __m256i *)src + ((i + 1) << s_stride_log2)),
498       _mm256_load_si256((const __m256i *)src + ((i + 2) << s_stride_log2)),
499       _mm256_load_si256((const __m256i *)src + ((i + 3) << s_stride_log2)),
500       _mm256_load_si256((const __m256i *)src + ((i + 4) << s_stride_log2)),
501       _mm256_load_si256((const __m256i *)src + ((i + 5) << s_stride_log2)),
502       _mm256_load_si256((const __m256i *)src + ((i + 6) << s_stride_log2)),
503       _mm256_load_si256((const __m256i *)src + ((i + 7) << s_stride_log2)),
504     };
505 
506     tmp_32[0]      = _mm256_unpacklo_epi16(     m[0],      m[1]);
507     tmp_32[1]      = _mm256_unpacklo_epi16(     m[2],      m[3]);
508     tmp_32[2]      = _mm256_unpackhi_epi16(     m[0],      m[1]);
509     tmp_32[3]      = _mm256_unpackhi_epi16(     m[2],      m[3]);
510 
511     tmp_32[4]      = _mm256_unpacklo_epi16(     m[4],      m[5]);
512     tmp_32[5]      = _mm256_unpacklo_epi16(     m[6],      m[7]);
513     tmp_32[6]      = _mm256_unpackhi_epi16(     m[4],      m[5]);
514     tmp_32[7]      = _mm256_unpackhi_epi16(     m[6],      m[7]);
515 
516 
517     tmp_64[0]      = _mm256_unpacklo_epi32(tmp_32[0], tmp_32[1]);
518     tmp_64[1]      = _mm256_unpacklo_epi32(tmp_32[2], tmp_32[3]);
519     tmp_64[2]      = _mm256_unpackhi_epi32(tmp_32[0], tmp_32[1]);
520     tmp_64[3]      = _mm256_unpackhi_epi32(tmp_32[2], tmp_32[3]);
521 
522     tmp_64[4]      = _mm256_unpacklo_epi32(tmp_32[4], tmp_32[5]);
523     tmp_64[5]      = _mm256_unpacklo_epi32(tmp_32[6], tmp_32[7]);
524     tmp_64[6]      = _mm256_unpackhi_epi32(tmp_32[4], tmp_32[5]);
525     tmp_64[7]      = _mm256_unpackhi_epi32(tmp_32[6], tmp_32[7]);
526 
527 
528     tmp_128[i + 0] = _mm256_unpacklo_epi64(tmp_64[0], tmp_64[4]);
529     tmp_128[i + 1] = _mm256_unpackhi_epi64(tmp_64[0], tmp_64[4]);
530     tmp_128[i + 2] = _mm256_unpacklo_epi64(tmp_64[2], tmp_64[6]);
531     tmp_128[i + 3] = _mm256_unpackhi_epi64(tmp_64[2], tmp_64[6]);
532 
533     tmp_128[i + 4] = _mm256_unpacklo_epi64(tmp_64[1], tmp_64[5]);
534     tmp_128[i + 5] = _mm256_unpackhi_epi64(tmp_64[1], tmp_64[5]);
535     tmp_128[i + 6] = _mm256_unpacklo_epi64(tmp_64[3], tmp_64[7]);
536     tmp_128[i + 7] = _mm256_unpackhi_epi64(tmp_64[3], tmp_64[7]);
537   }
538 
539   for (uint32_t i = 0; i < 8; i++) {
540     uint32_t loid     = i + 0;
541     uint32_t hiid     = i + 8;
542 
543     uint32_t dst_loid = loid << d_stride_log2;
544     uint32_t dst_hiid = hiid << d_stride_log2;
545 
546     __m256i lo       = tmp_128[loid];
547     __m256i hi       = tmp_128[hiid];
548     __m256i final_lo = _mm256_permute2x128_si256(lo, hi, 0x20);
549     __m256i final_hi = _mm256_permute2x128_si256(lo, hi, 0x31);
550 
551     _mm256_store_si256((__m256i *)dst + dst_loid, final_lo);
552     _mm256_store_si256((__m256i *)dst + dst_hiid, final_hi);
553   }
554 }
555 
transpose_16x16(const int16_t * src,int16_t * dst)556 static void transpose_16x16(const int16_t *src, int16_t *dst)
557 {
558   transpose_16x16_stride(src, dst, 0, 0);
559 }
560 
truncate_inv(__m256i v,int32_t shift)561 static __m256i truncate_inv(__m256i v, int32_t shift)
562 {
563   int32_t add = 1 << (shift - 1);
564 
565   __m256i debias  = _mm256_set1_epi32(add);
566   __m256i v2      = _mm256_add_epi32 (v,  debias);
567   __m256i trunced = _mm256_srai_epi32(v2, shift);
568   return  trunced;
569 }
570 
extract_odds(__m256i v)571 static __m256i extract_odds(__m256i v)
572 {
573   // 0 1 2 3 4 5 6 7 | 8 9 a b c d e f => 1 3 5 7 1 3 5 7 | 9 b d f 9 b d f
574   const __m256i oddmask = _mm256_setr_epi8( 2,  3,  6,  7, 10, 11, 14, 15,
575                                             2,  3,  6,  7, 10, 11, 14, 15,
576                                             2,  3,  6,  7, 10, 11, 14, 15,
577                                             2,  3,  6,  7, 10, 11, 14, 15);
578 
579   __m256i tmp = _mm256_shuffle_epi8 (v,   oddmask);
580   return _mm256_permute4x64_epi64   (tmp, _MM_SHUFFLE(3, 1, 2, 0));
581 }
582 
extract_combine_odds(__m256i v0,__m256i v1)583 static __m256i extract_combine_odds(__m256i v0, __m256i v1)
584 {
585   // 0 1 2 3 4 5 6 7 | 8 9 a b c d e f => 1 3 5 7 1 3 5 7 | 9 b d f 9 b d f
586   const __m256i oddmask = _mm256_setr_epi8( 2,  3,  6,  7, 10, 11, 14, 15,
587                                             2,  3,  6,  7, 10, 11, 14, 15,
588                                             2,  3,  6,  7, 10, 11, 14, 15,
589                                             2,  3,  6,  7, 10, 11, 14, 15);
590 
591   __m256i tmp0 = _mm256_shuffle_epi8(v0,   oddmask);
592   __m256i tmp1 = _mm256_shuffle_epi8(v1,   oddmask);
593 
594   __m256i tmp2 = _mm256_blend_epi32 (tmp0, tmp1, 0xcc); // 1100 1100
595 
596   return _mm256_permute4x64_epi64   (tmp2, _MM_SHUFFLE(3, 1, 2, 0));
597 }
598 
599 // Extract items 2, 6, A and E from first four columns of DCT, order them as
600 // follows:
601 // D0,2 D0,6 D1,2 D1,6 D1,a D1,e D0,a D0,e | D2,2 D2,6 D3,2 D3,6 D3,a D3,e D2,a D2,e
extract_26ae(const __m256i * tdct)602 static __m256i extract_26ae(const __m256i *tdct)
603 {
604   // 02 03 22 23 06 07 26 27 | 0a 0b 2a 2b 02 0f 2e 2f
605   // =>
606   // 02 06 22 26 02 06 22 26 | 2a 2e 0a 0e 2a 2e 0a 0e
607   const __m256i evens_mask = _mm256_setr_epi8( 0,  1,  8,  9,  4,  5, 12, 13,
608                                                0,  1,  8,  9,  4,  5, 12, 13,
609                                                4,  5, 12, 13,  0,  1,  8,  9,
610                                                4,  5, 12, 13,  0,  1,  8,  9);
611 
612   __m256i shufd_0 = _mm256_shuffle_epi32(tdct[0], _MM_SHUFFLE(2, 3, 0, 1));
613   __m256i shufd_2 = _mm256_shuffle_epi32(tdct[2], _MM_SHUFFLE(2, 3, 0, 1));
614 
615   __m256i cmbd_01 = _mm256_blend_epi32(shufd_0, tdct[1], 0xaa); // 1010 1010
616   __m256i cmbd_23 = _mm256_blend_epi32(shufd_2, tdct[3], 0xaa); // 1010 1010
617 
618   __m256i evens_01 = _mm256_shuffle_epi8(cmbd_01, evens_mask);
619   __m256i evens_23 = _mm256_shuffle_epi8(cmbd_23, evens_mask);
620 
621   __m256i evens_0123 = _mm256_unpacklo_epi64(evens_01, evens_23);
622 
623   return _mm256_permute4x64_epi64(evens_0123, _MM_SHUFFLE(3, 1, 2, 0));
624 }
625 
626 // 2 6 2 6 a e a e | 2 6 2 6 a e a e
extract_26ae_vec(__m256i col)627 static __m256i extract_26ae_vec(__m256i col)
628 {
629   const __m256i mask_26ae = _mm256_set1_epi32(0x0d0c0504);
630 
631   // 2 6 2 6 2 6 2 6 | a e a e a e a e
632   __m256i reord = _mm256_shuffle_epi8     (col,   mask_26ae);
633   __m256i final = _mm256_permute4x64_epi64(reord, _MM_SHUFFLE(3, 1, 2, 0));
634   return  final;
635 }
636 
637 // D00 D80 D01 D81 D41 Dc1 D40 Dc0 | D40 Dc0 D41 Dc1 D01 D81 D00 D80
extract_d048c(const __m256i * tdct)638 static __m256i extract_d048c(const __m256i *tdct)
639 {
640   const __m256i final_shuf = _mm256_setr_epi8( 0,  1,  8,  9,  2,  3, 10, 11,
641                                                6,  7, 14, 15,  4,  5, 12, 13,
642                                                4,  5, 12, 13,  6,  7, 14, 15,
643                                                2,  3, 10, 11,  0,  1,  8,  9);
644   __m256i c0 = tdct[0];
645   __m256i c1 = tdct[1];
646 
647   __m256i c1_2  = _mm256_slli_epi32       (c1,    16);
648   __m256i cmbd  = _mm256_blend_epi16      (c0,    c1_2, 0x22); // 0010 0010
649   __m256i cmbd2 = _mm256_shuffle_epi32    (cmbd,  _MM_SHUFFLE(2, 0, 2, 0));
650   __m256i cmbd3 = _mm256_permute4x64_epi64(cmbd2, _MM_SHUFFLE(3, 1, 2, 0));
651   __m256i final = _mm256_shuffle_epi8     (cmbd3, final_shuf);
652 
653   return final;
654 }
655 
656 // 0 8 0 8 4 c 4 c | 4 c 4 c 0 8 0 8
extract_d048c_vec(__m256i col)657 static __m256i extract_d048c_vec(__m256i col)
658 {
659   const __m256i shufmask = _mm256_setr_epi8( 0,  1,  0,  1,  8,  9,  8,  9,
660                                              8,  9,  8,  9,  0,  1,  0,  1,
661                                              0,  1,  0,  1,  8,  9,  8,  9,
662                                              8,  9,  8,  9,  0,  1,  0,  1);
663 
664   __m256i col_db4s = _mm256_shuffle_epi8     (col, shufmask);
665   __m256i col_los  = _mm256_permute4x64_epi64(col_db4s, _MM_SHUFFLE(1, 1, 0, 0));
666   __m256i col_his  = _mm256_permute4x64_epi64(col_db4s, _MM_SHUFFLE(3, 3, 2, 2));
667 
668   __m256i final    = _mm256_unpacklo_epi16   (col_los,  col_his);
669   return final;
670 }
671 
partial_butterfly_inverse_16_avx2(const int16_t * src,int16_t * dst,int32_t shift)672 static void partial_butterfly_inverse_16_avx2(const int16_t *src, int16_t *dst, int32_t shift)
673 {
674   __m256i tsrc[16];
675 
676   const uint32_t width = 16;
677 
678   const int16_t *tdct = &kvz_g_dct_16_t[0][0];
679 
680   const __m256i  eo_signmask = _mm256_setr_epi32( 1,  1,  1,  1, -1, -1, -1, -1);
681   const __m256i eeo_signmask = _mm256_setr_epi32( 1,  1, -1, -1, -1, -1,  1,  1);
682   const __m256i   o_signmask = _mm256_set1_epi32(-1);
683 
684   const __m256i final_shufmask = _mm256_setr_epi8( 0,  1,  2,  3,  4,  5,  6,  7,
685                                                    8,  9, 10, 11, 12, 13, 14, 15,
686                                                    6,  7,  4,  5,  2,  3,  0,  1,
687                                                   14, 15, 12, 13, 10, 11,  8,  9);
688   transpose_16x16(src, (int16_t *)tsrc);
689 
690   const __m256i dct_cols[8] = {
691     _mm256_load_si256((const __m256i *)tdct + 0),
692     _mm256_load_si256((const __m256i *)tdct + 1),
693     _mm256_load_si256((const __m256i *)tdct + 2),
694     _mm256_load_si256((const __m256i *)tdct + 3),
695     _mm256_load_si256((const __m256i *)tdct + 4),
696     _mm256_load_si256((const __m256i *)tdct + 5),
697     _mm256_load_si256((const __m256i *)tdct + 6),
698     _mm256_load_si256((const __m256i *)tdct + 7),
699   };
700 
701   // These contain: D1,0 D3,0 D5,0 D7,0 D9,0 Db,0 Dd,0 Df,0 | D1,4 D3,4 D5,4 D7,4 D9,4 Db,4 Dd,4 Df,4
702   //                D1,1 D3,1 D5,1 D7,1 D9,1 Db,1 Dd,1 Df,1 | D1,5 D3,5 D5,5 D7,5 D9,5 Db,5 Dd,5 Df,5
703   //                D1,2 D3,2 D5,2 D7,2 D9,2 Db,2 Dd,2 Df,2 | D1,6 D3,6 D5,6 D7,6 D9,6 Db,6 Dd,6 Df,6
704   //                D1,3 D3,3 D5,3 D7,3 D9,3 Db,3 Dd,3 Df,3 | D1,7 D3,7 D5,7 D7,7 D9,7 Db,7 Dd,7 Df,7
705   __m256i dct_col_odds[4];
706   for (uint32_t j = 0; j < 4; j++) {
707     dct_col_odds[j] = extract_combine_odds(dct_cols[j + 0], dct_cols[j + 4]);
708   }
709   for (uint32_t j = 0; j < width; j++) {
710     __m256i col = tsrc[j];
711     __m256i odds = extract_odds(col);
712 
713     __m256i o04   = _mm256_madd_epi16           (odds,     dct_col_odds[0]);
714     __m256i o15   = _mm256_madd_epi16           (odds,     dct_col_odds[1]);
715     __m256i o26   = _mm256_madd_epi16           (odds,     dct_col_odds[2]);
716     __m256i o37   = _mm256_madd_epi16           (odds,     dct_col_odds[3]);
717 
718     __m256i o0145 = _mm256_hadd_epi32           (o04,      o15);
719     __m256i o2367 = _mm256_hadd_epi32           (o26,      o37);
720 
721     __m256i o     = _mm256_hadd_epi32           (o0145,    o2367);
722 
723     // D0,2 D0,6 D1,2 D1,6 D1,a D1,e D0,a D0,e | D2,2 D2,6 D3,2 D3,6 D3,a D3,e D2,a D2,e
724     __m256i d_db2 = extract_26ae(dct_cols);
725 
726     // 2 6 2 6 a e a e | 2 6 2 6 a e a e
727     __m256i t_db2 = extract_26ae_vec            (col);
728 
729     __m256i eo_parts  = _mm256_madd_epi16       (d_db2,    t_db2);
730     __m256i eo_parts2 = _mm256_shuffle_epi32    (eo_parts, _MM_SHUFFLE(0, 1, 2, 3));
731 
732     // EO0 EO1 EO1 EO0 | EO2 EO3 EO3 EO2
733     __m256i eo        = _mm256_add_epi32        (eo_parts, eo_parts2);
734     __m256i eo2       = _mm256_permute4x64_epi64(eo,       _MM_SHUFFLE(1, 3, 2, 0));
735     __m256i eo3       = _mm256_sign_epi32       (eo2,      eo_signmask);
736 
737     __m256i d_db4     = extract_d048c           (dct_cols);
738     __m256i t_db4     = extract_d048c_vec       (col);
739     __m256i eee_eeo   = _mm256_madd_epi16       (d_db4,   t_db4);
740 
741     __m256i eee_eee   = _mm256_permute4x64_epi64(eee_eeo,  _MM_SHUFFLE(3, 0, 3, 0));
742     __m256i eeo_eeo1  = _mm256_permute4x64_epi64(eee_eeo,  _MM_SHUFFLE(1, 2, 1, 2));
743 
744     __m256i eeo_eeo2  = _mm256_sign_epi32       (eeo_eeo1, eeo_signmask);
745 
746     // EE0 EE1 EE2 EE3 | EE3 EE2 EE1 EE0
747     __m256i ee        = _mm256_add_epi32        (eee_eee,  eeo_eeo2);
748     __m256i e         = _mm256_add_epi32        (ee,       eo3);
749 
750     __m256i o_neg     = _mm256_sign_epi32       (o,        o_signmask);
751     __m256i o_lo      = _mm256_blend_epi32      (o,        o_neg, 0xf0); // 1111 0000
752     __m256i o_hi      = _mm256_blend_epi32      (o,        o_neg, 0x0f); // 0000 1111
753 
754     __m256i res_lo    = _mm256_add_epi32        (e,        o_lo);
755     __m256i res_hi    = _mm256_add_epi32        (e,        o_hi);
756     __m256i res_hi2   = _mm256_permute4x64_epi64(res_hi,   _MM_SHUFFLE(1, 0, 3, 2));
757 
758     __m256i res_lo_t  = truncate_inv(res_lo,  shift);
759     __m256i res_hi_t  = truncate_inv(res_hi2, shift);
760 
761     __m256i res_16_1  = _mm256_packs_epi32      (res_lo_t, res_hi_t);
762     __m256i final     = _mm256_shuffle_epi8     (res_16_1, final_shufmask);
763 
764     _mm256_store_si256((__m256i *)dst + j, final);
765   }
766 }
767 
matrix_idct_16x16_avx2(int8_t bitdepth,const int16_t * input,int16_t * output)768 static void matrix_idct_16x16_avx2(int8_t bitdepth, const int16_t *input, int16_t *output)
769 {
770   int32_t shift_1st = 7;
771   int32_t shift_2nd = 12 - (bitdepth - 8);
772   ALIGNED(64) int16_t tmp[16 * 16];
773 
774   partial_butterfly_inverse_16_avx2(input, tmp,    shift_1st);
775   partial_butterfly_inverse_16_avx2(tmp,   output, shift_2nd);
776 }
777 
matrix_dct_16x16_avx2(int8_t bitdepth,const int16_t * input,int16_t * output)778 static void matrix_dct_16x16_avx2(int8_t bitdepth, const int16_t *input, int16_t *output)
779 {
780   int32_t shift_1st = kvz_g_convert_to_bit[16] + 1 + (bitdepth - 8);
781   int32_t shift_2nd = kvz_g_convert_to_bit[16] + 8;
782 
783   const int16_t *dct  = &kvz_g_dct_16[0][0];
784 
785   /*
786    * Multiply input by the tranpose of DCT matrix into tmpres, and DCT matrix
787    * by tmpres - this is then our output matrix
788    *
789    * It's easier to implement an AVX2 matrix multiplication if you can multiply
790    * the left term with the transpose of the right term. Here things are stored
791    * row-wise, not column-wise, so we can effectively read DCT_T column-wise
792    * into YMM registers by reading DCT row-wise. Also because of this, the
793    * first multiplication is hacked to produce the transpose of the result
794    * instead, since it will be used in similar fashion as the right operand
795    * in the second multiplication.
796    */
797 
798   const __m256i *d_v = (const __m256i *)dct;
799   const __m256i *i_v = (const __m256i *)input;
800         __m256i *o_v = (      __m256i *)output;
801   __m256i tmp[16];
802 
803   // Hack! (A * B^T)^T = B * A^T, so we can dispatch the transpose-produciong
804   // multiply completely
805   matmul_16x16_a_bt(d_v, i_v, tmp, shift_1st);
806   matmul_16x16_a_bt(d_v, tmp, o_v, shift_2nd);
807 }
808 
809 // 32x32 matrix multiplication with value clipping.
810 // Parameters: Two 32x32 matrices containing 16-bit values in consecutive addresses,
811 //             destination for the result and the shift value for clipping.
mul_clip_matrix_32x32_avx2(const int16_t * left,const int16_t * right,int16_t * dst,const int32_t shift)812 static void mul_clip_matrix_32x32_avx2(const int16_t *left,
813                                        const int16_t *right,
814                                              int16_t *dst,
815                                        const int32_t  shift)
816 {
817   const int32_t add    = 1 << (shift - 1);
818   const __m256i debias = _mm256_set1_epi32(add);
819 
820   const uint32_t *l_32  = (const uint32_t *)left;
821   const __m256i  *r_v   = (const __m256i *)right;
822         __m256i  *dst_v = (      __m256i *)dst;
823 
824   __m256i accu[128] = {_mm256_setzero_si256()};
825   size_t i, j;
826 
827   for (j = 0; j < 64; j += 4) {
828     const __m256i r0 = r_v[j + 0];
829     const __m256i r1 = r_v[j + 1];
830     const __m256i r2 = r_v[j + 2];
831     const __m256i r3 = r_v[j + 3];
832 
833     __m256i r02l   = _mm256_unpacklo_epi16(r0, r2);
834     __m256i r02h   = _mm256_unpackhi_epi16(r0, r2);
835     __m256i r13l   = _mm256_unpacklo_epi16(r1, r3);
836     __m256i r13h   = _mm256_unpackhi_epi16(r1, r3);
837 
838     __m256i r02_07 = _mm256_permute2x128_si256(r02l, r02h, 0x20);
839     __m256i r02_8f = _mm256_permute2x128_si256(r02l, r02h, 0x31);
840 
841     __m256i r13_07 = _mm256_permute2x128_si256(r13l, r13h, 0x20);
842     __m256i r13_8f = _mm256_permute2x128_si256(r13l, r13h, 0x31);
843 
844     for (i = 0; i < 32; i += 2) {
845       size_t acc_base = i << 2;
846 
847       uint32_t curr_e    = l_32[(i + 0) * (32 / 2) + (j >> 2)];
848       uint32_t curr_o    = l_32[(i + 1) * (32 / 2) + (j >> 2)];
849 
850       __m256i even       = _mm256_set1_epi32(curr_e);
851       __m256i odd        = _mm256_set1_epi32(curr_o);
852 
853       __m256i p_e0       = _mm256_madd_epi16(even, r02_07);
854       __m256i p_e1       = _mm256_madd_epi16(even, r02_8f);
855       __m256i p_e2       = _mm256_madd_epi16(even, r13_07);
856       __m256i p_e3       = _mm256_madd_epi16(even, r13_8f);
857 
858       __m256i p_o0       = _mm256_madd_epi16(odd,  r02_07);
859       __m256i p_o1       = _mm256_madd_epi16(odd,  r02_8f);
860       __m256i p_o2       = _mm256_madd_epi16(odd,  r13_07);
861       __m256i p_o3       = _mm256_madd_epi16(odd,  r13_8f);
862 
863       accu[acc_base + 0] = _mm256_add_epi32 (p_e0, accu[acc_base + 0]);
864       accu[acc_base + 1] = _mm256_add_epi32 (p_e1, accu[acc_base + 1]);
865       accu[acc_base + 2] = _mm256_add_epi32 (p_e2, accu[acc_base + 2]);
866       accu[acc_base + 3] = _mm256_add_epi32 (p_e3, accu[acc_base + 3]);
867 
868       accu[acc_base + 4] = _mm256_add_epi32 (p_o0, accu[acc_base + 4]);
869       accu[acc_base + 5] = _mm256_add_epi32 (p_o1, accu[acc_base + 5]);
870       accu[acc_base + 6] = _mm256_add_epi32 (p_o2, accu[acc_base + 6]);
871       accu[acc_base + 7] = _mm256_add_epi32 (p_o3, accu[acc_base + 7]);
872     }
873   }
874 
875   for (i = 0; i < 32; i++) {
876     size_t acc_base = i << 2;
877     size_t dst_base = i << 1;
878 
879     __m256i q0  = truncate_avx2(accu[acc_base + 0], debias, shift);
880     __m256i q1  = truncate_avx2(accu[acc_base + 1], debias, shift);
881     __m256i q2  = truncate_avx2(accu[acc_base + 2], debias, shift);
882     __m256i q3  = truncate_avx2(accu[acc_base + 3], debias, shift);
883 
884     __m256i h01 = _mm256_packs_epi32(q0, q1);
885     __m256i h23 = _mm256_packs_epi32(q2, q3);
886 
887             h01 = _mm256_permute4x64_epi64(h01, _MM_SHUFFLE(3, 1, 2, 0));
888             h23 = _mm256_permute4x64_epi64(h23, _MM_SHUFFLE(3, 1, 2, 0));
889 
890     _mm256_store_si256(dst_v + dst_base + 0, h01);
891     _mm256_store_si256(dst_v + dst_base + 1, h23);
892   }
893 }
894 
895 // Macro that generates 2D transform functions with clipping values.
896 // Sets correct shift values and matrices according to transform type and
897 // block size. Performs matrix multiplication horizontally and vertically.
898 #define TRANSFORM(type, n) static void matrix_ ## type ## _ ## n ## x ## n ## _avx2(int8_t bitdepth, const int16_t *input, int16_t *output)\
899 {\
900   int32_t shift_1st = kvz_g_convert_to_bit[n] + 1 + (bitdepth - 8); \
901   int32_t shift_2nd = kvz_g_convert_to_bit[n] + 8; \
902   ALIGNED(64) int16_t tmp[n * n];\
903   const int16_t *tdct = &kvz_g_ ## type ## _ ## n ## _t[0][0];\
904   const int16_t *dct = &kvz_g_ ## type ## _ ## n [0][0];\
905 \
906   mul_clip_matrix_ ## n ## x ## n ## _avx2(input, tdct, tmp, shift_1st);\
907   mul_clip_matrix_ ## n ## x ## n ## _avx2(dct, tmp, output, shift_2nd);\
908 }\
909 
910 // Macro that generates 2D inverse transform functions with clipping values.
911 // Sets correct shift values and matrices according to transform type and
912 // block size. Performs matrix multiplication horizontally and vertically.
913 #define ITRANSFORM(type, n) \
914 static void matrix_i ## type ## _## n ## x ## n ## _avx2(int8_t bitdepth, const int16_t *input, int16_t *output)\
915 {\
916   int32_t shift_1st = 7; \
917   int32_t shift_2nd = 12 - (bitdepth - 8); \
918   ALIGNED(64) int16_t tmp[n * n];\
919   const int16_t *tdct = &kvz_g_ ## type ## _ ## n ## _t[0][0];\
920   const int16_t *dct = &kvz_g_ ## type ## _ ## n [0][0];\
921 \
922   mul_clip_matrix_ ## n ## x ## n ## _avx2(tdct, input, tmp, shift_1st);\
923   mul_clip_matrix_ ## n ## x ## n ## _avx2(tmp, dct, output, shift_2nd);\
924 }\
925 
926 // Ha, we've got a tailored implementation for these
927 // TRANSFORM(dst, 4);
928 // ITRANSFORM(dst, 4);
929 // TRANSFORM(dct, 4);
930 // ITRANSFORM(dct, 4);
931 // TRANSFORM(dct, 8);
932 // ITRANSFORM(dct, 8);
933 // TRANSFORM(dct, 16);
934 // ITRANSFORM(dct, 16);
935 
936 // Generate all the transform functions
937 
938 TRANSFORM(dct, 32);
939 ITRANSFORM(dct, 32);
940 
941 #endif // KVZ_BIT_DEPTH == 8
942 #endif //COMPILE_INTEL_AVX2
943 
kvz_strategy_register_dct_avx2(void * opaque,uint8_t bitdepth)944 int kvz_strategy_register_dct_avx2(void* opaque, uint8_t bitdepth)
945 {
946   bool success = true;
947 #if COMPILE_INTEL_AVX2
948 #if KVZ_BIT_DEPTH == 8
949   if (bitdepth == 8){
950     success &= kvz_strategyselector_register(opaque, "fast_forward_dst_4x4", "avx2", 40, &matrix_dst_4x4_avx2);
951 
952     success &= kvz_strategyselector_register(opaque, "dct_4x4", "avx2", 40, &matrix_dct_4x4_avx2);
953     success &= kvz_strategyselector_register(opaque, "dct_8x8", "avx2", 40, &matrix_dct_8x8_avx2);
954     success &= kvz_strategyselector_register(opaque, "dct_16x16", "avx2", 40, &matrix_dct_16x16_avx2);
955     success &= kvz_strategyselector_register(opaque, "dct_32x32", "avx2", 40, &matrix_dct_32x32_avx2);
956 
957     success &= kvz_strategyselector_register(opaque, "fast_inverse_dst_4x4", "avx2", 40, &matrix_idst_4x4_avx2);
958 
959     success &= kvz_strategyselector_register(opaque, "idct_4x4", "avx2", 40, &matrix_idct_4x4_avx2);
960     success &= kvz_strategyselector_register(opaque, "idct_8x8", "avx2", 40, &matrix_idct_8x8_avx2);
961     success &= kvz_strategyselector_register(opaque, "idct_16x16", "avx2", 40, &matrix_idct_16x16_avx2);
962     success &= kvz_strategyselector_register(opaque, "idct_32x32", "avx2", 40, &matrix_idct_32x32_avx2);
963   }
964 #endif // KVZ_BIT_DEPTH == 8
965 #endif //COMPILE_INTEL_AVX2
966   return success;
967 }
968