1 // Copyright 2021 The libgav1 Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "src/dsp/intrapred_filter.h"
16 #include "src/utils/cpu.h"
17 
18 #if LIBGAV1_ENABLE_NEON
19 
20 #include <arm_neon.h>
21 
22 #include <cassert>
23 #include <cstddef>
24 #include <cstdint>
25 
26 #include "src/dsp/arm/common_neon.h"
27 #include "src/dsp/constants.h"
28 #include "src/dsp/dsp.h"
29 #include "src/utils/common.h"
30 
31 namespace libgav1 {
32 namespace dsp {
33 
34 namespace low_bitdepth {
35 namespace {
36 
37 // Transpose kFilterIntraTaps and convert the first row to unsigned values.
38 //
39 // With the previous orientation we were able to multiply all the input values
40 // by a single tap. This required that all the input values be in one vector
41 // which requires expensive set up operations (shifts, vext, vtbl). All the
42 // elements of the result needed to be summed (easy on A64 - vaddvq_s16) but
43 // then the shifting, rounding, and clamping was done in GP registers.
44 //
45 // Switching to unsigned values allows multiplying the 8 bit inputs directly.
46 // When one value was negative we needed to vmovl_u8 first so that the results
47 // maintained the proper sign.
48 //
49 // We take this into account when summing the values by subtracting the product
50 // of the first row.
51 alignas(8) constexpr uint8_t kTransposedTaps[kNumFilterIntraPredictors][7][8] =
52     {{{6, 5, 3, 3, 4, 3, 3, 3},  // Original values are negative.
53       {10, 2, 1, 1, 6, 2, 2, 1},
54       {0, 10, 1, 1, 0, 6, 2, 2},
55       {0, 0, 10, 2, 0, 0, 6, 2},
56       {0, 0, 0, 10, 0, 0, 0, 6},
57       {12, 9, 7, 5, 2, 2, 2, 3},
58       {0, 0, 0, 0, 12, 9, 7, 5}},
59      {{10, 6, 4, 2, 10, 6, 4, 2},  // Original values are negative.
60       {16, 0, 0, 0, 16, 0, 0, 0},
61       {0, 16, 0, 0, 0, 16, 0, 0},
62       {0, 0, 16, 0, 0, 0, 16, 0},
63       {0, 0, 0, 16, 0, 0, 0, 16},
64       {10, 6, 4, 2, 0, 0, 0, 0},
65       {0, 0, 0, 0, 10, 6, 4, 2}},
66      {{8, 8, 8, 8, 4, 4, 4, 4},  // Original values are negative.
67       {8, 0, 0, 0, 4, 0, 0, 0},
68       {0, 8, 0, 0, 0, 4, 0, 0},
69       {0, 0, 8, 0, 0, 0, 4, 0},
70       {0, 0, 0, 8, 0, 0, 0, 4},
71       {16, 16, 16, 16, 0, 0, 0, 0},
72       {0, 0, 0, 0, 16, 16, 16, 16}},
73      {{2, 1, 1, 0, 1, 1, 1, 1},  // Original values are negative.
74       {8, 3, 2, 1, 4, 3, 2, 2},
75       {0, 8, 3, 2, 0, 4, 3, 2},
76       {0, 0, 8, 3, 0, 0, 4, 3},
77       {0, 0, 0, 8, 0, 0, 0, 4},
78       {10, 6, 4, 2, 3, 4, 4, 3},
79       {0, 0, 0, 0, 10, 6, 4, 3}},
80      {{12, 10, 9, 8, 10, 9, 8, 7},  // Original values are negative.
81       {14, 0, 0, 0, 12, 1, 0, 0},
82       {0, 14, 0, 0, 0, 12, 0, 0},
83       {0, 0, 14, 0, 0, 0, 12, 1},
84       {0, 0, 0, 14, 0, 0, 0, 12},
85       {14, 12, 11, 10, 0, 0, 1, 1},
86       {0, 0, 0, 0, 14, 12, 11, 9}}};
87 
FilterIntraPredictor_NEON(void * LIBGAV1_RESTRICT const dest,ptrdiff_t stride,const void * LIBGAV1_RESTRICT const top_row,const void * LIBGAV1_RESTRICT const left_column,FilterIntraPredictor pred,int width,int height)88 void FilterIntraPredictor_NEON(void* LIBGAV1_RESTRICT const dest,
89                                ptrdiff_t stride,
90                                const void* LIBGAV1_RESTRICT const top_row,
91                                const void* LIBGAV1_RESTRICT const left_column,
92                                FilterIntraPredictor pred, int width,
93                                int height) {
94   const auto* const top = static_cast<const uint8_t*>(top_row);
95   const auto* const left = static_cast<const uint8_t*>(left_column);
96 
97   assert(width <= 32 && height <= 32);
98 
99   auto* dst = static_cast<uint8_t*>(dest);
100 
101   uint8x8_t transposed_taps[7];
102   for (int i = 0; i < 7; ++i) {
103     transposed_taps[i] = vld1_u8(kTransposedTaps[pred][i]);
104   }
105 
106   uint8_t relative_top_left = top[-1];
107   const uint8_t* relative_top = top;
108   uint8_t relative_left[2] = {left[0], left[1]};
109 
110   int y = 0;
111   do {
112     uint8_t* row_dst = dst;
113     int x = 0;
114     do {
115       uint16x8_t sum = vdupq_n_u16(0);
116       const uint16x8_t subtrahend =
117           vmull_u8(transposed_taps[0], vdup_n_u8(relative_top_left));
118       for (int i = 1; i < 5; ++i) {
119         sum = vmlal_u8(sum, transposed_taps[i], vdup_n_u8(relative_top[i - 1]));
120       }
121       for (int i = 5; i < 7; ++i) {
122         sum =
123             vmlal_u8(sum, transposed_taps[i], vdup_n_u8(relative_left[i - 5]));
124       }
125 
126       const int16x8_t sum_signed =
127           vreinterpretq_s16_u16(vsubq_u16(sum, subtrahend));
128       const int16x8_t sum_shifted = vrshrq_n_s16(sum_signed, 4);
129 
130       uint8x8_t sum_saturated = vqmovun_s16(sum_shifted);
131 
132       StoreLo4(row_dst, sum_saturated);
133       StoreHi4(row_dst + stride, sum_saturated);
134 
135       // Progress across
136       relative_top_left = relative_top[3];
137       relative_top += 4;
138       relative_left[0] = row_dst[3];
139       relative_left[1] = row_dst[3 + stride];
140       row_dst += 4;
141       x += 4;
142     } while (x < width);
143 
144     // Progress down.
145     relative_top_left = left[y + 1];
146     relative_top = dst + stride;
147     relative_left[0] = left[y + 2];
148     relative_left[1] = left[y + 3];
149 
150     dst += 2 * stride;
151     y += 2;
152   } while (y < height);
153 }
154 
Init8bpp()155 void Init8bpp() {
156   Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
157   assert(dsp != nullptr);
158   dsp->filter_intra_predictor = FilterIntraPredictor_NEON;
159 }
160 
161 }  // namespace
162 }  // namespace low_bitdepth
163 
164 //------------------------------------------------------------------------------
165 #if LIBGAV1_MAX_BITDEPTH >= 10
166 namespace high_bitdepth {
167 namespace {
168 
169 alignas(kMaxAlignment) constexpr int16_t
170     kTransposedTaps[kNumFilterIntraPredictors][7][8] = {
171         {{-6, -5, -3, -3, -4, -3, -3, -3},
172          {10, 2, 1, 1, 6, 2, 2, 1},
173          {0, 10, 1, 1, 0, 6, 2, 2},
174          {0, 0, 10, 2, 0, 0, 6, 2},
175          {0, 0, 0, 10, 0, 0, 0, 6},
176          {12, 9, 7, 5, 2, 2, 2, 3},
177          {0, 0, 0, 0, 12, 9, 7, 5}},
178         {{-10, -6, -4, -2, -10, -6, -4, -2},
179          {16, 0, 0, 0, 16, 0, 0, 0},
180          {0, 16, 0, 0, 0, 16, 0, 0},
181          {0, 0, 16, 0, 0, 0, 16, 0},
182          {0, 0, 0, 16, 0, 0, 0, 16},
183          {10, 6, 4, 2, 0, 0, 0, 0},
184          {0, 0, 0, 0, 10, 6, 4, 2}},
185         {{-8, -8, -8, -8, -4, -4, -4, -4},
186          {8, 0, 0, 0, 4, 0, 0, 0},
187          {0, 8, 0, 0, 0, 4, 0, 0},
188          {0, 0, 8, 0, 0, 0, 4, 0},
189          {0, 0, 0, 8, 0, 0, 0, 4},
190          {16, 16, 16, 16, 0, 0, 0, 0},
191          {0, 0, 0, 0, 16, 16, 16, 16}},
192         {{-2, -1, -1, -0, -1, -1, -1, -1},
193          {8, 3, 2, 1, 4, 3, 2, 2},
194          {0, 8, 3, 2, 0, 4, 3, 2},
195          {0, 0, 8, 3, 0, 0, 4, 3},
196          {0, 0, 0, 8, 0, 0, 0, 4},
197          {10, 6, 4, 2, 3, 4, 4, 3},
198          {0, 0, 0, 0, 10, 6, 4, 3}},
199         {{-12, -10, -9, -8, -10, -9, -8, -7},
200          {14, 0, 0, 0, 12, 1, 0, 0},
201          {0, 14, 0, 0, 0, 12, 0, 0},
202          {0, 0, 14, 0, 0, 0, 12, 1},
203          {0, 0, 0, 14, 0, 0, 0, 12},
204          {14, 12, 11, 10, 0, 0, 1, 1},
205          {0, 0, 0, 0, 14, 12, 11, 9}}};
206 
FilterIntraPredictor_NEON(void * LIBGAV1_RESTRICT const dest,ptrdiff_t stride,const void * LIBGAV1_RESTRICT const top_row,const void * LIBGAV1_RESTRICT const left_column,FilterIntraPredictor pred,int width,int height)207 void FilterIntraPredictor_NEON(void* LIBGAV1_RESTRICT const dest,
208                                ptrdiff_t stride,
209                                const void* LIBGAV1_RESTRICT const top_row,
210                                const void* LIBGAV1_RESTRICT const left_column,
211                                FilterIntraPredictor pred, int width,
212                                int height) {
213   const auto* const top = static_cast<const uint16_t*>(top_row);
214   const auto* const left = static_cast<const uint16_t*>(left_column);
215 
216   assert(width <= 32 && height <= 32);
217 
218   auto* dst = static_cast<uint16_t*>(dest);
219 
220   stride >>= 1;
221 
222   int16x8_t transposed_taps[7];
223   for (int i = 0; i < 7; ++i) {
224     transposed_taps[i] = vld1q_s16(kTransposedTaps[pred][i]);
225   }
226 
227   uint16_t relative_top_left = top[-1];
228   const uint16_t* relative_top = top;
229   uint16_t relative_left[2] = {left[0], left[1]};
230 
231   int y = 0;
232   do {
233     uint16_t* row_dst = dst;
234     int x = 0;
235     do {
236       int16x8_t sum =
237           vmulq_s16(transposed_taps[0],
238                     vreinterpretq_s16_u16(vdupq_n_u16(relative_top_left)));
239       for (int i = 1; i < 5; ++i) {
240         sum =
241             vmlaq_s16(sum, transposed_taps[i],
242                       vreinterpretq_s16_u16(vdupq_n_u16(relative_top[i - 1])));
243       }
244       for (int i = 5; i < 7; ++i) {
245         sum =
246             vmlaq_s16(sum, transposed_taps[i],
247                       vreinterpretq_s16_u16(vdupq_n_u16(relative_left[i - 5])));
248       }
249 
250       const int16x8_t sum_shifted = vrshrq_n_s16(sum, 4);
251       const uint16x8_t sum_saturated = vminq_u16(
252           vreinterpretq_u16_s16(vmaxq_s16(sum_shifted, vdupq_n_s16(0))),
253           vdupq_n_u16((1 << kBitdepth10) - 1));
254 
255       vst1_u16(row_dst, vget_low_u16(sum_saturated));
256       vst1_u16(row_dst + stride, vget_high_u16(sum_saturated));
257 
258       // Progress across
259       relative_top_left = relative_top[3];
260       relative_top += 4;
261       relative_left[0] = row_dst[3];
262       relative_left[1] = row_dst[3 + stride];
263       row_dst += 4;
264       x += 4;
265     } while (x < width);
266 
267     // Progress down.
268     relative_top_left = left[y + 1];
269     relative_top = dst + stride;
270     relative_left[0] = left[y + 2];
271     relative_left[1] = left[y + 3];
272 
273     dst += 2 * stride;
274     y += 2;
275   } while (y < height);
276 }
277 
Init10bpp()278 void Init10bpp() {
279   Dsp* dsp = dsp_internal::GetWritableDspTable(kBitdepth10);
280   assert(dsp != nullptr);
281   dsp->filter_intra_predictor = FilterIntraPredictor_NEON;
282 }
283 
284 }  // namespace
285 }  // namespace high_bitdepth
286 #endif  // LIBGAV1_MAX_BITDEPTH >= 10
287 
IntraPredFilterInit_NEON()288 void IntraPredFilterInit_NEON() {
289   low_bitdepth::Init8bpp();
290 #if LIBGAV1_MAX_BITDEPTH >= 10
291   high_bitdepth::Init10bpp();
292 #endif
293 }
294 
295 }  // namespace dsp
296 }  // namespace libgav1
297 
298 #else   // !LIBGAV1_ENABLE_NEON
299 namespace libgav1 {
300 namespace dsp {
301 
IntraPredFilterInit_NEON()302 void IntraPredFilterInit_NEON() {}
303 
304 }  // namespace dsp
305 }  // namespace libgav1
306 #endif  // LIBGAV1_ENABLE_NEON
307