1 /*
2 * Copyright(c) 2019 Intel Corporation
3 *
4 * This source code is subject to the terms of the BSD 2 Clause License and
5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6 * was not distributed with this source code in the LICENSE file, you can
7 * obtain it at https://www.aomedia.org/license/software-license. If the Alliance for Open
8 * Media Patent License 1.0 was not distributed with this source code in the
9 * PATENTS file, you can obtain it at https://www.aomedia.org/license/patent-license.
10 */
11
12 #include "EbDefinitions.h"
13
14 #if EN_AVX512_SUPPORT
15 #include <immintrin.h>
16 #include "common_dsp_rtcd.h"
17 #include "EbBitstreamUnit.h"
18 #include "EbCdef.h"
19 #include "EbMemory_AVX2.h"
20
loadu_u16_8x4_avx512(const uint16_t * const src,const uint32_t stride)21 static INLINE __m512i loadu_u16_8x4_avx512(const uint16_t *const src, const uint32_t stride) {
22 const __m256i s0 = loadu_u16_8x2_avx2(src + 0 * stride, stride);
23 const __m256i s1 = loadu_u16_8x2_avx2(src + 2 * stride, stride);
24 return _mm512_inserti64x4(_mm512_castsi256_si512(s0), s1, 1);
25 }
26
27 // sign(a-b) * min(abs(a-b), max(0, threshold - (abs(a-b) >> adjdamp)))
constrain16_avx512(const __m512i in0,const __m512i in1,const __m512i threshold,const __m128i damping)28 static INLINE __m512i constrain16_avx512(const __m512i in0, const __m512i in1,
29 const __m512i threshold, const __m128i damping) {
30 const __m512i diff = _mm512_sub_epi16(in0, in1);
31 const __m512i sign = _mm512_srai_epi16(diff, 15);
32 const __m512i a = _mm512_abs_epi16(diff);
33 const __m512i l = _mm512_srl_epi16(a, damping);
34 const __m512i s = _mm512_subs_epu16(threshold, l);
35 const __m512i m = _mm512_min_epi16(a, s);
36 const __m512i d = _mm512_add_epi16(sign, m);
37 return _mm512_xor_si512(d, sign);
38 }
39
cdef_filter_block_8x8_16_pri_avx512(const uint16_t * const in,const __m128i damping,const int32_t po,const __m512i row,const __m512i strength,const __m512i pri_taps,__m512i * const max,__m512i * const min,__m512i * const sum)40 static INLINE void cdef_filter_block_8x8_16_pri_avx512(const uint16_t *const in,
41 const __m128i damping, const int32_t po,
42 const __m512i row, const __m512i strength,
43 const __m512i pri_taps, __m512i *const max,
44 __m512i *const min, __m512i *const sum) {
45 const __m512i mask = _mm512_set1_epi16(0x3FFF);
46 const __m512i p0 = loadu_u16_8x4_avx512(in + po, CDEF_BSTRIDE);
47 const __m512i p1 = loadu_u16_8x4_avx512(in - po, CDEF_BSTRIDE);
48
49 *max = _mm512_max_epi16(*max, _mm512_and_si512(p0, mask));
50 *max = _mm512_max_epi16(*max, _mm512_and_si512(p1, mask));
51 *min = _mm512_min_epi16(*min, p0);
52 *min = _mm512_min_epi16(*min, p1);
53
54 const __m512i q0 = constrain16_avx512(p0, row, strength, damping);
55 const __m512i q1 = constrain16_avx512(p1, row, strength, damping);
56
57 // sum += pri_taps * (p0 + p1)
58 *sum = _mm512_add_epi16(*sum, _mm512_mullo_epi16(pri_taps, _mm512_add_epi16(q0, q1)));
59 }
60
cdef_filter_block_8x8_16_sec_avx512(const uint16_t * const in,const __m128i damping,const int32_t so1,const int32_t so2,const __m512i row,const __m512i strength,const __m512i sec_taps,__m512i * const max,__m512i * const min,__m512i * const sum)61 static INLINE void cdef_filter_block_8x8_16_sec_avx512(const uint16_t *const in,
62 const __m128i damping, const int32_t so1,
63 const int32_t so2, const __m512i row,
64 const __m512i strength,
65 const __m512i sec_taps, __m512i *const max,
66 __m512i *const min, __m512i *const sum) {
67 const __m512i mask = _mm512_set1_epi16(0x3FFF);
68 const __m512i p0 = loadu_u16_8x4_avx512(in + so1, CDEF_BSTRIDE);
69 const __m512i p1 = loadu_u16_8x4_avx512(in - so1, CDEF_BSTRIDE);
70 const __m512i p2 = loadu_u16_8x4_avx512(in + so2, CDEF_BSTRIDE);
71 const __m512i p3 = loadu_u16_8x4_avx512(in - so2, CDEF_BSTRIDE);
72
73 *max = _mm512_max_epi16(*max, _mm512_and_si512(p0, mask));
74 *max = _mm512_max_epi16(*max, _mm512_and_si512(p1, mask));
75 *max = _mm512_max_epi16(*max, _mm512_and_si512(p2, mask));
76 *max = _mm512_max_epi16(*max, _mm512_and_si512(p3, mask));
77 *min = _mm512_min_epi16(*min, p0);
78 *min = _mm512_min_epi16(*min, p1);
79 *min = _mm512_min_epi16(*min, p2);
80 *min = _mm512_min_epi16(*min, p3);
81
82 const __m512i q0 = constrain16_avx512(p0, row, strength, damping);
83 const __m512i q1 = constrain16_avx512(p1, row, strength, damping);
84 const __m512i q2 = constrain16_avx512(p2, row, strength, damping);
85 const __m512i q3 = constrain16_avx512(p3, row, strength, damping);
86
87 // sum += sec_taps * (p0 + p1 + p2 + p3)
88 *sum = _mm512_add_epi16(
89 *sum,
90 _mm512_mullo_epi16(sec_taps,
91 _mm512_add_epi16(_mm512_add_epi16(q0, q1), _mm512_add_epi16(q2, q3))));
92 }
93
svt_cdef_filter_block_8x8_16_avx512(const uint16_t * const in,const int32_t pri_strength,const int32_t sec_strength,const int32_t dir,int32_t pri_damping,int32_t sec_damping,const int32_t coeff_shift,uint16_t * const dst,const int32_t dstride)94 void svt_cdef_filter_block_8x8_16_avx512(const uint16_t *const in, const int32_t pri_strength,
95 const int32_t sec_strength, const int32_t dir,
96 int32_t pri_damping, int32_t sec_damping,
97 const int32_t coeff_shift, uint16_t *const dst,
98 const int32_t dstride) {
99 const int32_t po1 = eb_cdef_directions[dir][0];
100 const int32_t po2 = eb_cdef_directions[dir][1];
101 const int32_t s1o1 = eb_cdef_directions[(dir + 2) & 7][0];
102 const int32_t s1o2 = eb_cdef_directions[(dir + 2) & 7][1];
103 const int32_t s2o1 = eb_cdef_directions[(dir + 6) & 7][0];
104 const int32_t s2o2 = eb_cdef_directions[(dir + 6) & 7][1];
105 const int32_t *pri_taps = eb_cdef_pri_taps[(pri_strength >> coeff_shift) & 1];
106 const int32_t *sec_taps = eb_cdef_sec_taps[(pri_strength >> coeff_shift) & 1];
107 const __m512i pri_taps_0 = _mm512_set1_epi16(pri_taps[0]);
108 const __m512i pri_taps_1 = _mm512_set1_epi16(pri_taps[1]);
109 const __m512i sec_taps_0 = _mm512_set1_epi16(sec_taps[0]);
110 const __m512i sec_taps_1 = _mm512_set1_epi16(sec_taps[1]);
111 const __m512i duplicate_8 = _mm512_set1_epi16(8);
112 const __m512i pri_strength_256 = _mm512_set1_epi16(pri_strength);
113 const __m512i sec_strength_256 = _mm512_set1_epi16(sec_strength);
114 const __m512i zero = _mm512_setzero_si512();
115
116 if (pri_strength)
117 pri_damping = AOMMAX(0, pri_damping - get_msb(pri_strength));
118 if (sec_strength)
119 sec_damping = AOMMAX(0, sec_damping - get_msb(sec_strength));
120
121 const __m128i pri_d = _mm_cvtsi32_si128(pri_damping);
122 const __m128i sec_d = _mm_cvtsi32_si128(sec_damping);
123
124 {
125 const __m512i row = loadu_u16_8x4_avx512(in, CDEF_BSTRIDE);
126 __m512i sum, res, max, min;
127
128 min = max = row;
129 sum = zero;
130
131 // Primary near taps
132 cdef_filter_block_8x8_16_pri_avx512(
133 in, pri_d, po1, row, pri_strength_256, pri_taps_0, &max, &min, &sum);
134
135 // Primary far taps
136 cdef_filter_block_8x8_16_pri_avx512(
137 in, pri_d, po2, row, pri_strength_256, pri_taps_1, &max, &min, &sum);
138
139 // Secondary near taps
140 cdef_filter_block_8x8_16_sec_avx512(
141 in, sec_d, s1o1, s2o1, row, sec_strength_256, sec_taps_0, &max, &min, &sum);
142
143 // Secondary far taps
144 cdef_filter_block_8x8_16_sec_avx512(
145 in, sec_d, s1o2, s2o2, row, sec_strength_256, sec_taps_1, &max, &min, &sum);
146
147 // res = row + ((sum - (sum < 0) + 8) >> 4)
148 const __mmask32 mask = _mm512_cmpgt_epi16_mask(zero, sum);
149 sum = _mm512_mask_add_epi16(sum, mask, sum, _mm512_set1_epi16(-1));
150 res = _mm512_add_epi16(sum, duplicate_8);
151 res = _mm512_srai_epi16(res, 4);
152 res = _mm512_add_epi16(row, res);
153 res = _mm512_max_epi16(res, min);
154 res = _mm512_min_epi16(res, max);
155
156 _mm_storeu_si128((__m128i *)&dst[0 * dstride], _mm512_castsi512_si128(res));
157 _mm_storeu_si128((__m128i *)&dst[1 * dstride], _mm512_extracti32x4_epi32(res, 1));
158 _mm_storeu_si128((__m128i *)&dst[2 * dstride], _mm512_extracti32x4_epi32(res, 2));
159 _mm_storeu_si128((__m128i *)&dst[3 * dstride], _mm512_extracti32x4_epi32(res, 3));
160 }
161
162 {
163 const __m512i row = loadu_u16_8x4_avx512(in + 4 * CDEF_BSTRIDE, CDEF_BSTRIDE);
164 __m512i sum, res, max, min;
165
166 min = max = row;
167 sum = zero;
168
169 // Primary near taps
170 cdef_filter_block_8x8_16_pri_avx512(
171 in + 4 * CDEF_BSTRIDE, pri_d, po1, row, pri_strength_256, pri_taps_0, &max, &min, &sum);
172
173 // Primary far taps
174 cdef_filter_block_8x8_16_pri_avx512(
175 in + 4 * CDEF_BSTRIDE, pri_d, po2, row, pri_strength_256, pri_taps_1, &max, &min, &sum);
176
177 // Secondary near taps
178 cdef_filter_block_8x8_16_sec_avx512(in + 4 * CDEF_BSTRIDE,
179 sec_d,
180 s1o1,
181 s2o1,
182 row,
183 sec_strength_256,
184 sec_taps_0,
185 &max,
186 &min,
187 &sum);
188
189 // Secondary far taps
190 cdef_filter_block_8x8_16_sec_avx512(in + 4 * CDEF_BSTRIDE,
191 sec_d,
192 s1o2,
193 s2o2,
194 row,
195 sec_strength_256,
196 sec_taps_1,
197 &max,
198 &min,
199 &sum);
200
201 // res = row + ((sum - (sum < 0) + 8) >> 4)
202 const __mmask32 mask = _mm512_cmpgt_epi16_mask(zero, sum);
203 sum = _mm512_mask_add_epi16(sum, mask, sum, _mm512_set1_epi16(-1));
204 res = _mm512_add_epi16(sum, duplicate_8);
205 res = _mm512_srai_epi16(res, 4);
206 res = _mm512_add_epi16(row, res);
207 res = _mm512_max_epi16(res, min);
208 res = _mm512_min_epi16(res, max);
209
210 _mm_storeu_si128((__m128i *)&dst[4 * dstride], _mm512_castsi512_si128(res));
211 _mm_storeu_si128((__m128i *)&dst[5 * dstride], _mm512_extracti32x4_epi32(res, 1));
212 _mm_storeu_si128((__m128i *)&dst[6 * dstride], _mm512_extracti32x4_epi32(res, 2));
213 _mm_storeu_si128((__m128i *)&dst[7 * dstride], _mm512_extracti32x4_epi32(res, 3));
214 }
215 }
216
217 #endif // EN_AVX512_SUPPORT
218