1 /*
2  *  Copyright (c) 2012 The WebM project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include <immintrin.h>  // AVX2
12 
13 #include "./vpx_dsp_rtcd.h"
14 
15 /* clang-format off */
16 DECLARE_ALIGNED(32, static const uint8_t, bilinear_filters_avx2[512]) = {
17   16, 0,  16, 0,  16, 0,  16, 0,  16, 0,  16, 0,  16, 0,  16, 0,
18   16, 0,  16, 0,  16, 0,  16, 0,  16, 0,  16, 0,  16, 0,  16, 0,
19   14, 2,  14, 2,  14, 2,  14, 2,  14, 2,  14, 2,  14, 2,  14, 2,
20   14, 2,  14, 2,  14, 2,  14, 2,  14, 2,  14, 2,  14, 2,  14, 2,
21   12, 4,  12, 4,  12, 4,  12, 4,  12, 4,  12, 4,  12, 4,  12, 4,
22   12, 4,  12, 4,  12, 4,  12, 4,  12, 4,  12, 4,  12, 4,  12, 4,
23   10, 6,  10, 6,  10, 6,  10, 6,  10, 6,  10, 6,  10, 6,  10, 6,
24   10, 6,  10, 6,  10, 6,  10, 6,  10, 6,  10, 6,  10, 6,  10, 6,
25   8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,
26   8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,  8,
27   6,  10, 6,  10, 6,  10, 6,  10, 6,  10, 6,  10, 6,  10, 6,  10,
28   6,  10, 6,  10, 6,  10, 6,  10, 6,  10, 6,  10, 6,  10, 6,  10,
29   4,  12, 4,  12, 4,  12, 4,  12, 4,  12, 4,  12, 4,  12, 4,  12,
30   4,  12, 4,  12, 4,  12, 4,  12, 4,  12, 4,  12, 4,  12, 4,  12,
31   2,  14, 2,  14, 2,  14, 2,  14, 2,  14, 2,  14, 2,  14, 2,  14,
32   2,  14, 2,  14, 2,  14, 2,  14, 2,  14, 2,  14, 2,  14, 2,  14,
33 };
34 
35 DECLARE_ALIGNED(32, static const int8_t, adjacent_sub_avx2[32]) = {
36   1, -1,  1, -1,  1, -1,  1, -1,  1, -1,  1, -1,  1, -1,  1, -1,
37   1, -1,  1, -1,  1, -1,  1, -1,  1, -1,  1, -1,  1, -1,  1, -1
38 };
39 /* clang-format on */
40 
variance_kernel_avx2(const __m256i src,const __m256i ref,__m256i * const sse,__m256i * const sum)41 static INLINE void variance_kernel_avx2(const __m256i src, const __m256i ref,
42                                         __m256i *const sse,
43                                         __m256i *const sum) {
44   const __m256i adj_sub = _mm256_load_si256((__m256i const *)adjacent_sub_avx2);
45 
46   // unpack into pairs of source and reference values
47   const __m256i src_ref0 = _mm256_unpacklo_epi8(src, ref);
48   const __m256i src_ref1 = _mm256_unpackhi_epi8(src, ref);
49 
50   // subtract adjacent elements using src*1 + ref*-1
51   const __m256i diff0 = _mm256_maddubs_epi16(src_ref0, adj_sub);
52   const __m256i diff1 = _mm256_maddubs_epi16(src_ref1, adj_sub);
53   const __m256i madd0 = _mm256_madd_epi16(diff0, diff0);
54   const __m256i madd1 = _mm256_madd_epi16(diff1, diff1);
55 
56   // add to the running totals
57   *sum = _mm256_add_epi16(*sum, _mm256_add_epi16(diff0, diff1));
58   *sse = _mm256_add_epi32(*sse, _mm256_add_epi32(madd0, madd1));
59 }
60 
variance_final_from_32bit_sum_avx2(__m256i vsse,__m128i vsum,unsigned int * const sse,int * const sum)61 static INLINE void variance_final_from_32bit_sum_avx2(__m256i vsse,
62                                                       __m128i vsum,
63                                                       unsigned int *const sse,
64                                                       int *const sum) {
65   // extract the low lane and add it to the high lane
66   const __m128i sse_reg_128 = _mm_add_epi32(_mm256_castsi256_si128(vsse),
67                                             _mm256_extractf128_si256(vsse, 1));
68 
69   // unpack sse and sum registers and add
70   const __m128i sse_sum_lo = _mm_unpacklo_epi32(sse_reg_128, vsum);
71   const __m128i sse_sum_hi = _mm_unpackhi_epi32(sse_reg_128, vsum);
72   const __m128i sse_sum = _mm_add_epi32(sse_sum_lo, sse_sum_hi);
73 
74   // perform the final summation and extract the results
75   const __m128i res = _mm_add_epi32(sse_sum, _mm_srli_si128(sse_sum, 8));
76   *((int *)sse) = _mm_cvtsi128_si32(res);
77   *((int *)sum) = _mm_extract_epi32(res, 1);
78 }
79 
variance_final_from_16bit_sum_avx2(__m256i vsse,__m256i vsum,unsigned int * const sse,int * const sum)80 static INLINE void variance_final_from_16bit_sum_avx2(__m256i vsse,
81                                                       __m256i vsum,
82                                                       unsigned int *const sse,
83                                                       int *const sum) {
84   // extract the low lane and add it to the high lane
85   const __m128i sum_reg_128 = _mm_add_epi16(_mm256_castsi256_si128(vsum),
86                                             _mm256_extractf128_si256(vsum, 1));
87   const __m128i sum_reg_64 =
88       _mm_add_epi16(sum_reg_128, _mm_srli_si128(sum_reg_128, 8));
89   const __m128i sum_int32 = _mm_cvtepi16_epi32(sum_reg_64);
90 
91   variance_final_from_32bit_sum_avx2(vsse, sum_int32, sse, sum);
92 }
93 
sum_to_32bit_avx2(const __m256i sum)94 static INLINE __m256i sum_to_32bit_avx2(const __m256i sum) {
95   const __m256i sum_lo = _mm256_cvtepi16_epi32(_mm256_castsi256_si128(sum));
96   const __m256i sum_hi =
97       _mm256_cvtepi16_epi32(_mm256_extractf128_si256(sum, 1));
98   return _mm256_add_epi32(sum_lo, sum_hi);
99 }
100 
variance16_kernel_avx2(const uint8_t * const src,const int src_stride,const uint8_t * const ref,const int ref_stride,__m256i * const sse,__m256i * const sum)101 static INLINE void variance16_kernel_avx2(
102     const uint8_t *const src, const int src_stride, const uint8_t *const ref,
103     const int ref_stride, __m256i *const sse, __m256i *const sum) {
104   const __m128i s0 = _mm_loadu_si128((__m128i const *)(src + 0 * src_stride));
105   const __m128i s1 = _mm_loadu_si128((__m128i const *)(src + 1 * src_stride));
106   const __m128i r0 = _mm_loadu_si128((__m128i const *)(ref + 0 * ref_stride));
107   const __m128i r1 = _mm_loadu_si128((__m128i const *)(ref + 1 * ref_stride));
108   const __m256i s = _mm256_inserti128_si256(_mm256_castsi128_si256(s0), s1, 1);
109   const __m256i r = _mm256_inserti128_si256(_mm256_castsi128_si256(r0), r1, 1);
110   variance_kernel_avx2(s, r, sse, sum);
111 }
112 
variance32_kernel_avx2(const uint8_t * const src,const uint8_t * const ref,__m256i * const sse,__m256i * const sum)113 static INLINE void variance32_kernel_avx2(const uint8_t *const src,
114                                           const uint8_t *const ref,
115                                           __m256i *const sse,
116                                           __m256i *const sum) {
117   const __m256i s = _mm256_loadu_si256((__m256i const *)(src));
118   const __m256i r = _mm256_loadu_si256((__m256i const *)(ref));
119   variance_kernel_avx2(s, r, sse, sum);
120 }
121 
variance16_avx2(const uint8_t * src,const int src_stride,const uint8_t * ref,const int ref_stride,const int h,__m256i * const vsse,__m256i * const vsum)122 static INLINE void variance16_avx2(const uint8_t *src, const int src_stride,
123                                    const uint8_t *ref, const int ref_stride,
124                                    const int h, __m256i *const vsse,
125                                    __m256i *const vsum) {
126   int i;
127   *vsum = _mm256_setzero_si256();
128   *vsse = _mm256_setzero_si256();
129 
130   for (i = 0; i < h; i += 2) {
131     variance16_kernel_avx2(src, src_stride, ref, ref_stride, vsse, vsum);
132     src += 2 * src_stride;
133     ref += 2 * ref_stride;
134   }
135 }
136 
variance32_avx2(const uint8_t * src,const int src_stride,const uint8_t * ref,const int ref_stride,const int h,__m256i * const vsse,__m256i * const vsum)137 static INLINE void variance32_avx2(const uint8_t *src, const int src_stride,
138                                    const uint8_t *ref, const int ref_stride,
139                                    const int h, __m256i *const vsse,
140                                    __m256i *const vsum) {
141   int i;
142   *vsum = _mm256_setzero_si256();
143   *vsse = _mm256_setzero_si256();
144 
145   for (i = 0; i < h; i++) {
146     variance32_kernel_avx2(src, ref, vsse, vsum);
147     src += src_stride;
148     ref += ref_stride;
149   }
150 }
151 
variance64_avx2(const uint8_t * src,const int src_stride,const uint8_t * ref,const int ref_stride,const int h,__m256i * const vsse,__m256i * const vsum)152 static INLINE void variance64_avx2(const uint8_t *src, const int src_stride,
153                                    const uint8_t *ref, const int ref_stride,
154                                    const int h, __m256i *const vsse,
155                                    __m256i *const vsum) {
156   int i;
157   *vsum = _mm256_setzero_si256();
158 
159   for (i = 0; i < h; i++) {
160     variance32_kernel_avx2(src + 0, ref + 0, vsse, vsum);
161     variance32_kernel_avx2(src + 32, ref + 32, vsse, vsum);
162     src += src_stride;
163     ref += ref_stride;
164   }
165 }
166 
vpx_get16x16var_avx2(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,unsigned int * sse,int * sum)167 void vpx_get16x16var_avx2(const uint8_t *src_ptr, int src_stride,
168                           const uint8_t *ref_ptr, int ref_stride,
169                           unsigned int *sse, int *sum) {
170   __m256i vsse, vsum;
171   variance16_avx2(src_ptr, src_stride, ref_ptr, ref_stride, 16, &vsse, &vsum);
172   variance_final_from_16bit_sum_avx2(vsse, vsum, sse, sum);
173 }
174 
175 #define FILTER_SRC(filter)                               \
176   /* filter the source */                                \
177   exp_src_lo = _mm256_maddubs_epi16(exp_src_lo, filter); \
178   exp_src_hi = _mm256_maddubs_epi16(exp_src_hi, filter); \
179                                                          \
180   /* add 8 to source */                                  \
181   exp_src_lo = _mm256_add_epi16(exp_src_lo, pw8);        \
182   exp_src_hi = _mm256_add_epi16(exp_src_hi, pw8);        \
183                                                          \
184   /* divide source by 16 */                              \
185   exp_src_lo = _mm256_srai_epi16(exp_src_lo, 4);         \
186   exp_src_hi = _mm256_srai_epi16(exp_src_hi, 4);
187 
188 #define CALC_SUM_SSE_INSIDE_LOOP                          \
189   /* expand each byte to 2 bytes */                       \
190   exp_dst_lo = _mm256_unpacklo_epi8(dst_reg, zero_reg);   \
191   exp_dst_hi = _mm256_unpackhi_epi8(dst_reg, zero_reg);   \
192   /* source - dest */                                     \
193   exp_src_lo = _mm256_sub_epi16(exp_src_lo, exp_dst_lo);  \
194   exp_src_hi = _mm256_sub_epi16(exp_src_hi, exp_dst_hi);  \
195   /* caculate sum */                                      \
196   *sum_reg = _mm256_add_epi16(*sum_reg, exp_src_lo);      \
197   exp_src_lo = _mm256_madd_epi16(exp_src_lo, exp_src_lo); \
198   *sum_reg = _mm256_add_epi16(*sum_reg, exp_src_hi);      \
199   exp_src_hi = _mm256_madd_epi16(exp_src_hi, exp_src_hi); \
200   /* calculate sse */                                     \
201   *sse_reg = _mm256_add_epi32(*sse_reg, exp_src_lo);      \
202   *sse_reg = _mm256_add_epi32(*sse_reg, exp_src_hi);
203 
204 // final calculation to sum and sse
205 #define CALC_SUM_AND_SSE                                                   \
206   res_cmp = _mm256_cmpgt_epi16(zero_reg, sum_reg);                         \
207   sse_reg_hi = _mm256_srli_si256(sse_reg, 8);                              \
208   sum_reg_lo = _mm256_unpacklo_epi16(sum_reg, res_cmp);                    \
209   sum_reg_hi = _mm256_unpackhi_epi16(sum_reg, res_cmp);                    \
210   sse_reg = _mm256_add_epi32(sse_reg, sse_reg_hi);                         \
211   sum_reg = _mm256_add_epi32(sum_reg_lo, sum_reg_hi);                      \
212                                                                            \
213   sse_reg_hi = _mm256_srli_si256(sse_reg, 4);                              \
214   sum_reg_hi = _mm256_srli_si256(sum_reg, 8);                              \
215                                                                            \
216   sse_reg = _mm256_add_epi32(sse_reg, sse_reg_hi);                         \
217   sum_reg = _mm256_add_epi32(sum_reg, sum_reg_hi);                         \
218   *((int *)sse) = _mm_cvtsi128_si32(_mm256_castsi256_si128(sse_reg)) +     \
219                   _mm_cvtsi128_si32(_mm256_extractf128_si256(sse_reg, 1)); \
220   sum_reg_hi = _mm256_srli_si256(sum_reg, 4);                              \
221   sum_reg = _mm256_add_epi32(sum_reg, sum_reg_hi);                         \
222   sum = _mm_cvtsi128_si32(_mm256_castsi256_si128(sum_reg)) +               \
223         _mm_cvtsi128_si32(_mm256_extractf128_si256(sum_reg, 1));
224 
spv32_x0_y0(const uint8_t * src,int src_stride,const uint8_t * dst,int dst_stride,const uint8_t * second_pred,int second_stride,int do_sec,int height,__m256i * sum_reg,__m256i * sse_reg)225 static INLINE void spv32_x0_y0(const uint8_t *src, int src_stride,
226                                const uint8_t *dst, int dst_stride,
227                                const uint8_t *second_pred, int second_stride,
228                                int do_sec, int height, __m256i *sum_reg,
229                                __m256i *sse_reg) {
230   const __m256i zero_reg = _mm256_setzero_si256();
231   __m256i exp_src_lo, exp_src_hi, exp_dst_lo, exp_dst_hi;
232   int i;
233   for (i = 0; i < height; i++) {
234     const __m256i dst_reg = _mm256_loadu_si256((__m256i const *)dst);
235     const __m256i src_reg = _mm256_loadu_si256((__m256i const *)src);
236     if (do_sec) {
237       const __m256i sec_reg = _mm256_loadu_si256((__m256i const *)second_pred);
238       const __m256i avg_reg = _mm256_avg_epu8(src_reg, sec_reg);
239       exp_src_lo = _mm256_unpacklo_epi8(avg_reg, zero_reg);
240       exp_src_hi = _mm256_unpackhi_epi8(avg_reg, zero_reg);
241       second_pred += second_stride;
242     } else {
243       exp_src_lo = _mm256_unpacklo_epi8(src_reg, zero_reg);
244       exp_src_hi = _mm256_unpackhi_epi8(src_reg, zero_reg);
245     }
246     CALC_SUM_SSE_INSIDE_LOOP
247     src += src_stride;
248     dst += dst_stride;
249   }
250 }
251 
252 // (x == 0, y == 4) or (x == 4, y == 0).  sstep determines the direction.
spv32_half_zero(const uint8_t * src,int src_stride,const uint8_t * dst,int dst_stride,const uint8_t * second_pred,int second_stride,int do_sec,int height,__m256i * sum_reg,__m256i * sse_reg,int sstep)253 static INLINE void spv32_half_zero(const uint8_t *src, int src_stride,
254                                    const uint8_t *dst, int dst_stride,
255                                    const uint8_t *second_pred,
256                                    int second_stride, int do_sec, int height,
257                                    __m256i *sum_reg, __m256i *sse_reg,
258                                    int sstep) {
259   const __m256i zero_reg = _mm256_setzero_si256();
260   __m256i exp_src_lo, exp_src_hi, exp_dst_lo, exp_dst_hi;
261   int i;
262   for (i = 0; i < height; i++) {
263     const __m256i dst_reg = _mm256_loadu_si256((__m256i const *)dst);
264     const __m256i src_0 = _mm256_loadu_si256((__m256i const *)src);
265     const __m256i src_1 = _mm256_loadu_si256((__m256i const *)(src + sstep));
266     const __m256i src_avg = _mm256_avg_epu8(src_0, src_1);
267     if (do_sec) {
268       const __m256i sec_reg = _mm256_loadu_si256((__m256i const *)second_pred);
269       const __m256i avg_reg = _mm256_avg_epu8(src_avg, sec_reg);
270       exp_src_lo = _mm256_unpacklo_epi8(avg_reg, zero_reg);
271       exp_src_hi = _mm256_unpackhi_epi8(avg_reg, zero_reg);
272       second_pred += second_stride;
273     } else {
274       exp_src_lo = _mm256_unpacklo_epi8(src_avg, zero_reg);
275       exp_src_hi = _mm256_unpackhi_epi8(src_avg, zero_reg);
276     }
277     CALC_SUM_SSE_INSIDE_LOOP
278     src += src_stride;
279     dst += dst_stride;
280   }
281 }
282 
spv32_x0_y4(const uint8_t * src,int src_stride,const uint8_t * dst,int dst_stride,const uint8_t * second_pred,int second_stride,int do_sec,int height,__m256i * sum_reg,__m256i * sse_reg)283 static INLINE void spv32_x0_y4(const uint8_t *src, int src_stride,
284                                const uint8_t *dst, int dst_stride,
285                                const uint8_t *second_pred, int second_stride,
286                                int do_sec, int height, __m256i *sum_reg,
287                                __m256i *sse_reg) {
288   spv32_half_zero(src, src_stride, dst, dst_stride, second_pred, second_stride,
289                   do_sec, height, sum_reg, sse_reg, src_stride);
290 }
291 
spv32_x4_y0(const uint8_t * src,int src_stride,const uint8_t * dst,int dst_stride,const uint8_t * second_pred,int second_stride,int do_sec,int height,__m256i * sum_reg,__m256i * sse_reg)292 static INLINE void spv32_x4_y0(const uint8_t *src, int src_stride,
293                                const uint8_t *dst, int dst_stride,
294                                const uint8_t *second_pred, int second_stride,
295                                int do_sec, int height, __m256i *sum_reg,
296                                __m256i *sse_reg) {
297   spv32_half_zero(src, src_stride, dst, dst_stride, second_pred, second_stride,
298                   do_sec, height, sum_reg, sse_reg, 1);
299 }
300 
spv32_x4_y4(const uint8_t * src,int src_stride,const uint8_t * dst,int dst_stride,const uint8_t * second_pred,int second_stride,int do_sec,int height,__m256i * sum_reg,__m256i * sse_reg)301 static INLINE void spv32_x4_y4(const uint8_t *src, int src_stride,
302                                const uint8_t *dst, int dst_stride,
303                                const uint8_t *second_pred, int second_stride,
304                                int do_sec, int height, __m256i *sum_reg,
305                                __m256i *sse_reg) {
306   const __m256i zero_reg = _mm256_setzero_si256();
307   const __m256i src_a = _mm256_loadu_si256((__m256i const *)src);
308   const __m256i src_b = _mm256_loadu_si256((__m256i const *)(src + 1));
309   __m256i prev_src_avg = _mm256_avg_epu8(src_a, src_b);
310   __m256i exp_src_lo, exp_src_hi, exp_dst_lo, exp_dst_hi;
311   int i;
312   src += src_stride;
313   for (i = 0; i < height; i++) {
314     const __m256i dst_reg = _mm256_loadu_si256((__m256i const *)dst);
315     const __m256i src_0 = _mm256_loadu_si256((__m256i const *)(src));
316     const __m256i src_1 = _mm256_loadu_si256((__m256i const *)(src + 1));
317     const __m256i src_avg = _mm256_avg_epu8(src_0, src_1);
318     const __m256i current_avg = _mm256_avg_epu8(prev_src_avg, src_avg);
319     prev_src_avg = src_avg;
320 
321     if (do_sec) {
322       const __m256i sec_reg = _mm256_loadu_si256((__m256i const *)second_pred);
323       const __m256i avg_reg = _mm256_avg_epu8(current_avg, sec_reg);
324       exp_src_lo = _mm256_unpacklo_epi8(avg_reg, zero_reg);
325       exp_src_hi = _mm256_unpackhi_epi8(avg_reg, zero_reg);
326       second_pred += second_stride;
327     } else {
328       exp_src_lo = _mm256_unpacklo_epi8(current_avg, zero_reg);
329       exp_src_hi = _mm256_unpackhi_epi8(current_avg, zero_reg);
330     }
331     // save current source average
332     CALC_SUM_SSE_INSIDE_LOOP
333     dst += dst_stride;
334     src += src_stride;
335   }
336 }
337 
338 // (x == 0, y == bil) or (x == 4, y == bil).  sstep determines the direction.
spv32_bilin_zero(const uint8_t * src,int src_stride,const uint8_t * dst,int dst_stride,const uint8_t * second_pred,int second_stride,int do_sec,int height,__m256i * sum_reg,__m256i * sse_reg,int offset,int sstep)339 static INLINE void spv32_bilin_zero(const uint8_t *src, int src_stride,
340                                     const uint8_t *dst, int dst_stride,
341                                     const uint8_t *second_pred,
342                                     int second_stride, int do_sec, int height,
343                                     __m256i *sum_reg, __m256i *sse_reg,
344                                     int offset, int sstep) {
345   const __m256i zero_reg = _mm256_setzero_si256();
346   const __m256i pw8 = _mm256_set1_epi16(8);
347   const __m256i filter = _mm256_load_si256(
348       (__m256i const *)(bilinear_filters_avx2 + (offset << 5)));
349   __m256i exp_src_lo, exp_src_hi, exp_dst_lo, exp_dst_hi;
350   int i;
351   for (i = 0; i < height; i++) {
352     const __m256i dst_reg = _mm256_loadu_si256((__m256i const *)dst);
353     const __m256i src_0 = _mm256_loadu_si256((__m256i const *)src);
354     const __m256i src_1 = _mm256_loadu_si256((__m256i const *)(src + sstep));
355     exp_src_lo = _mm256_unpacklo_epi8(src_0, src_1);
356     exp_src_hi = _mm256_unpackhi_epi8(src_0, src_1);
357 
358     FILTER_SRC(filter)
359     if (do_sec) {
360       const __m256i sec_reg = _mm256_loadu_si256((__m256i const *)second_pred);
361       const __m256i exp_src = _mm256_packus_epi16(exp_src_lo, exp_src_hi);
362       const __m256i avg_reg = _mm256_avg_epu8(exp_src, sec_reg);
363       second_pred += second_stride;
364       exp_src_lo = _mm256_unpacklo_epi8(avg_reg, zero_reg);
365       exp_src_hi = _mm256_unpackhi_epi8(avg_reg, zero_reg);
366     }
367     CALC_SUM_SSE_INSIDE_LOOP
368     src += src_stride;
369     dst += dst_stride;
370   }
371 }
372 
spv32_x0_yb(const uint8_t * src,int src_stride,const uint8_t * dst,int dst_stride,const uint8_t * second_pred,int second_stride,int do_sec,int height,__m256i * sum_reg,__m256i * sse_reg,int y_offset)373 static INLINE void spv32_x0_yb(const uint8_t *src, int src_stride,
374                                const uint8_t *dst, int dst_stride,
375                                const uint8_t *second_pred, int second_stride,
376                                int do_sec, int height, __m256i *sum_reg,
377                                __m256i *sse_reg, int y_offset) {
378   spv32_bilin_zero(src, src_stride, dst, dst_stride, second_pred, second_stride,
379                    do_sec, height, sum_reg, sse_reg, y_offset, src_stride);
380 }
381 
spv32_xb_y0(const uint8_t * src,int src_stride,const uint8_t * dst,int dst_stride,const uint8_t * second_pred,int second_stride,int do_sec,int height,__m256i * sum_reg,__m256i * sse_reg,int x_offset)382 static INLINE void spv32_xb_y0(const uint8_t *src, int src_stride,
383                                const uint8_t *dst, int dst_stride,
384                                const uint8_t *second_pred, int second_stride,
385                                int do_sec, int height, __m256i *sum_reg,
386                                __m256i *sse_reg, int x_offset) {
387   spv32_bilin_zero(src, src_stride, dst, dst_stride, second_pred, second_stride,
388                    do_sec, height, sum_reg, sse_reg, x_offset, 1);
389 }
390 
spv32_x4_yb(const uint8_t * src,int src_stride,const uint8_t * dst,int dst_stride,const uint8_t * second_pred,int second_stride,int do_sec,int height,__m256i * sum_reg,__m256i * sse_reg,int y_offset)391 static INLINE void spv32_x4_yb(const uint8_t *src, int src_stride,
392                                const uint8_t *dst, int dst_stride,
393                                const uint8_t *second_pred, int second_stride,
394                                int do_sec, int height, __m256i *sum_reg,
395                                __m256i *sse_reg, int y_offset) {
396   const __m256i zero_reg = _mm256_setzero_si256();
397   const __m256i pw8 = _mm256_set1_epi16(8);
398   const __m256i filter = _mm256_load_si256(
399       (__m256i const *)(bilinear_filters_avx2 + (y_offset << 5)));
400   const __m256i src_a = _mm256_loadu_si256((__m256i const *)src);
401   const __m256i src_b = _mm256_loadu_si256((__m256i const *)(src + 1));
402   __m256i prev_src_avg = _mm256_avg_epu8(src_a, src_b);
403   __m256i exp_src_lo, exp_src_hi, exp_dst_lo, exp_dst_hi;
404   int i;
405   src += src_stride;
406   for (i = 0; i < height; i++) {
407     const __m256i dst_reg = _mm256_loadu_si256((__m256i const *)dst);
408     const __m256i src_0 = _mm256_loadu_si256((__m256i const *)src);
409     const __m256i src_1 = _mm256_loadu_si256((__m256i const *)(src + 1));
410     const __m256i src_avg = _mm256_avg_epu8(src_0, src_1);
411     exp_src_lo = _mm256_unpacklo_epi8(prev_src_avg, src_avg);
412     exp_src_hi = _mm256_unpackhi_epi8(prev_src_avg, src_avg);
413     prev_src_avg = src_avg;
414 
415     FILTER_SRC(filter)
416     if (do_sec) {
417       const __m256i sec_reg = _mm256_loadu_si256((__m256i const *)second_pred);
418       const __m256i exp_src_avg = _mm256_packus_epi16(exp_src_lo, exp_src_hi);
419       const __m256i avg_reg = _mm256_avg_epu8(exp_src_avg, sec_reg);
420       exp_src_lo = _mm256_unpacklo_epi8(avg_reg, zero_reg);
421       exp_src_hi = _mm256_unpackhi_epi8(avg_reg, zero_reg);
422       second_pred += second_stride;
423     }
424     CALC_SUM_SSE_INSIDE_LOOP
425     dst += dst_stride;
426     src += src_stride;
427   }
428 }
429 
spv32_xb_y4(const uint8_t * src,int src_stride,const uint8_t * dst,int dst_stride,const uint8_t * second_pred,int second_stride,int do_sec,int height,__m256i * sum_reg,__m256i * sse_reg,int x_offset)430 static INLINE void spv32_xb_y4(const uint8_t *src, int src_stride,
431                                const uint8_t *dst, int dst_stride,
432                                const uint8_t *second_pred, int second_stride,
433                                int do_sec, int height, __m256i *sum_reg,
434                                __m256i *sse_reg, int x_offset) {
435   const __m256i zero_reg = _mm256_setzero_si256();
436   const __m256i pw8 = _mm256_set1_epi16(8);
437   const __m256i filter = _mm256_load_si256(
438       (__m256i const *)(bilinear_filters_avx2 + (x_offset << 5)));
439   const __m256i src_a = _mm256_loadu_si256((__m256i const *)src);
440   const __m256i src_b = _mm256_loadu_si256((__m256i const *)(src + 1));
441   __m256i exp_src_lo, exp_src_hi, exp_dst_lo, exp_dst_hi;
442   __m256i src_reg, src_pack;
443   int i;
444   exp_src_lo = _mm256_unpacklo_epi8(src_a, src_b);
445   exp_src_hi = _mm256_unpackhi_epi8(src_a, src_b);
446   FILTER_SRC(filter)
447   // convert each 16 bit to 8 bit to each low and high lane source
448   src_pack = _mm256_packus_epi16(exp_src_lo, exp_src_hi);
449 
450   src += src_stride;
451   for (i = 0; i < height; i++) {
452     const __m256i dst_reg = _mm256_loadu_si256((__m256i const *)dst);
453     const __m256i src_0 = _mm256_loadu_si256((__m256i const *)src);
454     const __m256i src_1 = _mm256_loadu_si256((__m256i const *)(src + 1));
455     exp_src_lo = _mm256_unpacklo_epi8(src_0, src_1);
456     exp_src_hi = _mm256_unpackhi_epi8(src_0, src_1);
457 
458     FILTER_SRC(filter)
459 
460     src_reg = _mm256_packus_epi16(exp_src_lo, exp_src_hi);
461     // average between previous pack to the current
462     src_pack = _mm256_avg_epu8(src_pack, src_reg);
463 
464     if (do_sec) {
465       const __m256i sec_reg = _mm256_loadu_si256((__m256i const *)second_pred);
466       const __m256i avg_pack = _mm256_avg_epu8(src_pack, sec_reg);
467       exp_src_lo = _mm256_unpacklo_epi8(avg_pack, zero_reg);
468       exp_src_hi = _mm256_unpackhi_epi8(avg_pack, zero_reg);
469       second_pred += second_stride;
470     } else {
471       exp_src_lo = _mm256_unpacklo_epi8(src_pack, zero_reg);
472       exp_src_hi = _mm256_unpackhi_epi8(src_pack, zero_reg);
473     }
474     CALC_SUM_SSE_INSIDE_LOOP
475     src_pack = src_reg;
476     dst += dst_stride;
477     src += src_stride;
478   }
479 }
480 
spv32_xb_yb(const uint8_t * src,int src_stride,const uint8_t * dst,int dst_stride,const uint8_t * second_pred,int second_stride,int do_sec,int height,__m256i * sum_reg,__m256i * sse_reg,int x_offset,int y_offset)481 static INLINE void spv32_xb_yb(const uint8_t *src, int src_stride,
482                                const uint8_t *dst, int dst_stride,
483                                const uint8_t *second_pred, int second_stride,
484                                int do_sec, int height, __m256i *sum_reg,
485                                __m256i *sse_reg, int x_offset, int y_offset) {
486   const __m256i zero_reg = _mm256_setzero_si256();
487   const __m256i pw8 = _mm256_set1_epi16(8);
488   const __m256i xfilter = _mm256_load_si256(
489       (__m256i const *)(bilinear_filters_avx2 + (x_offset << 5)));
490   const __m256i yfilter = _mm256_load_si256(
491       (__m256i const *)(bilinear_filters_avx2 + (y_offset << 5)));
492   const __m256i src_a = _mm256_loadu_si256((__m256i const *)src);
493   const __m256i src_b = _mm256_loadu_si256((__m256i const *)(src + 1));
494   __m256i exp_src_lo, exp_src_hi, exp_dst_lo, exp_dst_hi;
495   __m256i prev_src_pack, src_pack;
496   int i;
497   exp_src_lo = _mm256_unpacklo_epi8(src_a, src_b);
498   exp_src_hi = _mm256_unpackhi_epi8(src_a, src_b);
499   FILTER_SRC(xfilter)
500   // convert each 16 bit to 8 bit to each low and high lane source
501   prev_src_pack = _mm256_packus_epi16(exp_src_lo, exp_src_hi);
502   src += src_stride;
503 
504   for (i = 0; i < height; i++) {
505     const __m256i dst_reg = _mm256_loadu_si256((__m256i const *)dst);
506     const __m256i src_0 = _mm256_loadu_si256((__m256i const *)src);
507     const __m256i src_1 = _mm256_loadu_si256((__m256i const *)(src + 1));
508     exp_src_lo = _mm256_unpacklo_epi8(src_0, src_1);
509     exp_src_hi = _mm256_unpackhi_epi8(src_0, src_1);
510 
511     FILTER_SRC(xfilter)
512     src_pack = _mm256_packus_epi16(exp_src_lo, exp_src_hi);
513 
514     // merge previous pack to current pack source
515     exp_src_lo = _mm256_unpacklo_epi8(prev_src_pack, src_pack);
516     exp_src_hi = _mm256_unpackhi_epi8(prev_src_pack, src_pack);
517 
518     FILTER_SRC(yfilter)
519     if (do_sec) {
520       const __m256i sec_reg = _mm256_loadu_si256((__m256i const *)second_pred);
521       const __m256i exp_src = _mm256_packus_epi16(exp_src_lo, exp_src_hi);
522       const __m256i avg_reg = _mm256_avg_epu8(exp_src, sec_reg);
523       exp_src_lo = _mm256_unpacklo_epi8(avg_reg, zero_reg);
524       exp_src_hi = _mm256_unpackhi_epi8(avg_reg, zero_reg);
525       second_pred += second_stride;
526     }
527 
528     prev_src_pack = src_pack;
529 
530     CALC_SUM_SSE_INSIDE_LOOP
531     dst += dst_stride;
532     src += src_stride;
533   }
534 }
535 
sub_pix_var32xh(const uint8_t * src,int src_stride,int x_offset,int y_offset,const uint8_t * dst,int dst_stride,const uint8_t * second_pred,int second_stride,int do_sec,int height,unsigned int * sse)536 static INLINE int sub_pix_var32xh(const uint8_t *src, int src_stride,
537                                   int x_offset, int y_offset,
538                                   const uint8_t *dst, int dst_stride,
539                                   const uint8_t *second_pred, int second_stride,
540                                   int do_sec, int height, unsigned int *sse) {
541   const __m256i zero_reg = _mm256_setzero_si256();
542   __m256i sum_reg = _mm256_setzero_si256();
543   __m256i sse_reg = _mm256_setzero_si256();
544   __m256i sse_reg_hi, res_cmp, sum_reg_lo, sum_reg_hi;
545   int sum;
546   // x_offset = 0 and y_offset = 0
547   if (x_offset == 0) {
548     if (y_offset == 0) {
549       spv32_x0_y0(src, src_stride, dst, dst_stride, second_pred, second_stride,
550                   do_sec, height, &sum_reg, &sse_reg);
551       // x_offset = 0 and y_offset = 4
552     } else if (y_offset == 4) {
553       spv32_x0_y4(src, src_stride, dst, dst_stride, second_pred, second_stride,
554                   do_sec, height, &sum_reg, &sse_reg);
555       // x_offset = 0 and y_offset = bilin interpolation
556     } else {
557       spv32_x0_yb(src, src_stride, dst, dst_stride, second_pred, second_stride,
558                   do_sec, height, &sum_reg, &sse_reg, y_offset);
559     }
560     // x_offset = 4  and y_offset = 0
561   } else if (x_offset == 4) {
562     if (y_offset == 0) {
563       spv32_x4_y0(src, src_stride, dst, dst_stride, second_pred, second_stride,
564                   do_sec, height, &sum_reg, &sse_reg);
565       // x_offset = 4  and y_offset = 4
566     } else if (y_offset == 4) {
567       spv32_x4_y4(src, src_stride, dst, dst_stride, second_pred, second_stride,
568                   do_sec, height, &sum_reg, &sse_reg);
569       // x_offset = 4  and y_offset = bilin interpolation
570     } else {
571       spv32_x4_yb(src, src_stride, dst, dst_stride, second_pred, second_stride,
572                   do_sec, height, &sum_reg, &sse_reg, y_offset);
573     }
574     // x_offset = bilin interpolation and y_offset = 0
575   } else {
576     if (y_offset == 0) {
577       spv32_xb_y0(src, src_stride, dst, dst_stride, second_pred, second_stride,
578                   do_sec, height, &sum_reg, &sse_reg, x_offset);
579       // x_offset = bilin interpolation and y_offset = 4
580     } else if (y_offset == 4) {
581       spv32_xb_y4(src, src_stride, dst, dst_stride, second_pred, second_stride,
582                   do_sec, height, &sum_reg, &sse_reg, x_offset);
583       // x_offset = bilin interpolation and y_offset = bilin interpolation
584     } else {
585       spv32_xb_yb(src, src_stride, dst, dst_stride, second_pred, second_stride,
586                   do_sec, height, &sum_reg, &sse_reg, x_offset, y_offset);
587     }
588   }
589   CALC_SUM_AND_SSE
590   return sum;
591 }
592 
sub_pixel_variance32xh_avx2(const uint8_t * src,int src_stride,int x_offset,int y_offset,const uint8_t * dst,int dst_stride,int height,unsigned int * sse)593 static unsigned int sub_pixel_variance32xh_avx2(
594     const uint8_t *src, int src_stride, int x_offset, int y_offset,
595     const uint8_t *dst, int dst_stride, int height, unsigned int *sse) {
596   return sub_pix_var32xh(src, src_stride, x_offset, y_offset, dst, dst_stride,
597                          NULL, 0, 0, height, sse);
598 }
599 
sub_pixel_avg_variance32xh_avx2(const uint8_t * src,int src_stride,int x_offset,int y_offset,const uint8_t * dst,int dst_stride,const uint8_t * second_pred,int second_stride,int height,unsigned int * sse)600 static unsigned int sub_pixel_avg_variance32xh_avx2(
601     const uint8_t *src, int src_stride, int x_offset, int y_offset,
602     const uint8_t *dst, int dst_stride, const uint8_t *second_pred,
603     int second_stride, int height, unsigned int *sse) {
604   return sub_pix_var32xh(src, src_stride, x_offset, y_offset, dst, dst_stride,
605                          second_pred, second_stride, 1, height, sse);
606 }
607 
608 typedef void (*get_var_avx2)(const uint8_t *src_ptr, int src_stride,
609                              const uint8_t *ref_ptr, int ref_stride,
610                              unsigned int *sse, int *sum);
611 
vpx_variance16x8_avx2(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,unsigned int * sse)612 unsigned int vpx_variance16x8_avx2(const uint8_t *src_ptr, int src_stride,
613                                    const uint8_t *ref_ptr, int ref_stride,
614                                    unsigned int *sse) {
615   int sum;
616   __m256i vsse, vsum;
617   variance16_avx2(src_ptr, src_stride, ref_ptr, ref_stride, 8, &vsse, &vsum);
618   variance_final_from_16bit_sum_avx2(vsse, vsum, sse, &sum);
619   return *sse - (uint32_t)(((int64_t)sum * sum) >> 7);
620 }
621 
vpx_variance16x16_avx2(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,unsigned int * sse)622 unsigned int vpx_variance16x16_avx2(const uint8_t *src_ptr, int src_stride,
623                                     const uint8_t *ref_ptr, int ref_stride,
624                                     unsigned int *sse) {
625   int sum;
626   __m256i vsse, vsum;
627   variance16_avx2(src_ptr, src_stride, ref_ptr, ref_stride, 16, &vsse, &vsum);
628   variance_final_from_16bit_sum_avx2(vsse, vsum, sse, &sum);
629   return *sse - (uint32_t)(((int64_t)sum * sum) >> 8);
630 }
631 
vpx_variance16x32_avx2(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,unsigned int * sse)632 unsigned int vpx_variance16x32_avx2(const uint8_t *src_ptr, int src_stride,
633                                     const uint8_t *ref_ptr, int ref_stride,
634                                     unsigned int *sse) {
635   int sum;
636   __m256i vsse, vsum;
637   variance16_avx2(src_ptr, src_stride, ref_ptr, ref_stride, 32, &vsse, &vsum);
638   variance_final_from_16bit_sum_avx2(vsse, vsum, sse, &sum);
639   return *sse - (uint32_t)(((int64_t)sum * sum) >> 9);
640 }
641 
vpx_variance32x16_avx2(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,unsigned int * sse)642 unsigned int vpx_variance32x16_avx2(const uint8_t *src_ptr, int src_stride,
643                                     const uint8_t *ref_ptr, int ref_stride,
644                                     unsigned int *sse) {
645   int sum;
646   __m256i vsse, vsum;
647   variance32_avx2(src_ptr, src_stride, ref_ptr, ref_stride, 16, &vsse, &vsum);
648   variance_final_from_16bit_sum_avx2(vsse, vsum, sse, &sum);
649   return *sse - (uint32_t)(((int64_t)sum * sum) >> 9);
650 }
651 
vpx_variance32x32_avx2(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,unsigned int * sse)652 unsigned int vpx_variance32x32_avx2(const uint8_t *src_ptr, int src_stride,
653                                     const uint8_t *ref_ptr, int ref_stride,
654                                     unsigned int *sse) {
655   int sum;
656   __m256i vsse, vsum;
657   __m128i vsum_128;
658   variance32_avx2(src_ptr, src_stride, ref_ptr, ref_stride, 32, &vsse, &vsum);
659   vsum_128 = _mm_add_epi16(_mm256_castsi256_si128(vsum),
660                            _mm256_extractf128_si256(vsum, 1));
661   vsum_128 = _mm_add_epi32(_mm_cvtepi16_epi32(vsum_128),
662                            _mm_cvtepi16_epi32(_mm_srli_si128(vsum_128, 8)));
663   variance_final_from_32bit_sum_avx2(vsse, vsum_128, sse, &sum);
664   return *sse - (uint32_t)(((int64_t)sum * sum) >> 10);
665 }
666 
vpx_variance32x64_avx2(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,unsigned int * sse)667 unsigned int vpx_variance32x64_avx2(const uint8_t *src_ptr, int src_stride,
668                                     const uint8_t *ref_ptr, int ref_stride,
669                                     unsigned int *sse) {
670   int sum;
671   __m256i vsse, vsum;
672   __m128i vsum_128;
673   variance32_avx2(src_ptr, src_stride, ref_ptr, ref_stride, 64, &vsse, &vsum);
674   vsum = sum_to_32bit_avx2(vsum);
675   vsum_128 = _mm_add_epi32(_mm256_castsi256_si128(vsum),
676                            _mm256_extractf128_si256(vsum, 1));
677   variance_final_from_32bit_sum_avx2(vsse, vsum_128, sse, &sum);
678   return *sse - (uint32_t)(((int64_t)sum * sum) >> 11);
679 }
680 
vpx_variance64x32_avx2(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,unsigned int * sse)681 unsigned int vpx_variance64x32_avx2(const uint8_t *src_ptr, int src_stride,
682                                     const uint8_t *ref_ptr, int ref_stride,
683                                     unsigned int *sse) {
684   __m256i vsse = _mm256_setzero_si256();
685   __m256i vsum = _mm256_setzero_si256();
686   __m128i vsum_128;
687   int sum;
688   variance64_avx2(src_ptr, src_stride, ref_ptr, ref_stride, 32, &vsse, &vsum);
689   vsum = sum_to_32bit_avx2(vsum);
690   vsum_128 = _mm_add_epi32(_mm256_castsi256_si128(vsum),
691                            _mm256_extractf128_si256(vsum, 1));
692   variance_final_from_32bit_sum_avx2(vsse, vsum_128, sse, &sum);
693   return *sse - (uint32_t)(((int64_t)sum * sum) >> 11);
694 }
695 
vpx_variance64x64_avx2(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,unsigned int * sse)696 unsigned int vpx_variance64x64_avx2(const uint8_t *src_ptr, int src_stride,
697                                     const uint8_t *ref_ptr, int ref_stride,
698                                     unsigned int *sse) {
699   __m256i vsse = _mm256_setzero_si256();
700   __m256i vsum = _mm256_setzero_si256();
701   __m128i vsum_128;
702   int sum;
703   int i = 0;
704 
705   for (i = 0; i < 2; i++) {
706     __m256i vsum16;
707     variance64_avx2(src_ptr + 32 * i * src_stride, src_stride,
708                     ref_ptr + 32 * i * ref_stride, ref_stride, 32, &vsse,
709                     &vsum16);
710     vsum = _mm256_add_epi32(vsum, sum_to_32bit_avx2(vsum16));
711   }
712   vsum_128 = _mm_add_epi32(_mm256_castsi256_si128(vsum),
713                            _mm256_extractf128_si256(vsum, 1));
714   variance_final_from_32bit_sum_avx2(vsse, vsum_128, sse, &sum);
715   return *sse - (unsigned int)(((int64_t)sum * sum) >> 12);
716 }
717 
vpx_mse16x8_avx2(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,unsigned int * sse)718 unsigned int vpx_mse16x8_avx2(const uint8_t *src_ptr, int src_stride,
719                               const uint8_t *ref_ptr, int ref_stride,
720                               unsigned int *sse) {
721   int sum;
722   __m256i vsse, vsum;
723   variance16_avx2(src_ptr, src_stride, ref_ptr, ref_stride, 8, &vsse, &vsum);
724   variance_final_from_16bit_sum_avx2(vsse, vsum, sse, &sum);
725   return *sse;
726 }
727 
vpx_mse16x16_avx2(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,unsigned int * sse)728 unsigned int vpx_mse16x16_avx2(const uint8_t *src_ptr, int src_stride,
729                                const uint8_t *ref_ptr, int ref_stride,
730                                unsigned int *sse) {
731   int sum;
732   __m256i vsse, vsum;
733   variance16_avx2(src_ptr, src_stride, ref_ptr, ref_stride, 16, &vsse, &vsum);
734   variance_final_from_16bit_sum_avx2(vsse, vsum, sse, &sum);
735   return *sse;
736 }
737 
vpx_sub_pixel_variance64x64_avx2(const uint8_t * src_ptr,int src_stride,int x_offset,int y_offset,const uint8_t * ref_ptr,int ref_stride,unsigned int * sse)738 unsigned int vpx_sub_pixel_variance64x64_avx2(
739     const uint8_t *src_ptr, int src_stride, int x_offset, int y_offset,
740     const uint8_t *ref_ptr, int ref_stride, unsigned int *sse) {
741   unsigned int sse1;
742   const int se1 = sub_pixel_variance32xh_avx2(
743       src_ptr, src_stride, x_offset, y_offset, ref_ptr, ref_stride, 64, &sse1);
744   unsigned int sse2;
745   const int se2 =
746       sub_pixel_variance32xh_avx2(src_ptr + 32, src_stride, x_offset, y_offset,
747                                   ref_ptr + 32, ref_stride, 64, &sse2);
748   const int se = se1 + se2;
749   *sse = sse1 + sse2;
750   return *sse - (uint32_t)(((int64_t)se * se) >> 12);
751 }
752 
vpx_sub_pixel_variance32x32_avx2(const uint8_t * src_ptr,int src_stride,int x_offset,int y_offset,const uint8_t * ref_ptr,int ref_stride,unsigned int * sse)753 unsigned int vpx_sub_pixel_variance32x32_avx2(
754     const uint8_t *src_ptr, int src_stride, int x_offset, int y_offset,
755     const uint8_t *ref_ptr, int ref_stride, unsigned int *sse) {
756   const int se = sub_pixel_variance32xh_avx2(
757       src_ptr, src_stride, x_offset, y_offset, ref_ptr, ref_stride, 32, sse);
758   return *sse - (uint32_t)(((int64_t)se * se) >> 10);
759 }
760 
vpx_sub_pixel_avg_variance64x64_avx2(const uint8_t * src_ptr,int src_stride,int x_offset,int y_offset,const uint8_t * ref_ptr,int ref_stride,unsigned int * sse,const uint8_t * second_pred)761 unsigned int vpx_sub_pixel_avg_variance64x64_avx2(
762     const uint8_t *src_ptr, int src_stride, int x_offset, int y_offset,
763     const uint8_t *ref_ptr, int ref_stride, unsigned int *sse,
764     const uint8_t *second_pred) {
765   unsigned int sse1;
766   const int se1 = sub_pixel_avg_variance32xh_avx2(src_ptr, src_stride, x_offset,
767                                                   y_offset, ref_ptr, ref_stride,
768                                                   second_pred, 64, 64, &sse1);
769   unsigned int sse2;
770   const int se2 = sub_pixel_avg_variance32xh_avx2(
771       src_ptr + 32, src_stride, x_offset, y_offset, ref_ptr + 32, ref_stride,
772       second_pred + 32, 64, 64, &sse2);
773   const int se = se1 + se2;
774 
775   *sse = sse1 + sse2;
776 
777   return *sse - (uint32_t)(((int64_t)se * se) >> 12);
778 }
779 
vpx_sub_pixel_avg_variance32x32_avx2(const uint8_t * src_ptr,int src_stride,int x_offset,int y_offset,const uint8_t * ref_ptr,int ref_stride,unsigned int * sse,const uint8_t * second_pred)780 unsigned int vpx_sub_pixel_avg_variance32x32_avx2(
781     const uint8_t *src_ptr, int src_stride, int x_offset, int y_offset,
782     const uint8_t *ref_ptr, int ref_stride, unsigned int *sse,
783     const uint8_t *second_pred) {
784   // Process 32 elements in parallel.
785   const int se = sub_pixel_avg_variance32xh_avx2(src_ptr, src_stride, x_offset,
786                                                  y_offset, ref_ptr, ref_stride,
787                                                  second_pred, 32, 32, sse);
788   return *sse - (uint32_t)(((int64_t)se * se) >> 10);
789 }
790