1 /*
2  * Copyright (c) 2018, Alliance for Open Media. All rights reserved
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <immintrin.h>
13 #include <assert.h>
14 
15 #include "config/aom_dsp_rtcd.h"
16 
17 #include "aom_dsp/x86/convolve_avx2.h"
18 #include "aom_dsp/x86/convolve_common_intrin.h"
19 #include "aom_dsp/x86/convolve_sse4_1.h"
20 #include "aom_dsp/x86/synonyms.h"
21 #include "aom_dsp/aom_dsp_common.h"
22 #include "aom_dsp/aom_filter.h"
23 #include "av1/common/convolve.h"
24 
av1_highbd_dist_wtd_convolve_2d_copy_avx2(const uint16_t * src,int src_stride,uint16_t * dst0,int dst_stride0,int w,int h,ConvolveParams * conv_params,int bd)25 void av1_highbd_dist_wtd_convolve_2d_copy_avx2(const uint16_t *src,
26                                                int src_stride, uint16_t *dst0,
27                                                int dst_stride0, int w, int h,
28                                                ConvolveParams *conv_params,
29                                                int bd) {
30   CONV_BUF_TYPE *dst = conv_params->dst;
31   int dst_stride = conv_params->dst_stride;
32 
33   const int bits =
34       FILTER_BITS * 2 - conv_params->round_1 - conv_params->round_0;
35   const __m128i left_shift = _mm_cvtsi32_si128(bits);
36   const int do_average = conv_params->do_average;
37   const int use_dist_wtd_comp_avg = conv_params->use_dist_wtd_comp_avg;
38   const int w0 = conv_params->fwd_offset;
39   const int w1 = conv_params->bck_offset;
40   const __m256i wt0 = _mm256_set1_epi32(w0);
41   const __m256i wt1 = _mm256_set1_epi32(w1);
42   const __m256i zero = _mm256_setzero_si256();
43   int i, j;
44 
45   const int offset_0 =
46       bd + 2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1;
47   const int offset = (1 << offset_0) + (1 << (offset_0 - 1));
48   const __m256i offset_const = _mm256_set1_epi32(offset);
49   const __m256i offset_const_16b = _mm256_set1_epi16(offset);
50   const int rounding_shift =
51       2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1;
52   const __m256i rounding_const = _mm256_set1_epi32((1 << rounding_shift) >> 1);
53   const __m256i clip_pixel_to_bd =
54       _mm256_set1_epi16(bd == 10 ? 1023 : (bd == 12 ? 4095 : 255));
55 
56   assert(bits <= 4);
57 
58   if (!(w % 16)) {
59     for (i = 0; i < h; i += 1) {
60       for (j = 0; j < w; j += 16) {
61         const __m256i src_16bit =
62             _mm256_loadu_si256((__m256i *)(&src[i * src_stride + j]));
63 
64         const __m256i res = _mm256_sll_epi16(src_16bit, left_shift);
65 
66         if (do_average) {
67           const __m256i data_0 =
68               _mm256_loadu_si256((__m256i *)(&dst[i * dst_stride + j]));
69 
70           const __m256i data_ref_0_lo = _mm256_unpacklo_epi16(data_0, zero);
71           const __m256i data_ref_0_hi = _mm256_unpackhi_epi16(data_0, zero);
72 
73           const __m256i res_32b_lo = _mm256_unpacklo_epi16(res, zero);
74           const __m256i res_unsigned_lo =
75               _mm256_add_epi32(res_32b_lo, offset_const);
76 
77           const __m256i comp_avg_res_lo =
78               highbd_comp_avg(&data_ref_0_lo, &res_unsigned_lo, &wt0, &wt1,
79                               use_dist_wtd_comp_avg);
80 
81           const __m256i res_32b_hi = _mm256_unpackhi_epi16(res, zero);
82           const __m256i res_unsigned_hi =
83               _mm256_add_epi32(res_32b_hi, offset_const);
84 
85           const __m256i comp_avg_res_hi =
86               highbd_comp_avg(&data_ref_0_hi, &res_unsigned_hi, &wt0, &wt1,
87                               use_dist_wtd_comp_avg);
88 
89           const __m256i round_result_lo = highbd_convolve_rounding(
90               &comp_avg_res_lo, &offset_const, &rounding_const, rounding_shift);
91           const __m256i round_result_hi = highbd_convolve_rounding(
92               &comp_avg_res_hi, &offset_const, &rounding_const, rounding_shift);
93 
94           const __m256i res_16b =
95               _mm256_packus_epi32(round_result_lo, round_result_hi);
96           const __m256i res_clip = _mm256_min_epi16(res_16b, clip_pixel_to_bd);
97 
98           _mm256_store_si256((__m256i *)(&dst0[i * dst_stride0 + j]), res_clip);
99         } else {
100           const __m256i res_unsigned_16b =
101               _mm256_adds_epu16(res, offset_const_16b);
102 
103           _mm256_store_si256((__m256i *)(&dst[i * dst_stride + j]),
104                              res_unsigned_16b);
105         }
106       }
107     }
108   } else if (!(w % 4)) {
109     for (i = 0; i < h; i += 2) {
110       for (j = 0; j < w; j += 8) {
111         const __m128i src_row_0 =
112             _mm_loadu_si128((__m128i *)(&src[i * src_stride + j]));
113         const __m128i src_row_1 =
114             _mm_loadu_si128((__m128i *)(&src[i * src_stride + j + src_stride]));
115         // since not all compilers yet support _mm256_set_m128i()
116         const __m256i src_10 = _mm256_insertf128_si256(
117             _mm256_castsi128_si256(src_row_0), src_row_1, 1);
118 
119         const __m256i res = _mm256_sll_epi16(src_10, left_shift);
120 
121         if (w - j < 8) {
122           if (do_average) {
123             const __m256i data_0 = _mm256_castsi128_si256(
124                 _mm_loadl_epi64((__m128i *)(&dst[i * dst_stride + j])));
125             const __m256i data_1 = _mm256_castsi128_si256(_mm_loadl_epi64(
126                 (__m128i *)(&dst[i * dst_stride + j + dst_stride])));
127             const __m256i data_01 =
128                 _mm256_permute2x128_si256(data_0, data_1, 0x20);
129 
130             const __m256i data_ref_0 = _mm256_unpacklo_epi16(data_01, zero);
131 
132             const __m256i res_32b = _mm256_unpacklo_epi16(res, zero);
133             const __m256i res_unsigned_lo =
134                 _mm256_add_epi32(res_32b, offset_const);
135 
136             const __m256i comp_avg_res =
137                 highbd_comp_avg(&data_ref_0, &res_unsigned_lo, &wt0, &wt1,
138                                 use_dist_wtd_comp_avg);
139 
140             const __m256i round_result = highbd_convolve_rounding(
141                 &comp_avg_res, &offset_const, &rounding_const, rounding_shift);
142 
143             const __m256i res_16b =
144                 _mm256_packus_epi32(round_result, round_result);
145             const __m256i res_clip =
146                 _mm256_min_epi16(res_16b, clip_pixel_to_bd);
147 
148             const __m128i res_0 = _mm256_castsi256_si128(res_clip);
149             const __m128i res_1 = _mm256_extracti128_si256(res_clip, 1);
150 
151             _mm_storel_epi64((__m128i *)(&dst0[i * dst_stride0 + j]), res_0);
152             _mm_storel_epi64(
153                 (__m128i *)(&dst0[i * dst_stride0 + j + dst_stride0]), res_1);
154           } else {
155             const __m256i res_unsigned_16b =
156                 _mm256_adds_epu16(res, offset_const_16b);
157 
158             const __m128i res_0 = _mm256_castsi256_si128(res_unsigned_16b);
159             const __m128i res_1 = _mm256_extracti128_si256(res_unsigned_16b, 1);
160 
161             _mm_storel_epi64((__m128i *)(&dst[i * dst_stride + j]), res_0);
162             _mm_storel_epi64((__m128i *)(&dst[i * dst_stride + j + dst_stride]),
163                              res_1);
164           }
165         } else {
166           if (do_average) {
167             const __m256i data_0 = _mm256_castsi128_si256(
168                 _mm_loadu_si128((__m128i *)(&dst[i * dst_stride + j])));
169             const __m256i data_1 = _mm256_castsi128_si256(_mm_loadu_si128(
170                 (__m128i *)(&dst[i * dst_stride + j + dst_stride])));
171             const __m256i data_01 =
172                 _mm256_permute2x128_si256(data_0, data_1, 0x20);
173 
174             const __m256i data_ref_0_lo = _mm256_unpacklo_epi16(data_01, zero);
175             const __m256i data_ref_0_hi = _mm256_unpackhi_epi16(data_01, zero);
176 
177             const __m256i res_32b_lo = _mm256_unpacklo_epi16(res, zero);
178             const __m256i res_unsigned_lo =
179                 _mm256_add_epi32(res_32b_lo, offset_const);
180 
181             const __m256i comp_avg_res_lo =
182                 highbd_comp_avg(&data_ref_0_lo, &res_unsigned_lo, &wt0, &wt1,
183                                 use_dist_wtd_comp_avg);
184 
185             const __m256i res_32b_hi = _mm256_unpackhi_epi16(res, zero);
186             const __m256i res_unsigned_hi =
187                 _mm256_add_epi32(res_32b_hi, offset_const);
188 
189             const __m256i comp_avg_res_hi =
190                 highbd_comp_avg(&data_ref_0_hi, &res_unsigned_hi, &wt0, &wt1,
191                                 use_dist_wtd_comp_avg);
192 
193             const __m256i round_result_lo =
194                 highbd_convolve_rounding(&comp_avg_res_lo, &offset_const,
195                                          &rounding_const, rounding_shift);
196             const __m256i round_result_hi =
197                 highbd_convolve_rounding(&comp_avg_res_hi, &offset_const,
198                                          &rounding_const, rounding_shift);
199 
200             const __m256i res_16b =
201                 _mm256_packus_epi32(round_result_lo, round_result_hi);
202             const __m256i res_clip =
203                 _mm256_min_epi16(res_16b, clip_pixel_to_bd);
204 
205             const __m128i res_0 = _mm256_castsi256_si128(res_clip);
206             const __m128i res_1 = _mm256_extracti128_si256(res_clip, 1);
207 
208             _mm_store_si128((__m128i *)(&dst0[i * dst_stride0 + j]), res_0);
209             _mm_store_si128(
210                 (__m128i *)(&dst0[i * dst_stride0 + j + dst_stride0]), res_1);
211           } else {
212             const __m256i res_unsigned_16b =
213                 _mm256_adds_epu16(res, offset_const_16b);
214             const __m128i res_0 = _mm256_castsi256_si128(res_unsigned_16b);
215             const __m128i res_1 = _mm256_extracti128_si256(res_unsigned_16b, 1);
216 
217             _mm_store_si128((__m128i *)(&dst[i * dst_stride + j]), res_0);
218             _mm_store_si128((__m128i *)(&dst[i * dst_stride + j + dst_stride]),
219                             res_1);
220           }
221         }
222       }
223     }
224   }
225 }
226 
av1_highbd_dist_wtd_convolve_2d_avx2(const uint16_t * src,int src_stride,uint16_t * dst0,int dst_stride0,int w,int h,const InterpFilterParams * filter_params_x,const InterpFilterParams * filter_params_y,const int subpel_x_qn,const int subpel_y_qn,ConvolveParams * conv_params,int bd)227 void av1_highbd_dist_wtd_convolve_2d_avx2(
228     const uint16_t *src, int src_stride, uint16_t *dst0, int dst_stride0, int w,
229     int h, const InterpFilterParams *filter_params_x,
230     const InterpFilterParams *filter_params_y, const int subpel_x_qn,
231     const int subpel_y_qn, ConvolveParams *conv_params, int bd) {
232   DECLARE_ALIGNED(32, int16_t, im_block[(MAX_SB_SIZE + MAX_FILTER_TAP) * 8]);
233   CONV_BUF_TYPE *dst = conv_params->dst;
234   int dst_stride = conv_params->dst_stride;
235   int im_h = h + filter_params_y->taps - 1;
236   int im_stride = 8;
237   int i, j;
238   const int fo_vert = filter_params_y->taps / 2 - 1;
239   const int fo_horiz = filter_params_x->taps / 2 - 1;
240   const uint16_t *const src_ptr = src - fo_vert * src_stride - fo_horiz;
241 
242   // Check that, even with 12-bit input, the intermediate values will fit
243   // into an unsigned 16-bit intermediate array.
244   assert(bd + FILTER_BITS + 2 - conv_params->round_0 <= 16);
245 
246   __m256i s[8], coeffs_y[4], coeffs_x[4];
247   const int do_average = conv_params->do_average;
248   const int use_dist_wtd_comp_avg = conv_params->use_dist_wtd_comp_avg;
249 
250   const int w0 = conv_params->fwd_offset;
251   const int w1 = conv_params->bck_offset;
252   const __m256i wt0 = _mm256_set1_epi32(w0);
253   const __m256i wt1 = _mm256_set1_epi32(w1);
254   const __m256i zero = _mm256_setzero_si256();
255 
256   const __m256i round_const_x = _mm256_set1_epi32(
257       ((1 << conv_params->round_0) >> 1) + (1 << (bd + FILTER_BITS - 1)));
258   const __m128i round_shift_x = _mm_cvtsi32_si128(conv_params->round_0);
259 
260   const __m256i round_const_y = _mm256_set1_epi32(
261       ((1 << conv_params->round_1) >> 1) -
262       (1 << (bd + 2 * FILTER_BITS - conv_params->round_0 - 1)));
263   const __m128i round_shift_y = _mm_cvtsi32_si128(conv_params->round_1);
264 
265   const int offset_0 =
266       bd + 2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1;
267   const int offset = (1 << offset_0) + (1 << (offset_0 - 1));
268   const __m256i offset_const = _mm256_set1_epi32(offset);
269   const int rounding_shift =
270       2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1;
271   const __m256i rounding_const = _mm256_set1_epi32((1 << rounding_shift) >> 1);
272 
273   const __m256i clip_pixel_to_bd =
274       _mm256_set1_epi16(bd == 10 ? 1023 : (bd == 12 ? 4095 : 255));
275 
276   prepare_coeffs(filter_params_x, subpel_x_qn, coeffs_x);
277   prepare_coeffs(filter_params_y, subpel_y_qn, coeffs_y);
278 
279   for (j = 0; j < w; j += 8) {
280     /* Horizontal filter */
281     {
282       for (i = 0; i < im_h; i += 2) {
283         const __m256i row0 =
284             _mm256_loadu_si256((__m256i *)&src_ptr[i * src_stride + j]);
285         __m256i row1 = _mm256_set1_epi16(0);
286         if (i + 1 < im_h)
287           row1 =
288               _mm256_loadu_si256((__m256i *)&src_ptr[(i + 1) * src_stride + j]);
289 
290         const __m256i r0 = _mm256_permute2x128_si256(row0, row1, 0x20);
291         const __m256i r1 = _mm256_permute2x128_si256(row0, row1, 0x31);
292 
293         // even pixels
294         s[0] = _mm256_alignr_epi8(r1, r0, 0);
295         s[1] = _mm256_alignr_epi8(r1, r0, 4);
296         s[2] = _mm256_alignr_epi8(r1, r0, 8);
297         s[3] = _mm256_alignr_epi8(r1, r0, 12);
298 
299         __m256i res_even = convolve(s, coeffs_x);
300         res_even = _mm256_sra_epi32(_mm256_add_epi32(res_even, round_const_x),
301                                     round_shift_x);
302 
303         // odd pixels
304         s[0] = _mm256_alignr_epi8(r1, r0, 2);
305         s[1] = _mm256_alignr_epi8(r1, r0, 6);
306         s[2] = _mm256_alignr_epi8(r1, r0, 10);
307         s[3] = _mm256_alignr_epi8(r1, r0, 14);
308 
309         __m256i res_odd = convolve(s, coeffs_x);
310         res_odd = _mm256_sra_epi32(_mm256_add_epi32(res_odd, round_const_x),
311                                    round_shift_x);
312 
313         __m256i res_even1 = _mm256_packs_epi32(res_even, res_even);
314         __m256i res_odd1 = _mm256_packs_epi32(res_odd, res_odd);
315         __m256i res = _mm256_unpacklo_epi16(res_even1, res_odd1);
316 
317         _mm256_store_si256((__m256i *)&im_block[i * im_stride], res);
318       }
319     }
320 
321     /* Vertical filter */
322     {
323       __m256i s0 = _mm256_loadu_si256((__m256i *)(im_block + 0 * im_stride));
324       __m256i s1 = _mm256_loadu_si256((__m256i *)(im_block + 1 * im_stride));
325       __m256i s2 = _mm256_loadu_si256((__m256i *)(im_block + 2 * im_stride));
326       __m256i s3 = _mm256_loadu_si256((__m256i *)(im_block + 3 * im_stride));
327       __m256i s4 = _mm256_loadu_si256((__m256i *)(im_block + 4 * im_stride));
328       __m256i s5 = _mm256_loadu_si256((__m256i *)(im_block + 5 * im_stride));
329 
330       s[0] = _mm256_unpacklo_epi16(s0, s1);
331       s[1] = _mm256_unpacklo_epi16(s2, s3);
332       s[2] = _mm256_unpacklo_epi16(s4, s5);
333 
334       s[4] = _mm256_unpackhi_epi16(s0, s1);
335       s[5] = _mm256_unpackhi_epi16(s2, s3);
336       s[6] = _mm256_unpackhi_epi16(s4, s5);
337 
338       for (i = 0; i < h; i += 2) {
339         const int16_t *data = &im_block[i * im_stride];
340 
341         const __m256i s6 =
342             _mm256_loadu_si256((__m256i *)(data + 6 * im_stride));
343         const __m256i s7 =
344             _mm256_loadu_si256((__m256i *)(data + 7 * im_stride));
345 
346         s[3] = _mm256_unpacklo_epi16(s6, s7);
347         s[7] = _mm256_unpackhi_epi16(s6, s7);
348 
349         const __m256i res_a = convolve(s, coeffs_y);
350 
351         const __m256i res_a_round = _mm256_sra_epi32(
352             _mm256_add_epi32(res_a, round_const_y), round_shift_y);
353 
354         const __m256i res_unsigned_lo =
355             _mm256_add_epi32(res_a_round, offset_const);
356 
357         if (w - j < 8) {
358           if (do_average) {
359             const __m256i data_0 = _mm256_castsi128_si256(
360                 _mm_loadl_epi64((__m128i *)(&dst[i * dst_stride + j])));
361             const __m256i data_1 = _mm256_castsi128_si256(_mm_loadl_epi64(
362                 (__m128i *)(&dst[i * dst_stride + j + dst_stride])));
363             const __m256i data_01 =
364                 _mm256_permute2x128_si256(data_0, data_1, 0x20);
365 
366             const __m256i data_ref_0 = _mm256_unpacklo_epi16(data_01, zero);
367 
368             const __m256i comp_avg_res =
369                 highbd_comp_avg(&data_ref_0, &res_unsigned_lo, &wt0, &wt1,
370                                 use_dist_wtd_comp_avg);
371 
372             const __m256i round_result = highbd_convolve_rounding(
373                 &comp_avg_res, &offset_const, &rounding_const, rounding_shift);
374 
375             const __m256i res_16b =
376                 _mm256_packus_epi32(round_result, round_result);
377             const __m256i res_clip =
378                 _mm256_min_epi16(res_16b, clip_pixel_to_bd);
379 
380             const __m128i res_0 = _mm256_castsi256_si128(res_clip);
381             const __m128i res_1 = _mm256_extracti128_si256(res_clip, 1);
382 
383             _mm_storel_epi64((__m128i *)(&dst0[i * dst_stride0 + j]), res_0);
384             _mm_storel_epi64(
385                 (__m128i *)(&dst0[i * dst_stride0 + j + dst_stride0]), res_1);
386           } else {
387             __m256i res_16b =
388                 _mm256_packus_epi32(res_unsigned_lo, res_unsigned_lo);
389             const __m128i res_0 = _mm256_castsi256_si128(res_16b);
390             const __m128i res_1 = _mm256_extracti128_si256(res_16b, 1);
391 
392             _mm_storel_epi64((__m128i *)(&dst[i * dst_stride + j]), res_0);
393             _mm_storel_epi64((__m128i *)(&dst[i * dst_stride + j + dst_stride]),
394                              res_1);
395           }
396         } else {
397           const __m256i res_b = convolve(s + 4, coeffs_y);
398           const __m256i res_b_round = _mm256_sra_epi32(
399               _mm256_add_epi32(res_b, round_const_y), round_shift_y);
400 
401           __m256i res_unsigned_hi = _mm256_add_epi32(res_b_round, offset_const);
402 
403           if (do_average) {
404             const __m256i data_0 = _mm256_castsi128_si256(
405                 _mm_loadu_si128((__m128i *)(&dst[i * dst_stride + j])));
406             const __m256i data_1 = _mm256_castsi128_si256(_mm_loadu_si128(
407                 (__m128i *)(&dst[i * dst_stride + j + dst_stride])));
408             const __m256i data_01 =
409                 _mm256_permute2x128_si256(data_0, data_1, 0x20);
410 
411             const __m256i data_ref_0_lo = _mm256_unpacklo_epi16(data_01, zero);
412             const __m256i data_ref_0_hi = _mm256_unpackhi_epi16(data_01, zero);
413 
414             const __m256i comp_avg_res_lo =
415                 highbd_comp_avg(&data_ref_0_lo, &res_unsigned_lo, &wt0, &wt1,
416                                 use_dist_wtd_comp_avg);
417             const __m256i comp_avg_res_hi =
418                 highbd_comp_avg(&data_ref_0_hi, &res_unsigned_hi, &wt0, &wt1,
419                                 use_dist_wtd_comp_avg);
420 
421             const __m256i round_result_lo =
422                 highbd_convolve_rounding(&comp_avg_res_lo, &offset_const,
423                                          &rounding_const, rounding_shift);
424             const __m256i round_result_hi =
425                 highbd_convolve_rounding(&comp_avg_res_hi, &offset_const,
426                                          &rounding_const, rounding_shift);
427 
428             const __m256i res_16b =
429                 _mm256_packus_epi32(round_result_lo, round_result_hi);
430             const __m256i res_clip =
431                 _mm256_min_epi16(res_16b, clip_pixel_to_bd);
432 
433             const __m128i res_0 = _mm256_castsi256_si128(res_clip);
434             const __m128i res_1 = _mm256_extracti128_si256(res_clip, 1);
435 
436             _mm_store_si128((__m128i *)(&dst0[i * dst_stride0 + j]), res_0);
437             _mm_store_si128(
438                 (__m128i *)(&dst0[i * dst_stride0 + j + dst_stride0]), res_1);
439           } else {
440             __m256i res_16b =
441                 _mm256_packus_epi32(res_unsigned_lo, res_unsigned_hi);
442             const __m128i res_0 = _mm256_castsi256_si128(res_16b);
443             const __m128i res_1 = _mm256_extracti128_si256(res_16b, 1);
444 
445             _mm_store_si128((__m128i *)(&dst[i * dst_stride + j]), res_0);
446             _mm_store_si128((__m128i *)(&dst[i * dst_stride + j + dst_stride]),
447                             res_1);
448           }
449         }
450 
451         s[0] = s[1];
452         s[1] = s[2];
453         s[2] = s[3];
454 
455         s[4] = s[5];
456         s[5] = s[6];
457         s[6] = s[7];
458       }
459     }
460   }
461 }
462 
av1_highbd_dist_wtd_convolve_x_avx2(const uint16_t * src,int src_stride,uint16_t * dst0,int dst_stride0,int w,int h,const InterpFilterParams * filter_params_x,const int subpel_x_qn,ConvolveParams * conv_params,int bd)463 void av1_highbd_dist_wtd_convolve_x_avx2(
464     const uint16_t *src, int src_stride, uint16_t *dst0, int dst_stride0, int w,
465     int h, const InterpFilterParams *filter_params_x, const int subpel_x_qn,
466     ConvolveParams *conv_params, int bd) {
467   CONV_BUF_TYPE *dst = conv_params->dst;
468   int dst_stride = conv_params->dst_stride;
469   const int fo_horiz = filter_params_x->taps / 2 - 1;
470   const uint16_t *const src_ptr = src - fo_horiz;
471   const int bits = FILTER_BITS - conv_params->round_1;
472 
473   int i, j;
474   __m256i s[4], coeffs_x[4];
475 
476   const int do_average = conv_params->do_average;
477   const int use_dist_wtd_comp_avg = conv_params->use_dist_wtd_comp_avg;
478   const int w0 = conv_params->fwd_offset;
479   const int w1 = conv_params->bck_offset;
480   const __m256i wt0 = _mm256_set1_epi32(w0);
481   const __m256i wt1 = _mm256_set1_epi32(w1);
482   const __m256i zero = _mm256_setzero_si256();
483 
484   const __m256i round_const_x =
485       _mm256_set1_epi32(((1 << conv_params->round_0) >> 1));
486   const __m128i round_shift_x = _mm_cvtsi32_si128(conv_params->round_0);
487   const __m128i round_shift_bits = _mm_cvtsi32_si128(bits);
488 
489   const int offset_0 =
490       bd + 2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1;
491   const int offset = (1 << offset_0) + (1 << (offset_0 - 1));
492   const __m256i offset_const = _mm256_set1_epi32(offset);
493   const int rounding_shift =
494       2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1;
495   const __m256i rounding_const = _mm256_set1_epi32((1 << rounding_shift) >> 1);
496   const __m256i clip_pixel_to_bd =
497       _mm256_set1_epi16(bd == 10 ? 1023 : (bd == 12 ? 4095 : 255));
498 
499   assert(bits >= 0);
500   prepare_coeffs(filter_params_x, subpel_x_qn, coeffs_x);
501 
502   for (j = 0; j < w; j += 8) {
503     /* Horizontal filter */
504     for (i = 0; i < h; i += 2) {
505       const __m256i row0 =
506           _mm256_loadu_si256((__m256i *)&src_ptr[i * src_stride + j]);
507       __m256i row1 =
508           _mm256_loadu_si256((__m256i *)&src_ptr[(i + 1) * src_stride + j]);
509 
510       const __m256i r0 = _mm256_permute2x128_si256(row0, row1, 0x20);
511       const __m256i r1 = _mm256_permute2x128_si256(row0, row1, 0x31);
512 
513       // even pixels
514       s[0] = _mm256_alignr_epi8(r1, r0, 0);
515       s[1] = _mm256_alignr_epi8(r1, r0, 4);
516       s[2] = _mm256_alignr_epi8(r1, r0, 8);
517       s[3] = _mm256_alignr_epi8(r1, r0, 12);
518 
519       __m256i res_even = convolve(s, coeffs_x);
520       res_even = _mm256_sra_epi32(_mm256_add_epi32(res_even, round_const_x),
521                                   round_shift_x);
522 
523       // odd pixels
524       s[0] = _mm256_alignr_epi8(r1, r0, 2);
525       s[1] = _mm256_alignr_epi8(r1, r0, 6);
526       s[2] = _mm256_alignr_epi8(r1, r0, 10);
527       s[3] = _mm256_alignr_epi8(r1, r0, 14);
528 
529       __m256i res_odd = convolve(s, coeffs_x);
530       res_odd = _mm256_sra_epi32(_mm256_add_epi32(res_odd, round_const_x),
531                                  round_shift_x);
532 
533       res_even = _mm256_sll_epi32(res_even, round_shift_bits);
534       res_odd = _mm256_sll_epi32(res_odd, round_shift_bits);
535 
536       __m256i res1 = _mm256_unpacklo_epi32(res_even, res_odd);
537 
538       __m256i res_unsigned_lo = _mm256_add_epi32(res1, offset_const);
539 
540       if (w - j < 8) {
541         if (do_average) {
542           const __m256i data_0 = _mm256_castsi128_si256(
543               _mm_loadl_epi64((__m128i *)(&dst[i * dst_stride + j])));
544           const __m256i data_1 = _mm256_castsi128_si256(_mm_loadl_epi64(
545               (__m128i *)(&dst[i * dst_stride + j + dst_stride])));
546           const __m256i data_01 =
547               _mm256_permute2x128_si256(data_0, data_1, 0x20);
548 
549           const __m256i data_ref_0 = _mm256_unpacklo_epi16(data_01, zero);
550 
551           const __m256i comp_avg_res = highbd_comp_avg(
552               &data_ref_0, &res_unsigned_lo, &wt0, &wt1, use_dist_wtd_comp_avg);
553 
554           const __m256i round_result = highbd_convolve_rounding(
555               &comp_avg_res, &offset_const, &rounding_const, rounding_shift);
556 
557           const __m256i res_16b =
558               _mm256_packus_epi32(round_result, round_result);
559           const __m256i res_clip = _mm256_min_epi16(res_16b, clip_pixel_to_bd);
560 
561           const __m128i res_0 = _mm256_castsi256_si128(res_clip);
562           const __m128i res_1 = _mm256_extracti128_si256(res_clip, 1);
563 
564           _mm_storel_epi64((__m128i *)(&dst0[i * dst_stride0 + j]), res_0);
565           _mm_storel_epi64(
566               (__m128i *)(&dst0[i * dst_stride0 + j + dst_stride0]), res_1);
567         } else {
568           __m256i res_16b =
569               _mm256_packus_epi32(res_unsigned_lo, res_unsigned_lo);
570           const __m128i res_0 = _mm256_castsi256_si128(res_16b);
571           const __m128i res_1 = _mm256_extracti128_si256(res_16b, 1);
572 
573           _mm_storel_epi64((__m128i *)(&dst[i * dst_stride + j]), res_0);
574           _mm_storel_epi64((__m128i *)(&dst[i * dst_stride + j + dst_stride]),
575                            res_1);
576         }
577       } else {
578         __m256i res2 = _mm256_unpackhi_epi32(res_even, res_odd);
579         __m256i res_unsigned_hi = _mm256_add_epi32(res2, offset_const);
580 
581         if (do_average) {
582           const __m256i data_0 = _mm256_castsi128_si256(
583               _mm_loadu_si128((__m128i *)(&dst[i * dst_stride + j])));
584           const __m256i data_1 = _mm256_castsi128_si256(_mm_loadu_si128(
585               (__m128i *)(&dst[i * dst_stride + j + dst_stride])));
586           const __m256i data_01 =
587               _mm256_permute2x128_si256(data_0, data_1, 0x20);
588 
589           const __m256i data_ref_0_lo = _mm256_unpacklo_epi16(data_01, zero);
590           const __m256i data_ref_0_hi = _mm256_unpackhi_epi16(data_01, zero);
591 
592           const __m256i comp_avg_res_lo =
593               highbd_comp_avg(&data_ref_0_lo, &res_unsigned_lo, &wt0, &wt1,
594                               use_dist_wtd_comp_avg);
595           const __m256i comp_avg_res_hi =
596               highbd_comp_avg(&data_ref_0_hi, &res_unsigned_hi, &wt0, &wt1,
597                               use_dist_wtd_comp_avg);
598 
599           const __m256i round_result_lo = highbd_convolve_rounding(
600               &comp_avg_res_lo, &offset_const, &rounding_const, rounding_shift);
601           const __m256i round_result_hi = highbd_convolve_rounding(
602               &comp_avg_res_hi, &offset_const, &rounding_const, rounding_shift);
603 
604           const __m256i res_16b =
605               _mm256_packus_epi32(round_result_lo, round_result_hi);
606           const __m256i res_clip = _mm256_min_epi16(res_16b, clip_pixel_to_bd);
607 
608           const __m128i res_0 = _mm256_castsi256_si128(res_clip);
609           const __m128i res_1 = _mm256_extracti128_si256(res_clip, 1);
610 
611           _mm_store_si128((__m128i *)(&dst0[i * dst_stride0 + j]), res_0);
612           _mm_store_si128((__m128i *)(&dst0[i * dst_stride0 + j + dst_stride0]),
613                           res_1);
614         } else {
615           __m256i res_16b =
616               _mm256_packus_epi32(res_unsigned_lo, res_unsigned_hi);
617           const __m128i res_0 = _mm256_castsi256_si128(res_16b);
618           const __m128i res_1 = _mm256_extracti128_si256(res_16b, 1);
619 
620           _mm_store_si128((__m128i *)(&dst[i * dst_stride + j]), res_0);
621           _mm_store_si128((__m128i *)(&dst[i * dst_stride + j + dst_stride]),
622                           res_1);
623         }
624       }
625     }
626   }
627 }
628 
av1_highbd_dist_wtd_convolve_y_avx2(const uint16_t * src,int src_stride,uint16_t * dst0,int dst_stride0,int w,int h,const InterpFilterParams * filter_params_y,const int subpel_y_qn,ConvolveParams * conv_params,int bd)629 void av1_highbd_dist_wtd_convolve_y_avx2(
630     const uint16_t *src, int src_stride, uint16_t *dst0, int dst_stride0, int w,
631     int h, const InterpFilterParams *filter_params_y, const int subpel_y_qn,
632     ConvolveParams *conv_params, int bd) {
633   CONV_BUF_TYPE *dst = conv_params->dst;
634   int dst_stride = conv_params->dst_stride;
635   const int fo_vert = filter_params_y->taps / 2 - 1;
636   const uint16_t *const src_ptr = src - fo_vert * src_stride;
637   const int bits = FILTER_BITS - conv_params->round_0;
638 
639   assert(bits >= 0);
640   int i, j;
641   __m256i s[8], coeffs_y[4];
642   const int do_average = conv_params->do_average;
643   const int use_dist_wtd_comp_avg = conv_params->use_dist_wtd_comp_avg;
644 
645   const int w0 = conv_params->fwd_offset;
646   const int w1 = conv_params->bck_offset;
647   const __m256i wt0 = _mm256_set1_epi32(w0);
648   const __m256i wt1 = _mm256_set1_epi32(w1);
649   const __m256i round_const_y =
650       _mm256_set1_epi32(((1 << conv_params->round_1) >> 1));
651   const __m128i round_shift_y = _mm_cvtsi32_si128(conv_params->round_1);
652   const __m128i round_shift_bits = _mm_cvtsi32_si128(bits);
653 
654   const int offset_0 =
655       bd + 2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1;
656   const int offset = (1 << offset_0) + (1 << (offset_0 - 1));
657   const __m256i offset_const = _mm256_set1_epi32(offset);
658   const int rounding_shift =
659       2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1;
660   const __m256i rounding_const = _mm256_set1_epi32((1 << rounding_shift) >> 1);
661   const __m256i clip_pixel_to_bd =
662       _mm256_set1_epi16(bd == 10 ? 1023 : (bd == 12 ? 4095 : 255));
663   const __m256i zero = _mm256_setzero_si256();
664 
665   prepare_coeffs(filter_params_y, subpel_y_qn, coeffs_y);
666 
667   for (j = 0; j < w; j += 8) {
668     const uint16_t *data = &src_ptr[j];
669     /* Vertical filter */
670     {
671       __m256i src6;
672       __m256i s01 = _mm256_permute2x128_si256(
673           _mm256_castsi128_si256(
674               _mm_loadu_si128((__m128i *)(data + 0 * src_stride))),
675           _mm256_castsi128_si256(
676               _mm_loadu_si128((__m128i *)(data + 1 * src_stride))),
677           0x20);
678       __m256i s12 = _mm256_permute2x128_si256(
679           _mm256_castsi128_si256(
680               _mm_loadu_si128((__m128i *)(data + 1 * src_stride))),
681           _mm256_castsi128_si256(
682               _mm_loadu_si128((__m128i *)(data + 2 * src_stride))),
683           0x20);
684       __m256i s23 = _mm256_permute2x128_si256(
685           _mm256_castsi128_si256(
686               _mm_loadu_si128((__m128i *)(data + 2 * src_stride))),
687           _mm256_castsi128_si256(
688               _mm_loadu_si128((__m128i *)(data + 3 * src_stride))),
689           0x20);
690       __m256i s34 = _mm256_permute2x128_si256(
691           _mm256_castsi128_si256(
692               _mm_loadu_si128((__m128i *)(data + 3 * src_stride))),
693           _mm256_castsi128_si256(
694               _mm_loadu_si128((__m128i *)(data + 4 * src_stride))),
695           0x20);
696       __m256i s45 = _mm256_permute2x128_si256(
697           _mm256_castsi128_si256(
698               _mm_loadu_si128((__m128i *)(data + 4 * src_stride))),
699           _mm256_castsi128_si256(
700               _mm_loadu_si128((__m128i *)(data + 5 * src_stride))),
701           0x20);
702       src6 = _mm256_castsi128_si256(
703           _mm_loadu_si128((__m128i *)(data + 6 * src_stride)));
704       __m256i s56 = _mm256_permute2x128_si256(
705           _mm256_castsi128_si256(
706               _mm_loadu_si128((__m128i *)(data + 5 * src_stride))),
707           src6, 0x20);
708 
709       s[0] = _mm256_unpacklo_epi16(s01, s12);
710       s[1] = _mm256_unpacklo_epi16(s23, s34);
711       s[2] = _mm256_unpacklo_epi16(s45, s56);
712 
713       s[4] = _mm256_unpackhi_epi16(s01, s12);
714       s[5] = _mm256_unpackhi_epi16(s23, s34);
715       s[6] = _mm256_unpackhi_epi16(s45, s56);
716 
717       for (i = 0; i < h; i += 2) {
718         data = &src_ptr[i * src_stride + j];
719 
720         const __m256i s67 = _mm256_permute2x128_si256(
721             src6,
722             _mm256_castsi128_si256(
723                 _mm_loadu_si128((__m128i *)(data + 7 * src_stride))),
724             0x20);
725 
726         src6 = _mm256_castsi128_si256(
727             _mm_loadu_si128((__m128i *)(data + 8 * src_stride)));
728 
729         const __m256i s78 = _mm256_permute2x128_si256(
730             _mm256_castsi128_si256(
731                 _mm_loadu_si128((__m128i *)(data + 7 * src_stride))),
732             src6, 0x20);
733 
734         s[3] = _mm256_unpacklo_epi16(s67, s78);
735         s[7] = _mm256_unpackhi_epi16(s67, s78);
736 
737         const __m256i res_a = convolve(s, coeffs_y);
738 
739         __m256i res_a_round = _mm256_sll_epi32(res_a, round_shift_bits);
740         res_a_round = _mm256_sra_epi32(
741             _mm256_add_epi32(res_a_round, round_const_y), round_shift_y);
742 
743         __m256i res_unsigned_lo = _mm256_add_epi32(res_a_round, offset_const);
744 
745         if (w - j < 8) {
746           if (do_average) {
747             const __m256i data_0 = _mm256_castsi128_si256(
748                 _mm_loadl_epi64((__m128i *)(&dst[i * dst_stride + j])));
749             const __m256i data_1 = _mm256_castsi128_si256(_mm_loadl_epi64(
750                 (__m128i *)(&dst[i * dst_stride + j + dst_stride])));
751             const __m256i data_01 =
752                 _mm256_permute2x128_si256(data_0, data_1, 0x20);
753 
754             const __m256i data_ref_0 = _mm256_unpacklo_epi16(data_01, zero);
755 
756             const __m256i comp_avg_res =
757                 highbd_comp_avg(&data_ref_0, &res_unsigned_lo, &wt0, &wt1,
758                                 use_dist_wtd_comp_avg);
759 
760             const __m256i round_result = highbd_convolve_rounding(
761                 &comp_avg_res, &offset_const, &rounding_const, rounding_shift);
762 
763             const __m256i res_16b =
764                 _mm256_packus_epi32(round_result, round_result);
765             const __m256i res_clip =
766                 _mm256_min_epi16(res_16b, clip_pixel_to_bd);
767 
768             const __m128i res_0 = _mm256_castsi256_si128(res_clip);
769             const __m128i res_1 = _mm256_extracti128_si256(res_clip, 1);
770 
771             _mm_storel_epi64((__m128i *)(&dst0[i * dst_stride0 + j]), res_0);
772             _mm_storel_epi64(
773                 (__m128i *)(&dst0[i * dst_stride0 + j + dst_stride0]), res_1);
774           } else {
775             __m256i res_16b =
776                 _mm256_packus_epi32(res_unsigned_lo, res_unsigned_lo);
777             const __m128i res_0 = _mm256_castsi256_si128(res_16b);
778             const __m128i res_1 = _mm256_extracti128_si256(res_16b, 1);
779 
780             _mm_storel_epi64((__m128i *)(&dst[i * dst_stride + j]), res_0);
781             _mm_storel_epi64((__m128i *)(&dst[i * dst_stride + j + dst_stride]),
782                              res_1);
783           }
784         } else {
785           const __m256i res_b = convolve(s + 4, coeffs_y);
786           __m256i res_b_round = _mm256_sll_epi32(res_b, round_shift_bits);
787           res_b_round = _mm256_sra_epi32(
788               _mm256_add_epi32(res_b_round, round_const_y), round_shift_y);
789 
790           __m256i res_unsigned_hi = _mm256_add_epi32(res_b_round, offset_const);
791 
792           if (do_average) {
793             const __m256i data_0 = _mm256_castsi128_si256(
794                 _mm_loadu_si128((__m128i *)(&dst[i * dst_stride + j])));
795             const __m256i data_1 = _mm256_castsi128_si256(_mm_loadu_si128(
796                 (__m128i *)(&dst[i * dst_stride + j + dst_stride])));
797             const __m256i data_01 =
798                 _mm256_permute2x128_si256(data_0, data_1, 0x20);
799 
800             const __m256i data_ref_0_lo = _mm256_unpacklo_epi16(data_01, zero);
801             const __m256i data_ref_0_hi = _mm256_unpackhi_epi16(data_01, zero);
802 
803             const __m256i comp_avg_res_lo =
804                 highbd_comp_avg(&data_ref_0_lo, &res_unsigned_lo, &wt0, &wt1,
805                                 use_dist_wtd_comp_avg);
806             const __m256i comp_avg_res_hi =
807                 highbd_comp_avg(&data_ref_0_hi, &res_unsigned_hi, &wt0, &wt1,
808                                 use_dist_wtd_comp_avg);
809 
810             const __m256i round_result_lo =
811                 highbd_convolve_rounding(&comp_avg_res_lo, &offset_const,
812                                          &rounding_const, rounding_shift);
813             const __m256i round_result_hi =
814                 highbd_convolve_rounding(&comp_avg_res_hi, &offset_const,
815                                          &rounding_const, rounding_shift);
816 
817             const __m256i res_16b =
818                 _mm256_packus_epi32(round_result_lo, round_result_hi);
819             const __m256i res_clip =
820                 _mm256_min_epi16(res_16b, clip_pixel_to_bd);
821 
822             const __m128i res_0 = _mm256_castsi256_si128(res_clip);
823             const __m128i res_1 = _mm256_extracti128_si256(res_clip, 1);
824 
825             _mm_store_si128((__m128i *)(&dst0[i * dst_stride0 + j]), res_0);
826             _mm_store_si128(
827                 (__m128i *)(&dst0[i * dst_stride0 + j + dst_stride0]), res_1);
828           } else {
829             __m256i res_16b =
830                 _mm256_packus_epi32(res_unsigned_lo, res_unsigned_hi);
831             const __m128i res_0 = _mm256_castsi256_si128(res_16b);
832             const __m128i res_1 = _mm256_extracti128_si256(res_16b, 1);
833 
834             _mm_store_si128((__m128i *)(&dst[i * dst_stride + j]), res_0);
835             _mm_store_si128((__m128i *)(&dst[i * dst_stride + j + dst_stride]),
836                             res_1);
837           }
838         }
839         s[0] = s[1];
840         s[1] = s[2];
841         s[2] = s[3];
842 
843         s[4] = s[5];
844         s[5] = s[6];
845         s[6] = s[7];
846       }
847     }
848   }
849 }
850