1 /*
2 * Copyright (c) 2016, Alliance for Open Media. All rights reserved
3 *
4 * This source code is subject to the terms of the BSD 2 Clause License and
5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6 * was not distributed with this source code in the LICENSE file, you can
7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8 * Media Patent License 1.0 was not distributed with this source code in the
9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10 */
11
12 #include <immintrin.h>
13
14 #include "config/aom_config.h"
15 #include "config/aom_dsp_rtcd.h"
16
17 #include "aom/aom_integer.h"
18 #include "aom_dsp/x86/synonyms_avx2.h"
19 #include "aom_ports/mem.h"
20
21 // SAD
get_sad_from_mm256_epi32(const __m256i * v)22 static INLINE unsigned int get_sad_from_mm256_epi32(const __m256i *v) {
23 // input 8 32-bit summation
24 __m128i lo128, hi128;
25 __m256i u = _mm256_srli_si256(*v, 8);
26 u = _mm256_add_epi32(u, *v);
27
28 // 4 32-bit summation
29 hi128 = _mm256_extracti128_si256(u, 1);
30 lo128 = _mm256_castsi256_si128(u);
31 lo128 = _mm_add_epi32(hi128, lo128);
32
33 // 2 32-bit summation
34 hi128 = _mm_srli_si128(lo128, 4);
35 lo128 = _mm_add_epi32(lo128, hi128);
36
37 return (unsigned int)_mm_cvtsi128_si32(lo128);
38 }
39
aom_highbd_sad16x8_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride)40 unsigned int aom_highbd_sad16x8_avx2(const uint8_t *src, int src_stride,
41 const uint8_t *ref, int ref_stride) {
42 const uint16_t *src_ptr = CONVERT_TO_SHORTPTR(src);
43 const uint16_t *ref_ptr = CONVERT_TO_SHORTPTR(ref);
44
45 // first 4 rows
46 __m256i s0 = _mm256_loadu_si256((const __m256i *)src_ptr);
47 __m256i s1 = _mm256_loadu_si256((const __m256i *)(src_ptr + src_stride));
48 __m256i s2 = _mm256_loadu_si256((const __m256i *)(src_ptr + 2 * src_stride));
49 __m256i s3 = _mm256_loadu_si256((const __m256i *)(src_ptr + 3 * src_stride));
50
51 __m256i r0 = _mm256_loadu_si256((const __m256i *)ref_ptr);
52 __m256i r1 = _mm256_loadu_si256((const __m256i *)(ref_ptr + ref_stride));
53 __m256i r2 = _mm256_loadu_si256((const __m256i *)(ref_ptr + 2 * ref_stride));
54 __m256i r3 = _mm256_loadu_si256((const __m256i *)(ref_ptr + 3 * ref_stride));
55
56 __m256i u0 = _mm256_sub_epi16(s0, r0);
57 __m256i u1 = _mm256_sub_epi16(s1, r1);
58 __m256i u2 = _mm256_sub_epi16(s2, r2);
59 __m256i u3 = _mm256_sub_epi16(s3, r3);
60 __m256i zero = _mm256_setzero_si256();
61 __m256i sum0, sum1;
62
63 u0 = _mm256_abs_epi16(u0);
64 u1 = _mm256_abs_epi16(u1);
65 u2 = _mm256_abs_epi16(u2);
66 u3 = _mm256_abs_epi16(u3);
67
68 sum0 = _mm256_add_epi16(u0, u1);
69 sum0 = _mm256_add_epi16(sum0, u2);
70 sum0 = _mm256_add_epi16(sum0, u3);
71
72 // second 4 rows
73 src_ptr += src_stride << 2;
74 ref_ptr += ref_stride << 2;
75 s0 = _mm256_loadu_si256((const __m256i *)src_ptr);
76 s1 = _mm256_loadu_si256((const __m256i *)(src_ptr + src_stride));
77 s2 = _mm256_loadu_si256((const __m256i *)(src_ptr + 2 * src_stride));
78 s3 = _mm256_loadu_si256((const __m256i *)(src_ptr + 3 * src_stride));
79
80 r0 = _mm256_loadu_si256((const __m256i *)ref_ptr);
81 r1 = _mm256_loadu_si256((const __m256i *)(ref_ptr + ref_stride));
82 r2 = _mm256_loadu_si256((const __m256i *)(ref_ptr + 2 * ref_stride));
83 r3 = _mm256_loadu_si256((const __m256i *)(ref_ptr + 3 * ref_stride));
84
85 u0 = _mm256_sub_epi16(s0, r0);
86 u1 = _mm256_sub_epi16(s1, r1);
87 u2 = _mm256_sub_epi16(s2, r2);
88 u3 = _mm256_sub_epi16(s3, r3);
89
90 u0 = _mm256_abs_epi16(u0);
91 u1 = _mm256_abs_epi16(u1);
92 u2 = _mm256_abs_epi16(u2);
93 u3 = _mm256_abs_epi16(u3);
94
95 sum1 = _mm256_add_epi16(u0, u1);
96 sum1 = _mm256_add_epi16(sum1, u2);
97 sum1 = _mm256_add_epi16(sum1, u3);
98
99 // find out the SAD
100 s0 = _mm256_unpacklo_epi16(sum0, zero);
101 s1 = _mm256_unpackhi_epi16(sum0, zero);
102 r0 = _mm256_unpacklo_epi16(sum1, zero);
103 r1 = _mm256_unpackhi_epi16(sum1, zero);
104 s0 = _mm256_add_epi32(s0, s1);
105 r0 = _mm256_add_epi32(r0, r1);
106 sum0 = _mm256_add_epi32(s0, r0);
107 // 8 32-bit summation
108
109 return (unsigned int)get_sad_from_mm256_epi32(&sum0);
110 }
111
aom_highbd_sad16x16_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride)112 unsigned int aom_highbd_sad16x16_avx2(const uint8_t *src, int src_stride,
113 const uint8_t *ref, int ref_stride) {
114 const uint16_t *src_ptr = CONVERT_TO_SHORTPTR(src);
115 const uint16_t *ref_ptr = CONVERT_TO_SHORTPTR(ref);
116 __m256i s0, s1, s2, s3, r0, r1, r2, r3, u0, u1, u2, u3;
117 __m256i sum0;
118 __m256i sum = _mm256_setzero_si256();
119 const __m256i zero = _mm256_setzero_si256();
120 int row = 0;
121
122 // Loop for every 4 rows
123 while (row < 16) {
124 s0 = _mm256_loadu_si256((const __m256i *)src_ptr);
125 s1 = _mm256_loadu_si256((const __m256i *)(src_ptr + src_stride));
126 s2 = _mm256_loadu_si256((const __m256i *)(src_ptr + 2 * src_stride));
127 s3 = _mm256_loadu_si256((const __m256i *)(src_ptr + 3 * src_stride));
128
129 r0 = _mm256_loadu_si256((const __m256i *)ref_ptr);
130 r1 = _mm256_loadu_si256((const __m256i *)(ref_ptr + ref_stride));
131 r2 = _mm256_loadu_si256((const __m256i *)(ref_ptr + 2 * ref_stride));
132 r3 = _mm256_loadu_si256((const __m256i *)(ref_ptr + 3 * ref_stride));
133
134 u0 = _mm256_sub_epi16(s0, r0);
135 u1 = _mm256_sub_epi16(s1, r1);
136 u2 = _mm256_sub_epi16(s2, r2);
137 u3 = _mm256_sub_epi16(s3, r3);
138
139 u0 = _mm256_abs_epi16(u0);
140 u1 = _mm256_abs_epi16(u1);
141 u2 = _mm256_abs_epi16(u2);
142 u3 = _mm256_abs_epi16(u3);
143
144 sum0 = _mm256_add_epi16(u0, u1);
145 sum0 = _mm256_add_epi16(sum0, u2);
146 sum0 = _mm256_add_epi16(sum0, u3);
147
148 s0 = _mm256_unpacklo_epi16(sum0, zero);
149 s1 = _mm256_unpackhi_epi16(sum0, zero);
150 sum = _mm256_add_epi32(sum, s0);
151 sum = _mm256_add_epi32(sum, s1);
152 // 8 32-bit summation
153
154 row += 4;
155 src_ptr += src_stride << 2;
156 ref_ptr += ref_stride << 2;
157 }
158 return get_sad_from_mm256_epi32(&sum);
159 }
160
sad32x4(const uint16_t * src_ptr,int src_stride,const uint16_t * ref_ptr,int ref_stride,const uint16_t * sec_ptr,__m256i * sad_acc)161 static void sad32x4(const uint16_t *src_ptr, int src_stride,
162 const uint16_t *ref_ptr, int ref_stride,
163 const uint16_t *sec_ptr, __m256i *sad_acc) {
164 __m256i s0, s1, s2, s3, r0, r1, r2, r3;
165 const __m256i zero = _mm256_setzero_si256();
166 int row_sections = 0;
167
168 while (row_sections < 2) {
169 s0 = _mm256_loadu_si256((const __m256i *)src_ptr);
170 s1 = _mm256_loadu_si256((const __m256i *)(src_ptr + 16));
171 s2 = _mm256_loadu_si256((const __m256i *)(src_ptr + src_stride));
172 s3 = _mm256_loadu_si256((const __m256i *)(src_ptr + src_stride + 16));
173
174 r0 = _mm256_loadu_si256((const __m256i *)ref_ptr);
175 r1 = _mm256_loadu_si256((const __m256i *)(ref_ptr + 16));
176 r2 = _mm256_loadu_si256((const __m256i *)(ref_ptr + ref_stride));
177 r3 = _mm256_loadu_si256((const __m256i *)(ref_ptr + ref_stride + 16));
178
179 if (sec_ptr) {
180 r0 = _mm256_avg_epu16(r0, _mm256_loadu_si256((const __m256i *)sec_ptr));
181 r1 = _mm256_avg_epu16(
182 r1, _mm256_loadu_si256((const __m256i *)(sec_ptr + 16)));
183 r2 = _mm256_avg_epu16(
184 r2, _mm256_loadu_si256((const __m256i *)(sec_ptr + 32)));
185 r3 = _mm256_avg_epu16(
186 r3, _mm256_loadu_si256((const __m256i *)(sec_ptr + 48)));
187 }
188 s0 = _mm256_sub_epi16(s0, r0);
189 s1 = _mm256_sub_epi16(s1, r1);
190 s2 = _mm256_sub_epi16(s2, r2);
191 s3 = _mm256_sub_epi16(s3, r3);
192
193 s0 = _mm256_abs_epi16(s0);
194 s1 = _mm256_abs_epi16(s1);
195 s2 = _mm256_abs_epi16(s2);
196 s3 = _mm256_abs_epi16(s3);
197
198 s0 = _mm256_add_epi16(s0, s1);
199 s0 = _mm256_add_epi16(s0, s2);
200 s0 = _mm256_add_epi16(s0, s3);
201
202 r0 = _mm256_unpacklo_epi16(s0, zero);
203 r1 = _mm256_unpackhi_epi16(s0, zero);
204
205 r0 = _mm256_add_epi32(r0, r1);
206 *sad_acc = _mm256_add_epi32(*sad_acc, r0);
207
208 row_sections += 1;
209 src_ptr += src_stride << 1;
210 ref_ptr += ref_stride << 1;
211 if (sec_ptr) sec_ptr += 32 << 1;
212 }
213 }
214
aom_highbd_sad32x16_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride)215 unsigned int aom_highbd_sad32x16_avx2(const uint8_t *src, int src_stride,
216 const uint8_t *ref, int ref_stride) {
217 __m256i sad = _mm256_setzero_si256();
218 uint16_t *srcp = CONVERT_TO_SHORTPTR(src);
219 uint16_t *refp = CONVERT_TO_SHORTPTR(ref);
220 const int left_shift = 2;
221 int row_section = 0;
222
223 while (row_section < 4) {
224 sad32x4(srcp, src_stride, refp, ref_stride, NULL, &sad);
225 srcp += src_stride << left_shift;
226 refp += ref_stride << left_shift;
227 row_section += 1;
228 }
229 return get_sad_from_mm256_epi32(&sad);
230 }
231
aom_highbd_sad16x32_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride)232 unsigned int aom_highbd_sad16x32_avx2(const uint8_t *src, int src_stride,
233 const uint8_t *ref, int ref_stride) {
234 uint32_t sum = aom_highbd_sad16x16_avx2(src, src_stride, ref, ref_stride);
235 src += src_stride << 4;
236 ref += ref_stride << 4;
237 sum += aom_highbd_sad16x16_avx2(src, src_stride, ref, ref_stride);
238 return sum;
239 }
240
aom_highbd_sad32x32_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride)241 unsigned int aom_highbd_sad32x32_avx2(const uint8_t *src, int src_stride,
242 const uint8_t *ref, int ref_stride) {
243 uint32_t sum = aom_highbd_sad32x16_avx2(src, src_stride, ref, ref_stride);
244 src += src_stride << 4;
245 ref += ref_stride << 4;
246 sum += aom_highbd_sad32x16_avx2(src, src_stride, ref, ref_stride);
247 return sum;
248 }
249
aom_highbd_sad32x64_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride)250 unsigned int aom_highbd_sad32x64_avx2(const uint8_t *src, int src_stride,
251 const uint8_t *ref, int ref_stride) {
252 uint32_t sum = aom_highbd_sad32x32_avx2(src, src_stride, ref, ref_stride);
253 src += src_stride << 5;
254 ref += ref_stride << 5;
255 sum += aom_highbd_sad32x32_avx2(src, src_stride, ref, ref_stride);
256 return sum;
257 }
258
sad64x2(const uint16_t * src_ptr,int src_stride,const uint16_t * ref_ptr,int ref_stride,const uint16_t * sec_ptr,__m256i * sad_acc)259 static void sad64x2(const uint16_t *src_ptr, int src_stride,
260 const uint16_t *ref_ptr, int ref_stride,
261 const uint16_t *sec_ptr, __m256i *sad_acc) {
262 __m256i s[8], r[8];
263 const __m256i zero = _mm256_setzero_si256();
264
265 s[0] = _mm256_loadu_si256((const __m256i *)src_ptr);
266 s[1] = _mm256_loadu_si256((const __m256i *)(src_ptr + 16));
267 s[2] = _mm256_loadu_si256((const __m256i *)(src_ptr + 32));
268 s[3] = _mm256_loadu_si256((const __m256i *)(src_ptr + 48));
269 s[4] = _mm256_loadu_si256((const __m256i *)(src_ptr + src_stride));
270 s[5] = _mm256_loadu_si256((const __m256i *)(src_ptr + src_stride + 16));
271 s[6] = _mm256_loadu_si256((const __m256i *)(src_ptr + src_stride + 32));
272 s[7] = _mm256_loadu_si256((const __m256i *)(src_ptr + src_stride + 48));
273
274 r[0] = _mm256_loadu_si256((const __m256i *)ref_ptr);
275 r[1] = _mm256_loadu_si256((const __m256i *)(ref_ptr + 16));
276 r[2] = _mm256_loadu_si256((const __m256i *)(ref_ptr + 32));
277 r[3] = _mm256_loadu_si256((const __m256i *)(ref_ptr + 48));
278 r[4] = _mm256_loadu_si256((const __m256i *)(ref_ptr + ref_stride));
279 r[5] = _mm256_loadu_si256((const __m256i *)(ref_ptr + ref_stride + 16));
280 r[6] = _mm256_loadu_si256((const __m256i *)(ref_ptr + ref_stride + 32));
281 r[7] = _mm256_loadu_si256((const __m256i *)(ref_ptr + ref_stride + 48));
282
283 if (sec_ptr) {
284 r[0] = _mm256_avg_epu16(r[0], _mm256_loadu_si256((const __m256i *)sec_ptr));
285 r[1] = _mm256_avg_epu16(
286 r[1], _mm256_loadu_si256((const __m256i *)(sec_ptr + 16)));
287 r[2] = _mm256_avg_epu16(
288 r[2], _mm256_loadu_si256((const __m256i *)(sec_ptr + 32)));
289 r[3] = _mm256_avg_epu16(
290 r[3], _mm256_loadu_si256((const __m256i *)(sec_ptr + 48)));
291 r[4] = _mm256_avg_epu16(
292 r[4], _mm256_loadu_si256((const __m256i *)(sec_ptr + 64)));
293 r[5] = _mm256_avg_epu16(
294 r[5], _mm256_loadu_si256((const __m256i *)(sec_ptr + 80)));
295 r[6] = _mm256_avg_epu16(
296 r[6], _mm256_loadu_si256((const __m256i *)(sec_ptr + 96)));
297 r[7] = _mm256_avg_epu16(
298 r[7], _mm256_loadu_si256((const __m256i *)(sec_ptr + 112)));
299 }
300
301 s[0] = _mm256_sub_epi16(s[0], r[0]);
302 s[1] = _mm256_sub_epi16(s[1], r[1]);
303 s[2] = _mm256_sub_epi16(s[2], r[2]);
304 s[3] = _mm256_sub_epi16(s[3], r[3]);
305 s[4] = _mm256_sub_epi16(s[4], r[4]);
306 s[5] = _mm256_sub_epi16(s[5], r[5]);
307 s[6] = _mm256_sub_epi16(s[6], r[6]);
308 s[7] = _mm256_sub_epi16(s[7], r[7]);
309
310 s[0] = _mm256_abs_epi16(s[0]);
311 s[1] = _mm256_abs_epi16(s[1]);
312 s[2] = _mm256_abs_epi16(s[2]);
313 s[3] = _mm256_abs_epi16(s[3]);
314 s[4] = _mm256_abs_epi16(s[4]);
315 s[5] = _mm256_abs_epi16(s[5]);
316 s[6] = _mm256_abs_epi16(s[6]);
317 s[7] = _mm256_abs_epi16(s[7]);
318
319 s[0] = _mm256_add_epi16(s[0], s[1]);
320 s[0] = _mm256_add_epi16(s[0], s[2]);
321 s[0] = _mm256_add_epi16(s[0], s[3]);
322
323 s[4] = _mm256_add_epi16(s[4], s[5]);
324 s[4] = _mm256_add_epi16(s[4], s[6]);
325 s[4] = _mm256_add_epi16(s[4], s[7]);
326
327 r[0] = _mm256_unpacklo_epi16(s[0], zero);
328 r[1] = _mm256_unpackhi_epi16(s[0], zero);
329 r[2] = _mm256_unpacklo_epi16(s[4], zero);
330 r[3] = _mm256_unpackhi_epi16(s[4], zero);
331
332 r[0] = _mm256_add_epi32(r[0], r[1]);
333 r[0] = _mm256_add_epi32(r[0], r[2]);
334 r[0] = _mm256_add_epi32(r[0], r[3]);
335 *sad_acc = _mm256_add_epi32(*sad_acc, r[0]);
336 }
337
aom_highbd_sad64x32_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride)338 unsigned int aom_highbd_sad64x32_avx2(const uint8_t *src, int src_stride,
339 const uint8_t *ref, int ref_stride) {
340 __m256i sad = _mm256_setzero_si256();
341 uint16_t *srcp = CONVERT_TO_SHORTPTR(src);
342 uint16_t *refp = CONVERT_TO_SHORTPTR(ref);
343 const int left_shift = 1;
344 int row_section = 0;
345
346 while (row_section < 16) {
347 sad64x2(srcp, src_stride, refp, ref_stride, NULL, &sad);
348 srcp += src_stride << left_shift;
349 refp += ref_stride << left_shift;
350 row_section += 1;
351 }
352 return get_sad_from_mm256_epi32(&sad);
353 }
354
aom_highbd_sad64x64_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride)355 unsigned int aom_highbd_sad64x64_avx2(const uint8_t *src, int src_stride,
356 const uint8_t *ref, int ref_stride) {
357 uint32_t sum = aom_highbd_sad64x32_avx2(src, src_stride, ref, ref_stride);
358 src += src_stride << 5;
359 ref += ref_stride << 5;
360 sum += aom_highbd_sad64x32_avx2(src, src_stride, ref, ref_stride);
361 return sum;
362 }
363
sad128x1(const uint16_t * src_ptr,const uint16_t * ref_ptr,const uint16_t * sec_ptr,__m256i * sad_acc)364 static void sad128x1(const uint16_t *src_ptr, const uint16_t *ref_ptr,
365 const uint16_t *sec_ptr, __m256i *sad_acc) {
366 __m256i s[8], r[8];
367 const __m256i zero = _mm256_setzero_si256();
368
369 s[0] = _mm256_loadu_si256((const __m256i *)src_ptr);
370 s[1] = _mm256_loadu_si256((const __m256i *)(src_ptr + 16));
371 s[2] = _mm256_loadu_si256((const __m256i *)(src_ptr + 32));
372 s[3] = _mm256_loadu_si256((const __m256i *)(src_ptr + 48));
373 s[4] = _mm256_loadu_si256((const __m256i *)(src_ptr + 64));
374 s[5] = _mm256_loadu_si256((const __m256i *)(src_ptr + 80));
375 s[6] = _mm256_loadu_si256((const __m256i *)(src_ptr + 96));
376 s[7] = _mm256_loadu_si256((const __m256i *)(src_ptr + 112));
377
378 r[0] = _mm256_loadu_si256((const __m256i *)ref_ptr);
379 r[1] = _mm256_loadu_si256((const __m256i *)(ref_ptr + 16));
380 r[2] = _mm256_loadu_si256((const __m256i *)(ref_ptr + 32));
381 r[3] = _mm256_loadu_si256((const __m256i *)(ref_ptr + 48));
382 r[4] = _mm256_loadu_si256((const __m256i *)(ref_ptr + 64));
383 r[5] = _mm256_loadu_si256((const __m256i *)(ref_ptr + 80));
384 r[6] = _mm256_loadu_si256((const __m256i *)(ref_ptr + 96));
385 r[7] = _mm256_loadu_si256((const __m256i *)(ref_ptr + 112));
386
387 if (sec_ptr) {
388 r[0] = _mm256_avg_epu16(r[0], _mm256_loadu_si256((const __m256i *)sec_ptr));
389 r[1] = _mm256_avg_epu16(
390 r[1], _mm256_loadu_si256((const __m256i *)(sec_ptr + 16)));
391 r[2] = _mm256_avg_epu16(
392 r[2], _mm256_loadu_si256((const __m256i *)(sec_ptr + 32)));
393 r[3] = _mm256_avg_epu16(
394 r[3], _mm256_loadu_si256((const __m256i *)(sec_ptr + 48)));
395 r[4] = _mm256_avg_epu16(
396 r[4], _mm256_loadu_si256((const __m256i *)(sec_ptr + 64)));
397 r[5] = _mm256_avg_epu16(
398 r[5], _mm256_loadu_si256((const __m256i *)(sec_ptr + 80)));
399 r[6] = _mm256_avg_epu16(
400 r[6], _mm256_loadu_si256((const __m256i *)(sec_ptr + 96)));
401 r[7] = _mm256_avg_epu16(
402 r[7], _mm256_loadu_si256((const __m256i *)(sec_ptr + 112)));
403 }
404
405 s[0] = _mm256_sub_epi16(s[0], r[0]);
406 s[1] = _mm256_sub_epi16(s[1], r[1]);
407 s[2] = _mm256_sub_epi16(s[2], r[2]);
408 s[3] = _mm256_sub_epi16(s[3], r[3]);
409 s[4] = _mm256_sub_epi16(s[4], r[4]);
410 s[5] = _mm256_sub_epi16(s[5], r[5]);
411 s[6] = _mm256_sub_epi16(s[6], r[6]);
412 s[7] = _mm256_sub_epi16(s[7], r[7]);
413
414 s[0] = _mm256_abs_epi16(s[0]);
415 s[1] = _mm256_abs_epi16(s[1]);
416 s[2] = _mm256_abs_epi16(s[2]);
417 s[3] = _mm256_abs_epi16(s[3]);
418 s[4] = _mm256_abs_epi16(s[4]);
419 s[5] = _mm256_abs_epi16(s[5]);
420 s[6] = _mm256_abs_epi16(s[6]);
421 s[7] = _mm256_abs_epi16(s[7]);
422
423 s[0] = _mm256_add_epi16(s[0], s[1]);
424 s[0] = _mm256_add_epi16(s[0], s[2]);
425 s[0] = _mm256_add_epi16(s[0], s[3]);
426
427 s[4] = _mm256_add_epi16(s[4], s[5]);
428 s[4] = _mm256_add_epi16(s[4], s[6]);
429 s[4] = _mm256_add_epi16(s[4], s[7]);
430
431 r[0] = _mm256_unpacklo_epi16(s[0], zero);
432 r[1] = _mm256_unpackhi_epi16(s[0], zero);
433 r[2] = _mm256_unpacklo_epi16(s[4], zero);
434 r[3] = _mm256_unpackhi_epi16(s[4], zero);
435
436 r[0] = _mm256_add_epi32(r[0], r[1]);
437 r[0] = _mm256_add_epi32(r[0], r[2]);
438 r[0] = _mm256_add_epi32(r[0], r[3]);
439 *sad_acc = _mm256_add_epi32(*sad_acc, r[0]);
440 }
441
aom_highbd_sad128x64_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride)442 unsigned int aom_highbd_sad128x64_avx2(const uint8_t *src, int src_stride,
443 const uint8_t *ref, int ref_stride) {
444 __m256i sad = _mm256_setzero_si256();
445 uint16_t *srcp = CONVERT_TO_SHORTPTR(src);
446 uint16_t *refp = CONVERT_TO_SHORTPTR(ref);
447 int row = 0;
448 while (row < 64) {
449 sad128x1(srcp, refp, NULL, &sad);
450 srcp += src_stride;
451 refp += ref_stride;
452 row += 1;
453 }
454 return get_sad_from_mm256_epi32(&sad);
455 }
456
aom_highbd_sad64x128_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride)457 unsigned int aom_highbd_sad64x128_avx2(const uint8_t *src, int src_stride,
458 const uint8_t *ref, int ref_stride) {
459 uint32_t sum = aom_highbd_sad64x64_avx2(src, src_stride, ref, ref_stride);
460 src += src_stride << 6;
461 ref += ref_stride << 6;
462 sum += aom_highbd_sad64x64_avx2(src, src_stride, ref, ref_stride);
463 return sum;
464 }
465
aom_highbd_sad128x128_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride)466 unsigned int aom_highbd_sad128x128_avx2(const uint8_t *src, int src_stride,
467 const uint8_t *ref, int ref_stride) {
468 uint32_t sum = aom_highbd_sad128x64_avx2(src, src_stride, ref, ref_stride);
469 src += src_stride << 6;
470 ref += ref_stride << 6;
471 sum += aom_highbd_sad128x64_avx2(src, src_stride, ref, ref_stride);
472 return sum;
473 }
474
475 // If sec_ptr = 0, calculate regular SAD. Otherwise, calculate average SAD.
sad16x4(const uint16_t * src_ptr,int src_stride,const uint16_t * ref_ptr,int ref_stride,const uint16_t * sec_ptr,__m256i * sad_acc)476 static INLINE void sad16x4(const uint16_t *src_ptr, int src_stride,
477 const uint16_t *ref_ptr, int ref_stride,
478 const uint16_t *sec_ptr, __m256i *sad_acc) {
479 __m256i s0, s1, s2, s3, r0, r1, r2, r3;
480 const __m256i zero = _mm256_setzero_si256();
481
482 s0 = _mm256_loadu_si256((const __m256i *)src_ptr);
483 s1 = _mm256_loadu_si256((const __m256i *)(src_ptr + src_stride));
484 s2 = _mm256_loadu_si256((const __m256i *)(src_ptr + 2 * src_stride));
485 s3 = _mm256_loadu_si256((const __m256i *)(src_ptr + 3 * src_stride));
486
487 r0 = _mm256_loadu_si256((const __m256i *)ref_ptr);
488 r1 = _mm256_loadu_si256((const __m256i *)(ref_ptr + ref_stride));
489 r2 = _mm256_loadu_si256((const __m256i *)(ref_ptr + 2 * ref_stride));
490 r3 = _mm256_loadu_si256((const __m256i *)(ref_ptr + 3 * ref_stride));
491
492 if (sec_ptr) {
493 r0 = _mm256_avg_epu16(r0, _mm256_loadu_si256((const __m256i *)sec_ptr));
494 r1 = _mm256_avg_epu16(r1,
495 _mm256_loadu_si256((const __m256i *)(sec_ptr + 16)));
496 r2 = _mm256_avg_epu16(r2,
497 _mm256_loadu_si256((const __m256i *)(sec_ptr + 32)));
498 r3 = _mm256_avg_epu16(r3,
499 _mm256_loadu_si256((const __m256i *)(sec_ptr + 48)));
500 }
501
502 s0 = _mm256_sub_epi16(s0, r0);
503 s1 = _mm256_sub_epi16(s1, r1);
504 s2 = _mm256_sub_epi16(s2, r2);
505 s3 = _mm256_sub_epi16(s3, r3);
506
507 s0 = _mm256_abs_epi16(s0);
508 s1 = _mm256_abs_epi16(s1);
509 s2 = _mm256_abs_epi16(s2);
510 s3 = _mm256_abs_epi16(s3);
511
512 s0 = _mm256_add_epi16(s0, s1);
513 s0 = _mm256_add_epi16(s0, s2);
514 s0 = _mm256_add_epi16(s0, s3);
515
516 r0 = _mm256_unpacklo_epi16(s0, zero);
517 r1 = _mm256_unpackhi_epi16(s0, zero);
518
519 r0 = _mm256_add_epi32(r0, r1);
520 *sad_acc = _mm256_add_epi32(*sad_acc, r0);
521 }
522
aom_highbd_sad16x8_avg_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,const uint8_t * second_pred)523 unsigned int aom_highbd_sad16x8_avg_avx2(const uint8_t *src, int src_stride,
524 const uint8_t *ref, int ref_stride,
525 const uint8_t *second_pred) {
526 __m256i sad = _mm256_setzero_si256();
527 uint16_t *srcp = CONVERT_TO_SHORTPTR(src);
528 uint16_t *refp = CONVERT_TO_SHORTPTR(ref);
529 uint16_t *secp = CONVERT_TO_SHORTPTR(second_pred);
530
531 sad16x4(srcp, src_stride, refp, ref_stride, secp, &sad);
532
533 // Next 4 rows
534 srcp += src_stride << 2;
535 refp += ref_stride << 2;
536 secp += 64;
537 sad16x4(srcp, src_stride, refp, ref_stride, secp, &sad);
538 return get_sad_from_mm256_epi32(&sad);
539 }
540
aom_highbd_sad16x16_avg_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,const uint8_t * second_pred)541 unsigned int aom_highbd_sad16x16_avg_avx2(const uint8_t *src, int src_stride,
542 const uint8_t *ref, int ref_stride,
543 const uint8_t *second_pred) {
544 const int left_shift = 3;
545 uint32_t sum = aom_highbd_sad16x8_avg_avx2(src, src_stride, ref, ref_stride,
546 second_pred);
547 src += src_stride << left_shift;
548 ref += ref_stride << left_shift;
549 second_pred += 16 << left_shift;
550 sum += aom_highbd_sad16x8_avg_avx2(src, src_stride, ref, ref_stride,
551 second_pred);
552 return sum;
553 }
554
aom_highbd_sad16x32_avg_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,const uint8_t * second_pred)555 unsigned int aom_highbd_sad16x32_avg_avx2(const uint8_t *src, int src_stride,
556 const uint8_t *ref, int ref_stride,
557 const uint8_t *second_pred) {
558 const int left_shift = 4;
559 uint32_t sum = aom_highbd_sad16x16_avg_avx2(src, src_stride, ref, ref_stride,
560 second_pred);
561 src += src_stride << left_shift;
562 ref += ref_stride << left_shift;
563 second_pred += 16 << left_shift;
564 sum += aom_highbd_sad16x16_avg_avx2(src, src_stride, ref, ref_stride,
565 second_pred);
566 return sum;
567 }
568
aom_highbd_sad32x16_avg_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,const uint8_t * second_pred)569 unsigned int aom_highbd_sad32x16_avg_avx2(const uint8_t *src, int src_stride,
570 const uint8_t *ref, int ref_stride,
571 const uint8_t *second_pred) {
572 __m256i sad = _mm256_setzero_si256();
573 uint16_t *srcp = CONVERT_TO_SHORTPTR(src);
574 uint16_t *refp = CONVERT_TO_SHORTPTR(ref);
575 uint16_t *secp = CONVERT_TO_SHORTPTR(second_pred);
576 const int left_shift = 2;
577 int row_section = 0;
578
579 while (row_section < 4) {
580 sad32x4(srcp, src_stride, refp, ref_stride, secp, &sad);
581 srcp += src_stride << left_shift;
582 refp += ref_stride << left_shift;
583 secp += 32 << left_shift;
584 row_section += 1;
585 }
586 return get_sad_from_mm256_epi32(&sad);
587 }
588
aom_highbd_sad32x32_avg_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,const uint8_t * second_pred)589 unsigned int aom_highbd_sad32x32_avg_avx2(const uint8_t *src, int src_stride,
590 const uint8_t *ref, int ref_stride,
591 const uint8_t *second_pred) {
592 const int left_shift = 4;
593 uint32_t sum = aom_highbd_sad32x16_avg_avx2(src, src_stride, ref, ref_stride,
594 second_pred);
595 src += src_stride << left_shift;
596 ref += ref_stride << left_shift;
597 second_pred += 32 << left_shift;
598 sum += aom_highbd_sad32x16_avg_avx2(src, src_stride, ref, ref_stride,
599 second_pred);
600 return sum;
601 }
602
aom_highbd_sad32x64_avg_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,const uint8_t * second_pred)603 unsigned int aom_highbd_sad32x64_avg_avx2(const uint8_t *src, int src_stride,
604 const uint8_t *ref, int ref_stride,
605 const uint8_t *second_pred) {
606 const int left_shift = 5;
607 uint32_t sum = aom_highbd_sad32x32_avg_avx2(src, src_stride, ref, ref_stride,
608 second_pred);
609 src += src_stride << left_shift;
610 ref += ref_stride << left_shift;
611 second_pred += 32 << left_shift;
612 sum += aom_highbd_sad32x32_avg_avx2(src, src_stride, ref, ref_stride,
613 second_pred);
614 return sum;
615 }
616
aom_highbd_sad64x32_avg_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,const uint8_t * second_pred)617 unsigned int aom_highbd_sad64x32_avg_avx2(const uint8_t *src, int src_stride,
618 const uint8_t *ref, int ref_stride,
619 const uint8_t *second_pred) {
620 __m256i sad = _mm256_setzero_si256();
621 uint16_t *srcp = CONVERT_TO_SHORTPTR(src);
622 uint16_t *refp = CONVERT_TO_SHORTPTR(ref);
623 uint16_t *secp = CONVERT_TO_SHORTPTR(second_pred);
624 const int left_shift = 1;
625 int row_section = 0;
626
627 while (row_section < 16) {
628 sad64x2(srcp, src_stride, refp, ref_stride, secp, &sad);
629 srcp += src_stride << left_shift;
630 refp += ref_stride << left_shift;
631 secp += 64 << left_shift;
632 row_section += 1;
633 }
634 return get_sad_from_mm256_epi32(&sad);
635 }
636
aom_highbd_sad64x64_avg_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,const uint8_t * second_pred)637 unsigned int aom_highbd_sad64x64_avg_avx2(const uint8_t *src, int src_stride,
638 const uint8_t *ref, int ref_stride,
639 const uint8_t *second_pred) {
640 const int left_shift = 5;
641 uint32_t sum = aom_highbd_sad64x32_avg_avx2(src, src_stride, ref, ref_stride,
642 second_pred);
643 src += src_stride << left_shift;
644 ref += ref_stride << left_shift;
645 second_pred += 64 << left_shift;
646 sum += aom_highbd_sad64x32_avg_avx2(src, src_stride, ref, ref_stride,
647 second_pred);
648 return sum;
649 }
650
aom_highbd_sad64x128_avg_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,const uint8_t * second_pred)651 unsigned int aom_highbd_sad64x128_avg_avx2(const uint8_t *src, int src_stride,
652 const uint8_t *ref, int ref_stride,
653 const uint8_t *second_pred) {
654 const int left_shift = 6;
655 uint32_t sum = aom_highbd_sad64x64_avg_avx2(src, src_stride, ref, ref_stride,
656 second_pred);
657 src += src_stride << left_shift;
658 ref += ref_stride << left_shift;
659 second_pred += 64 << left_shift;
660 sum += aom_highbd_sad64x64_avg_avx2(src, src_stride, ref, ref_stride,
661 second_pred);
662 return sum;
663 }
664
aom_highbd_sad128x64_avg_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,const uint8_t * second_pred)665 unsigned int aom_highbd_sad128x64_avg_avx2(const uint8_t *src, int src_stride,
666 const uint8_t *ref, int ref_stride,
667 const uint8_t *second_pred) {
668 __m256i sad = _mm256_setzero_si256();
669 uint16_t *srcp = CONVERT_TO_SHORTPTR(src);
670 uint16_t *refp = CONVERT_TO_SHORTPTR(ref);
671 uint16_t *secp = CONVERT_TO_SHORTPTR(second_pred);
672 int row = 0;
673 while (row < 64) {
674 sad128x1(srcp, refp, secp, &sad);
675 srcp += src_stride;
676 refp += ref_stride;
677 secp += 16 << 3;
678 row += 1;
679 }
680 return get_sad_from_mm256_epi32(&sad);
681 }
682
aom_highbd_sad128x128_avg_avx2(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,const uint8_t * second_pred)683 unsigned int aom_highbd_sad128x128_avg_avx2(const uint8_t *src, int src_stride,
684 const uint8_t *ref, int ref_stride,
685 const uint8_t *second_pred) {
686 unsigned int sum;
687 const int left_shift = 6;
688
689 sum = aom_highbd_sad128x64_avg_avx2(src, src_stride, ref, ref_stride,
690 second_pred);
691 src += src_stride << left_shift;
692 ref += ref_stride << left_shift;
693 second_pred += 128 << left_shift;
694 sum += aom_highbd_sad128x64_avg_avx2(src, src_stride, ref, ref_stride,
695 second_pred);
696 return sum;
697 }
698
699 // SAD 4D
700 // Combine 4 __m256i vectors to uint32_t result[4]
get_4d_sad_from_mm256_epi32(const __m256i * v,uint32_t * res)701 static INLINE void get_4d_sad_from_mm256_epi32(const __m256i *v,
702 uint32_t *res) {
703 __m256i u0, u1, u2, u3;
704 const __m256i mask = yy_set1_64_from_32i(UINT32_MAX);
705 __m128i sad;
706
707 // 8 32-bit summation
708 u0 = _mm256_srli_si256(v[0], 4);
709 u1 = _mm256_srli_si256(v[1], 4);
710 u2 = _mm256_srli_si256(v[2], 4);
711 u3 = _mm256_srli_si256(v[3], 4);
712
713 u0 = _mm256_add_epi32(u0, v[0]);
714 u1 = _mm256_add_epi32(u1, v[1]);
715 u2 = _mm256_add_epi32(u2, v[2]);
716 u3 = _mm256_add_epi32(u3, v[3]);
717
718 u0 = _mm256_and_si256(u0, mask);
719 u1 = _mm256_and_si256(u1, mask);
720 u2 = _mm256_and_si256(u2, mask);
721 u3 = _mm256_and_si256(u3, mask);
722 // 4 32-bit summation, evenly positioned
723
724 u1 = _mm256_slli_si256(u1, 4);
725 u3 = _mm256_slli_si256(u3, 4);
726
727 u0 = _mm256_or_si256(u0, u1);
728 u2 = _mm256_or_si256(u2, u3);
729 // 8 32-bit summation, interleaved
730
731 u1 = _mm256_unpacklo_epi64(u0, u2);
732 u3 = _mm256_unpackhi_epi64(u0, u2);
733
734 u0 = _mm256_add_epi32(u1, u3);
735 sad = _mm_add_epi32(_mm256_extractf128_si256(u0, 1),
736 _mm256_castsi256_si128(u0));
737 _mm_storeu_si128((__m128i *)res, sad);
738 }
739
convert_pointers(const uint8_t * const ref8[],const uint16_t * ref[])740 static void convert_pointers(const uint8_t *const ref8[],
741 const uint16_t *ref[]) {
742 ref[0] = CONVERT_TO_SHORTPTR(ref8[0]);
743 ref[1] = CONVERT_TO_SHORTPTR(ref8[1]);
744 ref[2] = CONVERT_TO_SHORTPTR(ref8[2]);
745 ref[3] = CONVERT_TO_SHORTPTR(ref8[3]);
746 }
747
init_sad(__m256i * s)748 static void init_sad(__m256i *s) {
749 s[0] = _mm256_setzero_si256();
750 s[1] = _mm256_setzero_si256();
751 s[2] = _mm256_setzero_si256();
752 s[3] = _mm256_setzero_si256();
753 }
754
aom_highbd_sad16x8x4d_avx2(const uint8_t * src,int src_stride,const uint8_t * const ref_array[],int ref_stride,uint32_t * sad_array)755 void aom_highbd_sad16x8x4d_avx2(const uint8_t *src, int src_stride,
756 const uint8_t *const ref_array[],
757 int ref_stride, uint32_t *sad_array) {
758 __m256i sad_vec[4];
759 const uint16_t *refp[4];
760 const uint16_t *keep = CONVERT_TO_SHORTPTR(src);
761 const uint16_t *srcp;
762 const int shift_for_4_rows = 2;
763 int i;
764
765 init_sad(sad_vec);
766 convert_pointers(ref_array, refp);
767
768 for (i = 0; i < 4; ++i) {
769 srcp = keep;
770 sad16x4(srcp, src_stride, refp[i], ref_stride, 0, &sad_vec[i]);
771 srcp += src_stride << shift_for_4_rows;
772 refp[i] += ref_stride << shift_for_4_rows;
773 sad16x4(srcp, src_stride, refp[i], ref_stride, 0, &sad_vec[i]);
774 }
775 get_4d_sad_from_mm256_epi32(sad_vec, sad_array);
776 }
777
aom_highbd_sad16x16x4d_avx2(const uint8_t * src,int src_stride,const uint8_t * const ref_array[],int ref_stride,uint32_t * sad_array)778 void aom_highbd_sad16x16x4d_avx2(const uint8_t *src, int src_stride,
779 const uint8_t *const ref_array[],
780 int ref_stride, uint32_t *sad_array) {
781 uint32_t first8rows[4];
782 uint32_t second8rows[4];
783 const uint8_t *ref[4];
784 const int shift_for_8_rows = 3;
785
786 ref[0] = ref_array[0];
787 ref[1] = ref_array[1];
788 ref[2] = ref_array[2];
789 ref[3] = ref_array[3];
790
791 aom_highbd_sad16x8x4d_avx2(src, src_stride, ref, ref_stride, first8rows);
792 src += src_stride << shift_for_8_rows;
793 ref[0] += ref_stride << shift_for_8_rows;
794 ref[1] += ref_stride << shift_for_8_rows;
795 ref[2] += ref_stride << shift_for_8_rows;
796 ref[3] += ref_stride << shift_for_8_rows;
797 aom_highbd_sad16x8x4d_avx2(src, src_stride, ref, ref_stride, second8rows);
798 sad_array[0] = first8rows[0] + second8rows[0];
799 sad_array[1] = first8rows[1] + second8rows[1];
800 sad_array[2] = first8rows[2] + second8rows[2];
801 sad_array[3] = first8rows[3] + second8rows[3];
802 }
803
aom_highbd_sad16x32x4d_avx2(const uint8_t * src,int src_stride,const uint8_t * const ref_array[],int ref_stride,uint32_t * sad_array)804 void aom_highbd_sad16x32x4d_avx2(const uint8_t *src, int src_stride,
805 const uint8_t *const ref_array[],
806 int ref_stride, uint32_t *sad_array) {
807 uint32_t first_half[4];
808 uint32_t second_half[4];
809 const uint8_t *ref[4];
810 const int shift_for_rows = 4;
811
812 ref[0] = ref_array[0];
813 ref[1] = ref_array[1];
814 ref[2] = ref_array[2];
815 ref[3] = ref_array[3];
816
817 aom_highbd_sad16x16x4d_avx2(src, src_stride, ref, ref_stride, first_half);
818 src += src_stride << shift_for_rows;
819 ref[0] += ref_stride << shift_for_rows;
820 ref[1] += ref_stride << shift_for_rows;
821 ref[2] += ref_stride << shift_for_rows;
822 ref[3] += ref_stride << shift_for_rows;
823 aom_highbd_sad16x16x4d_avx2(src, src_stride, ref, ref_stride, second_half);
824 sad_array[0] = first_half[0] + second_half[0];
825 sad_array[1] = first_half[1] + second_half[1];
826 sad_array[2] = first_half[2] + second_half[2];
827 sad_array[3] = first_half[3] + second_half[3];
828 }
829
aom_highbd_sad32x16x4d_avx2(const uint8_t * src,int src_stride,const uint8_t * const ref_array[],int ref_stride,uint32_t * sad_array)830 void aom_highbd_sad32x16x4d_avx2(const uint8_t *src, int src_stride,
831 const uint8_t *const ref_array[],
832 int ref_stride, uint32_t *sad_array) {
833 __m256i sad_vec[4];
834 const uint16_t *refp[4];
835 const uint16_t *keep = CONVERT_TO_SHORTPTR(src);
836 const uint16_t *srcp;
837 const int shift_for_4_rows = 2;
838 int i;
839 int rows_section;
840
841 init_sad(sad_vec);
842 convert_pointers(ref_array, refp);
843
844 for (i = 0; i < 4; ++i) {
845 srcp = keep;
846 rows_section = 0;
847 while (rows_section < 4) {
848 sad32x4(srcp, src_stride, refp[i], ref_stride, 0, &sad_vec[i]);
849 srcp += src_stride << shift_for_4_rows;
850 refp[i] += ref_stride << shift_for_4_rows;
851 rows_section++;
852 }
853 }
854 get_4d_sad_from_mm256_epi32(sad_vec, sad_array);
855 }
856
aom_highbd_sad32x32x4d_avx2(const uint8_t * src,int src_stride,const uint8_t * const ref_array[],int ref_stride,uint32_t * sad_array)857 void aom_highbd_sad32x32x4d_avx2(const uint8_t *src, int src_stride,
858 const uint8_t *const ref_array[],
859 int ref_stride, uint32_t *sad_array) {
860 uint32_t first_half[4];
861 uint32_t second_half[4];
862 const uint8_t *ref[4];
863 const int shift_for_rows = 4;
864
865 ref[0] = ref_array[0];
866 ref[1] = ref_array[1];
867 ref[2] = ref_array[2];
868 ref[3] = ref_array[3];
869
870 aom_highbd_sad32x16x4d_avx2(src, src_stride, ref, ref_stride, first_half);
871 src += src_stride << shift_for_rows;
872 ref[0] += ref_stride << shift_for_rows;
873 ref[1] += ref_stride << shift_for_rows;
874 ref[2] += ref_stride << shift_for_rows;
875 ref[3] += ref_stride << shift_for_rows;
876 aom_highbd_sad32x16x4d_avx2(src, src_stride, ref, ref_stride, second_half);
877 sad_array[0] = first_half[0] + second_half[0];
878 sad_array[1] = first_half[1] + second_half[1];
879 sad_array[2] = first_half[2] + second_half[2];
880 sad_array[3] = first_half[3] + second_half[3];
881 }
882
aom_highbd_sad32x64x4d_avx2(const uint8_t * src,int src_stride,const uint8_t * const ref_array[],int ref_stride,uint32_t * sad_array)883 void aom_highbd_sad32x64x4d_avx2(const uint8_t *src, int src_stride,
884 const uint8_t *const ref_array[],
885 int ref_stride, uint32_t *sad_array) {
886 uint32_t first_half[4];
887 uint32_t second_half[4];
888 const uint8_t *ref[4];
889 const int shift_for_rows = 5;
890
891 ref[0] = ref_array[0];
892 ref[1] = ref_array[1];
893 ref[2] = ref_array[2];
894 ref[3] = ref_array[3];
895
896 aom_highbd_sad32x32x4d_avx2(src, src_stride, ref, ref_stride, first_half);
897 src += src_stride << shift_for_rows;
898 ref[0] += ref_stride << shift_for_rows;
899 ref[1] += ref_stride << shift_for_rows;
900 ref[2] += ref_stride << shift_for_rows;
901 ref[3] += ref_stride << shift_for_rows;
902 aom_highbd_sad32x32x4d_avx2(src, src_stride, ref, ref_stride, second_half);
903 sad_array[0] = first_half[0] + second_half[0];
904 sad_array[1] = first_half[1] + second_half[1];
905 sad_array[2] = first_half[2] + second_half[2];
906 sad_array[3] = first_half[3] + second_half[3];
907 }
908
aom_highbd_sad64x32x4d_avx2(const uint8_t * src,int src_stride,const uint8_t * const ref_array[],int ref_stride,uint32_t * sad_array)909 void aom_highbd_sad64x32x4d_avx2(const uint8_t *src, int src_stride,
910 const uint8_t *const ref_array[],
911 int ref_stride, uint32_t *sad_array) {
912 __m256i sad_vec[4];
913 const uint16_t *refp[4];
914 const uint16_t *keep = CONVERT_TO_SHORTPTR(src);
915 const uint16_t *srcp;
916 const int shift_for_rows = 1;
917 int i;
918 int rows_section;
919
920 init_sad(sad_vec);
921 convert_pointers(ref_array, refp);
922
923 for (i = 0; i < 4; ++i) {
924 srcp = keep;
925 rows_section = 0;
926 while (rows_section < 16) {
927 sad64x2(srcp, src_stride, refp[i], ref_stride, NULL, &sad_vec[i]);
928 srcp += src_stride << shift_for_rows;
929 refp[i] += ref_stride << shift_for_rows;
930 rows_section++;
931 }
932 }
933 get_4d_sad_from_mm256_epi32(sad_vec, sad_array);
934 }
935
aom_highbd_sad64x64x4d_avx2(const uint8_t * src,int src_stride,const uint8_t * const ref_array[],int ref_stride,uint32_t * sad_array)936 void aom_highbd_sad64x64x4d_avx2(const uint8_t *src, int src_stride,
937 const uint8_t *const ref_array[],
938 int ref_stride, uint32_t *sad_array) {
939 uint32_t first_half[4];
940 uint32_t second_half[4];
941 const uint8_t *ref[4];
942 const int shift_for_rows = 5;
943
944 ref[0] = ref_array[0];
945 ref[1] = ref_array[1];
946 ref[2] = ref_array[2];
947 ref[3] = ref_array[3];
948
949 aom_highbd_sad64x32x4d_avx2(src, src_stride, ref, ref_stride, first_half);
950 src += src_stride << shift_for_rows;
951 ref[0] += ref_stride << shift_for_rows;
952 ref[1] += ref_stride << shift_for_rows;
953 ref[2] += ref_stride << shift_for_rows;
954 ref[3] += ref_stride << shift_for_rows;
955 aom_highbd_sad64x32x4d_avx2(src, src_stride, ref, ref_stride, second_half);
956 sad_array[0] = first_half[0] + second_half[0];
957 sad_array[1] = first_half[1] + second_half[1];
958 sad_array[2] = first_half[2] + second_half[2];
959 sad_array[3] = first_half[3] + second_half[3];
960 }
961
aom_highbd_sad64x128x4d_avx2(const uint8_t * src,int src_stride,const uint8_t * const ref_array[],int ref_stride,uint32_t * sad_array)962 void aom_highbd_sad64x128x4d_avx2(const uint8_t *src, int src_stride,
963 const uint8_t *const ref_array[],
964 int ref_stride, uint32_t *sad_array) {
965 uint32_t first_half[4];
966 uint32_t second_half[4];
967 const uint8_t *ref[4];
968 const int shift_for_rows = 6;
969
970 ref[0] = ref_array[0];
971 ref[1] = ref_array[1];
972 ref[2] = ref_array[2];
973 ref[3] = ref_array[3];
974
975 aom_highbd_sad64x64x4d_avx2(src, src_stride, ref, ref_stride, first_half);
976 src += src_stride << shift_for_rows;
977 ref[0] += ref_stride << shift_for_rows;
978 ref[1] += ref_stride << shift_for_rows;
979 ref[2] += ref_stride << shift_for_rows;
980 ref[3] += ref_stride << shift_for_rows;
981 aom_highbd_sad64x64x4d_avx2(src, src_stride, ref, ref_stride, second_half);
982 sad_array[0] = first_half[0] + second_half[0];
983 sad_array[1] = first_half[1] + second_half[1];
984 sad_array[2] = first_half[2] + second_half[2];
985 sad_array[3] = first_half[3] + second_half[3];
986 }
987
aom_highbd_sad128x64x4d_avx2(const uint8_t * src,int src_stride,const uint8_t * const ref_array[],int ref_stride,uint32_t * sad_array)988 void aom_highbd_sad128x64x4d_avx2(const uint8_t *src, int src_stride,
989 const uint8_t *const ref_array[],
990 int ref_stride, uint32_t *sad_array) {
991 __m256i sad_vec[4];
992 const uint16_t *refp[4];
993 const uint16_t *keep = CONVERT_TO_SHORTPTR(src);
994 const uint16_t *srcp;
995 int i;
996 int rows_section;
997
998 init_sad(sad_vec);
999 convert_pointers(ref_array, refp);
1000
1001 for (i = 0; i < 4; ++i) {
1002 srcp = keep;
1003 rows_section = 0;
1004 while (rows_section < 64) {
1005 sad128x1(srcp, refp[i], NULL, &sad_vec[i]);
1006 srcp += src_stride;
1007 refp[i] += ref_stride;
1008 rows_section++;
1009 }
1010 }
1011 get_4d_sad_from_mm256_epi32(sad_vec, sad_array);
1012 }
1013
aom_highbd_sad128x128x4d_avx2(const uint8_t * src,int src_stride,const uint8_t * const ref_array[],int ref_stride,uint32_t * sad_array)1014 void aom_highbd_sad128x128x4d_avx2(const uint8_t *src, int src_stride,
1015 const uint8_t *const ref_array[],
1016 int ref_stride, uint32_t *sad_array) {
1017 uint32_t first_half[4];
1018 uint32_t second_half[4];
1019 const uint8_t *ref[4];
1020 const int shift_for_rows = 6;
1021
1022 ref[0] = ref_array[0];
1023 ref[1] = ref_array[1];
1024 ref[2] = ref_array[2];
1025 ref[3] = ref_array[3];
1026
1027 aom_highbd_sad128x64x4d_avx2(src, src_stride, ref, ref_stride, first_half);
1028 src += src_stride << shift_for_rows;
1029 ref[0] += ref_stride << shift_for_rows;
1030 ref[1] += ref_stride << shift_for_rows;
1031 ref[2] += ref_stride << shift_for_rows;
1032 ref[3] += ref_stride << shift_for_rows;
1033 aom_highbd_sad128x64x4d_avx2(src, src_stride, ref, ref_stride, second_half);
1034 sad_array[0] = first_half[0] + second_half[0];
1035 sad_array[1] = first_half[1] + second_half[1];
1036 sad_array[2] = first_half[2] + second_half[2];
1037 sad_array[3] = first_half[3] + second_half[3];
1038 }
1039