1 /*
2  * Copyright (c) 2016, Alliance for Open Media. All rights reserved
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <arm_neon.h>
13 
14 #include "config/aom_dsp_rtcd.h"
15 #include "config/aom_config.h"
16 
17 #include "aom/aom_integer.h"
18 #include "aom_ports/mem.h"
19 
horizontal_add_s16x8(const int16x8_t v_16x8)20 static INLINE int horizontal_add_s16x8(const int16x8_t v_16x8) {
21   const int32x4_t a = vpaddlq_s16(v_16x8);
22   const int64x2_t b = vpaddlq_s32(a);
23   const int32x2_t c = vadd_s32(vreinterpret_s32_s64(vget_low_s64(b)),
24                                vreinterpret_s32_s64(vget_high_s64(b)));
25   return vget_lane_s32(c, 0);
26 }
27 
horizontal_add_s32x4(const int32x4_t v_32x4)28 static INLINE int horizontal_add_s32x4(const int32x4_t v_32x4) {
29   const int64x2_t b = vpaddlq_s32(v_32x4);
30   const int32x2_t c = vadd_s32(vreinterpret_s32_s64(vget_low_s64(b)),
31                                vreinterpret_s32_s64(vget_high_s64(b)));
32   return vget_lane_s32(c, 0);
33 }
34 
35 // w * h must be less than 2048 or local variable v_sum may overflow.
variance_neon_w8(const uint8_t * a,int a_stride,const uint8_t * b,int b_stride,int w,int h,uint32_t * sse,int * sum)36 static void variance_neon_w8(const uint8_t *a, int a_stride, const uint8_t *b,
37                              int b_stride, int w, int h, uint32_t *sse,
38                              int *sum) {
39   int i, j;
40   int16x8_t v_sum = vdupq_n_s16(0);
41   int32x4_t v_sse_lo = vdupq_n_s32(0);
42   int32x4_t v_sse_hi = vdupq_n_s32(0);
43 
44   for (i = 0; i < h; ++i) {
45     for (j = 0; j < w; j += 8) {
46       const uint8x8_t v_a = vld1_u8(&a[j]);
47       const uint8x8_t v_b = vld1_u8(&b[j]);
48       const uint16x8_t v_diff = vsubl_u8(v_a, v_b);
49       const int16x8_t sv_diff = vreinterpretq_s16_u16(v_diff);
50       v_sum = vaddq_s16(v_sum, sv_diff);
51       v_sse_lo =
52           vmlal_s16(v_sse_lo, vget_low_s16(sv_diff), vget_low_s16(sv_diff));
53       v_sse_hi =
54           vmlal_s16(v_sse_hi, vget_high_s16(sv_diff), vget_high_s16(sv_diff));
55     }
56     a += a_stride;
57     b += b_stride;
58   }
59 
60   *sum = horizontal_add_s16x8(v_sum);
61   *sse = (unsigned int)horizontal_add_s32x4(vaddq_s32(v_sse_lo, v_sse_hi));
62 }
63 
aom_get8x8var_neon(const uint8_t * a,int a_stride,const uint8_t * b,int b_stride,unsigned int * sse,int * sum)64 void aom_get8x8var_neon(const uint8_t *a, int a_stride, const uint8_t *b,
65                         int b_stride, unsigned int *sse, int *sum) {
66   variance_neon_w8(a, a_stride, b, b_stride, 8, 8, sse, sum);
67 }
68 
aom_get16x16var_neon(const uint8_t * a,int a_stride,const uint8_t * b,int b_stride,unsigned int * sse,int * sum)69 void aom_get16x16var_neon(const uint8_t *a, int a_stride, const uint8_t *b,
70                           int b_stride, unsigned int *sse, int *sum) {
71   variance_neon_w8(a, a_stride, b, b_stride, 16, 16, sse, sum);
72 }
73 
aom_variance8x8_neon(const uint8_t * a,int a_stride,const uint8_t * b,int b_stride,unsigned int * sse)74 unsigned int aom_variance8x8_neon(const uint8_t *a, int a_stride,
75                                   const uint8_t *b, int b_stride,
76                                   unsigned int *sse) {
77   int sum;
78   variance_neon_w8(a, a_stride, b, b_stride, 8, 8, sse, &sum);
79   return *sse - ((sum * sum) >> 6);
80 }
81 
aom_variance16x16_neon(const uint8_t * a,int a_stride,const uint8_t * b,int b_stride,unsigned int * sse)82 unsigned int aom_variance16x16_neon(const uint8_t *a, int a_stride,
83                                     const uint8_t *b, int b_stride,
84                                     unsigned int *sse) {
85   int sum;
86   variance_neon_w8(a, a_stride, b, b_stride, 16, 16, sse, &sum);
87   return *sse - (((unsigned int)((int64_t)sum * sum)) >> 8);
88 }
89 
aom_variance32x32_neon(const uint8_t * a,int a_stride,const uint8_t * b,int b_stride,unsigned int * sse)90 unsigned int aom_variance32x32_neon(const uint8_t *a, int a_stride,
91                                     const uint8_t *b, int b_stride,
92                                     unsigned int *sse) {
93   int sum;
94   variance_neon_w8(a, a_stride, b, b_stride, 32, 32, sse, &sum);
95   return *sse - (unsigned int)(((int64_t)sum * sum) >> 10);
96 }
97 
aom_variance32x64_neon(const uint8_t * a,int a_stride,const uint8_t * b,int b_stride,unsigned int * sse)98 unsigned int aom_variance32x64_neon(const uint8_t *a, int a_stride,
99                                     const uint8_t *b, int b_stride,
100                                     unsigned int *sse) {
101   int sum1, sum2;
102   uint32_t sse1, sse2;
103   variance_neon_w8(a, a_stride, b, b_stride, 32, 32, &sse1, &sum1);
104   variance_neon_w8(a + (32 * a_stride), a_stride, b + (32 * b_stride), b_stride,
105                    32, 32, &sse2, &sum2);
106   *sse = sse1 + sse2;
107   sum1 += sum2;
108   return *sse - (unsigned int)(((int64_t)sum1 * sum1) >> 11);
109 }
110 
aom_variance64x32_neon(const uint8_t * a,int a_stride,const uint8_t * b,int b_stride,unsigned int * sse)111 unsigned int aom_variance64x32_neon(const uint8_t *a, int a_stride,
112                                     const uint8_t *b, int b_stride,
113                                     unsigned int *sse) {
114   int sum1, sum2;
115   uint32_t sse1, sse2;
116   variance_neon_w8(a, a_stride, b, b_stride, 64, 16, &sse1, &sum1);
117   variance_neon_w8(a + (16 * a_stride), a_stride, b + (16 * b_stride), b_stride,
118                    64, 16, &sse2, &sum2);
119   *sse = sse1 + sse2;
120   sum1 += sum2;
121   return *sse - (unsigned int)(((int64_t)sum1 * sum1) >> 11);
122 }
123 
aom_variance64x64_neon(const uint8_t * a,int a_stride,const uint8_t * b,int b_stride,unsigned int * sse)124 unsigned int aom_variance64x64_neon(const uint8_t *a, int a_stride,
125                                     const uint8_t *b, int b_stride,
126                                     unsigned int *sse) {
127   int sum1, sum2;
128   uint32_t sse1, sse2;
129 
130   variance_neon_w8(a, a_stride, b, b_stride, 64, 16, &sse1, &sum1);
131   variance_neon_w8(a + (16 * a_stride), a_stride, b + (16 * b_stride), b_stride,
132                    64, 16, &sse2, &sum2);
133   sse1 += sse2;
134   sum1 += sum2;
135 
136   variance_neon_w8(a + (16 * 2 * a_stride), a_stride, b + (16 * 2 * b_stride),
137                    b_stride, 64, 16, &sse2, &sum2);
138   sse1 += sse2;
139   sum1 += sum2;
140 
141   variance_neon_w8(a + (16 * 3 * a_stride), a_stride, b + (16 * 3 * b_stride),
142                    b_stride, 64, 16, &sse2, &sum2);
143   *sse = sse1 + sse2;
144   sum1 += sum2;
145   return *sse - (unsigned int)(((int64_t)sum1 * sum1) >> 12);
146 }
147 
aom_variance16x8_neon(const unsigned char * src_ptr,int source_stride,const unsigned char * ref_ptr,int recon_stride,unsigned int * sse)148 unsigned int aom_variance16x8_neon(const unsigned char *src_ptr,
149                                    int source_stride,
150                                    const unsigned char *ref_ptr,
151                                    int recon_stride, unsigned int *sse) {
152   int i;
153   int16x4_t d22s16, d23s16, d24s16, d25s16, d26s16, d27s16, d28s16, d29s16;
154   uint32x2_t d0u32, d10u32;
155   int64x1_t d0s64, d1s64;
156   uint8x16_t q0u8, q1u8, q2u8, q3u8;
157   uint16x8_t q11u16, q12u16, q13u16, q14u16;
158   int32x4_t q8s32, q9s32, q10s32;
159   int64x2_t q0s64, q1s64, q5s64;
160 
161   q8s32 = vdupq_n_s32(0);
162   q9s32 = vdupq_n_s32(0);
163   q10s32 = vdupq_n_s32(0);
164 
165   for (i = 0; i < 4; i++) {
166     q0u8 = vld1q_u8(src_ptr);
167     src_ptr += source_stride;
168     q1u8 = vld1q_u8(src_ptr);
169     src_ptr += source_stride;
170     __builtin_prefetch(src_ptr);
171 
172     q2u8 = vld1q_u8(ref_ptr);
173     ref_ptr += recon_stride;
174     q3u8 = vld1q_u8(ref_ptr);
175     ref_ptr += recon_stride;
176     __builtin_prefetch(ref_ptr);
177 
178     q11u16 = vsubl_u8(vget_low_u8(q0u8), vget_low_u8(q2u8));
179     q12u16 = vsubl_u8(vget_high_u8(q0u8), vget_high_u8(q2u8));
180     q13u16 = vsubl_u8(vget_low_u8(q1u8), vget_low_u8(q3u8));
181     q14u16 = vsubl_u8(vget_high_u8(q1u8), vget_high_u8(q3u8));
182 
183     d22s16 = vreinterpret_s16_u16(vget_low_u16(q11u16));
184     d23s16 = vreinterpret_s16_u16(vget_high_u16(q11u16));
185     q8s32 = vpadalq_s16(q8s32, vreinterpretq_s16_u16(q11u16));
186     q9s32 = vmlal_s16(q9s32, d22s16, d22s16);
187     q10s32 = vmlal_s16(q10s32, d23s16, d23s16);
188 
189     d24s16 = vreinterpret_s16_u16(vget_low_u16(q12u16));
190     d25s16 = vreinterpret_s16_u16(vget_high_u16(q12u16));
191     q8s32 = vpadalq_s16(q8s32, vreinterpretq_s16_u16(q12u16));
192     q9s32 = vmlal_s16(q9s32, d24s16, d24s16);
193     q10s32 = vmlal_s16(q10s32, d25s16, d25s16);
194 
195     d26s16 = vreinterpret_s16_u16(vget_low_u16(q13u16));
196     d27s16 = vreinterpret_s16_u16(vget_high_u16(q13u16));
197     q8s32 = vpadalq_s16(q8s32, vreinterpretq_s16_u16(q13u16));
198     q9s32 = vmlal_s16(q9s32, d26s16, d26s16);
199     q10s32 = vmlal_s16(q10s32, d27s16, d27s16);
200 
201     d28s16 = vreinterpret_s16_u16(vget_low_u16(q14u16));
202     d29s16 = vreinterpret_s16_u16(vget_high_u16(q14u16));
203     q8s32 = vpadalq_s16(q8s32, vreinterpretq_s16_u16(q14u16));
204     q9s32 = vmlal_s16(q9s32, d28s16, d28s16);
205     q10s32 = vmlal_s16(q10s32, d29s16, d29s16);
206   }
207 
208   q10s32 = vaddq_s32(q10s32, q9s32);
209   q0s64 = vpaddlq_s32(q8s32);
210   q1s64 = vpaddlq_s32(q10s32);
211 
212   d0s64 = vadd_s64(vget_low_s64(q0s64), vget_high_s64(q0s64));
213   d1s64 = vadd_s64(vget_low_s64(q1s64), vget_high_s64(q1s64));
214 
215   q5s64 = vmull_s32(vreinterpret_s32_s64(d0s64), vreinterpret_s32_s64(d0s64));
216   vst1_lane_u32((uint32_t *)sse, vreinterpret_u32_s64(d1s64), 0);
217 
218   d10u32 = vshr_n_u32(vreinterpret_u32_s64(vget_low_s64(q5s64)), 7);
219   d0u32 = vsub_u32(vreinterpret_u32_s64(d1s64), d10u32);
220 
221   return vget_lane_u32(d0u32, 0);
222 }
223 
aom_variance8x16_neon(const unsigned char * src_ptr,int source_stride,const unsigned char * ref_ptr,int recon_stride,unsigned int * sse)224 unsigned int aom_variance8x16_neon(const unsigned char *src_ptr,
225                                    int source_stride,
226                                    const unsigned char *ref_ptr,
227                                    int recon_stride, unsigned int *sse) {
228   int i;
229   uint8x8_t d0u8, d2u8, d4u8, d6u8;
230   int16x4_t d22s16, d23s16, d24s16, d25s16;
231   uint32x2_t d0u32, d10u32;
232   int64x1_t d0s64, d1s64;
233   uint16x8_t q11u16, q12u16;
234   int32x4_t q8s32, q9s32, q10s32;
235   int64x2_t q0s64, q1s64, q5s64;
236 
237   q8s32 = vdupq_n_s32(0);
238   q9s32 = vdupq_n_s32(0);
239   q10s32 = vdupq_n_s32(0);
240 
241   for (i = 0; i < 8; i++) {
242     d0u8 = vld1_u8(src_ptr);
243     src_ptr += source_stride;
244     d2u8 = vld1_u8(src_ptr);
245     src_ptr += source_stride;
246     __builtin_prefetch(src_ptr);
247 
248     d4u8 = vld1_u8(ref_ptr);
249     ref_ptr += recon_stride;
250     d6u8 = vld1_u8(ref_ptr);
251     ref_ptr += recon_stride;
252     __builtin_prefetch(ref_ptr);
253 
254     q11u16 = vsubl_u8(d0u8, d4u8);
255     q12u16 = vsubl_u8(d2u8, d6u8);
256 
257     d22s16 = vreinterpret_s16_u16(vget_low_u16(q11u16));
258     d23s16 = vreinterpret_s16_u16(vget_high_u16(q11u16));
259     q8s32 = vpadalq_s16(q8s32, vreinterpretq_s16_u16(q11u16));
260     q9s32 = vmlal_s16(q9s32, d22s16, d22s16);
261     q10s32 = vmlal_s16(q10s32, d23s16, d23s16);
262 
263     d24s16 = vreinterpret_s16_u16(vget_low_u16(q12u16));
264     d25s16 = vreinterpret_s16_u16(vget_high_u16(q12u16));
265     q8s32 = vpadalq_s16(q8s32, vreinterpretq_s16_u16(q12u16));
266     q9s32 = vmlal_s16(q9s32, d24s16, d24s16);
267     q10s32 = vmlal_s16(q10s32, d25s16, d25s16);
268   }
269 
270   q10s32 = vaddq_s32(q10s32, q9s32);
271   q0s64 = vpaddlq_s32(q8s32);
272   q1s64 = vpaddlq_s32(q10s32);
273 
274   d0s64 = vadd_s64(vget_low_s64(q0s64), vget_high_s64(q0s64));
275   d1s64 = vadd_s64(vget_low_s64(q1s64), vget_high_s64(q1s64));
276 
277   q5s64 = vmull_s32(vreinterpret_s32_s64(d0s64), vreinterpret_s32_s64(d0s64));
278   vst1_lane_u32((uint32_t *)sse, vreinterpret_u32_s64(d1s64), 0);
279 
280   d10u32 = vshr_n_u32(vreinterpret_u32_s64(vget_low_s64(q5s64)), 7);
281   d0u32 = vsub_u32(vreinterpret_u32_s64(d1s64), d10u32);
282 
283   return vget_lane_u32(d0u32, 0);
284 }
285 
aom_mse16x16_neon(const unsigned char * src_ptr,int source_stride,const unsigned char * ref_ptr,int recon_stride,unsigned int * sse)286 unsigned int aom_mse16x16_neon(const unsigned char *src_ptr, int source_stride,
287                                const unsigned char *ref_ptr, int recon_stride,
288                                unsigned int *sse) {
289   int i;
290   int16x4_t d22s16, d23s16, d24s16, d25s16, d26s16, d27s16, d28s16, d29s16;
291   int64x1_t d0s64;
292   uint8x16_t q0u8, q1u8, q2u8, q3u8;
293   int32x4_t q7s32, q8s32, q9s32, q10s32;
294   uint16x8_t q11u16, q12u16, q13u16, q14u16;
295   int64x2_t q1s64;
296 
297   q7s32 = vdupq_n_s32(0);
298   q8s32 = vdupq_n_s32(0);
299   q9s32 = vdupq_n_s32(0);
300   q10s32 = vdupq_n_s32(0);
301 
302   for (i = 0; i < 8; i++) {  // mse16x16_neon_loop
303     q0u8 = vld1q_u8(src_ptr);
304     src_ptr += source_stride;
305     q1u8 = vld1q_u8(src_ptr);
306     src_ptr += source_stride;
307     q2u8 = vld1q_u8(ref_ptr);
308     ref_ptr += recon_stride;
309     q3u8 = vld1q_u8(ref_ptr);
310     ref_ptr += recon_stride;
311 
312     q11u16 = vsubl_u8(vget_low_u8(q0u8), vget_low_u8(q2u8));
313     q12u16 = vsubl_u8(vget_high_u8(q0u8), vget_high_u8(q2u8));
314     q13u16 = vsubl_u8(vget_low_u8(q1u8), vget_low_u8(q3u8));
315     q14u16 = vsubl_u8(vget_high_u8(q1u8), vget_high_u8(q3u8));
316 
317     d22s16 = vreinterpret_s16_u16(vget_low_u16(q11u16));
318     d23s16 = vreinterpret_s16_u16(vget_high_u16(q11u16));
319     q7s32 = vmlal_s16(q7s32, d22s16, d22s16);
320     q8s32 = vmlal_s16(q8s32, d23s16, d23s16);
321 
322     d24s16 = vreinterpret_s16_u16(vget_low_u16(q12u16));
323     d25s16 = vreinterpret_s16_u16(vget_high_u16(q12u16));
324     q9s32 = vmlal_s16(q9s32, d24s16, d24s16);
325     q10s32 = vmlal_s16(q10s32, d25s16, d25s16);
326 
327     d26s16 = vreinterpret_s16_u16(vget_low_u16(q13u16));
328     d27s16 = vreinterpret_s16_u16(vget_high_u16(q13u16));
329     q7s32 = vmlal_s16(q7s32, d26s16, d26s16);
330     q8s32 = vmlal_s16(q8s32, d27s16, d27s16);
331 
332     d28s16 = vreinterpret_s16_u16(vget_low_u16(q14u16));
333     d29s16 = vreinterpret_s16_u16(vget_high_u16(q14u16));
334     q9s32 = vmlal_s16(q9s32, d28s16, d28s16);
335     q10s32 = vmlal_s16(q10s32, d29s16, d29s16);
336   }
337 
338   q7s32 = vaddq_s32(q7s32, q8s32);
339   q9s32 = vaddq_s32(q9s32, q10s32);
340   q10s32 = vaddq_s32(q7s32, q9s32);
341 
342   q1s64 = vpaddlq_s32(q10s32);
343   d0s64 = vadd_s64(vget_low_s64(q1s64), vget_high_s64(q1s64));
344 
345   vst1_lane_u32((uint32_t *)sse, vreinterpret_u32_s64(d0s64), 0);
346   return vget_lane_u32(vreinterpret_u32_s64(d0s64), 0);
347 }
348 
aom_get4x4sse_cs_neon(const unsigned char * src_ptr,int source_stride,const unsigned char * ref_ptr,int recon_stride)349 unsigned int aom_get4x4sse_cs_neon(const unsigned char *src_ptr,
350                                    int source_stride,
351                                    const unsigned char *ref_ptr,
352                                    int recon_stride) {
353   int16x4_t d22s16, d24s16, d26s16, d28s16;
354   int64x1_t d0s64;
355   uint8x8_t d0u8, d1u8, d2u8, d3u8, d4u8, d5u8, d6u8, d7u8;
356   int32x4_t q7s32, q8s32, q9s32, q10s32;
357   uint16x8_t q11u16, q12u16, q13u16, q14u16;
358   int64x2_t q1s64;
359 
360   d0u8 = vld1_u8(src_ptr);
361   src_ptr += source_stride;
362   d4u8 = vld1_u8(ref_ptr);
363   ref_ptr += recon_stride;
364   d1u8 = vld1_u8(src_ptr);
365   src_ptr += source_stride;
366   d5u8 = vld1_u8(ref_ptr);
367   ref_ptr += recon_stride;
368   d2u8 = vld1_u8(src_ptr);
369   src_ptr += source_stride;
370   d6u8 = vld1_u8(ref_ptr);
371   ref_ptr += recon_stride;
372   d3u8 = vld1_u8(src_ptr);
373   src_ptr += source_stride;
374   d7u8 = vld1_u8(ref_ptr);
375   ref_ptr += recon_stride;
376 
377   q11u16 = vsubl_u8(d0u8, d4u8);
378   q12u16 = vsubl_u8(d1u8, d5u8);
379   q13u16 = vsubl_u8(d2u8, d6u8);
380   q14u16 = vsubl_u8(d3u8, d7u8);
381 
382   d22s16 = vget_low_s16(vreinterpretq_s16_u16(q11u16));
383   d24s16 = vget_low_s16(vreinterpretq_s16_u16(q12u16));
384   d26s16 = vget_low_s16(vreinterpretq_s16_u16(q13u16));
385   d28s16 = vget_low_s16(vreinterpretq_s16_u16(q14u16));
386 
387   q7s32 = vmull_s16(d22s16, d22s16);
388   q8s32 = vmull_s16(d24s16, d24s16);
389   q9s32 = vmull_s16(d26s16, d26s16);
390   q10s32 = vmull_s16(d28s16, d28s16);
391 
392   q7s32 = vaddq_s32(q7s32, q8s32);
393   q9s32 = vaddq_s32(q9s32, q10s32);
394   q9s32 = vaddq_s32(q7s32, q9s32);
395 
396   q1s64 = vpaddlq_s32(q9s32);
397   d0s64 = vadd_s64(vget_low_s64(q1s64), vget_high_s64(q1s64));
398 
399   return vget_lane_u32(vreinterpret_u32_s64(d0s64), 0);
400 }
401