1 /*
2  * Copyright (c) 2016, Alliance for Open Media. All rights reserved
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <immintrin.h>
13 
14 #include "config/aom_config.h"
15 #include "config/aom_dsp_rtcd.h"
16 
17 #include "aom/aom_integer.h"
18 #include "aom_dsp/x86/synonyms_avx2.h"
19 #include "aom_ports/mem.h"
20 
21 // SAD
get_sad_from_mm256_epi32(const __m256i * v)22 static INLINE unsigned int get_sad_from_mm256_epi32(const __m256i *v) {
23   // input 8 32-bit summation
24   __m128i lo128, hi128;
25   __m256i u = _mm256_srli_si256(*v, 8);
26   u = _mm256_add_epi32(u, *v);
27 
28   // 4 32-bit summation
29   hi128 = _mm256_extracti128_si256(u, 1);
30   lo128 = _mm256_castsi256_si128(u);
31   lo128 = _mm_add_epi32(hi128, lo128);
32 
33   // 2 32-bit summation
34   hi128 = _mm_srli_si128(lo128, 4);
35   lo128 = _mm_add_epi32(lo128, hi128);
36 
37   return (unsigned int)_mm_cvtsi128_si32(lo128);
38 }
39 
aom_highbd_sad16x8_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride)40 unsigned int aom_highbd_sad16x8_avx2(const uint8_t *src, int src_stride,
41                                      const uint8_t *ref, int ref_stride) {
42   const uint16_t *src_ptr = CONVERT_TO_SHORTPTR(src);
43   const uint16_t *ref_ptr = CONVERT_TO_SHORTPTR(ref);
44 
45   // first 4 rows
46   __m256i s0 = _mm256_loadu_si256((const __m256i *)src_ptr);
47   __m256i s1 = _mm256_loadu_si256((const __m256i *)(src_ptr + src_stride));
48   __m256i s2 = _mm256_loadu_si256((const __m256i *)(src_ptr + 2 * src_stride));
49   __m256i s3 = _mm256_loadu_si256((const __m256i *)(src_ptr + 3 * src_stride));
50 
51   __m256i r0 = _mm256_loadu_si256((const __m256i *)ref_ptr);
52   __m256i r1 = _mm256_loadu_si256((const __m256i *)(ref_ptr + ref_stride));
53   __m256i r2 = _mm256_loadu_si256((const __m256i *)(ref_ptr + 2 * ref_stride));
54   __m256i r3 = _mm256_loadu_si256((const __m256i *)(ref_ptr + 3 * ref_stride));
55 
56   __m256i u0 = _mm256_sub_epi16(s0, r0);
57   __m256i u1 = _mm256_sub_epi16(s1, r1);
58   __m256i u2 = _mm256_sub_epi16(s2, r2);
59   __m256i u3 = _mm256_sub_epi16(s3, r3);
60   __m256i zero = _mm256_setzero_si256();
61   __m256i sum0, sum1;
62 
63   u0 = _mm256_abs_epi16(u0);
64   u1 = _mm256_abs_epi16(u1);
65   u2 = _mm256_abs_epi16(u2);
66   u3 = _mm256_abs_epi16(u3);
67 
68   sum0 = _mm256_add_epi16(u0, u1);
69   sum0 = _mm256_add_epi16(sum0, u2);
70   sum0 = _mm256_add_epi16(sum0, u3);
71 
72   // second 4 rows
73   src_ptr += src_stride << 2;
74   ref_ptr += ref_stride << 2;
75   s0 = _mm256_loadu_si256((const __m256i *)src_ptr);
76   s1 = _mm256_loadu_si256((const __m256i *)(src_ptr + src_stride));
77   s2 = _mm256_loadu_si256((const __m256i *)(src_ptr + 2 * src_stride));
78   s3 = _mm256_loadu_si256((const __m256i *)(src_ptr + 3 * src_stride));
79 
80   r0 = _mm256_loadu_si256((const __m256i *)ref_ptr);
81   r1 = _mm256_loadu_si256((const __m256i *)(ref_ptr + ref_stride));
82   r2 = _mm256_loadu_si256((const __m256i *)(ref_ptr + 2 * ref_stride));
83   r3 = _mm256_loadu_si256((const __m256i *)(ref_ptr + 3 * ref_stride));
84 
85   u0 = _mm256_sub_epi16(s0, r0);
86   u1 = _mm256_sub_epi16(s1, r1);
87   u2 = _mm256_sub_epi16(s2, r2);
88   u3 = _mm256_sub_epi16(s3, r3);
89 
90   u0 = _mm256_abs_epi16(u0);
91   u1 = _mm256_abs_epi16(u1);
92   u2 = _mm256_abs_epi16(u2);
93   u3 = _mm256_abs_epi16(u3);
94 
95   sum1 = _mm256_add_epi16(u0, u1);
96   sum1 = _mm256_add_epi16(sum1, u2);
97   sum1 = _mm256_add_epi16(sum1, u3);
98 
99   // find out the SAD
100   s0 = _mm256_unpacklo_epi16(sum0, zero);
101   s1 = _mm256_unpackhi_epi16(sum0, zero);
102   r0 = _mm256_unpacklo_epi16(sum1, zero);
103   r1 = _mm256_unpackhi_epi16(sum1, zero);
104   s0 = _mm256_add_epi32(s0, s1);
105   r0 = _mm256_add_epi32(r0, r1);
106   sum0 = _mm256_add_epi32(s0, r0);
107   // 8 32-bit summation
108 
109   return (unsigned int)get_sad_from_mm256_epi32(&sum0);
110 }
111 
aom_highbd_sad16x16_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride)112 unsigned int aom_highbd_sad16x16_avx2(const uint8_t *src, int src_stride,
113                                       const uint8_t *ref, int ref_stride) {
114   const uint16_t *src_ptr = CONVERT_TO_SHORTPTR(src);
115   const uint16_t *ref_ptr = CONVERT_TO_SHORTPTR(ref);
116   __m256i s0, s1, s2, s3, r0, r1, r2, r3, u0, u1, u2, u3;
117   __m256i sum0;
118   __m256i sum = _mm256_setzero_si256();
119   const __m256i zero = _mm256_setzero_si256();
120   int row = 0;
121 
122   // Loop for every 4 rows
123   while (row < 16) {
124     s0 = _mm256_loadu_si256((const __m256i *)src_ptr);
125     s1 = _mm256_loadu_si256((const __m256i *)(src_ptr + src_stride));
126     s2 = _mm256_loadu_si256((const __m256i *)(src_ptr + 2 * src_stride));
127     s3 = _mm256_loadu_si256((const __m256i *)(src_ptr + 3 * src_stride));
128 
129     r0 = _mm256_loadu_si256((const __m256i *)ref_ptr);
130     r1 = _mm256_loadu_si256((const __m256i *)(ref_ptr + ref_stride));
131     r2 = _mm256_loadu_si256((const __m256i *)(ref_ptr + 2 * ref_stride));
132     r3 = _mm256_loadu_si256((const __m256i *)(ref_ptr + 3 * ref_stride));
133 
134     u0 = _mm256_sub_epi16(s0, r0);
135     u1 = _mm256_sub_epi16(s1, r1);
136     u2 = _mm256_sub_epi16(s2, r2);
137     u3 = _mm256_sub_epi16(s3, r3);
138 
139     u0 = _mm256_abs_epi16(u0);
140     u1 = _mm256_abs_epi16(u1);
141     u2 = _mm256_abs_epi16(u2);
142     u3 = _mm256_abs_epi16(u3);
143 
144     sum0 = _mm256_add_epi16(u0, u1);
145     sum0 = _mm256_add_epi16(sum0, u2);
146     sum0 = _mm256_add_epi16(sum0, u3);
147 
148     s0 = _mm256_unpacklo_epi16(sum0, zero);
149     s1 = _mm256_unpackhi_epi16(sum0, zero);
150     sum = _mm256_add_epi32(sum, s0);
151     sum = _mm256_add_epi32(sum, s1);
152     // 8 32-bit summation
153 
154     row += 4;
155     src_ptr += src_stride << 2;
156     ref_ptr += ref_stride << 2;
157   }
158   return get_sad_from_mm256_epi32(&sum);
159 }
160 
sad32x4(const uint16_t * src_ptr,int src_stride,const uint16_t * ref_ptr,int ref_stride,const uint16_t * sec_ptr,__m256i * sad_acc)161 static void sad32x4(const uint16_t *src_ptr, int src_stride,
162                     const uint16_t *ref_ptr, int ref_stride,
163                     const uint16_t *sec_ptr, __m256i *sad_acc) {
164   __m256i s0, s1, s2, s3, r0, r1, r2, r3;
165   const __m256i zero = _mm256_setzero_si256();
166   int row_sections = 0;
167 
168   while (row_sections < 2) {
169     s0 = _mm256_loadu_si256((const __m256i *)src_ptr);
170     s1 = _mm256_loadu_si256((const __m256i *)(src_ptr + 16));
171     s2 = _mm256_loadu_si256((const __m256i *)(src_ptr + src_stride));
172     s3 = _mm256_loadu_si256((const __m256i *)(src_ptr + src_stride + 16));
173 
174     r0 = _mm256_loadu_si256((const __m256i *)ref_ptr);
175     r1 = _mm256_loadu_si256((const __m256i *)(ref_ptr + 16));
176     r2 = _mm256_loadu_si256((const __m256i *)(ref_ptr + ref_stride));
177     r3 = _mm256_loadu_si256((const __m256i *)(ref_ptr + ref_stride + 16));
178 
179     if (sec_ptr) {
180       r0 = _mm256_avg_epu16(r0, _mm256_loadu_si256((const __m256i *)sec_ptr));
181       r1 = _mm256_avg_epu16(
182           r1, _mm256_loadu_si256((const __m256i *)(sec_ptr + 16)));
183       r2 = _mm256_avg_epu16(
184           r2, _mm256_loadu_si256((const __m256i *)(sec_ptr + 32)));
185       r3 = _mm256_avg_epu16(
186           r3, _mm256_loadu_si256((const __m256i *)(sec_ptr + 48)));
187     }
188     s0 = _mm256_sub_epi16(s0, r0);
189     s1 = _mm256_sub_epi16(s1, r1);
190     s2 = _mm256_sub_epi16(s2, r2);
191     s3 = _mm256_sub_epi16(s3, r3);
192 
193     s0 = _mm256_abs_epi16(s0);
194     s1 = _mm256_abs_epi16(s1);
195     s2 = _mm256_abs_epi16(s2);
196     s3 = _mm256_abs_epi16(s3);
197 
198     s0 = _mm256_add_epi16(s0, s1);
199     s0 = _mm256_add_epi16(s0, s2);
200     s0 = _mm256_add_epi16(s0, s3);
201 
202     r0 = _mm256_unpacklo_epi16(s0, zero);
203     r1 = _mm256_unpackhi_epi16(s0, zero);
204 
205     r0 = _mm256_add_epi32(r0, r1);
206     *sad_acc = _mm256_add_epi32(*sad_acc, r0);
207 
208     row_sections += 1;
209     src_ptr += src_stride << 1;
210     ref_ptr += ref_stride << 1;
211     if (sec_ptr) sec_ptr += 32 << 1;
212   }
213 }
214 
aom_highbd_sad32x16_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride)215 unsigned int aom_highbd_sad32x16_avx2(const uint8_t *src, int src_stride,
216                                       const uint8_t *ref, int ref_stride) {
217   __m256i sad = _mm256_setzero_si256();
218   uint16_t *srcp = CONVERT_TO_SHORTPTR(src);
219   uint16_t *refp = CONVERT_TO_SHORTPTR(ref);
220   const int left_shift = 2;
221   int row_section = 0;
222 
223   while (row_section < 4) {
224     sad32x4(srcp, src_stride, refp, ref_stride, NULL, &sad);
225     srcp += src_stride << left_shift;
226     refp += ref_stride << left_shift;
227     row_section += 1;
228   }
229   return get_sad_from_mm256_epi32(&sad);
230 }
231 
aom_highbd_sad16x32_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride)232 unsigned int aom_highbd_sad16x32_avx2(const uint8_t *src, int src_stride,
233                                       const uint8_t *ref, int ref_stride) {
234   uint32_t sum = aom_highbd_sad16x16_avx2(src, src_stride, ref, ref_stride);
235   src += src_stride << 4;
236   ref += ref_stride << 4;
237   sum += aom_highbd_sad16x16_avx2(src, src_stride, ref, ref_stride);
238   return sum;
239 }
240 
aom_highbd_sad32x32_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride)241 unsigned int aom_highbd_sad32x32_avx2(const uint8_t *src, int src_stride,
242                                       const uint8_t *ref, int ref_stride) {
243   uint32_t sum = aom_highbd_sad32x16_avx2(src, src_stride, ref, ref_stride);
244   src += src_stride << 4;
245   ref += ref_stride << 4;
246   sum += aom_highbd_sad32x16_avx2(src, src_stride, ref, ref_stride);
247   return sum;
248 }
249 
aom_highbd_sad32x64_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride)250 unsigned int aom_highbd_sad32x64_avx2(const uint8_t *src, int src_stride,
251                                       const uint8_t *ref, int ref_stride) {
252   uint32_t sum = aom_highbd_sad32x32_avx2(src, src_stride, ref, ref_stride);
253   src += src_stride << 5;
254   ref += ref_stride << 5;
255   sum += aom_highbd_sad32x32_avx2(src, src_stride, ref, ref_stride);
256   return sum;
257 }
258 
sad64x2(const uint16_t * src_ptr,int src_stride,const uint16_t * ref_ptr,int ref_stride,const uint16_t * sec_ptr,__m256i * sad_acc)259 static void sad64x2(const uint16_t *src_ptr, int src_stride,
260                     const uint16_t *ref_ptr, int ref_stride,
261                     const uint16_t *sec_ptr, __m256i *sad_acc) {
262   __m256i s[8], r[8];
263   const __m256i zero = _mm256_setzero_si256();
264 
265   s[0] = _mm256_loadu_si256((const __m256i *)src_ptr);
266   s[1] = _mm256_loadu_si256((const __m256i *)(src_ptr + 16));
267   s[2] = _mm256_loadu_si256((const __m256i *)(src_ptr + 32));
268   s[3] = _mm256_loadu_si256((const __m256i *)(src_ptr + 48));
269   s[4] = _mm256_loadu_si256((const __m256i *)(src_ptr + src_stride));
270   s[5] = _mm256_loadu_si256((const __m256i *)(src_ptr + src_stride + 16));
271   s[6] = _mm256_loadu_si256((const __m256i *)(src_ptr + src_stride + 32));
272   s[7] = _mm256_loadu_si256((const __m256i *)(src_ptr + src_stride + 48));
273 
274   r[0] = _mm256_loadu_si256((const __m256i *)ref_ptr);
275   r[1] = _mm256_loadu_si256((const __m256i *)(ref_ptr + 16));
276   r[2] = _mm256_loadu_si256((const __m256i *)(ref_ptr + 32));
277   r[3] = _mm256_loadu_si256((const __m256i *)(ref_ptr + 48));
278   r[4] = _mm256_loadu_si256((const __m256i *)(ref_ptr + ref_stride));
279   r[5] = _mm256_loadu_si256((const __m256i *)(ref_ptr + ref_stride + 16));
280   r[6] = _mm256_loadu_si256((const __m256i *)(ref_ptr + ref_stride + 32));
281   r[7] = _mm256_loadu_si256((const __m256i *)(ref_ptr + ref_stride + 48));
282 
283   if (sec_ptr) {
284     r[0] = _mm256_avg_epu16(r[0], _mm256_loadu_si256((const __m256i *)sec_ptr));
285     r[1] = _mm256_avg_epu16(
286         r[1], _mm256_loadu_si256((const __m256i *)(sec_ptr + 16)));
287     r[2] = _mm256_avg_epu16(
288         r[2], _mm256_loadu_si256((const __m256i *)(sec_ptr + 32)));
289     r[3] = _mm256_avg_epu16(
290         r[3], _mm256_loadu_si256((const __m256i *)(sec_ptr + 48)));
291     r[4] = _mm256_avg_epu16(
292         r[4], _mm256_loadu_si256((const __m256i *)(sec_ptr + 64)));
293     r[5] = _mm256_avg_epu16(
294         r[5], _mm256_loadu_si256((const __m256i *)(sec_ptr + 80)));
295     r[6] = _mm256_avg_epu16(
296         r[6], _mm256_loadu_si256((const __m256i *)(sec_ptr + 96)));
297     r[7] = _mm256_avg_epu16(
298         r[7], _mm256_loadu_si256((const __m256i *)(sec_ptr + 112)));
299   }
300 
301   s[0] = _mm256_sub_epi16(s[0], r[0]);
302   s[1] = _mm256_sub_epi16(s[1], r[1]);
303   s[2] = _mm256_sub_epi16(s[2], r[2]);
304   s[3] = _mm256_sub_epi16(s[3], r[3]);
305   s[4] = _mm256_sub_epi16(s[4], r[4]);
306   s[5] = _mm256_sub_epi16(s[5], r[5]);
307   s[6] = _mm256_sub_epi16(s[6], r[6]);
308   s[7] = _mm256_sub_epi16(s[7], r[7]);
309 
310   s[0] = _mm256_abs_epi16(s[0]);
311   s[1] = _mm256_abs_epi16(s[1]);
312   s[2] = _mm256_abs_epi16(s[2]);
313   s[3] = _mm256_abs_epi16(s[3]);
314   s[4] = _mm256_abs_epi16(s[4]);
315   s[5] = _mm256_abs_epi16(s[5]);
316   s[6] = _mm256_abs_epi16(s[6]);
317   s[7] = _mm256_abs_epi16(s[7]);
318 
319   s[0] = _mm256_add_epi16(s[0], s[1]);
320   s[0] = _mm256_add_epi16(s[0], s[2]);
321   s[0] = _mm256_add_epi16(s[0], s[3]);
322 
323   s[4] = _mm256_add_epi16(s[4], s[5]);
324   s[4] = _mm256_add_epi16(s[4], s[6]);
325   s[4] = _mm256_add_epi16(s[4], s[7]);
326 
327   r[0] = _mm256_unpacklo_epi16(s[0], zero);
328   r[1] = _mm256_unpackhi_epi16(s[0], zero);
329   r[2] = _mm256_unpacklo_epi16(s[4], zero);
330   r[3] = _mm256_unpackhi_epi16(s[4], zero);
331 
332   r[0] = _mm256_add_epi32(r[0], r[1]);
333   r[0] = _mm256_add_epi32(r[0], r[2]);
334   r[0] = _mm256_add_epi32(r[0], r[3]);
335   *sad_acc = _mm256_add_epi32(*sad_acc, r[0]);
336 }
337 
aom_highbd_sad64x32_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride)338 unsigned int aom_highbd_sad64x32_avx2(const uint8_t *src, int src_stride,
339                                       const uint8_t *ref, int ref_stride) {
340   __m256i sad = _mm256_setzero_si256();
341   uint16_t *srcp = CONVERT_TO_SHORTPTR(src);
342   uint16_t *refp = CONVERT_TO_SHORTPTR(ref);
343   const int left_shift = 1;
344   int row_section = 0;
345 
346   while (row_section < 16) {
347     sad64x2(srcp, src_stride, refp, ref_stride, NULL, &sad);
348     srcp += src_stride << left_shift;
349     refp += ref_stride << left_shift;
350     row_section += 1;
351   }
352   return get_sad_from_mm256_epi32(&sad);
353 }
354 
aom_highbd_sad64x64_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride)355 unsigned int aom_highbd_sad64x64_avx2(const uint8_t *src, int src_stride,
356                                       const uint8_t *ref, int ref_stride) {
357   uint32_t sum = aom_highbd_sad64x32_avx2(src, src_stride, ref, ref_stride);
358   src += src_stride << 5;
359   ref += ref_stride << 5;
360   sum += aom_highbd_sad64x32_avx2(src, src_stride, ref, ref_stride);
361   return sum;
362 }
363 
sad128x1(const uint16_t * src_ptr,const uint16_t * ref_ptr,const uint16_t * sec_ptr,__m256i * sad_acc)364 static void sad128x1(const uint16_t *src_ptr, const uint16_t *ref_ptr,
365                      const uint16_t *sec_ptr, __m256i *sad_acc) {
366   __m256i s[8], r[8];
367   const __m256i zero = _mm256_setzero_si256();
368 
369   s[0] = _mm256_loadu_si256((const __m256i *)src_ptr);
370   s[1] = _mm256_loadu_si256((const __m256i *)(src_ptr + 16));
371   s[2] = _mm256_loadu_si256((const __m256i *)(src_ptr + 32));
372   s[3] = _mm256_loadu_si256((const __m256i *)(src_ptr + 48));
373   s[4] = _mm256_loadu_si256((const __m256i *)(src_ptr + 64));
374   s[5] = _mm256_loadu_si256((const __m256i *)(src_ptr + 80));
375   s[6] = _mm256_loadu_si256((const __m256i *)(src_ptr + 96));
376   s[7] = _mm256_loadu_si256((const __m256i *)(src_ptr + 112));
377 
378   r[0] = _mm256_loadu_si256((const __m256i *)ref_ptr);
379   r[1] = _mm256_loadu_si256((const __m256i *)(ref_ptr + 16));
380   r[2] = _mm256_loadu_si256((const __m256i *)(ref_ptr + 32));
381   r[3] = _mm256_loadu_si256((const __m256i *)(ref_ptr + 48));
382   r[4] = _mm256_loadu_si256((const __m256i *)(ref_ptr + 64));
383   r[5] = _mm256_loadu_si256((const __m256i *)(ref_ptr + 80));
384   r[6] = _mm256_loadu_si256((const __m256i *)(ref_ptr + 96));
385   r[7] = _mm256_loadu_si256((const __m256i *)(ref_ptr + 112));
386 
387   if (sec_ptr) {
388     r[0] = _mm256_avg_epu16(r[0], _mm256_loadu_si256((const __m256i *)sec_ptr));
389     r[1] = _mm256_avg_epu16(
390         r[1], _mm256_loadu_si256((const __m256i *)(sec_ptr + 16)));
391     r[2] = _mm256_avg_epu16(
392         r[2], _mm256_loadu_si256((const __m256i *)(sec_ptr + 32)));
393     r[3] = _mm256_avg_epu16(
394         r[3], _mm256_loadu_si256((const __m256i *)(sec_ptr + 48)));
395     r[4] = _mm256_avg_epu16(
396         r[4], _mm256_loadu_si256((const __m256i *)(sec_ptr + 64)));
397     r[5] = _mm256_avg_epu16(
398         r[5], _mm256_loadu_si256((const __m256i *)(sec_ptr + 80)));
399     r[6] = _mm256_avg_epu16(
400         r[6], _mm256_loadu_si256((const __m256i *)(sec_ptr + 96)));
401     r[7] = _mm256_avg_epu16(
402         r[7], _mm256_loadu_si256((const __m256i *)(sec_ptr + 112)));
403   }
404 
405   s[0] = _mm256_sub_epi16(s[0], r[0]);
406   s[1] = _mm256_sub_epi16(s[1], r[1]);
407   s[2] = _mm256_sub_epi16(s[2], r[2]);
408   s[3] = _mm256_sub_epi16(s[3], r[3]);
409   s[4] = _mm256_sub_epi16(s[4], r[4]);
410   s[5] = _mm256_sub_epi16(s[5], r[5]);
411   s[6] = _mm256_sub_epi16(s[6], r[6]);
412   s[7] = _mm256_sub_epi16(s[7], r[7]);
413 
414   s[0] = _mm256_abs_epi16(s[0]);
415   s[1] = _mm256_abs_epi16(s[1]);
416   s[2] = _mm256_abs_epi16(s[2]);
417   s[3] = _mm256_abs_epi16(s[3]);
418   s[4] = _mm256_abs_epi16(s[4]);
419   s[5] = _mm256_abs_epi16(s[5]);
420   s[6] = _mm256_abs_epi16(s[6]);
421   s[7] = _mm256_abs_epi16(s[7]);
422 
423   s[0] = _mm256_add_epi16(s[0], s[1]);
424   s[0] = _mm256_add_epi16(s[0], s[2]);
425   s[0] = _mm256_add_epi16(s[0], s[3]);
426 
427   s[4] = _mm256_add_epi16(s[4], s[5]);
428   s[4] = _mm256_add_epi16(s[4], s[6]);
429   s[4] = _mm256_add_epi16(s[4], s[7]);
430 
431   r[0] = _mm256_unpacklo_epi16(s[0], zero);
432   r[1] = _mm256_unpackhi_epi16(s[0], zero);
433   r[2] = _mm256_unpacklo_epi16(s[4], zero);
434   r[3] = _mm256_unpackhi_epi16(s[4], zero);
435 
436   r[0] = _mm256_add_epi32(r[0], r[1]);
437   r[0] = _mm256_add_epi32(r[0], r[2]);
438   r[0] = _mm256_add_epi32(r[0], r[3]);
439   *sad_acc = _mm256_add_epi32(*sad_acc, r[0]);
440 }
441 
aom_highbd_sad128x64_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride)442 unsigned int aom_highbd_sad128x64_avx2(const uint8_t *src, int src_stride,
443                                        const uint8_t *ref, int ref_stride) {
444   __m256i sad = _mm256_setzero_si256();
445   uint16_t *srcp = CONVERT_TO_SHORTPTR(src);
446   uint16_t *refp = CONVERT_TO_SHORTPTR(ref);
447   int row = 0;
448   while (row < 64) {
449     sad128x1(srcp, refp, NULL, &sad);
450     srcp += src_stride;
451     refp += ref_stride;
452     row += 1;
453   }
454   return get_sad_from_mm256_epi32(&sad);
455 }
456 
aom_highbd_sad64x128_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride)457 unsigned int aom_highbd_sad64x128_avx2(const uint8_t *src, int src_stride,
458                                        const uint8_t *ref, int ref_stride) {
459   uint32_t sum = aom_highbd_sad64x64_avx2(src, src_stride, ref, ref_stride);
460   src += src_stride << 6;
461   ref += ref_stride << 6;
462   sum += aom_highbd_sad64x64_avx2(src, src_stride, ref, ref_stride);
463   return sum;
464 }
465 
aom_highbd_sad128x128_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride)466 unsigned int aom_highbd_sad128x128_avx2(const uint8_t *src, int src_stride,
467                                         const uint8_t *ref, int ref_stride) {
468   uint32_t sum = aom_highbd_sad128x64_avx2(src, src_stride, ref, ref_stride);
469   src += src_stride << 6;
470   ref += ref_stride << 6;
471   sum += aom_highbd_sad128x64_avx2(src, src_stride, ref, ref_stride);
472   return sum;
473 }
474 
475 // If sec_ptr = 0, calculate regular SAD. Otherwise, calculate average SAD.
sad16x4(const uint16_t * src_ptr,int src_stride,const uint16_t * ref_ptr,int ref_stride,const uint16_t * sec_ptr,__m256i * sad_acc)476 static INLINE void sad16x4(const uint16_t *src_ptr, int src_stride,
477                            const uint16_t *ref_ptr, int ref_stride,
478                            const uint16_t *sec_ptr, __m256i *sad_acc) {
479   __m256i s0, s1, s2, s3, r0, r1, r2, r3;
480   const __m256i zero = _mm256_setzero_si256();
481 
482   s0 = _mm256_loadu_si256((const __m256i *)src_ptr);
483   s1 = _mm256_loadu_si256((const __m256i *)(src_ptr + src_stride));
484   s2 = _mm256_loadu_si256((const __m256i *)(src_ptr + 2 * src_stride));
485   s3 = _mm256_loadu_si256((const __m256i *)(src_ptr + 3 * src_stride));
486 
487   r0 = _mm256_loadu_si256((const __m256i *)ref_ptr);
488   r1 = _mm256_loadu_si256((const __m256i *)(ref_ptr + ref_stride));
489   r2 = _mm256_loadu_si256((const __m256i *)(ref_ptr + 2 * ref_stride));
490   r3 = _mm256_loadu_si256((const __m256i *)(ref_ptr + 3 * ref_stride));
491 
492   if (sec_ptr) {
493     r0 = _mm256_avg_epu16(r0, _mm256_loadu_si256((const __m256i *)sec_ptr));
494     r1 = _mm256_avg_epu16(r1,
495                           _mm256_loadu_si256((const __m256i *)(sec_ptr + 16)));
496     r2 = _mm256_avg_epu16(r2,
497                           _mm256_loadu_si256((const __m256i *)(sec_ptr + 32)));
498     r3 = _mm256_avg_epu16(r3,
499                           _mm256_loadu_si256((const __m256i *)(sec_ptr + 48)));
500   }
501 
502   s0 = _mm256_sub_epi16(s0, r0);
503   s1 = _mm256_sub_epi16(s1, r1);
504   s2 = _mm256_sub_epi16(s2, r2);
505   s3 = _mm256_sub_epi16(s3, r3);
506 
507   s0 = _mm256_abs_epi16(s0);
508   s1 = _mm256_abs_epi16(s1);
509   s2 = _mm256_abs_epi16(s2);
510   s3 = _mm256_abs_epi16(s3);
511 
512   s0 = _mm256_add_epi16(s0, s1);
513   s0 = _mm256_add_epi16(s0, s2);
514   s0 = _mm256_add_epi16(s0, s3);
515 
516   r0 = _mm256_unpacklo_epi16(s0, zero);
517   r1 = _mm256_unpackhi_epi16(s0, zero);
518 
519   r0 = _mm256_add_epi32(r0, r1);
520   *sad_acc = _mm256_add_epi32(*sad_acc, r0);
521 }
522 
aom_highbd_sad16x8_avg_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,const uint8_t * second_pred)523 unsigned int aom_highbd_sad16x8_avg_avx2(const uint8_t *src, int src_stride,
524                                          const uint8_t *ref, int ref_stride,
525                                          const uint8_t *second_pred) {
526   __m256i sad = _mm256_setzero_si256();
527   uint16_t *srcp = CONVERT_TO_SHORTPTR(src);
528   uint16_t *refp = CONVERT_TO_SHORTPTR(ref);
529   uint16_t *secp = CONVERT_TO_SHORTPTR(second_pred);
530 
531   sad16x4(srcp, src_stride, refp, ref_stride, secp, &sad);
532 
533   // Next 4 rows
534   srcp += src_stride << 2;
535   refp += ref_stride << 2;
536   secp += 64;
537   sad16x4(srcp, src_stride, refp, ref_stride, secp, &sad);
538   return get_sad_from_mm256_epi32(&sad);
539 }
540 
aom_highbd_sad16x16_avg_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,const uint8_t * second_pred)541 unsigned int aom_highbd_sad16x16_avg_avx2(const uint8_t *src, int src_stride,
542                                           const uint8_t *ref, int ref_stride,
543                                           const uint8_t *second_pred) {
544   const int left_shift = 3;
545   uint32_t sum = aom_highbd_sad16x8_avg_avx2(src, src_stride, ref, ref_stride,
546                                              second_pred);
547   src += src_stride << left_shift;
548   ref += ref_stride << left_shift;
549   second_pred += 16 << left_shift;
550   sum += aom_highbd_sad16x8_avg_avx2(src, src_stride, ref, ref_stride,
551                                      second_pred);
552   return sum;
553 }
554 
aom_highbd_sad16x32_avg_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,const uint8_t * second_pred)555 unsigned int aom_highbd_sad16x32_avg_avx2(const uint8_t *src, int src_stride,
556                                           const uint8_t *ref, int ref_stride,
557                                           const uint8_t *second_pred) {
558   const int left_shift = 4;
559   uint32_t sum = aom_highbd_sad16x16_avg_avx2(src, src_stride, ref, ref_stride,
560                                               second_pred);
561   src += src_stride << left_shift;
562   ref += ref_stride << left_shift;
563   second_pred += 16 << left_shift;
564   sum += aom_highbd_sad16x16_avg_avx2(src, src_stride, ref, ref_stride,
565                                       second_pred);
566   return sum;
567 }
568 
aom_highbd_sad32x16_avg_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,const uint8_t * second_pred)569 unsigned int aom_highbd_sad32x16_avg_avx2(const uint8_t *src, int src_stride,
570                                           const uint8_t *ref, int ref_stride,
571                                           const uint8_t *second_pred) {
572   __m256i sad = _mm256_setzero_si256();
573   uint16_t *srcp = CONVERT_TO_SHORTPTR(src);
574   uint16_t *refp = CONVERT_TO_SHORTPTR(ref);
575   uint16_t *secp = CONVERT_TO_SHORTPTR(second_pred);
576   const int left_shift = 2;
577   int row_section = 0;
578 
579   while (row_section < 4) {
580     sad32x4(srcp, src_stride, refp, ref_stride, secp, &sad);
581     srcp += src_stride << left_shift;
582     refp += ref_stride << left_shift;
583     secp += 32 << left_shift;
584     row_section += 1;
585   }
586   return get_sad_from_mm256_epi32(&sad);
587 }
588 
aom_highbd_sad32x32_avg_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,const uint8_t * second_pred)589 unsigned int aom_highbd_sad32x32_avg_avx2(const uint8_t *src, int src_stride,
590                                           const uint8_t *ref, int ref_stride,
591                                           const uint8_t *second_pred) {
592   const int left_shift = 4;
593   uint32_t sum = aom_highbd_sad32x16_avg_avx2(src, src_stride, ref, ref_stride,
594                                               second_pred);
595   src += src_stride << left_shift;
596   ref += ref_stride << left_shift;
597   second_pred += 32 << left_shift;
598   sum += aom_highbd_sad32x16_avg_avx2(src, src_stride, ref, ref_stride,
599                                       second_pred);
600   return sum;
601 }
602 
aom_highbd_sad32x64_avg_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,const uint8_t * second_pred)603 unsigned int aom_highbd_sad32x64_avg_avx2(const uint8_t *src, int src_stride,
604                                           const uint8_t *ref, int ref_stride,
605                                           const uint8_t *second_pred) {
606   const int left_shift = 5;
607   uint32_t sum = aom_highbd_sad32x32_avg_avx2(src, src_stride, ref, ref_stride,
608                                               second_pred);
609   src += src_stride << left_shift;
610   ref += ref_stride << left_shift;
611   second_pred += 32 << left_shift;
612   sum += aom_highbd_sad32x32_avg_avx2(src, src_stride, ref, ref_stride,
613                                       second_pred);
614   return sum;
615 }
616 
aom_highbd_sad64x32_avg_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,const uint8_t * second_pred)617 unsigned int aom_highbd_sad64x32_avg_avx2(const uint8_t *src, int src_stride,
618                                           const uint8_t *ref, int ref_stride,
619                                           const uint8_t *second_pred) {
620   __m256i sad = _mm256_setzero_si256();
621   uint16_t *srcp = CONVERT_TO_SHORTPTR(src);
622   uint16_t *refp = CONVERT_TO_SHORTPTR(ref);
623   uint16_t *secp = CONVERT_TO_SHORTPTR(second_pred);
624   const int left_shift = 1;
625   int row_section = 0;
626 
627   while (row_section < 16) {
628     sad64x2(srcp, src_stride, refp, ref_stride, secp, &sad);
629     srcp += src_stride << left_shift;
630     refp += ref_stride << left_shift;
631     secp += 64 << left_shift;
632     row_section += 1;
633   }
634   return get_sad_from_mm256_epi32(&sad);
635 }
636 
aom_highbd_sad64x64_avg_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,const uint8_t * second_pred)637 unsigned int aom_highbd_sad64x64_avg_avx2(const uint8_t *src, int src_stride,
638                                           const uint8_t *ref, int ref_stride,
639                                           const uint8_t *second_pred) {
640   const int left_shift = 5;
641   uint32_t sum = aom_highbd_sad64x32_avg_avx2(src, src_stride, ref, ref_stride,
642                                               second_pred);
643   src += src_stride << left_shift;
644   ref += ref_stride << left_shift;
645   second_pred += 64 << left_shift;
646   sum += aom_highbd_sad64x32_avg_avx2(src, src_stride, ref, ref_stride,
647                                       second_pred);
648   return sum;
649 }
650 
aom_highbd_sad64x128_avg_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,const uint8_t * second_pred)651 unsigned int aom_highbd_sad64x128_avg_avx2(const uint8_t *src, int src_stride,
652                                            const uint8_t *ref, int ref_stride,
653                                            const uint8_t *second_pred) {
654   const int left_shift = 6;
655   uint32_t sum = aom_highbd_sad64x64_avg_avx2(src, src_stride, ref, ref_stride,
656                                               second_pred);
657   src += src_stride << left_shift;
658   ref += ref_stride << left_shift;
659   second_pred += 64 << left_shift;
660   sum += aom_highbd_sad64x64_avg_avx2(src, src_stride, ref, ref_stride,
661                                       second_pred);
662   return sum;
663 }
664 
aom_highbd_sad128x64_avg_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,const uint8_t * second_pred)665 unsigned int aom_highbd_sad128x64_avg_avx2(const uint8_t *src, int src_stride,
666                                            const uint8_t *ref, int ref_stride,
667                                            const uint8_t *second_pred) {
668   __m256i sad = _mm256_setzero_si256();
669   uint16_t *srcp = CONVERT_TO_SHORTPTR(src);
670   uint16_t *refp = CONVERT_TO_SHORTPTR(ref);
671   uint16_t *secp = CONVERT_TO_SHORTPTR(second_pred);
672   int row = 0;
673   while (row < 64) {
674     sad128x1(srcp, refp, secp, &sad);
675     srcp += src_stride;
676     refp += ref_stride;
677     secp += 16 << 3;
678     row += 1;
679   }
680   return get_sad_from_mm256_epi32(&sad);
681 }
682 
aom_highbd_sad128x128_avg_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,const uint8_t * second_pred)683 unsigned int aom_highbd_sad128x128_avg_avx2(const uint8_t *src, int src_stride,
684                                             const uint8_t *ref, int ref_stride,
685                                             const uint8_t *second_pred) {
686   unsigned int sum;
687   const int left_shift = 6;
688 
689   sum = aom_highbd_sad128x64_avg_avx2(src, src_stride, ref, ref_stride,
690                                       second_pred);
691   src += src_stride << left_shift;
692   ref += ref_stride << left_shift;
693   second_pred += 128 << left_shift;
694   sum += aom_highbd_sad128x64_avg_avx2(src, src_stride, ref, ref_stride,
695                                        second_pred);
696   return sum;
697 }
698 
699 // SAD 4D
700 // Combine 4 __m256i vectors to uint32_t result[4]
get_4d_sad_from_mm256_epi32(const __m256i * v,uint32_t * res)701 static INLINE void get_4d_sad_from_mm256_epi32(const __m256i *v,
702                                                uint32_t *res) {
703   __m256i u0, u1, u2, u3;
704   const __m256i mask = yy_set1_64_from_32i(UINT32_MAX);
705   __m128i sad;
706 
707   // 8 32-bit summation
708   u0 = _mm256_srli_si256(v[0], 4);
709   u1 = _mm256_srli_si256(v[1], 4);
710   u2 = _mm256_srli_si256(v[2], 4);
711   u3 = _mm256_srli_si256(v[3], 4);
712 
713   u0 = _mm256_add_epi32(u0, v[0]);
714   u1 = _mm256_add_epi32(u1, v[1]);
715   u2 = _mm256_add_epi32(u2, v[2]);
716   u3 = _mm256_add_epi32(u3, v[3]);
717 
718   u0 = _mm256_and_si256(u0, mask);
719   u1 = _mm256_and_si256(u1, mask);
720   u2 = _mm256_and_si256(u2, mask);
721   u3 = _mm256_and_si256(u3, mask);
722   // 4 32-bit summation, evenly positioned
723 
724   u1 = _mm256_slli_si256(u1, 4);
725   u3 = _mm256_slli_si256(u3, 4);
726 
727   u0 = _mm256_or_si256(u0, u1);
728   u2 = _mm256_or_si256(u2, u3);
729   // 8 32-bit summation, interleaved
730 
731   u1 = _mm256_unpacklo_epi64(u0, u2);
732   u3 = _mm256_unpackhi_epi64(u0, u2);
733 
734   u0 = _mm256_add_epi32(u1, u3);
735   sad = _mm_add_epi32(_mm256_extractf128_si256(u0, 1),
736                       _mm256_castsi256_si128(u0));
737   _mm_storeu_si128((__m128i *)res, sad);
738 }
739 
convert_pointers(const uint8_t * const ref8[],const uint16_t * ref[])740 static void convert_pointers(const uint8_t *const ref8[],
741                              const uint16_t *ref[]) {
742   ref[0] = CONVERT_TO_SHORTPTR(ref8[0]);
743   ref[1] = CONVERT_TO_SHORTPTR(ref8[1]);
744   ref[2] = CONVERT_TO_SHORTPTR(ref8[2]);
745   ref[3] = CONVERT_TO_SHORTPTR(ref8[3]);
746 }
747 
init_sad(__m256i * s)748 static void init_sad(__m256i *s) {
749   s[0] = _mm256_setzero_si256();
750   s[1] = _mm256_setzero_si256();
751   s[2] = _mm256_setzero_si256();
752   s[3] = _mm256_setzero_si256();
753 }
754 
aom_highbd_sad16x8x4d_avx2(const uint8_t * src,int src_stride,const uint8_t * const ref_array[],int ref_stride,uint32_t * sad_array)755 void aom_highbd_sad16x8x4d_avx2(const uint8_t *src, int src_stride,
756                                 const uint8_t *const ref_array[],
757                                 int ref_stride, uint32_t *sad_array) {
758   __m256i sad_vec[4];
759   const uint16_t *refp[4];
760   const uint16_t *keep = CONVERT_TO_SHORTPTR(src);
761   const uint16_t *srcp;
762   const int shift_for_4_rows = 2;
763   int i;
764 
765   init_sad(sad_vec);
766   convert_pointers(ref_array, refp);
767 
768   for (i = 0; i < 4; ++i) {
769     srcp = keep;
770     sad16x4(srcp, src_stride, refp[i], ref_stride, 0, &sad_vec[i]);
771     srcp += src_stride << shift_for_4_rows;
772     refp[i] += ref_stride << shift_for_4_rows;
773     sad16x4(srcp, src_stride, refp[i], ref_stride, 0, &sad_vec[i]);
774   }
775   get_4d_sad_from_mm256_epi32(sad_vec, sad_array);
776 }
777 
aom_highbd_sad16x16x4d_avx2(const uint8_t * src,int src_stride,const uint8_t * const ref_array[],int ref_stride,uint32_t * sad_array)778 void aom_highbd_sad16x16x4d_avx2(const uint8_t *src, int src_stride,
779                                  const uint8_t *const ref_array[],
780                                  int ref_stride, uint32_t *sad_array) {
781   uint32_t first8rows[4];
782   uint32_t second8rows[4];
783   const uint8_t *ref[4];
784   const int shift_for_8_rows = 3;
785 
786   ref[0] = ref_array[0];
787   ref[1] = ref_array[1];
788   ref[2] = ref_array[2];
789   ref[3] = ref_array[3];
790 
791   aom_highbd_sad16x8x4d_avx2(src, src_stride, ref, ref_stride, first8rows);
792   src += src_stride << shift_for_8_rows;
793   ref[0] += ref_stride << shift_for_8_rows;
794   ref[1] += ref_stride << shift_for_8_rows;
795   ref[2] += ref_stride << shift_for_8_rows;
796   ref[3] += ref_stride << shift_for_8_rows;
797   aom_highbd_sad16x8x4d_avx2(src, src_stride, ref, ref_stride, second8rows);
798   sad_array[0] = first8rows[0] + second8rows[0];
799   sad_array[1] = first8rows[1] + second8rows[1];
800   sad_array[2] = first8rows[2] + second8rows[2];
801   sad_array[3] = first8rows[3] + second8rows[3];
802 }
803 
aom_highbd_sad16x32x4d_avx2(const uint8_t * src,int src_stride,const uint8_t * const ref_array[],int ref_stride,uint32_t * sad_array)804 void aom_highbd_sad16x32x4d_avx2(const uint8_t *src, int src_stride,
805                                  const uint8_t *const ref_array[],
806                                  int ref_stride, uint32_t *sad_array) {
807   uint32_t first_half[4];
808   uint32_t second_half[4];
809   const uint8_t *ref[4];
810   const int shift_for_rows = 4;
811 
812   ref[0] = ref_array[0];
813   ref[1] = ref_array[1];
814   ref[2] = ref_array[2];
815   ref[3] = ref_array[3];
816 
817   aom_highbd_sad16x16x4d_avx2(src, src_stride, ref, ref_stride, first_half);
818   src += src_stride << shift_for_rows;
819   ref[0] += ref_stride << shift_for_rows;
820   ref[1] += ref_stride << shift_for_rows;
821   ref[2] += ref_stride << shift_for_rows;
822   ref[3] += ref_stride << shift_for_rows;
823   aom_highbd_sad16x16x4d_avx2(src, src_stride, ref, ref_stride, second_half);
824   sad_array[0] = first_half[0] + second_half[0];
825   sad_array[1] = first_half[1] + second_half[1];
826   sad_array[2] = first_half[2] + second_half[2];
827   sad_array[3] = first_half[3] + second_half[3];
828 }
829 
aom_highbd_sad32x16x4d_avx2(const uint8_t * src,int src_stride,const uint8_t * const ref_array[],int ref_stride,uint32_t * sad_array)830 void aom_highbd_sad32x16x4d_avx2(const uint8_t *src, int src_stride,
831                                  const uint8_t *const ref_array[],
832                                  int ref_stride, uint32_t *sad_array) {
833   __m256i sad_vec[4];
834   const uint16_t *refp[4];
835   const uint16_t *keep = CONVERT_TO_SHORTPTR(src);
836   const uint16_t *srcp;
837   const int shift_for_4_rows = 2;
838   int i;
839   int rows_section;
840 
841   init_sad(sad_vec);
842   convert_pointers(ref_array, refp);
843 
844   for (i = 0; i < 4; ++i) {
845     srcp = keep;
846     rows_section = 0;
847     while (rows_section < 4) {
848       sad32x4(srcp, src_stride, refp[i], ref_stride, 0, &sad_vec[i]);
849       srcp += src_stride << shift_for_4_rows;
850       refp[i] += ref_stride << shift_for_4_rows;
851       rows_section++;
852     }
853   }
854   get_4d_sad_from_mm256_epi32(sad_vec, sad_array);
855 }
856 
aom_highbd_sad32x32x4d_avx2(const uint8_t * src,int src_stride,const uint8_t * const ref_array[],int ref_stride,uint32_t * sad_array)857 void aom_highbd_sad32x32x4d_avx2(const uint8_t *src, int src_stride,
858                                  const uint8_t *const ref_array[],
859                                  int ref_stride, uint32_t *sad_array) {
860   uint32_t first_half[4];
861   uint32_t second_half[4];
862   const uint8_t *ref[4];
863   const int shift_for_rows = 4;
864 
865   ref[0] = ref_array[0];
866   ref[1] = ref_array[1];
867   ref[2] = ref_array[2];
868   ref[3] = ref_array[3];
869 
870   aom_highbd_sad32x16x4d_avx2(src, src_stride, ref, ref_stride, first_half);
871   src += src_stride << shift_for_rows;
872   ref[0] += ref_stride << shift_for_rows;
873   ref[1] += ref_stride << shift_for_rows;
874   ref[2] += ref_stride << shift_for_rows;
875   ref[3] += ref_stride << shift_for_rows;
876   aom_highbd_sad32x16x4d_avx2(src, src_stride, ref, ref_stride, second_half);
877   sad_array[0] = first_half[0] + second_half[0];
878   sad_array[1] = first_half[1] + second_half[1];
879   sad_array[2] = first_half[2] + second_half[2];
880   sad_array[3] = first_half[3] + second_half[3];
881 }
882 
aom_highbd_sad32x64x4d_avx2(const uint8_t * src,int src_stride,const uint8_t * const ref_array[],int ref_stride,uint32_t * sad_array)883 void aom_highbd_sad32x64x4d_avx2(const uint8_t *src, int src_stride,
884                                  const uint8_t *const ref_array[],
885                                  int ref_stride, uint32_t *sad_array) {
886   uint32_t first_half[4];
887   uint32_t second_half[4];
888   const uint8_t *ref[4];
889   const int shift_for_rows = 5;
890 
891   ref[0] = ref_array[0];
892   ref[1] = ref_array[1];
893   ref[2] = ref_array[2];
894   ref[3] = ref_array[3];
895 
896   aom_highbd_sad32x32x4d_avx2(src, src_stride, ref, ref_stride, first_half);
897   src += src_stride << shift_for_rows;
898   ref[0] += ref_stride << shift_for_rows;
899   ref[1] += ref_stride << shift_for_rows;
900   ref[2] += ref_stride << shift_for_rows;
901   ref[3] += ref_stride << shift_for_rows;
902   aom_highbd_sad32x32x4d_avx2(src, src_stride, ref, ref_stride, second_half);
903   sad_array[0] = first_half[0] + second_half[0];
904   sad_array[1] = first_half[1] + second_half[1];
905   sad_array[2] = first_half[2] + second_half[2];
906   sad_array[3] = first_half[3] + second_half[3];
907 }
908 
aom_highbd_sad64x32x4d_avx2(const uint8_t * src,int src_stride,const uint8_t * const ref_array[],int ref_stride,uint32_t * sad_array)909 void aom_highbd_sad64x32x4d_avx2(const uint8_t *src, int src_stride,
910                                  const uint8_t *const ref_array[],
911                                  int ref_stride, uint32_t *sad_array) {
912   __m256i sad_vec[4];
913   const uint16_t *refp[4];
914   const uint16_t *keep = CONVERT_TO_SHORTPTR(src);
915   const uint16_t *srcp;
916   const int shift_for_rows = 1;
917   int i;
918   int rows_section;
919 
920   init_sad(sad_vec);
921   convert_pointers(ref_array, refp);
922 
923   for (i = 0; i < 4; ++i) {
924     srcp = keep;
925     rows_section = 0;
926     while (rows_section < 16) {
927       sad64x2(srcp, src_stride, refp[i], ref_stride, NULL, &sad_vec[i]);
928       srcp += src_stride << shift_for_rows;
929       refp[i] += ref_stride << shift_for_rows;
930       rows_section++;
931     }
932   }
933   get_4d_sad_from_mm256_epi32(sad_vec, sad_array);
934 }
935 
aom_highbd_sad64x64x4d_avx2(const uint8_t * src,int src_stride,const uint8_t * const ref_array[],int ref_stride,uint32_t * sad_array)936 void aom_highbd_sad64x64x4d_avx2(const uint8_t *src, int src_stride,
937                                  const uint8_t *const ref_array[],
938                                  int ref_stride, uint32_t *sad_array) {
939   uint32_t first_half[4];
940   uint32_t second_half[4];
941   const uint8_t *ref[4];
942   const int shift_for_rows = 5;
943 
944   ref[0] = ref_array[0];
945   ref[1] = ref_array[1];
946   ref[2] = ref_array[2];
947   ref[3] = ref_array[3];
948 
949   aom_highbd_sad64x32x4d_avx2(src, src_stride, ref, ref_stride, first_half);
950   src += src_stride << shift_for_rows;
951   ref[0] += ref_stride << shift_for_rows;
952   ref[1] += ref_stride << shift_for_rows;
953   ref[2] += ref_stride << shift_for_rows;
954   ref[3] += ref_stride << shift_for_rows;
955   aom_highbd_sad64x32x4d_avx2(src, src_stride, ref, ref_stride, second_half);
956   sad_array[0] = first_half[0] + second_half[0];
957   sad_array[1] = first_half[1] + second_half[1];
958   sad_array[2] = first_half[2] + second_half[2];
959   sad_array[3] = first_half[3] + second_half[3];
960 }
961 
aom_highbd_sad64x128x4d_avx2(const uint8_t * src,int src_stride,const uint8_t * const ref_array[],int ref_stride,uint32_t * sad_array)962 void aom_highbd_sad64x128x4d_avx2(const uint8_t *src, int src_stride,
963                                   const uint8_t *const ref_array[],
964                                   int ref_stride, uint32_t *sad_array) {
965   uint32_t first_half[4];
966   uint32_t second_half[4];
967   const uint8_t *ref[4];
968   const int shift_for_rows = 6;
969 
970   ref[0] = ref_array[0];
971   ref[1] = ref_array[1];
972   ref[2] = ref_array[2];
973   ref[3] = ref_array[3];
974 
975   aom_highbd_sad64x64x4d_avx2(src, src_stride, ref, ref_stride, first_half);
976   src += src_stride << shift_for_rows;
977   ref[0] += ref_stride << shift_for_rows;
978   ref[1] += ref_stride << shift_for_rows;
979   ref[2] += ref_stride << shift_for_rows;
980   ref[3] += ref_stride << shift_for_rows;
981   aom_highbd_sad64x64x4d_avx2(src, src_stride, ref, ref_stride, second_half);
982   sad_array[0] = first_half[0] + second_half[0];
983   sad_array[1] = first_half[1] + second_half[1];
984   sad_array[2] = first_half[2] + second_half[2];
985   sad_array[3] = first_half[3] + second_half[3];
986 }
987 
aom_highbd_sad128x64x4d_avx2(const uint8_t * src,int src_stride,const uint8_t * const ref_array[],int ref_stride,uint32_t * sad_array)988 void aom_highbd_sad128x64x4d_avx2(const uint8_t *src, int src_stride,
989                                   const uint8_t *const ref_array[],
990                                   int ref_stride, uint32_t *sad_array) {
991   __m256i sad_vec[4];
992   const uint16_t *refp[4];
993   const uint16_t *keep = CONVERT_TO_SHORTPTR(src);
994   const uint16_t *srcp;
995   int i;
996   int rows_section;
997 
998   init_sad(sad_vec);
999   convert_pointers(ref_array, refp);
1000 
1001   for (i = 0; i < 4; ++i) {
1002     srcp = keep;
1003     rows_section = 0;
1004     while (rows_section < 64) {
1005       sad128x1(srcp, refp[i], NULL, &sad_vec[i]);
1006       srcp += src_stride;
1007       refp[i] += ref_stride;
1008       rows_section++;
1009     }
1010   }
1011   get_4d_sad_from_mm256_epi32(sad_vec, sad_array);
1012 }
1013 
aom_highbd_sad128x128x4d_avx2(const uint8_t * src,int src_stride,const uint8_t * const ref_array[],int ref_stride,uint32_t * sad_array)1014 void aom_highbd_sad128x128x4d_avx2(const uint8_t *src, int src_stride,
1015                                    const uint8_t *const ref_array[],
1016                                    int ref_stride, uint32_t *sad_array) {
1017   uint32_t first_half[4];
1018   uint32_t second_half[4];
1019   const uint8_t *ref[4];
1020   const int shift_for_rows = 6;
1021 
1022   ref[0] = ref_array[0];
1023   ref[1] = ref_array[1];
1024   ref[2] = ref_array[2];
1025   ref[3] = ref_array[3];
1026 
1027   aom_highbd_sad128x64x4d_avx2(src, src_stride, ref, ref_stride, first_half);
1028   src += src_stride << shift_for_rows;
1029   ref[0] += ref_stride << shift_for_rows;
1030   ref[1] += ref_stride << shift_for_rows;
1031   ref[2] += ref_stride << shift_for_rows;
1032   ref[3] += ref_stride << shift_for_rows;
1033   aom_highbd_sad128x64x4d_avx2(src, src_stride, ref, ref_stride, second_half);
1034   sad_array[0] = first_half[0] + second_half[0];
1035   sad_array[1] = first_half[1] + second_half[1];
1036   sad_array[2] = first_half[2] + second_half[2];
1037   sad_array[3] = first_half[3] + second_half[3];
1038 }
1039