1 /*
2 * Copyright (c) 2020, Alliance for Open Media. All rights reserved
3 *
4 * This source code is subject to the terms of the BSD 2 Clause License and
5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6 * was not distributed with this source code in the LICENSE file, you can
7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8 * Media Patent License 1.0 was not distributed with this source code in the
9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10 */
11
12 #include <assert.h>
13 #include <immintrin.h>
14
15 #include "config/av1_rtcd.h"
16 #include "av1/encoder/encoder.h"
17 #include "av1/encoder/temporal_filter.h"
18
19 #define SSE_STRIDE (BW + 4)
20
21 DECLARE_ALIGNED(32, static const uint32_t, sse_bytemask[4][8]) = {
22 { 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0, 0, 0 },
23 { 0, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0, 0 },
24 { 0, 0, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0 },
25 { 0, 0, 0, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF }
26 };
27
get_squared_error_16x16_avx2(const uint16_t * frame1,const unsigned int stride,const uint16_t * frame2,const unsigned int stride2,const int block_width,const int block_height,uint32_t * frame_sse,const unsigned int sse_stride)28 static AOM_FORCE_INLINE void get_squared_error_16x16_avx2(
29 const uint16_t *frame1, const unsigned int stride, const uint16_t *frame2,
30 const unsigned int stride2, const int block_width, const int block_height,
31 uint32_t *frame_sse, const unsigned int sse_stride) {
32 (void)block_width;
33 const uint16_t *src1 = frame1;
34 const uint16_t *src2 = frame2;
35 uint32_t *dst = frame_sse + 2;
36 for (int i = 0; i < block_height; i++) {
37 __m256i v_src1 = _mm256_loadu_si256((__m256i *)src1);
38 __m256i v_src2 = _mm256_loadu_si256((__m256i *)src2);
39 __m256i v_diff = _mm256_sub_epi16(v_src1, v_src2);
40 __m256i v_mullo = _mm256_mullo_epi16(v_diff, v_diff);
41 __m256i v_mulhi = _mm256_mulhi_epi16(v_diff, v_diff);
42
43 __m256i v_lo = _mm256_unpacklo_epi16(v_mullo, v_mulhi);
44 __m256i v_hi = _mm256_unpackhi_epi16(v_mullo, v_mulhi);
45 __m256i diff_lo =
46 _mm256_inserti128_si256(v_lo, _mm256_extracti128_si256(v_hi, 0), 1);
47 __m256i diff_hi =
48 _mm256_inserti128_si256(v_hi, _mm256_extracti128_si256(v_lo, 1), 0);
49
50 _mm256_storeu_si256((__m256i *)dst, diff_lo);
51 dst += 8;
52 _mm256_storeu_si256((__m256i *)dst, diff_hi);
53
54 src1 += stride, src2 += stride2;
55 dst += sse_stride - 8;
56 }
57 }
58
get_squared_error_32x32_avx2(const uint16_t * frame1,const unsigned int stride,const uint16_t * frame2,const unsigned int stride2,const int block_width,const int block_height,uint32_t * frame_sse,const unsigned int sse_stride)59 static AOM_FORCE_INLINE void get_squared_error_32x32_avx2(
60 const uint16_t *frame1, const unsigned int stride, const uint16_t *frame2,
61 const unsigned int stride2, const int block_width, const int block_height,
62 uint32_t *frame_sse, const unsigned int sse_stride) {
63 (void)block_width;
64 const uint16_t *src1 = frame1;
65 const uint16_t *src2 = frame2;
66 uint32_t *dst = frame_sse + 2;
67 for (int i = 0; i < block_height; i++) {
68 __m256i v_src1 = _mm256_loadu_si256((__m256i *)src1);
69 __m256i v_src2 = _mm256_loadu_si256((__m256i *)src2);
70 __m256i v_diff = _mm256_sub_epi16(v_src1, v_src2);
71 __m256i v_mullo = _mm256_mullo_epi16(v_diff, v_diff);
72 __m256i v_mulhi = _mm256_mulhi_epi16(v_diff, v_diff);
73
74 __m256i v_lo = _mm256_unpacklo_epi16(v_mullo, v_mulhi);
75 __m256i v_hi = _mm256_unpackhi_epi16(v_mullo, v_mulhi);
76 __m256i diff_lo =
77 _mm256_inserti128_si256(v_lo, _mm256_extracti128_si256(v_hi, 0), 1);
78 __m256i diff_hi =
79 _mm256_inserti128_si256(v_hi, _mm256_extracti128_si256(v_lo, 1), 0);
80
81 _mm256_storeu_si256((__m256i *)dst, diff_lo);
82 _mm256_storeu_si256((__m256i *)(dst + 8), diff_hi);
83
84 v_src1 = _mm256_loadu_si256((__m256i *)(src1 + 16));
85 v_src2 = _mm256_loadu_si256((__m256i *)(src2 + 16));
86 v_diff = _mm256_sub_epi16(v_src1, v_src2);
87 v_mullo = _mm256_mullo_epi16(v_diff, v_diff);
88 v_mulhi = _mm256_mulhi_epi16(v_diff, v_diff);
89
90 v_lo = _mm256_unpacklo_epi16(v_mullo, v_mulhi);
91 v_hi = _mm256_unpackhi_epi16(v_mullo, v_mulhi);
92 diff_lo =
93 _mm256_inserti128_si256(v_lo, _mm256_extracti128_si256(v_hi, 0), 1);
94 diff_hi =
95 _mm256_inserti128_si256(v_hi, _mm256_extracti128_si256(v_lo, 1), 0);
96
97 _mm256_storeu_si256((__m256i *)(dst + 16), diff_lo);
98 _mm256_storeu_si256((__m256i *)(dst + 24), diff_hi);
99
100 src1 += stride;
101 src2 += stride2;
102 dst += sse_stride;
103 }
104 }
105
xx_load_and_pad_left(uint32_t * src,__m256i * v256tmp)106 static AOM_FORCE_INLINE void xx_load_and_pad_left(uint32_t *src,
107 __m256i *v256tmp) {
108 *v256tmp = _mm256_loadu_si256((__m256i *)src);
109 // For the first column, replicate the first element twice to the left
110 __m256i v256tmp1 = _mm256_shuffle_epi32(*v256tmp, 0xEA);
111 *v256tmp = _mm256_inserti128_si256(*v256tmp,
112 _mm256_extracti128_si256(v256tmp1, 0), 0);
113 }
114
xx_load_and_pad_right(uint32_t * src,__m256i * v256tmp)115 static AOM_FORCE_INLINE void xx_load_and_pad_right(uint32_t *src,
116 __m256i *v256tmp) {
117 *v256tmp = _mm256_loadu_si256((__m256i *)src);
118 // For the last column, replicate the last element twice to the right
119 __m256i v256tmp1 = _mm256_shuffle_epi32(*v256tmp, 0x54);
120 *v256tmp = _mm256_inserti128_si256(*v256tmp,
121 _mm256_extracti128_si256(v256tmp1, 1), 1);
122 }
123
xx_mask_and_hadd(__m256i vsum,int i)124 static AOM_FORCE_INLINE int32_t xx_mask_and_hadd(__m256i vsum, int i) {
125 // Mask the required 5 values inside the vector
126 __m256i vtmp = _mm256_and_si256(vsum, *(__m256i *)sse_bytemask[i]);
127 __m128i v128a, v128b;
128 // Extract 256b as two 128b registers A and B
129 v128a = _mm256_castsi256_si128(vtmp);
130 v128b = _mm256_extracti128_si256(vtmp, 1);
131 // A = [A0+B0, A1+B1, A2+B2, A3+B3]
132 v128a = _mm_add_epi32(v128a, v128b);
133 // B = [A2+B2, A3+B3, 0, 0]
134 v128b = _mm_srli_si128(v128a, 8);
135 // A = [A0+B0+A2+B2, A1+B1+A3+B3, X, X]
136 v128a = _mm_add_epi32(v128a, v128b);
137 // B = [A1+B1+A3+B3, 0, 0, 0]
138 v128b = _mm_srli_si128(v128a, 4);
139 // A = [A0+B0+A2+B2+A1+B1+A3+B3, X, X, X]
140 v128a = _mm_add_epi32(v128a, v128b);
141 return _mm_extract_epi32(v128a, 0);
142 }
143
highbd_apply_temporal_filter(const uint16_t * frame1,const unsigned int stride,const uint16_t * frame2,const unsigned int stride2,const int block_width,const int block_height,const int * subblock_mses,unsigned int * accumulator,uint16_t * count,uint32_t * frame_sse,uint32_t * luma_sse_sum,int bd,const double inv_num_ref_pixels,const double decay_factor,const double inv_factor,const double weight_factor,double * d_factor)144 static void highbd_apply_temporal_filter(
145 const uint16_t *frame1, const unsigned int stride, const uint16_t *frame2,
146 const unsigned int stride2, const int block_width, const int block_height,
147 const int *subblock_mses, unsigned int *accumulator, uint16_t *count,
148 uint32_t *frame_sse, uint32_t *luma_sse_sum, int bd,
149 const double inv_num_ref_pixels, const double decay_factor,
150 const double inv_factor, const double weight_factor, double *d_factor) {
151 assert(((block_width == 16) || (block_width == 32)) &&
152 ((block_height == 16) || (block_height == 32)));
153
154 uint32_t acc_5x5_sse[BH][BW];
155
156 if (block_width == 32) {
157 get_squared_error_32x32_avx2(frame1, stride, frame2, stride2, block_width,
158 block_height, frame_sse, SSE_STRIDE);
159 } else {
160 get_squared_error_16x16_avx2(frame1, stride, frame2, stride2, block_width,
161 block_height, frame_sse, SSE_STRIDE);
162 }
163
164 __m256i vsrc[5];
165
166 // Traverse 4 columns at a time
167 // First and last columns will require padding
168 int col;
169 uint32_t *src = frame_sse;
170 for (int i = 2; i < 5; i++) {
171 xx_load_and_pad_left(src, &vsrc[i]);
172 src += SSE_STRIDE;
173 }
174
175 // Copy first row to first 2 vectors
176 vsrc[0] = vsrc[2];
177 vsrc[1] = vsrc[2];
178
179 for (int row = 0; row < block_height - 3; row++) {
180 __m256i vsum1 = _mm256_add_epi32(vsrc[0], vsrc[1]);
181 __m256i vsum2 = _mm256_add_epi32(vsrc[2], vsrc[3]);
182 __m256i vsum3 = _mm256_add_epi32(vsum1, vsum2);
183 __m256i vsum = _mm256_add_epi32(vsum3, vsrc[4]);
184
185 for (int i = 0; i < 4; i++) {
186 vsrc[i] = vsrc[i + 1];
187 }
188
189 xx_load_and_pad_left(src, &vsrc[4]);
190 src += SSE_STRIDE;
191
192 acc_5x5_sse[row][0] = xx_mask_and_hadd(vsum, 0);
193 acc_5x5_sse[row][1] = xx_mask_and_hadd(vsum, 1);
194 acc_5x5_sse[row][2] = xx_mask_and_hadd(vsum, 2);
195 acc_5x5_sse[row][3] = xx_mask_and_hadd(vsum, 3);
196 }
197 for (int row = block_height - 3; row < block_height; row++) {
198 __m256i vsum1 = _mm256_add_epi32(vsrc[0], vsrc[1]);
199 __m256i vsum2 = _mm256_add_epi32(vsrc[2], vsrc[3]);
200 __m256i vsum3 = _mm256_add_epi32(vsum1, vsum2);
201 __m256i vsum = _mm256_add_epi32(vsum3, vsrc[4]);
202
203 for (int i = 0; i < 4; i++) {
204 vsrc[i] = vsrc[i + 1];
205 }
206
207 acc_5x5_sse[row][0] = xx_mask_and_hadd(vsum, 0);
208 acc_5x5_sse[row][1] = xx_mask_and_hadd(vsum, 1);
209 acc_5x5_sse[row][2] = xx_mask_and_hadd(vsum, 2);
210 acc_5x5_sse[row][3] = xx_mask_and_hadd(vsum, 3);
211 }
212 for (col = 4; col < block_width - 4; col += 4) {
213 src = frame_sse + col;
214
215 // Load and pad(for first and last col) 3 rows from the top
216 for (int i = 2; i < 5; i++) {
217 vsrc[i] = _mm256_loadu_si256((__m256i *)src);
218 src += SSE_STRIDE;
219 }
220
221 // Copy first row to first 2 vectors
222 vsrc[0] = vsrc[2];
223 vsrc[1] = vsrc[2];
224
225 for (int row = 0; row < block_height - 3; row++) {
226 __m256i vsum1 = _mm256_add_epi32(vsrc[0], vsrc[1]);
227 __m256i vsum2 = _mm256_add_epi32(vsrc[2], vsrc[3]);
228 __m256i vsum3 = _mm256_add_epi32(vsum1, vsum2);
229 __m256i vsum = _mm256_add_epi32(vsum3, vsrc[4]);
230
231 for (int i = 0; i < 4; i++) {
232 vsrc[i] = vsrc[i + 1];
233 }
234
235 vsrc[4] = _mm256_loadu_si256((__m256i *)src);
236
237 src += SSE_STRIDE;
238
239 acc_5x5_sse[row][col] = xx_mask_and_hadd(vsum, 0);
240 acc_5x5_sse[row][col + 1] = xx_mask_and_hadd(vsum, 1);
241 acc_5x5_sse[row][col + 2] = xx_mask_and_hadd(vsum, 2);
242 acc_5x5_sse[row][col + 3] = xx_mask_and_hadd(vsum, 3);
243 }
244 for (int row = block_height - 3; row < block_height; row++) {
245 __m256i vsum1 = _mm256_add_epi32(vsrc[0], vsrc[1]);
246 __m256i vsum2 = _mm256_add_epi32(vsrc[2], vsrc[3]);
247 __m256i vsum3 = _mm256_add_epi32(vsum1, vsum2);
248 __m256i vsum = _mm256_add_epi32(vsum3, vsrc[4]);
249
250 for (int i = 0; i < 4; i++) {
251 vsrc[i] = vsrc[i + 1];
252 }
253
254 acc_5x5_sse[row][col] = xx_mask_and_hadd(vsum, 0);
255 acc_5x5_sse[row][col + 1] = xx_mask_and_hadd(vsum, 1);
256 acc_5x5_sse[row][col + 2] = xx_mask_and_hadd(vsum, 2);
257 acc_5x5_sse[row][col + 3] = xx_mask_and_hadd(vsum, 3);
258 }
259 }
260
261 src = frame_sse + col;
262
263 // Load and pad(for first and last col) 3 rows from the top
264 for (int i = 2; i < 5; i++) {
265 xx_load_and_pad_right(src, &vsrc[i]);
266 src += SSE_STRIDE;
267 }
268
269 // Copy first row to first 2 vectors
270 vsrc[0] = vsrc[2];
271 vsrc[1] = vsrc[2];
272
273 for (int row = 0; row < block_height - 3; row++) {
274 __m256i vsum1 = _mm256_add_epi32(vsrc[0], vsrc[1]);
275 __m256i vsum2 = _mm256_add_epi32(vsrc[2], vsrc[3]);
276 __m256i vsum3 = _mm256_add_epi32(vsum1, vsum2);
277 __m256i vsum = _mm256_add_epi32(vsum3, vsrc[4]);
278
279 for (int i = 0; i < 4; i++) {
280 vsrc[i] = vsrc[i + 1];
281 }
282
283 xx_load_and_pad_right(src, &vsrc[4]);
284 src += SSE_STRIDE;
285
286 acc_5x5_sse[row][col] = xx_mask_and_hadd(vsum, 0);
287 acc_5x5_sse[row][col + 1] = xx_mask_and_hadd(vsum, 1);
288 acc_5x5_sse[row][col + 2] = xx_mask_and_hadd(vsum, 2);
289 acc_5x5_sse[row][col + 3] = xx_mask_and_hadd(vsum, 3);
290 }
291 for (int row = block_height - 3; row < block_height; row++) {
292 __m256i vsum1 = _mm256_add_epi32(vsrc[0], vsrc[1]);
293 __m256i vsum2 = _mm256_add_epi32(vsrc[2], vsrc[3]);
294 __m256i vsum3 = _mm256_add_epi32(vsum1, vsum2);
295 __m256i vsum = _mm256_add_epi32(vsum3, vsrc[4]);
296
297 for (int i = 0; i < 4; i++) {
298 vsrc[i] = vsrc[i + 1];
299 }
300
301 acc_5x5_sse[row][col] = xx_mask_and_hadd(vsum, 0);
302 acc_5x5_sse[row][col + 1] = xx_mask_and_hadd(vsum, 1);
303 acc_5x5_sse[row][col + 2] = xx_mask_and_hadd(vsum, 2);
304 acc_5x5_sse[row][col + 3] = xx_mask_and_hadd(vsum, 3);
305 }
306
307 for (int i = 0, k = 0; i < block_height; i++) {
308 for (int j = 0; j < block_width; j++, k++) {
309 const int pixel_value = frame2[i * stride2 + j];
310 uint32_t diff_sse = acc_5x5_sse[i][j] + luma_sse_sum[i * BW + j];
311
312 // Scale down the difference for high bit depth input.
313 diff_sse >>= ((bd - 8) * 2);
314
315 const double window_error = diff_sse * inv_num_ref_pixels;
316 const int subblock_idx =
317 (i >= block_height / 2) * 2 + (j >= block_width / 2);
318 const double block_error = (double)subblock_mses[subblock_idx];
319 const double combined_error =
320 weight_factor * window_error + block_error * inv_factor;
321
322 double scaled_error =
323 combined_error * d_factor[subblock_idx] * decay_factor;
324 scaled_error = AOMMIN(scaled_error, 7);
325 const int weight = (int)(exp(-scaled_error) * TF_WEIGHT_SCALE);
326
327 count[k] += weight;
328 accumulator[k] += weight * pixel_value;
329 }
330 }
331 }
332
av1_highbd_apply_temporal_filter_avx2(const YV12_BUFFER_CONFIG * frame_to_filter,const MACROBLOCKD * mbd,const BLOCK_SIZE block_size,const int mb_row,const int mb_col,const int num_planes,const double * noise_levels,const MV * subblock_mvs,const int * subblock_mses,const int q_factor,const int filter_strength,const uint8_t * pred,uint32_t * accum,uint16_t * count)333 void av1_highbd_apply_temporal_filter_avx2(
334 const YV12_BUFFER_CONFIG *frame_to_filter, const MACROBLOCKD *mbd,
335 const BLOCK_SIZE block_size, const int mb_row, const int mb_col,
336 const int num_planes, const double *noise_levels, const MV *subblock_mvs,
337 const int *subblock_mses, const int q_factor, const int filter_strength,
338 const uint8_t *pred, uint32_t *accum, uint16_t *count) {
339 const int is_high_bitdepth = frame_to_filter->flags & YV12_FLAG_HIGHBITDEPTH;
340 assert(block_size == BLOCK_32X32 && "Only support 32x32 block with sse2!");
341 assert(TF_WINDOW_LENGTH == 5 && "Only support window length 5 with sse2!");
342 assert(num_planes >= 1 && num_planes <= MAX_MB_PLANE);
343 (void)is_high_bitdepth;
344
345 const int mb_height = block_size_high[block_size];
346 const int mb_width = block_size_wide[block_size];
347 const int frame_height = frame_to_filter->y_crop_height;
348 const int frame_width = frame_to_filter->y_crop_width;
349 const int min_frame_size = AOMMIN(frame_height, frame_width);
350 // Variables to simplify combined error calculation.
351 const double inv_factor = 1.0 / ((TF_WINDOW_BLOCK_BALANCE_WEIGHT + 1) *
352 TF_SEARCH_ERROR_NORM_WEIGHT);
353 const double weight_factor =
354 (double)TF_WINDOW_BLOCK_BALANCE_WEIGHT * inv_factor;
355 // Decay factors for non-local mean approach.
356 // Smaller q -> smaller filtering weight.
357 double q_decay = pow((double)q_factor / TF_Q_DECAY_THRESHOLD, 2);
358 q_decay = CLIP(q_decay, 1e-5, 1);
359 // Smaller strength -> smaller filtering weight.
360 double s_decay = pow((double)filter_strength / TF_STRENGTH_THRESHOLD, 2);
361 s_decay = CLIP(s_decay, 1e-5, 1);
362 double d_factor[4] = { 0 };
363 uint32_t frame_sse[SSE_STRIDE * BH] = { 0 };
364 uint32_t luma_sse_sum[BW * BH] = { 0 };
365 uint16_t *pred1 = CONVERT_TO_SHORTPTR(pred);
366
367 for (int subblock_idx = 0; subblock_idx < 4; subblock_idx++) {
368 // Larger motion vector -> smaller filtering weight.
369 const MV mv = subblock_mvs[subblock_idx];
370 const double distance = sqrt(pow(mv.row, 2) + pow(mv.col, 2));
371 double distance_threshold = min_frame_size * TF_SEARCH_DISTANCE_THRESHOLD;
372 distance_threshold = AOMMAX(distance_threshold, 1);
373 d_factor[subblock_idx] = distance / distance_threshold;
374 d_factor[subblock_idx] = AOMMAX(d_factor[subblock_idx], 1);
375 }
376
377 // Handle planes in sequence.
378 int plane_offset = 0;
379 for (int plane = 0; plane < num_planes; ++plane) {
380 const uint32_t plane_h = mb_height >> mbd->plane[plane].subsampling_y;
381 const uint32_t plane_w = mb_width >> mbd->plane[plane].subsampling_x;
382 const uint32_t frame_stride = frame_to_filter->strides[plane == 0 ? 0 : 1];
383 const int frame_offset = mb_row * plane_h * frame_stride + mb_col * plane_w;
384
385 const uint16_t *ref =
386 CONVERT_TO_SHORTPTR(frame_to_filter->buffers[plane]) + frame_offset;
387 const int ss_x_shift =
388 mbd->plane[plane].subsampling_x - mbd->plane[AOM_PLANE_Y].subsampling_x;
389 const int ss_y_shift =
390 mbd->plane[plane].subsampling_y - mbd->plane[AOM_PLANE_Y].subsampling_y;
391 const int num_ref_pixels = TF_WINDOW_LENGTH * TF_WINDOW_LENGTH +
392 ((plane) ? (1 << (ss_x_shift + ss_y_shift)) : 0);
393 const double inv_num_ref_pixels = 1.0 / num_ref_pixels;
394 // Larger noise -> larger filtering weight.
395 const double n_decay = 0.5 + log(2 * noise_levels[plane] + 5.0);
396 const double decay_factor = 1 / (n_decay * q_decay * s_decay);
397
398 // Filter U-plane and V-plane using Y-plane. This is because motion
399 // search is only done on Y-plane, so the information from Y-plane
400 // will be more accurate. The luma sse sum is reused in both chroma
401 // planes.
402 if (plane == AOM_PLANE_U) {
403 for (unsigned int i = 0, k = 0; i < plane_h; i++) {
404 for (unsigned int j = 0; j < plane_w; j++, k++) {
405 for (int ii = 0; ii < (1 << ss_y_shift); ++ii) {
406 for (int jj = 0; jj < (1 << ss_x_shift); ++jj) {
407 const int yy = (i << ss_y_shift) + ii; // Y-coord on Y-plane.
408 const int xx = (j << ss_x_shift) + jj; // X-coord on Y-plane.
409 luma_sse_sum[i * BW + j] += frame_sse[yy * SSE_STRIDE + xx + 2];
410 }
411 }
412 }
413 }
414 }
415
416 highbd_apply_temporal_filter(
417 ref, frame_stride, pred1 + plane_offset, plane_w, plane_w, plane_h,
418 subblock_mses, accum + plane_offset, count + plane_offset, frame_sse,
419 luma_sse_sum, mbd->bd, inv_num_ref_pixels, decay_factor, inv_factor,
420 weight_factor, d_factor);
421 plane_offset += plane_h * plane_w;
422 }
423 }
424