1 // Copyright 2019 The libgav1 Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "src/dsp/loop_restoration.h"
16 #include "src/utils/cpu.h"
17
18 #if LIBGAV1_ENABLE_NEON
19 #include <arm_neon.h>
20
21 #include <algorithm>
22 #include <cassert>
23 #include <cstddef>
24 #include <cstdint>
25 #include <cstring>
26
27 #include "src/dsp/arm/common_neon.h"
28 #include "src/dsp/constants.h"
29 #include "src/dsp/dsp.h"
30 #include "src/utils/common.h"
31 #include "src/utils/constants.h"
32
33 namespace libgav1 {
34 namespace dsp {
35 namespace low_bitdepth {
36 namespace {
37
38 template <int bytes>
VshrU128(const uint8x8x2_t src)39 inline uint8x8_t VshrU128(const uint8x8x2_t src) {
40 return vext_u8(src.val[0], src.val[1], bytes);
41 }
42
43 template <int bytes>
VshrU128(const uint16x8x2_t src)44 inline uint16x8_t VshrU128(const uint16x8x2_t src) {
45 return vextq_u16(src.val[0], src.val[1], bytes / 2);
46 }
47
48 // Wiener
49
50 // Must make a local copy of coefficients to help compiler know that they have
51 // no overlap with other buffers. Using 'const' keyword is not enough. Actually
52 // compiler doesn't make a copy, since there is enough registers in this case.
PopulateWienerCoefficients(const RestorationUnitInfo & restoration_info,const int direction,int16_t filter[4])53 inline void PopulateWienerCoefficients(
54 const RestorationUnitInfo& restoration_info, const int direction,
55 int16_t filter[4]) {
56 // In order to keep the horizontal pass intermediate values within 16 bits we
57 // offset |filter[3]| by 128. The 128 offset will be added back in the loop.
58 for (int i = 0; i < 4; ++i) {
59 filter[i] = restoration_info.wiener_info.filter[direction][i];
60 }
61 if (direction == WienerInfo::kHorizontal) {
62 filter[3] -= 128;
63 }
64 }
65
WienerHorizontal2(const uint8x8_t s0,const uint8x8_t s1,const int16_t filter,const int16x8_t sum)66 inline int16x8_t WienerHorizontal2(const uint8x8_t s0, const uint8x8_t s1,
67 const int16_t filter, const int16x8_t sum) {
68 const int16x8_t ss = vreinterpretq_s16_u16(vaddl_u8(s0, s1));
69 return vmlaq_n_s16(sum, ss, filter);
70 }
71
WienerHorizontal2(const uint8x16_t s0,const uint8x16_t s1,const int16_t filter,const int16x8x2_t sum)72 inline int16x8x2_t WienerHorizontal2(const uint8x16_t s0, const uint8x16_t s1,
73 const int16_t filter,
74 const int16x8x2_t sum) {
75 int16x8x2_t d;
76 d.val[0] =
77 WienerHorizontal2(vget_low_u8(s0), vget_low_u8(s1), filter, sum.val[0]);
78 d.val[1] =
79 WienerHorizontal2(vget_high_u8(s0), vget_high_u8(s1), filter, sum.val[1]);
80 return d;
81 }
82
WienerHorizontalSum(const uint8x8_t s[3],const int16_t filter[4],int16x8_t sum,int16_t * const wiener_buffer)83 inline void WienerHorizontalSum(const uint8x8_t s[3], const int16_t filter[4],
84 int16x8_t sum, int16_t* const wiener_buffer) {
85 constexpr int offset =
86 1 << (8 + kWienerFilterBits - kInterRoundBitsHorizontal - 1);
87 constexpr int limit = (offset << 2) - 1;
88 const int16x8_t s_0_2 = vreinterpretq_s16_u16(vaddl_u8(s[0], s[2]));
89 const int16x8_t s_1 = ZeroExtend(s[1]);
90 sum = vmlaq_n_s16(sum, s_0_2, filter[2]);
91 sum = vmlaq_n_s16(sum, s_1, filter[3]);
92 // Calculate scaled down offset correction, and add to sum here to prevent
93 // signed 16 bit outranging.
94 sum = vrsraq_n_s16(vshlq_n_s16(s_1, 7 - kInterRoundBitsHorizontal), sum,
95 kInterRoundBitsHorizontal);
96 sum = vmaxq_s16(sum, vdupq_n_s16(-offset));
97 sum = vminq_s16(sum, vdupq_n_s16(limit - offset));
98 vst1q_s16(wiener_buffer, sum);
99 }
100
WienerHorizontalSum(const uint8x16_t src[3],const int16_t filter[4],int16x8x2_t sum,int16_t * const wiener_buffer)101 inline void WienerHorizontalSum(const uint8x16_t src[3],
102 const int16_t filter[4], int16x8x2_t sum,
103 int16_t* const wiener_buffer) {
104 uint8x8_t s[3];
105 s[0] = vget_low_u8(src[0]);
106 s[1] = vget_low_u8(src[1]);
107 s[2] = vget_low_u8(src[2]);
108 WienerHorizontalSum(s, filter, sum.val[0], wiener_buffer);
109 s[0] = vget_high_u8(src[0]);
110 s[1] = vget_high_u8(src[1]);
111 s[2] = vget_high_u8(src[2]);
112 WienerHorizontalSum(s, filter, sum.val[1], wiener_buffer + 8);
113 }
114
WienerHorizontalTap7(const uint8_t * src,const ptrdiff_t src_stride,const ptrdiff_t width,const int height,const int16_t filter[4],int16_t ** const wiener_buffer)115 inline void WienerHorizontalTap7(const uint8_t* src, const ptrdiff_t src_stride,
116 const ptrdiff_t width, const int height,
117 const int16_t filter[4],
118 int16_t** const wiener_buffer) {
119 for (int y = height; y != 0; --y) {
120 const uint8_t* src_ptr = src;
121 uint8x16_t s[8];
122 s[0] = vld1q_u8(src_ptr);
123 ptrdiff_t x = width;
124 do {
125 src_ptr += 16;
126 s[7] = vld1q_u8(src_ptr);
127 s[1] = vextq_u8(s[0], s[7], 1);
128 s[2] = vextq_u8(s[0], s[7], 2);
129 s[3] = vextq_u8(s[0], s[7], 3);
130 s[4] = vextq_u8(s[0], s[7], 4);
131 s[5] = vextq_u8(s[0], s[7], 5);
132 s[6] = vextq_u8(s[0], s[7], 6);
133 int16x8x2_t sum;
134 sum.val[0] = sum.val[1] = vdupq_n_s16(0);
135 sum = WienerHorizontal2(s[0], s[6], filter[0], sum);
136 sum = WienerHorizontal2(s[1], s[5], filter[1], sum);
137 WienerHorizontalSum(s + 2, filter, sum, *wiener_buffer);
138 s[0] = s[7];
139 *wiener_buffer += 16;
140 x -= 16;
141 } while (x != 0);
142 src += src_stride;
143 }
144 }
145
WienerHorizontalTap5(const uint8_t * src,const ptrdiff_t src_stride,const ptrdiff_t width,const int height,const int16_t filter[4],int16_t ** const wiener_buffer)146 inline void WienerHorizontalTap5(const uint8_t* src, const ptrdiff_t src_stride,
147 const ptrdiff_t width, const int height,
148 const int16_t filter[4],
149 int16_t** const wiener_buffer) {
150 for (int y = height; y != 0; --y) {
151 const uint8_t* src_ptr = src;
152 uint8x16_t s[6];
153 s[0] = vld1q_u8(src_ptr);
154 ptrdiff_t x = width;
155 do {
156 src_ptr += 16;
157 s[5] = vld1q_u8(src_ptr);
158 s[1] = vextq_u8(s[0], s[5], 1);
159 s[2] = vextq_u8(s[0], s[5], 2);
160 s[3] = vextq_u8(s[0], s[5], 3);
161 s[4] = vextq_u8(s[0], s[5], 4);
162 int16x8x2_t sum;
163 sum.val[0] = sum.val[1] = vdupq_n_s16(0);
164 sum = WienerHorizontal2(s[0], s[4], filter[1], sum);
165 WienerHorizontalSum(s + 1, filter, sum, *wiener_buffer);
166 s[0] = s[5];
167 *wiener_buffer += 16;
168 x -= 16;
169 } while (x != 0);
170 src += src_stride;
171 }
172 }
173
WienerHorizontalTap3(const uint8_t * src,const ptrdiff_t src_stride,const ptrdiff_t width,const int height,const int16_t filter[4],int16_t ** const wiener_buffer)174 inline void WienerHorizontalTap3(const uint8_t* src, const ptrdiff_t src_stride,
175 const ptrdiff_t width, const int height,
176 const int16_t filter[4],
177 int16_t** const wiener_buffer) {
178 for (int y = height; y != 0; --y) {
179 const uint8_t* src_ptr = src;
180 uint8x16_t s[4];
181 s[0] = vld1q_u8(src_ptr);
182 ptrdiff_t x = width;
183 do {
184 src_ptr += 16;
185 s[3] = vld1q_u8(src_ptr);
186 s[1] = vextq_u8(s[0], s[3], 1);
187 s[2] = vextq_u8(s[0], s[3], 2);
188 int16x8x2_t sum;
189 sum.val[0] = sum.val[1] = vdupq_n_s16(0);
190 WienerHorizontalSum(s, filter, sum, *wiener_buffer);
191 s[0] = s[3];
192 *wiener_buffer += 16;
193 x -= 16;
194 } while (x != 0);
195 src += src_stride;
196 }
197 }
198
WienerHorizontalTap1(const uint8_t * src,const ptrdiff_t src_stride,const ptrdiff_t width,const int height,int16_t ** const wiener_buffer)199 inline void WienerHorizontalTap1(const uint8_t* src, const ptrdiff_t src_stride,
200 const ptrdiff_t width, const int height,
201 int16_t** const wiener_buffer) {
202 for (int y = height; y != 0; --y) {
203 const uint8_t* src_ptr = src;
204 ptrdiff_t x = width;
205 do {
206 const uint8x16_t s = vld1q_u8(src_ptr);
207 const uint8x8_t s0 = vget_low_u8(s);
208 const uint8x8_t s1 = vget_high_u8(s);
209 const int16x8_t d0 = vreinterpretq_s16_u16(vshll_n_u8(s0, 4));
210 const int16x8_t d1 = vreinterpretq_s16_u16(vshll_n_u8(s1, 4));
211 vst1q_s16(*wiener_buffer + 0, d0);
212 vst1q_s16(*wiener_buffer + 8, d1);
213 src_ptr += 16;
214 *wiener_buffer += 16;
215 x -= 16;
216 } while (x != 0);
217 src += src_stride;
218 }
219 }
220
WienerVertical2(const int16x8_t a0,const int16x8_t a1,const int16_t filter,const int32x4x2_t sum)221 inline int32x4x2_t WienerVertical2(const int16x8_t a0, const int16x8_t a1,
222 const int16_t filter,
223 const int32x4x2_t sum) {
224 const int16x8_t a = vaddq_s16(a0, a1);
225 int32x4x2_t d;
226 d.val[0] = vmlal_n_s16(sum.val[0], vget_low_s16(a), filter);
227 d.val[1] = vmlal_n_s16(sum.val[1], vget_high_s16(a), filter);
228 return d;
229 }
230
WienerVertical(const int16x8_t a[3],const int16_t filter[4],const int32x4x2_t sum)231 inline uint8x8_t WienerVertical(const int16x8_t a[3], const int16_t filter[4],
232 const int32x4x2_t sum) {
233 int32x4x2_t d = WienerVertical2(a[0], a[2], filter[2], sum);
234 d.val[0] = vmlal_n_s16(d.val[0], vget_low_s16(a[1]), filter[3]);
235 d.val[1] = vmlal_n_s16(d.val[1], vget_high_s16(a[1]), filter[3]);
236 const uint16x4_t sum_lo_16 = vqrshrun_n_s32(d.val[0], 11);
237 const uint16x4_t sum_hi_16 = vqrshrun_n_s32(d.val[1], 11);
238 return vqmovn_u16(vcombine_u16(sum_lo_16, sum_hi_16));
239 }
240
WienerVerticalTap7Kernel(const int16_t * const wiener_buffer,const ptrdiff_t wiener_stride,const int16_t filter[4],int16x8_t a[7])241 inline uint8x8_t WienerVerticalTap7Kernel(const int16_t* const wiener_buffer,
242 const ptrdiff_t wiener_stride,
243 const int16_t filter[4],
244 int16x8_t a[7]) {
245 int32x4x2_t sum;
246 a[0] = vld1q_s16(wiener_buffer + 0 * wiener_stride);
247 a[1] = vld1q_s16(wiener_buffer + 1 * wiener_stride);
248 a[5] = vld1q_s16(wiener_buffer + 5 * wiener_stride);
249 a[6] = vld1q_s16(wiener_buffer + 6 * wiener_stride);
250 sum.val[0] = sum.val[1] = vdupq_n_s32(0);
251 sum = WienerVertical2(a[0], a[6], filter[0], sum);
252 sum = WienerVertical2(a[1], a[5], filter[1], sum);
253 a[2] = vld1q_s16(wiener_buffer + 2 * wiener_stride);
254 a[3] = vld1q_s16(wiener_buffer + 3 * wiener_stride);
255 a[4] = vld1q_s16(wiener_buffer + 4 * wiener_stride);
256 return WienerVertical(a + 2, filter, sum);
257 }
258
WienerVerticalTap7Kernel2(const int16_t * const wiener_buffer,const ptrdiff_t wiener_stride,const int16_t filter[4])259 inline uint8x8x2_t WienerVerticalTap7Kernel2(const int16_t* const wiener_buffer,
260 const ptrdiff_t wiener_stride,
261 const int16_t filter[4]) {
262 int16x8_t a[8];
263 int32x4x2_t sum;
264 uint8x8x2_t d;
265 d.val[0] = WienerVerticalTap7Kernel(wiener_buffer, wiener_stride, filter, a);
266 a[7] = vld1q_s16(wiener_buffer + 7 * wiener_stride);
267 sum.val[0] = sum.val[1] = vdupq_n_s32(0);
268 sum = WienerVertical2(a[1], a[7], filter[0], sum);
269 sum = WienerVertical2(a[2], a[6], filter[1], sum);
270 d.val[1] = WienerVertical(a + 3, filter, sum);
271 return d;
272 }
273
WienerVerticalTap7(const int16_t * wiener_buffer,const ptrdiff_t width,const int height,const int16_t filter[4],uint8_t * dst,const ptrdiff_t dst_stride)274 inline void WienerVerticalTap7(const int16_t* wiener_buffer,
275 const ptrdiff_t width, const int height,
276 const int16_t filter[4], uint8_t* dst,
277 const ptrdiff_t dst_stride) {
278 for (int y = height >> 1; y != 0; --y) {
279 uint8_t* dst_ptr = dst;
280 ptrdiff_t x = width;
281 do {
282 uint8x8x2_t d[2];
283 d[0] = WienerVerticalTap7Kernel2(wiener_buffer + 0, width, filter);
284 d[1] = WienerVerticalTap7Kernel2(wiener_buffer + 8, width, filter);
285 vst1q_u8(dst_ptr, vcombine_u8(d[0].val[0], d[1].val[0]));
286 vst1q_u8(dst_ptr + dst_stride, vcombine_u8(d[0].val[1], d[1].val[1]));
287 wiener_buffer += 16;
288 dst_ptr += 16;
289 x -= 16;
290 } while (x != 0);
291 wiener_buffer += width;
292 dst += 2 * dst_stride;
293 }
294
295 if ((height & 1) != 0) {
296 ptrdiff_t x = width;
297 do {
298 int16x8_t a[7];
299 const uint8x8_t d0 =
300 WienerVerticalTap7Kernel(wiener_buffer + 0, width, filter, a);
301 const uint8x8_t d1 =
302 WienerVerticalTap7Kernel(wiener_buffer + 8, width, filter, a);
303 vst1q_u8(dst, vcombine_u8(d0, d1));
304 wiener_buffer += 16;
305 dst += 16;
306 x -= 16;
307 } while (x != 0);
308 }
309 }
310
WienerVerticalTap5Kernel(const int16_t * const wiener_buffer,const ptrdiff_t wiener_stride,const int16_t filter[4],int16x8_t a[5])311 inline uint8x8_t WienerVerticalTap5Kernel(const int16_t* const wiener_buffer,
312 const ptrdiff_t wiener_stride,
313 const int16_t filter[4],
314 int16x8_t a[5]) {
315 a[0] = vld1q_s16(wiener_buffer + 0 * wiener_stride);
316 a[1] = vld1q_s16(wiener_buffer + 1 * wiener_stride);
317 a[2] = vld1q_s16(wiener_buffer + 2 * wiener_stride);
318 a[3] = vld1q_s16(wiener_buffer + 3 * wiener_stride);
319 a[4] = vld1q_s16(wiener_buffer + 4 * wiener_stride);
320 int32x4x2_t sum;
321 sum.val[0] = sum.val[1] = vdupq_n_s32(0);
322 sum = WienerVertical2(a[0], a[4], filter[1], sum);
323 return WienerVertical(a + 1, filter, sum);
324 }
325
WienerVerticalTap5Kernel2(const int16_t * const wiener_buffer,const ptrdiff_t wiener_stride,const int16_t filter[4])326 inline uint8x8x2_t WienerVerticalTap5Kernel2(const int16_t* const wiener_buffer,
327 const ptrdiff_t wiener_stride,
328 const int16_t filter[4]) {
329 int16x8_t a[6];
330 int32x4x2_t sum;
331 uint8x8x2_t d;
332 d.val[0] = WienerVerticalTap5Kernel(wiener_buffer, wiener_stride, filter, a);
333 a[5] = vld1q_s16(wiener_buffer + 5 * wiener_stride);
334 sum.val[0] = sum.val[1] = vdupq_n_s32(0);
335 sum = WienerVertical2(a[1], a[5], filter[1], sum);
336 d.val[1] = WienerVertical(a + 2, filter, sum);
337 return d;
338 }
339
WienerVerticalTap5(const int16_t * wiener_buffer,const ptrdiff_t width,const int height,const int16_t filter[4],uint8_t * dst,const ptrdiff_t dst_stride)340 inline void WienerVerticalTap5(const int16_t* wiener_buffer,
341 const ptrdiff_t width, const int height,
342 const int16_t filter[4], uint8_t* dst,
343 const ptrdiff_t dst_stride) {
344 for (int y = height >> 1; y != 0; --y) {
345 uint8_t* dst_ptr = dst;
346 ptrdiff_t x = width;
347 do {
348 uint8x8x2_t d[2];
349 d[0] = WienerVerticalTap5Kernel2(wiener_buffer + 0, width, filter);
350 d[1] = WienerVerticalTap5Kernel2(wiener_buffer + 8, width, filter);
351 vst1q_u8(dst_ptr, vcombine_u8(d[0].val[0], d[1].val[0]));
352 vst1q_u8(dst_ptr + dst_stride, vcombine_u8(d[0].val[1], d[1].val[1]));
353 wiener_buffer += 16;
354 dst_ptr += 16;
355 x -= 16;
356 } while (x != 0);
357 wiener_buffer += width;
358 dst += 2 * dst_stride;
359 }
360
361 if ((height & 1) != 0) {
362 ptrdiff_t x = width;
363 do {
364 int16x8_t a[5];
365 const uint8x8_t d0 =
366 WienerVerticalTap5Kernel(wiener_buffer + 0, width, filter, a);
367 const uint8x8_t d1 =
368 WienerVerticalTap5Kernel(wiener_buffer + 8, width, filter, a);
369 vst1q_u8(dst, vcombine_u8(d0, d1));
370 wiener_buffer += 16;
371 dst += 16;
372 x -= 16;
373 } while (x != 0);
374 }
375 }
376
WienerVerticalTap3Kernel(const int16_t * const wiener_buffer,const ptrdiff_t wiener_stride,const int16_t filter[4],int16x8_t a[3])377 inline uint8x8_t WienerVerticalTap3Kernel(const int16_t* const wiener_buffer,
378 const ptrdiff_t wiener_stride,
379 const int16_t filter[4],
380 int16x8_t a[3]) {
381 a[0] = vld1q_s16(wiener_buffer + 0 * wiener_stride);
382 a[1] = vld1q_s16(wiener_buffer + 1 * wiener_stride);
383 a[2] = vld1q_s16(wiener_buffer + 2 * wiener_stride);
384 int32x4x2_t sum;
385 sum.val[0] = sum.val[1] = vdupq_n_s32(0);
386 return WienerVertical(a, filter, sum);
387 }
388
WienerVerticalTap3Kernel2(const int16_t * const wiener_buffer,const ptrdiff_t wiener_stride,const int16_t filter[4])389 inline uint8x8x2_t WienerVerticalTap3Kernel2(const int16_t* const wiener_buffer,
390 const ptrdiff_t wiener_stride,
391 const int16_t filter[4]) {
392 int16x8_t a[4];
393 int32x4x2_t sum;
394 uint8x8x2_t d;
395 d.val[0] = WienerVerticalTap3Kernel(wiener_buffer, wiener_stride, filter, a);
396 a[3] = vld1q_s16(wiener_buffer + 3 * wiener_stride);
397 sum.val[0] = sum.val[1] = vdupq_n_s32(0);
398 d.val[1] = WienerVertical(a + 1, filter, sum);
399 return d;
400 }
401
WienerVerticalTap3(const int16_t * wiener_buffer,const ptrdiff_t width,const int height,const int16_t filter[4],uint8_t * dst,const ptrdiff_t dst_stride)402 inline void WienerVerticalTap3(const int16_t* wiener_buffer,
403 const ptrdiff_t width, const int height,
404 const int16_t filter[4], uint8_t* dst,
405 const ptrdiff_t dst_stride) {
406 for (int y = height >> 1; y != 0; --y) {
407 uint8_t* dst_ptr = dst;
408 ptrdiff_t x = width;
409 do {
410 uint8x8x2_t d[2];
411 d[0] = WienerVerticalTap3Kernel2(wiener_buffer + 0, width, filter);
412 d[1] = WienerVerticalTap3Kernel2(wiener_buffer + 8, width, filter);
413 vst1q_u8(dst_ptr, vcombine_u8(d[0].val[0], d[1].val[0]));
414 vst1q_u8(dst_ptr + dst_stride, vcombine_u8(d[0].val[1], d[1].val[1]));
415 wiener_buffer += 16;
416 dst_ptr += 16;
417 x -= 16;
418 } while (x != 0);
419 wiener_buffer += width;
420 dst += 2 * dst_stride;
421 }
422
423 if ((height & 1) != 0) {
424 ptrdiff_t x = width;
425 do {
426 int16x8_t a[3];
427 const uint8x8_t d0 =
428 WienerVerticalTap3Kernel(wiener_buffer + 0, width, filter, a);
429 const uint8x8_t d1 =
430 WienerVerticalTap3Kernel(wiener_buffer + 8, width, filter, a);
431 vst1q_u8(dst, vcombine_u8(d0, d1));
432 wiener_buffer += 16;
433 dst += 16;
434 x -= 16;
435 } while (x != 0);
436 }
437 }
438
WienerVerticalTap1Kernel(const int16_t * const wiener_buffer,uint8_t * const dst)439 inline void WienerVerticalTap1Kernel(const int16_t* const wiener_buffer,
440 uint8_t* const dst) {
441 const int16x8_t a0 = vld1q_s16(wiener_buffer + 0);
442 const int16x8_t a1 = vld1q_s16(wiener_buffer + 8);
443 const uint8x8_t d0 = vqrshrun_n_s16(a0, 4);
444 const uint8x8_t d1 = vqrshrun_n_s16(a1, 4);
445 vst1q_u8(dst, vcombine_u8(d0, d1));
446 }
447
WienerVerticalTap1(const int16_t * wiener_buffer,const ptrdiff_t width,const int height,uint8_t * dst,const ptrdiff_t dst_stride)448 inline void WienerVerticalTap1(const int16_t* wiener_buffer,
449 const ptrdiff_t width, const int height,
450 uint8_t* dst, const ptrdiff_t dst_stride) {
451 for (int y = height >> 1; y != 0; --y) {
452 uint8_t* dst_ptr = dst;
453 ptrdiff_t x = width;
454 do {
455 WienerVerticalTap1Kernel(wiener_buffer, dst_ptr);
456 WienerVerticalTap1Kernel(wiener_buffer + width, dst_ptr + dst_stride);
457 wiener_buffer += 16;
458 dst_ptr += 16;
459 x -= 16;
460 } while (x != 0);
461 wiener_buffer += width;
462 dst += 2 * dst_stride;
463 }
464
465 if ((height & 1) != 0) {
466 ptrdiff_t x = width;
467 do {
468 WienerVerticalTap1Kernel(wiener_buffer, dst);
469 wiener_buffer += 16;
470 dst += 16;
471 x -= 16;
472 } while (x != 0);
473 }
474 }
475
476 // For width 16 and up, store the horizontal results, and then do the vertical
477 // filter row by row. This is faster than doing it column by column when
478 // considering cache issues.
WienerFilter_NEON(const RestorationUnitInfo & restoration_info,const void * const source,const void * const top_border,const void * const bottom_border,const ptrdiff_t stride,const int width,const int height,RestorationBuffer * const restoration_buffer,void * const dest)479 void WienerFilter_NEON(const RestorationUnitInfo& restoration_info,
480 const void* const source, const void* const top_border,
481 const void* const bottom_border, const ptrdiff_t stride,
482 const int width, const int height,
483 RestorationBuffer* const restoration_buffer,
484 void* const dest) {
485 const int16_t* const number_leading_zero_coefficients =
486 restoration_info.wiener_info.number_leading_zero_coefficients;
487 const int number_rows_to_skip = std::max(
488 static_cast<int>(number_leading_zero_coefficients[WienerInfo::kVertical]),
489 1);
490 const ptrdiff_t wiener_stride = Align(width, 16);
491 int16_t* const wiener_buffer_vertical = restoration_buffer->wiener_buffer;
492 // The values are saturated to 13 bits before storing.
493 int16_t* wiener_buffer_horizontal =
494 wiener_buffer_vertical + number_rows_to_skip * wiener_stride;
495 int16_t filter_horizontal[(kWienerFilterTaps + 1) / 2];
496 int16_t filter_vertical[(kWienerFilterTaps + 1) / 2];
497 PopulateWienerCoefficients(restoration_info, WienerInfo::kHorizontal,
498 filter_horizontal);
499 PopulateWienerCoefficients(restoration_info, WienerInfo::kVertical,
500 filter_vertical);
501
502 // horizontal filtering.
503 // Over-reads up to 15 - |kRestorationHorizontalBorder| values.
504 const int height_horizontal =
505 height + kWienerFilterTaps - 1 - 2 * number_rows_to_skip;
506 const int height_extra = (height_horizontal - height) >> 1;
507 assert(height_extra <= 2);
508 const auto* const src = static_cast<const uint8_t*>(source);
509 const auto* const top = static_cast<const uint8_t*>(top_border);
510 const auto* const bottom = static_cast<const uint8_t*>(bottom_border);
511 if (number_leading_zero_coefficients[WienerInfo::kHorizontal] == 0) {
512 WienerHorizontalTap7(top + (2 - height_extra) * stride - 3, stride,
513 wiener_stride, height_extra, filter_horizontal,
514 &wiener_buffer_horizontal);
515 WienerHorizontalTap7(src - 3, stride, wiener_stride, height,
516 filter_horizontal, &wiener_buffer_horizontal);
517 WienerHorizontalTap7(bottom - 3, stride, wiener_stride, height_extra,
518 filter_horizontal, &wiener_buffer_horizontal);
519 } else if (number_leading_zero_coefficients[WienerInfo::kHorizontal] == 1) {
520 WienerHorizontalTap5(top + (2 - height_extra) * stride - 2, stride,
521 wiener_stride, height_extra, filter_horizontal,
522 &wiener_buffer_horizontal);
523 WienerHorizontalTap5(src - 2, stride, wiener_stride, height,
524 filter_horizontal, &wiener_buffer_horizontal);
525 WienerHorizontalTap5(bottom - 2, stride, wiener_stride, height_extra,
526 filter_horizontal, &wiener_buffer_horizontal);
527 } else if (number_leading_zero_coefficients[WienerInfo::kHorizontal] == 2) {
528 // The maximum over-reads happen here.
529 WienerHorizontalTap3(top + (2 - height_extra) * stride - 1, stride,
530 wiener_stride, height_extra, filter_horizontal,
531 &wiener_buffer_horizontal);
532 WienerHorizontalTap3(src - 1, stride, wiener_stride, height,
533 filter_horizontal, &wiener_buffer_horizontal);
534 WienerHorizontalTap3(bottom - 1, stride, wiener_stride, height_extra,
535 filter_horizontal, &wiener_buffer_horizontal);
536 } else {
537 assert(number_leading_zero_coefficients[WienerInfo::kHorizontal] == 3);
538 WienerHorizontalTap1(top + (2 - height_extra) * stride, stride,
539 wiener_stride, height_extra,
540 &wiener_buffer_horizontal);
541 WienerHorizontalTap1(src, stride, wiener_stride, height,
542 &wiener_buffer_horizontal);
543 WienerHorizontalTap1(bottom, stride, wiener_stride, height_extra,
544 &wiener_buffer_horizontal);
545 }
546
547 // vertical filtering.
548 // Over-writes up to 15 values.
549 auto* dst = static_cast<uint8_t*>(dest);
550 if (number_leading_zero_coefficients[WienerInfo::kVertical] == 0) {
551 // Because the top row of |source| is a duplicate of the second row, and the
552 // bottom row of |source| is a duplicate of its above row, we can duplicate
553 // the top and bottom row of |wiener_buffer| accordingly.
554 memcpy(wiener_buffer_horizontal, wiener_buffer_horizontal - wiener_stride,
555 sizeof(*wiener_buffer_horizontal) * wiener_stride);
556 memcpy(restoration_buffer->wiener_buffer,
557 restoration_buffer->wiener_buffer + wiener_stride,
558 sizeof(*restoration_buffer->wiener_buffer) * wiener_stride);
559 WienerVerticalTap7(wiener_buffer_vertical, wiener_stride, height,
560 filter_vertical, dst, stride);
561 } else if (number_leading_zero_coefficients[WienerInfo::kVertical] == 1) {
562 WienerVerticalTap5(wiener_buffer_vertical + wiener_stride, wiener_stride,
563 height, filter_vertical, dst, stride);
564 } else if (number_leading_zero_coefficients[WienerInfo::kVertical] == 2) {
565 WienerVerticalTap3(wiener_buffer_vertical + 2 * wiener_stride,
566 wiener_stride, height, filter_vertical, dst, stride);
567 } else {
568 assert(number_leading_zero_coefficients[WienerInfo::kVertical] == 3);
569 WienerVerticalTap1(wiener_buffer_vertical + 3 * wiener_stride,
570 wiener_stride, height, dst, stride);
571 }
572 }
573
574 //------------------------------------------------------------------------------
575 // SGR
576
Prepare3_8(const uint8x8x2_t src,uint8x8_t dst[3])577 inline void Prepare3_8(const uint8x8x2_t src, uint8x8_t dst[3]) {
578 dst[0] = VshrU128<0>(src);
579 dst[1] = VshrU128<1>(src);
580 dst[2] = VshrU128<2>(src);
581 }
582
Prepare3_16(const uint16x8x2_t src,uint16x4_t low[3],uint16x4_t high[3])583 inline void Prepare3_16(const uint16x8x2_t src, uint16x4_t low[3],
584 uint16x4_t high[3]) {
585 uint16x8_t s[3];
586 s[0] = VshrU128<0>(src);
587 s[1] = VshrU128<2>(src);
588 s[2] = VshrU128<4>(src);
589 low[0] = vget_low_u16(s[0]);
590 low[1] = vget_low_u16(s[1]);
591 low[2] = vget_low_u16(s[2]);
592 high[0] = vget_high_u16(s[0]);
593 high[1] = vget_high_u16(s[1]);
594 high[2] = vget_high_u16(s[2]);
595 }
596
Prepare5_8(const uint8x8x2_t src,uint8x8_t dst[5])597 inline void Prepare5_8(const uint8x8x2_t src, uint8x8_t dst[5]) {
598 dst[0] = VshrU128<0>(src);
599 dst[1] = VshrU128<1>(src);
600 dst[2] = VshrU128<2>(src);
601 dst[3] = VshrU128<3>(src);
602 dst[4] = VshrU128<4>(src);
603 }
604
Prepare5_16(const uint16x8x2_t src,uint16x4_t low[5],uint16x4_t high[5])605 inline void Prepare5_16(const uint16x8x2_t src, uint16x4_t low[5],
606 uint16x4_t high[5]) {
607 Prepare3_16(src, low, high);
608 const uint16x8_t s3 = VshrU128<6>(src);
609 const uint16x8_t s4 = VshrU128<8>(src);
610 low[3] = vget_low_u16(s3);
611 low[4] = vget_low_u16(s4);
612 high[3] = vget_high_u16(s3);
613 high[4] = vget_high_u16(s4);
614 }
615
Sum3_16(const uint16x8_t src0,const uint16x8_t src1,const uint16x8_t src2)616 inline uint16x8_t Sum3_16(const uint16x8_t src0, const uint16x8_t src1,
617 const uint16x8_t src2) {
618 const uint16x8_t sum = vaddq_u16(src0, src1);
619 return vaddq_u16(sum, src2);
620 }
621
Sum3_16(const uint16x8_t src[3])622 inline uint16x8_t Sum3_16(const uint16x8_t src[3]) {
623 return Sum3_16(src[0], src[1], src[2]);
624 }
625
Sum3_32(const uint32x4_t src0,const uint32x4_t src1,const uint32x4_t src2)626 inline uint32x4_t Sum3_32(const uint32x4_t src0, const uint32x4_t src1,
627 const uint32x4_t src2) {
628 const uint32x4_t sum = vaddq_u32(src0, src1);
629 return vaddq_u32(sum, src2);
630 }
631
Sum3_32(const uint32x4x2_t src[3])632 inline uint32x4x2_t Sum3_32(const uint32x4x2_t src[3]) {
633 uint32x4x2_t d;
634 d.val[0] = Sum3_32(src[0].val[0], src[1].val[0], src[2].val[0]);
635 d.val[1] = Sum3_32(src[0].val[1], src[1].val[1], src[2].val[1]);
636 return d;
637 }
638
Sum3W_16(const uint8x8_t src[3])639 inline uint16x8_t Sum3W_16(const uint8x8_t src[3]) {
640 const uint16x8_t sum = vaddl_u8(src[0], src[1]);
641 return vaddw_u8(sum, src[2]);
642 }
643
Sum3W_32(const uint16x4_t src[3])644 inline uint32x4_t Sum3W_32(const uint16x4_t src[3]) {
645 const uint32x4_t sum = vaddl_u16(src[0], src[1]);
646 return vaddw_u16(sum, src[2]);
647 }
648
Sum5_16(const uint16x8_t src[5])649 inline uint16x8_t Sum5_16(const uint16x8_t src[5]) {
650 const uint16x8_t sum01 = vaddq_u16(src[0], src[1]);
651 const uint16x8_t sum23 = vaddq_u16(src[2], src[3]);
652 const uint16x8_t sum = vaddq_u16(sum01, sum23);
653 return vaddq_u16(sum, src[4]);
654 }
655
Sum5_32(const uint32x4_t src0,const uint32x4_t src1,const uint32x4_t src2,const uint32x4_t src3,const uint32x4_t src4)656 inline uint32x4_t Sum5_32(const uint32x4_t src0, const uint32x4_t src1,
657 const uint32x4_t src2, const uint32x4_t src3,
658 const uint32x4_t src4) {
659 const uint32x4_t sum01 = vaddq_u32(src0, src1);
660 const uint32x4_t sum23 = vaddq_u32(src2, src3);
661 const uint32x4_t sum = vaddq_u32(sum01, sum23);
662 return vaddq_u32(sum, src4);
663 }
664
Sum5_32(const uint32x4x2_t src[5])665 inline uint32x4x2_t Sum5_32(const uint32x4x2_t src[5]) {
666 uint32x4x2_t d;
667 d.val[0] = Sum5_32(src[0].val[0], src[1].val[0], src[2].val[0], src[3].val[0],
668 src[4].val[0]);
669 d.val[1] = Sum5_32(src[0].val[1], src[1].val[1], src[2].val[1], src[3].val[1],
670 src[4].val[1]);
671 return d;
672 }
673
Sum5W_32(const uint16x4_t src[5])674 inline uint32x4_t Sum5W_32(const uint16x4_t src[5]) {
675 const uint32x4_t sum01 = vaddl_u16(src[0], src[1]);
676 const uint32x4_t sum23 = vaddl_u16(src[2], src[3]);
677 const uint32x4_t sum0123 = vaddq_u32(sum01, sum23);
678 return vaddw_u16(sum0123, src[4]);
679 }
680
Sum3Horizontal(const uint8x8x2_t src)681 inline uint16x8_t Sum3Horizontal(const uint8x8x2_t src) {
682 uint8x8_t s[3];
683 Prepare3_8(src, s);
684 return Sum3W_16(s);
685 }
686
Sum3WHorizontal(const uint16x8x2_t src)687 inline uint32x4x2_t Sum3WHorizontal(const uint16x8x2_t src) {
688 uint16x4_t low[3], high[3];
689 uint32x4x2_t sum;
690 Prepare3_16(src, low, high);
691 sum.val[0] = Sum3W_32(low);
692 sum.val[1] = Sum3W_32(high);
693 return sum;
694 }
695
Sum5Horizontal(const uint8x8x2_t src)696 inline uint16x8_t Sum5Horizontal(const uint8x8x2_t src) {
697 uint8x8_t s[5];
698 Prepare5_8(src, s);
699 const uint16x8_t sum01 = vaddl_u8(s[0], s[1]);
700 const uint16x8_t sum23 = vaddl_u8(s[2], s[3]);
701 const uint16x8_t sum0123 = vaddq_u16(sum01, sum23);
702 return vaddw_u8(sum0123, s[4]);
703 }
704
Sum5WHorizontal(const uint16x8x2_t src)705 inline uint32x4x2_t Sum5WHorizontal(const uint16x8x2_t src) {
706 uint16x4_t low[5], high[5];
707 Prepare5_16(src, low, high);
708 uint32x4x2_t sum;
709 sum.val[0] = Sum5W_32(low);
710 sum.val[1] = Sum5W_32(high);
711 return sum;
712 }
713
SumHorizontal(const uint16x4_t src[5],uint32x4_t * const row_sq3,uint32x4_t * const row_sq5)714 void SumHorizontal(const uint16x4_t src[5], uint32x4_t* const row_sq3,
715 uint32x4_t* const row_sq5) {
716 const uint32x4_t sum04 = vaddl_u16(src[0], src[4]);
717 const uint32x4_t sum12 = vaddl_u16(src[1], src[2]);
718 *row_sq3 = vaddw_u16(sum12, src[3]);
719 *row_sq5 = vaddq_u32(sum04, *row_sq3);
720 }
721
SumHorizontal(const uint8x8x2_t src,const uint16x8x2_t sq,uint16x8_t * const row3,uint16x8_t * const row5,uint32x4x2_t * const row_sq3,uint32x4x2_t * const row_sq5)722 void SumHorizontal(const uint8x8x2_t src, const uint16x8x2_t sq,
723 uint16x8_t* const row3, uint16x8_t* const row5,
724 uint32x4x2_t* const row_sq3, uint32x4x2_t* const row_sq5) {
725 uint8x8_t s[5];
726 Prepare5_8(src, s);
727 const uint16x8_t sum04 = vaddl_u8(s[0], s[4]);
728 const uint16x8_t sum12 = vaddl_u8(s[1], s[2]);
729 *row3 = vaddw_u8(sum12, s[3]);
730 *row5 = vaddq_u16(sum04, *row3);
731 uint16x4_t low[5], high[5];
732 Prepare5_16(sq, low, high);
733 SumHorizontal(low, &row_sq3->val[0], &row_sq5->val[0]);
734 SumHorizontal(high, &row_sq3->val[1], &row_sq5->val[1]);
735 }
736
Sum343(const uint8x8x2_t src)737 inline uint16x8_t Sum343(const uint8x8x2_t src) {
738 uint8x8_t s[3];
739 Prepare3_8(src, s);
740 const uint16x8_t sum = Sum3W_16(s);
741 const uint16x8_t sum3 = Sum3_16(sum, sum, sum);
742 return vaddw_u8(sum3, s[1]);
743 }
744
Sum343W(const uint16x4_t src[3])745 inline uint32x4_t Sum343W(const uint16x4_t src[3]) {
746 const uint32x4_t sum = Sum3W_32(src);
747 const uint32x4_t sum3 = Sum3_32(sum, sum, sum);
748 return vaddw_u16(sum3, src[1]);
749 }
750
Sum343W(const uint16x8x2_t src)751 inline uint32x4x2_t Sum343W(const uint16x8x2_t src) {
752 uint16x4_t low[3], high[3];
753 uint32x4x2_t d;
754 Prepare3_16(src, low, high);
755 d.val[0] = Sum343W(low);
756 d.val[1] = Sum343W(high);
757 return d;
758 }
759
Sum565(const uint8x8x2_t src)760 inline uint16x8_t Sum565(const uint8x8x2_t src) {
761 uint8x8_t s[3];
762 Prepare3_8(src, s);
763 const uint16x8_t sum = Sum3W_16(s);
764 const uint16x8_t sum4 = vshlq_n_u16(sum, 2);
765 const uint16x8_t sum5 = vaddq_u16(sum4, sum);
766 return vaddw_u8(sum5, s[1]);
767 }
768
Sum565W(const uint16x4_t src[3])769 inline uint32x4_t Sum565W(const uint16x4_t src[3]) {
770 const uint32x4_t sum = Sum3W_32(src);
771 const uint32x4_t sum4 = vshlq_n_u32(sum, 2);
772 const uint32x4_t sum5 = vaddq_u32(sum4, sum);
773 return vaddw_u16(sum5, src[1]);
774 }
775
Sum565W(const uint16x8x2_t src)776 inline uint32x4x2_t Sum565W(const uint16x8x2_t src) {
777 uint16x4_t low[3], high[3];
778 uint32x4x2_t d;
779 Prepare3_16(src, low, high);
780 d.val[0] = Sum565W(low);
781 d.val[1] = Sum565W(high);
782 return d;
783 }
784
BoxSum(const uint8_t * src,const ptrdiff_t src_stride,const int height,const ptrdiff_t sum_stride,uint16_t * sum3,uint16_t * sum5,uint32_t * square_sum3,uint32_t * square_sum5)785 inline void BoxSum(const uint8_t* src, const ptrdiff_t src_stride,
786 const int height, const ptrdiff_t sum_stride, uint16_t* sum3,
787 uint16_t* sum5, uint32_t* square_sum3,
788 uint32_t* square_sum5) {
789 int y = height;
790 do {
791 uint8x8x2_t s;
792 uint16x8x2_t sq;
793 s.val[0] = vld1_u8(src);
794 sq.val[0] = vmull_u8(s.val[0], s.val[0]);
795 ptrdiff_t x = 0;
796 do {
797 uint16x8_t row3, row5;
798 uint32x4x2_t row_sq3, row_sq5;
799 s.val[1] = vld1_u8(src + x + 8);
800 sq.val[1] = vmull_u8(s.val[1], s.val[1]);
801 SumHorizontal(s, sq, &row3, &row5, &row_sq3, &row_sq5);
802 vst1q_u16(sum3, row3);
803 vst1q_u16(sum5, row5);
804 vst1q_u32(square_sum3 + 0, row_sq3.val[0]);
805 vst1q_u32(square_sum3 + 4, row_sq3.val[1]);
806 vst1q_u32(square_sum5 + 0, row_sq5.val[0]);
807 vst1q_u32(square_sum5 + 4, row_sq5.val[1]);
808 s.val[0] = s.val[1];
809 sq.val[0] = sq.val[1];
810 sum3 += 8;
811 sum5 += 8;
812 square_sum3 += 8;
813 square_sum5 += 8;
814 x += 8;
815 } while (x < sum_stride);
816 src += src_stride;
817 } while (--y != 0);
818 }
819
820 template <int size>
BoxSum(const uint8_t * src,const ptrdiff_t src_stride,const int height,const ptrdiff_t sum_stride,uint16_t * sums,uint32_t * square_sums)821 inline void BoxSum(const uint8_t* src, const ptrdiff_t src_stride,
822 const int height, const ptrdiff_t sum_stride, uint16_t* sums,
823 uint32_t* square_sums) {
824 static_assert(size == 3 || size == 5, "");
825 int y = height;
826 do {
827 uint8x8x2_t s;
828 uint16x8x2_t sq;
829 s.val[0] = vld1_u8(src);
830 sq.val[0] = vmull_u8(s.val[0], s.val[0]);
831 ptrdiff_t x = 0;
832 do {
833 uint16x8_t row;
834 uint32x4x2_t row_sq;
835 s.val[1] = vld1_u8(src + x + 8);
836 sq.val[1] = vmull_u8(s.val[1], s.val[1]);
837 if (size == 3) {
838 row = Sum3Horizontal(s);
839 row_sq = Sum3WHorizontal(sq);
840 } else {
841 row = Sum5Horizontal(s);
842 row_sq = Sum5WHorizontal(sq);
843 }
844 vst1q_u16(sums, row);
845 vst1q_u32(square_sums + 0, row_sq.val[0]);
846 vst1q_u32(square_sums + 4, row_sq.val[1]);
847 s.val[0] = s.val[1];
848 sq.val[0] = sq.val[1];
849 sums += 8;
850 square_sums += 8;
851 x += 8;
852 } while (x < sum_stride);
853 src += src_stride;
854 } while (--y != 0);
855 }
856
857 template <int n>
CalculateMa(const uint16x4_t sum,const uint32x4_t sum_sq,const uint32_t scale)858 inline uint16x4_t CalculateMa(const uint16x4_t sum, const uint32x4_t sum_sq,
859 const uint32_t scale) {
860 // a = |sum_sq|
861 // d = |sum|
862 // p = (a * n < d * d) ? 0 : a * n - d * d;
863 const uint32x4_t dxd = vmull_u16(sum, sum);
864 const uint32x4_t axn = vmulq_n_u32(sum_sq, n);
865 // Ensure |p| does not underflow by using saturating subtraction.
866 const uint32x4_t p = vqsubq_u32(axn, dxd);
867 const uint32x4_t pxs = vmulq_n_u32(p, scale);
868 // vrshrn_n_u32() (narrowing shift) can only shift by 16 and kSgrProjScaleBits
869 // is 20.
870 const uint32x4_t shifted = vrshrq_n_u32(pxs, kSgrProjScaleBits);
871 return vmovn_u32(shifted);
872 }
873
874 template <int n>
CalculateIntermediate(const uint16x8_t sum,const uint32x4x2_t sum_sq,const uint32_t scale,uint8x8_t * const ma,uint16x8_t * const b)875 inline void CalculateIntermediate(const uint16x8_t sum,
876 const uint32x4x2_t sum_sq,
877 const uint32_t scale, uint8x8_t* const ma,
878 uint16x8_t* const b) {
879 constexpr uint32_t one_over_n =
880 ((1 << kSgrProjReciprocalBits) + (n >> 1)) / n;
881 const uint16x4_t z0 = CalculateMa<n>(vget_low_u16(sum), sum_sq.val[0], scale);
882 const uint16x4_t z1 =
883 CalculateMa<n>(vget_high_u16(sum), sum_sq.val[1], scale);
884 const uint16x8_t z01 = vcombine_u16(z0, z1);
885 // Using vqmovn_u16() needs an extra sign extension instruction.
886 const uint16x8_t z = vminq_u16(z01, vdupq_n_u16(255));
887 // Using vgetq_lane_s16() can save the sign extension instruction.
888 const uint8_t lookup[8] = {
889 kSgrMaLookup[vgetq_lane_s16(vreinterpretq_s16_u16(z), 0)],
890 kSgrMaLookup[vgetq_lane_s16(vreinterpretq_s16_u16(z), 1)],
891 kSgrMaLookup[vgetq_lane_s16(vreinterpretq_s16_u16(z), 2)],
892 kSgrMaLookup[vgetq_lane_s16(vreinterpretq_s16_u16(z), 3)],
893 kSgrMaLookup[vgetq_lane_s16(vreinterpretq_s16_u16(z), 4)],
894 kSgrMaLookup[vgetq_lane_s16(vreinterpretq_s16_u16(z), 5)],
895 kSgrMaLookup[vgetq_lane_s16(vreinterpretq_s16_u16(z), 6)],
896 kSgrMaLookup[vgetq_lane_s16(vreinterpretq_s16_u16(z), 7)]};
897 *ma = vld1_u8(lookup);
898 // b = ma * b * one_over_n
899 // |ma| = [0, 255]
900 // |sum| is a box sum with radius 1 or 2.
901 // For the first pass radius is 2. Maximum value is 5x5x255 = 6375.
902 // For the second pass radius is 1. Maximum value is 3x3x255 = 2295.
903 // |one_over_n| = ((1 << kSgrProjReciprocalBits) + (n >> 1)) / n
904 // When radius is 2 |n| is 25. |one_over_n| is 164.
905 // When radius is 1 |n| is 9. |one_over_n| is 455.
906 // |kSgrProjReciprocalBits| is 12.
907 // Radius 2: 255 * 6375 * 164 >> 12 = 65088 (16 bits).
908 // Radius 1: 255 * 2295 * 455 >> 12 = 65009 (16 bits).
909 const uint16x8_t maq = vmovl_u8(*ma);
910 const uint32x4_t m0 = vmull_u16(vget_low_u16(maq), vget_low_u16(sum));
911 const uint32x4_t m1 = vmull_u16(vget_high_u16(maq), vget_high_u16(sum));
912 const uint32x4_t m2 = vmulq_n_u32(m0, one_over_n);
913 const uint32x4_t m3 = vmulq_n_u32(m1, one_over_n);
914 const uint16x4_t b_lo = vrshrn_n_u32(m2, kSgrProjReciprocalBits);
915 const uint16x4_t b_hi = vrshrn_n_u32(m3, kSgrProjReciprocalBits);
916 *b = vcombine_u16(b_lo, b_hi);
917 }
918
CalculateIntermediate5(const uint16x8_t s5[5],const uint32x4x2_t sq5[5],const uint32_t scale,uint8x8_t * const ma,uint16x8_t * const b)919 inline void CalculateIntermediate5(const uint16x8_t s5[5],
920 const uint32x4x2_t sq5[5],
921 const uint32_t scale, uint8x8_t* const ma,
922 uint16x8_t* const b) {
923 const uint16x8_t sum = Sum5_16(s5);
924 const uint32x4x2_t sum_sq = Sum5_32(sq5);
925 CalculateIntermediate<25>(sum, sum_sq, scale, ma, b);
926 }
927
CalculateIntermediate3(const uint16x8_t s3[3],const uint32x4x2_t sq3[3],const uint32_t scale,uint8x8_t * const ma,uint16x8_t * const b)928 inline void CalculateIntermediate3(const uint16x8_t s3[3],
929 const uint32x4x2_t sq3[3],
930 const uint32_t scale, uint8x8_t* const ma,
931 uint16x8_t* const b) {
932 const uint16x8_t sum = Sum3_16(s3);
933 const uint32x4x2_t sum_sq = Sum3_32(sq3);
934 CalculateIntermediate<9>(sum, sum_sq, scale, ma, b);
935 }
936
Store343_444(const uint8x8x2_t ma3,const uint16x8x2_t b3,const ptrdiff_t x,uint16x8_t * const sum_ma343,uint16x8_t * const sum_ma444,uint32x4x2_t * const sum_b343,uint32x4x2_t * const sum_b444,uint16_t * const ma343,uint16_t * const ma444,uint32_t * const b343,uint32_t * const b444)937 inline void Store343_444(const uint8x8x2_t ma3, const uint16x8x2_t b3,
938 const ptrdiff_t x, uint16x8_t* const sum_ma343,
939 uint16x8_t* const sum_ma444,
940 uint32x4x2_t* const sum_b343,
941 uint32x4x2_t* const sum_b444, uint16_t* const ma343,
942 uint16_t* const ma444, uint32_t* const b343,
943 uint32_t* const b444) {
944 uint8x8_t s[3];
945 Prepare3_8(ma3, s);
946 const uint16x8_t sum_ma111 = Sum3W_16(s);
947 *sum_ma444 = vshlq_n_u16(sum_ma111, 2);
948 const uint16x8_t sum333 = vsubq_u16(*sum_ma444, sum_ma111);
949 *sum_ma343 = vaddw_u8(sum333, s[1]);
950 uint16x4_t low[3], high[3];
951 uint32x4x2_t sum_b111;
952 Prepare3_16(b3, low, high);
953 sum_b111.val[0] = Sum3W_32(low);
954 sum_b111.val[1] = Sum3W_32(high);
955 sum_b444->val[0] = vshlq_n_u32(sum_b111.val[0], 2);
956 sum_b444->val[1] = vshlq_n_u32(sum_b111.val[1], 2);
957 sum_b343->val[0] = vsubq_u32(sum_b444->val[0], sum_b111.val[0]);
958 sum_b343->val[1] = vsubq_u32(sum_b444->val[1], sum_b111.val[1]);
959 sum_b343->val[0] = vaddw_u16(sum_b343->val[0], low[1]);
960 sum_b343->val[1] = vaddw_u16(sum_b343->val[1], high[1]);
961 vst1q_u16(ma343 + x, *sum_ma343);
962 vst1q_u16(ma444 + x, *sum_ma444);
963 vst1q_u32(b343 + x + 0, sum_b343->val[0]);
964 vst1q_u32(b343 + x + 4, sum_b343->val[1]);
965 vst1q_u32(b444 + x + 0, sum_b444->val[0]);
966 vst1q_u32(b444 + x + 4, sum_b444->val[1]);
967 }
968
Store343_444(const uint8x8x2_t ma3,const uint16x8x2_t b3,const ptrdiff_t x,uint16x8_t * const sum_ma343,uint32x4x2_t * const sum_b343,uint16_t * const ma343,uint16_t * const ma444,uint32_t * const b343,uint32_t * const b444)969 inline void Store343_444(const uint8x8x2_t ma3, const uint16x8x2_t b3,
970 const ptrdiff_t x, uint16x8_t* const sum_ma343,
971 uint32x4x2_t* const sum_b343, uint16_t* const ma343,
972 uint16_t* const ma444, uint32_t* const b343,
973 uint32_t* const b444) {
974 uint16x8_t sum_ma444;
975 uint32x4x2_t sum_b444;
976 Store343_444(ma3, b3, x, sum_ma343, &sum_ma444, sum_b343, &sum_b444, ma343,
977 ma444, b343, b444);
978 }
979
Store343_444(const uint8x8x2_t ma3,const uint16x8x2_t b3,const ptrdiff_t x,uint16_t * const ma343,uint16_t * const ma444,uint32_t * const b343,uint32_t * const b444)980 inline void Store343_444(const uint8x8x2_t ma3, const uint16x8x2_t b3,
981 const ptrdiff_t x, uint16_t* const ma343,
982 uint16_t* const ma444, uint32_t* const b343,
983 uint32_t* const b444) {
984 uint16x8_t sum_ma343;
985 uint32x4x2_t sum_b343;
986 Store343_444(ma3, b3, x, &sum_ma343, &sum_b343, ma343, ma444, b343, b444);
987 }
988
BoxFilterPreProcess5(const uint8_t * const src0,const uint8_t * const src1,const ptrdiff_t x,const uint32_t scale,uint16_t * const sum5[5],uint32_t * const square_sum5[5],uint8x8x2_t s[2],uint16x8x2_t sq[2],uint8x8_t * const ma,uint16x8_t * const b)989 LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess5(
990 const uint8_t* const src0, const uint8_t* const src1, const ptrdiff_t x,
991 const uint32_t scale, uint16_t* const sum5[5],
992 uint32_t* const square_sum5[5], uint8x8x2_t s[2], uint16x8x2_t sq[2],
993 uint8x8_t* const ma, uint16x8_t* const b) {
994 uint16x8_t s5[5];
995 uint32x4x2_t sq5[5];
996 s[0].val[1] = vld1_u8(src0 + x + 8);
997 s[1].val[1] = vld1_u8(src1 + x + 8);
998 sq[0].val[1] = vmull_u8(s[0].val[1], s[0].val[1]);
999 sq[1].val[1] = vmull_u8(s[1].val[1], s[1].val[1]);
1000 s5[3] = Sum5Horizontal(s[0]);
1001 s5[4] = Sum5Horizontal(s[1]);
1002 sq5[3] = Sum5WHorizontal(sq[0]);
1003 sq5[4] = Sum5WHorizontal(sq[1]);
1004 vst1q_u16(sum5[3] + x, s5[3]);
1005 vst1q_u16(sum5[4] + x, s5[4]);
1006 vst1q_u32(square_sum5[3] + x + 0, sq5[3].val[0]);
1007 vst1q_u32(square_sum5[3] + x + 4, sq5[3].val[1]);
1008 vst1q_u32(square_sum5[4] + x + 0, sq5[4].val[0]);
1009 vst1q_u32(square_sum5[4] + x + 4, sq5[4].val[1]);
1010 s5[0] = vld1q_u16(sum5[0] + x);
1011 s5[1] = vld1q_u16(sum5[1] + x);
1012 s5[2] = vld1q_u16(sum5[2] + x);
1013 sq5[0].val[0] = vld1q_u32(square_sum5[0] + x + 0);
1014 sq5[0].val[1] = vld1q_u32(square_sum5[0] + x + 4);
1015 sq5[1].val[0] = vld1q_u32(square_sum5[1] + x + 0);
1016 sq5[1].val[1] = vld1q_u32(square_sum5[1] + x + 4);
1017 sq5[2].val[0] = vld1q_u32(square_sum5[2] + x + 0);
1018 sq5[2].val[1] = vld1q_u32(square_sum5[2] + x + 4);
1019 CalculateIntermediate5(s5, sq5, scale, ma, b);
1020 }
1021
BoxFilterPreProcess5LastRow(const uint8_t * const src,const ptrdiff_t x,const uint32_t scale,const uint16_t * const sum5[5],const uint32_t * const square_sum5[5],uint8x8x2_t * const s,uint16x8x2_t * const sq,uint8x8_t * const ma,uint16x8_t * const b)1022 LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess5LastRow(
1023 const uint8_t* const src, const ptrdiff_t x, const uint32_t scale,
1024 const uint16_t* const sum5[5], const uint32_t* const square_sum5[5],
1025 uint8x8x2_t* const s, uint16x8x2_t* const sq, uint8x8_t* const ma,
1026 uint16x8_t* const b) {
1027 uint16x8_t s5[5];
1028 uint32x4x2_t sq5[5];
1029 s->val[1] = vld1_u8(src + x + 8);
1030 sq->val[1] = vmull_u8(s->val[1], s->val[1]);
1031 s5[3] = s5[4] = Sum5Horizontal(*s);
1032 sq5[3] = sq5[4] = Sum5WHorizontal(*sq);
1033 s5[0] = vld1q_u16(sum5[0] + x);
1034 s5[1] = vld1q_u16(sum5[1] + x);
1035 s5[2] = vld1q_u16(sum5[2] + x);
1036 sq5[0].val[0] = vld1q_u32(square_sum5[0] + x + 0);
1037 sq5[0].val[1] = vld1q_u32(square_sum5[0] + x + 4);
1038 sq5[1].val[0] = vld1q_u32(square_sum5[1] + x + 0);
1039 sq5[1].val[1] = vld1q_u32(square_sum5[1] + x + 4);
1040 sq5[2].val[0] = vld1q_u32(square_sum5[2] + x + 0);
1041 sq5[2].val[1] = vld1q_u32(square_sum5[2] + x + 4);
1042 CalculateIntermediate5(s5, sq5, scale, ma, b);
1043 }
1044
BoxFilterPreProcess3(const uint8_t * const src,const ptrdiff_t x,const uint32_t scale,uint16_t * const sum3[3],uint32_t * const square_sum3[3],uint8x8x2_t * const s,uint16x8x2_t * const sq,uint8x8_t * const ma,uint16x8_t * const b)1045 LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess3(
1046 const uint8_t* const src, const ptrdiff_t x, const uint32_t scale,
1047 uint16_t* const sum3[3], uint32_t* const square_sum3[3],
1048 uint8x8x2_t* const s, uint16x8x2_t* const sq, uint8x8_t* const ma,
1049 uint16x8_t* const b) {
1050 uint16x8_t s3[3];
1051 uint32x4x2_t sq3[3];
1052 s->val[1] = vld1_u8(src + x + 8);
1053 sq->val[1] = vmull_u8(s->val[1], s->val[1]);
1054 s3[2] = Sum3Horizontal(*s);
1055 sq3[2] = Sum3WHorizontal(*sq);
1056 vst1q_u16(sum3[2] + x, s3[2]);
1057 vst1q_u32(square_sum3[2] + x + 0, sq3[2].val[0]);
1058 vst1q_u32(square_sum3[2] + x + 4, sq3[2].val[1]);
1059 s3[0] = vld1q_u16(sum3[0] + x);
1060 s3[1] = vld1q_u16(sum3[1] + x);
1061 sq3[0].val[0] = vld1q_u32(square_sum3[0] + x + 0);
1062 sq3[0].val[1] = vld1q_u32(square_sum3[0] + x + 4);
1063 sq3[1].val[0] = vld1q_u32(square_sum3[1] + x + 0);
1064 sq3[1].val[1] = vld1q_u32(square_sum3[1] + x + 4);
1065 CalculateIntermediate3(s3, sq3, scale, ma, b);
1066 }
1067
BoxFilterPreProcess(const uint8_t * const src0,const uint8_t * const src1,const ptrdiff_t x,const uint16_t scales[2],uint16_t * const sum3[4],uint16_t * const sum5[5],uint32_t * const square_sum3[4],uint32_t * const square_sum5[5],uint8x8x2_t s[2],uint16x8x2_t sq[2],uint8x8_t * const ma3_0,uint8x8_t * const ma3_1,uint16x8_t * const b3_0,uint16x8_t * const b3_1,uint8x8_t * const ma5,uint16x8_t * const b5)1068 LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess(
1069 const uint8_t* const src0, const uint8_t* const src1, const ptrdiff_t x,
1070 const uint16_t scales[2], uint16_t* const sum3[4], uint16_t* const sum5[5],
1071 uint32_t* const square_sum3[4], uint32_t* const square_sum5[5],
1072 uint8x8x2_t s[2], uint16x8x2_t sq[2], uint8x8_t* const ma3_0,
1073 uint8x8_t* const ma3_1, uint16x8_t* const b3_0, uint16x8_t* const b3_1,
1074 uint8x8_t* const ma5, uint16x8_t* const b5) {
1075 uint16x8_t s3[4], s5[5];
1076 uint32x4x2_t sq3[4], sq5[5];
1077 s[0].val[1] = vld1_u8(src0 + x + 8);
1078 s[1].val[1] = vld1_u8(src1 + x + 8);
1079 sq[0].val[1] = vmull_u8(s[0].val[1], s[0].val[1]);
1080 sq[1].val[1] = vmull_u8(s[1].val[1], s[1].val[1]);
1081 SumHorizontal(s[0], sq[0], &s3[2], &s5[3], &sq3[2], &sq5[3]);
1082 SumHorizontal(s[1], sq[1], &s3[3], &s5[4], &sq3[3], &sq5[4]);
1083 vst1q_u16(sum3[2] + x, s3[2]);
1084 vst1q_u16(sum3[3] + x, s3[3]);
1085 vst1q_u32(square_sum3[2] + x + 0, sq3[2].val[0]);
1086 vst1q_u32(square_sum3[2] + x + 4, sq3[2].val[1]);
1087 vst1q_u32(square_sum3[3] + x + 0, sq3[3].val[0]);
1088 vst1q_u32(square_sum3[3] + x + 4, sq3[3].val[1]);
1089 vst1q_u16(sum5[3] + x, s5[3]);
1090 vst1q_u16(sum5[4] + x, s5[4]);
1091 vst1q_u32(square_sum5[3] + x + 0, sq5[3].val[0]);
1092 vst1q_u32(square_sum5[3] + x + 4, sq5[3].val[1]);
1093 vst1q_u32(square_sum5[4] + x + 0, sq5[4].val[0]);
1094 vst1q_u32(square_sum5[4] + x + 4, sq5[4].val[1]);
1095 s3[0] = vld1q_u16(sum3[0] + x);
1096 s3[1] = vld1q_u16(sum3[1] + x);
1097 sq3[0].val[0] = vld1q_u32(square_sum3[0] + x + 0);
1098 sq3[0].val[1] = vld1q_u32(square_sum3[0] + x + 4);
1099 sq3[1].val[0] = vld1q_u32(square_sum3[1] + x + 0);
1100 sq3[1].val[1] = vld1q_u32(square_sum3[1] + x + 4);
1101 s5[0] = vld1q_u16(sum5[0] + x);
1102 s5[1] = vld1q_u16(sum5[1] + x);
1103 s5[2] = vld1q_u16(sum5[2] + x);
1104 sq5[0].val[0] = vld1q_u32(square_sum5[0] + x + 0);
1105 sq5[0].val[1] = vld1q_u32(square_sum5[0] + x + 4);
1106 sq5[1].val[0] = vld1q_u32(square_sum5[1] + x + 0);
1107 sq5[1].val[1] = vld1q_u32(square_sum5[1] + x + 4);
1108 sq5[2].val[0] = vld1q_u32(square_sum5[2] + x + 0);
1109 sq5[2].val[1] = vld1q_u32(square_sum5[2] + x + 4);
1110 CalculateIntermediate3(s3, sq3, scales[1], ma3_0, b3_0);
1111 CalculateIntermediate3(s3 + 1, sq3 + 1, scales[1], ma3_1, b3_1);
1112 CalculateIntermediate5(s5, sq5, scales[0], ma5, b5);
1113 }
1114
BoxFilterPreProcessLastRow(const uint8_t * const src,const ptrdiff_t x,const uint16_t scales[2],const uint16_t * const sum3[4],const uint16_t * const sum5[5],const uint32_t * const square_sum3[4],const uint32_t * const square_sum5[5],uint8x8x2_t * const s,uint16x8x2_t * const sq,uint8x8_t * const ma3,uint8x8_t * const ma5,uint16x8_t * const b3,uint16x8_t * const b5)1115 LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcessLastRow(
1116 const uint8_t* const src, const ptrdiff_t x, const uint16_t scales[2],
1117 const uint16_t* const sum3[4], const uint16_t* const sum5[5],
1118 const uint32_t* const square_sum3[4], const uint32_t* const square_sum5[5],
1119 uint8x8x2_t* const s, uint16x8x2_t* const sq, uint8x8_t* const ma3,
1120 uint8x8_t* const ma5, uint16x8_t* const b3, uint16x8_t* const b5) {
1121 uint16x8_t s3[3], s5[5];
1122 uint32x4x2_t sq3[3], sq5[5];
1123 s->val[1] = vld1_u8(src + x + 8);
1124 sq->val[1] = vmull_u8(s->val[1], s->val[1]);
1125 SumHorizontal(*s, *sq, &s3[2], &s5[3], &sq3[2], &sq5[3]);
1126 s5[0] = vld1q_u16(sum5[0] + x);
1127 s5[1] = vld1q_u16(sum5[1] + x);
1128 s5[2] = vld1q_u16(sum5[2] + x);
1129 s5[4] = s5[3];
1130 sq5[0].val[0] = vld1q_u32(square_sum5[0] + x + 0);
1131 sq5[0].val[1] = vld1q_u32(square_sum5[0] + x + 4);
1132 sq5[1].val[0] = vld1q_u32(square_sum5[1] + x + 0);
1133 sq5[1].val[1] = vld1q_u32(square_sum5[1] + x + 4);
1134 sq5[2].val[0] = vld1q_u32(square_sum5[2] + x + 0);
1135 sq5[2].val[1] = vld1q_u32(square_sum5[2] + x + 4);
1136 sq5[4] = sq5[3];
1137 CalculateIntermediate5(s5, sq5, scales[0], ma5, b5);
1138 s3[0] = vld1q_u16(sum3[0] + x);
1139 s3[1] = vld1q_u16(sum3[1] + x);
1140 sq3[0].val[0] = vld1q_u32(square_sum3[0] + x + 0);
1141 sq3[0].val[1] = vld1q_u32(square_sum3[0] + x + 4);
1142 sq3[1].val[0] = vld1q_u32(square_sum3[1] + x + 0);
1143 sq3[1].val[1] = vld1q_u32(square_sum3[1] + x + 4);
1144 CalculateIntermediate3(s3, sq3, scales[1], ma3, b3);
1145 }
1146
BoxSumFilterPreProcess5(const uint8_t * const src0,const uint8_t * const src1,const int width,const uint32_t scale,uint16_t * const sum5[5],uint32_t * const square_sum5[5],uint16_t * ma565,uint32_t * b565)1147 inline void BoxSumFilterPreProcess5(const uint8_t* const src0,
1148 const uint8_t* const src1, const int width,
1149 const uint32_t scale,
1150 uint16_t* const sum5[5],
1151 uint32_t* const square_sum5[5],
1152 uint16_t* ma565, uint32_t* b565) {
1153 uint8x8x2_t s[2], mas;
1154 uint16x8x2_t sq[2], bs;
1155 s[0].val[0] = vld1_u8(src0);
1156 s[1].val[0] = vld1_u8(src1);
1157 sq[0].val[0] = vmull_u8(s[0].val[0], s[0].val[0]);
1158 sq[1].val[0] = vmull_u8(s[1].val[0], s[1].val[0]);
1159 BoxFilterPreProcess5(src0, src1, 0, scale, sum5, square_sum5, s, sq,
1160 &mas.val[0], &bs.val[0]);
1161
1162 int x = 0;
1163 do {
1164 s[0].val[0] = s[0].val[1];
1165 s[1].val[0] = s[1].val[1];
1166 sq[0].val[0] = sq[0].val[1];
1167 sq[1].val[0] = sq[1].val[1];
1168 BoxFilterPreProcess5(src0, src1, x + 8, scale, sum5, square_sum5, s, sq,
1169 &mas.val[1], &bs.val[1]);
1170 const uint16x8_t ma = Sum565(mas);
1171 const uint32x4x2_t b = Sum565W(bs);
1172 vst1q_u16(ma565, ma);
1173 vst1q_u32(b565 + 0, b.val[0]);
1174 vst1q_u32(b565 + 4, b.val[1]);
1175 mas.val[0] = mas.val[1];
1176 bs.val[0] = bs.val[1];
1177 ma565 += 8;
1178 b565 += 8;
1179 x += 8;
1180 } while (x < width);
1181 }
1182
1183 template <bool calculate444>
BoxSumFilterPreProcess3(const uint8_t * const src,const int width,const uint32_t scale,uint16_t * const sum3[3],uint32_t * const square_sum3[3],uint16_t * ma343,uint16_t * ma444,uint32_t * b343,uint32_t * b444)1184 LIBGAV1_ALWAYS_INLINE void BoxSumFilterPreProcess3(
1185 const uint8_t* const src, const int width, const uint32_t scale,
1186 uint16_t* const sum3[3], uint32_t* const square_sum3[3], uint16_t* ma343,
1187 uint16_t* ma444, uint32_t* b343, uint32_t* b444) {
1188 uint8x8x2_t s, mas;
1189 uint16x8x2_t sq, bs;
1190 s.val[0] = vld1_u8(src);
1191 sq.val[0] = vmull_u8(s.val[0], s.val[0]);
1192 BoxFilterPreProcess3(src, 0, scale, sum3, square_sum3, &s, &sq, &mas.val[0],
1193 &bs.val[0]);
1194
1195 int x = 0;
1196 do {
1197 s.val[0] = s.val[1];
1198 sq.val[0] = sq.val[1];
1199 BoxFilterPreProcess3(src, x + 8, scale, sum3, square_sum3, &s, &sq,
1200 &mas.val[1], &bs.val[1]);
1201 if (calculate444) {
1202 Store343_444(mas, bs, 0, ma343, ma444, b343, b444);
1203 ma444 += 8;
1204 b444 += 8;
1205 } else {
1206 const uint16x8_t ma = Sum343(mas);
1207 const uint32x4x2_t b = Sum343W(bs);
1208 vst1q_u16(ma343, ma);
1209 vst1q_u32(b343 + 0, b.val[0]);
1210 vst1q_u32(b343 + 4, b.val[1]);
1211 }
1212 mas.val[0] = mas.val[1];
1213 bs.val[0] = bs.val[1];
1214 ma343 += 8;
1215 b343 += 8;
1216 x += 8;
1217 } while (x < width);
1218 }
1219
BoxSumFilterPreProcess(const uint8_t * const src0,const uint8_t * const src1,const int width,const uint16_t scales[2],uint16_t * const sum3[4],uint16_t * const sum5[5],uint32_t * const square_sum3[4],uint32_t * const square_sum5[5],uint16_t * const ma343[4],uint16_t * const ma444[2],uint16_t * ma565,uint32_t * const b343[4],uint32_t * const b444[2],uint32_t * b565)1220 inline void BoxSumFilterPreProcess(
1221 const uint8_t* const src0, const uint8_t* const src1, const int width,
1222 const uint16_t scales[2], uint16_t* const sum3[4], uint16_t* const sum5[5],
1223 uint32_t* const square_sum3[4], uint32_t* const square_sum5[5],
1224 uint16_t* const ma343[4], uint16_t* const ma444[2], uint16_t* ma565,
1225 uint32_t* const b343[4], uint32_t* const b444[2], uint32_t* b565) {
1226 uint8x8x2_t s[2];
1227 uint8x8x2_t ma3[2], ma5;
1228 uint16x8x2_t sq[2], b3[2], b5;
1229 s[0].val[0] = vld1_u8(src0);
1230 s[1].val[0] = vld1_u8(src1);
1231 sq[0].val[0] = vmull_u8(s[0].val[0], s[0].val[0]);
1232 sq[1].val[0] = vmull_u8(s[1].val[0], s[1].val[0]);
1233 BoxFilterPreProcess(src0, src1, 0, scales, sum3, sum5, square_sum3,
1234 square_sum5, s, sq, &ma3[0].val[0], &ma3[1].val[0],
1235 &b3[0].val[0], &b3[1].val[0], &ma5.val[0], &b5.val[0]);
1236
1237 int x = 0;
1238 do {
1239 s[0].val[0] = s[0].val[1];
1240 s[1].val[0] = s[1].val[1];
1241 sq[0].val[0] = sq[0].val[1];
1242 sq[1].val[0] = sq[1].val[1];
1243 BoxFilterPreProcess(src0, src1, x + 8, scales, sum3, sum5, square_sum3,
1244 square_sum5, s, sq, &ma3[0].val[1], &ma3[1].val[1],
1245 &b3[0].val[1], &b3[1].val[1], &ma5.val[1], &b5.val[1]);
1246 uint16x8_t ma = Sum343(ma3[0]);
1247 uint32x4x2_t b = Sum343W(b3[0]);
1248 vst1q_u16(ma343[0] + x, ma);
1249 vst1q_u32(b343[0] + x, b.val[0]);
1250 vst1q_u32(b343[0] + x + 4, b.val[1]);
1251 Store343_444(ma3[1], b3[1], x, ma343[1], ma444[0], b343[1], b444[0]);
1252 ma = Sum565(ma5);
1253 b = Sum565W(b5);
1254 vst1q_u16(ma565, ma);
1255 vst1q_u32(b565 + 0, b.val[0]);
1256 vst1q_u32(b565 + 4, b.val[1]);
1257 ma3[0].val[0] = ma3[0].val[1];
1258 ma3[1].val[0] = ma3[1].val[1];
1259 b3[0].val[0] = b3[0].val[1];
1260 b3[1].val[0] = b3[1].val[1];
1261 ma5.val[0] = ma5.val[1];
1262 b5.val[0] = b5.val[1];
1263 ma565 += 8;
1264 b565 += 8;
1265 x += 8;
1266 } while (x < width);
1267 }
1268
1269 template <int shift>
FilterOutput(const uint16x4_t src,const uint16x4_t ma,const uint32x4_t b)1270 inline int16x4_t FilterOutput(const uint16x4_t src, const uint16x4_t ma,
1271 const uint32x4_t b) {
1272 // ma: 255 * 32 = 8160 (13 bits)
1273 // b: 65088 * 32 = 2082816 (21 bits)
1274 // v: b - ma * 255 (22 bits)
1275 const int32x4_t v = vreinterpretq_s32_u32(vmlsl_u16(b, ma, src));
1276 // kSgrProjSgrBits = 8
1277 // kSgrProjRestoreBits = 4
1278 // shift = 4 or 5
1279 // v >> 8 or 9 (13 bits)
1280 return vrshrn_n_s32(v, kSgrProjSgrBits + shift - kSgrProjRestoreBits);
1281 }
1282
1283 template <int shift>
CalculateFilteredOutput(const uint8x8_t src,const uint16x8_t ma,const uint32x4x2_t b)1284 inline int16x8_t CalculateFilteredOutput(const uint8x8_t src,
1285 const uint16x8_t ma,
1286 const uint32x4x2_t b) {
1287 const uint16x8_t src_u16 = vmovl_u8(src);
1288 const int16x4_t dst_lo =
1289 FilterOutput<shift>(vget_low_u16(src_u16), vget_low_u16(ma), b.val[0]);
1290 const int16x4_t dst_hi =
1291 FilterOutput<shift>(vget_high_u16(src_u16), vget_high_u16(ma), b.val[1]);
1292 return vcombine_s16(dst_lo, dst_hi); // 13 bits
1293 }
1294
CalculateFilteredOutputPass1(const uint8x8_t s,uint16x8_t ma[2],uint32x4x2_t b[2])1295 inline int16x8_t CalculateFilteredOutputPass1(const uint8x8_t s,
1296 uint16x8_t ma[2],
1297 uint32x4x2_t b[2]) {
1298 const uint16x8_t ma_sum = vaddq_u16(ma[0], ma[1]);
1299 uint32x4x2_t b_sum;
1300 b_sum.val[0] = vaddq_u32(b[0].val[0], b[1].val[0]);
1301 b_sum.val[1] = vaddq_u32(b[0].val[1], b[1].val[1]);
1302 return CalculateFilteredOutput<5>(s, ma_sum, b_sum);
1303 }
1304
CalculateFilteredOutputPass2(const uint8x8_t s,uint16x8_t ma[3],uint32x4x2_t b[3])1305 inline int16x8_t CalculateFilteredOutputPass2(const uint8x8_t s,
1306 uint16x8_t ma[3],
1307 uint32x4x2_t b[3]) {
1308 const uint16x8_t ma_sum = Sum3_16(ma);
1309 const uint32x4x2_t b_sum = Sum3_32(b);
1310 return CalculateFilteredOutput<5>(s, ma_sum, b_sum);
1311 }
1312
SelfGuidedFinal(const uint8x8_t src,const int32x4_t v[2],uint8_t * const dst)1313 inline void SelfGuidedFinal(const uint8x8_t src, const int32x4_t v[2],
1314 uint8_t* const dst) {
1315 const int16x4_t v_lo =
1316 vrshrn_n_s32(v[0], kSgrProjRestoreBits + kSgrProjPrecisionBits);
1317 const int16x4_t v_hi =
1318 vrshrn_n_s32(v[1], kSgrProjRestoreBits + kSgrProjPrecisionBits);
1319 const int16x8_t vv = vcombine_s16(v_lo, v_hi);
1320 const int16x8_t s = ZeroExtend(src);
1321 const int16x8_t d = vaddq_s16(s, vv);
1322 vst1_u8(dst, vqmovun_s16(d));
1323 }
1324
SelfGuidedDoubleMultiplier(const uint8x8_t src,const int16x8_t filter[2],const int w0,const int w2,uint8_t * const dst)1325 inline void SelfGuidedDoubleMultiplier(const uint8x8_t src,
1326 const int16x8_t filter[2], const int w0,
1327 const int w2, uint8_t* const dst) {
1328 int32x4_t v[2];
1329 v[0] = vmull_n_s16(vget_low_s16(filter[0]), w0);
1330 v[1] = vmull_n_s16(vget_high_s16(filter[0]), w0);
1331 v[0] = vmlal_n_s16(v[0], vget_low_s16(filter[1]), w2);
1332 v[1] = vmlal_n_s16(v[1], vget_high_s16(filter[1]), w2);
1333 SelfGuidedFinal(src, v, dst);
1334 }
1335
SelfGuidedSingleMultiplier(const uint8x8_t src,const int16x8_t filter,const int w0,uint8_t * const dst)1336 inline void SelfGuidedSingleMultiplier(const uint8x8_t src,
1337 const int16x8_t filter, const int w0,
1338 uint8_t* const dst) {
1339 // weight: -96 to 96 (Sgrproj_Xqd_Min/Max)
1340 int32x4_t v[2];
1341 v[0] = vmull_n_s16(vget_low_s16(filter), w0);
1342 v[1] = vmull_n_s16(vget_high_s16(filter), w0);
1343 SelfGuidedFinal(src, v, dst);
1344 }
1345
BoxFilterPass1(const uint8_t * const src,const uint8_t * const src0,const uint8_t * const src1,const ptrdiff_t stride,uint16_t * const sum5[5],uint32_t * const square_sum5[5],const int width,const uint32_t scale,const int16_t w0,uint16_t * const ma565[2],uint32_t * const b565[2],uint8_t * const dst)1346 LIBGAV1_ALWAYS_INLINE void BoxFilterPass1(
1347 const uint8_t* const src, const uint8_t* const src0,
1348 const uint8_t* const src1, const ptrdiff_t stride, uint16_t* const sum5[5],
1349 uint32_t* const square_sum5[5], const int width, const uint32_t scale,
1350 const int16_t w0, uint16_t* const ma565[2], uint32_t* const b565[2],
1351 uint8_t* const dst) {
1352 uint8x8x2_t s[2], mas;
1353 uint16x8x2_t sq[2], bs;
1354 s[0].val[0] = vld1_u8(src0);
1355 s[1].val[0] = vld1_u8(src1);
1356 sq[0].val[0] = vmull_u8(s[0].val[0], s[0].val[0]);
1357 sq[1].val[0] = vmull_u8(s[1].val[0], s[1].val[0]);
1358 BoxFilterPreProcess5(src0, src1, 0, scale, sum5, square_sum5, s, sq,
1359 &mas.val[0], &bs.val[0]);
1360
1361 int x = 0;
1362 do {
1363 s[0].val[0] = s[0].val[1];
1364 s[1].val[0] = s[1].val[1];
1365 sq[0].val[0] = sq[0].val[1];
1366 sq[1].val[0] = sq[1].val[1];
1367 BoxFilterPreProcess5(src0, src1, x + 8, scale, sum5, square_sum5, s, sq,
1368 &mas.val[1], &bs.val[1]);
1369 uint16x8_t ma[2];
1370 uint32x4x2_t b[2];
1371 ma[1] = Sum565(mas);
1372 b[1] = Sum565W(bs);
1373 vst1q_u16(ma565[1] + x, ma[1]);
1374 vst1q_u32(b565[1] + x + 0, b[1].val[0]);
1375 vst1q_u32(b565[1] + x + 4, b[1].val[1]);
1376 const uint8x8_t sr0 = vld1_u8(src + x);
1377 const uint8x8_t sr1 = vld1_u8(src + stride + x);
1378 int16x8_t p0, p1;
1379 ma[0] = vld1q_u16(ma565[0] + x);
1380 b[0].val[0] = vld1q_u32(b565[0] + x + 0);
1381 b[0].val[1] = vld1q_u32(b565[0] + x + 4);
1382 p0 = CalculateFilteredOutputPass1(sr0, ma, b);
1383 p1 = CalculateFilteredOutput<4>(sr1, ma[1], b[1]);
1384 SelfGuidedSingleMultiplier(sr0, p0, w0, dst + x);
1385 SelfGuidedSingleMultiplier(sr1, p1, w0, dst + stride + x);
1386 mas.val[0] = mas.val[1];
1387 bs.val[0] = bs.val[1];
1388 x += 8;
1389 } while (x < width);
1390 }
1391
BoxFilterPass1LastRow(const uint8_t * const src,const uint8_t * const src0,const int width,const uint32_t scale,const int16_t w0,uint16_t * const sum5[5],uint32_t * const square_sum5[5],uint16_t * ma565,uint32_t * b565,uint8_t * const dst)1392 inline void BoxFilterPass1LastRow(const uint8_t* const src,
1393 const uint8_t* const src0, const int width,
1394 const uint32_t scale, const int16_t w0,
1395 uint16_t* const sum5[5],
1396 uint32_t* const square_sum5[5],
1397 uint16_t* ma565, uint32_t* b565,
1398 uint8_t* const dst) {
1399 uint8x8x2_t s, mas;
1400 uint16x8x2_t sq, bs;
1401 s.val[0] = vld1_u8(src0);
1402 sq.val[0] = vmull_u8(s.val[0], s.val[0]);
1403 BoxFilterPreProcess5LastRow(src0, 0, scale, sum5, square_sum5, &s, &sq,
1404 &mas.val[0], &bs.val[0]);
1405
1406 int x = 0;
1407 do {
1408 s.val[0] = s.val[1];
1409 sq.val[0] = sq.val[1];
1410 BoxFilterPreProcess5LastRow(src0, x + 8, scale, sum5, square_sum5, &s, &sq,
1411 &mas.val[1], &bs.val[1]);
1412 uint16x8_t ma[2];
1413 uint32x4x2_t b[2];
1414 ma[1] = Sum565(mas);
1415 b[1] = Sum565W(bs);
1416 mas.val[0] = mas.val[1];
1417 bs.val[0] = bs.val[1];
1418 ma[0] = vld1q_u16(ma565);
1419 b[0].val[0] = vld1q_u32(b565 + 0);
1420 b[0].val[1] = vld1q_u32(b565 + 4);
1421 const uint8x8_t sr = vld1_u8(src + x);
1422 const int16x8_t p = CalculateFilteredOutputPass1(sr, ma, b);
1423 SelfGuidedSingleMultiplier(sr, p, w0, dst + x);
1424 ma565 += 8;
1425 b565 += 8;
1426 x += 8;
1427 } while (x < width);
1428 }
1429
BoxFilterPass2(const uint8_t * const src,const uint8_t * const src0,const int width,const uint32_t scale,const int16_t w0,uint16_t * const sum3[3],uint32_t * const square_sum3[3],uint16_t * const ma343[3],uint16_t * const ma444[2],uint32_t * const b343[3],uint32_t * const b444[2],uint8_t * const dst)1430 LIBGAV1_ALWAYS_INLINE void BoxFilterPass2(
1431 const uint8_t* const src, const uint8_t* const src0, const int width,
1432 const uint32_t scale, const int16_t w0, uint16_t* const sum3[3],
1433 uint32_t* const square_sum3[3], uint16_t* const ma343[3],
1434 uint16_t* const ma444[2], uint32_t* const b343[3], uint32_t* const b444[2],
1435 uint8_t* const dst) {
1436 uint8x8x2_t s, mas;
1437 uint16x8x2_t sq, bs;
1438 s.val[0] = vld1_u8(src0);
1439 sq.val[0] = vmull_u8(s.val[0], s.val[0]);
1440 BoxFilterPreProcess3(src0, 0, scale, sum3, square_sum3, &s, &sq, &mas.val[0],
1441 &bs.val[0]);
1442
1443 int x = 0;
1444 do {
1445 s.val[0] = s.val[1];
1446 sq.val[0] = sq.val[1];
1447 BoxFilterPreProcess3(src0, x + 8, scale, sum3, square_sum3, &s, &sq,
1448 &mas.val[1], &bs.val[1]);
1449 uint16x8_t ma[3];
1450 uint32x4x2_t b[3];
1451 Store343_444(mas, bs, x, &ma[2], &b[2], ma343[2], ma444[1], b343[2],
1452 b444[1]);
1453 const uint8x8_t sr = vld1_u8(src + x);
1454 ma[0] = vld1q_u16(ma343[0] + x);
1455 ma[1] = vld1q_u16(ma444[0] + x);
1456 b[0].val[0] = vld1q_u32(b343[0] + x + 0);
1457 b[0].val[1] = vld1q_u32(b343[0] + x + 4);
1458 b[1].val[0] = vld1q_u32(b444[0] + x + 0);
1459 b[1].val[1] = vld1q_u32(b444[0] + x + 4);
1460 const int16x8_t p = CalculateFilteredOutputPass2(sr, ma, b);
1461 SelfGuidedSingleMultiplier(sr, p, w0, dst + x);
1462 mas.val[0] = mas.val[1];
1463 bs.val[0] = bs.val[1];
1464 x += 8;
1465 } while (x < width);
1466 }
1467
BoxFilter(const uint8_t * const src,const uint8_t * const src0,const uint8_t * const src1,const ptrdiff_t stride,const int width,const uint16_t scales[2],const int16_t w0,const int16_t w2,uint16_t * const sum3[4],uint16_t * const sum5[5],uint32_t * const square_sum3[4],uint32_t * const square_sum5[5],uint16_t * const ma343[4],uint16_t * const ma444[3],uint16_t * const ma565[2],uint32_t * const b343[4],uint32_t * const b444[3],uint32_t * const b565[2],uint8_t * const dst)1468 LIBGAV1_ALWAYS_INLINE void BoxFilter(
1469 const uint8_t* const src, const uint8_t* const src0,
1470 const uint8_t* const src1, const ptrdiff_t stride, const int width,
1471 const uint16_t scales[2], const int16_t w0, const int16_t w2,
1472 uint16_t* const sum3[4], uint16_t* const sum5[5],
1473 uint32_t* const square_sum3[4], uint32_t* const square_sum5[5],
1474 uint16_t* const ma343[4], uint16_t* const ma444[3],
1475 uint16_t* const ma565[2], uint32_t* const b343[4], uint32_t* const b444[3],
1476 uint32_t* const b565[2], uint8_t* const dst) {
1477 uint8x8x2_t s[2], ma3[2], ma5;
1478 uint16x8x2_t sq[2], b3[2], b5;
1479 s[0].val[0] = vld1_u8(src0);
1480 s[1].val[0] = vld1_u8(src1);
1481 sq[0].val[0] = vmull_u8(s[0].val[0], s[0].val[0]);
1482 sq[1].val[0] = vmull_u8(s[1].val[0], s[1].val[0]);
1483 BoxFilterPreProcess(src0, src1, 0, scales, sum3, sum5, square_sum3,
1484 square_sum5, s, sq, &ma3[0].val[0], &ma3[1].val[0],
1485 &b3[0].val[0], &b3[1].val[0], &ma5.val[0], &b5.val[0]);
1486
1487 int x = 0;
1488 do {
1489 s[0].val[0] = s[0].val[1];
1490 s[1].val[0] = s[1].val[1];
1491 sq[0].val[0] = sq[0].val[1];
1492 sq[1].val[0] = sq[1].val[1];
1493 BoxFilterPreProcess(src0, src1, x + 8, scales, sum3, sum5, square_sum3,
1494 square_sum5, s, sq, &ma3[0].val[1], &ma3[1].val[1],
1495 &b3[0].val[1], &b3[1].val[1], &ma5.val[1], &b5.val[1]);
1496 uint16x8_t ma[3][3];
1497 uint32x4x2_t b[3][3];
1498 Store343_444(ma3[0], b3[0], x, &ma[1][2], &ma[2][1], &b[1][2], &b[2][1],
1499 ma343[2], ma444[1], b343[2], b444[1]);
1500 Store343_444(ma3[1], b3[1], x, &ma[2][2], &b[2][2], ma343[3], ma444[2],
1501 b343[3], b444[2]);
1502 ma[0][1] = Sum565(ma5);
1503 b[0][1] = Sum565W(b5);
1504 vst1q_u16(ma565[1] + x, ma[0][1]);
1505 vst1q_u32(b565[1] + x, b[0][1].val[0]);
1506 vst1q_u32(b565[1] + x + 4, b[0][1].val[1]);
1507 ma3[0].val[0] = ma3[0].val[1];
1508 ma3[1].val[0] = ma3[1].val[1];
1509 b3[0].val[0] = b3[0].val[1];
1510 b3[1].val[0] = b3[1].val[1];
1511 ma5.val[0] = ma5.val[1];
1512 b5.val[0] = b5.val[1];
1513 int16x8_t p[2][2];
1514 const uint8x8_t sr0 = vld1_u8(src + x);
1515 const uint8x8_t sr1 = vld1_u8(src + stride + x);
1516 ma[0][0] = vld1q_u16(ma565[0] + x);
1517 b[0][0].val[0] = vld1q_u32(b565[0] + x);
1518 b[0][0].val[1] = vld1q_u32(b565[0] + x + 4);
1519 p[0][0] = CalculateFilteredOutputPass1(sr0, ma[0], b[0]);
1520 p[1][0] = CalculateFilteredOutput<4>(sr1, ma[0][1], b[0][1]);
1521 ma[1][0] = vld1q_u16(ma343[0] + x);
1522 ma[1][1] = vld1q_u16(ma444[0] + x);
1523 b[1][0].val[0] = vld1q_u32(b343[0] + x);
1524 b[1][0].val[1] = vld1q_u32(b343[0] + x + 4);
1525 b[1][1].val[0] = vld1q_u32(b444[0] + x);
1526 b[1][1].val[1] = vld1q_u32(b444[0] + x + 4);
1527 p[0][1] = CalculateFilteredOutputPass2(sr0, ma[1], b[1]);
1528 ma[2][0] = vld1q_u16(ma343[1] + x);
1529 b[2][0].val[0] = vld1q_u32(b343[1] + x);
1530 b[2][0].val[1] = vld1q_u32(b343[1] + x + 4);
1531 p[1][1] = CalculateFilteredOutputPass2(sr1, ma[2], b[2]);
1532 SelfGuidedDoubleMultiplier(sr0, p[0], w0, w2, dst + x);
1533 SelfGuidedDoubleMultiplier(sr1, p[1], w0, w2, dst + stride + x);
1534 x += 8;
1535 } while (x < width);
1536 }
1537
BoxFilterLastRow(const uint8_t * const src,const uint8_t * const src0,const int width,const uint16_t scales[2],const int16_t w0,const int16_t w2,uint16_t * const sum3[4],uint16_t * const sum5[5],uint32_t * const square_sum3[4],uint32_t * const square_sum5[5],uint16_t * const ma343[4],uint16_t * const ma444[3],uint16_t * const ma565[2],uint32_t * const b343[4],uint32_t * const b444[3],uint32_t * const b565[2],uint8_t * const dst)1538 inline void BoxFilterLastRow(
1539 const uint8_t* const src, const uint8_t* const src0, const int width,
1540 const uint16_t scales[2], const int16_t w0, const int16_t w2,
1541 uint16_t* const sum3[4], uint16_t* const sum5[5],
1542 uint32_t* const square_sum3[4], uint32_t* const square_sum5[5],
1543 uint16_t* const ma343[4], uint16_t* const ma444[3],
1544 uint16_t* const ma565[2], uint32_t* const b343[4], uint32_t* const b444[3],
1545 uint32_t* const b565[2], uint8_t* const dst) {
1546 uint8x8x2_t s, ma3, ma5;
1547 uint16x8x2_t sq, b3, b5;
1548 uint16x8_t ma[3];
1549 uint32x4x2_t b[3];
1550 s.val[0] = vld1_u8(src0);
1551 sq.val[0] = vmull_u8(s.val[0], s.val[0]);
1552 BoxFilterPreProcessLastRow(src0, 0, scales, sum3, sum5, square_sum3,
1553 square_sum5, &s, &sq, &ma3.val[0], &ma5.val[0],
1554 &b3.val[0], &b5.val[0]);
1555
1556 int x = 0;
1557 do {
1558 s.val[0] = s.val[1];
1559 sq.val[0] = sq.val[1];
1560 BoxFilterPreProcessLastRow(src0, x + 8, scales, sum3, sum5, square_sum3,
1561 square_sum5, &s, &sq, &ma3.val[1], &ma5.val[1],
1562 &b3.val[1], &b5.val[1]);
1563 ma[1] = Sum565(ma5);
1564 b[1] = Sum565W(b5);
1565 ma5.val[0] = ma5.val[1];
1566 b5.val[0] = b5.val[1];
1567 ma[2] = Sum343(ma3);
1568 b[2] = Sum343W(b3);
1569 ma3.val[0] = ma3.val[1];
1570 b3.val[0] = b3.val[1];
1571 const uint8x8_t sr = vld1_u8(src + x);
1572 int16x8_t p[2];
1573 ma[0] = vld1q_u16(ma565[0] + x);
1574 b[0].val[0] = vld1q_u32(b565[0] + x + 0);
1575 b[0].val[1] = vld1q_u32(b565[0] + x + 4);
1576 p[0] = CalculateFilteredOutputPass1(sr, ma, b);
1577 ma[0] = vld1q_u16(ma343[0] + x);
1578 ma[1] = vld1q_u16(ma444[0] + x);
1579 b[0].val[0] = vld1q_u32(b343[0] + x + 0);
1580 b[0].val[1] = vld1q_u32(b343[0] + x + 4);
1581 b[1].val[0] = vld1q_u32(b444[0] + x + 0);
1582 b[1].val[1] = vld1q_u32(b444[0] + x + 4);
1583 p[1] = CalculateFilteredOutputPass2(sr, ma, b);
1584 SelfGuidedDoubleMultiplier(sr, p, w0, w2, dst + x);
1585 x += 8;
1586 } while (x < width);
1587 }
1588
BoxFilterProcess(const RestorationUnitInfo & restoration_info,const uint8_t * src,const uint8_t * const top_border,const uint8_t * bottom_border,const ptrdiff_t stride,const int width,const int height,SgrBuffer * const sgr_buffer,uint8_t * dst)1589 LIBGAV1_ALWAYS_INLINE void BoxFilterProcess(
1590 const RestorationUnitInfo& restoration_info, const uint8_t* src,
1591 const uint8_t* const top_border, const uint8_t* bottom_border,
1592 const ptrdiff_t stride, const int width, const int height,
1593 SgrBuffer* const sgr_buffer, uint8_t* dst) {
1594 const auto temp_stride = Align<ptrdiff_t>(width, 8);
1595 const ptrdiff_t sum_stride = temp_stride + 8;
1596 const int sgr_proj_index = restoration_info.sgr_proj_info.index;
1597 const uint16_t* const scales = kSgrScaleParameter[sgr_proj_index]; // < 2^12.
1598 const int16_t w0 = restoration_info.sgr_proj_info.multiplier[0];
1599 const int16_t w1 = restoration_info.sgr_proj_info.multiplier[1];
1600 const int16_t w2 = (1 << kSgrProjPrecisionBits) - w0 - w1;
1601 uint16_t *sum3[4], *sum5[5], *ma343[4], *ma444[3], *ma565[2];
1602 uint32_t *square_sum3[4], *square_sum5[5], *b343[4], *b444[3], *b565[2];
1603 sum3[0] = sgr_buffer->sum3;
1604 square_sum3[0] = sgr_buffer->square_sum3;
1605 ma343[0] = sgr_buffer->ma343;
1606 b343[0] = sgr_buffer->b343;
1607 for (int i = 1; i <= 3; ++i) {
1608 sum3[i] = sum3[i - 1] + sum_stride;
1609 square_sum3[i] = square_sum3[i - 1] + sum_stride;
1610 ma343[i] = ma343[i - 1] + temp_stride;
1611 b343[i] = b343[i - 1] + temp_stride;
1612 }
1613 sum5[0] = sgr_buffer->sum5;
1614 square_sum5[0] = sgr_buffer->square_sum5;
1615 for (int i = 1; i <= 4; ++i) {
1616 sum5[i] = sum5[i - 1] + sum_stride;
1617 square_sum5[i] = square_sum5[i - 1] + sum_stride;
1618 }
1619 ma444[0] = sgr_buffer->ma444;
1620 b444[0] = sgr_buffer->b444;
1621 for (int i = 1; i <= 2; ++i) {
1622 ma444[i] = ma444[i - 1] + temp_stride;
1623 b444[i] = b444[i - 1] + temp_stride;
1624 }
1625 ma565[0] = sgr_buffer->ma565;
1626 ma565[1] = ma565[0] + temp_stride;
1627 b565[0] = sgr_buffer->b565;
1628 b565[1] = b565[0] + temp_stride;
1629 assert(scales[0] != 0);
1630 assert(scales[1] != 0);
1631 BoxSum(top_border, stride, 2, sum_stride, sum3[0], sum5[1], square_sum3[0],
1632 square_sum5[1]);
1633 sum5[0] = sum5[1];
1634 square_sum5[0] = square_sum5[1];
1635 const uint8_t* const s = (height > 1) ? src + stride : bottom_border;
1636 BoxSumFilterPreProcess(src, s, width, scales, sum3, sum5, square_sum3,
1637 square_sum5, ma343, ma444, ma565[0], b343, b444,
1638 b565[0]);
1639 sum5[0] = sgr_buffer->sum5;
1640 square_sum5[0] = sgr_buffer->square_sum5;
1641
1642 for (int y = (height >> 1) - 1; y > 0; --y) {
1643 Circulate4PointersBy2<uint16_t>(sum3);
1644 Circulate4PointersBy2<uint32_t>(square_sum3);
1645 Circulate5PointersBy2<uint16_t>(sum5);
1646 Circulate5PointersBy2<uint32_t>(square_sum5);
1647 BoxFilter(src + 3, src + 2 * stride, src + 3 * stride, stride, width,
1648 scales, w0, w2, sum3, sum5, square_sum3, square_sum5, ma343,
1649 ma444, ma565, b343, b444, b565, dst);
1650 src += 2 * stride;
1651 dst += 2 * stride;
1652 Circulate4PointersBy2<uint16_t>(ma343);
1653 Circulate4PointersBy2<uint32_t>(b343);
1654 std::swap(ma444[0], ma444[2]);
1655 std::swap(b444[0], b444[2]);
1656 std::swap(ma565[0], ma565[1]);
1657 std::swap(b565[0], b565[1]);
1658 }
1659
1660 Circulate4PointersBy2<uint16_t>(sum3);
1661 Circulate4PointersBy2<uint32_t>(square_sum3);
1662 Circulate5PointersBy2<uint16_t>(sum5);
1663 Circulate5PointersBy2<uint32_t>(square_sum5);
1664 if ((height & 1) == 0 || height > 1) {
1665 const uint8_t* sr[2];
1666 if ((height & 1) == 0) {
1667 sr[0] = bottom_border;
1668 sr[1] = bottom_border + stride;
1669 } else {
1670 sr[0] = src + 2 * stride;
1671 sr[1] = bottom_border;
1672 }
1673 BoxFilter(src + 3, sr[0], sr[1], stride, width, scales, w0, w2, sum3, sum5,
1674 square_sum3, square_sum5, ma343, ma444, ma565, b343, b444, b565,
1675 dst);
1676 }
1677 if ((height & 1) != 0) {
1678 if (height > 1) {
1679 src += 2 * stride;
1680 dst += 2 * stride;
1681 Circulate4PointersBy2<uint16_t>(sum3);
1682 Circulate4PointersBy2<uint32_t>(square_sum3);
1683 Circulate5PointersBy2<uint16_t>(sum5);
1684 Circulate5PointersBy2<uint32_t>(square_sum5);
1685 Circulate4PointersBy2<uint16_t>(ma343);
1686 Circulate4PointersBy2<uint32_t>(b343);
1687 std::swap(ma444[0], ma444[2]);
1688 std::swap(b444[0], b444[2]);
1689 std::swap(ma565[0], ma565[1]);
1690 std::swap(b565[0], b565[1]);
1691 }
1692 BoxFilterLastRow(src + 3, bottom_border + stride, width, scales, w0, w2,
1693 sum3, sum5, square_sum3, square_sum5, ma343, ma444, ma565,
1694 b343, b444, b565, dst);
1695 }
1696 }
1697
BoxFilterProcessPass1(const RestorationUnitInfo & restoration_info,const uint8_t * src,const uint8_t * const top_border,const uint8_t * bottom_border,const ptrdiff_t stride,const int width,const int height,SgrBuffer * const sgr_buffer,uint8_t * dst)1698 inline void BoxFilterProcessPass1(const RestorationUnitInfo& restoration_info,
1699 const uint8_t* src,
1700 const uint8_t* const top_border,
1701 const uint8_t* bottom_border,
1702 const ptrdiff_t stride, const int width,
1703 const int height, SgrBuffer* const sgr_buffer,
1704 uint8_t* dst) {
1705 const auto temp_stride = Align<ptrdiff_t>(width, 8);
1706 const ptrdiff_t sum_stride = temp_stride + 8;
1707 const int sgr_proj_index = restoration_info.sgr_proj_info.index;
1708 const uint32_t scale = kSgrScaleParameter[sgr_proj_index][0]; // < 2^12.
1709 const int16_t w0 = restoration_info.sgr_proj_info.multiplier[0];
1710 uint16_t *sum5[5], *ma565[2];
1711 uint32_t *square_sum5[5], *b565[2];
1712 sum5[0] = sgr_buffer->sum5;
1713 square_sum5[0] = sgr_buffer->square_sum5;
1714 for (int i = 1; i <= 4; ++i) {
1715 sum5[i] = sum5[i - 1] + sum_stride;
1716 square_sum5[i] = square_sum5[i - 1] + sum_stride;
1717 }
1718 ma565[0] = sgr_buffer->ma565;
1719 ma565[1] = ma565[0] + temp_stride;
1720 b565[0] = sgr_buffer->b565;
1721 b565[1] = b565[0] + temp_stride;
1722 assert(scale != 0);
1723 BoxSum<5>(top_border, stride, 2, sum_stride, sum5[1], square_sum5[1]);
1724 sum5[0] = sum5[1];
1725 square_sum5[0] = square_sum5[1];
1726 const uint8_t* const s = (height > 1) ? src + stride : bottom_border;
1727 BoxSumFilterPreProcess5(src, s, width, scale, sum5, square_sum5, ma565[0],
1728 b565[0]);
1729 sum5[0] = sgr_buffer->sum5;
1730 square_sum5[0] = sgr_buffer->square_sum5;
1731
1732 for (int y = (height >> 1) - 1; y > 0; --y) {
1733 Circulate5PointersBy2<uint16_t>(sum5);
1734 Circulate5PointersBy2<uint32_t>(square_sum5);
1735 BoxFilterPass1(src + 3, src + 2 * stride, src + 3 * stride, stride, sum5,
1736 square_sum5, width, scale, w0, ma565, b565, dst);
1737 src += 2 * stride;
1738 dst += 2 * stride;
1739 std::swap(ma565[0], ma565[1]);
1740 std::swap(b565[0], b565[1]);
1741 }
1742
1743 Circulate5PointersBy2<uint16_t>(sum5);
1744 Circulate5PointersBy2<uint32_t>(square_sum5);
1745 if ((height & 1) == 0 || height > 1) {
1746 const uint8_t* sr[2];
1747 if ((height & 1) == 0) {
1748 sr[0] = bottom_border;
1749 sr[1] = bottom_border + stride;
1750 } else {
1751 sr[0] = src + 2 * stride;
1752 sr[1] = bottom_border;
1753 }
1754 BoxFilterPass1(src + 3, sr[0], sr[1], stride, sum5, square_sum5, width,
1755 scale, w0, ma565, b565, dst);
1756 }
1757 if ((height & 1) != 0) {
1758 if (height > 1) {
1759 src += 2 * stride;
1760 dst += 2 * stride;
1761 std::swap(ma565[0], ma565[1]);
1762 std::swap(b565[0], b565[1]);
1763 Circulate5PointersBy2<uint16_t>(sum5);
1764 Circulate5PointersBy2<uint32_t>(square_sum5);
1765 }
1766 BoxFilterPass1LastRow(src + 3, bottom_border + stride, width, scale, w0,
1767 sum5, square_sum5, ma565[0], b565[0], dst);
1768 }
1769 }
1770
BoxFilterProcessPass2(const RestorationUnitInfo & restoration_info,const uint8_t * src,const uint8_t * const top_border,const uint8_t * bottom_border,const ptrdiff_t stride,const int width,const int height,SgrBuffer * const sgr_buffer,uint8_t * dst)1771 inline void BoxFilterProcessPass2(const RestorationUnitInfo& restoration_info,
1772 const uint8_t* src,
1773 const uint8_t* const top_border,
1774 const uint8_t* bottom_border,
1775 const ptrdiff_t stride, const int width,
1776 const int height, SgrBuffer* const sgr_buffer,
1777 uint8_t* dst) {
1778 assert(restoration_info.sgr_proj_info.multiplier[0] == 0);
1779 const auto temp_stride = Align<ptrdiff_t>(width, 8);
1780 const ptrdiff_t sum_stride = temp_stride + 8;
1781 const int16_t w1 = restoration_info.sgr_proj_info.multiplier[1];
1782 const int16_t w0 = (1 << kSgrProjPrecisionBits) - w1;
1783 const int sgr_proj_index = restoration_info.sgr_proj_info.index;
1784 const uint32_t scale = kSgrScaleParameter[sgr_proj_index][1]; // < 2^12.
1785 uint16_t *sum3[3], *ma343[3], *ma444[2];
1786 uint32_t *square_sum3[3], *b343[3], *b444[2];
1787 sum3[0] = sgr_buffer->sum3;
1788 square_sum3[0] = sgr_buffer->square_sum3;
1789 ma343[0] = sgr_buffer->ma343;
1790 b343[0] = sgr_buffer->b343;
1791 for (int i = 1; i <= 2; ++i) {
1792 sum3[i] = sum3[i - 1] + sum_stride;
1793 square_sum3[i] = square_sum3[i - 1] + sum_stride;
1794 ma343[i] = ma343[i - 1] + temp_stride;
1795 b343[i] = b343[i - 1] + temp_stride;
1796 }
1797 ma444[0] = sgr_buffer->ma444;
1798 ma444[1] = ma444[0] + temp_stride;
1799 b444[0] = sgr_buffer->b444;
1800 b444[1] = b444[0] + temp_stride;
1801 assert(scale != 0);
1802 BoxSum<3>(top_border, stride, 2, sum_stride, sum3[0], square_sum3[0]);
1803 BoxSumFilterPreProcess3<false>(src, width, scale, sum3, square_sum3, ma343[0],
1804 nullptr, b343[0], nullptr);
1805 Circulate3PointersBy1<uint16_t>(sum3);
1806 Circulate3PointersBy1<uint32_t>(square_sum3);
1807 const uint8_t* s;
1808 if (height > 1) {
1809 s = src + stride;
1810 } else {
1811 s = bottom_border;
1812 bottom_border += stride;
1813 }
1814 BoxSumFilterPreProcess3<true>(s, width, scale, sum3, square_sum3, ma343[1],
1815 ma444[0], b343[1], b444[0]);
1816
1817 for (int y = height - 2; y > 0; --y) {
1818 Circulate3PointersBy1<uint16_t>(sum3);
1819 Circulate3PointersBy1<uint32_t>(square_sum3);
1820 BoxFilterPass2(src + 2, src + 2 * stride, width, scale, w0, sum3,
1821 square_sum3, ma343, ma444, b343, b444, dst);
1822 src += stride;
1823 dst += stride;
1824 Circulate3PointersBy1<uint16_t>(ma343);
1825 Circulate3PointersBy1<uint32_t>(b343);
1826 std::swap(ma444[0], ma444[1]);
1827 std::swap(b444[0], b444[1]);
1828 }
1829
1830 src += 2;
1831 int y = std::min(height, 2);
1832 do {
1833 Circulate3PointersBy1<uint16_t>(sum3);
1834 Circulate3PointersBy1<uint32_t>(square_sum3);
1835 BoxFilterPass2(src, bottom_border, width, scale, w0, sum3, square_sum3,
1836 ma343, ma444, b343, b444, dst);
1837 src += stride;
1838 dst += stride;
1839 bottom_border += stride;
1840 Circulate3PointersBy1<uint16_t>(ma343);
1841 Circulate3PointersBy1<uint32_t>(b343);
1842 std::swap(ma444[0], ma444[1]);
1843 std::swap(b444[0], b444[1]);
1844 } while (--y != 0);
1845 }
1846
1847 // If |width| is non-multiple of 8, up to 7 more pixels are written to |dest| in
1848 // the end of each row. It is safe to overwrite the output as it will not be
1849 // part of the visible frame.
SelfGuidedFilter_NEON(const RestorationUnitInfo & restoration_info,const void * const source,const void * const top_border,const void * const bottom_border,const ptrdiff_t stride,const int width,const int height,RestorationBuffer * const restoration_buffer,void * const dest)1850 void SelfGuidedFilter_NEON(
1851 const RestorationUnitInfo& restoration_info, const void* const source,
1852 const void* const top_border, const void* const bottom_border,
1853 const ptrdiff_t stride, const int width, const int height,
1854 RestorationBuffer* const restoration_buffer, void* const dest) {
1855 const int index = restoration_info.sgr_proj_info.index;
1856 const int radius_pass_0 = kSgrProjParams[index][0]; // 2 or 0
1857 const int radius_pass_1 = kSgrProjParams[index][2]; // 1 or 0
1858 const auto* const src = static_cast<const uint8_t*>(source);
1859 const auto* top = static_cast<const uint8_t*>(top_border);
1860 const auto* bottom = static_cast<const uint8_t*>(bottom_border);
1861 auto* const dst = static_cast<uint8_t*>(dest);
1862 SgrBuffer* const sgr_buffer = &restoration_buffer->sgr_buffer;
1863 if (radius_pass_1 == 0) {
1864 // |radius_pass_0| and |radius_pass_1| cannot both be 0, so we have the
1865 // following assertion.
1866 assert(radius_pass_0 != 0);
1867 BoxFilterProcessPass1(restoration_info, src - 3, top - 3, bottom - 3,
1868 stride, width, height, sgr_buffer, dst);
1869 } else if (radius_pass_0 == 0) {
1870 BoxFilterProcessPass2(restoration_info, src - 2, top - 2, bottom - 2,
1871 stride, width, height, sgr_buffer, dst);
1872 } else {
1873 BoxFilterProcess(restoration_info, src - 3, top - 3, bottom - 3, stride,
1874 width, height, sgr_buffer, dst);
1875 }
1876 }
1877
Init8bpp()1878 void Init8bpp() {
1879 Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
1880 assert(dsp != nullptr);
1881 dsp->loop_restorations[0] = WienerFilter_NEON;
1882 dsp->loop_restorations[1] = SelfGuidedFilter_NEON;
1883 }
1884
1885 } // namespace
1886 } // namespace low_bitdepth
1887
LoopRestorationInit_NEON()1888 void LoopRestorationInit_NEON() { low_bitdepth::Init8bpp(); }
1889
1890 } // namespace dsp
1891 } // namespace libgav1
1892
1893 #else // !LIBGAV1_ENABLE_NEON
1894 namespace libgav1 {
1895 namespace dsp {
1896
LoopRestorationInit_NEON()1897 void LoopRestorationInit_NEON() {}
1898
1899 } // namespace dsp
1900 } // namespace libgav1
1901 #endif // LIBGAV1_ENABLE_NEON
1902