1 /*
2  * Copyright (c) 2018, Alliance for Open Media. All rights reserved
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at https://www.aomedia.org/license/software-license. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at https://www.aomedia.org/license/patent-license.
10  */
11 
12 #include "EbDefinitions.h"
13 #include <immintrin.h>
14 #include "common_dsp_rtcd.h"
15 #include "EbRestoration.h"
16 #include "synonyms.h"
17 #include "synonyms_avx2.h"
18 #include "transpose_avx2.h"
19 #include "transpose_sse2.h"
20 
cvt_16to32bit_8x8(const __m128i s[8],__m256i r[8])21 static INLINE void cvt_16to32bit_8x8(const __m128i s[8], __m256i r[8]) {
22     r[0] = _mm256_cvtepu16_epi32(s[0]);
23     r[1] = _mm256_cvtepu16_epi32(s[1]);
24     r[2] = _mm256_cvtepu16_epi32(s[2]);
25     r[3] = _mm256_cvtepu16_epi32(s[3]);
26     r[4] = _mm256_cvtepu16_epi32(s[4]);
27     r[5] = _mm256_cvtepu16_epi32(s[5]);
28     r[6] = _mm256_cvtepu16_epi32(s[6]);
29     r[7] = _mm256_cvtepu16_epi32(s[7]);
30 }
31 
add_32bit_8x8(const __m256i neighbor,__m256i r[8])32 static INLINE void add_32bit_8x8(const __m256i neighbor, __m256i r[8]) {
33     r[0] = _mm256_add_epi32(neighbor, r[0]);
34     r[1] = _mm256_add_epi32(r[0], r[1]);
35     r[2] = _mm256_add_epi32(r[1], r[2]);
36     r[3] = _mm256_add_epi32(r[2], r[3]);
37     r[4] = _mm256_add_epi32(r[3], r[4]);
38     r[5] = _mm256_add_epi32(r[4], r[5]);
39     r[6] = _mm256_add_epi32(r[5], r[6]);
40     r[7] = _mm256_add_epi32(r[6], r[7]);
41 }
42 
store_32bit_8x8(const __m256i r[8],int32_t * const buf,const int32_t buf_stride)43 static INLINE void store_32bit_8x8(const __m256i r[8], int32_t *const buf,
44                                    const int32_t buf_stride) {
45     _mm256_storeu_si256((__m256i *)(buf + 0 * buf_stride), r[0]);
46     _mm256_storeu_si256((__m256i *)(buf + 1 * buf_stride), r[1]);
47     _mm256_storeu_si256((__m256i *)(buf + 2 * buf_stride), r[2]);
48     _mm256_storeu_si256((__m256i *)(buf + 3 * buf_stride), r[3]);
49     _mm256_storeu_si256((__m256i *)(buf + 4 * buf_stride), r[4]);
50     _mm256_storeu_si256((__m256i *)(buf + 5 * buf_stride), r[5]);
51     _mm256_storeu_si256((__m256i *)(buf + 6 * buf_stride), r[6]);
52     _mm256_storeu_si256((__m256i *)(buf + 7 * buf_stride), r[7]);
53 }
54 
integral_images(const uint8_t * src,int32_t src_stride,int32_t width,int32_t height,int32_t * C,int32_t * D,int32_t buf_stride)55 static AOM_FORCE_INLINE void integral_images(const uint8_t *src, int32_t src_stride, int32_t width,
56                                              int32_t height, int32_t *C, int32_t *D,
57                                              int32_t buf_stride) {
58     const uint8_t *src_t = src;
59     int32_t *      ct    = C + buf_stride + 1;
60     int32_t *      dt    = D + buf_stride + 1;
61     const __m256i  zero  = _mm256_setzero_si256();
62 
63     memset(C, 0, sizeof(*C) * (width + 8));
64     memset(D, 0, sizeof(*D) * (width + 8));
65 
66     int y = 0;
67     do {
68         __m256i c_left = _mm256_setzero_si256();
69         __m256i d_left = _mm256_setzero_si256();
70 
71         // zero the left column.
72         ct[0 * buf_stride - 1] = dt[0 * buf_stride - 1] = 0;
73         ct[1 * buf_stride - 1] = dt[1 * buf_stride - 1] = 0;
74         ct[2 * buf_stride - 1] = dt[2 * buf_stride - 1] = 0;
75         ct[3 * buf_stride - 1] = dt[3 * buf_stride - 1] = 0;
76         ct[4 * buf_stride - 1] = dt[4 * buf_stride - 1] = 0;
77         ct[5 * buf_stride - 1] = dt[5 * buf_stride - 1] = 0;
78         ct[6 * buf_stride - 1] = dt[6 * buf_stride - 1] = 0;
79         ct[7 * buf_stride - 1] = dt[7 * buf_stride - 1] = 0;
80 
81         int x = 0;
82         do {
83             __m128i s[8];
84             __m256i r32[8];
85 
86             s[0] = _mm_loadl_epi64((__m128i *)(src_t + 0 * src_stride + x));
87             s[1] = _mm_loadl_epi64((__m128i *)(src_t + 1 * src_stride + x));
88             s[2] = _mm_loadl_epi64((__m128i *)(src_t + 2 * src_stride + x));
89             s[3] = _mm_loadl_epi64((__m128i *)(src_t + 3 * src_stride + x));
90             s[4] = _mm_loadl_epi64((__m128i *)(src_t + 4 * src_stride + x));
91             s[5] = _mm_loadl_epi64((__m128i *)(src_t + 5 * src_stride + x));
92             s[6] = _mm_loadl_epi64((__m128i *)(src_t + 6 * src_stride + x));
93             s[7] = _mm_loadl_epi64((__m128i *)(src_t + 7 * src_stride + x));
94 
95             partial_transpose_8bit_8x8(s, s);
96 
97             s[7] = _mm_unpackhi_epi8(s[3], _mm_setzero_si128());
98             s[6] = _mm_unpacklo_epi8(s[3], _mm_setzero_si128());
99             s[5] = _mm_unpackhi_epi8(s[2], _mm_setzero_si128());
100             s[4] = _mm_unpacklo_epi8(s[2], _mm_setzero_si128());
101             s[3] = _mm_unpackhi_epi8(s[1], _mm_setzero_si128());
102             s[2] = _mm_unpacklo_epi8(s[1], _mm_setzero_si128());
103             s[1] = _mm_unpackhi_epi8(s[0], _mm_setzero_si128());
104             s[0] = _mm_unpacklo_epi8(s[0], _mm_setzero_si128());
105 
106             cvt_16to32bit_8x8(s, r32);
107             add_32bit_8x8(d_left, r32);
108             d_left = r32[7];
109 
110             transpose_32bit_8x8_avx2(r32, r32);
111 
112             const __m256i d_top = _mm256_loadu_si256((__m256i *)(dt - buf_stride + x));
113             add_32bit_8x8(d_top, r32);
114             store_32bit_8x8(r32, dt + x, buf_stride);
115 
116             s[0] = _mm_mullo_epi16(s[0], s[0]);
117             s[1] = _mm_mullo_epi16(s[1], s[1]);
118             s[2] = _mm_mullo_epi16(s[2], s[2]);
119             s[3] = _mm_mullo_epi16(s[3], s[3]);
120             s[4] = _mm_mullo_epi16(s[4], s[4]);
121             s[5] = _mm_mullo_epi16(s[5], s[5]);
122             s[6] = _mm_mullo_epi16(s[6], s[6]);
123             s[7] = _mm_mullo_epi16(s[7], s[7]);
124 
125             cvt_16to32bit_8x8(s, r32);
126             add_32bit_8x8(c_left, r32);
127             c_left = r32[7];
128 
129             transpose_32bit_8x8_avx2(r32, r32);
130 
131             const __m256i c_top = _mm256_loadu_si256((__m256i *)(ct - buf_stride + x));
132             add_32bit_8x8(c_top, r32);
133             store_32bit_8x8(r32, ct + x, buf_stride);
134             x += 8;
135         } while (x < width);
136 
137         /* Used in calc_ab and calc_ab_fast, when calc out of right border */
138         for (int ln = 0; ln < 8; ++ln) {
139             _mm256_storeu_si256((__m256i *)(ct + x + ln * buf_stride), zero);
140             _mm256_storeu_si256((__m256i *)(dt + x + ln * buf_stride), zero);
141         }
142 
143         src_t += 8 * src_stride;
144         ct += 8 * buf_stride;
145         dt += 8 * buf_stride;
146         y += 8;
147     } while (y < height);
148 }
149 
integral_images_highbd(const uint16_t * src,int32_t src_stride,int32_t width,int32_t height,int32_t * C,int32_t * D,int32_t buf_stride)150 static AOM_FORCE_INLINE void integral_images_highbd(const uint16_t *src, int32_t src_stride,
151                                                     int32_t width, int32_t height, int32_t *C,
152                                                     int32_t *D, int32_t buf_stride) {
153     const uint16_t *src_t = src;
154     int32_t *       ct    = C + buf_stride + 1;
155     int32_t *       dt    = D + buf_stride + 1;
156     const __m256i   zero  = _mm256_setzero_si256();
157 
158     memset(C, 0, sizeof(*C) * (width + 8));
159     memset(D, 0, sizeof(*D) * (width + 8));
160 
161     int y = 0;
162     do {
163         __m256i c_left = _mm256_setzero_si256();
164         __m256i d_left = _mm256_setzero_si256();
165 
166         // zero the left column.
167         ct[0 * buf_stride - 1] = dt[0 * buf_stride - 1] = 0;
168         ct[1 * buf_stride - 1] = dt[1 * buf_stride - 1] = 0;
169         ct[2 * buf_stride - 1] = dt[2 * buf_stride - 1] = 0;
170         ct[3 * buf_stride - 1] = dt[3 * buf_stride - 1] = 0;
171         ct[4 * buf_stride - 1] = dt[4 * buf_stride - 1] = 0;
172         ct[5 * buf_stride - 1] = dt[5 * buf_stride - 1] = 0;
173         ct[6 * buf_stride - 1] = dt[6 * buf_stride - 1] = 0;
174         ct[7 * buf_stride - 1] = dt[7 * buf_stride - 1] = 0;
175 
176         int x = 0;
177         do {
178             __m128i s[8];
179             __m256i r32[8], a32[8];
180 
181             s[0] = _mm_loadu_si128((__m128i *)(src_t + 0 * src_stride + x));
182             s[1] = _mm_loadu_si128((__m128i *)(src_t + 1 * src_stride + x));
183             s[2] = _mm_loadu_si128((__m128i *)(src_t + 2 * src_stride + x));
184             s[3] = _mm_loadu_si128((__m128i *)(src_t + 3 * src_stride + x));
185             s[4] = _mm_loadu_si128((__m128i *)(src_t + 4 * src_stride + x));
186             s[5] = _mm_loadu_si128((__m128i *)(src_t + 5 * src_stride + x));
187             s[6] = _mm_loadu_si128((__m128i *)(src_t + 6 * src_stride + x));
188             s[7] = _mm_loadu_si128((__m128i *)(src_t + 7 * src_stride + x));
189 
190             transpose_16bit_8x8(s, s);
191 
192             cvt_16to32bit_8x8(s, r32);
193 
194             a32[0] = _mm256_madd_epi16(r32[0], r32[0]);
195             a32[1] = _mm256_madd_epi16(r32[1], r32[1]);
196             a32[2] = _mm256_madd_epi16(r32[2], r32[2]);
197             a32[3] = _mm256_madd_epi16(r32[3], r32[3]);
198             a32[4] = _mm256_madd_epi16(r32[4], r32[4]);
199             a32[5] = _mm256_madd_epi16(r32[5], r32[5]);
200             a32[6] = _mm256_madd_epi16(r32[6], r32[6]);
201             a32[7] = _mm256_madd_epi16(r32[7], r32[7]);
202 
203             add_32bit_8x8(c_left, a32);
204             c_left = a32[7];
205 
206             transpose_32bit_8x8_avx2(a32, a32);
207 
208             const __m256i c_top = _mm256_loadu_si256((__m256i *)(ct - buf_stride + x));
209             add_32bit_8x8(c_top, a32);
210             store_32bit_8x8(a32, ct + x, buf_stride);
211 
212             add_32bit_8x8(d_left, r32);
213             d_left = r32[7];
214 
215             transpose_32bit_8x8_avx2(r32, r32);
216 
217             const __m256i d_top = _mm256_loadu_si256((__m256i *)(dt - buf_stride + x));
218             add_32bit_8x8(d_top, r32);
219             store_32bit_8x8(r32, dt + x, buf_stride);
220             x += 8;
221         } while (x < width);
222 
223         /* Used in calc_ab and calc_ab_fast, when calc out of right border */
224         for (int ln = 0; ln < 8; ++ln) {
225             _mm256_storeu_si256((__m256i *)(ct + x + ln * buf_stride), zero);
226             _mm256_storeu_si256((__m256i *)(dt + x + ln * buf_stride), zero);
227         }
228 
229         src_t += 8 * src_stride;
230         ct += 8 * buf_stride;
231         dt += 8 * buf_stride;
232         y += 8;
233     } while (y < height);
234 }
235 
236 // Compute 8 values of boxsum from the given integral image. ii should point
237 // at the middle of the box (for the first value). r is the box radius.
boxsum_from_ii(const int32_t * ii,int32_t stride,int32_t r)238 static INLINE __m256i boxsum_from_ii(const int32_t *ii, int32_t stride, int32_t r) {
239     const __m256i tl = yy_loadu_256(ii - (r + 1) - (r + 1) * stride);
240     const __m256i tr = yy_loadu_256(ii + (r + 0) - (r + 1) * stride);
241     const __m256i bl = yy_loadu_256(ii - (r + 1) + r * stride);
242     const __m256i br = yy_loadu_256(ii + (r + 0) + r * stride);
243     const __m256i u  = _mm256_sub_epi32(tr, tl);
244     const __m256i v  = _mm256_sub_epi32(br, bl);
245     return _mm256_sub_epi32(v, u);
246 }
247 
round_for_shift(unsigned shift)248 static INLINE __m256i round_for_shift(unsigned shift) {
249     return _mm256_set1_epi32((1 << shift) >> 1);
250 }
251 
compute_p(__m256i sum1,__m256i sum2,int32_t n)252 static INLINE __m256i compute_p(__m256i sum1, __m256i sum2, int32_t n) {
253     const __m256i bb = _mm256_madd_epi16(sum1, sum1);
254     const __m256i an = _mm256_mullo_epi32(sum2, _mm256_set1_epi32(n));
255     return _mm256_sub_epi32(an, bb);
256 }
257 
compute_p_highbd(__m256i sum1,__m256i sum2,int32_t bit_depth,int32_t n)258 static INLINE __m256i compute_p_highbd(__m256i sum1, __m256i sum2, int32_t bit_depth, int32_t n) {
259     const __m256i rounding_a = round_for_shift(2 * (bit_depth - 8));
260     const __m256i rounding_b = round_for_shift(bit_depth - 8);
261     const __m128i shift_a    = _mm_cvtsi32_si128(2 * (bit_depth - 8));
262     const __m128i shift_b    = _mm_cvtsi32_si128(bit_depth - 8);
263     const __m256i a          = _mm256_srl_epi32(_mm256_add_epi32(sum2, rounding_a), shift_a);
264     const __m256i b          = _mm256_srl_epi32(_mm256_add_epi32(sum1, rounding_b), shift_b);
265     // b < 2^14, so we can use a 16-bit madd rather than a 32-bit
266     // mullo to square it
267     const __m256i bb = _mm256_madd_epi16(b, b);
268     const __m256i an = _mm256_max_epi32(_mm256_mullo_epi32(a, _mm256_set1_epi32(n)), bb);
269     return _mm256_sub_epi32(an, bb);
270 }
271 
272 // Assumes that C, D are integral images for the original buffer which has been
273 // extended to have a padding of SGRPROJ_BORDER_VERT/SGRPROJ_BORDER_HORZ pixels
274 // on the sides. A, b, C, D point at logical position (0, 0).
calc_ab(int32_t * A,int32_t * b,const int32_t * C,const int32_t * D,int32_t width,int32_t height,int32_t buf_stride,int32_t bit_depth,int32_t sgr_params_idx,int32_t radius_idx)275 static AOM_FORCE_INLINE void calc_ab(int32_t *A, int32_t *b, const int32_t *C, const int32_t *D,
276                                      int32_t width, int32_t height, int32_t buf_stride,
277                                      int32_t bit_depth, int32_t sgr_params_idx,
278                                      int32_t radius_idx) {
279     const SgrParamsType *const params = &eb_sgr_params[sgr_params_idx];
280     const int32_t              r      = params->r[radius_idx];
281     const int32_t              n      = (2 * r + 1) * (2 * r + 1);
282     const __m256i              s      = _mm256_set1_epi32(params->s[radius_idx]);
283     // one_over_n[n-1] is 2^12/n, so easily fits in an int16
284     const __m256i one_over_n = _mm256_set1_epi32(eb_one_by_x[n - 1]);
285     const __m256i rnd_z      = round_for_shift(SGRPROJ_MTABLE_BITS);
286     const __m256i rnd_res    = round_for_shift(SGRPROJ_RECIP_BITS);
287 
288     A -= buf_stride + 1;
289     b -= buf_stride + 1;
290     C -= buf_stride + 1;
291     D -= buf_stride + 1;
292 
293     int32_t i = height + 2;
294 
295     if (bit_depth == 8) {
296         do {
297             int32_t j = 0;
298             do {
299                 const __m256i sum1 = boxsum_from_ii(D + j, buf_stride, r);
300                 const __m256i sum2 = boxsum_from_ii(C + j, buf_stride, r);
301                 const __m256i p    = compute_p(sum1, sum2, n);
302                 const __m256i z    = _mm256_min_epi32(
303                     _mm256_srli_epi32(_mm256_add_epi32(_mm256_mullo_epi32(p, s), rnd_z),
304                                       SGRPROJ_MTABLE_BITS),
305                     _mm256_set1_epi32(255));
306                 const __m256i a_res = _mm256_i32gather_epi32(eb_x_by_xplus1, z, 4);
307                 yy_storeu_256(A + j, a_res);
308 
309                 const __m256i a_complement = _mm256_sub_epi32(_mm256_set1_epi32(SGRPROJ_SGR),
310                                                               a_res);
311 
312                 // sum1 might have lanes greater than 2^15, so we can't use madd to do
313                 // multiplication involving sum1. However, a_complement and one_over_n
314                 // are both less than 256, so we can multiply them first.
315                 const __m256i a_comp_over_n = _mm256_madd_epi16(a_complement, one_over_n);
316                 const __m256i b_int         = _mm256_mullo_epi32(a_comp_over_n, sum1);
317                 const __m256i b_res         = _mm256_srli_epi32(_mm256_add_epi32(b_int, rnd_res),
318                                                         SGRPROJ_RECIP_BITS);
319                 yy_storeu_256(b + j, b_res);
320                 j += 8;
321             } while (j < width + 2);
322 
323             A += buf_stride;
324             b += buf_stride;
325             C += buf_stride;
326             D += buf_stride;
327         } while (--i);
328     } else {
329         do {
330             int32_t j = 0;
331             do {
332                 const __m256i sum1 = boxsum_from_ii(D + j, buf_stride, r);
333                 const __m256i sum2 = boxsum_from_ii(C + j, buf_stride, r);
334                 const __m256i p    = compute_p_highbd(sum1, sum2, bit_depth, n);
335                 const __m256i z    = _mm256_min_epi32(
336                     _mm256_srli_epi32(_mm256_add_epi32(_mm256_mullo_epi32(p, s), rnd_z),
337                                       SGRPROJ_MTABLE_BITS),
338                     _mm256_set1_epi32(255));
339                 const __m256i a_res = _mm256_i32gather_epi32(eb_x_by_xplus1, z, 4);
340                 yy_storeu_256(A + j, a_res);
341 
342                 const __m256i a_complement = _mm256_sub_epi32(_mm256_set1_epi32(SGRPROJ_SGR),
343                                                               a_res);
344 
345                 // sum1 might have lanes greater than 2^15, so we can't use madd to do
346                 // multiplication involving sum1. However, a_complement and one_over_n
347                 // are both less than 256, so we can multiply them first.
348                 const __m256i a_comp_over_n = _mm256_madd_epi16(a_complement, one_over_n);
349                 const __m256i b_int         = _mm256_mullo_epi32(a_comp_over_n, sum1);
350                 const __m256i b_res         = _mm256_srli_epi32(_mm256_add_epi32(b_int, rnd_res),
351                                                         SGRPROJ_RECIP_BITS);
352                 yy_storeu_256(b + j, b_res);
353                 j += 8;
354             } while (j < width + 2);
355 
356             A += buf_stride;
357             b += buf_stride;
358             C += buf_stride;
359             D += buf_stride;
360         } while (--i);
361     }
362 }
363 
364 // Calculate 8 values of the "cross sum" starting at buf. This is a 3x3 filter
365 // where the outer four corners have weight 3 and all other pixels have weight
366 // 4.
367 //
368 // Pixels are indexed as follows:
369 // xtl  xt   xtr
370 // xl    x   xr
371 // xbl  xb   xbr
372 //
373 // buf points to x
374 //
375 // fours = xl + xt + xr + xb + x
376 // threes = xtl + xtr + xbr + xbl
377 // cross_sum = 4 * fours + 3 * threes
378 //           = 4 * (fours + threes) - threes
379 //           = (fours + threes) << 2 - threes
cross_sum(const int32_t * buf,int32_t stride)380 static INLINE __m256i cross_sum(const int32_t *buf, int32_t stride) {
381     const __m256i xtl = yy_loadu_256(buf - 1 - stride);
382     const __m256i xt  = yy_loadu_256(buf - stride);
383     const __m256i xtr = yy_loadu_256(buf + 1 - stride);
384     const __m256i xl  = yy_loadu_256(buf - 1);
385     const __m256i x   = yy_loadu_256(buf);
386     const __m256i xr  = yy_loadu_256(buf + 1);
387     const __m256i xbl = yy_loadu_256(buf - 1 + stride);
388     const __m256i xb  = yy_loadu_256(buf + stride);
389     const __m256i xbr = yy_loadu_256(buf + 1 + stride);
390 
391     const __m256i fours = _mm256_add_epi32(
392         xl, _mm256_add_epi32(xt, _mm256_add_epi32(xr, _mm256_add_epi32(xb, x))));
393     const __m256i threes = _mm256_add_epi32(xtl, _mm256_add_epi32(xtr, _mm256_add_epi32(xbr, xbl)));
394 
395     return _mm256_sub_epi32(_mm256_slli_epi32(_mm256_add_epi32(fours, threes), 2), threes);
396 }
397 
398 // The final filter for self-guided restoration. Computes a weighted average
399 // across A, b with "cross sums" (see cross_sum implementation above).
final_filter(int32_t * dst,int32_t dst_stride,const int32_t * A,const int32_t * B,int32_t buf_stride,const uint8_t * dgd8,int32_t dgd_stride,int32_t width,int32_t height,int32_t highbd)400 static AOM_FORCE_INLINE void final_filter(int32_t *dst, int32_t dst_stride, const int32_t *A,
401                                           const int32_t *B, int32_t buf_stride, const uint8_t *dgd8,
402                                           int32_t dgd_stride, int32_t width, int32_t height,
403                                           int32_t highbd) {
404     const int32_t nb       = 5;
405     const __m256i rounding = round_for_shift(SGRPROJ_SGR_BITS + nb - SGRPROJ_RST_BITS);
406     int32_t       i        = height;
407 
408     if (!highbd) {
409         do {
410             int32_t j = 0;
411             do {
412                 const __m256i a   = cross_sum(A + j, buf_stride);
413                 const __m256i b   = cross_sum(B + j, buf_stride);
414                 const __m128i raw = xx_loadl_64(dgd8 + j);
415                 const __m256i src = _mm256_cvtepu8_epi32(raw);
416                 const __m256i v   = _mm256_add_epi32(_mm256_madd_epi16(a, src), b);
417                 const __m256i w   = _mm256_srai_epi32(_mm256_add_epi32(v, rounding),
418                                                     SGRPROJ_SGR_BITS + nb - SGRPROJ_RST_BITS);
419                 yy_storeu_256(dst + j, w);
420                 j += 8;
421             } while (j < width);
422 
423             A += buf_stride;
424             B += buf_stride;
425             dgd8 += dgd_stride;
426             dst += dst_stride;
427         } while (--i);
428     } else {
429         const uint16_t *dgd_real = CONVERT_TO_SHORTPTR(dgd8);
430 
431         do {
432             int32_t j = 0;
433             do {
434                 const __m256i a   = cross_sum(A + j, buf_stride);
435                 const __m256i b   = cross_sum(B + j, buf_stride);
436                 const __m128i raw = xx_loadu_128(dgd_real + j);
437                 const __m256i src = _mm256_cvtepu16_epi32(raw);
438                 const __m256i v   = _mm256_add_epi32(_mm256_madd_epi16(a, src), b);
439                 const __m256i w   = _mm256_srai_epi32(_mm256_add_epi32(v, rounding),
440                                                     SGRPROJ_SGR_BITS + nb - SGRPROJ_RST_BITS);
441                 yy_storeu_256(dst + j, w);
442                 j += 8;
443             } while (j < width);
444 
445             A += buf_stride;
446             B += buf_stride;
447             dgd_real += dgd_stride;
448             dst += dst_stride;
449         } while (--i);
450     }
451 }
452 
453 // Assumes that C, D are integral images for the original buffer which has been
454 // extended to have a padding of SGRPROJ_BORDER_VERT/SGRPROJ_BORDER_HORZ pixels
455 // on the sides. A, b, C, D point at logical position (0, 0).
calc_ab_fast(int32_t * A,int32_t * b,const int32_t * C,const int32_t * D,int32_t width,int32_t height,int32_t buf_stride,int32_t bit_depth,int32_t sgr_params_idx,int32_t radius_idx)456 static AOM_FORCE_INLINE void calc_ab_fast(int32_t *A, int32_t *b, const int32_t *C,
457                                           const int32_t *D, int32_t width, int32_t height,
458                                           int32_t buf_stride, int32_t bit_depth,
459                                           int32_t sgr_params_idx, int32_t radius_idx) {
460     const SgrParamsType *const params = &eb_sgr_params[sgr_params_idx];
461     const int32_t              r      = params->r[radius_idx];
462     const int32_t              n      = (2 * r + 1) * (2 * r + 1);
463     const __m256i              s      = _mm256_set1_epi32(params->s[radius_idx]);
464     // one_over_n[n-1] is 2^12/n, so easily fits in an int16
465     const __m256i one_over_n = _mm256_set1_epi32(eb_one_by_x[n - 1]);
466     const __m256i rnd_z      = round_for_shift(SGRPROJ_MTABLE_BITS);
467     const __m256i rnd_res    = round_for_shift(SGRPROJ_RECIP_BITS);
468 
469     A -= buf_stride + 1;
470     b -= buf_stride + 1;
471     C -= buf_stride + 1;
472     D -= buf_stride + 1;
473 
474     int32_t i = 0;
475     if (bit_depth == 8) {
476         do {
477             int32_t j = 0;
478             do {
479                 const __m256i sum1 = boxsum_from_ii(D + j, buf_stride, r);
480                 const __m256i sum2 = boxsum_from_ii(C + j, buf_stride, r);
481                 const __m256i p    = compute_p(sum1, sum2, n);
482                 const __m256i z    = _mm256_min_epi32(
483                     _mm256_srli_epi32(_mm256_add_epi32(_mm256_mullo_epi32(p, s), rnd_z),
484                                       SGRPROJ_MTABLE_BITS),
485                     _mm256_set1_epi32(255));
486                 const __m256i a_res = _mm256_i32gather_epi32(eb_x_by_xplus1, z, 4);
487                 yy_storeu_256(A + j, a_res);
488 
489                 const __m256i a_complement = _mm256_sub_epi32(_mm256_set1_epi32(SGRPROJ_SGR),
490                                                               a_res);
491 
492                 // sum1 might have lanes greater than 2^15, so we can't use madd to do
493                 // multiplication involving sum1. However, a_complement and one_over_n
494                 // are both less than 256, so we can multiply them first.
495                 const __m256i a_comp_over_n = _mm256_madd_epi16(a_complement, one_over_n);
496                 const __m256i b_int         = _mm256_mullo_epi32(a_comp_over_n, sum1);
497                 const __m256i b_res         = _mm256_srli_epi32(_mm256_add_epi32(b_int, rnd_res),
498                                                         SGRPROJ_RECIP_BITS);
499                 yy_storeu_256(b + j, b_res);
500                 j += 8;
501             } while (j < width + 2);
502 
503             A += 2 * buf_stride;
504             b += 2 * buf_stride;
505             C += 2 * buf_stride;
506             D += 2 * buf_stride;
507             i += 2;
508         } while (i < height + 2);
509     } else {
510         do {
511             int32_t j = 0;
512             do {
513                 const __m256i sum1 = boxsum_from_ii(D + j, buf_stride, r);
514                 const __m256i sum2 = boxsum_from_ii(C + j, buf_stride, r);
515                 const __m256i p    = compute_p_highbd(sum1, sum2, bit_depth, n);
516                 const __m256i z    = _mm256_min_epi32(
517                     _mm256_srli_epi32(_mm256_add_epi32(_mm256_mullo_epi32(p, s), rnd_z),
518                                       SGRPROJ_MTABLE_BITS),
519                     _mm256_set1_epi32(255));
520                 const __m256i a_res = _mm256_i32gather_epi32(eb_x_by_xplus1, z, 4);
521                 yy_storeu_256(A + j, a_res);
522 
523                 const __m256i a_complement = _mm256_sub_epi32(_mm256_set1_epi32(SGRPROJ_SGR),
524                                                               a_res);
525 
526                 // sum1 might have lanes greater than 2^15, so we can't use madd to do
527                 // multiplication involving sum1. However, a_complement and one_over_n
528                 // are both less than 256, so we can multiply them first.
529                 const __m256i a_comp_over_n = _mm256_madd_epi16(a_complement, one_over_n);
530                 const __m256i b_int         = _mm256_mullo_epi32(a_comp_over_n, sum1);
531                 const __m256i b_res         = _mm256_srli_epi32(_mm256_add_epi32(b_int, rnd_res),
532                                                         SGRPROJ_RECIP_BITS);
533                 yy_storeu_256(b + j, b_res);
534                 j += 8;
535             } while (j < width + 2);
536 
537             A += 2 * buf_stride;
538             b += 2 * buf_stride;
539             C += 2 * buf_stride;
540             D += 2 * buf_stride;
541             i += 2;
542         } while (i < height + 2);
543     }
544 }
545 
546 // Calculate 8 values of the "cross sum" starting at buf.
547 //
548 // Pixels are indexed like this:
549 // xtl  xt   xtr
550 //  -   buf   -
551 // xbl  xb   xbr
552 //
553 // Pixels are weighted like this:
554 //  5    6    5
555 //  0    0    0
556 //  5    6    5
557 //
558 // fives = xtl + xtr + xbl + xbr
559 // sixes = xt + xb
560 // cross_sum = 6 * sixes + 5 * fives
561 //           = 5 * (fives + sixes) - sixes
562 //           = (fives + sixes) << 2 + (fives + sixes) + sixes
cross_sum_fast_even_row(const int32_t * buf,int32_t stride)563 static INLINE __m256i cross_sum_fast_even_row(const int32_t *buf, int32_t stride) {
564     const __m256i xtl = yy_loadu_256(buf - 1 - stride);
565     const __m256i xt  = yy_loadu_256(buf - stride);
566     const __m256i xtr = yy_loadu_256(buf + 1 - stride);
567     const __m256i xbl = yy_loadu_256(buf - 1 + stride);
568     const __m256i xb  = yy_loadu_256(buf + stride);
569     const __m256i xbr = yy_loadu_256(buf + 1 + stride);
570 
571     const __m256i fives = _mm256_add_epi32(xtl, _mm256_add_epi32(xtr, _mm256_add_epi32(xbr, xbl)));
572     const __m256i sixes = _mm256_add_epi32(xt, xb);
573     const __m256i fives_plus_sixes = _mm256_add_epi32(fives, sixes);
574 
575     return _mm256_add_epi32(
576         _mm256_add_epi32(_mm256_slli_epi32(fives_plus_sixes, 2), fives_plus_sixes), sixes);
577 }
578 
579 // Calculate 8 values of the "cross sum" starting at buf.
580 //
581 // Pixels are indexed like this:
582 // xl    x   xr
583 //
584 // Pixels are weighted like this:
585 //  5    6    5
586 //
587 // buf points to x
588 //
589 // fives = xl + xr
590 // sixes = x
591 // cross_sum = 5 * fives + 6 * sixes
592 //           = 4 * (fives + sixes) + (fives + sixes) + sixes
593 //           = (fives + sixes) << 2 + (fives + sixes) + sixes
cross_sum_fast_odd_row(const int32_t * buf)594 static INLINE __m256i cross_sum_fast_odd_row(const int32_t *buf) {
595     const __m256i xl = yy_loadu_256(buf - 1);
596     const __m256i x  = yy_loadu_256(buf);
597     const __m256i xr = yy_loadu_256(buf + 1);
598 
599     const __m256i fives = _mm256_add_epi32(xl, xr);
600     const __m256i sixes = x;
601 
602     const __m256i fives_plus_sixes = _mm256_add_epi32(fives, sixes);
603 
604     return _mm256_add_epi32(
605         _mm256_add_epi32(_mm256_slli_epi32(fives_plus_sixes, 2), fives_plus_sixes), sixes);
606 }
607 
608 // The final filter for the self-guided restoration. Computes a
609 // weighted average across A, b with "cross sums" (see cross_sum_...
610 // implementations above).
final_filter_fast(int32_t * dst,int32_t dst_stride,const int32_t * A,const int32_t * B,int32_t buf_stride,const uint8_t * dgd8,int32_t dgd_stride,int32_t width,int32_t height,int32_t highbd)611 static AOM_FORCE_INLINE void final_filter_fast(int32_t *dst, int32_t dst_stride, const int32_t *A,
612                                                const int32_t *B, int32_t buf_stride,
613                                                const uint8_t *dgd8, int32_t dgd_stride,
614                                                int32_t width, int32_t height, int32_t highbd) {
615     const int32_t nb0       = 5;
616     const int32_t nb1       = 4;
617     const __m256i rounding0 = round_for_shift(SGRPROJ_SGR_BITS + nb0 - SGRPROJ_RST_BITS);
618     const __m256i rounding1 = round_for_shift(SGRPROJ_SGR_BITS + nb1 - SGRPROJ_RST_BITS);
619     int32_t       i         = 0;
620 
621     if (!highbd) {
622         do {
623             if (!(i & 1)) { // even row
624                 int32_t j = 0;
625                 do {
626                     const __m256i a   = cross_sum_fast_even_row(A + j, buf_stride);
627                     const __m256i b   = cross_sum_fast_even_row(B + j, buf_stride);
628                     const __m128i raw = xx_loadl_64(dgd8 + j);
629                     const __m256i src = _mm256_cvtepu8_epi32(raw);
630                     const __m256i v   = _mm256_add_epi32(_mm256_madd_epi16(a, src), b);
631                     const __m256i w   = _mm256_srai_epi32(_mm256_add_epi32(v, rounding0),
632                                                         SGRPROJ_SGR_BITS + nb0 - SGRPROJ_RST_BITS);
633                     yy_storeu_256(dst + j, w);
634                     j += 8;
635                 } while (j < width);
636             } else { // odd row
637                 int32_t j = 0;
638                 do {
639                     const __m256i a   = cross_sum_fast_odd_row(A + j);
640                     const __m256i b   = cross_sum_fast_odd_row(B + j);
641                     const __m128i raw = xx_loadl_64(dgd8 + j);
642                     const __m256i src = _mm256_cvtepu8_epi32(raw);
643                     const __m256i v   = _mm256_add_epi32(_mm256_madd_epi16(a, src), b);
644                     const __m256i w   = _mm256_srai_epi32(_mm256_add_epi32(v, rounding1),
645                                                         SGRPROJ_SGR_BITS + nb1 - SGRPROJ_RST_BITS);
646                     yy_storeu_256(dst + j, w);
647                     j += 8;
648                 } while (j < width);
649             }
650 
651             A += buf_stride;
652             B += buf_stride;
653             dgd8 += dgd_stride;
654             dst += dst_stride;
655         } while (++i < height);
656     } else {
657         const uint16_t *dgd_real = CONVERT_TO_SHORTPTR(dgd8);
658 
659         do {
660             if (!(i & 1)) { // even row
661                 int32_t j = 0;
662                 do {
663                     const __m256i a   = cross_sum_fast_even_row(A + j, buf_stride);
664                     const __m256i b   = cross_sum_fast_even_row(B + j, buf_stride);
665                     const __m128i raw = xx_loadu_128(dgd_real + j);
666                     const __m256i src = _mm256_cvtepu16_epi32(raw);
667                     const __m256i v   = _mm256_add_epi32(_mm256_madd_epi16(a, src), b);
668                     const __m256i w   = _mm256_srai_epi32(_mm256_add_epi32(v, rounding0),
669                                                         SGRPROJ_SGR_BITS + nb0 - SGRPROJ_RST_BITS);
670                     yy_storeu_256(dst + j, w);
671                     j += 8;
672                 } while (j < width);
673             } else { // odd row
674                 int32_t j = 0;
675                 do {
676                     const __m256i a   = cross_sum_fast_odd_row(A + j);
677                     const __m256i b   = cross_sum_fast_odd_row(B + j);
678                     const __m128i raw = xx_loadu_128(dgd_real + j);
679                     const __m256i src = _mm256_cvtepu16_epi32(raw);
680                     const __m256i v   = _mm256_add_epi32(_mm256_madd_epi16(a, src), b);
681                     const __m256i w   = _mm256_srai_epi32(_mm256_add_epi32(v, rounding1),
682                                                         SGRPROJ_SGR_BITS + nb1 - SGRPROJ_RST_BITS);
683                     yy_storeu_256(dst + j, w);
684                     j += 8;
685                 } while (j < width);
686             }
687 
688             A += buf_stride;
689             B += buf_stride;
690             dgd_real += dgd_stride;
691             dst += dst_stride;
692         } while (++i < height);
693     }
694 }
695 
svt_av1_selfguided_restoration_avx2(const uint8_t * dgd8,int32_t width,int32_t height,int32_t dgd_stride,int32_t * flt0,int32_t * flt1,int32_t flt_stride,int32_t sgr_params_idx,int32_t bit_depth,int32_t highbd)696 void svt_av1_selfguided_restoration_avx2(const uint8_t *dgd8, int32_t width, int32_t height,
697                                          int32_t dgd_stride, int32_t *flt0, int32_t *flt1,
698                                          int32_t flt_stride, int32_t sgr_params_idx,
699                                          int32_t bit_depth, int32_t highbd) {
700     // The ALIGN_POWER_OF_TWO macro here ensures that column 1 of atl, btl,
701     // ctl and dtl is 32-byte aligned.
702     const int32_t buf_elts = ALIGN_POWER_OF_TWO(RESTORATION_PROC_UNIT_PELS, 3);
703 
704     DECLARE_ALIGNED(32, int32_t, buf[4 * ALIGN_POWER_OF_TWO(RESTORATION_PROC_UNIT_PELS, 3)]);
705 
706     const int32_t width_ext  = width + 2 * SGRPROJ_BORDER_HORZ;
707     const int32_t height_ext = height + 2 * SGRPROJ_BORDER_VERT;
708 
709     // Adjusting the stride of A and b here appears to avoid bad cache effects,
710     // leading to a significant speed improvement.
711     // We also align the stride to a multiple of 32 bytes for efficiency.
712     int32_t buf_stride = ALIGN_POWER_OF_TWO(width_ext + 16, 3);
713 
714     // The "tl" pointers point at the top-left of the initialised data for the
715     // array.
716     int32_t *atl = buf + 0 * buf_elts + 7;
717     int32_t *btl = buf + 1 * buf_elts + 7;
718     int32_t *ctl = buf + 2 * buf_elts + 7;
719     int32_t *dtl = buf + 3 * buf_elts + 7;
720 
721     // The "0" pointers are (- SGRPROJ_BORDER_VERT, -SGRPROJ_BORDER_HORZ). Note
722     // there's a zero row and column in A, b (integral images), so we move down
723     // and right one for them.
724     const int32_t buf_diag_border = SGRPROJ_BORDER_HORZ + buf_stride * SGRPROJ_BORDER_VERT;
725 
726     int32_t *a0 = atl + 1 + buf_stride;
727     int32_t *b0 = btl + 1 + buf_stride;
728     int32_t *c0 = ctl + 1 + buf_stride;
729     int32_t *d0 = dtl + 1 + buf_stride;
730 
731     // Finally, A, b, C, D point at position (0, 0).
732     int32_t *A = a0 + buf_diag_border;
733     int32_t *b = b0 + buf_diag_border;
734     int32_t *C = c0 + buf_diag_border;
735     int32_t *D = d0 + buf_diag_border;
736 
737     const int32_t  dgd_diag_border = SGRPROJ_BORDER_HORZ + dgd_stride * SGRPROJ_BORDER_VERT;
738     const uint8_t *dgd0            = dgd8 - dgd_diag_border;
739 
740     // Generate integral images from the input. C will contain sums of squares; D
741     // will contain just sums
742     if (highbd)
743         integral_images_highbd(
744             CONVERT_TO_SHORTPTR(dgd0), dgd_stride, width_ext, height_ext, ctl, dtl, buf_stride);
745     else
746         integral_images(dgd0, dgd_stride, width_ext, height_ext, ctl, dtl, buf_stride);
747 
748     const SgrParamsType *const params = &eb_sgr_params[sgr_params_idx];
749     // Write to flt0 and flt1
750     // If params->r == 0 we skip the corresponding filter. We only allow one of
751     // the radii to be 0, as having both equal to 0 would be equivalent to
752     // skipping SGR entirely.
753     assert(!(params->r[0] == 0 && params->r[1] == 0));
754     assert(params->r[0] < 3); // AOMMIN(SGRPROJ_BORDER_VERT, SGRPROJ_BORDER_HORZ) == 3
755     assert(params->r[1] < 3); // AOMMIN(SGRPROJ_BORDER_VERT, SGRPROJ_BORDER_HORZ) == 3
756 
757     if (params->r[0] > 0) {
758         calc_ab_fast(A, b, C, D, width, height, buf_stride, bit_depth, sgr_params_idx, 0);
759         final_filter_fast(
760             flt0, flt_stride, A, b, buf_stride, dgd8, dgd_stride, width, height, highbd);
761     }
762 
763     if (params->r[1] > 0) {
764         calc_ab(A, b, C, D, width, height, buf_stride, bit_depth, sgr_params_idx, 1);
765         final_filter(flt1, flt_stride, A, b, buf_stride, dgd8, dgd_stride, width, height, highbd);
766     }
767 }
768 
svt_apply_selfguided_restoration_avx2(const uint8_t * dat8,int32_t width,int32_t height,int32_t stride,int32_t eps,const int32_t * xqd,uint8_t * dst8,int32_t dst_stride,int32_t * tmpbuf,int32_t bit_depth,int32_t highbd)769 void svt_apply_selfguided_restoration_avx2(const uint8_t *dat8, int32_t width, int32_t height,
770                                            int32_t stride, int32_t eps, const int32_t *xqd,
771                                            uint8_t *dst8, int32_t dst_stride, int32_t *tmpbuf,
772                                            int32_t bit_depth, int32_t highbd) {
773     int32_t *flt0 = tmpbuf;
774     int32_t *flt1 = flt0 + RESTORATION_UNITPELS_MAX;
775     assert(width * height <= RESTORATION_UNITPELS_MAX);
776     svt_av1_selfguided_restoration_avx2(
777         dat8, width, height, stride, flt0, flt1, width, eps, bit_depth, highbd);
778     const SgrParamsType *const params = &eb_sgr_params[eps];
779     int32_t                    xq[2];
780     svt_decode_xq(xqd, xq, params);
781 
782     const __m256i xq0      = _mm256_set1_epi32(xq[0]);
783     const __m256i xq1      = _mm256_set1_epi32(xq[1]);
784     const __m256i rounding = round_for_shift(SGRPROJ_PRJ_BITS + SGRPROJ_RST_BITS);
785 
786     int32_t i = height;
787 
788     if (!highbd) {
789         const __m256i idx = _mm256_setr_epi32(0, 4, 1, 5, 0, 0, 0, 0);
790 
791         do {
792             // Calculate output in batches of 16 pixels
793             int32_t j = 0;
794             do {
795                 const __m128i src  = xx_loadu_128(dat8 + j);
796                 const __m256i ep_0 = _mm256_cvtepu8_epi32(src);
797                 const __m256i ep_1 = _mm256_cvtepu8_epi32(_mm_srli_si128(src, 8));
798                 const __m256i u_0  = _mm256_slli_epi32(ep_0, SGRPROJ_RST_BITS);
799                 const __m256i u_1  = _mm256_slli_epi32(ep_1, SGRPROJ_RST_BITS);
800                 __m256i       v_0  = _mm256_slli_epi32(u_0, SGRPROJ_PRJ_BITS);
801                 __m256i       v_1  = _mm256_slli_epi32(u_1, SGRPROJ_PRJ_BITS);
802 
803                 if (params->r[0] > 0) {
804                     const __m256i f1_0 = _mm256_sub_epi32(yy_loadu_256(&flt0[j + 0]), u_0);
805                     const __m256i f1_1 = _mm256_sub_epi32(yy_loadu_256(&flt0[j + 8]), u_1);
806                     v_0                = _mm256_add_epi32(v_0, _mm256_mullo_epi32(xq0, f1_0));
807                     v_1                = _mm256_add_epi32(v_1, _mm256_mullo_epi32(xq0, f1_1));
808                 }
809 
810                 if (params->r[1] > 0) {
811                     const __m256i f2_0 = _mm256_sub_epi32(yy_loadu_256(&flt1[j + 0]), u_0);
812                     const __m256i f2_1 = _mm256_sub_epi32(yy_loadu_256(&flt1[j + 8]), u_1);
813                     v_0                = _mm256_add_epi32(v_0, _mm256_mullo_epi32(xq1, f2_0));
814                     v_1                = _mm256_add_epi32(v_1, _mm256_mullo_epi32(xq1, f2_1));
815                 }
816 
817                 const __m256i w_0 = _mm256_srai_epi32(_mm256_add_epi32(v_0, rounding),
818                                                       SGRPROJ_PRJ_BITS + SGRPROJ_RST_BITS);
819                 const __m256i w_1 = _mm256_srai_epi32(_mm256_add_epi32(v_1, rounding),
820                                                       SGRPROJ_PRJ_BITS + SGRPROJ_RST_BITS);
821 
822                 // Pack into 8 bits and clamp to [0, 256)
823                 // Note that each pack messes up the order of the bits,
824                 // so we use a permute function to correct this
825                 // 0, 1, 4, 5, 2, 3, 6, 7
826                 const __m256i tmp = _mm256_packus_epi32(w_0, w_1);
827                 // 0, 1, 4, 5, 2, 3, 6, 7, 0, 1, 4, 5, 2, 3, 6, 7
828                 const __m256i tmp2 = _mm256_packus_epi16(tmp, tmp);
829                 // 0, 1, 2, 3, 4, 5, 6, 7, ...
830                 const __m256i tmp3 = _mm256_permutevar8x32_epi32(tmp2, idx);
831                 const __m128i res  = _mm256_castsi256_si128(tmp3);
832                 xx_storeu_128(dst8 + j, res);
833                 j += 16;
834             } while (j < width);
835 
836             dat8 += stride;
837             flt0 += width;
838             flt1 += width;
839             dst8 += dst_stride;
840         } while (--i);
841     } else {
842         const __m256i   max   = _mm256_set1_epi16((1 << bit_depth) - 1);
843         const uint16_t *dat16 = CONVERT_TO_SHORTPTR(dat8);
844         uint16_t *      dst16 = CONVERT_TO_SHORTPTR(dst8);
845 
846         do {
847             // Calculate output in batches of 16 pixels
848             int32_t j = 0;
849             do {
850                 const __m128i src_0 = xx_loadu_128(dat16 + j + 0);
851                 const __m128i src_1 = xx_loadu_128(dat16 + j + 8);
852                 const __m256i ep_0  = _mm256_cvtepu16_epi32(src_0);
853                 const __m256i ep_1  = _mm256_cvtepu16_epi32(src_1);
854                 const __m256i u_0   = _mm256_slli_epi32(ep_0, SGRPROJ_RST_BITS);
855                 const __m256i u_1   = _mm256_slli_epi32(ep_1, SGRPROJ_RST_BITS);
856                 __m256i       v_0   = _mm256_slli_epi32(u_0, SGRPROJ_PRJ_BITS);
857                 __m256i       v_1   = _mm256_slli_epi32(u_1, SGRPROJ_PRJ_BITS);
858 
859                 if (params->r[0] > 0) {
860                     const __m256i f1_0 = _mm256_sub_epi32(yy_loadu_256(&flt0[j + 0]), u_0);
861                     const __m256i f1_1 = _mm256_sub_epi32(yy_loadu_256(&flt0[j + 8]), u_1);
862                     v_0                = _mm256_add_epi32(v_0, _mm256_mullo_epi32(xq0, f1_0));
863                     v_1                = _mm256_add_epi32(v_1, _mm256_mullo_epi32(xq0, f1_1));
864                 }
865 
866                 if (params->r[1] > 0) {
867                     const __m256i f2_0 = _mm256_sub_epi32(yy_loadu_256(&flt1[j + 0]), u_0);
868                     const __m256i f2_1 = _mm256_sub_epi32(yy_loadu_256(&flt1[j + 8]), u_1);
869                     v_0                = _mm256_add_epi32(v_0, _mm256_mullo_epi32(xq1, f2_0));
870                     v_1                = _mm256_add_epi32(v_1, _mm256_mullo_epi32(xq1, f2_1));
871                 }
872 
873                 const __m256i w_0 = _mm256_srai_epi32(_mm256_add_epi32(v_0, rounding),
874                                                       SGRPROJ_PRJ_BITS + SGRPROJ_RST_BITS);
875                 const __m256i w_1 = _mm256_srai_epi32(_mm256_add_epi32(v_1, rounding),
876                                                       SGRPROJ_PRJ_BITS + SGRPROJ_RST_BITS);
877 
878                 // Pack into 16 bits and clamp to [0, 2^bit_depth)
879                 // Note that packing into 16 bits messes up the order of the bits,
880                 // so we use a permute function to correct this
881                 const __m256i tmp  = _mm256_packus_epi32(w_0, w_1);
882                 const __m256i tmp2 = _mm256_permute4x64_epi64(tmp, 0xd8);
883                 const __m256i res  = _mm256_min_epi16(tmp2, max);
884                 yy_storeu_256(dst16 + j, res);
885                 j += 16;
886             } while (j < width);
887 
888             dat16 += stride;
889             flt0 += width;
890             flt1 += width;
891             dst16 += dst_stride;
892         } while (--i);
893     }
894 }
895