1 /*
2 * Copyright(c) 2019 Intel Corporation
3 *
4 * This source code is subject to the terms of the BSD 2 Clause License and
5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6 * was not distributed with this source code in the LICENSE file, you can
7 * obtain it at https://www.aomedia.org/license/software-license. If the Alliance for Open
8 * Media Patent License 1.0 was not distributed with this source code in the
9 * PATENTS file, you can obtain it at https://www.aomedia.org/license/patent-license.
10 */
11 #include "EbDefinitions.h"
12
13 #if EN_AVX512_SUPPORT
14
15 #include <immintrin.h>
16 #include "EbPictureOperators_Inline_AVX2.h"
17 #include "EbPictureOperators_SSE2.h"
18 #include "EbMemory_AVX2.h"
19
20 /*******************************************************************************
21 * Helper function that add 32bit values from sum32 to 64bit values in sum64
22 * sum32 is also zeroed
23 *******************************************************************************/
sum32_to64_avx512(__m512i * const sum32,__m512i * const sum64)24 static INLINE void sum32_to64_avx512(__m512i *const sum32, __m512i *const sum64) {
25 //Save partial sum into large 64bit register instead of 32 bit (which could overflow)
26 *sum64 = _mm512_add_epi64(*sum64, _mm512_unpacklo_epi32(*sum32, _mm512_setzero_si512()));
27 *sum64 = _mm512_add_epi64(*sum64, _mm512_unpackhi_epi32(*sum32, _mm512_setzero_si512()));
28 *sum32 = _mm512_setzero_si512();
29 }
30
residual32x2_avx512(const uint8_t * input,const uint32_t input_stride,const uint8_t * pred,const uint32_t pred_stride,int16_t * residual,const uint32_t residual_stride)31 static INLINE void residual32x2_avx512(const uint8_t *input, const uint32_t input_stride,
32 const uint8_t *pred, const uint32_t pred_stride,
33 int16_t *residual, const uint32_t residual_stride) {
34 const __m512i zero = _mm512_setzero_si512();
35 const __m512i idx = _mm512_setr_epi64(0, 4, 1, 5, 2, 6, 3, 7);
36 const __m256i in0 = _mm256_loadu_si256((__m256i *)input);
37 const __m256i in1 = _mm256_loadu_si256((__m256i *)(input + input_stride));
38 const __m256i pr0 = _mm256_loadu_si256((__m256i *)pred);
39 const __m256i pr1 = _mm256_loadu_si256((__m256i *)(pred + pred_stride));
40 const __m512i in2 = _mm512_inserti64x4(_mm512_castsi256_si512(in0), in1, 1);
41 const __m512i pr2 = _mm512_inserti64x4(_mm512_castsi256_si512(pr0), pr1, 1);
42 const __m512i in3 = _mm512_permutexvar_epi64(idx, in2);
43 const __m512i pr3 = _mm512_permutexvar_epi64(idx, pr2);
44 const __m512i in_lo = _mm512_unpacklo_epi8(in3, zero);
45 const __m512i in_hi = _mm512_unpackhi_epi8(in3, zero);
46 const __m512i pr_lo = _mm512_unpacklo_epi8(pr3, zero);
47 const __m512i pr_hi = _mm512_unpackhi_epi8(pr3, zero);
48 const __m512i re_lo = _mm512_sub_epi16(in_lo, pr_lo);
49 const __m512i re_hi = _mm512_sub_epi16(in_hi, pr_hi);
50 _mm512_storeu_si512((__m512i *)(residual + 0 * residual_stride), re_lo);
51 _mm512_storeu_si512((__m512i *)(residual + 1 * residual_stride), re_hi);
52 }
53
residual_kernel32_avx2(const uint8_t * input,const uint32_t input_stride,const uint8_t * pred,const uint32_t pred_stride,int16_t * residual,const uint32_t residual_stride,const uint32_t area_height)54 SIMD_INLINE void residual_kernel32_avx2(const uint8_t *input, const uint32_t input_stride,
55 const uint8_t *pred, const uint32_t pred_stride,
56 int16_t *residual, const uint32_t residual_stride,
57 const uint32_t area_height) {
58 uint32_t y = area_height;
59
60 do {
61 residual32x2_avx512(input, input_stride, pred, pred_stride, residual, residual_stride);
62 input += 2 * input_stride;
63 pred += 2 * pred_stride;
64 residual += 2 * residual_stride;
65 y -= 2;
66 } while (y);
67 }
68
residual64_avx512(const uint8_t * const input,const uint8_t * const pred,int16_t * const residual)69 static INLINE void residual64_avx512(const uint8_t *const input, const uint8_t *const pred,
70 int16_t *const residual) {
71 const __m512i zero = _mm512_setzero_si512();
72 const __m512i idx = _mm512_setr_epi64(0, 4, 1, 5, 2, 6, 3, 7);
73 const __m512i in0 = _mm512_loadu_si512((__m512i *)input);
74 const __m512i pr0 = _mm512_loadu_si512((__m512i *)pred);
75 const __m512i in1 = _mm512_permutexvar_epi64(idx, in0);
76 const __m512i pr1 = _mm512_permutexvar_epi64(idx, pr0);
77 const __m512i in_lo = _mm512_unpacklo_epi8(in1, zero);
78 const __m512i in_hi = _mm512_unpackhi_epi8(in1, zero);
79 const __m512i pr_lo = _mm512_unpacklo_epi8(pr1, zero);
80 const __m512i pr_hi = _mm512_unpackhi_epi8(pr1, zero);
81 const __m512i re_lo = _mm512_sub_epi16(in_lo, pr_lo);
82 const __m512i re_hi = _mm512_sub_epi16(in_hi, pr_hi);
83 _mm512_storeu_si512((__m512i *)(residual + 0 * 32), re_lo);
84 _mm512_storeu_si512((__m512i *)(residual + 1 * 32), re_hi);
85 }
86
residual_kernel64_avx2(const uint8_t * input,const uint32_t input_stride,const uint8_t * pred,const uint32_t pred_stride,int16_t * residual,const uint32_t residual_stride,const uint32_t area_height)87 SIMD_INLINE void residual_kernel64_avx2(const uint8_t *input, const uint32_t input_stride,
88 const uint8_t *pred, const uint32_t pred_stride,
89 int16_t *residual, const uint32_t residual_stride,
90 const uint32_t area_height) {
91 uint32_t y = area_height;
92
93 do {
94 residual64_avx512(input, pred, residual);
95 input += input_stride;
96 pred += pred_stride;
97 residual += residual_stride;
98 } while (--y);
99 }
100
residual_kernel128_avx2(const uint8_t * input,const uint32_t input_stride,const uint8_t * pred,const uint32_t pred_stride,int16_t * residual,const uint32_t residual_stride,const uint32_t area_height)101 SIMD_INLINE void residual_kernel128_avx2(const uint8_t *input, const uint32_t input_stride,
102 const uint8_t *pred, const uint32_t pred_stride,
103 int16_t *residual, const uint32_t residual_stride,
104 const uint32_t area_height) {
105 uint32_t y = area_height;
106
107 do {
108 residual64_avx512(input + 0 * 64, pred + 0 * 64, residual + 0 * 64);
109 residual64_avx512(input + 1 * 64, pred + 1 * 64, residual + 1 * 64);
110 input += input_stride;
111 pred += pred_stride;
112 residual += residual_stride;
113 } while (--y);
114 }
115
svt_residual_kernel8bit_avx512(uint8_t * input,uint32_t input_stride,uint8_t * pred,uint32_t pred_stride,int16_t * residual,uint32_t residual_stride,uint32_t area_width,uint32_t area_height)116 void svt_residual_kernel8bit_avx512(uint8_t *input, uint32_t input_stride, uint8_t *pred,
117 uint32_t pred_stride, int16_t *residual,
118 uint32_t residual_stride, uint32_t area_width,
119 uint32_t area_height) {
120 switch (area_width) {
121 case 4:
122 residual_kernel4_avx2(
123 input, input_stride, pred, pred_stride, residual, residual_stride, area_height);
124 break;
125
126 case 8:
127 residual_kernel8_avx2(
128 input, input_stride, pred, pred_stride, residual, residual_stride, area_height);
129 break;
130
131 case 16:
132 residual_kernel16_avx2(
133 input, input_stride, pred, pred_stride, residual, residual_stride, area_height);
134 break;
135
136 case 32:
137 residual_kernel32_avx2(
138 input, input_stride, pred, pred_stride, residual, residual_stride, area_height);
139 break;
140
141 case 64:
142 residual_kernel64_avx2(
143 input, input_stride, pred, pred_stride, residual, residual_stride, area_height);
144 break;
145
146 default: // 128
147 residual_kernel128_avx2(
148 input, input_stride, pred, pred_stride, residual, residual_stride, area_height);
149 break;
150 }
151 }
152
Hadd32_AVX512_INTRIN(const __m512i src)153 static INLINE int32_t Hadd32_AVX512_INTRIN(const __m512i src) {
154 const __m256i src_l = _mm512_castsi512_si256(src);
155 const __m256i src_h = _mm512_extracti64x4_epi64(src, 1);
156 const __m256i sum = _mm256_add_epi32(src_l, src_h);
157
158 return hadd32_avx2_intrin(sum);
159 }
160
Distortion_AVX512_INTRIN(const __m256i input,const __m256i recon,__m512i * const sum)161 static INLINE void Distortion_AVX512_INTRIN(const __m256i input, const __m256i recon,
162 __m512i *const sum) {
163 const __m512i in = _mm512_cvtepu8_epi16(input);
164 const __m512i re = _mm512_cvtepu8_epi16(recon);
165 const __m512i diff = _mm512_sub_epi16(in, re);
166 const __m512i dist = _mm512_madd_epi16(diff, diff);
167 *sum = _mm512_add_epi32(*sum, dist);
168 }
169
SpatialFullDistortionKernel32_AVX512_INTRIN(const uint8_t * const input,const uint8_t * const recon,__m512i * const sum)170 static INLINE void SpatialFullDistortionKernel32_AVX512_INTRIN(const uint8_t *const input,
171 const uint8_t *const recon,
172 __m512i *const sum) {
173 const __m256i in = _mm256_loadu_si256((__m256i *)input);
174 const __m256i re = _mm256_loadu_si256((__m256i *)recon);
175 Distortion_AVX512_INTRIN(in, re, sum);
176 }
177
SpatialFullDistortionKernel64_AVX512_INTRIN(const uint8_t * const input,const uint8_t * const recon,__m512i * const sum)178 static INLINE void SpatialFullDistortionKernel64_AVX512_INTRIN(const uint8_t *const input,
179 const uint8_t *const recon,
180 __m512i *const sum) {
181 const __m512i in = _mm512_loadu_si512((__m512i *)input);
182 const __m512i re = _mm512_loadu_si512((__m512i *)recon);
183 const __m512i max = _mm512_max_epu8(in, re);
184 const __m512i min = _mm512_min_epu8(in, re);
185 const __m512i diff = _mm512_sub_epi8(max, min);
186 const __m512i diff_l = _mm512_unpacklo_epi8(diff, _mm512_setzero_si512());
187 const __m512i diff_h = _mm512_unpackhi_epi8(diff, _mm512_setzero_si512());
188 const __m512i dist_l = _mm512_madd_epi16(diff_l, diff_l);
189 const __m512i dist_h = _mm512_madd_epi16(diff_h, diff_h);
190 const __m512i dist = _mm512_add_epi32(dist_l, dist_h);
191 *sum = _mm512_add_epi32(*sum, dist);
192 }
193
svt_spatial_full_distortion_kernel_avx512(uint8_t * input,uint32_t input_offset,uint32_t input_stride,uint8_t * recon,int32_t recon_offset,uint32_t recon_stride,uint32_t area_width,uint32_t area_height)194 uint64_t svt_spatial_full_distortion_kernel_avx512(uint8_t *input, uint32_t input_offset,
195 uint32_t input_stride, uint8_t *recon,
196 int32_t recon_offset, uint32_t recon_stride,
197 uint32_t area_width, uint32_t area_height) {
198 const uint32_t leftover = area_width & 31;
199 int32_t h;
200 __m256i sum = _mm256_setzero_si256();
201 __m128i sum_l, sum_h, s;
202 uint64_t spatial_distortion = 0;
203 input += input_offset;
204 recon += recon_offset;
205
206 if (leftover) {
207 const uint8_t *inp = input + area_width - leftover;
208 const uint8_t *rec = recon + area_width - leftover;
209
210 if (leftover == 4) {
211 h = area_height;
212 do {
213 const __m128i in0 = _mm_cvtsi32_si128(*(uint32_t *)inp);
214 const __m128i in1 = _mm_cvtsi32_si128(*(uint32_t *)(inp + input_stride));
215 const __m128i re0 = _mm_cvtsi32_si128(*(uint32_t *)rec);
216 const __m128i re1 = _mm_cvtsi32_si128(*(uint32_t *)(rec + recon_stride));
217 const __m256i in = _mm256_setr_m128i(in0, in1);
218 const __m256i re = _mm256_setr_m128i(re0, re1);
219 distortion_avx2_intrin(in, re, &sum);
220 inp += 2 * input_stride;
221 rec += 2 * recon_stride;
222 h -= 2;
223 } while (h);
224
225 if (area_width == 4) {
226 sum_l = _mm256_extracti128_si256(sum, 0);
227 sum_h = _mm256_extracti128_si256(sum, 1);
228 s = _mm_add_epi32(sum_l, sum_h);
229 s = _mm_add_epi32(s, _mm_srli_si128(s, 4));
230 spatial_distortion = _mm_cvtsi128_si32(s);
231 return spatial_distortion;
232 }
233 } else if (leftover == 8) {
234 h = area_height;
235 do {
236 const __m128i in0 = _mm_loadl_epi64((__m128i *)inp);
237 const __m128i in1 = _mm_loadl_epi64((__m128i *)(inp + input_stride));
238 const __m128i re0 = _mm_loadl_epi64((__m128i *)rec);
239 const __m128i re1 = _mm_loadl_epi64((__m128i *)(rec + recon_stride));
240 const __m256i in = _mm256_setr_m128i(in0, in1);
241 const __m256i re = _mm256_setr_m128i(re0, re1);
242 distortion_avx2_intrin(in, re, &sum);
243 inp += 2 * input_stride;
244 rec += 2 * recon_stride;
245 h -= 2;
246 } while (h);
247 } else if (leftover <= 16) {
248 h = area_height;
249 do {
250 spatial_full_distortion_kernel16_avx2_intrin(inp, rec, &sum);
251 inp += input_stride;
252 rec += recon_stride;
253 } while (--h);
254
255 if (leftover == 12) {
256 const __m256i mask = _mm256_setr_epi32(-1, -1, -1, -1, -1, -1, 0, 0);
257 sum = _mm256_and_si256(sum, mask);
258 }
259 } else {
260 __m256i sum1 = _mm256_setzero_si256();
261 h = area_height;
262 do {
263 spatial_full_distortion_kernel32_leftover_avx2_intrin(inp, rec, &sum, &sum1);
264 inp += input_stride;
265 rec += recon_stride;
266 } while (--h);
267
268 __m256i mask[2];
269 if (leftover == 20) {
270 mask[0] = _mm256_setr_epi32(-1, -1, -1, -1, -1, -1, 0, 0);
271 mask[1] = _mm256_setr_epi32(-1, -1, -1, -1, 0, 0, 0, 0);
272 } else if (leftover == 24) {
273 mask[0] = _mm256_setr_epi32(-1, -1, -1, -1, -1, -1, -1, -1);
274 mask[1] = _mm256_setr_epi32(-1, -1, -1, -1, 0, 0, 0, 0);
275 } else { // leftover = 28
276 mask[0] = _mm256_setr_epi32(-1, -1, -1, -1, -1, -1, -1, -1);
277 mask[1] = _mm256_setr_epi32(-1, -1, -1, -1, -1, -1, 0, 0);
278 }
279
280 sum = _mm256_and_si256(sum, mask[0]);
281 sum1 = _mm256_and_si256(sum1, mask[1]);
282 sum = _mm256_add_epi32(sum, sum1);
283 }
284 }
285
286 area_width -= leftover;
287
288 if (area_width) {
289 uint8_t *inp = input;
290 uint8_t *rec = recon;
291 h = area_height;
292
293 if (area_width == 32) {
294 do {
295 spatial_full_distortion_kernel32_avx2_intrin(inp, rec, &sum);
296 inp += input_stride;
297 rec += recon_stride;
298 } while (--h);
299 } else {
300 __m512i sum512 = _mm512_setzero_si512();
301
302 if (area_width == 64) {
303 do {
304 SpatialFullDistortionKernel64_AVX512_INTRIN(inp, rec, &sum512);
305 inp += input_stride;
306 rec += recon_stride;
307 } while (--h);
308 } else if (area_width == 96) {
309 do {
310 SpatialFullDistortionKernel64_AVX512_INTRIN(inp, rec, &sum512);
311 SpatialFullDistortionKernel32_AVX512_INTRIN(inp + 64, rec + 64, &sum512);
312 inp += input_stride;
313 rec += recon_stride;
314 } while (--h);
315 } else if (area_width == 128) {
316 do {
317 SpatialFullDistortionKernel64_AVX512_INTRIN(inp, rec, &sum512);
318 SpatialFullDistortionKernel64_AVX512_INTRIN(inp + 64, rec + 64, &sum512);
319 inp += input_stride;
320 rec += recon_stride;
321 } while (--h);
322 } else {
323 __m512i sum64 = _mm512_setzero_si512();
324
325 if (area_width & 32) {
326 do {
327 SpatialFullDistortionKernel32_AVX512_INTRIN(inp, rec, &sum512);
328 inp += input_stride;
329 rec += recon_stride;
330 } while (--h);
331 inp = input + 32;
332 rec = recon + 32;
333 h = area_height;
334 area_width -= 32;
335 }
336
337 do {
338 for (uint32_t w = 0; w < area_width; w += 64) {
339 SpatialFullDistortionKernel64_AVX512_INTRIN(inp + w, rec + w, &sum512);
340 }
341 sum32_to64_avx512(&sum512, &sum64);
342 inp += input_stride;
343 rec += recon_stride;
344 } while (--h);
345
346 const __m256i sum_L = _mm512_castsi512_si256(sum64);
347 const __m256i sum_H = _mm512_extracti64x4_epi64(sum64, 1);
348 __m256i leftover_sum64 = _mm256_unpacklo_epi32(sum, _mm256_setzero_si256());
349 leftover_sum64 = _mm256_add_epi64(
350 leftover_sum64, _mm256_unpackhi_epi32(sum, _mm256_setzero_si256()));
351 sum = _mm256_add_epi64(sum_L, sum_H);
352 sum = _mm256_add_epi64(sum, leftover_sum64);
353 s = _mm_add_epi64(_mm256_castsi256_si128(sum),
354 _mm256_extracti128_si256(sum, 1));
355 return _mm_extract_epi64(s, 0) + _mm_extract_epi64(s, 1);
356 }
357
358 const __m256i sum512_L = _mm512_castsi512_si256(sum512);
359 const __m256i sum512_H = _mm512_extracti64x4_epi64(sum512, 1);
360 sum = _mm256_add_epi32(sum, sum512_L);
361 sum = _mm256_add_epi32(sum, sum512_H);
362 }
363 }
364
365 return hadd32_avx2_intrin(sum);
366 }
367
368 #endif // EN_AVX512_SUPPORT
369