1 /*
2  *  Copyright (c) 2018, Alliance for Open Media. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #ifndef AOM_AV1_COMMON_ARM_CONVOLVE_NEON_H_
12 #define AOM_AV1_COMMON_ARM_CONVOLVE_NEON_H_
13 
14 #include <arm_neon.h>
15 
16 #define HORIZ_EXTRA_ROWS ((SUBPEL_TAPS + 7) & ~0x07)
17 
convolve8_4(const int16x4_t s0,const int16x4_t s1,const int16x4_t s2,const int16x4_t s3,const int16x4_t s4,const int16x4_t s5,const int16x4_t s6,const int16x4_t s7,const int16x8_t filters,const int16x4_t filter3,const int16x4_t filter4)18 static INLINE int16x4_t convolve8_4(const int16x4_t s0, const int16x4_t s1,
19                                     const int16x4_t s2, const int16x4_t s3,
20                                     const int16x4_t s4, const int16x4_t s5,
21                                     const int16x4_t s6, const int16x4_t s7,
22                                     const int16x8_t filters,
23                                     const int16x4_t filter3,
24                                     const int16x4_t filter4) {
25   const int16x4_t filters_lo = vget_low_s16(filters);
26   const int16x4_t filters_hi = vget_high_s16(filters);
27   int16x4_t sum;
28 
29   sum = vmul_lane_s16(s0, filters_lo, 0);
30   sum = vmla_lane_s16(sum, s1, filters_lo, 1);
31   sum = vmla_lane_s16(sum, s2, filters_lo, 2);
32   sum = vmla_lane_s16(sum, s5, filters_hi, 1);
33   sum = vmla_lane_s16(sum, s6, filters_hi, 2);
34   sum = vmla_lane_s16(sum, s7, filters_hi, 3);
35   sum = vqadd_s16(sum, vmul_s16(s3, filter3));
36   sum = vqadd_s16(sum, vmul_s16(s4, filter4));
37   return sum;
38 }
39 
convolve8_8(const int16x8_t s0,const int16x8_t s1,const int16x8_t s2,const int16x8_t s3,const int16x8_t s4,const int16x8_t s5,const int16x8_t s6,const int16x8_t s7,const int16x8_t filters,const int16x8_t filter3,const int16x8_t filter4)40 static INLINE uint8x8_t convolve8_8(const int16x8_t s0, const int16x8_t s1,
41                                     const int16x8_t s2, const int16x8_t s3,
42                                     const int16x8_t s4, const int16x8_t s5,
43                                     const int16x8_t s6, const int16x8_t s7,
44                                     const int16x8_t filters,
45                                     const int16x8_t filter3,
46                                     const int16x8_t filter4) {
47   const int16x4_t filters_lo = vget_low_s16(filters);
48   const int16x4_t filters_hi = vget_high_s16(filters);
49   int16x8_t sum;
50 
51   sum = vmulq_lane_s16(s0, filters_lo, 0);
52   sum = vmlaq_lane_s16(sum, s1, filters_lo, 1);
53   sum = vmlaq_lane_s16(sum, s2, filters_lo, 2);
54   sum = vmlaq_lane_s16(sum, s5, filters_hi, 1);
55   sum = vmlaq_lane_s16(sum, s6, filters_hi, 2);
56   sum = vmlaq_lane_s16(sum, s7, filters_hi, 3);
57   sum = vqaddq_s16(sum, vmulq_s16(s3, filter3));
58   sum = vqaddq_s16(sum, vmulq_s16(s4, filter4));
59   return vqrshrun_n_s16(sum, 7);
60 }
61 
scale_filter_8(const uint8x8_t * const s,const int16x8_t filters)62 static INLINE uint8x8_t scale_filter_8(const uint8x8_t *const s,
63                                        const int16x8_t filters) {
64   const int16x8_t filter3 = vdupq_lane_s16(vget_low_s16(filters), 3);
65   const int16x8_t filter4 = vdupq_lane_s16(vget_high_s16(filters), 0);
66   int16x8_t ss[8];
67 
68   ss[0] = vreinterpretq_s16_u16(vmovl_u8(s[0]));
69   ss[1] = vreinterpretq_s16_u16(vmovl_u8(s[1]));
70   ss[2] = vreinterpretq_s16_u16(vmovl_u8(s[2]));
71   ss[3] = vreinterpretq_s16_u16(vmovl_u8(s[3]));
72   ss[4] = vreinterpretq_s16_u16(vmovl_u8(s[4]));
73   ss[5] = vreinterpretq_s16_u16(vmovl_u8(s[5]));
74   ss[6] = vreinterpretq_s16_u16(vmovl_u8(s[6]));
75   ss[7] = vreinterpretq_s16_u16(vmovl_u8(s[7]));
76 
77   return convolve8_8(ss[0], ss[1], ss[2], ss[3], ss[4], ss[5], ss[6], ss[7],
78                      filters, filter3, filter4);
79 }
80 
wiener_convolve8_vert_4x8(const int16x8_t s0,const int16x8_t s1,const int16x8_t s2,const int16x8_t s3,const int16x8_t s4,const int16x8_t s5,const int16x8_t s6,int16_t * filter_y,const int bd,const int round1_bits)81 static INLINE uint8x8_t wiener_convolve8_vert_4x8(
82     const int16x8_t s0, const int16x8_t s1, const int16x8_t s2,
83     const int16x8_t s3, const int16x8_t s4, const int16x8_t s5,
84     const int16x8_t s6, int16_t *filter_y, const int bd,
85     const int round1_bits) {
86   int16x8_t ss0, ss1, ss2;
87   int32x4_t sum0, sum1;
88   uint16x4_t tmp0, tmp1;
89   uint16x8_t tmp;
90   uint8x8_t res;
91 
92   const int32_t round_const = (1 << (bd + round1_bits - 1));
93   const int32x4_t round_bits = vdupq_n_s32(-round1_bits);
94   const int32x4_t zero = vdupq_n_s32(0);
95   const int32x4_t round_vec = vdupq_n_s32(round_const);
96 
97   ss0 = vaddq_s16(s0, s6);
98   ss1 = vaddq_s16(s1, s5);
99   ss2 = vaddq_s16(s2, s4);
100 
101   sum0 = vmull_n_s16(vget_low_s16(ss0), filter_y[0]);
102   sum0 = vmlal_n_s16(sum0, vget_low_s16(ss1), filter_y[1]);
103   sum0 = vmlal_n_s16(sum0, vget_low_s16(ss2), filter_y[2]);
104   sum0 = vmlal_n_s16(sum0, vget_low_s16(s3), filter_y[3]);
105 
106   sum1 = vmull_n_s16(vget_high_s16(ss0), filter_y[0]);
107   sum1 = vmlal_n_s16(sum1, vget_high_s16(ss1), filter_y[1]);
108   sum1 = vmlal_n_s16(sum1, vget_high_s16(ss2), filter_y[2]);
109   sum1 = vmlal_n_s16(sum1, vget_high_s16(s3), filter_y[3]);
110 
111   sum0 = vsubq_s32(sum0, round_vec);
112   sum1 = vsubq_s32(sum1, round_vec);
113 
114   /* right shift & rounding */
115   sum0 = vrshlq_s32(sum0, round_bits);
116   sum1 = vrshlq_s32(sum1, round_bits);
117 
118   sum0 = vmaxq_s32(sum0, zero);
119   sum1 = vmaxq_s32(sum1, zero);
120 
121   /* from int32x4_t to uint8x8_t */
122   tmp0 = vqmovn_u32(vreinterpretq_u32_s32(sum0));
123   tmp1 = vqmovn_u32(vreinterpretq_u32_s32(sum1));
124   tmp = vcombine_u16(tmp0, tmp1);
125   res = vqmovn_u16(tmp);
126 
127   return res;
128 }
129 
wiener_convolve8_horiz_8x8(const int16x8_t s0,const int16x8_t s1,const int16x8_t s2,const int16x8_t s3,int16_t * filter_x,const int bd,const int round0_bits)130 static INLINE uint16x8_t wiener_convolve8_horiz_8x8(
131     const int16x8_t s0, const int16x8_t s1, const int16x8_t s2,
132     const int16x8_t s3, int16_t *filter_x, const int bd,
133     const int round0_bits) {
134   int16x8_t sum;
135   uint16x8_t res;
136   int32x4_t sum_0, sum_1;
137   int32x4_t s3_0, s3_1;
138   const int32_t round_const_0 = (1 << (bd + FILTER_BITS - 1));
139   const int32_t round_const_1 = (1 << (bd + 1 + FILTER_BITS - round0_bits)) - 1;
140 
141   /* for the purpose of right shift by { conv_params->round_0 } */
142   const int32x4_t round_bits = vdupq_n_s32(-round0_bits);
143 
144   const int32x4_t round_vec_0 = vdupq_n_s32(round_const_0);
145   const int32x4_t round_vec_1 = vdupq_n_s32(round_const_1);
146 
147   sum = vmulq_n_s16(s0, filter_x[0]);
148   sum = vmlaq_n_s16(sum, s1, filter_x[1]);
149   sum = vmlaq_n_s16(sum, s2, filter_x[2]);
150 
151   /* sum from 16x8 to 2 32x4 registers */
152   sum_0 = vmovl_s16(vget_low_s16(sum));
153   sum_1 = vmovl_s16(vget_high_s16(sum));
154 
155   /* s[3]*128 -- and filter coef max can be 128
156    *  then max value possible = 128*128*255 exceeding 16 bit
157    */
158 
159   s3_0 = vmull_n_s16(vget_low_s16(s3), filter_x[3]);
160   s3_1 = vmull_n_s16(vget_high_s16(s3), filter_x[3]);
161   sum_0 = vaddq_s32(sum_0, s3_0);
162   sum_1 = vaddq_s32(sum_1, s3_1);
163 
164   /* Add the constant value */
165   sum_0 = vaddq_s32(sum_0, round_vec_0);
166   sum_1 = vaddq_s32(sum_1, round_vec_0);
167 
168   /* right shift & rounding & saturating */
169   sum_0 = vqrshlq_s32(sum_0, round_bits);
170   sum_1 = vqrshlq_s32(sum_1, round_bits);
171 
172   /* Clipping to max value */
173   sum_0 = vminq_s32(sum_0, round_vec_1);
174   sum_1 = vminq_s32(sum_1, round_vec_1);
175 
176   res = vcombine_u16(vqmovun_s32(sum_0), vqmovun_s32(sum_1));
177   return res;
178 }
179 
wiener_convolve8_horiz_4x8(const int16x4_t s0,const int16x4_t s1,const int16x4_t s2,const int16x4_t s3,const int16x4_t s4,const int16x4_t s5,const int16x4_t s6,int16_t * filter_x,const int bd,const int round0_bits)180 static INLINE uint16x4_t wiener_convolve8_horiz_4x8(
181     const int16x4_t s0, const int16x4_t s1, const int16x4_t s2,
182     const int16x4_t s3, const int16x4_t s4, const int16x4_t s5,
183     const int16x4_t s6, int16_t *filter_x, const int bd,
184     const int round0_bits) {
185   uint16x4_t res;
186   int32x4_t sum_0, s3_0;
187   int16x4_t sum, temp0, temp1, temp2;
188 
189   const int32_t round_const_0 = (1 << (bd + FILTER_BITS - 1));
190   const int32_t round_const_1 = (1 << (bd + 1 + FILTER_BITS - round0_bits)) - 1;
191   const int32x4_t round_bits = vdupq_n_s32(-round0_bits);
192   const int32x4_t zero = vdupq_n_s32(0);
193   const int32x4_t round_vec_0 = vdupq_n_s32(round_const_0);
194   const int32x4_t round_vec_1 = vdupq_n_s32(round_const_1);
195 
196   temp0 = vadd_s16(s0, s6);
197   temp1 = vadd_s16(s1, s5);
198   temp2 = vadd_s16(s2, s4);
199 
200   sum = vmul_n_s16(temp0, filter_x[0]);
201   sum = vmla_n_s16(sum, temp1, filter_x[1]);
202   sum = vmla_n_s16(sum, temp2, filter_x[2]);
203   sum_0 = vmovl_s16(sum);
204 
205   /* s[3]*128 -- and filter coff max can be 128.
206    * then max value possible = 128*128*255 Therefore, 32 bits are required to
207    * hold the result.
208    */
209   s3_0 = vmull_n_s16(s3, filter_x[3]);
210   sum_0 = vaddq_s32(sum_0, s3_0);
211 
212   sum_0 = vaddq_s32(sum_0, round_vec_0);
213   sum_0 = vrshlq_s32(sum_0, round_bits);
214 
215   sum_0 = vmaxq_s32(sum_0, zero);
216   sum_0 = vminq_s32(sum_0, round_vec_1);
217   res = vqmovun_s32(sum_0);
218   return res;
219 }
220 
221 static INLINE int16x8_t
convolve8_8x8_s16(const int16x8_t s0,const int16x8_t s1,const int16x8_t s2,const int16x8_t s3,const int16x8_t s4,const int16x8_t s5,const int16x8_t s6,const int16x8_t s7,const int16_t * filter,const int16x8_t horiz_const,const int16x8_t shift_round_0)222 convolve8_8x8_s16(const int16x8_t s0, const int16x8_t s1, const int16x8_t s2,
223                   const int16x8_t s3, const int16x8_t s4, const int16x8_t s5,
224                   const int16x8_t s6, const int16x8_t s7, const int16_t *filter,
225                   const int16x8_t horiz_const, const int16x8_t shift_round_0) {
226   int16x8_t sum;
227   int16x8_t res;
228 
229   sum = horiz_const;
230   sum = vmlaq_n_s16(sum, s0, filter[0]);
231   sum = vmlaq_n_s16(sum, s1, filter[1]);
232   sum = vmlaq_n_s16(sum, s2, filter[2]);
233   sum = vmlaq_n_s16(sum, s3, filter[3]);
234   sum = vmlaq_n_s16(sum, s4, filter[4]);
235   sum = vmlaq_n_s16(sum, s5, filter[5]);
236   sum = vmlaq_n_s16(sum, s6, filter[6]);
237   sum = vmlaq_n_s16(sum, s7, filter[7]);
238 
239   res = vqrshlq_s16(sum, shift_round_0);
240 
241   return res;
242 }
243 
244 static INLINE int16x4_t
convolve8_4x4_s16(const int16x4_t s0,const int16x4_t s1,const int16x4_t s2,const int16x4_t s3,const int16x4_t s4,const int16x4_t s5,const int16x4_t s6,const int16x4_t s7,const int16_t * filter,const int16x4_t horiz_const,const int16x4_t shift_round_0)245 convolve8_4x4_s16(const int16x4_t s0, const int16x4_t s1, const int16x4_t s2,
246                   const int16x4_t s3, const int16x4_t s4, const int16x4_t s5,
247                   const int16x4_t s6, const int16x4_t s7, const int16_t *filter,
248                   const int16x4_t horiz_const, const int16x4_t shift_round_0) {
249   int16x4_t sum;
250   sum = horiz_const;
251   sum = vmla_n_s16(sum, s0, filter[0]);
252   sum = vmla_n_s16(sum, s1, filter[1]);
253   sum = vmla_n_s16(sum, s2, filter[2]);
254   sum = vmla_n_s16(sum, s3, filter[3]);
255   sum = vmla_n_s16(sum, s4, filter[4]);
256   sum = vmla_n_s16(sum, s5, filter[5]);
257   sum = vmla_n_s16(sum, s6, filter[6]);
258   sum = vmla_n_s16(sum, s7, filter[7]);
259 
260   sum = vqrshl_s16(sum, shift_round_0);
261 
262   return sum;
263 }
264 
convolve8_4x4_s32(const int16x4_t s0,const int16x4_t s1,const int16x4_t s2,const int16x4_t s3,const int16x4_t s4,const int16x4_t s5,const int16x4_t s6,const int16x4_t s7,const int16_t * y_filter,const int32x4_t round_shift_vec,const int32x4_t offset_const)265 static INLINE uint16x4_t convolve8_4x4_s32(
266     const int16x4_t s0, const int16x4_t s1, const int16x4_t s2,
267     const int16x4_t s3, const int16x4_t s4, const int16x4_t s5,
268     const int16x4_t s6, const int16x4_t s7, const int16_t *y_filter,
269     const int32x4_t round_shift_vec, const int32x4_t offset_const) {
270   int32x4_t sum0;
271   uint16x4_t res;
272   const int32x4_t zero = vdupq_n_s32(0);
273 
274   sum0 = vmull_n_s16(s0, y_filter[0]);
275   sum0 = vmlal_n_s16(sum0, s1, y_filter[1]);
276   sum0 = vmlal_n_s16(sum0, s2, y_filter[2]);
277   sum0 = vmlal_n_s16(sum0, s3, y_filter[3]);
278   sum0 = vmlal_n_s16(sum0, s4, y_filter[4]);
279   sum0 = vmlal_n_s16(sum0, s5, y_filter[5]);
280   sum0 = vmlal_n_s16(sum0, s6, y_filter[6]);
281   sum0 = vmlal_n_s16(sum0, s7, y_filter[7]);
282 
283   sum0 = vaddq_s32(sum0, offset_const);
284   sum0 = vqrshlq_s32(sum0, round_shift_vec);
285   sum0 = vmaxq_s32(sum0, zero);
286   res = vmovn_u32(vreinterpretq_u32_s32(sum0));
287 
288   return res;
289 }
290 
291 #endif  // AOM_AV1_COMMON_ARM_CONVOLVE_NEON_H_
292