1 /*
2 * Copyright (c) 2016, Alliance for Open Media. All rights reserved
3 *
4 * This source code is subject to the terms of the BSD 2 Clause License and
5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6 * was not distributed with this source code in the LICENSE file, you can
7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8 * Media Patent License 1.0 was not distributed with this source code in the
9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10 */
11
12 #include <arm_neon.h>
13
14 #include "config/aom_dsp_rtcd.h"
15 #include "config/aom_config.h"
16
17 #include "aom/aom_integer.h"
18 #include "aom_ports/mem.h"
19
horizontal_add_s16x8(const int16x8_t v_16x8)20 static INLINE int horizontal_add_s16x8(const int16x8_t v_16x8) {
21 const int32x4_t a = vpaddlq_s16(v_16x8);
22 const int64x2_t b = vpaddlq_s32(a);
23 const int32x2_t c = vadd_s32(vreinterpret_s32_s64(vget_low_s64(b)),
24 vreinterpret_s32_s64(vget_high_s64(b)));
25 return vget_lane_s32(c, 0);
26 }
27
horizontal_add_s32x4(const int32x4_t v_32x4)28 static INLINE int horizontal_add_s32x4(const int32x4_t v_32x4) {
29 const int64x2_t b = vpaddlq_s32(v_32x4);
30 const int32x2_t c = vadd_s32(vreinterpret_s32_s64(vget_low_s64(b)),
31 vreinterpret_s32_s64(vget_high_s64(b)));
32 return vget_lane_s32(c, 0);
33 }
34
35 // w * h must be less than 2048 or local variable v_sum may overflow.
variance_neon_w8(const uint8_t * a,int a_stride,const uint8_t * b,int b_stride,int w,int h,uint32_t * sse,int * sum)36 static void variance_neon_w8(const uint8_t *a, int a_stride, const uint8_t *b,
37 int b_stride, int w, int h, uint32_t *sse,
38 int *sum) {
39 int i, j;
40 int16x8_t v_sum = vdupq_n_s16(0);
41 int32x4_t v_sse_lo = vdupq_n_s32(0);
42 int32x4_t v_sse_hi = vdupq_n_s32(0);
43
44 for (i = 0; i < h; ++i) {
45 for (j = 0; j < w; j += 8) {
46 const uint8x8_t v_a = vld1_u8(&a[j]);
47 const uint8x8_t v_b = vld1_u8(&b[j]);
48 const uint16x8_t v_diff = vsubl_u8(v_a, v_b);
49 const int16x8_t sv_diff = vreinterpretq_s16_u16(v_diff);
50 v_sum = vaddq_s16(v_sum, sv_diff);
51 v_sse_lo =
52 vmlal_s16(v_sse_lo, vget_low_s16(sv_diff), vget_low_s16(sv_diff));
53 v_sse_hi =
54 vmlal_s16(v_sse_hi, vget_high_s16(sv_diff), vget_high_s16(sv_diff));
55 }
56 a += a_stride;
57 b += b_stride;
58 }
59
60 *sum = horizontal_add_s16x8(v_sum);
61 *sse = (unsigned int)horizontal_add_s32x4(vaddq_s32(v_sse_lo, v_sse_hi));
62 }
63
aom_get8x8var_neon(const uint8_t * a,int a_stride,const uint8_t * b,int b_stride,unsigned int * sse,int * sum)64 void aom_get8x8var_neon(const uint8_t *a, int a_stride, const uint8_t *b,
65 int b_stride, unsigned int *sse, int *sum) {
66 variance_neon_w8(a, a_stride, b, b_stride, 8, 8, sse, sum);
67 }
68
aom_get16x16var_neon(const uint8_t * a,int a_stride,const uint8_t * b,int b_stride,unsigned int * sse,int * sum)69 void aom_get16x16var_neon(const uint8_t *a, int a_stride, const uint8_t *b,
70 int b_stride, unsigned int *sse, int *sum) {
71 variance_neon_w8(a, a_stride, b, b_stride, 16, 16, sse, sum);
72 }
73
aom_variance8x8_neon(const uint8_t * a,int a_stride,const uint8_t * b,int b_stride,unsigned int * sse)74 unsigned int aom_variance8x8_neon(const uint8_t *a, int a_stride,
75 const uint8_t *b, int b_stride,
76 unsigned int *sse) {
77 int sum;
78 variance_neon_w8(a, a_stride, b, b_stride, 8, 8, sse, &sum);
79 return *sse - ((sum * sum) >> 6);
80 }
81
aom_variance16x16_neon(const uint8_t * a,int a_stride,const uint8_t * b,int b_stride,unsigned int * sse)82 unsigned int aom_variance16x16_neon(const uint8_t *a, int a_stride,
83 const uint8_t *b, int b_stride,
84 unsigned int *sse) {
85 int sum;
86 variance_neon_w8(a, a_stride, b, b_stride, 16, 16, sse, &sum);
87 return *sse - (((unsigned int)((int64_t)sum * sum)) >> 8);
88 }
89
aom_variance32x32_neon(const uint8_t * a,int a_stride,const uint8_t * b,int b_stride,unsigned int * sse)90 unsigned int aom_variance32x32_neon(const uint8_t *a, int a_stride,
91 const uint8_t *b, int b_stride,
92 unsigned int *sse) {
93 int sum;
94 variance_neon_w8(a, a_stride, b, b_stride, 32, 32, sse, &sum);
95 return *sse - (unsigned int)(((int64_t)sum * sum) >> 10);
96 }
97
aom_variance32x64_neon(const uint8_t * a,int a_stride,const uint8_t * b,int b_stride,unsigned int * sse)98 unsigned int aom_variance32x64_neon(const uint8_t *a, int a_stride,
99 const uint8_t *b, int b_stride,
100 unsigned int *sse) {
101 int sum1, sum2;
102 uint32_t sse1, sse2;
103 variance_neon_w8(a, a_stride, b, b_stride, 32, 32, &sse1, &sum1);
104 variance_neon_w8(a + (32 * a_stride), a_stride, b + (32 * b_stride), b_stride,
105 32, 32, &sse2, &sum2);
106 *sse = sse1 + sse2;
107 sum1 += sum2;
108 return *sse - (unsigned int)(((int64_t)sum1 * sum1) >> 11);
109 }
110
aom_variance64x32_neon(const uint8_t * a,int a_stride,const uint8_t * b,int b_stride,unsigned int * sse)111 unsigned int aom_variance64x32_neon(const uint8_t *a, int a_stride,
112 const uint8_t *b, int b_stride,
113 unsigned int *sse) {
114 int sum1, sum2;
115 uint32_t sse1, sse2;
116 variance_neon_w8(a, a_stride, b, b_stride, 64, 16, &sse1, &sum1);
117 variance_neon_w8(a + (16 * a_stride), a_stride, b + (16 * b_stride), b_stride,
118 64, 16, &sse2, &sum2);
119 *sse = sse1 + sse2;
120 sum1 += sum2;
121 return *sse - (unsigned int)(((int64_t)sum1 * sum1) >> 11);
122 }
123
aom_variance64x64_neon(const uint8_t * a,int a_stride,const uint8_t * b,int b_stride,unsigned int * sse)124 unsigned int aom_variance64x64_neon(const uint8_t *a, int a_stride,
125 const uint8_t *b, int b_stride,
126 unsigned int *sse) {
127 int sum1, sum2;
128 uint32_t sse1, sse2;
129
130 variance_neon_w8(a, a_stride, b, b_stride, 64, 16, &sse1, &sum1);
131 variance_neon_w8(a + (16 * a_stride), a_stride, b + (16 * b_stride), b_stride,
132 64, 16, &sse2, &sum2);
133 sse1 += sse2;
134 sum1 += sum2;
135
136 variance_neon_w8(a + (16 * 2 * a_stride), a_stride, b + (16 * 2 * b_stride),
137 b_stride, 64, 16, &sse2, &sum2);
138 sse1 += sse2;
139 sum1 += sum2;
140
141 variance_neon_w8(a + (16 * 3 * a_stride), a_stride, b + (16 * 3 * b_stride),
142 b_stride, 64, 16, &sse2, &sum2);
143 *sse = sse1 + sse2;
144 sum1 += sum2;
145 return *sse - (unsigned int)(((int64_t)sum1 * sum1) >> 12);
146 }
147
aom_variance16x8_neon(const unsigned char * src_ptr,int source_stride,const unsigned char * ref_ptr,int recon_stride,unsigned int * sse)148 unsigned int aom_variance16x8_neon(const unsigned char *src_ptr,
149 int source_stride,
150 const unsigned char *ref_ptr,
151 int recon_stride, unsigned int *sse) {
152 int i;
153 int16x4_t d22s16, d23s16, d24s16, d25s16, d26s16, d27s16, d28s16, d29s16;
154 uint32x2_t d0u32, d10u32;
155 int64x1_t d0s64, d1s64;
156 uint8x16_t q0u8, q1u8, q2u8, q3u8;
157 uint16x8_t q11u16, q12u16, q13u16, q14u16;
158 int32x4_t q8s32, q9s32, q10s32;
159 int64x2_t q0s64, q1s64, q5s64;
160
161 q8s32 = vdupq_n_s32(0);
162 q9s32 = vdupq_n_s32(0);
163 q10s32 = vdupq_n_s32(0);
164
165 for (i = 0; i < 4; i++) {
166 q0u8 = vld1q_u8(src_ptr);
167 src_ptr += source_stride;
168 q1u8 = vld1q_u8(src_ptr);
169 src_ptr += source_stride;
170 __builtin_prefetch(src_ptr);
171
172 q2u8 = vld1q_u8(ref_ptr);
173 ref_ptr += recon_stride;
174 q3u8 = vld1q_u8(ref_ptr);
175 ref_ptr += recon_stride;
176 __builtin_prefetch(ref_ptr);
177
178 q11u16 = vsubl_u8(vget_low_u8(q0u8), vget_low_u8(q2u8));
179 q12u16 = vsubl_u8(vget_high_u8(q0u8), vget_high_u8(q2u8));
180 q13u16 = vsubl_u8(vget_low_u8(q1u8), vget_low_u8(q3u8));
181 q14u16 = vsubl_u8(vget_high_u8(q1u8), vget_high_u8(q3u8));
182
183 d22s16 = vreinterpret_s16_u16(vget_low_u16(q11u16));
184 d23s16 = vreinterpret_s16_u16(vget_high_u16(q11u16));
185 q8s32 = vpadalq_s16(q8s32, vreinterpretq_s16_u16(q11u16));
186 q9s32 = vmlal_s16(q9s32, d22s16, d22s16);
187 q10s32 = vmlal_s16(q10s32, d23s16, d23s16);
188
189 d24s16 = vreinterpret_s16_u16(vget_low_u16(q12u16));
190 d25s16 = vreinterpret_s16_u16(vget_high_u16(q12u16));
191 q8s32 = vpadalq_s16(q8s32, vreinterpretq_s16_u16(q12u16));
192 q9s32 = vmlal_s16(q9s32, d24s16, d24s16);
193 q10s32 = vmlal_s16(q10s32, d25s16, d25s16);
194
195 d26s16 = vreinterpret_s16_u16(vget_low_u16(q13u16));
196 d27s16 = vreinterpret_s16_u16(vget_high_u16(q13u16));
197 q8s32 = vpadalq_s16(q8s32, vreinterpretq_s16_u16(q13u16));
198 q9s32 = vmlal_s16(q9s32, d26s16, d26s16);
199 q10s32 = vmlal_s16(q10s32, d27s16, d27s16);
200
201 d28s16 = vreinterpret_s16_u16(vget_low_u16(q14u16));
202 d29s16 = vreinterpret_s16_u16(vget_high_u16(q14u16));
203 q8s32 = vpadalq_s16(q8s32, vreinterpretq_s16_u16(q14u16));
204 q9s32 = vmlal_s16(q9s32, d28s16, d28s16);
205 q10s32 = vmlal_s16(q10s32, d29s16, d29s16);
206 }
207
208 q10s32 = vaddq_s32(q10s32, q9s32);
209 q0s64 = vpaddlq_s32(q8s32);
210 q1s64 = vpaddlq_s32(q10s32);
211
212 d0s64 = vadd_s64(vget_low_s64(q0s64), vget_high_s64(q0s64));
213 d1s64 = vadd_s64(vget_low_s64(q1s64), vget_high_s64(q1s64));
214
215 q5s64 = vmull_s32(vreinterpret_s32_s64(d0s64), vreinterpret_s32_s64(d0s64));
216 vst1_lane_u32((uint32_t *)sse, vreinterpret_u32_s64(d1s64), 0);
217
218 d10u32 = vshr_n_u32(vreinterpret_u32_s64(vget_low_s64(q5s64)), 7);
219 d0u32 = vsub_u32(vreinterpret_u32_s64(d1s64), d10u32);
220
221 return vget_lane_u32(d0u32, 0);
222 }
223
aom_variance8x16_neon(const unsigned char * src_ptr,int source_stride,const unsigned char * ref_ptr,int recon_stride,unsigned int * sse)224 unsigned int aom_variance8x16_neon(const unsigned char *src_ptr,
225 int source_stride,
226 const unsigned char *ref_ptr,
227 int recon_stride, unsigned int *sse) {
228 int i;
229 uint8x8_t d0u8, d2u8, d4u8, d6u8;
230 int16x4_t d22s16, d23s16, d24s16, d25s16;
231 uint32x2_t d0u32, d10u32;
232 int64x1_t d0s64, d1s64;
233 uint16x8_t q11u16, q12u16;
234 int32x4_t q8s32, q9s32, q10s32;
235 int64x2_t q0s64, q1s64, q5s64;
236
237 q8s32 = vdupq_n_s32(0);
238 q9s32 = vdupq_n_s32(0);
239 q10s32 = vdupq_n_s32(0);
240
241 for (i = 0; i < 8; i++) {
242 d0u8 = vld1_u8(src_ptr);
243 src_ptr += source_stride;
244 d2u8 = vld1_u8(src_ptr);
245 src_ptr += source_stride;
246 __builtin_prefetch(src_ptr);
247
248 d4u8 = vld1_u8(ref_ptr);
249 ref_ptr += recon_stride;
250 d6u8 = vld1_u8(ref_ptr);
251 ref_ptr += recon_stride;
252 __builtin_prefetch(ref_ptr);
253
254 q11u16 = vsubl_u8(d0u8, d4u8);
255 q12u16 = vsubl_u8(d2u8, d6u8);
256
257 d22s16 = vreinterpret_s16_u16(vget_low_u16(q11u16));
258 d23s16 = vreinterpret_s16_u16(vget_high_u16(q11u16));
259 q8s32 = vpadalq_s16(q8s32, vreinterpretq_s16_u16(q11u16));
260 q9s32 = vmlal_s16(q9s32, d22s16, d22s16);
261 q10s32 = vmlal_s16(q10s32, d23s16, d23s16);
262
263 d24s16 = vreinterpret_s16_u16(vget_low_u16(q12u16));
264 d25s16 = vreinterpret_s16_u16(vget_high_u16(q12u16));
265 q8s32 = vpadalq_s16(q8s32, vreinterpretq_s16_u16(q12u16));
266 q9s32 = vmlal_s16(q9s32, d24s16, d24s16);
267 q10s32 = vmlal_s16(q10s32, d25s16, d25s16);
268 }
269
270 q10s32 = vaddq_s32(q10s32, q9s32);
271 q0s64 = vpaddlq_s32(q8s32);
272 q1s64 = vpaddlq_s32(q10s32);
273
274 d0s64 = vadd_s64(vget_low_s64(q0s64), vget_high_s64(q0s64));
275 d1s64 = vadd_s64(vget_low_s64(q1s64), vget_high_s64(q1s64));
276
277 q5s64 = vmull_s32(vreinterpret_s32_s64(d0s64), vreinterpret_s32_s64(d0s64));
278 vst1_lane_u32((uint32_t *)sse, vreinterpret_u32_s64(d1s64), 0);
279
280 d10u32 = vshr_n_u32(vreinterpret_u32_s64(vget_low_s64(q5s64)), 7);
281 d0u32 = vsub_u32(vreinterpret_u32_s64(d1s64), d10u32);
282
283 return vget_lane_u32(d0u32, 0);
284 }
285
aom_mse16x16_neon(const unsigned char * src_ptr,int source_stride,const unsigned char * ref_ptr,int recon_stride,unsigned int * sse)286 unsigned int aom_mse16x16_neon(const unsigned char *src_ptr, int source_stride,
287 const unsigned char *ref_ptr, int recon_stride,
288 unsigned int *sse) {
289 int i;
290 int16x4_t d22s16, d23s16, d24s16, d25s16, d26s16, d27s16, d28s16, d29s16;
291 int64x1_t d0s64;
292 uint8x16_t q0u8, q1u8, q2u8, q3u8;
293 int32x4_t q7s32, q8s32, q9s32, q10s32;
294 uint16x8_t q11u16, q12u16, q13u16, q14u16;
295 int64x2_t q1s64;
296
297 q7s32 = vdupq_n_s32(0);
298 q8s32 = vdupq_n_s32(0);
299 q9s32 = vdupq_n_s32(0);
300 q10s32 = vdupq_n_s32(0);
301
302 for (i = 0; i < 8; i++) { // mse16x16_neon_loop
303 q0u8 = vld1q_u8(src_ptr);
304 src_ptr += source_stride;
305 q1u8 = vld1q_u8(src_ptr);
306 src_ptr += source_stride;
307 q2u8 = vld1q_u8(ref_ptr);
308 ref_ptr += recon_stride;
309 q3u8 = vld1q_u8(ref_ptr);
310 ref_ptr += recon_stride;
311
312 q11u16 = vsubl_u8(vget_low_u8(q0u8), vget_low_u8(q2u8));
313 q12u16 = vsubl_u8(vget_high_u8(q0u8), vget_high_u8(q2u8));
314 q13u16 = vsubl_u8(vget_low_u8(q1u8), vget_low_u8(q3u8));
315 q14u16 = vsubl_u8(vget_high_u8(q1u8), vget_high_u8(q3u8));
316
317 d22s16 = vreinterpret_s16_u16(vget_low_u16(q11u16));
318 d23s16 = vreinterpret_s16_u16(vget_high_u16(q11u16));
319 q7s32 = vmlal_s16(q7s32, d22s16, d22s16);
320 q8s32 = vmlal_s16(q8s32, d23s16, d23s16);
321
322 d24s16 = vreinterpret_s16_u16(vget_low_u16(q12u16));
323 d25s16 = vreinterpret_s16_u16(vget_high_u16(q12u16));
324 q9s32 = vmlal_s16(q9s32, d24s16, d24s16);
325 q10s32 = vmlal_s16(q10s32, d25s16, d25s16);
326
327 d26s16 = vreinterpret_s16_u16(vget_low_u16(q13u16));
328 d27s16 = vreinterpret_s16_u16(vget_high_u16(q13u16));
329 q7s32 = vmlal_s16(q7s32, d26s16, d26s16);
330 q8s32 = vmlal_s16(q8s32, d27s16, d27s16);
331
332 d28s16 = vreinterpret_s16_u16(vget_low_u16(q14u16));
333 d29s16 = vreinterpret_s16_u16(vget_high_u16(q14u16));
334 q9s32 = vmlal_s16(q9s32, d28s16, d28s16);
335 q10s32 = vmlal_s16(q10s32, d29s16, d29s16);
336 }
337
338 q7s32 = vaddq_s32(q7s32, q8s32);
339 q9s32 = vaddq_s32(q9s32, q10s32);
340 q10s32 = vaddq_s32(q7s32, q9s32);
341
342 q1s64 = vpaddlq_s32(q10s32);
343 d0s64 = vadd_s64(vget_low_s64(q1s64), vget_high_s64(q1s64));
344
345 vst1_lane_u32((uint32_t *)sse, vreinterpret_u32_s64(d0s64), 0);
346 return vget_lane_u32(vreinterpret_u32_s64(d0s64), 0);
347 }
348
aom_get4x4sse_cs_neon(const unsigned char * src_ptr,int source_stride,const unsigned char * ref_ptr,int recon_stride)349 unsigned int aom_get4x4sse_cs_neon(const unsigned char *src_ptr,
350 int source_stride,
351 const unsigned char *ref_ptr,
352 int recon_stride) {
353 int16x4_t d22s16, d24s16, d26s16, d28s16;
354 int64x1_t d0s64;
355 uint8x8_t d0u8, d1u8, d2u8, d3u8, d4u8, d5u8, d6u8, d7u8;
356 int32x4_t q7s32, q8s32, q9s32, q10s32;
357 uint16x8_t q11u16, q12u16, q13u16, q14u16;
358 int64x2_t q1s64;
359
360 d0u8 = vld1_u8(src_ptr);
361 src_ptr += source_stride;
362 d4u8 = vld1_u8(ref_ptr);
363 ref_ptr += recon_stride;
364 d1u8 = vld1_u8(src_ptr);
365 src_ptr += source_stride;
366 d5u8 = vld1_u8(ref_ptr);
367 ref_ptr += recon_stride;
368 d2u8 = vld1_u8(src_ptr);
369 src_ptr += source_stride;
370 d6u8 = vld1_u8(ref_ptr);
371 ref_ptr += recon_stride;
372 d3u8 = vld1_u8(src_ptr);
373 src_ptr += source_stride;
374 d7u8 = vld1_u8(ref_ptr);
375 ref_ptr += recon_stride;
376
377 q11u16 = vsubl_u8(d0u8, d4u8);
378 q12u16 = vsubl_u8(d1u8, d5u8);
379 q13u16 = vsubl_u8(d2u8, d6u8);
380 q14u16 = vsubl_u8(d3u8, d7u8);
381
382 d22s16 = vget_low_s16(vreinterpretq_s16_u16(q11u16));
383 d24s16 = vget_low_s16(vreinterpretq_s16_u16(q12u16));
384 d26s16 = vget_low_s16(vreinterpretq_s16_u16(q13u16));
385 d28s16 = vget_low_s16(vreinterpretq_s16_u16(q14u16));
386
387 q7s32 = vmull_s16(d22s16, d22s16);
388 q8s32 = vmull_s16(d24s16, d24s16);
389 q9s32 = vmull_s16(d26s16, d26s16);
390 q10s32 = vmull_s16(d28s16, d28s16);
391
392 q7s32 = vaddq_s32(q7s32, q8s32);
393 q9s32 = vaddq_s32(q9s32, q10s32);
394 q9s32 = vaddq_s32(q7s32, q9s32);
395
396 q1s64 = vpaddlq_s32(q9s32);
397 d0s64 = vadd_s64(vget_low_s64(q1s64), vget_high_s64(q1s64));
398
399 return vget_lane_u32(vreinterpret_u32_s64(d0s64), 0);
400 }
401