1 /*
2  * Copyright (c) 2018, Alliance for Open Media. All rights reserved
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <immintrin.h>
13 
14 #include "config/av1_rtcd.h"
15 
16 #include "aom/aom_integer.h"
17 #include "aom_dsp/blend.h"
18 #include "aom_dsp/x86/synonyms.h"
19 #include "aom_dsp/x86/synonyms_avx2.h"
20 #include "av1/common/blockd.h"
21 
calc_mask_avx2(const __m256i mask_base,const __m256i s0,const __m256i s1)22 static INLINE __m256i calc_mask_avx2(const __m256i mask_base, const __m256i s0,
23                                      const __m256i s1) {
24   const __m256i diff = _mm256_abs_epi16(_mm256_sub_epi16(s0, s1));
25   return _mm256_abs_epi16(
26       _mm256_add_epi16(mask_base, _mm256_srli_epi16(diff, 4)));
27   // clamp(diff, 0, 64) can be skiped for diff is always in the range ( 38, 54)
28 }
av1_build_compound_diffwtd_mask_avx2(uint8_t * mask,DIFFWTD_MASK_TYPE mask_type,const uint8_t * src0,int stride0,const uint8_t * src1,int stride1,int h,int w)29 void av1_build_compound_diffwtd_mask_avx2(uint8_t *mask,
30                                           DIFFWTD_MASK_TYPE mask_type,
31                                           const uint8_t *src0, int stride0,
32                                           const uint8_t *src1, int stride1,
33                                           int h, int w) {
34   const int mb = (mask_type == DIFFWTD_38_INV) ? AOM_BLEND_A64_MAX_ALPHA : 0;
35   const __m256i y_mask_base = _mm256_set1_epi16(38 - mb);
36   int i = 0;
37   if (4 == w) {
38     do {
39       const __m128i s0A = xx_loadl_32(src0);
40       const __m128i s0B = xx_loadl_32(src0 + stride0);
41       const __m128i s0C = xx_loadl_32(src0 + stride0 * 2);
42       const __m128i s0D = xx_loadl_32(src0 + stride0 * 3);
43       const __m128i s0AB = _mm_unpacklo_epi32(s0A, s0B);
44       const __m128i s0CD = _mm_unpacklo_epi32(s0C, s0D);
45       const __m128i s0ABCD = _mm_unpacklo_epi64(s0AB, s0CD);
46       const __m256i s0ABCD_w = _mm256_cvtepu8_epi16(s0ABCD);
47 
48       const __m128i s1A = xx_loadl_32(src1);
49       const __m128i s1B = xx_loadl_32(src1 + stride1);
50       const __m128i s1C = xx_loadl_32(src1 + stride1 * 2);
51       const __m128i s1D = xx_loadl_32(src1 + stride1 * 3);
52       const __m128i s1AB = _mm_unpacklo_epi32(s1A, s1B);
53       const __m128i s1CD = _mm_unpacklo_epi32(s1C, s1D);
54       const __m128i s1ABCD = _mm_unpacklo_epi64(s1AB, s1CD);
55       const __m256i s1ABCD_w = _mm256_cvtepu8_epi16(s1ABCD);
56       const __m256i m16 = calc_mask_avx2(y_mask_base, s0ABCD_w, s1ABCD_w);
57       const __m256i m8 = _mm256_packus_epi16(m16, _mm256_setzero_si256());
58       const __m128i x_m8 =
59           _mm256_castsi256_si128(_mm256_permute4x64_epi64(m8, 0xd8));
60       xx_storeu_128(mask, x_m8);
61       src0 += (stride0 << 2);
62       src1 += (stride1 << 2);
63       mask += 16;
64       i += 4;
65     } while (i < h);
66   } else if (8 == w) {
67     do {
68       const __m128i s0A = xx_loadl_64(src0);
69       const __m128i s0B = xx_loadl_64(src0 + stride0);
70       const __m128i s0C = xx_loadl_64(src0 + stride0 * 2);
71       const __m128i s0D = xx_loadl_64(src0 + stride0 * 3);
72       const __m256i s0AC_w = _mm256_cvtepu8_epi16(_mm_unpacklo_epi64(s0A, s0C));
73       const __m256i s0BD_w = _mm256_cvtepu8_epi16(_mm_unpacklo_epi64(s0B, s0D));
74       const __m128i s1A = xx_loadl_64(src1);
75       const __m128i s1B = xx_loadl_64(src1 + stride1);
76       const __m128i s1C = xx_loadl_64(src1 + stride1 * 2);
77       const __m128i s1D = xx_loadl_64(src1 + stride1 * 3);
78       const __m256i s1AB_w = _mm256_cvtepu8_epi16(_mm_unpacklo_epi64(s1A, s1C));
79       const __m256i s1CD_w = _mm256_cvtepu8_epi16(_mm_unpacklo_epi64(s1B, s1D));
80       const __m256i m16AC = calc_mask_avx2(y_mask_base, s0AC_w, s1AB_w);
81       const __m256i m16BD = calc_mask_avx2(y_mask_base, s0BD_w, s1CD_w);
82       const __m256i m8 = _mm256_packus_epi16(m16AC, m16BD);
83       yy_storeu_256(mask, m8);
84       src0 += stride0 << 2;
85       src1 += stride1 << 2;
86       mask += 32;
87       i += 4;
88     } while (i < h);
89   } else if (16 == w) {
90     do {
91       const __m128i s0A = xx_load_128(src0);
92       const __m128i s0B = xx_load_128(src0 + stride0);
93       const __m128i s1A = xx_load_128(src1);
94       const __m128i s1B = xx_load_128(src1 + stride1);
95       const __m256i s0AL = _mm256_cvtepu8_epi16(s0A);
96       const __m256i s0BL = _mm256_cvtepu8_epi16(s0B);
97       const __m256i s1AL = _mm256_cvtepu8_epi16(s1A);
98       const __m256i s1BL = _mm256_cvtepu8_epi16(s1B);
99 
100       const __m256i m16AL = calc_mask_avx2(y_mask_base, s0AL, s1AL);
101       const __m256i m16BL = calc_mask_avx2(y_mask_base, s0BL, s1BL);
102 
103       const __m256i m8 =
104           _mm256_permute4x64_epi64(_mm256_packus_epi16(m16AL, m16BL), 0xd8);
105       yy_storeu_256(mask, m8);
106       src0 += stride0 << 1;
107       src1 += stride1 << 1;
108       mask += 32;
109       i += 2;
110     } while (i < h);
111   } else {
112     do {
113       int j = 0;
114       do {
115         const __m256i s0 = yy_loadu_256(src0 + j);
116         const __m256i s1 = yy_loadu_256(src1 + j);
117         const __m256i s0L = _mm256_cvtepu8_epi16(_mm256_castsi256_si128(s0));
118         const __m256i s1L = _mm256_cvtepu8_epi16(_mm256_castsi256_si128(s1));
119         const __m256i s0H =
120             _mm256_cvtepu8_epi16(_mm256_extracti128_si256(s0, 1));
121         const __m256i s1H =
122             _mm256_cvtepu8_epi16(_mm256_extracti128_si256(s1, 1));
123         const __m256i m16L = calc_mask_avx2(y_mask_base, s0L, s1L);
124         const __m256i m16H = calc_mask_avx2(y_mask_base, s0H, s1H);
125         const __m256i m8 =
126             _mm256_permute4x64_epi64(_mm256_packus_epi16(m16L, m16H), 0xd8);
127         yy_storeu_256(mask + j, m8);
128         j += 32;
129       } while (j < w);
130       src0 += stride0;
131       src1 += stride1;
132       mask += w;
133       i += 1;
134     } while (i < h);
135   }
136 }
137 
calc_mask_d16_avx2(const __m256i * data_src0,const __m256i * data_src1,const __m256i * round_const,const __m256i * mask_base_16,const __m256i * clip_diff,int round)138 static INLINE __m256i calc_mask_d16_avx2(const __m256i *data_src0,
139                                          const __m256i *data_src1,
140                                          const __m256i *round_const,
141                                          const __m256i *mask_base_16,
142                                          const __m256i *clip_diff, int round) {
143   const __m256i diffa = _mm256_subs_epu16(*data_src0, *data_src1);
144   const __m256i diffb = _mm256_subs_epu16(*data_src1, *data_src0);
145   const __m256i diff = _mm256_max_epu16(diffa, diffb);
146   const __m256i diff_round =
147       _mm256_srli_epi16(_mm256_adds_epu16(diff, *round_const), round);
148   const __m256i diff_factor = _mm256_srli_epi16(diff_round, DIFF_FACTOR_LOG2);
149   const __m256i diff_mask = _mm256_adds_epi16(diff_factor, *mask_base_16);
150   const __m256i diff_clamp = _mm256_min_epi16(diff_mask, *clip_diff);
151   return diff_clamp;
152 }
153 
calc_mask_d16_inv_avx2(const __m256i * data_src0,const __m256i * data_src1,const __m256i * round_const,const __m256i * mask_base_16,const __m256i * clip_diff,int round)154 static INLINE __m256i calc_mask_d16_inv_avx2(const __m256i *data_src0,
155                                              const __m256i *data_src1,
156                                              const __m256i *round_const,
157                                              const __m256i *mask_base_16,
158                                              const __m256i *clip_diff,
159                                              int round) {
160   const __m256i diffa = _mm256_subs_epu16(*data_src0, *data_src1);
161   const __m256i diffb = _mm256_subs_epu16(*data_src1, *data_src0);
162   const __m256i diff = _mm256_max_epu16(diffa, diffb);
163   const __m256i diff_round =
164       _mm256_srli_epi16(_mm256_adds_epu16(diff, *round_const), round);
165   const __m256i diff_factor = _mm256_srli_epi16(diff_round, DIFF_FACTOR_LOG2);
166   const __m256i diff_mask = _mm256_adds_epi16(diff_factor, *mask_base_16);
167   const __m256i diff_clamp = _mm256_min_epi16(diff_mask, *clip_diff);
168   const __m256i diff_const_16 = _mm256_sub_epi16(*clip_diff, diff_clamp);
169   return diff_const_16;
170 }
171 
build_compound_diffwtd_mask_d16_avx2(uint8_t * mask,const CONV_BUF_TYPE * src0,int src0_stride,const CONV_BUF_TYPE * src1,int src1_stride,int h,int w,int shift)172 static INLINE void build_compound_diffwtd_mask_d16_avx2(
173     uint8_t *mask, const CONV_BUF_TYPE *src0, int src0_stride,
174     const CONV_BUF_TYPE *src1, int src1_stride, int h, int w, int shift) {
175   const int mask_base = 38;
176   const __m256i _r = _mm256_set1_epi16((1 << shift) >> 1);
177   const __m256i y38 = _mm256_set1_epi16(mask_base);
178   const __m256i y64 = _mm256_set1_epi16(AOM_BLEND_A64_MAX_ALPHA);
179   int i = 0;
180   if (w == 4) {
181     do {
182       const __m128i s0A = xx_loadl_64(src0);
183       const __m128i s0B = xx_loadl_64(src0 + src0_stride);
184       const __m128i s0C = xx_loadl_64(src0 + src0_stride * 2);
185       const __m128i s0D = xx_loadl_64(src0 + src0_stride * 3);
186       const __m128i s1A = xx_loadl_64(src1);
187       const __m128i s1B = xx_loadl_64(src1 + src1_stride);
188       const __m128i s1C = xx_loadl_64(src1 + src1_stride * 2);
189       const __m128i s1D = xx_loadl_64(src1 + src1_stride * 3);
190       const __m256i s0 = yy_set_m128i(_mm_unpacklo_epi64(s0C, s0D),
191                                       _mm_unpacklo_epi64(s0A, s0B));
192       const __m256i s1 = yy_set_m128i(_mm_unpacklo_epi64(s1C, s1D),
193                                       _mm_unpacklo_epi64(s1A, s1B));
194       const __m256i m16 = calc_mask_d16_avx2(&s0, &s1, &_r, &y38, &y64, shift);
195       const __m256i m8 = _mm256_packus_epi16(m16, _mm256_setzero_si256());
196       xx_storeu_128(mask,
197                     _mm256_castsi256_si128(_mm256_permute4x64_epi64(m8, 0xd8)));
198       src0 += src0_stride << 2;
199       src1 += src1_stride << 2;
200       mask += 16;
201       i += 4;
202     } while (i < h);
203   } else if (w == 8) {
204     do {
205       const __m256i s0AB = yy_loadu2_128(src0 + src0_stride, src0);
206       const __m256i s0CD =
207           yy_loadu2_128(src0 + src0_stride * 3, src0 + src0_stride * 2);
208       const __m256i s1AB = yy_loadu2_128(src1 + src1_stride, src1);
209       const __m256i s1CD =
210           yy_loadu2_128(src1 + src1_stride * 3, src1 + src1_stride * 2);
211       const __m256i m16AB =
212           calc_mask_d16_avx2(&s0AB, &s1AB, &_r, &y38, &y64, shift);
213       const __m256i m16CD =
214           calc_mask_d16_avx2(&s0CD, &s1CD, &_r, &y38, &y64, shift);
215       const __m256i m8 = _mm256_packus_epi16(m16AB, m16CD);
216       yy_storeu_256(mask, _mm256_permute4x64_epi64(m8, 0xd8));
217       src0 += src0_stride << 2;
218       src1 += src1_stride << 2;
219       mask += 32;
220       i += 4;
221     } while (i < h);
222   } else if (w == 16) {
223     do {
224       const __m256i s0A = yy_loadu_256(src0);
225       const __m256i s0B = yy_loadu_256(src0 + src0_stride);
226       const __m256i s1A = yy_loadu_256(src1);
227       const __m256i s1B = yy_loadu_256(src1 + src1_stride);
228       const __m256i m16A =
229           calc_mask_d16_avx2(&s0A, &s1A, &_r, &y38, &y64, shift);
230       const __m256i m16B =
231           calc_mask_d16_avx2(&s0B, &s1B, &_r, &y38, &y64, shift);
232       const __m256i m8 = _mm256_packus_epi16(m16A, m16B);
233       yy_storeu_256(mask, _mm256_permute4x64_epi64(m8, 0xd8));
234       src0 += src0_stride << 1;
235       src1 += src1_stride << 1;
236       mask += 32;
237       i += 2;
238     } while (i < h);
239   } else if (w == 32) {
240     do {
241       const __m256i s0A = yy_loadu_256(src0);
242       const __m256i s0B = yy_loadu_256(src0 + 16);
243       const __m256i s1A = yy_loadu_256(src1);
244       const __m256i s1B = yy_loadu_256(src1 + 16);
245       const __m256i m16A =
246           calc_mask_d16_avx2(&s0A, &s1A, &_r, &y38, &y64, shift);
247       const __m256i m16B =
248           calc_mask_d16_avx2(&s0B, &s1B, &_r, &y38, &y64, shift);
249       const __m256i m8 = _mm256_packus_epi16(m16A, m16B);
250       yy_storeu_256(mask, _mm256_permute4x64_epi64(m8, 0xd8));
251       src0 += src0_stride;
252       src1 += src1_stride;
253       mask += 32;
254       i += 1;
255     } while (i < h);
256   } else if (w == 64) {
257     do {
258       const __m256i s0A = yy_loadu_256(src0);
259       const __m256i s0B = yy_loadu_256(src0 + 16);
260       const __m256i s0C = yy_loadu_256(src0 + 32);
261       const __m256i s0D = yy_loadu_256(src0 + 48);
262       const __m256i s1A = yy_loadu_256(src1);
263       const __m256i s1B = yy_loadu_256(src1 + 16);
264       const __m256i s1C = yy_loadu_256(src1 + 32);
265       const __m256i s1D = yy_loadu_256(src1 + 48);
266       const __m256i m16A =
267           calc_mask_d16_avx2(&s0A, &s1A, &_r, &y38, &y64, shift);
268       const __m256i m16B =
269           calc_mask_d16_avx2(&s0B, &s1B, &_r, &y38, &y64, shift);
270       const __m256i m16C =
271           calc_mask_d16_avx2(&s0C, &s1C, &_r, &y38, &y64, shift);
272       const __m256i m16D =
273           calc_mask_d16_avx2(&s0D, &s1D, &_r, &y38, &y64, shift);
274       const __m256i m8AB = _mm256_packus_epi16(m16A, m16B);
275       const __m256i m8CD = _mm256_packus_epi16(m16C, m16D);
276       yy_storeu_256(mask, _mm256_permute4x64_epi64(m8AB, 0xd8));
277       yy_storeu_256(mask + 32, _mm256_permute4x64_epi64(m8CD, 0xd8));
278       src0 += src0_stride;
279       src1 += src1_stride;
280       mask += 64;
281       i += 1;
282     } while (i < h);
283   } else {
284     do {
285       const __m256i s0A = yy_loadu_256(src0);
286       const __m256i s0B = yy_loadu_256(src0 + 16);
287       const __m256i s0C = yy_loadu_256(src0 + 32);
288       const __m256i s0D = yy_loadu_256(src0 + 48);
289       const __m256i s0E = yy_loadu_256(src0 + 64);
290       const __m256i s0F = yy_loadu_256(src0 + 80);
291       const __m256i s0G = yy_loadu_256(src0 + 96);
292       const __m256i s0H = yy_loadu_256(src0 + 112);
293       const __m256i s1A = yy_loadu_256(src1);
294       const __m256i s1B = yy_loadu_256(src1 + 16);
295       const __m256i s1C = yy_loadu_256(src1 + 32);
296       const __m256i s1D = yy_loadu_256(src1 + 48);
297       const __m256i s1E = yy_loadu_256(src1 + 64);
298       const __m256i s1F = yy_loadu_256(src1 + 80);
299       const __m256i s1G = yy_loadu_256(src1 + 96);
300       const __m256i s1H = yy_loadu_256(src1 + 112);
301       const __m256i m16A =
302           calc_mask_d16_avx2(&s0A, &s1A, &_r, &y38, &y64, shift);
303       const __m256i m16B =
304           calc_mask_d16_avx2(&s0B, &s1B, &_r, &y38, &y64, shift);
305       const __m256i m16C =
306           calc_mask_d16_avx2(&s0C, &s1C, &_r, &y38, &y64, shift);
307       const __m256i m16D =
308           calc_mask_d16_avx2(&s0D, &s1D, &_r, &y38, &y64, shift);
309       const __m256i m16E =
310           calc_mask_d16_avx2(&s0E, &s1E, &_r, &y38, &y64, shift);
311       const __m256i m16F =
312           calc_mask_d16_avx2(&s0F, &s1F, &_r, &y38, &y64, shift);
313       const __m256i m16G =
314           calc_mask_d16_avx2(&s0G, &s1G, &_r, &y38, &y64, shift);
315       const __m256i m16H =
316           calc_mask_d16_avx2(&s0H, &s1H, &_r, &y38, &y64, shift);
317       const __m256i m8AB = _mm256_packus_epi16(m16A, m16B);
318       const __m256i m8CD = _mm256_packus_epi16(m16C, m16D);
319       const __m256i m8EF = _mm256_packus_epi16(m16E, m16F);
320       const __m256i m8GH = _mm256_packus_epi16(m16G, m16H);
321       yy_storeu_256(mask, _mm256_permute4x64_epi64(m8AB, 0xd8));
322       yy_storeu_256(mask + 32, _mm256_permute4x64_epi64(m8CD, 0xd8));
323       yy_storeu_256(mask + 64, _mm256_permute4x64_epi64(m8EF, 0xd8));
324       yy_storeu_256(mask + 96, _mm256_permute4x64_epi64(m8GH, 0xd8));
325       src0 += src0_stride;
326       src1 += src1_stride;
327       mask += 128;
328       i += 1;
329     } while (i < h);
330   }
331 }
332 
build_compound_diffwtd_mask_d16_inv_avx2(uint8_t * mask,const CONV_BUF_TYPE * src0,int src0_stride,const CONV_BUF_TYPE * src1,int src1_stride,int h,int w,int shift)333 static INLINE void build_compound_diffwtd_mask_d16_inv_avx2(
334     uint8_t *mask, const CONV_BUF_TYPE *src0, int src0_stride,
335     const CONV_BUF_TYPE *src1, int src1_stride, int h, int w, int shift) {
336   const int mask_base = 38;
337   const __m256i _r = _mm256_set1_epi16((1 << shift) >> 1);
338   const __m256i y38 = _mm256_set1_epi16(mask_base);
339   const __m256i y64 = _mm256_set1_epi16(AOM_BLEND_A64_MAX_ALPHA);
340   int i = 0;
341   if (w == 4) {
342     do {
343       const __m128i s0A = xx_loadl_64(src0);
344       const __m128i s0B = xx_loadl_64(src0 + src0_stride);
345       const __m128i s0C = xx_loadl_64(src0 + src0_stride * 2);
346       const __m128i s0D = xx_loadl_64(src0 + src0_stride * 3);
347       const __m128i s1A = xx_loadl_64(src1);
348       const __m128i s1B = xx_loadl_64(src1 + src1_stride);
349       const __m128i s1C = xx_loadl_64(src1 + src1_stride * 2);
350       const __m128i s1D = xx_loadl_64(src1 + src1_stride * 3);
351       const __m256i s0 = yy_set_m128i(_mm_unpacklo_epi64(s0C, s0D),
352                                       _mm_unpacklo_epi64(s0A, s0B));
353       const __m256i s1 = yy_set_m128i(_mm_unpacklo_epi64(s1C, s1D),
354                                       _mm_unpacklo_epi64(s1A, s1B));
355       const __m256i m16 =
356           calc_mask_d16_inv_avx2(&s0, &s1, &_r, &y38, &y64, shift);
357       const __m256i m8 = _mm256_packus_epi16(m16, _mm256_setzero_si256());
358       xx_storeu_128(mask,
359                     _mm256_castsi256_si128(_mm256_permute4x64_epi64(m8, 0xd8)));
360       src0 += src0_stride << 2;
361       src1 += src1_stride << 2;
362       mask += 16;
363       i += 4;
364     } while (i < h);
365   } else if (w == 8) {
366     do {
367       const __m256i s0AB = yy_loadu2_128(src0 + src0_stride, src0);
368       const __m256i s0CD =
369           yy_loadu2_128(src0 + src0_stride * 3, src0 + src0_stride * 2);
370       const __m256i s1AB = yy_loadu2_128(src1 + src1_stride, src1);
371       const __m256i s1CD =
372           yy_loadu2_128(src1 + src1_stride * 3, src1 + src1_stride * 2);
373       const __m256i m16AB =
374           calc_mask_d16_inv_avx2(&s0AB, &s1AB, &_r, &y38, &y64, shift);
375       const __m256i m16CD =
376           calc_mask_d16_inv_avx2(&s0CD, &s1CD, &_r, &y38, &y64, shift);
377       const __m256i m8 = _mm256_packus_epi16(m16AB, m16CD);
378       yy_storeu_256(mask, _mm256_permute4x64_epi64(m8, 0xd8));
379       src0 += src0_stride << 2;
380       src1 += src1_stride << 2;
381       mask += 32;
382       i += 4;
383     } while (i < h);
384   } else if (w == 16) {
385     do {
386       const __m256i s0A = yy_loadu_256(src0);
387       const __m256i s0B = yy_loadu_256(src0 + src0_stride);
388       const __m256i s1A = yy_loadu_256(src1);
389       const __m256i s1B = yy_loadu_256(src1 + src1_stride);
390       const __m256i m16A =
391           calc_mask_d16_inv_avx2(&s0A, &s1A, &_r, &y38, &y64, shift);
392       const __m256i m16B =
393           calc_mask_d16_inv_avx2(&s0B, &s1B, &_r, &y38, &y64, shift);
394       const __m256i m8 = _mm256_packus_epi16(m16A, m16B);
395       yy_storeu_256(mask, _mm256_permute4x64_epi64(m8, 0xd8));
396       src0 += src0_stride << 1;
397       src1 += src1_stride << 1;
398       mask += 32;
399       i += 2;
400     } while (i < h);
401   } else if (w == 32) {
402     do {
403       const __m256i s0A = yy_loadu_256(src0);
404       const __m256i s0B = yy_loadu_256(src0 + 16);
405       const __m256i s1A = yy_loadu_256(src1);
406       const __m256i s1B = yy_loadu_256(src1 + 16);
407       const __m256i m16A =
408           calc_mask_d16_inv_avx2(&s0A, &s1A, &_r, &y38, &y64, shift);
409       const __m256i m16B =
410           calc_mask_d16_inv_avx2(&s0B, &s1B, &_r, &y38, &y64, shift);
411       const __m256i m8 = _mm256_packus_epi16(m16A, m16B);
412       yy_storeu_256(mask, _mm256_permute4x64_epi64(m8, 0xd8));
413       src0 += src0_stride;
414       src1 += src1_stride;
415       mask += 32;
416       i += 1;
417     } while (i < h);
418   } else if (w == 64) {
419     do {
420       const __m256i s0A = yy_loadu_256(src0);
421       const __m256i s0B = yy_loadu_256(src0 + 16);
422       const __m256i s0C = yy_loadu_256(src0 + 32);
423       const __m256i s0D = yy_loadu_256(src0 + 48);
424       const __m256i s1A = yy_loadu_256(src1);
425       const __m256i s1B = yy_loadu_256(src1 + 16);
426       const __m256i s1C = yy_loadu_256(src1 + 32);
427       const __m256i s1D = yy_loadu_256(src1 + 48);
428       const __m256i m16A =
429           calc_mask_d16_inv_avx2(&s0A, &s1A, &_r, &y38, &y64, shift);
430       const __m256i m16B =
431           calc_mask_d16_inv_avx2(&s0B, &s1B, &_r, &y38, &y64, shift);
432       const __m256i m16C =
433           calc_mask_d16_inv_avx2(&s0C, &s1C, &_r, &y38, &y64, shift);
434       const __m256i m16D =
435           calc_mask_d16_inv_avx2(&s0D, &s1D, &_r, &y38, &y64, shift);
436       const __m256i m8AB = _mm256_packus_epi16(m16A, m16B);
437       const __m256i m8CD = _mm256_packus_epi16(m16C, m16D);
438       yy_storeu_256(mask, _mm256_permute4x64_epi64(m8AB, 0xd8));
439       yy_storeu_256(mask + 32, _mm256_permute4x64_epi64(m8CD, 0xd8));
440       src0 += src0_stride;
441       src1 += src1_stride;
442       mask += 64;
443       i += 1;
444     } while (i < h);
445   } else {
446     do {
447       const __m256i s0A = yy_loadu_256(src0);
448       const __m256i s0B = yy_loadu_256(src0 + 16);
449       const __m256i s0C = yy_loadu_256(src0 + 32);
450       const __m256i s0D = yy_loadu_256(src0 + 48);
451       const __m256i s0E = yy_loadu_256(src0 + 64);
452       const __m256i s0F = yy_loadu_256(src0 + 80);
453       const __m256i s0G = yy_loadu_256(src0 + 96);
454       const __m256i s0H = yy_loadu_256(src0 + 112);
455       const __m256i s1A = yy_loadu_256(src1);
456       const __m256i s1B = yy_loadu_256(src1 + 16);
457       const __m256i s1C = yy_loadu_256(src1 + 32);
458       const __m256i s1D = yy_loadu_256(src1 + 48);
459       const __m256i s1E = yy_loadu_256(src1 + 64);
460       const __m256i s1F = yy_loadu_256(src1 + 80);
461       const __m256i s1G = yy_loadu_256(src1 + 96);
462       const __m256i s1H = yy_loadu_256(src1 + 112);
463       const __m256i m16A =
464           calc_mask_d16_inv_avx2(&s0A, &s1A, &_r, &y38, &y64, shift);
465       const __m256i m16B =
466           calc_mask_d16_inv_avx2(&s0B, &s1B, &_r, &y38, &y64, shift);
467       const __m256i m16C =
468           calc_mask_d16_inv_avx2(&s0C, &s1C, &_r, &y38, &y64, shift);
469       const __m256i m16D =
470           calc_mask_d16_inv_avx2(&s0D, &s1D, &_r, &y38, &y64, shift);
471       const __m256i m16E =
472           calc_mask_d16_inv_avx2(&s0E, &s1E, &_r, &y38, &y64, shift);
473       const __m256i m16F =
474           calc_mask_d16_inv_avx2(&s0F, &s1F, &_r, &y38, &y64, shift);
475       const __m256i m16G =
476           calc_mask_d16_inv_avx2(&s0G, &s1G, &_r, &y38, &y64, shift);
477       const __m256i m16H =
478           calc_mask_d16_inv_avx2(&s0H, &s1H, &_r, &y38, &y64, shift);
479       const __m256i m8AB = _mm256_packus_epi16(m16A, m16B);
480       const __m256i m8CD = _mm256_packus_epi16(m16C, m16D);
481       const __m256i m8EF = _mm256_packus_epi16(m16E, m16F);
482       const __m256i m8GH = _mm256_packus_epi16(m16G, m16H);
483       yy_storeu_256(mask, _mm256_permute4x64_epi64(m8AB, 0xd8));
484       yy_storeu_256(mask + 32, _mm256_permute4x64_epi64(m8CD, 0xd8));
485       yy_storeu_256(mask + 64, _mm256_permute4x64_epi64(m8EF, 0xd8));
486       yy_storeu_256(mask + 96, _mm256_permute4x64_epi64(m8GH, 0xd8));
487       src0 += src0_stride;
488       src1 += src1_stride;
489       mask += 128;
490       i += 1;
491     } while (i < h);
492   }
493 }
494 
av1_build_compound_diffwtd_mask_d16_avx2(uint8_t * mask,DIFFWTD_MASK_TYPE mask_type,const CONV_BUF_TYPE * src0,int src0_stride,const CONV_BUF_TYPE * src1,int src1_stride,int h,int w,ConvolveParams * conv_params,int bd)495 void av1_build_compound_diffwtd_mask_d16_avx2(
496     uint8_t *mask, DIFFWTD_MASK_TYPE mask_type, const CONV_BUF_TYPE *src0,
497     int src0_stride, const CONV_BUF_TYPE *src1, int src1_stride, int h, int w,
498     ConvolveParams *conv_params, int bd) {
499   const int shift =
500       2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1 + (bd - 8);
501   // When rounding constant is added, there is a possibility of overflow.
502   // However that much precision is not required. Code should very well work for
503   // other values of DIFF_FACTOR_LOG2 and AOM_BLEND_A64_MAX_ALPHA as well. But
504   // there is a possibility of corner case bugs.
505   assert(DIFF_FACTOR_LOG2 == 4);
506   assert(AOM_BLEND_A64_MAX_ALPHA == 64);
507 
508   if (mask_type == DIFFWTD_38) {
509     build_compound_diffwtd_mask_d16_avx2(mask, src0, src0_stride, src1,
510                                          src1_stride, h, w, shift);
511   } else {
512     build_compound_diffwtd_mask_d16_inv_avx2(mask, src0, src0_stride, src1,
513                                              src1_stride, h, w, shift);
514   }
515 }
516 
av1_build_compound_diffwtd_mask_highbd_avx2(uint8_t * mask,DIFFWTD_MASK_TYPE mask_type,const uint8_t * src0,int src0_stride,const uint8_t * src1,int src1_stride,int h,int w,int bd)517 void av1_build_compound_diffwtd_mask_highbd_avx2(
518     uint8_t *mask, DIFFWTD_MASK_TYPE mask_type, const uint8_t *src0,
519     int src0_stride, const uint8_t *src1, int src1_stride, int h, int w,
520     int bd) {
521   if (w < 16) {
522     av1_build_compound_diffwtd_mask_highbd_ssse3(
523         mask, mask_type, src0, src0_stride, src1, src1_stride, h, w, bd);
524   } else {
525     assert(mask_type == DIFFWTD_38 || mask_type == DIFFWTD_38_INV);
526     assert(bd >= 8);
527     assert((w % 16) == 0);
528     const __m256i y0 = _mm256_setzero_si256();
529     const __m256i yAOM_BLEND_A64_MAX_ALPHA =
530         _mm256_set1_epi16(AOM_BLEND_A64_MAX_ALPHA);
531     const int mask_base = 38;
532     const __m256i ymask_base = _mm256_set1_epi16(mask_base);
533     const uint16_t *ssrc0 = CONVERT_TO_SHORTPTR(src0);
534     const uint16_t *ssrc1 = CONVERT_TO_SHORTPTR(src1);
535     if (bd == 8) {
536       if (mask_type == DIFFWTD_38_INV) {
537         for (int i = 0; i < h; ++i) {
538           for (int j = 0; j < w; j += 16) {
539             __m256i s0 = _mm256_loadu_si256((const __m256i *)&ssrc0[j]);
540             __m256i s1 = _mm256_loadu_si256((const __m256i *)&ssrc1[j]);
541             __m256i diff = _mm256_srai_epi16(
542                 _mm256_abs_epi16(_mm256_sub_epi16(s0, s1)), DIFF_FACTOR_LOG2);
543             __m256i m = _mm256_min_epi16(
544                 _mm256_max_epi16(y0, _mm256_add_epi16(diff, ymask_base)),
545                 yAOM_BLEND_A64_MAX_ALPHA);
546             m = _mm256_sub_epi16(yAOM_BLEND_A64_MAX_ALPHA, m);
547             m = _mm256_packus_epi16(m, m);
548             m = _mm256_permute4x64_epi64(m, _MM_SHUFFLE(0, 0, 2, 0));
549             __m128i m0 = _mm256_castsi256_si128(m);
550             _mm_storeu_si128((__m128i *)&mask[j], m0);
551           }
552           ssrc0 += src0_stride;
553           ssrc1 += src1_stride;
554           mask += w;
555         }
556       } else {
557         for (int i = 0; i < h; ++i) {
558           for (int j = 0; j < w; j += 16) {
559             __m256i s0 = _mm256_loadu_si256((const __m256i *)&ssrc0[j]);
560             __m256i s1 = _mm256_loadu_si256((const __m256i *)&ssrc1[j]);
561             __m256i diff = _mm256_srai_epi16(
562                 _mm256_abs_epi16(_mm256_sub_epi16(s0, s1)), DIFF_FACTOR_LOG2);
563             __m256i m = _mm256_min_epi16(
564                 _mm256_max_epi16(y0, _mm256_add_epi16(diff, ymask_base)),
565                 yAOM_BLEND_A64_MAX_ALPHA);
566             m = _mm256_packus_epi16(m, m);
567             m = _mm256_permute4x64_epi64(m, _MM_SHUFFLE(0, 0, 2, 0));
568             __m128i m0 = _mm256_castsi256_si128(m);
569             _mm_storeu_si128((__m128i *)&mask[j], m0);
570           }
571           ssrc0 += src0_stride;
572           ssrc1 += src1_stride;
573           mask += w;
574         }
575       }
576     } else {
577       const __m128i xshift = xx_set1_64_from_32i(bd - 8 + DIFF_FACTOR_LOG2);
578       if (mask_type == DIFFWTD_38_INV) {
579         for (int i = 0; i < h; ++i) {
580           for (int j = 0; j < w; j += 16) {
581             __m256i s0 = _mm256_loadu_si256((const __m256i *)&ssrc0[j]);
582             __m256i s1 = _mm256_loadu_si256((const __m256i *)&ssrc1[j]);
583             __m256i diff = _mm256_sra_epi16(
584                 _mm256_abs_epi16(_mm256_sub_epi16(s0, s1)), xshift);
585             __m256i m = _mm256_min_epi16(
586                 _mm256_max_epi16(y0, _mm256_add_epi16(diff, ymask_base)),
587                 yAOM_BLEND_A64_MAX_ALPHA);
588             m = _mm256_sub_epi16(yAOM_BLEND_A64_MAX_ALPHA, m);
589             m = _mm256_packus_epi16(m, m);
590             m = _mm256_permute4x64_epi64(m, _MM_SHUFFLE(0, 0, 2, 0));
591             __m128i m0 = _mm256_castsi256_si128(m);
592             _mm_storeu_si128((__m128i *)&mask[j], m0);
593           }
594           ssrc0 += src0_stride;
595           ssrc1 += src1_stride;
596           mask += w;
597         }
598       } else {
599         for (int i = 0; i < h; ++i) {
600           for (int j = 0; j < w; j += 16) {
601             __m256i s0 = _mm256_loadu_si256((const __m256i *)&ssrc0[j]);
602             __m256i s1 = _mm256_loadu_si256((const __m256i *)&ssrc1[j]);
603             __m256i diff = _mm256_sra_epi16(
604                 _mm256_abs_epi16(_mm256_sub_epi16(s0, s1)), xshift);
605             __m256i m = _mm256_min_epi16(
606                 _mm256_max_epi16(y0, _mm256_add_epi16(diff, ymask_base)),
607                 yAOM_BLEND_A64_MAX_ALPHA);
608             m = _mm256_packus_epi16(m, m);
609             m = _mm256_permute4x64_epi64(m, _MM_SHUFFLE(0, 0, 2, 0));
610             __m128i m0 = _mm256_castsi256_si128(m);
611             _mm_storeu_si128((__m128i *)&mask[j], m0);
612           }
613           ssrc0 += src0_stride;
614           ssrc1 += src1_stride;
615           mask += w;
616         }
617       }
618     }
619   }
620 }
621