1 /*
2  * Copyright (c) 2016, Alliance for Open Media. All rights reserved
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at https://www.aomedia.org/license/software-license. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at https://www.aomedia.org/license/patent-license.
10  */
11 
12 #include "EbDefinitions.h"
13 #include <immintrin.h>
14 #include "aom_dsp_rtcd.h"
15 #include "EbVariance_SSE2.h"
16 
17 // Alpha blending with alpha values from the range [0, 256], where 256
18 // means use the first input and 0 means use the second input.
19 #define AOM_BLEND_A256_ROUND_BITS 8
20 #define AOM_BLEND_A256_MAX_ALPHA (1 << AOM_BLEND_A256_ROUND_BITS) // 256
21 
22 #define AOM_BLEND_A256(a, v0, v1)                                            \
23     ROUND_POWER_OF_TWO((a) * (v0) + (AOM_BLEND_A256_MAX_ALPHA - (a)) * (v1), \
24                        AOM_BLEND_A256_ROUND_BITS)
25 
mm256_add_hi_lo_epi16(const __m256i val)26 static INLINE __m128i mm256_add_hi_lo_epi16(const __m256i val) {
27     return _mm_add_epi16(_mm256_castsi256_si128(val), _mm256_extractf128_si256(val, 1));
28 }
29 
mm256_add_hi_lo_epi32(const __m256i val)30 static INLINE __m128i mm256_add_hi_lo_epi32(const __m256i val) {
31     return _mm_add_epi32(_mm256_castsi256_si128(val), _mm256_extractf128_si256(val, 1));
32 }
33 
variance_kernel_no_sum_avx2(const __m256i src,const __m256i ref,__m256i * const sse)34 static INLINE void variance_kernel_no_sum_avx2(const __m256i src, const __m256i ref,
35                                                __m256i *const sse) {
36     const __m256i adj_sub = _mm256_set1_epi16((short)0xff01); // (1,-1)
37 
38     // unpack into pairs of source and reference values
39     const __m256i src_ref0 = _mm256_unpacklo_epi8(src, ref);
40     const __m256i src_ref1 = _mm256_unpackhi_epi8(src, ref);
41 
42     // subtract adjacent elements using src*1 + ref*-1
43     const __m256i diff0 = _mm256_maddubs_epi16(src_ref0, adj_sub);
44     const __m256i diff1 = _mm256_maddubs_epi16(src_ref1, adj_sub);
45     const __m256i madd0 = _mm256_madd_epi16(diff0, diff0);
46     const __m256i madd1 = _mm256_madd_epi16(diff1, diff1);
47 
48     // add to the running totals
49     *sse = _mm256_add_epi32(*sse, _mm256_add_epi32(madd0, madd1));
50 }
51 
variance_final_from_32bit_no_sum_avx2(__m256i vsse,uint32_t * const sse)52 static INLINE void variance_final_from_32bit_no_sum_avx2(__m256i vsse, uint32_t *const sse) {
53     // extract the low lane and add it to the high lane
54     const __m128i sse_reg_128 = mm256_add_hi_lo_epi32(vsse);
55     const __m128i zero        = _mm_setzero_si128();
56 
57     // unpack sse and sum registers and add
58     const __m128i sse_sum_lo = _mm_unpacklo_epi32(sse_reg_128, zero);
59     const __m128i sse_sum_hi = _mm_unpackhi_epi32(sse_reg_128, zero);
60     const __m128i sse_sum    = _mm_add_epi32(sse_sum_lo, sse_sum_hi);
61 
62     // perform the final summation and extract the results
63     const __m128i res = _mm_add_epi32(sse_sum, _mm_srli_si128(sse_sum, 8));
64     *((int32_t *)sse) = _mm_cvtsi128_si32(res);
65 }
66 
67 // handle pixels (<= 512)
variance_final_512_no_sum_avx2(__m256i vsse,uint32_t * const sse)68 static INLINE void variance_final_512_no_sum_avx2(__m256i vsse, uint32_t *const sse) {
69     // extract the low lane and add it to the high lane
70     variance_final_from_32bit_no_sum_avx2(vsse, sse);
71 }
72 
sum_to_32bit_avx2(const __m256i sum)73 static INLINE __m256i sum_to_32bit_avx2(const __m256i sum) {
74     const __m256i sum_lo = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(sum));
75     const __m256i sum_hi = _mm256_cvtepi16_epi32(_mm256_extractf128_si256(sum, 1));
76     return _mm256_add_epi32(sum_lo, sum_hi);
77 }
variance16_kernel_no_sum_avx2(const uint8_t * const src,const int32_t src_stride,const uint8_t * const ref,const int32_t ref_stride,__m256i * const sse)78 static INLINE void variance16_kernel_no_sum_avx2(const uint8_t *const src, const int32_t src_stride,
79                                                  const uint8_t *const ref, const int32_t ref_stride,
80                                                  __m256i *const sse) {
81     const __m128i s0 = _mm_loadu_si128((__m128i const *)(src + 0 * src_stride));
82     const __m128i s1 = _mm_loadu_si128((__m128i const *)(src + 1 * src_stride));
83     const __m128i r0 = _mm_loadu_si128((__m128i const *)(ref + 0 * ref_stride));
84     const __m128i r1 = _mm_loadu_si128((__m128i const *)(ref + 1 * ref_stride));
85     const __m256i s  = _mm256_inserti128_si256(_mm256_castsi128_si256(s0), s1, 1);
86     const __m256i r  = _mm256_inserti128_si256(_mm256_castsi128_si256(r0), r1, 1);
87     variance_kernel_no_sum_avx2(s, r, sse);
88 }
89 
variance16_no_sum_avx2(const uint8_t * src,const int32_t src_stride,const uint8_t * ref,const int32_t ref_stride,const int32_t h,__m256i * const vsse)90 static INLINE void variance16_no_sum_avx2(const uint8_t *src, const int32_t src_stride,
91                                           const uint8_t *ref, const int32_t ref_stride,
92                                           const int32_t h, __m256i *const vsse) {
93     for (int32_t i = 0; i < h; i += 2) {
94         variance16_kernel_no_sum_avx2(src, src_stride, ref, ref_stride, vsse);
95         src += 2 * src_stride;
96         ref += 2 * ref_stride;
97     }
98 }
99 #define AOM_VAR_NO_LOOP_NO_SUM_AVX2(bw, bh, bits, max_pixel)                     \
100     void svt_aom_variance##bw##x##bh##_no_sum_avx2(const uint8_t *src,           \
101                                                    int32_t        src_stride,    \
102                                                    const uint8_t *ref,           \
103                                                    int32_t        ref_stride,    \
104                                                    uint32_t *     sse) {              \
105         __m256i vsse = _mm256_setzero_si256();                                   \
106         variance##bw##_no_sum_avx2(src, src_stride, ref, ref_stride, bh, &vsse); \
107         variance_final_##max_pixel##_no_sum_avx2(vsse, sse);                     \
108     }
109 
110 AOM_VAR_NO_LOOP_NO_SUM_AVX2(16, 16, 8, 512);
111 
svt_aom_mse16x16_avx2(const uint8_t * src,int32_t src_stride,const uint8_t * ref,int32_t ref_stride,uint32_t * sse)112 uint32_t svt_aom_mse16x16_avx2(const uint8_t *src, int32_t src_stride, const uint8_t *ref,
113                                int32_t ref_stride, uint32_t *sse) {
114     svt_aom_variance16x16_no_sum_avx2(src, src_stride, ref, ref_stride, sse);
115     return *sse;
116 }
117 
variance_final_from_32bit_sum_avx2(__m256i vsse,__m128i vsum,unsigned int * const sse)118 static INLINE int variance_final_from_32bit_sum_avx2(__m256i vsse, __m128i vsum,
119                                                      unsigned int *const sse) {
120     // extract the low lane and add it to the high lane
121     const __m128i sse_reg_128 = mm256_add_hi_lo_epi32(vsse);
122 
123     // unpack sse and sum registers and add
124     const __m128i sse_sum_lo = _mm_unpacklo_epi32(sse_reg_128, vsum);
125     const __m128i sse_sum_hi = _mm_unpackhi_epi32(sse_reg_128, vsum);
126     const __m128i sse_sum    = _mm_add_epi32(sse_sum_lo, sse_sum_hi);
127 
128     // perform the final summation and extract the results
129     const __m128i res = _mm_add_epi32(sse_sum, _mm_srli_si128(sse_sum, 8));
130     *((int *)sse)     = _mm_cvtsi128_si32(res);
131     return _mm_extract_epi32(res, 1);
132 }
133 
134 // handle pixels (<= 512)
variance_final_512_avx2(__m256i vsse,__m256i vsum,unsigned int * const sse)135 static INLINE int variance_final_512_avx2(__m256i vsse, __m256i vsum, unsigned int *const sse) {
136     // extract the low lane and add it to the high lane
137     const __m128i vsum_128  = mm256_add_hi_lo_epi16(vsum);
138     const __m128i vsum_64   = _mm_add_epi16(vsum_128, _mm_srli_si128(vsum_128, 8));
139     const __m128i sum_int32 = _mm_cvtepi16_epi32(vsum_64);
140     return variance_final_from_32bit_sum_avx2(vsse, sum_int32, sse);
141 }
142 
143 // handle 1024 pixels (32x32, 16x64, 64x16)
variance_final_1024_avx2(__m256i vsse,__m256i vsum,unsigned int * const sse)144 static INLINE int variance_final_1024_avx2(__m256i vsse, __m256i vsum, unsigned int *const sse) {
145     // extract the low lane and add it to the high lane
146     const __m128i vsum_128 = mm256_add_hi_lo_epi16(vsum);
147     const __m128i vsum_64  = _mm_add_epi32(_mm_cvtepi16_epi32(vsum_128),
148                                           _mm_cvtepi16_epi32(_mm_srli_si128(vsum_128, 8)));
149     return variance_final_from_32bit_sum_avx2(vsse, vsum_64, sse);
150 }
151 
variance_final_2048_avx2(__m256i vsse,__m256i vsum,unsigned int * const sse)152 static INLINE int variance_final_2048_avx2(__m256i vsse, __m256i vsum, unsigned int *const sse) {
153     vsum                   = sum_to_32bit_avx2(vsum);
154     const __m128i vsum_128 = mm256_add_hi_lo_epi32(vsum);
155     return variance_final_from_32bit_sum_avx2(vsse, vsum_128, sse);
156 }
157 
variance_kernel_avx2(const __m256i src,const __m256i ref,__m256i * const sse,__m256i * const sum)158 static INLINE void variance_kernel_avx2(const __m256i src, const __m256i ref, __m256i *const sse,
159                                         __m256i *const sum) {
160     const __m256i adj_sub = _mm256_set1_epi16(0xff01); // (1,-1)
161 
162     // unpack into pairs of source and reference values
163     const __m256i src_ref0 = _mm256_unpacklo_epi8(src, ref);
164     const __m256i src_ref1 = _mm256_unpackhi_epi8(src, ref);
165 
166     // subtract adjacent elements using src*1 + ref*-1
167     const __m256i diff0 = _mm256_maddubs_epi16(src_ref0, adj_sub);
168     const __m256i diff1 = _mm256_maddubs_epi16(src_ref1, adj_sub);
169     const __m256i madd0 = _mm256_madd_epi16(diff0, diff0);
170     const __m256i madd1 = _mm256_madd_epi16(diff1, diff1);
171 
172     // add to the running totals
173     *sum = _mm256_add_epi16(*sum, _mm256_add_epi16(diff0, diff1));
174     *sse = _mm256_add_epi32(*sse, _mm256_add_epi32(madd0, madd1));
175 }
176 
variance16_kernel_avx2(const uint8_t * const src,const int src_stride,const uint8_t * const ref,const int ref_stride,__m256i * const sse,__m256i * const sum)177 static INLINE void variance16_kernel_avx2(const uint8_t *const src, const int src_stride,
178                                           const uint8_t *const ref, const int ref_stride,
179                                           __m256i *const sse, __m256i *const sum) {
180     const __m128i s0 = _mm_loadu_si128((__m128i const *)(src + 0 * src_stride));
181     const __m128i s1 = _mm_loadu_si128((__m128i const *)(src + 1 * src_stride));
182     const __m128i r0 = _mm_loadu_si128((__m128i const *)(ref + 0 * ref_stride));
183     const __m128i r1 = _mm_loadu_si128((__m128i const *)(ref + 1 * ref_stride));
184     const __m256i s  = _mm256_inserti128_si256(_mm256_castsi128_si256(s0), s1, 1);
185     const __m256i r  = _mm256_inserti128_si256(_mm256_castsi128_si256(r0), r1, 1);
186     variance_kernel_avx2(s, r, sse, sum);
187 }
188 
variance32_kernel_avx2(const uint8_t * const src,const uint8_t * const ref,__m256i * const sse,__m256i * const sum)189 static INLINE void variance32_kernel_avx2(const uint8_t *const src, const uint8_t *const ref,
190                                           __m256i *const sse, __m256i *const sum) {
191     const __m256i s = _mm256_loadu_si256((__m256i const *)(src));
192     const __m256i r = _mm256_loadu_si256((__m256i const *)(ref));
193     variance_kernel_avx2(s, r, sse, sum);
194 }
195 
variance16_avx2(const uint8_t * src,const int src_stride,const uint8_t * ref,const int ref_stride,const int h,__m256i * const vsse,__m256i * const vsum)196 static INLINE void variance16_avx2(const uint8_t *src, const int src_stride, const uint8_t *ref,
197                                    const int ref_stride, const int h, __m256i *const vsse,
198                                    __m256i *const vsum) {
199     *vsum = _mm256_setzero_si256();
200 
201     for (int i = 0; i < h; i += 2) {
202         variance16_kernel_avx2(src, src_stride, ref, ref_stride, vsse, vsum);
203         src += 2 * src_stride;
204         ref += 2 * ref_stride;
205     }
206 }
207 
variance32_avx2(const uint8_t * src,const int src_stride,const uint8_t * ref,const int ref_stride,const int h,__m256i * const vsse,__m256i * const vsum)208 static INLINE void variance32_avx2(const uint8_t *src, const int src_stride, const uint8_t *ref,
209                                    const int ref_stride, const int h, __m256i *const vsse,
210                                    __m256i *const vsum) {
211     *vsum = _mm256_setzero_si256();
212 
213     for (int i = 0; i < h; i++) {
214         variance32_kernel_avx2(src, ref, vsse, vsum);
215         src += src_stride;
216         ref += ref_stride;
217     }
218 }
219 
variance64_avx2(const uint8_t * src,const int src_stride,const uint8_t * ref,const int ref_stride,const int h,__m256i * const vsse,__m256i * const vsum)220 static INLINE void variance64_avx2(const uint8_t *src, const int src_stride, const uint8_t *ref,
221                                    const int ref_stride, const int h, __m256i *const vsse,
222                                    __m256i *const vsum) {
223     *vsum = _mm256_setzero_si256();
224 
225     for (int i = 0; i < h; i++) {
226         variance32_kernel_avx2(src + 0, ref + 0, vsse, vsum);
227         variance32_kernel_avx2(src + 32, ref + 32, vsse, vsum);
228         src += src_stride;
229         ref += ref_stride;
230     }
231 }
232 
variance128_avx2(const uint8_t * src,const int src_stride,const uint8_t * ref,const int ref_stride,const int h,__m256i * const vsse,__m256i * const vsum)233 static INLINE void variance128_avx2(const uint8_t *src, const int src_stride, const uint8_t *ref,
234                                     const int ref_stride, const int h, __m256i *const vsse,
235                                     __m256i *const vsum) {
236     *vsum = _mm256_setzero_si256();
237 
238     for (int i = 0; i < h; i++) {
239         variance32_kernel_avx2(src + 0, ref + 0, vsse, vsum);
240         variance32_kernel_avx2(src + 32, ref + 32, vsse, vsum);
241         variance32_kernel_avx2(src + 64, ref + 64, vsse, vsum);
242         variance32_kernel_avx2(src + 96, ref + 96, vsse, vsum);
243         src += src_stride;
244         ref += ref_stride;
245     }
246 }
247 
248 #define AOM_VAR_NO_LOOP_AVX2(bw, bh, bits, max_pixel)                            \
249     unsigned int svt_aom_variance##bw##x##bh##_avx2(const uint8_t *src,          \
250                                                     int            src_stride,   \
251                                                     const uint8_t *ref,          \
252                                                     int            ref_stride,   \
253                                                     unsigned int * sse) {        \
254         __m256i vsse = _mm256_setzero_si256();                                   \
255         __m256i vsum;                                                            \
256         variance##bw##_avx2(src, src_stride, ref, ref_stride, bh, &vsse, &vsum); \
257         const int sum = variance_final_##max_pixel##_avx2(vsse, vsum, sse);      \
258         return *sse - (uint32_t)(((int64_t)sum * sum) >> bits);                  \
259     }
260 
261 AOM_VAR_NO_LOOP_AVX2(16, 4, 6, 512);
262 AOM_VAR_NO_LOOP_AVX2(16, 8, 7, 512);
263 AOM_VAR_NO_LOOP_AVX2(16, 16, 8, 512);
264 AOM_VAR_NO_LOOP_AVX2(16, 32, 9, 512);
265 AOM_VAR_NO_LOOP_AVX2(16, 64, 10, 1024);
266 
267 AOM_VAR_NO_LOOP_AVX2(32, 8, 8, 512);
268 AOM_VAR_NO_LOOP_AVX2(32, 16, 9, 512);
269 AOM_VAR_NO_LOOP_AVX2(32, 32, 10, 1024);
270 AOM_VAR_NO_LOOP_AVX2(32, 64, 11, 2048);
271 
272 AOM_VAR_NO_LOOP_AVX2(64, 16, 10, 1024);
273 AOM_VAR_NO_LOOP_AVX2(64, 32, 11, 2048);
274 
275 #define AOM_VAR_LOOP_AVX2(bw, bh, bits, uh)                                               \
276     unsigned int svt_aom_variance##bw##x##bh##_avx2(const uint8_t *src,                   \
277                                                     int            src_stride,            \
278                                                     const uint8_t *ref,                   \
279                                                     int            ref_stride,            \
280                                                     unsigned int * sse) {                 \
281         __m256i vsse = _mm256_setzero_si256();                                            \
282         __m256i vsum = _mm256_setzero_si256();                                            \
283         for (int i = 0; i < (bh / uh); i++) {                                             \
284             __m256i vsum16;                                                               \
285             variance##bw##_avx2(src, src_stride, ref, ref_stride, uh, &vsse, &vsum16);    \
286             vsum = _mm256_add_epi32(vsum, sum_to_32bit_avx2(vsum16));                     \
287             src += uh * src_stride;                                                       \
288             ref += uh * ref_stride;                                                       \
289         }                                                                                 \
290         const __m128i vsum_128 = mm256_add_hi_lo_epi32(vsum);                             \
291         const int     sum      = variance_final_from_32bit_sum_avx2(vsse, vsum_128, sse); \
292         return *sse - (unsigned int)(((int64_t)sum * sum) >> bits);                       \
293     }
294 
295 AOM_VAR_LOOP_AVX2(64, 64, 12, 32); // 64x32 * ( 64/32)
296 AOM_VAR_LOOP_AVX2(64, 128, 13, 32); // 64x32 * (128/32)
297 AOM_VAR_LOOP_AVX2(128, 64, 13, 16); // 128x16 * ( 64/16)
298 AOM_VAR_LOOP_AVX2(128, 128, 14, 16); // 128x16 * (128/16)
299 
300 unsigned int svt_aom_sub_pixel_variance32xh_avx2(const uint8_t *src, int src_stride,
301     int x_offset, int y_offset,
302     const uint8_t *dst, int dst_stride,
303     int height, unsigned int *sse);
304 unsigned int svt_aom_sub_pixel_variance16xh_avx2(const uint8_t *src, int src_stride,
305     int x_offset, int y_offset,
306     const uint8_t *dst, int dst_stride,
307     int height, unsigned int *sse);
308 
309 #define AOM_SUB_PIXEL_VAR_AVX2(w, h, wf, wlog2, hlog2)                        \
310   unsigned int svt_aom_sub_pixel_variance##w##x##h##_avx2(                    \
311       const uint8_t *src, int src_stride, int x_offset, int y_offset,         \
312       const uint8_t *dst, int dst_stride, unsigned int *sse_ptr) {            \
313     /*Avoid overflow in helper by capping height.*/                           \
314     const int hf = AOMMIN(h, 64);                                             \
315     const int wf2 = AOMMIN(wf, 128);                                          \
316     unsigned int sse = 0;                                                     \
317     int se = 0;                                                               \
318     for (int i = 0; i < (w / wf2); ++i) {                                     \
319       const uint8_t *src_ptr = src;                                           \
320       const uint8_t *dst_ptr = dst;                                           \
321       for (int j = 0; j < (h / hf); ++j) {                                    \
322         unsigned int sse2;                                                    \
323         const int se2 = svt_aom_sub_pixel_variance##wf##xh_avx2(              \
324             src_ptr, src_stride, x_offset, y_offset, dst_ptr, dst_stride, hf, \
325             &sse2);                                                           \
326         dst_ptr += hf * dst_stride;                                           \
327         src_ptr += hf * src_stride;                                           \
328         se += se2;                                                            \
329         sse += sse2;                                                          \
330       }                                                                       \
331       src += wf;                                                              \
332       dst += wf;                                                              \
333     }                                                                         \
334     *sse_ptr = sse;                                                           \
335     return sse - (unsigned int)(((int64_t)se * se) >> (wlog2 + hlog2));       \
336   }
337 
338 AOM_SUB_PIXEL_VAR_AVX2(128, 128, 32, 7, 7);
339 AOM_SUB_PIXEL_VAR_AVX2(128, 64, 32, 7, 6);
340 AOM_SUB_PIXEL_VAR_AVX2(64, 128, 32, 6, 7);
341 AOM_SUB_PIXEL_VAR_AVX2(64, 64, 32, 6, 6);
342 AOM_SUB_PIXEL_VAR_AVX2(64, 32, 32, 6, 5);
343 AOM_SUB_PIXEL_VAR_AVX2(32, 64, 32, 5, 6);
344 AOM_SUB_PIXEL_VAR_AVX2(32, 32, 32, 5, 5);
345 AOM_SUB_PIXEL_VAR_AVX2(32, 16, 32, 5, 4);
346 AOM_SUB_PIXEL_VAR_AVX2(16, 64, 16, 4, 6);
347 AOM_SUB_PIXEL_VAR_AVX2(16, 32, 16, 4, 5);
348 AOM_SUB_PIXEL_VAR_AVX2(16, 16, 16, 4, 4);
349 AOM_SUB_PIXEL_VAR_AVX2(16, 8, 16, 4, 3);
350 AOM_SUB_PIXEL_VAR_AVX2(16, 4, 16, 4, 2);
351