1 /*
2 * Copyright (c) 2017, Alliance for Open Media. All rights reserved
3 *
4 * This source code is subject to the terms of the BSD 2 Clause License and
5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6 * was not distributed with this source code in the LICENSE file, you can
7 * obtain it at https://www.aomedia.org/license/software-license. If the Alliance for Open
8 * Media Patent License 1.0 was not distributed with this source code in the
9 * PATENTS file, you can obtain it at https://www.aomedia.org/license/patent-license.
10 */
11 #include <immintrin.h>
12 #include "synonyms.h"
13 #include "synonyms_avx2.h"
14
15 #include "EbDefinitions.h"
16
calc_mask_d16_avx2(const __m256i * data_src0,const __m256i * data_src1,const __m256i * round_const,const __m256i * mask_base_16,const __m256i * clip_diff,int round)17 static INLINE __m256i calc_mask_d16_avx2(const __m256i *data_src0, const __m256i *data_src1,
18 const __m256i *round_const, const __m256i *mask_base_16,
19 const __m256i *clip_diff, int round) {
20 const __m256i diffa = _mm256_subs_epu16(*data_src0, *data_src1);
21 const __m256i diffb = _mm256_subs_epu16(*data_src1, *data_src0);
22 const __m256i diff = _mm256_max_epu16(diffa, diffb);
23 const __m256i diff_round = _mm256_srli_epi16(_mm256_adds_epu16(diff, *round_const), round);
24 const __m256i diff_factor = _mm256_srli_epi16(diff_round, DIFF_FACTOR_LOG2);
25 const __m256i diff_mask = _mm256_adds_epi16(diff_factor, *mask_base_16);
26 const __m256i diff_clamp = _mm256_min_epi16(diff_mask, *clip_diff);
27 return diff_clamp;
28 }
29
calc_mask_d16_inv_avx2(const __m256i * data_src0,const __m256i * data_src1,const __m256i * round_const,const __m256i * mask_base_16,const __m256i * clip_diff,int round)30 static INLINE __m256i calc_mask_d16_inv_avx2(const __m256i *data_src0, const __m256i *data_src1,
31 const __m256i *round_const,
32 const __m256i *mask_base_16, const __m256i *clip_diff,
33 int round) {
34 const __m256i diffa = _mm256_subs_epu16(*data_src0, *data_src1);
35 const __m256i diffb = _mm256_subs_epu16(*data_src1, *data_src0);
36 const __m256i diff = _mm256_max_epu16(diffa, diffb);
37 const __m256i diff_round = _mm256_srli_epi16(_mm256_adds_epu16(diff, *round_const), round);
38 const __m256i diff_factor = _mm256_srli_epi16(diff_round, DIFF_FACTOR_LOG2);
39 const __m256i diff_mask = _mm256_adds_epi16(diff_factor, *mask_base_16);
40 const __m256i diff_clamp = _mm256_min_epi16(diff_mask, *clip_diff);
41 const __m256i diff_const_16 = _mm256_sub_epi16(*clip_diff, diff_clamp);
42 return diff_const_16;
43 }
44
build_compound_diffwtd_mask_d16_avx2(uint8_t * mask,const CONV_BUF_TYPE * src0,int src0_stride,const CONV_BUF_TYPE * src1,int src1_stride,int h,int w,int shift)45 static INLINE void build_compound_diffwtd_mask_d16_avx2(uint8_t *mask, const CONV_BUF_TYPE *src0,
46 int src0_stride, const CONV_BUF_TYPE *src1,
47 int src1_stride, int h, int w, int shift) {
48 const int mask_base = 38;
49 const __m256i _r = _mm256_set1_epi16((1 << shift) >> 1);
50 const __m256i y38 = _mm256_set1_epi16(mask_base);
51 const __m256i y64 = _mm256_set1_epi16(AOM_BLEND_A64_MAX_ALPHA);
52 int i = 0;
53 if (w == 4) {
54 do {
55 const __m128i s0_a = xx_loadl_64(src0);
56 const __m128i s0_b = xx_loadl_64(src0 + src0_stride);
57 const __m128i s0_c = xx_loadl_64(src0 + src0_stride * 2);
58 const __m128i s0_d = xx_loadl_64(src0 + src0_stride * 3);
59 const __m128i s1_a = xx_loadl_64(src1);
60 const __m128i s1_b = xx_loadl_64(src1 + src1_stride);
61 const __m128i s1_c = xx_loadl_64(src1 + src1_stride * 2);
62 const __m128i s1_d = xx_loadl_64(src1 + src1_stride * 3);
63 const __m256i s0 = yy_set_m128i(_mm_unpacklo_epi64(s0_c, s0_d),
64 _mm_unpacklo_epi64(s0_a, s0_b));
65 const __m256i s1 = yy_set_m128i(_mm_unpacklo_epi64(s1_c, s1_d),
66 _mm_unpacklo_epi64(s1_a, s1_b));
67 const __m256i m16 = calc_mask_d16_avx2(&s0, &s1, &_r, &y38, &y64, shift);
68 const __m256i m8 = _mm256_packus_epi16(m16, _mm256_setzero_si256());
69 xx_storeu_128(mask, _mm256_castsi256_si128(_mm256_permute4x64_epi64(m8, 0xd8)));
70 src0 += src0_stride << 2;
71 src1 += src1_stride << 2;
72 mask += 16;
73 i += 4;
74 } while (i < h);
75 } else if (w == 8) {
76 do {
77 const __m256i s0_a_b = yy_loadu2_128(src0 + src0_stride, src0);
78 const __m256i s0_c_d = yy_loadu2_128(src0 + src0_stride * 3, src0 + src0_stride * 2);
79 const __m256i s1_a_b = yy_loadu2_128(src1 + src1_stride, src1);
80 const __m256i s1_c_d = yy_loadu2_128(src1 + src1_stride * 3, src1 + src1_stride * 2);
81 const __m256i m16_a_b = calc_mask_d16_avx2(&s0_a_b, &s1_a_b, &_r, &y38, &y64, shift);
82 const __m256i m16_c_d = calc_mask_d16_avx2(&s0_c_d, &s1_c_d, &_r, &y38, &y64, shift);
83 const __m256i m8 = _mm256_packus_epi16(m16_a_b, m16_c_d);
84 yy_storeu_256(mask, _mm256_permute4x64_epi64(m8, 0xd8));
85 src0 += src0_stride << 2;
86 src1 += src1_stride << 2;
87 mask += 32;
88 i += 4;
89 } while (i < h);
90 } else if (w == 16) {
91 do {
92 const __m256i s0_a = yy_loadu_256(src0);
93 const __m256i s0_b = yy_loadu_256(src0 + src0_stride);
94 const __m256i s1_a = yy_loadu_256(src1);
95 const __m256i s1_b = yy_loadu_256(src1 + src1_stride);
96 const __m256i m16_a = calc_mask_d16_avx2(&s0_a, &s1_a, &_r, &y38, &y64, shift);
97 const __m256i m16_b = calc_mask_d16_avx2(&s0_b, &s1_b, &_r, &y38, &y64, shift);
98 const __m256i m8 = _mm256_packus_epi16(m16_a, m16_b);
99 yy_storeu_256(mask, _mm256_permute4x64_epi64(m8, 0xd8));
100 src0 += src0_stride << 1;
101 src1 += src1_stride << 1;
102 mask += 32;
103 i += 2;
104 } while (i < h);
105 } else if (w == 32) {
106 do {
107 const __m256i s0_a = yy_loadu_256(src0);
108 const __m256i s0_b = yy_loadu_256(src0 + 16);
109 const __m256i s1_a = yy_loadu_256(src1);
110 const __m256i s1_b = yy_loadu_256(src1 + 16);
111 const __m256i m16_a = calc_mask_d16_avx2(&s0_a, &s1_a, &_r, &y38, &y64, shift);
112 const __m256i m16_b = calc_mask_d16_avx2(&s0_b, &s1_b, &_r, &y38, &y64, shift);
113 const __m256i m8 = _mm256_packus_epi16(m16_a, m16_b);
114 yy_storeu_256(mask, _mm256_permute4x64_epi64(m8, 0xd8));
115 src0 += src0_stride;
116 src1 += src1_stride;
117 mask += 32;
118 i += 1;
119 } while (i < h);
120 } else if (w == 64) {
121 do {
122 const __m256i s0_a = yy_loadu_256(src0);
123 const __m256i s0_b = yy_loadu_256(src0 + 16);
124 const __m256i s0_c = yy_loadu_256(src0 + 32);
125 const __m256i s0_d = yy_loadu_256(src0 + 48);
126 const __m256i s1_a = yy_loadu_256(src1);
127 const __m256i s1_b = yy_loadu_256(src1 + 16);
128 const __m256i s1_c = yy_loadu_256(src1 + 32);
129 const __m256i s1_d = yy_loadu_256(src1 + 48);
130 const __m256i m16_a = calc_mask_d16_avx2(&s0_a, &s1_a, &_r, &y38, &y64, shift);
131 const __m256i m16_b = calc_mask_d16_avx2(&s0_b, &s1_b, &_r, &y38, &y64, shift);
132 const __m256i m16_c = calc_mask_d16_avx2(&s0_c, &s1_c, &_r, &y38, &y64, shift);
133 const __m256i m16_d = calc_mask_d16_avx2(&s0_d, &s1_d, &_r, &y38, &y64, shift);
134 const __m256i m8_a_b = _mm256_packus_epi16(m16_a, m16_b);
135 const __m256i m8_c_d = _mm256_packus_epi16(m16_c, m16_d);
136 yy_storeu_256(mask, _mm256_permute4x64_epi64(m8_a_b, 0xd8));
137 yy_storeu_256(mask + 32, _mm256_permute4x64_epi64(m8_c_d, 0xd8));
138 src0 += src0_stride;
139 src1 += src1_stride;
140 mask += 64;
141 i += 1;
142 } while (i < h);
143 } else {
144 do {
145 const __m256i s0_a = yy_loadu_256(src0);
146 const __m256i s0_b = yy_loadu_256(src0 + 16);
147 const __m256i s0_c = yy_loadu_256(src0 + 32);
148 const __m256i s0_d = yy_loadu_256(src0 + 48);
149 const __m256i s0_e = yy_loadu_256(src0 + 64);
150 const __m256i s0_f = yy_loadu_256(src0 + 80);
151 const __m256i s0_g = yy_loadu_256(src0 + 96);
152 const __m256i s0_h = yy_loadu_256(src0 + 112);
153 const __m256i s1_a = yy_loadu_256(src1);
154 const __m256i s1_b = yy_loadu_256(src1 + 16);
155 const __m256i s1_c = yy_loadu_256(src1 + 32);
156 const __m256i s1_d = yy_loadu_256(src1 + 48);
157 const __m256i s1_e = yy_loadu_256(src1 + 64);
158 const __m256i s1_f = yy_loadu_256(src1 + 80);
159 const __m256i s1_g = yy_loadu_256(src1 + 96);
160 const __m256i s1_h = yy_loadu_256(src1 + 112);
161 const __m256i m16_a = calc_mask_d16_avx2(&s0_a, &s1_a, &_r, &y38, &y64, shift);
162 const __m256i m16_b = calc_mask_d16_avx2(&s0_b, &s1_b, &_r, &y38, &y64, shift);
163 const __m256i m16_c = calc_mask_d16_avx2(&s0_c, &s1_c, &_r, &y38, &y64, shift);
164 const __m256i m16_d = calc_mask_d16_avx2(&s0_d, &s1_d, &_r, &y38, &y64, shift);
165 const __m256i m16_e = calc_mask_d16_avx2(&s0_e, &s1_e, &_r, &y38, &y64, shift);
166 const __m256i m16_f = calc_mask_d16_avx2(&s0_f, &s1_f, &_r, &y38, &y64, shift);
167 const __m256i m16_g = calc_mask_d16_avx2(&s0_g, &s1_g, &_r, &y38, &y64, shift);
168 const __m256i m16_h = calc_mask_d16_avx2(&s0_h, &s1_h, &_r, &y38, &y64, shift);
169 const __m256i m8_a_b = _mm256_packus_epi16(m16_a, m16_b);
170 const __m256i m8_c_d = _mm256_packus_epi16(m16_c, m16_d);
171 const __m256i m8_e_f = _mm256_packus_epi16(m16_e, m16_f);
172 const __m256i m8_g_h = _mm256_packus_epi16(m16_g, m16_h);
173 yy_storeu_256(mask, _mm256_permute4x64_epi64(m8_a_b, 0xd8));
174 yy_storeu_256(mask + 32, _mm256_permute4x64_epi64(m8_c_d, 0xd8));
175 yy_storeu_256(mask + 64, _mm256_permute4x64_epi64(m8_e_f, 0xd8));
176 yy_storeu_256(mask + 96, _mm256_permute4x64_epi64(m8_g_h, 0xd8));
177 src0 += src0_stride;
178 src1 += src1_stride;
179 mask += 128;
180 i += 1;
181 } while (i < h);
182 }
183 }
184
build_compound_diffwtd_mask_d16_inv_avx2(uint8_t * mask,const CONV_BUF_TYPE * src0,int src0_stride,const CONV_BUF_TYPE * src1,int src1_stride,int h,int w,int shift)185 static INLINE void build_compound_diffwtd_mask_d16_inv_avx2(
186 uint8_t *mask, const CONV_BUF_TYPE *src0, int src0_stride, const CONV_BUF_TYPE *src1,
187 int src1_stride, int h, int w, int shift) {
188 const int mask_base = 38;
189 const __m256i _r = _mm256_set1_epi16((1 << shift) >> 1);
190 const __m256i y38 = _mm256_set1_epi16(mask_base);
191 const __m256i y64 = _mm256_set1_epi16(AOM_BLEND_A64_MAX_ALPHA);
192 int i = 0;
193 if (w == 4) {
194 do {
195 const __m128i s0_a = xx_loadl_64(src0);
196 const __m128i s0_b = xx_loadl_64(src0 + src0_stride);
197 const __m128i s0_c = xx_loadl_64(src0 + src0_stride * 2);
198 const __m128i s0_d = xx_loadl_64(src0 + src0_stride * 3);
199 const __m128i s1_a = xx_loadl_64(src1);
200 const __m128i s1_b = xx_loadl_64(src1 + src1_stride);
201 const __m128i s1_c = xx_loadl_64(src1 + src1_stride * 2);
202 const __m128i s1_d = xx_loadl_64(src1 + src1_stride * 3);
203 const __m256i s0 = yy_set_m128i(_mm_unpacklo_epi64(s0_c, s0_d),
204 _mm_unpacklo_epi64(s0_a, s0_b));
205 const __m256i s1 = yy_set_m128i(_mm_unpacklo_epi64(s1_c, s1_d),
206 _mm_unpacklo_epi64(s1_a, s1_b));
207 const __m256i m16 = calc_mask_d16_inv_avx2(&s0, &s1, &_r, &y38, &y64, shift);
208 const __m256i m8 = _mm256_packus_epi16(m16, _mm256_setzero_si256());
209 xx_storeu_128(mask, _mm256_castsi256_si128(_mm256_permute4x64_epi64(m8, 0xd8)));
210 src0 += src0_stride << 2;
211 src1 += src1_stride << 2;
212 mask += 16;
213 i += 4;
214 } while (i < h);
215 } else if (w == 8) {
216 do {
217 const __m256i s0_a_b = yy_loadu2_128(src0 + src0_stride, src0);
218 const __m256i s0_c_d = yy_loadu2_128(src0 + src0_stride * 3, src0 + src0_stride * 2);
219 const __m256i s1_a_b = yy_loadu2_128(src1 + src1_stride, src1);
220 const __m256i s1_c_d = yy_loadu2_128(src1 + src1_stride * 3, src1 + src1_stride * 2);
221 const __m256i m16_a_b = calc_mask_d16_inv_avx2(
222 &s0_a_b, &s1_a_b, &_r, &y38, &y64, shift);
223 const __m256i m16_c_d = calc_mask_d16_inv_avx2(
224 &s0_c_d, &s1_c_d, &_r, &y38, &y64, shift);
225 const __m256i m8 = _mm256_packus_epi16(m16_a_b, m16_c_d);
226 yy_storeu_256(mask, _mm256_permute4x64_epi64(m8, 0xd8));
227 src0 += src0_stride << 2;
228 src1 += src1_stride << 2;
229 mask += 32;
230 i += 4;
231 } while (i < h);
232 } else if (w == 16) {
233 do {
234 const __m256i s0_a = yy_loadu_256(src0);
235 const __m256i s0_b = yy_loadu_256(src0 + src0_stride);
236 const __m256i s1_a = yy_loadu_256(src1);
237 const __m256i s1_b = yy_loadu_256(src1 + src1_stride);
238 const __m256i m16_a = calc_mask_d16_inv_avx2(&s0_a, &s1_a, &_r, &y38, &y64, shift);
239 const __m256i m16_b = calc_mask_d16_inv_avx2(&s0_b, &s1_b, &_r, &y38, &y64, shift);
240 const __m256i m8 = _mm256_packus_epi16(m16_a, m16_b);
241 yy_storeu_256(mask, _mm256_permute4x64_epi64(m8, 0xd8));
242 src0 += src0_stride << 1;
243 src1 += src1_stride << 1;
244 mask += 32;
245 i += 2;
246 } while (i < h);
247 } else if (w == 32) {
248 do {
249 const __m256i s0_a = yy_loadu_256(src0);
250 const __m256i s0_b = yy_loadu_256(src0 + 16);
251 const __m256i s1_a = yy_loadu_256(src1);
252 const __m256i s1_b = yy_loadu_256(src1 + 16);
253 const __m256i m16_a = calc_mask_d16_inv_avx2(&s0_a, &s1_a, &_r, &y38, &y64, shift);
254 const __m256i m16_b = calc_mask_d16_inv_avx2(&s0_b, &s1_b, &_r, &y38, &y64, shift);
255 const __m256i m8 = _mm256_packus_epi16(m16_a, m16_b);
256 yy_storeu_256(mask, _mm256_permute4x64_epi64(m8, 0xd8));
257 src0 += src0_stride;
258 src1 += src1_stride;
259 mask += 32;
260 i += 1;
261 } while (i < h);
262 } else if (w == 64) {
263 do {
264 const __m256i s0_a = yy_loadu_256(src0);
265 const __m256i s0_b = yy_loadu_256(src0 + 16);
266 const __m256i s0_c = yy_loadu_256(src0 + 32);
267 const __m256i s0_d = yy_loadu_256(src0 + 48);
268 const __m256i s1_a = yy_loadu_256(src1);
269 const __m256i s1_b = yy_loadu_256(src1 + 16);
270 const __m256i s1_c = yy_loadu_256(src1 + 32);
271 const __m256i s1_d = yy_loadu_256(src1 + 48);
272 const __m256i m16_a = calc_mask_d16_inv_avx2(&s0_a, &s1_a, &_r, &y38, &y64, shift);
273 const __m256i m16_b = calc_mask_d16_inv_avx2(&s0_b, &s1_b, &_r, &y38, &y64, shift);
274 const __m256i m16_c = calc_mask_d16_inv_avx2(&s0_c, &s1_c, &_r, &y38, &y64, shift);
275 const __m256i m16_d = calc_mask_d16_inv_avx2(&s0_d, &s1_d, &_r, &y38, &y64, shift);
276 const __m256i m8_a_b = _mm256_packus_epi16(m16_a, m16_b);
277 const __m256i m8_c_d = _mm256_packus_epi16(m16_c, m16_d);
278 yy_storeu_256(mask, _mm256_permute4x64_epi64(m8_a_b, 0xd8));
279 yy_storeu_256(mask + 32, _mm256_permute4x64_epi64(m8_c_d, 0xd8));
280 src0 += src0_stride;
281 src1 += src1_stride;
282 mask += 64;
283 i += 1;
284 } while (i < h);
285 } else {
286 do {
287 const __m256i s0_a = yy_loadu_256(src0);
288 const __m256i s0_b = yy_loadu_256(src0 + 16);
289 const __m256i s0_c = yy_loadu_256(src0 + 32);
290 const __m256i s0_d = yy_loadu_256(src0 + 48);
291 const __m256i s0_e = yy_loadu_256(src0 + 64);
292 const __m256i s0_f = yy_loadu_256(src0 + 80);
293 const __m256i s0_g = yy_loadu_256(src0 + 96);
294 const __m256i s0_h = yy_loadu_256(src0 + 112);
295 const __m256i s1_a = yy_loadu_256(src1);
296 const __m256i s1_b = yy_loadu_256(src1 + 16);
297 const __m256i s1_c = yy_loadu_256(src1 + 32);
298 const __m256i s1_d = yy_loadu_256(src1 + 48);
299 const __m256i s1_e = yy_loadu_256(src1 + 64);
300 const __m256i s1_f = yy_loadu_256(src1 + 80);
301 const __m256i s1_g = yy_loadu_256(src1 + 96);
302 const __m256i s1_h = yy_loadu_256(src1 + 112);
303 const __m256i m16_a = calc_mask_d16_inv_avx2(&s0_a, &s1_a, &_r, &y38, &y64, shift);
304 const __m256i m16_b = calc_mask_d16_inv_avx2(&s0_b, &s1_b, &_r, &y38, &y64, shift);
305 const __m256i m16_c = calc_mask_d16_inv_avx2(&s0_c, &s1_c, &_r, &y38, &y64, shift);
306 const __m256i m16_d = calc_mask_d16_inv_avx2(&s0_d, &s1_d, &_r, &y38, &y64, shift);
307 const __m256i m16_e = calc_mask_d16_inv_avx2(&s0_e, &s1_e, &_r, &y38, &y64, shift);
308 const __m256i m16_f = calc_mask_d16_inv_avx2(&s0_f, &s1_f, &_r, &y38, &y64, shift);
309 const __m256i m16_g = calc_mask_d16_inv_avx2(&s0_g, &s1_g, &_r, &y38, &y64, shift);
310 const __m256i m16_h = calc_mask_d16_inv_avx2(&s0_h, &s1_h, &_r, &y38, &y64, shift);
311 const __m256i m8_a_b = _mm256_packus_epi16(m16_a, m16_b);
312 const __m256i m8_c_d = _mm256_packus_epi16(m16_c, m16_d);
313 const __m256i m8_e_f = _mm256_packus_epi16(m16_e, m16_f);
314 const __m256i m8_g_h = _mm256_packus_epi16(m16_g, m16_h);
315 yy_storeu_256(mask, _mm256_permute4x64_epi64(m8_a_b, 0xd8));
316 yy_storeu_256(mask + 32, _mm256_permute4x64_epi64(m8_c_d, 0xd8));
317 yy_storeu_256(mask + 64, _mm256_permute4x64_epi64(m8_e_f, 0xd8));
318 yy_storeu_256(mask + 96, _mm256_permute4x64_epi64(m8_g_h, 0xd8));
319 src0 += src0_stride;
320 src1 += src1_stride;
321 mask += 128;
322 i += 1;
323 } while (i < h);
324 }
325 }
326
svt_av1_build_compound_diffwtd_mask_d16_avx2(uint8_t * mask,DIFFWTD_MASK_TYPE mask_type,const CONV_BUF_TYPE * src0,int src0_stride,const CONV_BUF_TYPE * src1,int src1_stride,int h,int w,ConvolveParams * conv_params,int bd)327 void svt_av1_build_compound_diffwtd_mask_d16_avx2(uint8_t *mask, DIFFWTD_MASK_TYPE mask_type,
328 const CONV_BUF_TYPE *src0, int src0_stride,
329 const CONV_BUF_TYPE *src1, int src1_stride, int h,
330 int w, ConvolveParams *conv_params, int bd) {
331 const int shift = 2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1 + (bd - 8);
332 // When rounding constant is added, there is a possibility of overflow.
333 // However that much precision is not required. Code should very well work for
334 // other values of DIFF_FACTOR_LOG2 and AOM_BLEND_A64_MAX_ALPHA as well. But
335 // there is a possibility of corner case bugs.
336 assert(DIFF_FACTOR_LOG2 == 4);
337 assert(AOM_BLEND_A64_MAX_ALPHA == 64);
338
339 if (mask_type == DIFFWTD_38) {
340 build_compound_diffwtd_mask_d16_avx2(
341 mask, src0, src0_stride, src1, src1_stride, h, w, shift);
342 } else {
343 build_compound_diffwtd_mask_d16_inv_avx2(
344 mask, src0, src0_stride, src1, src1_stride, h, w, shift);
345 }
346 }
347