1 // Copyright 2019 The libgav1 Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "src/dsp/loop_restoration.h"
16 #include "src/utils/cpu.h"
17 
18 #if LIBGAV1_ENABLE_NEON
19 #include <arm_neon.h>
20 
21 #include <algorithm>
22 #include <cassert>
23 #include <cstddef>
24 #include <cstdint>
25 #include <cstring>
26 
27 #include "src/dsp/arm/common_neon.h"
28 #include "src/dsp/constants.h"
29 #include "src/dsp/dsp.h"
30 #include "src/utils/common.h"
31 #include "src/utils/constants.h"
32 
33 namespace libgav1 {
34 namespace dsp {
35 namespace low_bitdepth {
36 namespace {
37 
38 template <int bytes>
VshrU128(const uint8x8x2_t src)39 inline uint8x8_t VshrU128(const uint8x8x2_t src) {
40   return vext_u8(src.val[0], src.val[1], bytes);
41 }
42 
43 template <int bytes>
VshrU128(const uint16x8x2_t src)44 inline uint16x8_t VshrU128(const uint16x8x2_t src) {
45   return vextq_u16(src.val[0], src.val[1], bytes / 2);
46 }
47 
48 // Wiener
49 
50 // Must make a local copy of coefficients to help compiler know that they have
51 // no overlap with other buffers. Using 'const' keyword is not enough. Actually
52 // compiler doesn't make a copy, since there is enough registers in this case.
PopulateWienerCoefficients(const RestorationUnitInfo & restoration_info,const int direction,int16_t filter[4])53 inline void PopulateWienerCoefficients(
54     const RestorationUnitInfo& restoration_info, const int direction,
55     int16_t filter[4]) {
56   // In order to keep the horizontal pass intermediate values within 16 bits we
57   // offset |filter[3]| by 128. The 128 offset will be added back in the loop.
58   for (int i = 0; i < 4; ++i) {
59     filter[i] = restoration_info.wiener_info.filter[direction][i];
60   }
61   if (direction == WienerInfo::kHorizontal) {
62     filter[3] -= 128;
63   }
64 }
65 
WienerHorizontal2(const uint8x8_t s0,const uint8x8_t s1,const int16_t filter,const int16x8_t sum)66 inline int16x8_t WienerHorizontal2(const uint8x8_t s0, const uint8x8_t s1,
67                                    const int16_t filter, const int16x8_t sum) {
68   const int16x8_t ss = vreinterpretq_s16_u16(vaddl_u8(s0, s1));
69   return vmlaq_n_s16(sum, ss, filter);
70 }
71 
WienerHorizontal2(const uint8x16_t s0,const uint8x16_t s1,const int16_t filter,const int16x8x2_t sum)72 inline int16x8x2_t WienerHorizontal2(const uint8x16_t s0, const uint8x16_t s1,
73                                      const int16_t filter,
74                                      const int16x8x2_t sum) {
75   int16x8x2_t d;
76   d.val[0] =
77       WienerHorizontal2(vget_low_u8(s0), vget_low_u8(s1), filter, sum.val[0]);
78   d.val[1] =
79       WienerHorizontal2(vget_high_u8(s0), vget_high_u8(s1), filter, sum.val[1]);
80   return d;
81 }
82 
WienerHorizontalSum(const uint8x8_t s[3],const int16_t filter[4],int16x8_t sum,int16_t * const wiener_buffer)83 inline void WienerHorizontalSum(const uint8x8_t s[3], const int16_t filter[4],
84                                 int16x8_t sum, int16_t* const wiener_buffer) {
85   constexpr int offset =
86       1 << (8 + kWienerFilterBits - kInterRoundBitsHorizontal - 1);
87   constexpr int limit = (offset << 2) - 1;
88   const int16x8_t s_0_2 = vreinterpretq_s16_u16(vaddl_u8(s[0], s[2]));
89   const int16x8_t s_1 = ZeroExtend(s[1]);
90   sum = vmlaq_n_s16(sum, s_0_2, filter[2]);
91   sum = vmlaq_n_s16(sum, s_1, filter[3]);
92   // Calculate scaled down offset correction, and add to sum here to prevent
93   // signed 16 bit outranging.
94   sum = vrsraq_n_s16(vshlq_n_s16(s_1, 7 - kInterRoundBitsHorizontal), sum,
95                      kInterRoundBitsHorizontal);
96   sum = vmaxq_s16(sum, vdupq_n_s16(-offset));
97   sum = vminq_s16(sum, vdupq_n_s16(limit - offset));
98   vst1q_s16(wiener_buffer, sum);
99 }
100 
WienerHorizontalSum(const uint8x16_t src[3],const int16_t filter[4],int16x8x2_t sum,int16_t * const wiener_buffer)101 inline void WienerHorizontalSum(const uint8x16_t src[3],
102                                 const int16_t filter[4], int16x8x2_t sum,
103                                 int16_t* const wiener_buffer) {
104   uint8x8_t s[3];
105   s[0] = vget_low_u8(src[0]);
106   s[1] = vget_low_u8(src[1]);
107   s[2] = vget_low_u8(src[2]);
108   WienerHorizontalSum(s, filter, sum.val[0], wiener_buffer);
109   s[0] = vget_high_u8(src[0]);
110   s[1] = vget_high_u8(src[1]);
111   s[2] = vget_high_u8(src[2]);
112   WienerHorizontalSum(s, filter, sum.val[1], wiener_buffer + 8);
113 }
114 
WienerHorizontalTap7(const uint8_t * src,const ptrdiff_t src_stride,const ptrdiff_t width,const int height,const int16_t filter[4],int16_t ** const wiener_buffer)115 inline void WienerHorizontalTap7(const uint8_t* src, const ptrdiff_t src_stride,
116                                  const ptrdiff_t width, const int height,
117                                  const int16_t filter[4],
118                                  int16_t** const wiener_buffer) {
119   for (int y = height; y != 0; --y) {
120     const uint8_t* src_ptr = src;
121     uint8x16_t s[8];
122     s[0] = vld1q_u8(src_ptr);
123     ptrdiff_t x = width;
124     do {
125       src_ptr += 16;
126       s[7] = vld1q_u8(src_ptr);
127       s[1] = vextq_u8(s[0], s[7], 1);
128       s[2] = vextq_u8(s[0], s[7], 2);
129       s[3] = vextq_u8(s[0], s[7], 3);
130       s[4] = vextq_u8(s[0], s[7], 4);
131       s[5] = vextq_u8(s[0], s[7], 5);
132       s[6] = vextq_u8(s[0], s[7], 6);
133       int16x8x2_t sum;
134       sum.val[0] = sum.val[1] = vdupq_n_s16(0);
135       sum = WienerHorizontal2(s[0], s[6], filter[0], sum);
136       sum = WienerHorizontal2(s[1], s[5], filter[1], sum);
137       WienerHorizontalSum(s + 2, filter, sum, *wiener_buffer);
138       s[0] = s[7];
139       *wiener_buffer += 16;
140       x -= 16;
141     } while (x != 0);
142     src += src_stride;
143   }
144 }
145 
WienerHorizontalTap5(const uint8_t * src,const ptrdiff_t src_stride,const ptrdiff_t width,const int height,const int16_t filter[4],int16_t ** const wiener_buffer)146 inline void WienerHorizontalTap5(const uint8_t* src, const ptrdiff_t src_stride,
147                                  const ptrdiff_t width, const int height,
148                                  const int16_t filter[4],
149                                  int16_t** const wiener_buffer) {
150   for (int y = height; y != 0; --y) {
151     const uint8_t* src_ptr = src;
152     uint8x16_t s[6];
153     s[0] = vld1q_u8(src_ptr);
154     ptrdiff_t x = width;
155     do {
156       src_ptr += 16;
157       s[5] = vld1q_u8(src_ptr);
158       s[1] = vextq_u8(s[0], s[5], 1);
159       s[2] = vextq_u8(s[0], s[5], 2);
160       s[3] = vextq_u8(s[0], s[5], 3);
161       s[4] = vextq_u8(s[0], s[5], 4);
162       int16x8x2_t sum;
163       sum.val[0] = sum.val[1] = vdupq_n_s16(0);
164       sum = WienerHorizontal2(s[0], s[4], filter[1], sum);
165       WienerHorizontalSum(s + 1, filter, sum, *wiener_buffer);
166       s[0] = s[5];
167       *wiener_buffer += 16;
168       x -= 16;
169     } while (x != 0);
170     src += src_stride;
171   }
172 }
173 
WienerHorizontalTap3(const uint8_t * src,const ptrdiff_t src_stride,const ptrdiff_t width,const int height,const int16_t filter[4],int16_t ** const wiener_buffer)174 inline void WienerHorizontalTap3(const uint8_t* src, const ptrdiff_t src_stride,
175                                  const ptrdiff_t width, const int height,
176                                  const int16_t filter[4],
177                                  int16_t** const wiener_buffer) {
178   for (int y = height; y != 0; --y) {
179     const uint8_t* src_ptr = src;
180     uint8x16_t s[4];
181     s[0] = vld1q_u8(src_ptr);
182     ptrdiff_t x = width;
183     do {
184       src_ptr += 16;
185       s[3] = vld1q_u8(src_ptr);
186       s[1] = vextq_u8(s[0], s[3], 1);
187       s[2] = vextq_u8(s[0], s[3], 2);
188       int16x8x2_t sum;
189       sum.val[0] = sum.val[1] = vdupq_n_s16(0);
190       WienerHorizontalSum(s, filter, sum, *wiener_buffer);
191       s[0] = s[3];
192       *wiener_buffer += 16;
193       x -= 16;
194     } while (x != 0);
195     src += src_stride;
196   }
197 }
198 
WienerHorizontalTap1(const uint8_t * src,const ptrdiff_t src_stride,const ptrdiff_t width,const int height,int16_t ** const wiener_buffer)199 inline void WienerHorizontalTap1(const uint8_t* src, const ptrdiff_t src_stride,
200                                  const ptrdiff_t width, const int height,
201                                  int16_t** const wiener_buffer) {
202   for (int y = height; y != 0; --y) {
203     const uint8_t* src_ptr = src;
204     ptrdiff_t x = width;
205     do {
206       const uint8x16_t s = vld1q_u8(src_ptr);
207       const uint8x8_t s0 = vget_low_u8(s);
208       const uint8x8_t s1 = vget_high_u8(s);
209       const int16x8_t d0 = vreinterpretq_s16_u16(vshll_n_u8(s0, 4));
210       const int16x8_t d1 = vreinterpretq_s16_u16(vshll_n_u8(s1, 4));
211       vst1q_s16(*wiener_buffer + 0, d0);
212       vst1q_s16(*wiener_buffer + 8, d1);
213       src_ptr += 16;
214       *wiener_buffer += 16;
215       x -= 16;
216     } while (x != 0);
217     src += src_stride;
218   }
219 }
220 
WienerVertical2(const int16x8_t a0,const int16x8_t a1,const int16_t filter,const int32x4x2_t sum)221 inline int32x4x2_t WienerVertical2(const int16x8_t a0, const int16x8_t a1,
222                                    const int16_t filter,
223                                    const int32x4x2_t sum) {
224   const int16x8_t a = vaddq_s16(a0, a1);
225   int32x4x2_t d;
226   d.val[0] = vmlal_n_s16(sum.val[0], vget_low_s16(a), filter);
227   d.val[1] = vmlal_n_s16(sum.val[1], vget_high_s16(a), filter);
228   return d;
229 }
230 
WienerVertical(const int16x8_t a[3],const int16_t filter[4],const int32x4x2_t sum)231 inline uint8x8_t WienerVertical(const int16x8_t a[3], const int16_t filter[4],
232                                 const int32x4x2_t sum) {
233   int32x4x2_t d = WienerVertical2(a[0], a[2], filter[2], sum);
234   d.val[0] = vmlal_n_s16(d.val[0], vget_low_s16(a[1]), filter[3]);
235   d.val[1] = vmlal_n_s16(d.val[1], vget_high_s16(a[1]), filter[3]);
236   const uint16x4_t sum_lo_16 = vqrshrun_n_s32(d.val[0], 11);
237   const uint16x4_t sum_hi_16 = vqrshrun_n_s32(d.val[1], 11);
238   return vqmovn_u16(vcombine_u16(sum_lo_16, sum_hi_16));
239 }
240 
WienerVerticalTap7Kernel(const int16_t * const wiener_buffer,const ptrdiff_t wiener_stride,const int16_t filter[4],int16x8_t a[7])241 inline uint8x8_t WienerVerticalTap7Kernel(const int16_t* const wiener_buffer,
242                                           const ptrdiff_t wiener_stride,
243                                           const int16_t filter[4],
244                                           int16x8_t a[7]) {
245   int32x4x2_t sum;
246   a[0] = vld1q_s16(wiener_buffer + 0 * wiener_stride);
247   a[1] = vld1q_s16(wiener_buffer + 1 * wiener_stride);
248   a[5] = vld1q_s16(wiener_buffer + 5 * wiener_stride);
249   a[6] = vld1q_s16(wiener_buffer + 6 * wiener_stride);
250   sum.val[0] = sum.val[1] = vdupq_n_s32(0);
251   sum = WienerVertical2(a[0], a[6], filter[0], sum);
252   sum = WienerVertical2(a[1], a[5], filter[1], sum);
253   a[2] = vld1q_s16(wiener_buffer + 2 * wiener_stride);
254   a[3] = vld1q_s16(wiener_buffer + 3 * wiener_stride);
255   a[4] = vld1q_s16(wiener_buffer + 4 * wiener_stride);
256   return WienerVertical(a + 2, filter, sum);
257 }
258 
WienerVerticalTap7Kernel2(const int16_t * const wiener_buffer,const ptrdiff_t wiener_stride,const int16_t filter[4])259 inline uint8x8x2_t WienerVerticalTap7Kernel2(const int16_t* const wiener_buffer,
260                                              const ptrdiff_t wiener_stride,
261                                              const int16_t filter[4]) {
262   int16x8_t a[8];
263   int32x4x2_t sum;
264   uint8x8x2_t d;
265   d.val[0] = WienerVerticalTap7Kernel(wiener_buffer, wiener_stride, filter, a);
266   a[7] = vld1q_s16(wiener_buffer + 7 * wiener_stride);
267   sum.val[0] = sum.val[1] = vdupq_n_s32(0);
268   sum = WienerVertical2(a[1], a[7], filter[0], sum);
269   sum = WienerVertical2(a[2], a[6], filter[1], sum);
270   d.val[1] = WienerVertical(a + 3, filter, sum);
271   return d;
272 }
273 
WienerVerticalTap7(const int16_t * wiener_buffer,const ptrdiff_t width,const int height,const int16_t filter[4],uint8_t * dst,const ptrdiff_t dst_stride)274 inline void WienerVerticalTap7(const int16_t* wiener_buffer,
275                                const ptrdiff_t width, const int height,
276                                const int16_t filter[4], uint8_t* dst,
277                                const ptrdiff_t dst_stride) {
278   for (int y = height >> 1; y != 0; --y) {
279     uint8_t* dst_ptr = dst;
280     ptrdiff_t x = width;
281     do {
282       uint8x8x2_t d[2];
283       d[0] = WienerVerticalTap7Kernel2(wiener_buffer + 0, width, filter);
284       d[1] = WienerVerticalTap7Kernel2(wiener_buffer + 8, width, filter);
285       vst1q_u8(dst_ptr, vcombine_u8(d[0].val[0], d[1].val[0]));
286       vst1q_u8(dst_ptr + dst_stride, vcombine_u8(d[0].val[1], d[1].val[1]));
287       wiener_buffer += 16;
288       dst_ptr += 16;
289       x -= 16;
290     } while (x != 0);
291     wiener_buffer += width;
292     dst += 2 * dst_stride;
293   }
294 
295   if ((height & 1) != 0) {
296     ptrdiff_t x = width;
297     do {
298       int16x8_t a[7];
299       const uint8x8_t d0 =
300           WienerVerticalTap7Kernel(wiener_buffer + 0, width, filter, a);
301       const uint8x8_t d1 =
302           WienerVerticalTap7Kernel(wiener_buffer + 8, width, filter, a);
303       vst1q_u8(dst, vcombine_u8(d0, d1));
304       wiener_buffer += 16;
305       dst += 16;
306       x -= 16;
307     } while (x != 0);
308   }
309 }
310 
WienerVerticalTap5Kernel(const int16_t * const wiener_buffer,const ptrdiff_t wiener_stride,const int16_t filter[4],int16x8_t a[5])311 inline uint8x8_t WienerVerticalTap5Kernel(const int16_t* const wiener_buffer,
312                                           const ptrdiff_t wiener_stride,
313                                           const int16_t filter[4],
314                                           int16x8_t a[5]) {
315   a[0] = vld1q_s16(wiener_buffer + 0 * wiener_stride);
316   a[1] = vld1q_s16(wiener_buffer + 1 * wiener_stride);
317   a[2] = vld1q_s16(wiener_buffer + 2 * wiener_stride);
318   a[3] = vld1q_s16(wiener_buffer + 3 * wiener_stride);
319   a[4] = vld1q_s16(wiener_buffer + 4 * wiener_stride);
320   int32x4x2_t sum;
321   sum.val[0] = sum.val[1] = vdupq_n_s32(0);
322   sum = WienerVertical2(a[0], a[4], filter[1], sum);
323   return WienerVertical(a + 1, filter, sum);
324 }
325 
WienerVerticalTap5Kernel2(const int16_t * const wiener_buffer,const ptrdiff_t wiener_stride,const int16_t filter[4])326 inline uint8x8x2_t WienerVerticalTap5Kernel2(const int16_t* const wiener_buffer,
327                                              const ptrdiff_t wiener_stride,
328                                              const int16_t filter[4]) {
329   int16x8_t a[6];
330   int32x4x2_t sum;
331   uint8x8x2_t d;
332   d.val[0] = WienerVerticalTap5Kernel(wiener_buffer, wiener_stride, filter, a);
333   a[5] = vld1q_s16(wiener_buffer + 5 * wiener_stride);
334   sum.val[0] = sum.val[1] = vdupq_n_s32(0);
335   sum = WienerVertical2(a[1], a[5], filter[1], sum);
336   d.val[1] = WienerVertical(a + 2, filter, sum);
337   return d;
338 }
339 
WienerVerticalTap5(const int16_t * wiener_buffer,const ptrdiff_t width,const int height,const int16_t filter[4],uint8_t * dst,const ptrdiff_t dst_stride)340 inline void WienerVerticalTap5(const int16_t* wiener_buffer,
341                                const ptrdiff_t width, const int height,
342                                const int16_t filter[4], uint8_t* dst,
343                                const ptrdiff_t dst_stride) {
344   for (int y = height >> 1; y != 0; --y) {
345     uint8_t* dst_ptr = dst;
346     ptrdiff_t x = width;
347     do {
348       uint8x8x2_t d[2];
349       d[0] = WienerVerticalTap5Kernel2(wiener_buffer + 0, width, filter);
350       d[1] = WienerVerticalTap5Kernel2(wiener_buffer + 8, width, filter);
351       vst1q_u8(dst_ptr, vcombine_u8(d[0].val[0], d[1].val[0]));
352       vst1q_u8(dst_ptr + dst_stride, vcombine_u8(d[0].val[1], d[1].val[1]));
353       wiener_buffer += 16;
354       dst_ptr += 16;
355       x -= 16;
356     } while (x != 0);
357     wiener_buffer += width;
358     dst += 2 * dst_stride;
359   }
360 
361   if ((height & 1) != 0) {
362     ptrdiff_t x = width;
363     do {
364       int16x8_t a[5];
365       const uint8x8_t d0 =
366           WienerVerticalTap5Kernel(wiener_buffer + 0, width, filter, a);
367       const uint8x8_t d1 =
368           WienerVerticalTap5Kernel(wiener_buffer + 8, width, filter, a);
369       vst1q_u8(dst, vcombine_u8(d0, d1));
370       wiener_buffer += 16;
371       dst += 16;
372       x -= 16;
373     } while (x != 0);
374   }
375 }
376 
WienerVerticalTap3Kernel(const int16_t * const wiener_buffer,const ptrdiff_t wiener_stride,const int16_t filter[4],int16x8_t a[3])377 inline uint8x8_t WienerVerticalTap3Kernel(const int16_t* const wiener_buffer,
378                                           const ptrdiff_t wiener_stride,
379                                           const int16_t filter[4],
380                                           int16x8_t a[3]) {
381   a[0] = vld1q_s16(wiener_buffer + 0 * wiener_stride);
382   a[1] = vld1q_s16(wiener_buffer + 1 * wiener_stride);
383   a[2] = vld1q_s16(wiener_buffer + 2 * wiener_stride);
384   int32x4x2_t sum;
385   sum.val[0] = sum.val[1] = vdupq_n_s32(0);
386   return WienerVertical(a, filter, sum);
387 }
388 
WienerVerticalTap3Kernel2(const int16_t * const wiener_buffer,const ptrdiff_t wiener_stride,const int16_t filter[4])389 inline uint8x8x2_t WienerVerticalTap3Kernel2(const int16_t* const wiener_buffer,
390                                              const ptrdiff_t wiener_stride,
391                                              const int16_t filter[4]) {
392   int16x8_t a[4];
393   int32x4x2_t sum;
394   uint8x8x2_t d;
395   d.val[0] = WienerVerticalTap3Kernel(wiener_buffer, wiener_stride, filter, a);
396   a[3] = vld1q_s16(wiener_buffer + 3 * wiener_stride);
397   sum.val[0] = sum.val[1] = vdupq_n_s32(0);
398   d.val[1] = WienerVertical(a + 1, filter, sum);
399   return d;
400 }
401 
WienerVerticalTap3(const int16_t * wiener_buffer,const ptrdiff_t width,const int height,const int16_t filter[4],uint8_t * dst,const ptrdiff_t dst_stride)402 inline void WienerVerticalTap3(const int16_t* wiener_buffer,
403                                const ptrdiff_t width, const int height,
404                                const int16_t filter[4], uint8_t* dst,
405                                const ptrdiff_t dst_stride) {
406   for (int y = height >> 1; y != 0; --y) {
407     uint8_t* dst_ptr = dst;
408     ptrdiff_t x = width;
409     do {
410       uint8x8x2_t d[2];
411       d[0] = WienerVerticalTap3Kernel2(wiener_buffer + 0, width, filter);
412       d[1] = WienerVerticalTap3Kernel2(wiener_buffer + 8, width, filter);
413       vst1q_u8(dst_ptr, vcombine_u8(d[0].val[0], d[1].val[0]));
414       vst1q_u8(dst_ptr + dst_stride, vcombine_u8(d[0].val[1], d[1].val[1]));
415       wiener_buffer += 16;
416       dst_ptr += 16;
417       x -= 16;
418     } while (x != 0);
419     wiener_buffer += width;
420     dst += 2 * dst_stride;
421   }
422 
423   if ((height & 1) != 0) {
424     ptrdiff_t x = width;
425     do {
426       int16x8_t a[3];
427       const uint8x8_t d0 =
428           WienerVerticalTap3Kernel(wiener_buffer + 0, width, filter, a);
429       const uint8x8_t d1 =
430           WienerVerticalTap3Kernel(wiener_buffer + 8, width, filter, a);
431       vst1q_u8(dst, vcombine_u8(d0, d1));
432       wiener_buffer += 16;
433       dst += 16;
434       x -= 16;
435     } while (x != 0);
436   }
437 }
438 
WienerVerticalTap1Kernel(const int16_t * const wiener_buffer,uint8_t * const dst)439 inline void WienerVerticalTap1Kernel(const int16_t* const wiener_buffer,
440                                      uint8_t* const dst) {
441   const int16x8_t a0 = vld1q_s16(wiener_buffer + 0);
442   const int16x8_t a1 = vld1q_s16(wiener_buffer + 8);
443   const uint8x8_t d0 = vqrshrun_n_s16(a0, 4);
444   const uint8x8_t d1 = vqrshrun_n_s16(a1, 4);
445   vst1q_u8(dst, vcombine_u8(d0, d1));
446 }
447 
WienerVerticalTap1(const int16_t * wiener_buffer,const ptrdiff_t width,const int height,uint8_t * dst,const ptrdiff_t dst_stride)448 inline void WienerVerticalTap1(const int16_t* wiener_buffer,
449                                const ptrdiff_t width, const int height,
450                                uint8_t* dst, const ptrdiff_t dst_stride) {
451   for (int y = height >> 1; y != 0; --y) {
452     uint8_t* dst_ptr = dst;
453     ptrdiff_t x = width;
454     do {
455       WienerVerticalTap1Kernel(wiener_buffer, dst_ptr);
456       WienerVerticalTap1Kernel(wiener_buffer + width, dst_ptr + dst_stride);
457       wiener_buffer += 16;
458       dst_ptr += 16;
459       x -= 16;
460     } while (x != 0);
461     wiener_buffer += width;
462     dst += 2 * dst_stride;
463   }
464 
465   if ((height & 1) != 0) {
466     ptrdiff_t x = width;
467     do {
468       WienerVerticalTap1Kernel(wiener_buffer, dst);
469       wiener_buffer += 16;
470       dst += 16;
471       x -= 16;
472     } while (x != 0);
473   }
474 }
475 
476 // For width 16 and up, store the horizontal results, and then do the vertical
477 // filter row by row. This is faster than doing it column by column when
478 // considering cache issues.
WienerFilter_NEON(const RestorationUnitInfo & restoration_info,const void * const source,const void * const top_border,const void * const bottom_border,const ptrdiff_t stride,const int width,const int height,RestorationBuffer * const restoration_buffer,void * const dest)479 void WienerFilter_NEON(const RestorationUnitInfo& restoration_info,
480                        const void* const source, const void* const top_border,
481                        const void* const bottom_border, const ptrdiff_t stride,
482                        const int width, const int height,
483                        RestorationBuffer* const restoration_buffer,
484                        void* const dest) {
485   const int16_t* const number_leading_zero_coefficients =
486       restoration_info.wiener_info.number_leading_zero_coefficients;
487   const int number_rows_to_skip = std::max(
488       static_cast<int>(number_leading_zero_coefficients[WienerInfo::kVertical]),
489       1);
490   const ptrdiff_t wiener_stride = Align(width, 16);
491   int16_t* const wiener_buffer_vertical = restoration_buffer->wiener_buffer;
492   // The values are saturated to 13 bits before storing.
493   int16_t* wiener_buffer_horizontal =
494       wiener_buffer_vertical + number_rows_to_skip * wiener_stride;
495   int16_t filter_horizontal[(kWienerFilterTaps + 1) / 2];
496   int16_t filter_vertical[(kWienerFilterTaps + 1) / 2];
497   PopulateWienerCoefficients(restoration_info, WienerInfo::kHorizontal,
498                              filter_horizontal);
499   PopulateWienerCoefficients(restoration_info, WienerInfo::kVertical,
500                              filter_vertical);
501 
502   // horizontal filtering.
503   // Over-reads up to 15 - |kRestorationHorizontalBorder| values.
504   const int height_horizontal =
505       height + kWienerFilterTaps - 1 - 2 * number_rows_to_skip;
506   const int height_extra = (height_horizontal - height) >> 1;
507   assert(height_extra <= 2);
508   const auto* const src = static_cast<const uint8_t*>(source);
509   const auto* const top = static_cast<const uint8_t*>(top_border);
510   const auto* const bottom = static_cast<const uint8_t*>(bottom_border);
511   if (number_leading_zero_coefficients[WienerInfo::kHorizontal] == 0) {
512     WienerHorizontalTap7(top + (2 - height_extra) * stride - 3, stride,
513                          wiener_stride, height_extra, filter_horizontal,
514                          &wiener_buffer_horizontal);
515     WienerHorizontalTap7(src - 3, stride, wiener_stride, height,
516                          filter_horizontal, &wiener_buffer_horizontal);
517     WienerHorizontalTap7(bottom - 3, stride, wiener_stride, height_extra,
518                          filter_horizontal, &wiener_buffer_horizontal);
519   } else if (number_leading_zero_coefficients[WienerInfo::kHorizontal] == 1) {
520     WienerHorizontalTap5(top + (2 - height_extra) * stride - 2, stride,
521                          wiener_stride, height_extra, filter_horizontal,
522                          &wiener_buffer_horizontal);
523     WienerHorizontalTap5(src - 2, stride, wiener_stride, height,
524                          filter_horizontal, &wiener_buffer_horizontal);
525     WienerHorizontalTap5(bottom - 2, stride, wiener_stride, height_extra,
526                          filter_horizontal, &wiener_buffer_horizontal);
527   } else if (number_leading_zero_coefficients[WienerInfo::kHorizontal] == 2) {
528     // The maximum over-reads happen here.
529     WienerHorizontalTap3(top + (2 - height_extra) * stride - 1, stride,
530                          wiener_stride, height_extra, filter_horizontal,
531                          &wiener_buffer_horizontal);
532     WienerHorizontalTap3(src - 1, stride, wiener_stride, height,
533                          filter_horizontal, &wiener_buffer_horizontal);
534     WienerHorizontalTap3(bottom - 1, stride, wiener_stride, height_extra,
535                          filter_horizontal, &wiener_buffer_horizontal);
536   } else {
537     assert(number_leading_zero_coefficients[WienerInfo::kHorizontal] == 3);
538     WienerHorizontalTap1(top + (2 - height_extra) * stride, stride,
539                          wiener_stride, height_extra,
540                          &wiener_buffer_horizontal);
541     WienerHorizontalTap1(src, stride, wiener_stride, height,
542                          &wiener_buffer_horizontal);
543     WienerHorizontalTap1(bottom, stride, wiener_stride, height_extra,
544                          &wiener_buffer_horizontal);
545   }
546 
547   // vertical filtering.
548   // Over-writes up to 15 values.
549   auto* dst = static_cast<uint8_t*>(dest);
550   if (number_leading_zero_coefficients[WienerInfo::kVertical] == 0) {
551     // Because the top row of |source| is a duplicate of the second row, and the
552     // bottom row of |source| is a duplicate of its above row, we can duplicate
553     // the top and bottom row of |wiener_buffer| accordingly.
554     memcpy(wiener_buffer_horizontal, wiener_buffer_horizontal - wiener_stride,
555            sizeof(*wiener_buffer_horizontal) * wiener_stride);
556     memcpy(restoration_buffer->wiener_buffer,
557            restoration_buffer->wiener_buffer + wiener_stride,
558            sizeof(*restoration_buffer->wiener_buffer) * wiener_stride);
559     WienerVerticalTap7(wiener_buffer_vertical, wiener_stride, height,
560                        filter_vertical, dst, stride);
561   } else if (number_leading_zero_coefficients[WienerInfo::kVertical] == 1) {
562     WienerVerticalTap5(wiener_buffer_vertical + wiener_stride, wiener_stride,
563                        height, filter_vertical, dst, stride);
564   } else if (number_leading_zero_coefficients[WienerInfo::kVertical] == 2) {
565     WienerVerticalTap3(wiener_buffer_vertical + 2 * wiener_stride,
566                        wiener_stride, height, filter_vertical, dst, stride);
567   } else {
568     assert(number_leading_zero_coefficients[WienerInfo::kVertical] == 3);
569     WienerVerticalTap1(wiener_buffer_vertical + 3 * wiener_stride,
570                        wiener_stride, height, dst, stride);
571   }
572 }
573 
574 //------------------------------------------------------------------------------
575 // SGR
576 
Prepare3_8(const uint8x8x2_t src,uint8x8_t dst[3])577 inline void Prepare3_8(const uint8x8x2_t src, uint8x8_t dst[3]) {
578   dst[0] = VshrU128<0>(src);
579   dst[1] = VshrU128<1>(src);
580   dst[2] = VshrU128<2>(src);
581 }
582 
Prepare3_16(const uint16x8x2_t src,uint16x4_t low[3],uint16x4_t high[3])583 inline void Prepare3_16(const uint16x8x2_t src, uint16x4_t low[3],
584                         uint16x4_t high[3]) {
585   uint16x8_t s[3];
586   s[0] = VshrU128<0>(src);
587   s[1] = VshrU128<2>(src);
588   s[2] = VshrU128<4>(src);
589   low[0] = vget_low_u16(s[0]);
590   low[1] = vget_low_u16(s[1]);
591   low[2] = vget_low_u16(s[2]);
592   high[0] = vget_high_u16(s[0]);
593   high[1] = vget_high_u16(s[1]);
594   high[2] = vget_high_u16(s[2]);
595 }
596 
Prepare5_8(const uint8x8x2_t src,uint8x8_t dst[5])597 inline void Prepare5_8(const uint8x8x2_t src, uint8x8_t dst[5]) {
598   dst[0] = VshrU128<0>(src);
599   dst[1] = VshrU128<1>(src);
600   dst[2] = VshrU128<2>(src);
601   dst[3] = VshrU128<3>(src);
602   dst[4] = VshrU128<4>(src);
603 }
604 
Prepare5_16(const uint16x8x2_t src,uint16x4_t low[5],uint16x4_t high[5])605 inline void Prepare5_16(const uint16x8x2_t src, uint16x4_t low[5],
606                         uint16x4_t high[5]) {
607   Prepare3_16(src, low, high);
608   const uint16x8_t s3 = VshrU128<6>(src);
609   const uint16x8_t s4 = VshrU128<8>(src);
610   low[3] = vget_low_u16(s3);
611   low[4] = vget_low_u16(s4);
612   high[3] = vget_high_u16(s3);
613   high[4] = vget_high_u16(s4);
614 }
615 
Sum3_16(const uint16x8_t src0,const uint16x8_t src1,const uint16x8_t src2)616 inline uint16x8_t Sum3_16(const uint16x8_t src0, const uint16x8_t src1,
617                           const uint16x8_t src2) {
618   const uint16x8_t sum = vaddq_u16(src0, src1);
619   return vaddq_u16(sum, src2);
620 }
621 
Sum3_16(const uint16x8_t src[3])622 inline uint16x8_t Sum3_16(const uint16x8_t src[3]) {
623   return Sum3_16(src[0], src[1], src[2]);
624 }
625 
Sum3_32(const uint32x4_t src0,const uint32x4_t src1,const uint32x4_t src2)626 inline uint32x4_t Sum3_32(const uint32x4_t src0, const uint32x4_t src1,
627                           const uint32x4_t src2) {
628   const uint32x4_t sum = vaddq_u32(src0, src1);
629   return vaddq_u32(sum, src2);
630 }
631 
Sum3_32(const uint32x4x2_t src[3])632 inline uint32x4x2_t Sum3_32(const uint32x4x2_t src[3]) {
633   uint32x4x2_t d;
634   d.val[0] = Sum3_32(src[0].val[0], src[1].val[0], src[2].val[0]);
635   d.val[1] = Sum3_32(src[0].val[1], src[1].val[1], src[2].val[1]);
636   return d;
637 }
638 
Sum3W_16(const uint8x8_t src[3])639 inline uint16x8_t Sum3W_16(const uint8x8_t src[3]) {
640   const uint16x8_t sum = vaddl_u8(src[0], src[1]);
641   return vaddw_u8(sum, src[2]);
642 }
643 
Sum3W_32(const uint16x4_t src[3])644 inline uint32x4_t Sum3W_32(const uint16x4_t src[3]) {
645   const uint32x4_t sum = vaddl_u16(src[0], src[1]);
646   return vaddw_u16(sum, src[2]);
647 }
648 
Sum5_16(const uint16x8_t src[5])649 inline uint16x8_t Sum5_16(const uint16x8_t src[5]) {
650   const uint16x8_t sum01 = vaddq_u16(src[0], src[1]);
651   const uint16x8_t sum23 = vaddq_u16(src[2], src[3]);
652   const uint16x8_t sum = vaddq_u16(sum01, sum23);
653   return vaddq_u16(sum, src[4]);
654 }
655 
Sum5_32(const uint32x4_t src0,const uint32x4_t src1,const uint32x4_t src2,const uint32x4_t src3,const uint32x4_t src4)656 inline uint32x4_t Sum5_32(const uint32x4_t src0, const uint32x4_t src1,
657                           const uint32x4_t src2, const uint32x4_t src3,
658                           const uint32x4_t src4) {
659   const uint32x4_t sum01 = vaddq_u32(src0, src1);
660   const uint32x4_t sum23 = vaddq_u32(src2, src3);
661   const uint32x4_t sum = vaddq_u32(sum01, sum23);
662   return vaddq_u32(sum, src4);
663 }
664 
Sum5_32(const uint32x4x2_t src[5])665 inline uint32x4x2_t Sum5_32(const uint32x4x2_t src[5]) {
666   uint32x4x2_t d;
667   d.val[0] = Sum5_32(src[0].val[0], src[1].val[0], src[2].val[0], src[3].val[0],
668                      src[4].val[0]);
669   d.val[1] = Sum5_32(src[0].val[1], src[1].val[1], src[2].val[1], src[3].val[1],
670                      src[4].val[1]);
671   return d;
672 }
673 
Sum5W_32(const uint16x4_t src[5])674 inline uint32x4_t Sum5W_32(const uint16x4_t src[5]) {
675   const uint32x4_t sum01 = vaddl_u16(src[0], src[1]);
676   const uint32x4_t sum23 = vaddl_u16(src[2], src[3]);
677   const uint32x4_t sum0123 = vaddq_u32(sum01, sum23);
678   return vaddw_u16(sum0123, src[4]);
679 }
680 
Sum3Horizontal(const uint8x8x2_t src)681 inline uint16x8_t Sum3Horizontal(const uint8x8x2_t src) {
682   uint8x8_t s[3];
683   Prepare3_8(src, s);
684   return Sum3W_16(s);
685 }
686 
Sum3WHorizontal(const uint16x8x2_t src)687 inline uint32x4x2_t Sum3WHorizontal(const uint16x8x2_t src) {
688   uint16x4_t low[3], high[3];
689   uint32x4x2_t sum;
690   Prepare3_16(src, low, high);
691   sum.val[0] = Sum3W_32(low);
692   sum.val[1] = Sum3W_32(high);
693   return sum;
694 }
695 
Sum5Horizontal(const uint8x8x2_t src)696 inline uint16x8_t Sum5Horizontal(const uint8x8x2_t src) {
697   uint8x8_t s[5];
698   Prepare5_8(src, s);
699   const uint16x8_t sum01 = vaddl_u8(s[0], s[1]);
700   const uint16x8_t sum23 = vaddl_u8(s[2], s[3]);
701   const uint16x8_t sum0123 = vaddq_u16(sum01, sum23);
702   return vaddw_u8(sum0123, s[4]);
703 }
704 
Sum5WHorizontal(const uint16x8x2_t src)705 inline uint32x4x2_t Sum5WHorizontal(const uint16x8x2_t src) {
706   uint16x4_t low[5], high[5];
707   Prepare5_16(src, low, high);
708   uint32x4x2_t sum;
709   sum.val[0] = Sum5W_32(low);
710   sum.val[1] = Sum5W_32(high);
711   return sum;
712 }
713 
SumHorizontal(const uint16x4_t src[5],uint32x4_t * const row_sq3,uint32x4_t * const row_sq5)714 void SumHorizontal(const uint16x4_t src[5], uint32x4_t* const row_sq3,
715                    uint32x4_t* const row_sq5) {
716   const uint32x4_t sum04 = vaddl_u16(src[0], src[4]);
717   const uint32x4_t sum12 = vaddl_u16(src[1], src[2]);
718   *row_sq3 = vaddw_u16(sum12, src[3]);
719   *row_sq5 = vaddq_u32(sum04, *row_sq3);
720 }
721 
SumHorizontal(const uint8x8x2_t src,const uint16x8x2_t sq,uint16x8_t * const row3,uint16x8_t * const row5,uint32x4x2_t * const row_sq3,uint32x4x2_t * const row_sq5)722 void SumHorizontal(const uint8x8x2_t src, const uint16x8x2_t sq,
723                    uint16x8_t* const row3, uint16x8_t* const row5,
724                    uint32x4x2_t* const row_sq3, uint32x4x2_t* const row_sq5) {
725   uint8x8_t s[5];
726   Prepare5_8(src, s);
727   const uint16x8_t sum04 = vaddl_u8(s[0], s[4]);
728   const uint16x8_t sum12 = vaddl_u8(s[1], s[2]);
729   *row3 = vaddw_u8(sum12, s[3]);
730   *row5 = vaddq_u16(sum04, *row3);
731   uint16x4_t low[5], high[5];
732   Prepare5_16(sq, low, high);
733   SumHorizontal(low, &row_sq3->val[0], &row_sq5->val[0]);
734   SumHorizontal(high, &row_sq3->val[1], &row_sq5->val[1]);
735 }
736 
Sum343(const uint8x8x2_t src)737 inline uint16x8_t Sum343(const uint8x8x2_t src) {
738   uint8x8_t s[3];
739   Prepare3_8(src, s);
740   const uint16x8_t sum = Sum3W_16(s);
741   const uint16x8_t sum3 = Sum3_16(sum, sum, sum);
742   return vaddw_u8(sum3, s[1]);
743 }
744 
Sum343W(const uint16x4_t src[3])745 inline uint32x4_t Sum343W(const uint16x4_t src[3]) {
746   const uint32x4_t sum = Sum3W_32(src);
747   const uint32x4_t sum3 = Sum3_32(sum, sum, sum);
748   return vaddw_u16(sum3, src[1]);
749 }
750 
Sum343W(const uint16x8x2_t src)751 inline uint32x4x2_t Sum343W(const uint16x8x2_t src) {
752   uint16x4_t low[3], high[3];
753   uint32x4x2_t d;
754   Prepare3_16(src, low, high);
755   d.val[0] = Sum343W(low);
756   d.val[1] = Sum343W(high);
757   return d;
758 }
759 
Sum565(const uint8x8x2_t src)760 inline uint16x8_t Sum565(const uint8x8x2_t src) {
761   uint8x8_t s[3];
762   Prepare3_8(src, s);
763   const uint16x8_t sum = Sum3W_16(s);
764   const uint16x8_t sum4 = vshlq_n_u16(sum, 2);
765   const uint16x8_t sum5 = vaddq_u16(sum4, sum);
766   return vaddw_u8(sum5, s[1]);
767 }
768 
Sum565W(const uint16x4_t src[3])769 inline uint32x4_t Sum565W(const uint16x4_t src[3]) {
770   const uint32x4_t sum = Sum3W_32(src);
771   const uint32x4_t sum4 = vshlq_n_u32(sum, 2);
772   const uint32x4_t sum5 = vaddq_u32(sum4, sum);
773   return vaddw_u16(sum5, src[1]);
774 }
775 
Sum565W(const uint16x8x2_t src)776 inline uint32x4x2_t Sum565W(const uint16x8x2_t src) {
777   uint16x4_t low[3], high[3];
778   uint32x4x2_t d;
779   Prepare3_16(src, low, high);
780   d.val[0] = Sum565W(low);
781   d.val[1] = Sum565W(high);
782   return d;
783 }
784 
BoxSum(const uint8_t * src,const ptrdiff_t src_stride,const int height,const ptrdiff_t sum_stride,uint16_t * sum3,uint16_t * sum5,uint32_t * square_sum3,uint32_t * square_sum5)785 inline void BoxSum(const uint8_t* src, const ptrdiff_t src_stride,
786                    const int height, const ptrdiff_t sum_stride, uint16_t* sum3,
787                    uint16_t* sum5, uint32_t* square_sum3,
788                    uint32_t* square_sum5) {
789   int y = height;
790   do {
791     uint8x8x2_t s;
792     uint16x8x2_t sq;
793     s.val[0] = vld1_u8(src);
794     sq.val[0] = vmull_u8(s.val[0], s.val[0]);
795     ptrdiff_t x = 0;
796     do {
797       uint16x8_t row3, row5;
798       uint32x4x2_t row_sq3, row_sq5;
799       s.val[1] = vld1_u8(src + x + 8);
800       sq.val[1] = vmull_u8(s.val[1], s.val[1]);
801       SumHorizontal(s, sq, &row3, &row5, &row_sq3, &row_sq5);
802       vst1q_u16(sum3, row3);
803       vst1q_u16(sum5, row5);
804       vst1q_u32(square_sum3 + 0, row_sq3.val[0]);
805       vst1q_u32(square_sum3 + 4, row_sq3.val[1]);
806       vst1q_u32(square_sum5 + 0, row_sq5.val[0]);
807       vst1q_u32(square_sum5 + 4, row_sq5.val[1]);
808       s.val[0] = s.val[1];
809       sq.val[0] = sq.val[1];
810       sum3 += 8;
811       sum5 += 8;
812       square_sum3 += 8;
813       square_sum5 += 8;
814       x += 8;
815     } while (x < sum_stride);
816     src += src_stride;
817   } while (--y != 0);
818 }
819 
820 template <int size>
BoxSum(const uint8_t * src,const ptrdiff_t src_stride,const int height,const ptrdiff_t sum_stride,uint16_t * sums,uint32_t * square_sums)821 inline void BoxSum(const uint8_t* src, const ptrdiff_t src_stride,
822                    const int height, const ptrdiff_t sum_stride, uint16_t* sums,
823                    uint32_t* square_sums) {
824   static_assert(size == 3 || size == 5, "");
825   int y = height;
826   do {
827     uint8x8x2_t s;
828     uint16x8x2_t sq;
829     s.val[0] = vld1_u8(src);
830     sq.val[0] = vmull_u8(s.val[0], s.val[0]);
831     ptrdiff_t x = 0;
832     do {
833       uint16x8_t row;
834       uint32x4x2_t row_sq;
835       s.val[1] = vld1_u8(src + x + 8);
836       sq.val[1] = vmull_u8(s.val[1], s.val[1]);
837       if (size == 3) {
838         row = Sum3Horizontal(s);
839         row_sq = Sum3WHorizontal(sq);
840       } else {
841         row = Sum5Horizontal(s);
842         row_sq = Sum5WHorizontal(sq);
843       }
844       vst1q_u16(sums, row);
845       vst1q_u32(square_sums + 0, row_sq.val[0]);
846       vst1q_u32(square_sums + 4, row_sq.val[1]);
847       s.val[0] = s.val[1];
848       sq.val[0] = sq.val[1];
849       sums += 8;
850       square_sums += 8;
851       x += 8;
852     } while (x < sum_stride);
853     src += src_stride;
854   } while (--y != 0);
855 }
856 
857 template <int n>
CalculateMa(const uint16x4_t sum,const uint32x4_t sum_sq,const uint32_t scale)858 inline uint16x4_t CalculateMa(const uint16x4_t sum, const uint32x4_t sum_sq,
859                               const uint32_t scale) {
860   // a = |sum_sq|
861   // d = |sum|
862   // p = (a * n < d * d) ? 0 : a * n - d * d;
863   const uint32x4_t dxd = vmull_u16(sum, sum);
864   const uint32x4_t axn = vmulq_n_u32(sum_sq, n);
865   // Ensure |p| does not underflow by using saturating subtraction.
866   const uint32x4_t p = vqsubq_u32(axn, dxd);
867   const uint32x4_t pxs = vmulq_n_u32(p, scale);
868   // vrshrn_n_u32() (narrowing shift) can only shift by 16 and kSgrProjScaleBits
869   // is 20.
870   const uint32x4_t shifted = vrshrq_n_u32(pxs, kSgrProjScaleBits);
871   return vmovn_u32(shifted);
872 }
873 
874 template <int n>
CalculateIntermediate(const uint16x8_t sum,const uint32x4x2_t sum_sq,const uint32_t scale,uint8x8_t * const ma,uint16x8_t * const b)875 inline void CalculateIntermediate(const uint16x8_t sum,
876                                   const uint32x4x2_t sum_sq,
877                                   const uint32_t scale, uint8x8_t* const ma,
878                                   uint16x8_t* const b) {
879   constexpr uint32_t one_over_n =
880       ((1 << kSgrProjReciprocalBits) + (n >> 1)) / n;
881   const uint16x4_t z0 = CalculateMa<n>(vget_low_u16(sum), sum_sq.val[0], scale);
882   const uint16x4_t z1 =
883       CalculateMa<n>(vget_high_u16(sum), sum_sq.val[1], scale);
884   const uint16x8_t z01 = vcombine_u16(z0, z1);
885   // Using vqmovn_u16() needs an extra sign extension instruction.
886   const uint16x8_t z = vminq_u16(z01, vdupq_n_u16(255));
887   // Using vgetq_lane_s16() can save the sign extension instruction.
888   const uint8_t lookup[8] = {
889       kSgrMaLookup[vgetq_lane_s16(vreinterpretq_s16_u16(z), 0)],
890       kSgrMaLookup[vgetq_lane_s16(vreinterpretq_s16_u16(z), 1)],
891       kSgrMaLookup[vgetq_lane_s16(vreinterpretq_s16_u16(z), 2)],
892       kSgrMaLookup[vgetq_lane_s16(vreinterpretq_s16_u16(z), 3)],
893       kSgrMaLookup[vgetq_lane_s16(vreinterpretq_s16_u16(z), 4)],
894       kSgrMaLookup[vgetq_lane_s16(vreinterpretq_s16_u16(z), 5)],
895       kSgrMaLookup[vgetq_lane_s16(vreinterpretq_s16_u16(z), 6)],
896       kSgrMaLookup[vgetq_lane_s16(vreinterpretq_s16_u16(z), 7)]};
897   *ma = vld1_u8(lookup);
898   // b = ma * b * one_over_n
899   // |ma| = [0, 255]
900   // |sum| is a box sum with radius 1 or 2.
901   // For the first pass radius is 2. Maximum value is 5x5x255 = 6375.
902   // For the second pass radius is 1. Maximum value is 3x3x255 = 2295.
903   // |one_over_n| = ((1 << kSgrProjReciprocalBits) + (n >> 1)) / n
904   // When radius is 2 |n| is 25. |one_over_n| is 164.
905   // When radius is 1 |n| is 9. |one_over_n| is 455.
906   // |kSgrProjReciprocalBits| is 12.
907   // Radius 2: 255 * 6375 * 164 >> 12 = 65088 (16 bits).
908   // Radius 1: 255 * 2295 * 455 >> 12 = 65009 (16 bits).
909   const uint16x8_t maq = vmovl_u8(*ma);
910   const uint32x4_t m0 = vmull_u16(vget_low_u16(maq), vget_low_u16(sum));
911   const uint32x4_t m1 = vmull_u16(vget_high_u16(maq), vget_high_u16(sum));
912   const uint32x4_t m2 = vmulq_n_u32(m0, one_over_n);
913   const uint32x4_t m3 = vmulq_n_u32(m1, one_over_n);
914   const uint16x4_t b_lo = vrshrn_n_u32(m2, kSgrProjReciprocalBits);
915   const uint16x4_t b_hi = vrshrn_n_u32(m3, kSgrProjReciprocalBits);
916   *b = vcombine_u16(b_lo, b_hi);
917 }
918 
CalculateIntermediate5(const uint16x8_t s5[5],const uint32x4x2_t sq5[5],const uint32_t scale,uint8x8_t * const ma,uint16x8_t * const b)919 inline void CalculateIntermediate5(const uint16x8_t s5[5],
920                                    const uint32x4x2_t sq5[5],
921                                    const uint32_t scale, uint8x8_t* const ma,
922                                    uint16x8_t* const b) {
923   const uint16x8_t sum = Sum5_16(s5);
924   const uint32x4x2_t sum_sq = Sum5_32(sq5);
925   CalculateIntermediate<25>(sum, sum_sq, scale, ma, b);
926 }
927 
CalculateIntermediate3(const uint16x8_t s3[3],const uint32x4x2_t sq3[3],const uint32_t scale,uint8x8_t * const ma,uint16x8_t * const b)928 inline void CalculateIntermediate3(const uint16x8_t s3[3],
929                                    const uint32x4x2_t sq3[3],
930                                    const uint32_t scale, uint8x8_t* const ma,
931                                    uint16x8_t* const b) {
932   const uint16x8_t sum = Sum3_16(s3);
933   const uint32x4x2_t sum_sq = Sum3_32(sq3);
934   CalculateIntermediate<9>(sum, sum_sq, scale, ma, b);
935 }
936 
Store343_444(const uint8x8x2_t ma3,const uint16x8x2_t b3,const ptrdiff_t x,uint16x8_t * const sum_ma343,uint16x8_t * const sum_ma444,uint32x4x2_t * const sum_b343,uint32x4x2_t * const sum_b444,uint16_t * const ma343,uint16_t * const ma444,uint32_t * const b343,uint32_t * const b444)937 inline void Store343_444(const uint8x8x2_t ma3, const uint16x8x2_t b3,
938                          const ptrdiff_t x, uint16x8_t* const sum_ma343,
939                          uint16x8_t* const sum_ma444,
940                          uint32x4x2_t* const sum_b343,
941                          uint32x4x2_t* const sum_b444, uint16_t* const ma343,
942                          uint16_t* const ma444, uint32_t* const b343,
943                          uint32_t* const b444) {
944   uint8x8_t s[3];
945   Prepare3_8(ma3, s);
946   const uint16x8_t sum_ma111 = Sum3W_16(s);
947   *sum_ma444 = vshlq_n_u16(sum_ma111, 2);
948   const uint16x8_t sum333 = vsubq_u16(*sum_ma444, sum_ma111);
949   *sum_ma343 = vaddw_u8(sum333, s[1]);
950   uint16x4_t low[3], high[3];
951   uint32x4x2_t sum_b111;
952   Prepare3_16(b3, low, high);
953   sum_b111.val[0] = Sum3W_32(low);
954   sum_b111.val[1] = Sum3W_32(high);
955   sum_b444->val[0] = vshlq_n_u32(sum_b111.val[0], 2);
956   sum_b444->val[1] = vshlq_n_u32(sum_b111.val[1], 2);
957   sum_b343->val[0] = vsubq_u32(sum_b444->val[0], sum_b111.val[0]);
958   sum_b343->val[1] = vsubq_u32(sum_b444->val[1], sum_b111.val[1]);
959   sum_b343->val[0] = vaddw_u16(sum_b343->val[0], low[1]);
960   sum_b343->val[1] = vaddw_u16(sum_b343->val[1], high[1]);
961   vst1q_u16(ma343 + x, *sum_ma343);
962   vst1q_u16(ma444 + x, *sum_ma444);
963   vst1q_u32(b343 + x + 0, sum_b343->val[0]);
964   vst1q_u32(b343 + x + 4, sum_b343->val[1]);
965   vst1q_u32(b444 + x + 0, sum_b444->val[0]);
966   vst1q_u32(b444 + x + 4, sum_b444->val[1]);
967 }
968 
Store343_444(const uint8x8x2_t ma3,const uint16x8x2_t b3,const ptrdiff_t x,uint16x8_t * const sum_ma343,uint32x4x2_t * const sum_b343,uint16_t * const ma343,uint16_t * const ma444,uint32_t * const b343,uint32_t * const b444)969 inline void Store343_444(const uint8x8x2_t ma3, const uint16x8x2_t b3,
970                          const ptrdiff_t x, uint16x8_t* const sum_ma343,
971                          uint32x4x2_t* const sum_b343, uint16_t* const ma343,
972                          uint16_t* const ma444, uint32_t* const b343,
973                          uint32_t* const b444) {
974   uint16x8_t sum_ma444;
975   uint32x4x2_t sum_b444;
976   Store343_444(ma3, b3, x, sum_ma343, &sum_ma444, sum_b343, &sum_b444, ma343,
977                ma444, b343, b444);
978 }
979 
Store343_444(const uint8x8x2_t ma3,const uint16x8x2_t b3,const ptrdiff_t x,uint16_t * const ma343,uint16_t * const ma444,uint32_t * const b343,uint32_t * const b444)980 inline void Store343_444(const uint8x8x2_t ma3, const uint16x8x2_t b3,
981                          const ptrdiff_t x, uint16_t* const ma343,
982                          uint16_t* const ma444, uint32_t* const b343,
983                          uint32_t* const b444) {
984   uint16x8_t sum_ma343;
985   uint32x4x2_t sum_b343;
986   Store343_444(ma3, b3, x, &sum_ma343, &sum_b343, ma343, ma444, b343, b444);
987 }
988 
BoxFilterPreProcess5(const uint8_t * const src0,const uint8_t * const src1,const ptrdiff_t x,const uint32_t scale,uint16_t * const sum5[5],uint32_t * const square_sum5[5],uint8x8x2_t s[2],uint16x8x2_t sq[2],uint8x8_t * const ma,uint16x8_t * const b)989 LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess5(
990     const uint8_t* const src0, const uint8_t* const src1, const ptrdiff_t x,
991     const uint32_t scale, uint16_t* const sum5[5],
992     uint32_t* const square_sum5[5], uint8x8x2_t s[2], uint16x8x2_t sq[2],
993     uint8x8_t* const ma, uint16x8_t* const b) {
994   uint16x8_t s5[5];
995   uint32x4x2_t sq5[5];
996   s[0].val[1] = vld1_u8(src0 + x + 8);
997   s[1].val[1] = vld1_u8(src1 + x + 8);
998   sq[0].val[1] = vmull_u8(s[0].val[1], s[0].val[1]);
999   sq[1].val[1] = vmull_u8(s[1].val[1], s[1].val[1]);
1000   s5[3] = Sum5Horizontal(s[0]);
1001   s5[4] = Sum5Horizontal(s[1]);
1002   sq5[3] = Sum5WHorizontal(sq[0]);
1003   sq5[4] = Sum5WHorizontal(sq[1]);
1004   vst1q_u16(sum5[3] + x, s5[3]);
1005   vst1q_u16(sum5[4] + x, s5[4]);
1006   vst1q_u32(square_sum5[3] + x + 0, sq5[3].val[0]);
1007   vst1q_u32(square_sum5[3] + x + 4, sq5[3].val[1]);
1008   vst1q_u32(square_sum5[4] + x + 0, sq5[4].val[0]);
1009   vst1q_u32(square_sum5[4] + x + 4, sq5[4].val[1]);
1010   s5[0] = vld1q_u16(sum5[0] + x);
1011   s5[1] = vld1q_u16(sum5[1] + x);
1012   s5[2] = vld1q_u16(sum5[2] + x);
1013   sq5[0].val[0] = vld1q_u32(square_sum5[0] + x + 0);
1014   sq5[0].val[1] = vld1q_u32(square_sum5[0] + x + 4);
1015   sq5[1].val[0] = vld1q_u32(square_sum5[1] + x + 0);
1016   sq5[1].val[1] = vld1q_u32(square_sum5[1] + x + 4);
1017   sq5[2].val[0] = vld1q_u32(square_sum5[2] + x + 0);
1018   sq5[2].val[1] = vld1q_u32(square_sum5[2] + x + 4);
1019   CalculateIntermediate5(s5, sq5, scale, ma, b);
1020 }
1021 
BoxFilterPreProcess5LastRow(const uint8_t * const src,const ptrdiff_t x,const uint32_t scale,const uint16_t * const sum5[5],const uint32_t * const square_sum5[5],uint8x8x2_t * const s,uint16x8x2_t * const sq,uint8x8_t * const ma,uint16x8_t * const b)1022 LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess5LastRow(
1023     const uint8_t* const src, const ptrdiff_t x, const uint32_t scale,
1024     const uint16_t* const sum5[5], const uint32_t* const square_sum5[5],
1025     uint8x8x2_t* const s, uint16x8x2_t* const sq, uint8x8_t* const ma,
1026     uint16x8_t* const b) {
1027   uint16x8_t s5[5];
1028   uint32x4x2_t sq5[5];
1029   s->val[1] = vld1_u8(src + x + 8);
1030   sq->val[1] = vmull_u8(s->val[1], s->val[1]);
1031   s5[3] = s5[4] = Sum5Horizontal(*s);
1032   sq5[3] = sq5[4] = Sum5WHorizontal(*sq);
1033   s5[0] = vld1q_u16(sum5[0] + x);
1034   s5[1] = vld1q_u16(sum5[1] + x);
1035   s5[2] = vld1q_u16(sum5[2] + x);
1036   sq5[0].val[0] = vld1q_u32(square_sum5[0] + x + 0);
1037   sq5[0].val[1] = vld1q_u32(square_sum5[0] + x + 4);
1038   sq5[1].val[0] = vld1q_u32(square_sum5[1] + x + 0);
1039   sq5[1].val[1] = vld1q_u32(square_sum5[1] + x + 4);
1040   sq5[2].val[0] = vld1q_u32(square_sum5[2] + x + 0);
1041   sq5[2].val[1] = vld1q_u32(square_sum5[2] + x + 4);
1042   CalculateIntermediate5(s5, sq5, scale, ma, b);
1043 }
1044 
BoxFilterPreProcess3(const uint8_t * const src,const ptrdiff_t x,const uint32_t scale,uint16_t * const sum3[3],uint32_t * const square_sum3[3],uint8x8x2_t * const s,uint16x8x2_t * const sq,uint8x8_t * const ma,uint16x8_t * const b)1045 LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess3(
1046     const uint8_t* const src, const ptrdiff_t x, const uint32_t scale,
1047     uint16_t* const sum3[3], uint32_t* const square_sum3[3],
1048     uint8x8x2_t* const s, uint16x8x2_t* const sq, uint8x8_t* const ma,
1049     uint16x8_t* const b) {
1050   uint16x8_t s3[3];
1051   uint32x4x2_t sq3[3];
1052   s->val[1] = vld1_u8(src + x + 8);
1053   sq->val[1] = vmull_u8(s->val[1], s->val[1]);
1054   s3[2] = Sum3Horizontal(*s);
1055   sq3[2] = Sum3WHorizontal(*sq);
1056   vst1q_u16(sum3[2] + x, s3[2]);
1057   vst1q_u32(square_sum3[2] + x + 0, sq3[2].val[0]);
1058   vst1q_u32(square_sum3[2] + x + 4, sq3[2].val[1]);
1059   s3[0] = vld1q_u16(sum3[0] + x);
1060   s3[1] = vld1q_u16(sum3[1] + x);
1061   sq3[0].val[0] = vld1q_u32(square_sum3[0] + x + 0);
1062   sq3[0].val[1] = vld1q_u32(square_sum3[0] + x + 4);
1063   sq3[1].val[0] = vld1q_u32(square_sum3[1] + x + 0);
1064   sq3[1].val[1] = vld1q_u32(square_sum3[1] + x + 4);
1065   CalculateIntermediate3(s3, sq3, scale, ma, b);
1066 }
1067 
BoxFilterPreProcess(const uint8_t * const src0,const uint8_t * const src1,const ptrdiff_t x,const uint16_t scales[2],uint16_t * const sum3[4],uint16_t * const sum5[5],uint32_t * const square_sum3[4],uint32_t * const square_sum5[5],uint8x8x2_t s[2],uint16x8x2_t sq[2],uint8x8_t * const ma3_0,uint8x8_t * const ma3_1,uint16x8_t * const b3_0,uint16x8_t * const b3_1,uint8x8_t * const ma5,uint16x8_t * const b5)1068 LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess(
1069     const uint8_t* const src0, const uint8_t* const src1, const ptrdiff_t x,
1070     const uint16_t scales[2], uint16_t* const sum3[4], uint16_t* const sum5[5],
1071     uint32_t* const square_sum3[4], uint32_t* const square_sum5[5],
1072     uint8x8x2_t s[2], uint16x8x2_t sq[2], uint8x8_t* const ma3_0,
1073     uint8x8_t* const ma3_1, uint16x8_t* const b3_0, uint16x8_t* const b3_1,
1074     uint8x8_t* const ma5, uint16x8_t* const b5) {
1075   uint16x8_t s3[4], s5[5];
1076   uint32x4x2_t sq3[4], sq5[5];
1077   s[0].val[1] = vld1_u8(src0 + x + 8);
1078   s[1].val[1] = vld1_u8(src1 + x + 8);
1079   sq[0].val[1] = vmull_u8(s[0].val[1], s[0].val[1]);
1080   sq[1].val[1] = vmull_u8(s[1].val[1], s[1].val[1]);
1081   SumHorizontal(s[0], sq[0], &s3[2], &s5[3], &sq3[2], &sq5[3]);
1082   SumHorizontal(s[1], sq[1], &s3[3], &s5[4], &sq3[3], &sq5[4]);
1083   vst1q_u16(sum3[2] + x, s3[2]);
1084   vst1q_u16(sum3[3] + x, s3[3]);
1085   vst1q_u32(square_sum3[2] + x + 0, sq3[2].val[0]);
1086   vst1q_u32(square_sum3[2] + x + 4, sq3[2].val[1]);
1087   vst1q_u32(square_sum3[3] + x + 0, sq3[3].val[0]);
1088   vst1q_u32(square_sum3[3] + x + 4, sq3[3].val[1]);
1089   vst1q_u16(sum5[3] + x, s5[3]);
1090   vst1q_u16(sum5[4] + x, s5[4]);
1091   vst1q_u32(square_sum5[3] + x + 0, sq5[3].val[0]);
1092   vst1q_u32(square_sum5[3] + x + 4, sq5[3].val[1]);
1093   vst1q_u32(square_sum5[4] + x + 0, sq5[4].val[0]);
1094   vst1q_u32(square_sum5[4] + x + 4, sq5[4].val[1]);
1095   s3[0] = vld1q_u16(sum3[0] + x);
1096   s3[1] = vld1q_u16(sum3[1] + x);
1097   sq3[0].val[0] = vld1q_u32(square_sum3[0] + x + 0);
1098   sq3[0].val[1] = vld1q_u32(square_sum3[0] + x + 4);
1099   sq3[1].val[0] = vld1q_u32(square_sum3[1] + x + 0);
1100   sq3[1].val[1] = vld1q_u32(square_sum3[1] + x + 4);
1101   s5[0] = vld1q_u16(sum5[0] + x);
1102   s5[1] = vld1q_u16(sum5[1] + x);
1103   s5[2] = vld1q_u16(sum5[2] + x);
1104   sq5[0].val[0] = vld1q_u32(square_sum5[0] + x + 0);
1105   sq5[0].val[1] = vld1q_u32(square_sum5[0] + x + 4);
1106   sq5[1].val[0] = vld1q_u32(square_sum5[1] + x + 0);
1107   sq5[1].val[1] = vld1q_u32(square_sum5[1] + x + 4);
1108   sq5[2].val[0] = vld1q_u32(square_sum5[2] + x + 0);
1109   sq5[2].val[1] = vld1q_u32(square_sum5[2] + x + 4);
1110   CalculateIntermediate3(s3, sq3, scales[1], ma3_0, b3_0);
1111   CalculateIntermediate3(s3 + 1, sq3 + 1, scales[1], ma3_1, b3_1);
1112   CalculateIntermediate5(s5, sq5, scales[0], ma5, b5);
1113 }
1114 
BoxFilterPreProcessLastRow(const uint8_t * const src,const ptrdiff_t x,const uint16_t scales[2],const uint16_t * const sum3[4],const uint16_t * const sum5[5],const uint32_t * const square_sum3[4],const uint32_t * const square_sum5[5],uint8x8x2_t * const s,uint16x8x2_t * const sq,uint8x8_t * const ma3,uint8x8_t * const ma5,uint16x8_t * const b3,uint16x8_t * const b5)1115 LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcessLastRow(
1116     const uint8_t* const src, const ptrdiff_t x, const uint16_t scales[2],
1117     const uint16_t* const sum3[4], const uint16_t* const sum5[5],
1118     const uint32_t* const square_sum3[4], const uint32_t* const square_sum5[5],
1119     uint8x8x2_t* const s, uint16x8x2_t* const sq, uint8x8_t* const ma3,
1120     uint8x8_t* const ma5, uint16x8_t* const b3, uint16x8_t* const b5) {
1121   uint16x8_t s3[3], s5[5];
1122   uint32x4x2_t sq3[3], sq5[5];
1123   s->val[1] = vld1_u8(src + x + 8);
1124   sq->val[1] = vmull_u8(s->val[1], s->val[1]);
1125   SumHorizontal(*s, *sq, &s3[2], &s5[3], &sq3[2], &sq5[3]);
1126   s5[0] = vld1q_u16(sum5[0] + x);
1127   s5[1] = vld1q_u16(sum5[1] + x);
1128   s5[2] = vld1q_u16(sum5[2] + x);
1129   s5[4] = s5[3];
1130   sq5[0].val[0] = vld1q_u32(square_sum5[0] + x + 0);
1131   sq5[0].val[1] = vld1q_u32(square_sum5[0] + x + 4);
1132   sq5[1].val[0] = vld1q_u32(square_sum5[1] + x + 0);
1133   sq5[1].val[1] = vld1q_u32(square_sum5[1] + x + 4);
1134   sq5[2].val[0] = vld1q_u32(square_sum5[2] + x + 0);
1135   sq5[2].val[1] = vld1q_u32(square_sum5[2] + x + 4);
1136   sq5[4] = sq5[3];
1137   CalculateIntermediate5(s5, sq5, scales[0], ma5, b5);
1138   s3[0] = vld1q_u16(sum3[0] + x);
1139   s3[1] = vld1q_u16(sum3[1] + x);
1140   sq3[0].val[0] = vld1q_u32(square_sum3[0] + x + 0);
1141   sq3[0].val[1] = vld1q_u32(square_sum3[0] + x + 4);
1142   sq3[1].val[0] = vld1q_u32(square_sum3[1] + x + 0);
1143   sq3[1].val[1] = vld1q_u32(square_sum3[1] + x + 4);
1144   CalculateIntermediate3(s3, sq3, scales[1], ma3, b3);
1145 }
1146 
BoxSumFilterPreProcess5(const uint8_t * const src0,const uint8_t * const src1,const int width,const uint32_t scale,uint16_t * const sum5[5],uint32_t * const square_sum5[5],uint16_t * ma565,uint32_t * b565)1147 inline void BoxSumFilterPreProcess5(const uint8_t* const src0,
1148                                     const uint8_t* const src1, const int width,
1149                                     const uint32_t scale,
1150                                     uint16_t* const sum5[5],
1151                                     uint32_t* const square_sum5[5],
1152                                     uint16_t* ma565, uint32_t* b565) {
1153   uint8x8x2_t s[2], mas;
1154   uint16x8x2_t sq[2], bs;
1155   s[0].val[0] = vld1_u8(src0);
1156   s[1].val[0] = vld1_u8(src1);
1157   sq[0].val[0] = vmull_u8(s[0].val[0], s[0].val[0]);
1158   sq[1].val[0] = vmull_u8(s[1].val[0], s[1].val[0]);
1159   BoxFilterPreProcess5(src0, src1, 0, scale, sum5, square_sum5, s, sq,
1160                        &mas.val[0], &bs.val[0]);
1161 
1162   int x = 0;
1163   do {
1164     s[0].val[0] = s[0].val[1];
1165     s[1].val[0] = s[1].val[1];
1166     sq[0].val[0] = sq[0].val[1];
1167     sq[1].val[0] = sq[1].val[1];
1168     BoxFilterPreProcess5(src0, src1, x + 8, scale, sum5, square_sum5, s, sq,
1169                          &mas.val[1], &bs.val[1]);
1170     const uint16x8_t ma = Sum565(mas);
1171     const uint32x4x2_t b = Sum565W(bs);
1172     vst1q_u16(ma565, ma);
1173     vst1q_u32(b565 + 0, b.val[0]);
1174     vst1q_u32(b565 + 4, b.val[1]);
1175     mas.val[0] = mas.val[1];
1176     bs.val[0] = bs.val[1];
1177     ma565 += 8;
1178     b565 += 8;
1179     x += 8;
1180   } while (x < width);
1181 }
1182 
1183 template <bool calculate444>
BoxSumFilterPreProcess3(const uint8_t * const src,const int width,const uint32_t scale,uint16_t * const sum3[3],uint32_t * const square_sum3[3],uint16_t * ma343,uint16_t * ma444,uint32_t * b343,uint32_t * b444)1184 LIBGAV1_ALWAYS_INLINE void BoxSumFilterPreProcess3(
1185     const uint8_t* const src, const int width, const uint32_t scale,
1186     uint16_t* const sum3[3], uint32_t* const square_sum3[3], uint16_t* ma343,
1187     uint16_t* ma444, uint32_t* b343, uint32_t* b444) {
1188   uint8x8x2_t s, mas;
1189   uint16x8x2_t sq, bs;
1190   s.val[0] = vld1_u8(src);
1191   sq.val[0] = vmull_u8(s.val[0], s.val[0]);
1192   BoxFilterPreProcess3(src, 0, scale, sum3, square_sum3, &s, &sq, &mas.val[0],
1193                        &bs.val[0]);
1194 
1195   int x = 0;
1196   do {
1197     s.val[0] = s.val[1];
1198     sq.val[0] = sq.val[1];
1199     BoxFilterPreProcess3(src, x + 8, scale, sum3, square_sum3, &s, &sq,
1200                          &mas.val[1], &bs.val[1]);
1201     if (calculate444) {
1202       Store343_444(mas, bs, 0, ma343, ma444, b343, b444);
1203       ma444 += 8;
1204       b444 += 8;
1205     } else {
1206       const uint16x8_t ma = Sum343(mas);
1207       const uint32x4x2_t b = Sum343W(bs);
1208       vst1q_u16(ma343, ma);
1209       vst1q_u32(b343 + 0, b.val[0]);
1210       vst1q_u32(b343 + 4, b.val[1]);
1211     }
1212     mas.val[0] = mas.val[1];
1213     bs.val[0] = bs.val[1];
1214     ma343 += 8;
1215     b343 += 8;
1216     x += 8;
1217   } while (x < width);
1218 }
1219 
BoxSumFilterPreProcess(const uint8_t * const src0,const uint8_t * const src1,const int width,const uint16_t scales[2],uint16_t * const sum3[4],uint16_t * const sum5[5],uint32_t * const square_sum3[4],uint32_t * const square_sum5[5],uint16_t * const ma343[4],uint16_t * const ma444[2],uint16_t * ma565,uint32_t * const b343[4],uint32_t * const b444[2],uint32_t * b565)1220 inline void BoxSumFilterPreProcess(
1221     const uint8_t* const src0, const uint8_t* const src1, const int width,
1222     const uint16_t scales[2], uint16_t* const sum3[4], uint16_t* const sum5[5],
1223     uint32_t* const square_sum3[4], uint32_t* const square_sum5[5],
1224     uint16_t* const ma343[4], uint16_t* const ma444[2], uint16_t* ma565,
1225     uint32_t* const b343[4], uint32_t* const b444[2], uint32_t* b565) {
1226   uint8x8x2_t s[2];
1227   uint8x8x2_t ma3[2], ma5;
1228   uint16x8x2_t sq[2], b3[2], b5;
1229   s[0].val[0] = vld1_u8(src0);
1230   s[1].val[0] = vld1_u8(src1);
1231   sq[0].val[0] = vmull_u8(s[0].val[0], s[0].val[0]);
1232   sq[1].val[0] = vmull_u8(s[1].val[0], s[1].val[0]);
1233   BoxFilterPreProcess(src0, src1, 0, scales, sum3, sum5, square_sum3,
1234                       square_sum5, s, sq, &ma3[0].val[0], &ma3[1].val[0],
1235                       &b3[0].val[0], &b3[1].val[0], &ma5.val[0], &b5.val[0]);
1236 
1237   int x = 0;
1238   do {
1239     s[0].val[0] = s[0].val[1];
1240     s[1].val[0] = s[1].val[1];
1241     sq[0].val[0] = sq[0].val[1];
1242     sq[1].val[0] = sq[1].val[1];
1243     BoxFilterPreProcess(src0, src1, x + 8, scales, sum3, sum5, square_sum3,
1244                         square_sum5, s, sq, &ma3[0].val[1], &ma3[1].val[1],
1245                         &b3[0].val[1], &b3[1].val[1], &ma5.val[1], &b5.val[1]);
1246     uint16x8_t ma = Sum343(ma3[0]);
1247     uint32x4x2_t b = Sum343W(b3[0]);
1248     vst1q_u16(ma343[0] + x, ma);
1249     vst1q_u32(b343[0] + x, b.val[0]);
1250     vst1q_u32(b343[0] + x + 4, b.val[1]);
1251     Store343_444(ma3[1], b3[1], x, ma343[1], ma444[0], b343[1], b444[0]);
1252     ma = Sum565(ma5);
1253     b = Sum565W(b5);
1254     vst1q_u16(ma565, ma);
1255     vst1q_u32(b565 + 0, b.val[0]);
1256     vst1q_u32(b565 + 4, b.val[1]);
1257     ma3[0].val[0] = ma3[0].val[1];
1258     ma3[1].val[0] = ma3[1].val[1];
1259     b3[0].val[0] = b3[0].val[1];
1260     b3[1].val[0] = b3[1].val[1];
1261     ma5.val[0] = ma5.val[1];
1262     b5.val[0] = b5.val[1];
1263     ma565 += 8;
1264     b565 += 8;
1265     x += 8;
1266   } while (x < width);
1267 }
1268 
1269 template <int shift>
FilterOutput(const uint16x4_t src,const uint16x4_t ma,const uint32x4_t b)1270 inline int16x4_t FilterOutput(const uint16x4_t src, const uint16x4_t ma,
1271                               const uint32x4_t b) {
1272   // ma: 255 * 32 = 8160 (13 bits)
1273   // b: 65088 * 32 = 2082816 (21 bits)
1274   // v: b - ma * 255 (22 bits)
1275   const int32x4_t v = vreinterpretq_s32_u32(vmlsl_u16(b, ma, src));
1276   // kSgrProjSgrBits = 8
1277   // kSgrProjRestoreBits = 4
1278   // shift = 4 or 5
1279   // v >> 8 or 9 (13 bits)
1280   return vrshrn_n_s32(v, kSgrProjSgrBits + shift - kSgrProjRestoreBits);
1281 }
1282 
1283 template <int shift>
CalculateFilteredOutput(const uint8x8_t src,const uint16x8_t ma,const uint32x4x2_t b)1284 inline int16x8_t CalculateFilteredOutput(const uint8x8_t src,
1285                                          const uint16x8_t ma,
1286                                          const uint32x4x2_t b) {
1287   const uint16x8_t src_u16 = vmovl_u8(src);
1288   const int16x4_t dst_lo =
1289       FilterOutput<shift>(vget_low_u16(src_u16), vget_low_u16(ma), b.val[0]);
1290   const int16x4_t dst_hi =
1291       FilterOutput<shift>(vget_high_u16(src_u16), vget_high_u16(ma), b.val[1]);
1292   return vcombine_s16(dst_lo, dst_hi);  // 13 bits
1293 }
1294 
CalculateFilteredOutputPass1(const uint8x8_t s,uint16x8_t ma[2],uint32x4x2_t b[2])1295 inline int16x8_t CalculateFilteredOutputPass1(const uint8x8_t s,
1296                                               uint16x8_t ma[2],
1297                                               uint32x4x2_t b[2]) {
1298   const uint16x8_t ma_sum = vaddq_u16(ma[0], ma[1]);
1299   uint32x4x2_t b_sum;
1300   b_sum.val[0] = vaddq_u32(b[0].val[0], b[1].val[0]);
1301   b_sum.val[1] = vaddq_u32(b[0].val[1], b[1].val[1]);
1302   return CalculateFilteredOutput<5>(s, ma_sum, b_sum);
1303 }
1304 
CalculateFilteredOutputPass2(const uint8x8_t s,uint16x8_t ma[3],uint32x4x2_t b[3])1305 inline int16x8_t CalculateFilteredOutputPass2(const uint8x8_t s,
1306                                               uint16x8_t ma[3],
1307                                               uint32x4x2_t b[3]) {
1308   const uint16x8_t ma_sum = Sum3_16(ma);
1309   const uint32x4x2_t b_sum = Sum3_32(b);
1310   return CalculateFilteredOutput<5>(s, ma_sum, b_sum);
1311 }
1312 
SelfGuidedFinal(const uint8x8_t src,const int32x4_t v[2],uint8_t * const dst)1313 inline void SelfGuidedFinal(const uint8x8_t src, const int32x4_t v[2],
1314                             uint8_t* const dst) {
1315   const int16x4_t v_lo =
1316       vrshrn_n_s32(v[0], kSgrProjRestoreBits + kSgrProjPrecisionBits);
1317   const int16x4_t v_hi =
1318       vrshrn_n_s32(v[1], kSgrProjRestoreBits + kSgrProjPrecisionBits);
1319   const int16x8_t vv = vcombine_s16(v_lo, v_hi);
1320   const int16x8_t s = ZeroExtend(src);
1321   const int16x8_t d = vaddq_s16(s, vv);
1322   vst1_u8(dst, vqmovun_s16(d));
1323 }
1324 
SelfGuidedDoubleMultiplier(const uint8x8_t src,const int16x8_t filter[2],const int w0,const int w2,uint8_t * const dst)1325 inline void SelfGuidedDoubleMultiplier(const uint8x8_t src,
1326                                        const int16x8_t filter[2], const int w0,
1327                                        const int w2, uint8_t* const dst) {
1328   int32x4_t v[2];
1329   v[0] = vmull_n_s16(vget_low_s16(filter[0]), w0);
1330   v[1] = vmull_n_s16(vget_high_s16(filter[0]), w0);
1331   v[0] = vmlal_n_s16(v[0], vget_low_s16(filter[1]), w2);
1332   v[1] = vmlal_n_s16(v[1], vget_high_s16(filter[1]), w2);
1333   SelfGuidedFinal(src, v, dst);
1334 }
1335 
SelfGuidedSingleMultiplier(const uint8x8_t src,const int16x8_t filter,const int w0,uint8_t * const dst)1336 inline void SelfGuidedSingleMultiplier(const uint8x8_t src,
1337                                        const int16x8_t filter, const int w0,
1338                                        uint8_t* const dst) {
1339   // weight: -96 to 96 (Sgrproj_Xqd_Min/Max)
1340   int32x4_t v[2];
1341   v[0] = vmull_n_s16(vget_low_s16(filter), w0);
1342   v[1] = vmull_n_s16(vget_high_s16(filter), w0);
1343   SelfGuidedFinal(src, v, dst);
1344 }
1345 
BoxFilterPass1(const uint8_t * const src,const uint8_t * const src0,const uint8_t * const src1,const ptrdiff_t stride,uint16_t * const sum5[5],uint32_t * const square_sum5[5],const int width,const uint32_t scale,const int16_t w0,uint16_t * const ma565[2],uint32_t * const b565[2],uint8_t * const dst)1346 LIBGAV1_ALWAYS_INLINE void BoxFilterPass1(
1347     const uint8_t* const src, const uint8_t* const src0,
1348     const uint8_t* const src1, const ptrdiff_t stride, uint16_t* const sum5[5],
1349     uint32_t* const square_sum5[5], const int width, const uint32_t scale,
1350     const int16_t w0, uint16_t* const ma565[2], uint32_t* const b565[2],
1351     uint8_t* const dst) {
1352   uint8x8x2_t s[2], mas;
1353   uint16x8x2_t sq[2], bs;
1354   s[0].val[0] = vld1_u8(src0);
1355   s[1].val[0] = vld1_u8(src1);
1356   sq[0].val[0] = vmull_u8(s[0].val[0], s[0].val[0]);
1357   sq[1].val[0] = vmull_u8(s[1].val[0], s[1].val[0]);
1358   BoxFilterPreProcess5(src0, src1, 0, scale, sum5, square_sum5, s, sq,
1359                        &mas.val[0], &bs.val[0]);
1360 
1361   int x = 0;
1362   do {
1363     s[0].val[0] = s[0].val[1];
1364     s[1].val[0] = s[1].val[1];
1365     sq[0].val[0] = sq[0].val[1];
1366     sq[1].val[0] = sq[1].val[1];
1367     BoxFilterPreProcess5(src0, src1, x + 8, scale, sum5, square_sum5, s, sq,
1368                          &mas.val[1], &bs.val[1]);
1369     uint16x8_t ma[2];
1370     uint32x4x2_t b[2];
1371     ma[1] = Sum565(mas);
1372     b[1] = Sum565W(bs);
1373     vst1q_u16(ma565[1] + x, ma[1]);
1374     vst1q_u32(b565[1] + x + 0, b[1].val[0]);
1375     vst1q_u32(b565[1] + x + 4, b[1].val[1]);
1376     const uint8x8_t sr0 = vld1_u8(src + x);
1377     const uint8x8_t sr1 = vld1_u8(src + stride + x);
1378     int16x8_t p0, p1;
1379     ma[0] = vld1q_u16(ma565[0] + x);
1380     b[0].val[0] = vld1q_u32(b565[0] + x + 0);
1381     b[0].val[1] = vld1q_u32(b565[0] + x + 4);
1382     p0 = CalculateFilteredOutputPass1(sr0, ma, b);
1383     p1 = CalculateFilteredOutput<4>(sr1, ma[1], b[1]);
1384     SelfGuidedSingleMultiplier(sr0, p0, w0, dst + x);
1385     SelfGuidedSingleMultiplier(sr1, p1, w0, dst + stride + x);
1386     mas.val[0] = mas.val[1];
1387     bs.val[0] = bs.val[1];
1388     x += 8;
1389   } while (x < width);
1390 }
1391 
BoxFilterPass1LastRow(const uint8_t * const src,const uint8_t * const src0,const int width,const uint32_t scale,const int16_t w0,uint16_t * const sum5[5],uint32_t * const square_sum5[5],uint16_t * ma565,uint32_t * b565,uint8_t * const dst)1392 inline void BoxFilterPass1LastRow(const uint8_t* const src,
1393                                   const uint8_t* const src0, const int width,
1394                                   const uint32_t scale, const int16_t w0,
1395                                   uint16_t* const sum5[5],
1396                                   uint32_t* const square_sum5[5],
1397                                   uint16_t* ma565, uint32_t* b565,
1398                                   uint8_t* const dst) {
1399   uint8x8x2_t s, mas;
1400   uint16x8x2_t sq, bs;
1401   s.val[0] = vld1_u8(src0);
1402   sq.val[0] = vmull_u8(s.val[0], s.val[0]);
1403   BoxFilterPreProcess5LastRow(src0, 0, scale, sum5, square_sum5, &s, &sq,
1404                               &mas.val[0], &bs.val[0]);
1405 
1406   int x = 0;
1407   do {
1408     s.val[0] = s.val[1];
1409     sq.val[0] = sq.val[1];
1410     BoxFilterPreProcess5LastRow(src0, x + 8, scale, sum5, square_sum5, &s, &sq,
1411                                 &mas.val[1], &bs.val[1]);
1412     uint16x8_t ma[2];
1413     uint32x4x2_t b[2];
1414     ma[1] = Sum565(mas);
1415     b[1] = Sum565W(bs);
1416     mas.val[0] = mas.val[1];
1417     bs.val[0] = bs.val[1];
1418     ma[0] = vld1q_u16(ma565);
1419     b[0].val[0] = vld1q_u32(b565 + 0);
1420     b[0].val[1] = vld1q_u32(b565 + 4);
1421     const uint8x8_t sr = vld1_u8(src + x);
1422     const int16x8_t p = CalculateFilteredOutputPass1(sr, ma, b);
1423     SelfGuidedSingleMultiplier(sr, p, w0, dst + x);
1424     ma565 += 8;
1425     b565 += 8;
1426     x += 8;
1427   } while (x < width);
1428 }
1429 
BoxFilterPass2(const uint8_t * const src,const uint8_t * const src0,const int width,const uint32_t scale,const int16_t w0,uint16_t * const sum3[3],uint32_t * const square_sum3[3],uint16_t * const ma343[3],uint16_t * const ma444[2],uint32_t * const b343[3],uint32_t * const b444[2],uint8_t * const dst)1430 LIBGAV1_ALWAYS_INLINE void BoxFilterPass2(
1431     const uint8_t* const src, const uint8_t* const src0, const int width,
1432     const uint32_t scale, const int16_t w0, uint16_t* const sum3[3],
1433     uint32_t* const square_sum3[3], uint16_t* const ma343[3],
1434     uint16_t* const ma444[2], uint32_t* const b343[3], uint32_t* const b444[2],
1435     uint8_t* const dst) {
1436   uint8x8x2_t s, mas;
1437   uint16x8x2_t sq, bs;
1438   s.val[0] = vld1_u8(src0);
1439   sq.val[0] = vmull_u8(s.val[0], s.val[0]);
1440   BoxFilterPreProcess3(src0, 0, scale, sum3, square_sum3, &s, &sq, &mas.val[0],
1441                        &bs.val[0]);
1442 
1443   int x = 0;
1444   do {
1445     s.val[0] = s.val[1];
1446     sq.val[0] = sq.val[1];
1447     BoxFilterPreProcess3(src0, x + 8, scale, sum3, square_sum3, &s, &sq,
1448                          &mas.val[1], &bs.val[1]);
1449     uint16x8_t ma[3];
1450     uint32x4x2_t b[3];
1451     Store343_444(mas, bs, x, &ma[2], &b[2], ma343[2], ma444[1], b343[2],
1452                  b444[1]);
1453     const uint8x8_t sr = vld1_u8(src + x);
1454     ma[0] = vld1q_u16(ma343[0] + x);
1455     ma[1] = vld1q_u16(ma444[0] + x);
1456     b[0].val[0] = vld1q_u32(b343[0] + x + 0);
1457     b[0].val[1] = vld1q_u32(b343[0] + x + 4);
1458     b[1].val[0] = vld1q_u32(b444[0] + x + 0);
1459     b[1].val[1] = vld1q_u32(b444[0] + x + 4);
1460     const int16x8_t p = CalculateFilteredOutputPass2(sr, ma, b);
1461     SelfGuidedSingleMultiplier(sr, p, w0, dst + x);
1462     mas.val[0] = mas.val[1];
1463     bs.val[0] = bs.val[1];
1464     x += 8;
1465   } while (x < width);
1466 }
1467 
BoxFilter(const uint8_t * const src,const uint8_t * const src0,const uint8_t * const src1,const ptrdiff_t stride,const int width,const uint16_t scales[2],const int16_t w0,const int16_t w2,uint16_t * const sum3[4],uint16_t * const sum5[5],uint32_t * const square_sum3[4],uint32_t * const square_sum5[5],uint16_t * const ma343[4],uint16_t * const ma444[3],uint16_t * const ma565[2],uint32_t * const b343[4],uint32_t * const b444[3],uint32_t * const b565[2],uint8_t * const dst)1468 LIBGAV1_ALWAYS_INLINE void BoxFilter(
1469     const uint8_t* const src, const uint8_t* const src0,
1470     const uint8_t* const src1, const ptrdiff_t stride, const int width,
1471     const uint16_t scales[2], const int16_t w0, const int16_t w2,
1472     uint16_t* const sum3[4], uint16_t* const sum5[5],
1473     uint32_t* const square_sum3[4], uint32_t* const square_sum5[5],
1474     uint16_t* const ma343[4], uint16_t* const ma444[3],
1475     uint16_t* const ma565[2], uint32_t* const b343[4], uint32_t* const b444[3],
1476     uint32_t* const b565[2], uint8_t* const dst) {
1477   uint8x8x2_t s[2], ma3[2], ma5;
1478   uint16x8x2_t sq[2], b3[2], b5;
1479   s[0].val[0] = vld1_u8(src0);
1480   s[1].val[0] = vld1_u8(src1);
1481   sq[0].val[0] = vmull_u8(s[0].val[0], s[0].val[0]);
1482   sq[1].val[0] = vmull_u8(s[1].val[0], s[1].val[0]);
1483   BoxFilterPreProcess(src0, src1, 0, scales, sum3, sum5, square_sum3,
1484                       square_sum5, s, sq, &ma3[0].val[0], &ma3[1].val[0],
1485                       &b3[0].val[0], &b3[1].val[0], &ma5.val[0], &b5.val[0]);
1486 
1487   int x = 0;
1488   do {
1489     s[0].val[0] = s[0].val[1];
1490     s[1].val[0] = s[1].val[1];
1491     sq[0].val[0] = sq[0].val[1];
1492     sq[1].val[0] = sq[1].val[1];
1493     BoxFilterPreProcess(src0, src1, x + 8, scales, sum3, sum5, square_sum3,
1494                         square_sum5, s, sq, &ma3[0].val[1], &ma3[1].val[1],
1495                         &b3[0].val[1], &b3[1].val[1], &ma5.val[1], &b5.val[1]);
1496     uint16x8_t ma[3][3];
1497     uint32x4x2_t b[3][3];
1498     Store343_444(ma3[0], b3[0], x, &ma[1][2], &ma[2][1], &b[1][2], &b[2][1],
1499                  ma343[2], ma444[1], b343[2], b444[1]);
1500     Store343_444(ma3[1], b3[1], x, &ma[2][2], &b[2][2], ma343[3], ma444[2],
1501                  b343[3], b444[2]);
1502     ma[0][1] = Sum565(ma5);
1503     b[0][1] = Sum565W(b5);
1504     vst1q_u16(ma565[1] + x, ma[0][1]);
1505     vst1q_u32(b565[1] + x, b[0][1].val[0]);
1506     vst1q_u32(b565[1] + x + 4, b[0][1].val[1]);
1507     ma3[0].val[0] = ma3[0].val[1];
1508     ma3[1].val[0] = ma3[1].val[1];
1509     b3[0].val[0] = b3[0].val[1];
1510     b3[1].val[0] = b3[1].val[1];
1511     ma5.val[0] = ma5.val[1];
1512     b5.val[0] = b5.val[1];
1513     int16x8_t p[2][2];
1514     const uint8x8_t sr0 = vld1_u8(src + x);
1515     const uint8x8_t sr1 = vld1_u8(src + stride + x);
1516     ma[0][0] = vld1q_u16(ma565[0] + x);
1517     b[0][0].val[0] = vld1q_u32(b565[0] + x);
1518     b[0][0].val[1] = vld1q_u32(b565[0] + x + 4);
1519     p[0][0] = CalculateFilteredOutputPass1(sr0, ma[0], b[0]);
1520     p[1][0] = CalculateFilteredOutput<4>(sr1, ma[0][1], b[0][1]);
1521     ma[1][0] = vld1q_u16(ma343[0] + x);
1522     ma[1][1] = vld1q_u16(ma444[0] + x);
1523     b[1][0].val[0] = vld1q_u32(b343[0] + x);
1524     b[1][0].val[1] = vld1q_u32(b343[0] + x + 4);
1525     b[1][1].val[0] = vld1q_u32(b444[0] + x);
1526     b[1][1].val[1] = vld1q_u32(b444[0] + x + 4);
1527     p[0][1] = CalculateFilteredOutputPass2(sr0, ma[1], b[1]);
1528     ma[2][0] = vld1q_u16(ma343[1] + x);
1529     b[2][0].val[0] = vld1q_u32(b343[1] + x);
1530     b[2][0].val[1] = vld1q_u32(b343[1] + x + 4);
1531     p[1][1] = CalculateFilteredOutputPass2(sr1, ma[2], b[2]);
1532     SelfGuidedDoubleMultiplier(sr0, p[0], w0, w2, dst + x);
1533     SelfGuidedDoubleMultiplier(sr1, p[1], w0, w2, dst + stride + x);
1534     x += 8;
1535   } while (x < width);
1536 }
1537 
BoxFilterLastRow(const uint8_t * const src,const uint8_t * const src0,const int width,const uint16_t scales[2],const int16_t w0,const int16_t w2,uint16_t * const sum3[4],uint16_t * const sum5[5],uint32_t * const square_sum3[4],uint32_t * const square_sum5[5],uint16_t * const ma343[4],uint16_t * const ma444[3],uint16_t * const ma565[2],uint32_t * const b343[4],uint32_t * const b444[3],uint32_t * const b565[2],uint8_t * const dst)1538 inline void BoxFilterLastRow(
1539     const uint8_t* const src, const uint8_t* const src0, const int width,
1540     const uint16_t scales[2], const int16_t w0, const int16_t w2,
1541     uint16_t* const sum3[4], uint16_t* const sum5[5],
1542     uint32_t* const square_sum3[4], uint32_t* const square_sum5[5],
1543     uint16_t* const ma343[4], uint16_t* const ma444[3],
1544     uint16_t* const ma565[2], uint32_t* const b343[4], uint32_t* const b444[3],
1545     uint32_t* const b565[2], uint8_t* const dst) {
1546   uint8x8x2_t s, ma3, ma5;
1547   uint16x8x2_t sq, b3, b5;
1548   uint16x8_t ma[3];
1549   uint32x4x2_t b[3];
1550   s.val[0] = vld1_u8(src0);
1551   sq.val[0] = vmull_u8(s.val[0], s.val[0]);
1552   BoxFilterPreProcessLastRow(src0, 0, scales, sum3, sum5, square_sum3,
1553                              square_sum5, &s, &sq, &ma3.val[0], &ma5.val[0],
1554                              &b3.val[0], &b5.val[0]);
1555 
1556   int x = 0;
1557   do {
1558     s.val[0] = s.val[1];
1559     sq.val[0] = sq.val[1];
1560     BoxFilterPreProcessLastRow(src0, x + 8, scales, sum3, sum5, square_sum3,
1561                                square_sum5, &s, &sq, &ma3.val[1], &ma5.val[1],
1562                                &b3.val[1], &b5.val[1]);
1563     ma[1] = Sum565(ma5);
1564     b[1] = Sum565W(b5);
1565     ma5.val[0] = ma5.val[1];
1566     b5.val[0] = b5.val[1];
1567     ma[2] = Sum343(ma3);
1568     b[2] = Sum343W(b3);
1569     ma3.val[0] = ma3.val[1];
1570     b3.val[0] = b3.val[1];
1571     const uint8x8_t sr = vld1_u8(src + x);
1572     int16x8_t p[2];
1573     ma[0] = vld1q_u16(ma565[0] + x);
1574     b[0].val[0] = vld1q_u32(b565[0] + x + 0);
1575     b[0].val[1] = vld1q_u32(b565[0] + x + 4);
1576     p[0] = CalculateFilteredOutputPass1(sr, ma, b);
1577     ma[0] = vld1q_u16(ma343[0] + x);
1578     ma[1] = vld1q_u16(ma444[0] + x);
1579     b[0].val[0] = vld1q_u32(b343[0] + x + 0);
1580     b[0].val[1] = vld1q_u32(b343[0] + x + 4);
1581     b[1].val[0] = vld1q_u32(b444[0] + x + 0);
1582     b[1].val[1] = vld1q_u32(b444[0] + x + 4);
1583     p[1] = CalculateFilteredOutputPass2(sr, ma, b);
1584     SelfGuidedDoubleMultiplier(sr, p, w0, w2, dst + x);
1585     x += 8;
1586   } while (x < width);
1587 }
1588 
BoxFilterProcess(const RestorationUnitInfo & restoration_info,const uint8_t * src,const uint8_t * const top_border,const uint8_t * bottom_border,const ptrdiff_t stride,const int width,const int height,SgrBuffer * const sgr_buffer,uint8_t * dst)1589 LIBGAV1_ALWAYS_INLINE void BoxFilterProcess(
1590     const RestorationUnitInfo& restoration_info, const uint8_t* src,
1591     const uint8_t* const top_border, const uint8_t* bottom_border,
1592     const ptrdiff_t stride, const int width, const int height,
1593     SgrBuffer* const sgr_buffer, uint8_t* dst) {
1594   const auto temp_stride = Align<ptrdiff_t>(width, 8);
1595   const ptrdiff_t sum_stride = temp_stride + 8;
1596   const int sgr_proj_index = restoration_info.sgr_proj_info.index;
1597   const uint16_t* const scales = kSgrScaleParameter[sgr_proj_index];  // < 2^12.
1598   const int16_t w0 = restoration_info.sgr_proj_info.multiplier[0];
1599   const int16_t w1 = restoration_info.sgr_proj_info.multiplier[1];
1600   const int16_t w2 = (1 << kSgrProjPrecisionBits) - w0 - w1;
1601   uint16_t *sum3[4], *sum5[5], *ma343[4], *ma444[3], *ma565[2];
1602   uint32_t *square_sum3[4], *square_sum5[5], *b343[4], *b444[3], *b565[2];
1603   sum3[0] = sgr_buffer->sum3;
1604   square_sum3[0] = sgr_buffer->square_sum3;
1605   ma343[0] = sgr_buffer->ma343;
1606   b343[0] = sgr_buffer->b343;
1607   for (int i = 1; i <= 3; ++i) {
1608     sum3[i] = sum3[i - 1] + sum_stride;
1609     square_sum3[i] = square_sum3[i - 1] + sum_stride;
1610     ma343[i] = ma343[i - 1] + temp_stride;
1611     b343[i] = b343[i - 1] + temp_stride;
1612   }
1613   sum5[0] = sgr_buffer->sum5;
1614   square_sum5[0] = sgr_buffer->square_sum5;
1615   for (int i = 1; i <= 4; ++i) {
1616     sum5[i] = sum5[i - 1] + sum_stride;
1617     square_sum5[i] = square_sum5[i - 1] + sum_stride;
1618   }
1619   ma444[0] = sgr_buffer->ma444;
1620   b444[0] = sgr_buffer->b444;
1621   for (int i = 1; i <= 2; ++i) {
1622     ma444[i] = ma444[i - 1] + temp_stride;
1623     b444[i] = b444[i - 1] + temp_stride;
1624   }
1625   ma565[0] = sgr_buffer->ma565;
1626   ma565[1] = ma565[0] + temp_stride;
1627   b565[0] = sgr_buffer->b565;
1628   b565[1] = b565[0] + temp_stride;
1629   assert(scales[0] != 0);
1630   assert(scales[1] != 0);
1631   BoxSum(top_border, stride, 2, sum_stride, sum3[0], sum5[1], square_sum3[0],
1632          square_sum5[1]);
1633   sum5[0] = sum5[1];
1634   square_sum5[0] = square_sum5[1];
1635   const uint8_t* const s = (height > 1) ? src + stride : bottom_border;
1636   BoxSumFilterPreProcess(src, s, width, scales, sum3, sum5, square_sum3,
1637                          square_sum5, ma343, ma444, ma565[0], b343, b444,
1638                          b565[0]);
1639   sum5[0] = sgr_buffer->sum5;
1640   square_sum5[0] = sgr_buffer->square_sum5;
1641 
1642   for (int y = (height >> 1) - 1; y > 0; --y) {
1643     Circulate4PointersBy2<uint16_t>(sum3);
1644     Circulate4PointersBy2<uint32_t>(square_sum3);
1645     Circulate5PointersBy2<uint16_t>(sum5);
1646     Circulate5PointersBy2<uint32_t>(square_sum5);
1647     BoxFilter(src + 3, src + 2 * stride, src + 3 * stride, stride, width,
1648               scales, w0, w2, sum3, sum5, square_sum3, square_sum5, ma343,
1649               ma444, ma565, b343, b444, b565, dst);
1650     src += 2 * stride;
1651     dst += 2 * stride;
1652     Circulate4PointersBy2<uint16_t>(ma343);
1653     Circulate4PointersBy2<uint32_t>(b343);
1654     std::swap(ma444[0], ma444[2]);
1655     std::swap(b444[0], b444[2]);
1656     std::swap(ma565[0], ma565[1]);
1657     std::swap(b565[0], b565[1]);
1658   }
1659 
1660   Circulate4PointersBy2<uint16_t>(sum3);
1661   Circulate4PointersBy2<uint32_t>(square_sum3);
1662   Circulate5PointersBy2<uint16_t>(sum5);
1663   Circulate5PointersBy2<uint32_t>(square_sum5);
1664   if ((height & 1) == 0 || height > 1) {
1665     const uint8_t* sr[2];
1666     if ((height & 1) == 0) {
1667       sr[0] = bottom_border;
1668       sr[1] = bottom_border + stride;
1669     } else {
1670       sr[0] = src + 2 * stride;
1671       sr[1] = bottom_border;
1672     }
1673     BoxFilter(src + 3, sr[0], sr[1], stride, width, scales, w0, w2, sum3, sum5,
1674               square_sum3, square_sum5, ma343, ma444, ma565, b343, b444, b565,
1675               dst);
1676   }
1677   if ((height & 1) != 0) {
1678     if (height > 1) {
1679       src += 2 * stride;
1680       dst += 2 * stride;
1681       Circulate4PointersBy2<uint16_t>(sum3);
1682       Circulate4PointersBy2<uint32_t>(square_sum3);
1683       Circulate5PointersBy2<uint16_t>(sum5);
1684       Circulate5PointersBy2<uint32_t>(square_sum5);
1685       Circulate4PointersBy2<uint16_t>(ma343);
1686       Circulate4PointersBy2<uint32_t>(b343);
1687       std::swap(ma444[0], ma444[2]);
1688       std::swap(b444[0], b444[2]);
1689       std::swap(ma565[0], ma565[1]);
1690       std::swap(b565[0], b565[1]);
1691     }
1692     BoxFilterLastRow(src + 3, bottom_border + stride, width, scales, w0, w2,
1693                      sum3, sum5, square_sum3, square_sum5, ma343, ma444, ma565,
1694                      b343, b444, b565, dst);
1695   }
1696 }
1697 
BoxFilterProcessPass1(const RestorationUnitInfo & restoration_info,const uint8_t * src,const uint8_t * const top_border,const uint8_t * bottom_border,const ptrdiff_t stride,const int width,const int height,SgrBuffer * const sgr_buffer,uint8_t * dst)1698 inline void BoxFilterProcessPass1(const RestorationUnitInfo& restoration_info,
1699                                   const uint8_t* src,
1700                                   const uint8_t* const top_border,
1701                                   const uint8_t* bottom_border,
1702                                   const ptrdiff_t stride, const int width,
1703                                   const int height, SgrBuffer* const sgr_buffer,
1704                                   uint8_t* dst) {
1705   const auto temp_stride = Align<ptrdiff_t>(width, 8);
1706   const ptrdiff_t sum_stride = temp_stride + 8;
1707   const int sgr_proj_index = restoration_info.sgr_proj_info.index;
1708   const uint32_t scale = kSgrScaleParameter[sgr_proj_index][0];  // < 2^12.
1709   const int16_t w0 = restoration_info.sgr_proj_info.multiplier[0];
1710   uint16_t *sum5[5], *ma565[2];
1711   uint32_t *square_sum5[5], *b565[2];
1712   sum5[0] = sgr_buffer->sum5;
1713   square_sum5[0] = sgr_buffer->square_sum5;
1714   for (int i = 1; i <= 4; ++i) {
1715     sum5[i] = sum5[i - 1] + sum_stride;
1716     square_sum5[i] = square_sum5[i - 1] + sum_stride;
1717   }
1718   ma565[0] = sgr_buffer->ma565;
1719   ma565[1] = ma565[0] + temp_stride;
1720   b565[0] = sgr_buffer->b565;
1721   b565[1] = b565[0] + temp_stride;
1722   assert(scale != 0);
1723   BoxSum<5>(top_border, stride, 2, sum_stride, sum5[1], square_sum5[1]);
1724   sum5[0] = sum5[1];
1725   square_sum5[0] = square_sum5[1];
1726   const uint8_t* const s = (height > 1) ? src + stride : bottom_border;
1727   BoxSumFilterPreProcess5(src, s, width, scale, sum5, square_sum5, ma565[0],
1728                           b565[0]);
1729   sum5[0] = sgr_buffer->sum5;
1730   square_sum5[0] = sgr_buffer->square_sum5;
1731 
1732   for (int y = (height >> 1) - 1; y > 0; --y) {
1733     Circulate5PointersBy2<uint16_t>(sum5);
1734     Circulate5PointersBy2<uint32_t>(square_sum5);
1735     BoxFilterPass1(src + 3, src + 2 * stride, src + 3 * stride, stride, sum5,
1736                    square_sum5, width, scale, w0, ma565, b565, dst);
1737     src += 2 * stride;
1738     dst += 2 * stride;
1739     std::swap(ma565[0], ma565[1]);
1740     std::swap(b565[0], b565[1]);
1741   }
1742 
1743   Circulate5PointersBy2<uint16_t>(sum5);
1744   Circulate5PointersBy2<uint32_t>(square_sum5);
1745   if ((height & 1) == 0 || height > 1) {
1746     const uint8_t* sr[2];
1747     if ((height & 1) == 0) {
1748       sr[0] = bottom_border;
1749       sr[1] = bottom_border + stride;
1750     } else {
1751       sr[0] = src + 2 * stride;
1752       sr[1] = bottom_border;
1753     }
1754     BoxFilterPass1(src + 3, sr[0], sr[1], stride, sum5, square_sum5, width,
1755                    scale, w0, ma565, b565, dst);
1756   }
1757   if ((height & 1) != 0) {
1758     if (height > 1) {
1759       src += 2 * stride;
1760       dst += 2 * stride;
1761       std::swap(ma565[0], ma565[1]);
1762       std::swap(b565[0], b565[1]);
1763       Circulate5PointersBy2<uint16_t>(sum5);
1764       Circulate5PointersBy2<uint32_t>(square_sum5);
1765     }
1766     BoxFilterPass1LastRow(src + 3, bottom_border + stride, width, scale, w0,
1767                           sum5, square_sum5, ma565[0], b565[0], dst);
1768   }
1769 }
1770 
BoxFilterProcessPass2(const RestorationUnitInfo & restoration_info,const uint8_t * src,const uint8_t * const top_border,const uint8_t * bottom_border,const ptrdiff_t stride,const int width,const int height,SgrBuffer * const sgr_buffer,uint8_t * dst)1771 inline void BoxFilterProcessPass2(const RestorationUnitInfo& restoration_info,
1772                                   const uint8_t* src,
1773                                   const uint8_t* const top_border,
1774                                   const uint8_t* bottom_border,
1775                                   const ptrdiff_t stride, const int width,
1776                                   const int height, SgrBuffer* const sgr_buffer,
1777                                   uint8_t* dst) {
1778   assert(restoration_info.sgr_proj_info.multiplier[0] == 0);
1779   const auto temp_stride = Align<ptrdiff_t>(width, 8);
1780   const ptrdiff_t sum_stride = temp_stride + 8;
1781   const int16_t w1 = restoration_info.sgr_proj_info.multiplier[1];
1782   const int16_t w0 = (1 << kSgrProjPrecisionBits) - w1;
1783   const int sgr_proj_index = restoration_info.sgr_proj_info.index;
1784   const uint32_t scale = kSgrScaleParameter[sgr_proj_index][1];  // < 2^12.
1785   uint16_t *sum3[3], *ma343[3], *ma444[2];
1786   uint32_t *square_sum3[3], *b343[3], *b444[2];
1787   sum3[0] = sgr_buffer->sum3;
1788   square_sum3[0] = sgr_buffer->square_sum3;
1789   ma343[0] = sgr_buffer->ma343;
1790   b343[0] = sgr_buffer->b343;
1791   for (int i = 1; i <= 2; ++i) {
1792     sum3[i] = sum3[i - 1] + sum_stride;
1793     square_sum3[i] = square_sum3[i - 1] + sum_stride;
1794     ma343[i] = ma343[i - 1] + temp_stride;
1795     b343[i] = b343[i - 1] + temp_stride;
1796   }
1797   ma444[0] = sgr_buffer->ma444;
1798   ma444[1] = ma444[0] + temp_stride;
1799   b444[0] = sgr_buffer->b444;
1800   b444[1] = b444[0] + temp_stride;
1801   assert(scale != 0);
1802   BoxSum<3>(top_border, stride, 2, sum_stride, sum3[0], square_sum3[0]);
1803   BoxSumFilterPreProcess3<false>(src, width, scale, sum3, square_sum3, ma343[0],
1804                                  nullptr, b343[0], nullptr);
1805   Circulate3PointersBy1<uint16_t>(sum3);
1806   Circulate3PointersBy1<uint32_t>(square_sum3);
1807   const uint8_t* s;
1808   if (height > 1) {
1809     s = src + stride;
1810   } else {
1811     s = bottom_border;
1812     bottom_border += stride;
1813   }
1814   BoxSumFilterPreProcess3<true>(s, width, scale, sum3, square_sum3, ma343[1],
1815                                 ma444[0], b343[1], b444[0]);
1816 
1817   for (int y = height - 2; y > 0; --y) {
1818     Circulate3PointersBy1<uint16_t>(sum3);
1819     Circulate3PointersBy1<uint32_t>(square_sum3);
1820     BoxFilterPass2(src + 2, src + 2 * stride, width, scale, w0, sum3,
1821                    square_sum3, ma343, ma444, b343, b444, dst);
1822     src += stride;
1823     dst += stride;
1824     Circulate3PointersBy1<uint16_t>(ma343);
1825     Circulate3PointersBy1<uint32_t>(b343);
1826     std::swap(ma444[0], ma444[1]);
1827     std::swap(b444[0], b444[1]);
1828   }
1829 
1830   src += 2;
1831   int y = std::min(height, 2);
1832   do {
1833     Circulate3PointersBy1<uint16_t>(sum3);
1834     Circulate3PointersBy1<uint32_t>(square_sum3);
1835     BoxFilterPass2(src, bottom_border, width, scale, w0, sum3, square_sum3,
1836                    ma343, ma444, b343, b444, dst);
1837     src += stride;
1838     dst += stride;
1839     bottom_border += stride;
1840     Circulate3PointersBy1<uint16_t>(ma343);
1841     Circulate3PointersBy1<uint32_t>(b343);
1842     std::swap(ma444[0], ma444[1]);
1843     std::swap(b444[0], b444[1]);
1844   } while (--y != 0);
1845 }
1846 
1847 // If |width| is non-multiple of 8, up to 7 more pixels are written to |dest| in
1848 // the end of each row. It is safe to overwrite the output as it will not be
1849 // part of the visible frame.
SelfGuidedFilter_NEON(const RestorationUnitInfo & restoration_info,const void * const source,const void * const top_border,const void * const bottom_border,const ptrdiff_t stride,const int width,const int height,RestorationBuffer * const restoration_buffer,void * const dest)1850 void SelfGuidedFilter_NEON(
1851     const RestorationUnitInfo& restoration_info, const void* const source,
1852     const void* const top_border, const void* const bottom_border,
1853     const ptrdiff_t stride, const int width, const int height,
1854     RestorationBuffer* const restoration_buffer, void* const dest) {
1855   const int index = restoration_info.sgr_proj_info.index;
1856   const int radius_pass_0 = kSgrProjParams[index][0];  // 2 or 0
1857   const int radius_pass_1 = kSgrProjParams[index][2];  // 1 or 0
1858   const auto* const src = static_cast<const uint8_t*>(source);
1859   const auto* top = static_cast<const uint8_t*>(top_border);
1860   const auto* bottom = static_cast<const uint8_t*>(bottom_border);
1861   auto* const dst = static_cast<uint8_t*>(dest);
1862   SgrBuffer* const sgr_buffer = &restoration_buffer->sgr_buffer;
1863   if (radius_pass_1 == 0) {
1864     // |radius_pass_0| and |radius_pass_1| cannot both be 0, so we have the
1865     // following assertion.
1866     assert(radius_pass_0 != 0);
1867     BoxFilterProcessPass1(restoration_info, src - 3, top - 3, bottom - 3,
1868                           stride, width, height, sgr_buffer, dst);
1869   } else if (radius_pass_0 == 0) {
1870     BoxFilterProcessPass2(restoration_info, src - 2, top - 2, bottom - 2,
1871                           stride, width, height, sgr_buffer, dst);
1872   } else {
1873     BoxFilterProcess(restoration_info, src - 3, top - 3, bottom - 3, stride,
1874                      width, height, sgr_buffer, dst);
1875   }
1876 }
1877 
Init8bpp()1878 void Init8bpp() {
1879   Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
1880   assert(dsp != nullptr);
1881   dsp->loop_restorations[0] = WienerFilter_NEON;
1882   dsp->loop_restorations[1] = SelfGuidedFilter_NEON;
1883 }
1884 
1885 }  // namespace
1886 }  // namespace low_bitdepth
1887 
LoopRestorationInit_NEON()1888 void LoopRestorationInit_NEON() { low_bitdepth::Init8bpp(); }
1889 
1890 }  // namespace dsp
1891 }  // namespace libgav1
1892 
1893 #else  // !LIBGAV1_ENABLE_NEON
1894 namespace libgav1 {
1895 namespace dsp {
1896 
LoopRestorationInit_NEON()1897 void LoopRestorationInit_NEON() {}
1898 
1899 }  // namespace dsp
1900 }  // namespace libgav1
1901 #endif  // LIBGAV1_ENABLE_NEON
1902