1 /*
2 * Copyright(c) 2019 Intel Corporation
3 *
4 * This source code is subject to the terms of the BSD 2 Clause License and
5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6 * was not distributed with this source code in the LICENSE file, you can
7 * obtain it at https://www.aomedia.org/license/software-license. If the Alliance for Open
8 * Media Patent License 1.0 was not distributed with this source code in the
9 * PATENTS file, you can obtain it at https://www.aomedia.org/license/patent-license.
10 */
11 
12 #include "EbDefinitions.h"
13 
14 #if EN_AVX512_SUPPORT
15 #include <immintrin.h>
16 #include "common_dsp_rtcd.h"
17 #include "EbBitstreamUnit.h"
18 #include "EbCdef.h"
19 #include "EbMemory_AVX2.h"
20 
loadu_u16_8x4_avx512(const uint16_t * const src,const uint32_t stride)21 static INLINE __m512i loadu_u16_8x4_avx512(const uint16_t *const src, const uint32_t stride) {
22     const __m256i s0 = loadu_u16_8x2_avx2(src + 0 * stride, stride);
23     const __m256i s1 = loadu_u16_8x2_avx2(src + 2 * stride, stride);
24     return _mm512_inserti64x4(_mm512_castsi256_si512(s0), s1, 1);
25 }
26 
27 // sign(a-b) * min(abs(a-b), max(0, threshold - (abs(a-b) >> adjdamp)))
constrain16_avx512(const __m512i in0,const __m512i in1,const __m512i threshold,const __m128i damping)28 static INLINE __m512i constrain16_avx512(const __m512i in0, const __m512i in1,
29                                          const __m512i threshold, const __m128i damping) {
30     const __m512i diff = _mm512_sub_epi16(in0, in1);
31     const __m512i sign = _mm512_srai_epi16(diff, 15);
32     const __m512i a    = _mm512_abs_epi16(diff);
33     const __m512i l    = _mm512_srl_epi16(a, damping);
34     const __m512i s    = _mm512_subs_epu16(threshold, l);
35     const __m512i m    = _mm512_min_epi16(a, s);
36     const __m512i d    = _mm512_add_epi16(sign, m);
37     return _mm512_xor_si512(d, sign);
38 }
39 
cdef_filter_block_8x8_16_pri_avx512(const uint16_t * const in,const __m128i damping,const int32_t po,const __m512i row,const __m512i strength,const __m512i pri_taps,__m512i * const max,__m512i * const min,__m512i * const sum)40 static INLINE void cdef_filter_block_8x8_16_pri_avx512(const uint16_t *const in,
41                                                        const __m128i damping, const int32_t po,
42                                                        const __m512i row, const __m512i strength,
43                                                        const __m512i pri_taps, __m512i *const max,
44                                                        __m512i *const min, __m512i *const sum) {
45     const __m512i mask = _mm512_set1_epi16(0x3FFF);
46     const __m512i p0   = loadu_u16_8x4_avx512(in + po, CDEF_BSTRIDE);
47     const __m512i p1   = loadu_u16_8x4_avx512(in - po, CDEF_BSTRIDE);
48 
49     *max = _mm512_max_epi16(*max, _mm512_and_si512(p0, mask));
50     *max = _mm512_max_epi16(*max, _mm512_and_si512(p1, mask));
51     *min = _mm512_min_epi16(*min, p0);
52     *min = _mm512_min_epi16(*min, p1);
53 
54     const __m512i q0 = constrain16_avx512(p0, row, strength, damping);
55     const __m512i q1 = constrain16_avx512(p1, row, strength, damping);
56 
57     // sum += pri_taps * (p0 + p1)
58     *sum = _mm512_add_epi16(*sum, _mm512_mullo_epi16(pri_taps, _mm512_add_epi16(q0, q1)));
59 }
60 
cdef_filter_block_8x8_16_sec_avx512(const uint16_t * const in,const __m128i damping,const int32_t so1,const int32_t so2,const __m512i row,const __m512i strength,const __m512i sec_taps,__m512i * const max,__m512i * const min,__m512i * const sum)61 static INLINE void cdef_filter_block_8x8_16_sec_avx512(const uint16_t *const in,
62                                                        const __m128i damping, const int32_t so1,
63                                                        const int32_t so2, const __m512i row,
64                                                        const __m512i strength,
65                                                        const __m512i sec_taps, __m512i *const max,
66                                                        __m512i *const min, __m512i *const sum) {
67     const __m512i mask = _mm512_set1_epi16(0x3FFF);
68     const __m512i p0   = loadu_u16_8x4_avx512(in + so1, CDEF_BSTRIDE);
69     const __m512i p1   = loadu_u16_8x4_avx512(in - so1, CDEF_BSTRIDE);
70     const __m512i p2   = loadu_u16_8x4_avx512(in + so2, CDEF_BSTRIDE);
71     const __m512i p3   = loadu_u16_8x4_avx512(in - so2, CDEF_BSTRIDE);
72 
73     *max = _mm512_max_epi16(*max, _mm512_and_si512(p0, mask));
74     *max = _mm512_max_epi16(*max, _mm512_and_si512(p1, mask));
75     *max = _mm512_max_epi16(*max, _mm512_and_si512(p2, mask));
76     *max = _mm512_max_epi16(*max, _mm512_and_si512(p3, mask));
77     *min = _mm512_min_epi16(*min, p0);
78     *min = _mm512_min_epi16(*min, p1);
79     *min = _mm512_min_epi16(*min, p2);
80     *min = _mm512_min_epi16(*min, p3);
81 
82     const __m512i q0 = constrain16_avx512(p0, row, strength, damping);
83     const __m512i q1 = constrain16_avx512(p1, row, strength, damping);
84     const __m512i q2 = constrain16_avx512(p2, row, strength, damping);
85     const __m512i q3 = constrain16_avx512(p3, row, strength, damping);
86 
87     // sum += sec_taps * (p0 + p1 + p2 + p3)
88     *sum = _mm512_add_epi16(
89         *sum,
90         _mm512_mullo_epi16(sec_taps,
91                            _mm512_add_epi16(_mm512_add_epi16(q0, q1), _mm512_add_epi16(q2, q3))));
92 }
93 
svt_cdef_filter_block_8x8_16_avx512(const uint16_t * const in,const int32_t pri_strength,const int32_t sec_strength,const int32_t dir,int32_t pri_damping,int32_t sec_damping,const int32_t coeff_shift,uint16_t * const dst,const int32_t dstride)94 void svt_cdef_filter_block_8x8_16_avx512(const uint16_t *const in, const int32_t pri_strength,
95                                          const int32_t sec_strength, const int32_t dir,
96                                          int32_t pri_damping, int32_t sec_damping,
97                                          const int32_t coeff_shift, uint16_t *const dst,
98                                          const int32_t dstride) {
99     const int32_t  po1              = eb_cdef_directions[dir][0];
100     const int32_t  po2              = eb_cdef_directions[dir][1];
101     const int32_t  s1o1             = eb_cdef_directions[(dir + 2) & 7][0];
102     const int32_t  s1o2             = eb_cdef_directions[(dir + 2) & 7][1];
103     const int32_t  s2o1             = eb_cdef_directions[(dir + 6) & 7][0];
104     const int32_t  s2o2             = eb_cdef_directions[(dir + 6) & 7][1];
105     const int32_t *pri_taps         = eb_cdef_pri_taps[(pri_strength >> coeff_shift) & 1];
106     const int32_t *sec_taps         = eb_cdef_sec_taps[(pri_strength >> coeff_shift) & 1];
107     const __m512i  pri_taps_0       = _mm512_set1_epi16(pri_taps[0]);
108     const __m512i  pri_taps_1       = _mm512_set1_epi16(pri_taps[1]);
109     const __m512i  sec_taps_0       = _mm512_set1_epi16(sec_taps[0]);
110     const __m512i  sec_taps_1       = _mm512_set1_epi16(sec_taps[1]);
111     const __m512i  duplicate_8      = _mm512_set1_epi16(8);
112     const __m512i  pri_strength_256 = _mm512_set1_epi16(pri_strength);
113     const __m512i  sec_strength_256 = _mm512_set1_epi16(sec_strength);
114     const __m512i  zero             = _mm512_setzero_si512();
115 
116     if (pri_strength)
117         pri_damping = AOMMAX(0, pri_damping - get_msb(pri_strength));
118     if (sec_strength)
119         sec_damping = AOMMAX(0, sec_damping - get_msb(sec_strength));
120 
121     const __m128i pri_d = _mm_cvtsi32_si128(pri_damping);
122     const __m128i sec_d = _mm_cvtsi32_si128(sec_damping);
123 
124     {
125         const __m512i row = loadu_u16_8x4_avx512(in, CDEF_BSTRIDE);
126         __m512i       sum, res, max, min;
127 
128         min = max = row;
129         sum       = zero;
130 
131         // Primary near taps
132         cdef_filter_block_8x8_16_pri_avx512(
133             in, pri_d, po1, row, pri_strength_256, pri_taps_0, &max, &min, &sum);
134 
135         // Primary far taps
136         cdef_filter_block_8x8_16_pri_avx512(
137             in, pri_d, po2, row, pri_strength_256, pri_taps_1, &max, &min, &sum);
138 
139         // Secondary near taps
140         cdef_filter_block_8x8_16_sec_avx512(
141             in, sec_d, s1o1, s2o1, row, sec_strength_256, sec_taps_0, &max, &min, &sum);
142 
143         // Secondary far taps
144         cdef_filter_block_8x8_16_sec_avx512(
145             in, sec_d, s1o2, s2o2, row, sec_strength_256, sec_taps_1, &max, &min, &sum);
146 
147         // res = row + ((sum - (sum < 0) + 8) >> 4)
148         const __mmask32 mask = _mm512_cmpgt_epi16_mask(zero, sum);
149         sum                  = _mm512_mask_add_epi16(sum, mask, sum, _mm512_set1_epi16(-1));
150         res                  = _mm512_add_epi16(sum, duplicate_8);
151         res                  = _mm512_srai_epi16(res, 4);
152         res                  = _mm512_add_epi16(row, res);
153         res                  = _mm512_max_epi16(res, min);
154         res                  = _mm512_min_epi16(res, max);
155 
156         _mm_storeu_si128((__m128i *)&dst[0 * dstride], _mm512_castsi512_si128(res));
157         _mm_storeu_si128((__m128i *)&dst[1 * dstride], _mm512_extracti32x4_epi32(res, 1));
158         _mm_storeu_si128((__m128i *)&dst[2 * dstride], _mm512_extracti32x4_epi32(res, 2));
159         _mm_storeu_si128((__m128i *)&dst[3 * dstride], _mm512_extracti32x4_epi32(res, 3));
160     }
161 
162     {
163         const __m512i row = loadu_u16_8x4_avx512(in + 4 * CDEF_BSTRIDE, CDEF_BSTRIDE);
164         __m512i       sum, res, max, min;
165 
166         min = max = row;
167         sum       = zero;
168 
169         // Primary near taps
170         cdef_filter_block_8x8_16_pri_avx512(
171             in + 4 * CDEF_BSTRIDE, pri_d, po1, row, pri_strength_256, pri_taps_0, &max, &min, &sum);
172 
173         // Primary far taps
174         cdef_filter_block_8x8_16_pri_avx512(
175             in + 4 * CDEF_BSTRIDE, pri_d, po2, row, pri_strength_256, pri_taps_1, &max, &min, &sum);
176 
177         // Secondary near taps
178         cdef_filter_block_8x8_16_sec_avx512(in + 4 * CDEF_BSTRIDE,
179                                             sec_d,
180                                             s1o1,
181                                             s2o1,
182                                             row,
183                                             sec_strength_256,
184                                             sec_taps_0,
185                                             &max,
186                                             &min,
187                                             &sum);
188 
189         // Secondary far taps
190         cdef_filter_block_8x8_16_sec_avx512(in + 4 * CDEF_BSTRIDE,
191                                             sec_d,
192                                             s1o2,
193                                             s2o2,
194                                             row,
195                                             sec_strength_256,
196                                             sec_taps_1,
197                                             &max,
198                                             &min,
199                                             &sum);
200 
201         // res = row + ((sum - (sum < 0) + 8) >> 4)
202         const __mmask32 mask = _mm512_cmpgt_epi16_mask(zero, sum);
203         sum                  = _mm512_mask_add_epi16(sum, mask, sum, _mm512_set1_epi16(-1));
204         res                  = _mm512_add_epi16(sum, duplicate_8);
205         res                  = _mm512_srai_epi16(res, 4);
206         res                  = _mm512_add_epi16(row, res);
207         res                  = _mm512_max_epi16(res, min);
208         res                  = _mm512_min_epi16(res, max);
209 
210         _mm_storeu_si128((__m128i *)&dst[4 * dstride], _mm512_castsi512_si128(res));
211         _mm_storeu_si128((__m128i *)&dst[5 * dstride], _mm512_extracti32x4_epi32(res, 1));
212         _mm_storeu_si128((__m128i *)&dst[6 * dstride], _mm512_extracti32x4_epi32(res, 2));
213         _mm_storeu_si128((__m128i *)&dst[7 * dstride], _mm512_extracti32x4_epi32(res, 3));
214     }
215 }
216 
217 #endif // EN_AVX512_SUPPORT
218