1 /*
2 * Copyright(c) 2019 Intel Corporation
3 *
4 * This source code is subject to the terms of the BSD 2 Clause License and
5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6 * was not distributed with this source code in the LICENSE file, you can
7 * obtain it at https://www.aomedia.org/license/software-license. If the Alliance for Open
8 * Media Patent License 1.0 was not distributed with this source code in the
9 * PATENTS file, you can obtain it at https://www.aomedia.org/license/patent-license.
10 */
11 #include "EbDefinitions.h"
12 
13 #if EN_AVX512_SUPPORT
14 
15 #include <immintrin.h>
16 #include "EbPictureOperators_Inline_AVX2.h"
17 #include "EbPictureOperators_SSE2.h"
18 #include "EbMemory_AVX2.h"
19 
20 /*******************************************************************************
21  * Helper function that add 32bit values from sum32 to 64bit values in sum64
22  * sum32 is also zeroed
23 *******************************************************************************/
sum32_to64_avx512(__m512i * const sum32,__m512i * const sum64)24 static INLINE void sum32_to64_avx512(__m512i *const sum32, __m512i *const sum64) {
25     //Save partial sum into large 64bit register instead of 32 bit (which could overflow)
26     *sum64 = _mm512_add_epi64(*sum64, _mm512_unpacklo_epi32(*sum32, _mm512_setzero_si512()));
27     *sum64 = _mm512_add_epi64(*sum64, _mm512_unpackhi_epi32(*sum32, _mm512_setzero_si512()));
28     *sum32 = _mm512_setzero_si512();
29 }
30 
residual32x2_avx512(const uint8_t * input,const uint32_t input_stride,const uint8_t * pred,const uint32_t pred_stride,int16_t * residual,const uint32_t residual_stride)31 static INLINE void residual32x2_avx512(const uint8_t *input, const uint32_t input_stride,
32                                        const uint8_t *pred, const uint32_t pred_stride,
33                                        int16_t *residual, const uint32_t residual_stride) {
34     const __m512i zero  = _mm512_setzero_si512();
35     const __m512i idx   = _mm512_setr_epi64(0, 4, 1, 5, 2, 6, 3, 7);
36     const __m256i in0   = _mm256_loadu_si256((__m256i *)input);
37     const __m256i in1   = _mm256_loadu_si256((__m256i *)(input + input_stride));
38     const __m256i pr0   = _mm256_loadu_si256((__m256i *)pred);
39     const __m256i pr1   = _mm256_loadu_si256((__m256i *)(pred + pred_stride));
40     const __m512i in2   = _mm512_inserti64x4(_mm512_castsi256_si512(in0), in1, 1);
41     const __m512i pr2   = _mm512_inserti64x4(_mm512_castsi256_si512(pr0), pr1, 1);
42     const __m512i in3   = _mm512_permutexvar_epi64(idx, in2);
43     const __m512i pr3   = _mm512_permutexvar_epi64(idx, pr2);
44     const __m512i in_lo = _mm512_unpacklo_epi8(in3, zero);
45     const __m512i in_hi = _mm512_unpackhi_epi8(in3, zero);
46     const __m512i pr_lo = _mm512_unpacklo_epi8(pr3, zero);
47     const __m512i pr_hi = _mm512_unpackhi_epi8(pr3, zero);
48     const __m512i re_lo = _mm512_sub_epi16(in_lo, pr_lo);
49     const __m512i re_hi = _mm512_sub_epi16(in_hi, pr_hi);
50     _mm512_storeu_si512((__m512i *)(residual + 0 * residual_stride), re_lo);
51     _mm512_storeu_si512((__m512i *)(residual + 1 * residual_stride), re_hi);
52 }
53 
residual_kernel32_avx2(const uint8_t * input,const uint32_t input_stride,const uint8_t * pred,const uint32_t pred_stride,int16_t * residual,const uint32_t residual_stride,const uint32_t area_height)54 SIMD_INLINE void residual_kernel32_avx2(const uint8_t *input, const uint32_t input_stride,
55                                         const uint8_t *pred, const uint32_t pred_stride,
56                                         int16_t *residual, const uint32_t residual_stride,
57                                         const uint32_t area_height) {
58     uint32_t y = area_height;
59 
60     do {
61         residual32x2_avx512(input, input_stride, pred, pred_stride, residual, residual_stride);
62         input += 2 * input_stride;
63         pred += 2 * pred_stride;
64         residual += 2 * residual_stride;
65         y -= 2;
66     } while (y);
67 }
68 
residual64_avx512(const uint8_t * const input,const uint8_t * const pred,int16_t * const residual)69 static INLINE void residual64_avx512(const uint8_t *const input, const uint8_t *const pred,
70                                      int16_t *const residual) {
71     const __m512i zero  = _mm512_setzero_si512();
72     const __m512i idx   = _mm512_setr_epi64(0, 4, 1, 5, 2, 6, 3, 7);
73     const __m512i in0   = _mm512_loadu_si512((__m512i *)input);
74     const __m512i pr0   = _mm512_loadu_si512((__m512i *)pred);
75     const __m512i in1   = _mm512_permutexvar_epi64(idx, in0);
76     const __m512i pr1   = _mm512_permutexvar_epi64(idx, pr0);
77     const __m512i in_lo = _mm512_unpacklo_epi8(in1, zero);
78     const __m512i in_hi = _mm512_unpackhi_epi8(in1, zero);
79     const __m512i pr_lo = _mm512_unpacklo_epi8(pr1, zero);
80     const __m512i pr_hi = _mm512_unpackhi_epi8(pr1, zero);
81     const __m512i re_lo = _mm512_sub_epi16(in_lo, pr_lo);
82     const __m512i re_hi = _mm512_sub_epi16(in_hi, pr_hi);
83     _mm512_storeu_si512((__m512i *)(residual + 0 * 32), re_lo);
84     _mm512_storeu_si512((__m512i *)(residual + 1 * 32), re_hi);
85 }
86 
residual_kernel64_avx2(const uint8_t * input,const uint32_t input_stride,const uint8_t * pred,const uint32_t pred_stride,int16_t * residual,const uint32_t residual_stride,const uint32_t area_height)87 SIMD_INLINE void residual_kernel64_avx2(const uint8_t *input, const uint32_t input_stride,
88                                         const uint8_t *pred, const uint32_t pred_stride,
89                                         int16_t *residual, const uint32_t residual_stride,
90                                         const uint32_t area_height) {
91     uint32_t y = area_height;
92 
93     do {
94         residual64_avx512(input, pred, residual);
95         input += input_stride;
96         pred += pred_stride;
97         residual += residual_stride;
98     } while (--y);
99 }
100 
residual_kernel128_avx2(const uint8_t * input,const uint32_t input_stride,const uint8_t * pred,const uint32_t pred_stride,int16_t * residual,const uint32_t residual_stride,const uint32_t area_height)101 SIMD_INLINE void residual_kernel128_avx2(const uint8_t *input, const uint32_t input_stride,
102                                          const uint8_t *pred, const uint32_t pred_stride,
103                                          int16_t *residual, const uint32_t residual_stride,
104                                          const uint32_t area_height) {
105     uint32_t y = area_height;
106 
107     do {
108         residual64_avx512(input + 0 * 64, pred + 0 * 64, residual + 0 * 64);
109         residual64_avx512(input + 1 * 64, pred + 1 * 64, residual + 1 * 64);
110         input += input_stride;
111         pred += pred_stride;
112         residual += residual_stride;
113     } while (--y);
114 }
115 
svt_residual_kernel8bit_avx512(uint8_t * input,uint32_t input_stride,uint8_t * pred,uint32_t pred_stride,int16_t * residual,uint32_t residual_stride,uint32_t area_width,uint32_t area_height)116 void svt_residual_kernel8bit_avx512(uint8_t *input, uint32_t input_stride, uint8_t *pred,
117                                     uint32_t pred_stride, int16_t *residual,
118                                     uint32_t residual_stride, uint32_t area_width,
119                                     uint32_t area_height) {
120     switch (area_width) {
121     case 4:
122         residual_kernel4_avx2(
123             input, input_stride, pred, pred_stride, residual, residual_stride, area_height);
124         break;
125 
126     case 8:
127         residual_kernel8_avx2(
128             input, input_stride, pred, pred_stride, residual, residual_stride, area_height);
129         break;
130 
131     case 16:
132         residual_kernel16_avx2(
133             input, input_stride, pred, pred_stride, residual, residual_stride, area_height);
134         break;
135 
136     case 32:
137         residual_kernel32_avx2(
138             input, input_stride, pred, pred_stride, residual, residual_stride, area_height);
139         break;
140 
141     case 64:
142         residual_kernel64_avx2(
143             input, input_stride, pred, pred_stride, residual, residual_stride, area_height);
144         break;
145 
146     default: // 128
147         residual_kernel128_avx2(
148             input, input_stride, pred, pred_stride, residual, residual_stride, area_height);
149         break;
150     }
151 }
152 
Hadd32_AVX512_INTRIN(const __m512i src)153 static INLINE int32_t Hadd32_AVX512_INTRIN(const __m512i src) {
154     const __m256i src_l = _mm512_castsi512_si256(src);
155     const __m256i src_h = _mm512_extracti64x4_epi64(src, 1);
156     const __m256i sum   = _mm256_add_epi32(src_l, src_h);
157 
158     return hadd32_avx2_intrin(sum);
159 }
160 
Distortion_AVX512_INTRIN(const __m256i input,const __m256i recon,__m512i * const sum)161 static INLINE void Distortion_AVX512_INTRIN(const __m256i input, const __m256i recon,
162                                             __m512i *const sum) {
163     const __m512i in   = _mm512_cvtepu8_epi16(input);
164     const __m512i re   = _mm512_cvtepu8_epi16(recon);
165     const __m512i diff = _mm512_sub_epi16(in, re);
166     const __m512i dist = _mm512_madd_epi16(diff, diff);
167     *sum               = _mm512_add_epi32(*sum, dist);
168 }
169 
SpatialFullDistortionKernel32_AVX512_INTRIN(const uint8_t * const input,const uint8_t * const recon,__m512i * const sum)170 static INLINE void SpatialFullDistortionKernel32_AVX512_INTRIN(const uint8_t *const input,
171                                                                const uint8_t *const recon,
172                                                                __m512i *const       sum) {
173     const __m256i in = _mm256_loadu_si256((__m256i *)input);
174     const __m256i re = _mm256_loadu_si256((__m256i *)recon);
175     Distortion_AVX512_INTRIN(in, re, sum);
176 }
177 
SpatialFullDistortionKernel64_AVX512_INTRIN(const uint8_t * const input,const uint8_t * const recon,__m512i * const sum)178 static INLINE void SpatialFullDistortionKernel64_AVX512_INTRIN(const uint8_t *const input,
179                                                                const uint8_t *const recon,
180                                                                __m512i *const       sum) {
181     const __m512i in     = _mm512_loadu_si512((__m512i *)input);
182     const __m512i re     = _mm512_loadu_si512((__m512i *)recon);
183     const __m512i max    = _mm512_max_epu8(in, re);
184     const __m512i min    = _mm512_min_epu8(in, re);
185     const __m512i diff   = _mm512_sub_epi8(max, min);
186     const __m512i diff_l = _mm512_unpacklo_epi8(diff, _mm512_setzero_si512());
187     const __m512i diff_h = _mm512_unpackhi_epi8(diff, _mm512_setzero_si512());
188     const __m512i dist_l = _mm512_madd_epi16(diff_l, diff_l);
189     const __m512i dist_h = _mm512_madd_epi16(diff_h, diff_h);
190     const __m512i dist   = _mm512_add_epi32(dist_l, dist_h);
191     *sum                 = _mm512_add_epi32(*sum, dist);
192 }
193 
svt_spatial_full_distortion_kernel_avx512(uint8_t * input,uint32_t input_offset,uint32_t input_stride,uint8_t * recon,int32_t recon_offset,uint32_t recon_stride,uint32_t area_width,uint32_t area_height)194 uint64_t svt_spatial_full_distortion_kernel_avx512(uint8_t *input, uint32_t input_offset,
195                                                    uint32_t input_stride, uint8_t *recon,
196                                                    int32_t recon_offset, uint32_t recon_stride,
197                                                    uint32_t area_width, uint32_t area_height) {
198     const uint32_t leftover = area_width & 31;
199     int32_t        h;
200     __m256i        sum = _mm256_setzero_si256();
201     __m128i        sum_l, sum_h, s;
202     uint64_t       spatial_distortion = 0;
203     input += input_offset;
204     recon += recon_offset;
205 
206     if (leftover) {
207         const uint8_t *inp = input + area_width - leftover;
208         const uint8_t *rec = recon + area_width - leftover;
209 
210         if (leftover == 4) {
211             h = area_height;
212             do {
213                 const __m128i in0 = _mm_cvtsi32_si128(*(uint32_t *)inp);
214                 const __m128i in1 = _mm_cvtsi32_si128(*(uint32_t *)(inp + input_stride));
215                 const __m128i re0 = _mm_cvtsi32_si128(*(uint32_t *)rec);
216                 const __m128i re1 = _mm_cvtsi32_si128(*(uint32_t *)(rec + recon_stride));
217                 const __m256i in  = _mm256_setr_m128i(in0, in1);
218                 const __m256i re  = _mm256_setr_m128i(re0, re1);
219                 distortion_avx2_intrin(in, re, &sum);
220                 inp += 2 * input_stride;
221                 rec += 2 * recon_stride;
222                 h -= 2;
223             } while (h);
224 
225             if (area_width == 4) {
226                 sum_l              = _mm256_extracti128_si256(sum, 0);
227                 sum_h              = _mm256_extracti128_si256(sum, 1);
228                 s                  = _mm_add_epi32(sum_l, sum_h);
229                 s                  = _mm_add_epi32(s, _mm_srli_si128(s, 4));
230                 spatial_distortion = _mm_cvtsi128_si32(s);
231                 return spatial_distortion;
232             }
233         } else if (leftover == 8) {
234             h = area_height;
235             do {
236                 const __m128i in0 = _mm_loadl_epi64((__m128i *)inp);
237                 const __m128i in1 = _mm_loadl_epi64((__m128i *)(inp + input_stride));
238                 const __m128i re0 = _mm_loadl_epi64((__m128i *)rec);
239                 const __m128i re1 = _mm_loadl_epi64((__m128i *)(rec + recon_stride));
240                 const __m256i in  = _mm256_setr_m128i(in0, in1);
241                 const __m256i re  = _mm256_setr_m128i(re0, re1);
242                 distortion_avx2_intrin(in, re, &sum);
243                 inp += 2 * input_stride;
244                 rec += 2 * recon_stride;
245                 h -= 2;
246             } while (h);
247         } else if (leftover <= 16) {
248             h = area_height;
249             do {
250                 spatial_full_distortion_kernel16_avx2_intrin(inp, rec, &sum);
251                 inp += input_stride;
252                 rec += recon_stride;
253             } while (--h);
254 
255             if (leftover == 12) {
256                 const __m256i mask = _mm256_setr_epi32(-1, -1, -1, -1, -1, -1, 0, 0);
257                 sum                = _mm256_and_si256(sum, mask);
258             }
259         } else {
260             __m256i sum1 = _mm256_setzero_si256();
261             h            = area_height;
262             do {
263                 spatial_full_distortion_kernel32_leftover_avx2_intrin(inp, rec, &sum, &sum1);
264                 inp += input_stride;
265                 rec += recon_stride;
266             } while (--h);
267 
268             __m256i mask[2];
269             if (leftover == 20) {
270                 mask[0] = _mm256_setr_epi32(-1, -1, -1, -1, -1, -1, 0, 0);
271                 mask[1] = _mm256_setr_epi32(-1, -1, -1, -1, 0, 0, 0, 0);
272             } else if (leftover == 24) {
273                 mask[0] = _mm256_setr_epi32(-1, -1, -1, -1, -1, -1, -1, -1);
274                 mask[1] = _mm256_setr_epi32(-1, -1, -1, -1, 0, 0, 0, 0);
275             } else { // leftover = 28
276                 mask[0] = _mm256_setr_epi32(-1, -1, -1, -1, -1, -1, -1, -1);
277                 mask[1] = _mm256_setr_epi32(-1, -1, -1, -1, -1, -1, 0, 0);
278             }
279 
280             sum  = _mm256_and_si256(sum, mask[0]);
281             sum1 = _mm256_and_si256(sum1, mask[1]);
282             sum  = _mm256_add_epi32(sum, sum1);
283         }
284     }
285 
286     area_width -= leftover;
287 
288     if (area_width) {
289         uint8_t *inp = input;
290         uint8_t *rec = recon;
291         h            = area_height;
292 
293         if (area_width == 32) {
294             do {
295                 spatial_full_distortion_kernel32_avx2_intrin(inp, rec, &sum);
296                 inp += input_stride;
297                 rec += recon_stride;
298             } while (--h);
299         } else {
300             __m512i sum512 = _mm512_setzero_si512();
301 
302             if (area_width == 64) {
303                 do {
304                     SpatialFullDistortionKernel64_AVX512_INTRIN(inp, rec, &sum512);
305                     inp += input_stride;
306                     rec += recon_stride;
307                 } while (--h);
308             } else if (area_width == 96) {
309                 do {
310                     SpatialFullDistortionKernel64_AVX512_INTRIN(inp, rec, &sum512);
311                     SpatialFullDistortionKernel32_AVX512_INTRIN(inp + 64, rec + 64, &sum512);
312                     inp += input_stride;
313                     rec += recon_stride;
314                 } while (--h);
315             } else if (area_width == 128) {
316                 do {
317                     SpatialFullDistortionKernel64_AVX512_INTRIN(inp, rec, &sum512);
318                     SpatialFullDistortionKernel64_AVX512_INTRIN(inp + 64, rec + 64, &sum512);
319                     inp += input_stride;
320                     rec += recon_stride;
321                 } while (--h);
322             } else {
323                 __m512i sum64 = _mm512_setzero_si512();
324 
325                 if (area_width & 32) {
326                     do {
327                         SpatialFullDistortionKernel32_AVX512_INTRIN(inp, rec, &sum512);
328                         inp += input_stride;
329                         rec += recon_stride;
330                     } while (--h);
331                     inp = input + 32;
332                     rec = recon + 32;
333                     h   = area_height;
334                     area_width -= 32;
335                 }
336 
337                 do {
338                     for (uint32_t w = 0; w < area_width; w += 64) {
339                         SpatialFullDistortionKernel64_AVX512_INTRIN(inp + w, rec + w, &sum512);
340                     }
341                     sum32_to64_avx512(&sum512, &sum64);
342                     inp += input_stride;
343                     rec += recon_stride;
344                 } while (--h);
345 
346                 const __m256i sum_L          = _mm512_castsi512_si256(sum64);
347                 const __m256i sum_H          = _mm512_extracti64x4_epi64(sum64, 1);
348                 __m256i       leftover_sum64 = _mm256_unpacklo_epi32(sum, _mm256_setzero_si256());
349                 leftover_sum64               = _mm256_add_epi64(
350                     leftover_sum64, _mm256_unpackhi_epi32(sum, _mm256_setzero_si256()));
351                 sum       = _mm256_add_epi64(sum_L, sum_H);
352                 sum       = _mm256_add_epi64(sum, leftover_sum64);
353                 s = _mm_add_epi64(_mm256_castsi256_si128(sum),
354                                           _mm256_extracti128_si256(sum, 1));
355                 return _mm_extract_epi64(s, 0) + _mm_extract_epi64(s, 1);
356             }
357 
358             const __m256i sum512_L = _mm512_castsi512_si256(sum512);
359             const __m256i sum512_H = _mm512_extracti64x4_epi64(sum512, 1);
360             sum                    = _mm256_add_epi32(sum, sum512_L);
361             sum                    = _mm256_add_epi32(sum, sum512_H);
362         }
363     }
364 
365     return hadd32_avx2_intrin(sum);
366 }
367 
368 #endif // EN_AVX512_SUPPORT
369