1 /*
2  *  Copyright (c) 2014 The WebM project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include <arm_neon.h>
12 
13 #include "./vpx_dsp_rtcd.h"
14 #include "./vpx_config.h"
15 
16 #include "vpx/vpx_integer.h"
17 #include "vpx_ports/mem.h"
18 
horizontal_add_s16x8(const int16x8_t v_16x8)19 static INLINE int horizontal_add_s16x8(const int16x8_t v_16x8) {
20   const int32x4_t a = vpaddlq_s16(v_16x8);
21   const int64x2_t b = vpaddlq_s32(a);
22   const int32x2_t c = vadd_s32(vreinterpret_s32_s64(vget_low_s64(b)),
23                                vreinterpret_s32_s64(vget_high_s64(b)));
24   return vget_lane_s32(c, 0);
25 }
26 
horizontal_add_s32x4(const int32x4_t v_32x4)27 static INLINE int horizontal_add_s32x4(const int32x4_t v_32x4) {
28   const int64x2_t b = vpaddlq_s32(v_32x4);
29   const int32x2_t c = vadd_s32(vreinterpret_s32_s64(vget_low_s64(b)),
30                                vreinterpret_s32_s64(vget_high_s64(b)));
31   return vget_lane_s32(c, 0);
32 }
33 
34 // w * h must be less than 2048 or local variable v_sum may overflow.
variance_neon_w8(const uint8_t * a,int a_stride,const uint8_t * b,int b_stride,int w,int h,uint32_t * sse,int * sum)35 static void variance_neon_w8(const uint8_t *a, int a_stride, const uint8_t *b,
36                              int b_stride, int w, int h, uint32_t *sse,
37                              int *sum) {
38   int i, j;
39   int16x8_t v_sum = vdupq_n_s16(0);
40   int32x4_t v_sse_lo = vdupq_n_s32(0);
41   int32x4_t v_sse_hi = vdupq_n_s32(0);
42 
43   for (i = 0; i < h; ++i) {
44     for (j = 0; j < w; j += 8) {
45       const uint8x8_t v_a = vld1_u8(&a[j]);
46       const uint8x8_t v_b = vld1_u8(&b[j]);
47       const uint16x8_t v_diff = vsubl_u8(v_a, v_b);
48       const int16x8_t sv_diff = vreinterpretq_s16_u16(v_diff);
49       v_sum = vaddq_s16(v_sum, sv_diff);
50       v_sse_lo =
51           vmlal_s16(v_sse_lo, vget_low_s16(sv_diff), vget_low_s16(sv_diff));
52       v_sse_hi =
53           vmlal_s16(v_sse_hi, vget_high_s16(sv_diff), vget_high_s16(sv_diff));
54     }
55     a += a_stride;
56     b += b_stride;
57   }
58 
59   *sum = horizontal_add_s16x8(v_sum);
60   *sse = (unsigned int)horizontal_add_s32x4(vaddq_s32(v_sse_lo, v_sse_hi));
61 }
62 
vpx_get8x8var_neon(const uint8_t * a,int a_stride,const uint8_t * b,int b_stride,unsigned int * sse,int * sum)63 void vpx_get8x8var_neon(const uint8_t *a, int a_stride, const uint8_t *b,
64                         int b_stride, unsigned int *sse, int *sum) {
65   variance_neon_w8(a, a_stride, b, b_stride, 8, 8, sse, sum);
66 }
67 
vpx_get16x16var_neon(const uint8_t * a,int a_stride,const uint8_t * b,int b_stride,unsigned int * sse,int * sum)68 void vpx_get16x16var_neon(const uint8_t *a, int a_stride, const uint8_t *b,
69                           int b_stride, unsigned int *sse, int *sum) {
70   variance_neon_w8(a, a_stride, b, b_stride, 16, 16, sse, sum);
71 }
72 
vpx_variance8x8_neon(const uint8_t * a,int a_stride,const uint8_t * b,int b_stride,unsigned int * sse)73 unsigned int vpx_variance8x8_neon(const uint8_t *a, int a_stride,
74                                   const uint8_t *b, int b_stride,
75                                   unsigned int *sse) {
76   int sum;
77   variance_neon_w8(a, a_stride, b, b_stride, 8, 8, sse, &sum);
78   return *sse - ((sum * sum) >> 6);
79 }
80 
vpx_variance16x16_neon(const uint8_t * a,int a_stride,const uint8_t * b,int b_stride,unsigned int * sse)81 unsigned int vpx_variance16x16_neon(const uint8_t *a, int a_stride,
82                                     const uint8_t *b, int b_stride,
83                                     unsigned int *sse) {
84   int sum;
85   variance_neon_w8(a, a_stride, b, b_stride, 16, 16, sse, &sum);
86   return *sse - (((uint32_t)((int64_t)sum * sum)) >> 8);
87 }
88 
vpx_variance32x32_neon(const uint8_t * a,int a_stride,const uint8_t * b,int b_stride,unsigned int * sse)89 unsigned int vpx_variance32x32_neon(const uint8_t *a, int a_stride,
90                                     const uint8_t *b, int b_stride,
91                                     unsigned int *sse) {
92   int sum;
93   variance_neon_w8(a, a_stride, b, b_stride, 32, 32, sse, &sum);
94   return *sse - (unsigned int)(((int64_t)sum * sum) >> 10);
95 }
96 
vpx_variance32x64_neon(const uint8_t * a,int a_stride,const uint8_t * b,int b_stride,unsigned int * sse)97 unsigned int vpx_variance32x64_neon(const uint8_t *a, int a_stride,
98                                     const uint8_t *b, int b_stride,
99                                     unsigned int *sse) {
100   int sum1, sum2;
101   uint32_t sse1, sse2;
102   variance_neon_w8(a, a_stride, b, b_stride, 32, 32, &sse1, &sum1);
103   variance_neon_w8(a + (32 * a_stride), a_stride, b + (32 * b_stride), b_stride,
104                    32, 32, &sse2, &sum2);
105   *sse = sse1 + sse2;
106   sum1 += sum2;
107   return *sse - (unsigned int)(((int64_t)sum1 * sum1) >> 11);
108 }
109 
vpx_variance64x32_neon(const uint8_t * a,int a_stride,const uint8_t * b,int b_stride,unsigned int * sse)110 unsigned int vpx_variance64x32_neon(const uint8_t *a, int a_stride,
111                                     const uint8_t *b, int b_stride,
112                                     unsigned int *sse) {
113   int sum1, sum2;
114   uint32_t sse1, sse2;
115   variance_neon_w8(a, a_stride, b, b_stride, 64, 16, &sse1, &sum1);
116   variance_neon_w8(a + (16 * a_stride), a_stride, b + (16 * b_stride), b_stride,
117                    64, 16, &sse2, &sum2);
118   *sse = sse1 + sse2;
119   sum1 += sum2;
120   return *sse - (unsigned int)(((int64_t)sum1 * sum1) >> 11);
121 }
122 
vpx_variance64x64_neon(const uint8_t * a,int a_stride,const uint8_t * b,int b_stride,unsigned int * sse)123 unsigned int vpx_variance64x64_neon(const uint8_t *a, int a_stride,
124                                     const uint8_t *b, int b_stride,
125                                     unsigned int *sse) {
126   int sum1, sum2;
127   uint32_t sse1, sse2;
128 
129   variance_neon_w8(a, a_stride, b, b_stride, 64, 16, &sse1, &sum1);
130   variance_neon_w8(a + (16 * a_stride), a_stride, b + (16 * b_stride), b_stride,
131                    64, 16, &sse2, &sum2);
132   sse1 += sse2;
133   sum1 += sum2;
134 
135   variance_neon_w8(a + (16 * 2 * a_stride), a_stride, b + (16 * 2 * b_stride),
136                    b_stride, 64, 16, &sse2, &sum2);
137   sse1 += sse2;
138   sum1 += sum2;
139 
140   variance_neon_w8(a + (16 * 3 * a_stride), a_stride, b + (16 * 3 * b_stride),
141                    b_stride, 64, 16, &sse2, &sum2);
142   *sse = sse1 + sse2;
143   sum1 += sum2;
144   return *sse - (unsigned int)(((int64_t)sum1 * sum1) >> 12);
145 }
146 
vpx_variance16x8_neon(const unsigned char * src_ptr,int source_stride,const unsigned char * ref_ptr,int recon_stride,unsigned int * sse)147 unsigned int vpx_variance16x8_neon(const unsigned char *src_ptr,
148                                    int source_stride,
149                                    const unsigned char *ref_ptr,
150                                    int recon_stride, unsigned int *sse) {
151   int i;
152   int16x4_t d22s16, d23s16, d24s16, d25s16, d26s16, d27s16, d28s16, d29s16;
153   uint32x2_t d0u32, d10u32;
154   int64x1_t d0s64, d1s64;
155   uint8x16_t q0u8, q1u8, q2u8, q3u8;
156   uint16x8_t q11u16, q12u16, q13u16, q14u16;
157   int32x4_t q8s32, q9s32, q10s32;
158   int64x2_t q0s64, q1s64, q5s64;
159 
160   q8s32 = vdupq_n_s32(0);
161   q9s32 = vdupq_n_s32(0);
162   q10s32 = vdupq_n_s32(0);
163 
164   for (i = 0; i < 4; i++) {
165     q0u8 = vld1q_u8(src_ptr);
166     src_ptr += source_stride;
167     q1u8 = vld1q_u8(src_ptr);
168     src_ptr += source_stride;
169     __builtin_prefetch(src_ptr);
170 
171     q2u8 = vld1q_u8(ref_ptr);
172     ref_ptr += recon_stride;
173     q3u8 = vld1q_u8(ref_ptr);
174     ref_ptr += recon_stride;
175     __builtin_prefetch(ref_ptr);
176 
177     q11u16 = vsubl_u8(vget_low_u8(q0u8), vget_low_u8(q2u8));
178     q12u16 = vsubl_u8(vget_high_u8(q0u8), vget_high_u8(q2u8));
179     q13u16 = vsubl_u8(vget_low_u8(q1u8), vget_low_u8(q3u8));
180     q14u16 = vsubl_u8(vget_high_u8(q1u8), vget_high_u8(q3u8));
181 
182     d22s16 = vreinterpret_s16_u16(vget_low_u16(q11u16));
183     d23s16 = vreinterpret_s16_u16(vget_high_u16(q11u16));
184     q8s32 = vpadalq_s16(q8s32, vreinterpretq_s16_u16(q11u16));
185     q9s32 = vmlal_s16(q9s32, d22s16, d22s16);
186     q10s32 = vmlal_s16(q10s32, d23s16, d23s16);
187 
188     d24s16 = vreinterpret_s16_u16(vget_low_u16(q12u16));
189     d25s16 = vreinterpret_s16_u16(vget_high_u16(q12u16));
190     q8s32 = vpadalq_s16(q8s32, vreinterpretq_s16_u16(q12u16));
191     q9s32 = vmlal_s16(q9s32, d24s16, d24s16);
192     q10s32 = vmlal_s16(q10s32, d25s16, d25s16);
193 
194     d26s16 = vreinterpret_s16_u16(vget_low_u16(q13u16));
195     d27s16 = vreinterpret_s16_u16(vget_high_u16(q13u16));
196     q8s32 = vpadalq_s16(q8s32, vreinterpretq_s16_u16(q13u16));
197     q9s32 = vmlal_s16(q9s32, d26s16, d26s16);
198     q10s32 = vmlal_s16(q10s32, d27s16, d27s16);
199 
200     d28s16 = vreinterpret_s16_u16(vget_low_u16(q14u16));
201     d29s16 = vreinterpret_s16_u16(vget_high_u16(q14u16));
202     q8s32 = vpadalq_s16(q8s32, vreinterpretq_s16_u16(q14u16));
203     q9s32 = vmlal_s16(q9s32, d28s16, d28s16);
204     q10s32 = vmlal_s16(q10s32, d29s16, d29s16);
205   }
206 
207   q10s32 = vaddq_s32(q10s32, q9s32);
208   q0s64 = vpaddlq_s32(q8s32);
209   q1s64 = vpaddlq_s32(q10s32);
210 
211   d0s64 = vadd_s64(vget_low_s64(q0s64), vget_high_s64(q0s64));
212   d1s64 = vadd_s64(vget_low_s64(q1s64), vget_high_s64(q1s64));
213 
214   q5s64 = vmull_s32(vreinterpret_s32_s64(d0s64), vreinterpret_s32_s64(d0s64));
215   vst1_lane_u32((uint32_t *)sse, vreinterpret_u32_s64(d1s64), 0);
216 
217   d10u32 = vshr_n_u32(vreinterpret_u32_s64(vget_low_s64(q5s64)), 7);
218   d0u32 = vsub_u32(vreinterpret_u32_s64(d1s64), d10u32);
219 
220   return vget_lane_u32(d0u32, 0);
221 }
222 
vpx_variance8x16_neon(const unsigned char * src_ptr,int source_stride,const unsigned char * ref_ptr,int recon_stride,unsigned int * sse)223 unsigned int vpx_variance8x16_neon(const unsigned char *src_ptr,
224                                    int source_stride,
225                                    const unsigned char *ref_ptr,
226                                    int recon_stride, unsigned int *sse) {
227   int i;
228   uint8x8_t d0u8, d2u8, d4u8, d6u8;
229   int16x4_t d22s16, d23s16, d24s16, d25s16;
230   uint32x2_t d0u32, d10u32;
231   int64x1_t d0s64, d1s64;
232   uint16x8_t q11u16, q12u16;
233   int32x4_t q8s32, q9s32, q10s32;
234   int64x2_t q0s64, q1s64, q5s64;
235 
236   q8s32 = vdupq_n_s32(0);
237   q9s32 = vdupq_n_s32(0);
238   q10s32 = vdupq_n_s32(0);
239 
240   for (i = 0; i < 8; i++) {
241     d0u8 = vld1_u8(src_ptr);
242     src_ptr += source_stride;
243     d2u8 = vld1_u8(src_ptr);
244     src_ptr += source_stride;
245     __builtin_prefetch(src_ptr);
246 
247     d4u8 = vld1_u8(ref_ptr);
248     ref_ptr += recon_stride;
249     d6u8 = vld1_u8(ref_ptr);
250     ref_ptr += recon_stride;
251     __builtin_prefetch(ref_ptr);
252 
253     q11u16 = vsubl_u8(d0u8, d4u8);
254     q12u16 = vsubl_u8(d2u8, d6u8);
255 
256     d22s16 = vreinterpret_s16_u16(vget_low_u16(q11u16));
257     d23s16 = vreinterpret_s16_u16(vget_high_u16(q11u16));
258     q8s32 = vpadalq_s16(q8s32, vreinterpretq_s16_u16(q11u16));
259     q9s32 = vmlal_s16(q9s32, d22s16, d22s16);
260     q10s32 = vmlal_s16(q10s32, d23s16, d23s16);
261 
262     d24s16 = vreinterpret_s16_u16(vget_low_u16(q12u16));
263     d25s16 = vreinterpret_s16_u16(vget_high_u16(q12u16));
264     q8s32 = vpadalq_s16(q8s32, vreinterpretq_s16_u16(q12u16));
265     q9s32 = vmlal_s16(q9s32, d24s16, d24s16);
266     q10s32 = vmlal_s16(q10s32, d25s16, d25s16);
267   }
268 
269   q10s32 = vaddq_s32(q10s32, q9s32);
270   q0s64 = vpaddlq_s32(q8s32);
271   q1s64 = vpaddlq_s32(q10s32);
272 
273   d0s64 = vadd_s64(vget_low_s64(q0s64), vget_high_s64(q0s64));
274   d1s64 = vadd_s64(vget_low_s64(q1s64), vget_high_s64(q1s64));
275 
276   q5s64 = vmull_s32(vreinterpret_s32_s64(d0s64), vreinterpret_s32_s64(d0s64));
277   vst1_lane_u32((uint32_t *)sse, vreinterpret_u32_s64(d1s64), 0);
278 
279   d10u32 = vshr_n_u32(vreinterpret_u32_s64(vget_low_s64(q5s64)), 7);
280   d0u32 = vsub_u32(vreinterpret_u32_s64(d1s64), d10u32);
281 
282   return vget_lane_u32(d0u32, 0);
283 }
284 
vpx_mse16x16_neon(const unsigned char * src_ptr,int source_stride,const unsigned char * ref_ptr,int recon_stride,unsigned int * sse)285 unsigned int vpx_mse16x16_neon(const unsigned char *src_ptr, int source_stride,
286                                const unsigned char *ref_ptr, int recon_stride,
287                                unsigned int *sse) {
288   int i;
289   int16x4_t d22s16, d23s16, d24s16, d25s16, d26s16, d27s16, d28s16, d29s16;
290   int64x1_t d0s64;
291   uint8x16_t q0u8, q1u8, q2u8, q3u8;
292   int32x4_t q7s32, q8s32, q9s32, q10s32;
293   uint16x8_t q11u16, q12u16, q13u16, q14u16;
294   int64x2_t q1s64;
295 
296   q7s32 = vdupq_n_s32(0);
297   q8s32 = vdupq_n_s32(0);
298   q9s32 = vdupq_n_s32(0);
299   q10s32 = vdupq_n_s32(0);
300 
301   for (i = 0; i < 8; i++) {  // mse16x16_neon_loop
302     q0u8 = vld1q_u8(src_ptr);
303     src_ptr += source_stride;
304     q1u8 = vld1q_u8(src_ptr);
305     src_ptr += source_stride;
306     q2u8 = vld1q_u8(ref_ptr);
307     ref_ptr += recon_stride;
308     q3u8 = vld1q_u8(ref_ptr);
309     ref_ptr += recon_stride;
310 
311     q11u16 = vsubl_u8(vget_low_u8(q0u8), vget_low_u8(q2u8));
312     q12u16 = vsubl_u8(vget_high_u8(q0u8), vget_high_u8(q2u8));
313     q13u16 = vsubl_u8(vget_low_u8(q1u8), vget_low_u8(q3u8));
314     q14u16 = vsubl_u8(vget_high_u8(q1u8), vget_high_u8(q3u8));
315 
316     d22s16 = vreinterpret_s16_u16(vget_low_u16(q11u16));
317     d23s16 = vreinterpret_s16_u16(vget_high_u16(q11u16));
318     q7s32 = vmlal_s16(q7s32, d22s16, d22s16);
319     q8s32 = vmlal_s16(q8s32, d23s16, d23s16);
320 
321     d24s16 = vreinterpret_s16_u16(vget_low_u16(q12u16));
322     d25s16 = vreinterpret_s16_u16(vget_high_u16(q12u16));
323     q9s32 = vmlal_s16(q9s32, d24s16, d24s16);
324     q10s32 = vmlal_s16(q10s32, d25s16, d25s16);
325 
326     d26s16 = vreinterpret_s16_u16(vget_low_u16(q13u16));
327     d27s16 = vreinterpret_s16_u16(vget_high_u16(q13u16));
328     q7s32 = vmlal_s16(q7s32, d26s16, d26s16);
329     q8s32 = vmlal_s16(q8s32, d27s16, d27s16);
330 
331     d28s16 = vreinterpret_s16_u16(vget_low_u16(q14u16));
332     d29s16 = vreinterpret_s16_u16(vget_high_u16(q14u16));
333     q9s32 = vmlal_s16(q9s32, d28s16, d28s16);
334     q10s32 = vmlal_s16(q10s32, d29s16, d29s16);
335   }
336 
337   q7s32 = vaddq_s32(q7s32, q8s32);
338   q9s32 = vaddq_s32(q9s32, q10s32);
339   q10s32 = vaddq_s32(q7s32, q9s32);
340 
341   q1s64 = vpaddlq_s32(q10s32);
342   d0s64 = vadd_s64(vget_low_s64(q1s64), vget_high_s64(q1s64));
343 
344   vst1_lane_u32((uint32_t *)sse, vreinterpret_u32_s64(d0s64), 0);
345   return vget_lane_u32(vreinterpret_u32_s64(d0s64), 0);
346 }
347 
vpx_get4x4sse_cs_neon(const unsigned char * src_ptr,int source_stride,const unsigned char * ref_ptr,int recon_stride)348 unsigned int vpx_get4x4sse_cs_neon(const unsigned char *src_ptr,
349                                    int source_stride,
350                                    const unsigned char *ref_ptr,
351                                    int recon_stride) {
352   int16x4_t d22s16, d24s16, d26s16, d28s16;
353   int64x1_t d0s64;
354   uint8x8_t d0u8, d1u8, d2u8, d3u8, d4u8, d5u8, d6u8, d7u8;
355   int32x4_t q7s32, q8s32, q9s32, q10s32;
356   uint16x8_t q11u16, q12u16, q13u16, q14u16;
357   int64x2_t q1s64;
358 
359   d0u8 = vld1_u8(src_ptr);
360   src_ptr += source_stride;
361   d4u8 = vld1_u8(ref_ptr);
362   ref_ptr += recon_stride;
363   d1u8 = vld1_u8(src_ptr);
364   src_ptr += source_stride;
365   d5u8 = vld1_u8(ref_ptr);
366   ref_ptr += recon_stride;
367   d2u8 = vld1_u8(src_ptr);
368   src_ptr += source_stride;
369   d6u8 = vld1_u8(ref_ptr);
370   ref_ptr += recon_stride;
371   d3u8 = vld1_u8(src_ptr);
372   src_ptr += source_stride;
373   d7u8 = vld1_u8(ref_ptr);
374   ref_ptr += recon_stride;
375 
376   q11u16 = vsubl_u8(d0u8, d4u8);
377   q12u16 = vsubl_u8(d1u8, d5u8);
378   q13u16 = vsubl_u8(d2u8, d6u8);
379   q14u16 = vsubl_u8(d3u8, d7u8);
380 
381   d22s16 = vget_low_s16(vreinterpretq_s16_u16(q11u16));
382   d24s16 = vget_low_s16(vreinterpretq_s16_u16(q12u16));
383   d26s16 = vget_low_s16(vreinterpretq_s16_u16(q13u16));
384   d28s16 = vget_low_s16(vreinterpretq_s16_u16(q14u16));
385 
386   q7s32 = vmull_s16(d22s16, d22s16);
387   q8s32 = vmull_s16(d24s16, d24s16);
388   q9s32 = vmull_s16(d26s16, d26s16);
389   q10s32 = vmull_s16(d28s16, d28s16);
390 
391   q7s32 = vaddq_s32(q7s32, q8s32);
392   q9s32 = vaddq_s32(q9s32, q10s32);
393   q9s32 = vaddq_s32(q7s32, q9s32);
394 
395   q1s64 = vpaddlq_s32(q9s32);
396   d0s64 = vadd_s64(vget_low_s64(q1s64), vget_high_s64(q1s64));
397 
398   return vget_lane_u32(vreinterpret_u32_s64(d0s64), 0);
399 }
400