1 /*
2  * Copyright (c) 2020, Alliance for Open Media. All rights reserved
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <assert.h>
13 #include <immintrin.h>
14 
15 #include "config/av1_rtcd.h"
16 #include "av1/encoder/encoder.h"
17 #include "av1/encoder/temporal_filter.h"
18 
19 #define SSE_STRIDE (BW + 4)
20 
21 DECLARE_ALIGNED(32, static const uint32_t, sse_bytemask[4][8]) = {
22   { 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0, 0, 0 },
23   { 0, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0, 0 },
24   { 0, 0, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0 },
25   { 0, 0, 0, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF, 0xFFFFFFFF }
26 };
27 
get_squared_error_16x16_avx2(const uint16_t * frame1,const unsigned int stride,const uint16_t * frame2,const unsigned int stride2,const int block_width,const int block_height,uint32_t * frame_sse,const unsigned int sse_stride)28 static AOM_FORCE_INLINE void get_squared_error_16x16_avx2(
29     const uint16_t *frame1, const unsigned int stride, const uint16_t *frame2,
30     const unsigned int stride2, const int block_width, const int block_height,
31     uint32_t *frame_sse, const unsigned int sse_stride) {
32   (void)block_width;
33   const uint16_t *src1 = frame1;
34   const uint16_t *src2 = frame2;
35   uint32_t *dst = frame_sse + 2;
36   for (int i = 0; i < block_height; i++) {
37     __m256i v_src1 = _mm256_loadu_si256((__m256i *)src1);
38     __m256i v_src2 = _mm256_loadu_si256((__m256i *)src2);
39     __m256i v_diff = _mm256_sub_epi16(v_src1, v_src2);
40     __m256i v_mullo = _mm256_mullo_epi16(v_diff, v_diff);
41     __m256i v_mulhi = _mm256_mulhi_epi16(v_diff, v_diff);
42 
43     __m256i v_lo = _mm256_unpacklo_epi16(v_mullo, v_mulhi);
44     __m256i v_hi = _mm256_unpackhi_epi16(v_mullo, v_mulhi);
45     __m256i diff_lo =
46         _mm256_inserti128_si256(v_lo, _mm256_extracti128_si256(v_hi, 0), 1);
47     __m256i diff_hi =
48         _mm256_inserti128_si256(v_hi, _mm256_extracti128_si256(v_lo, 1), 0);
49 
50     _mm256_storeu_si256((__m256i *)dst, diff_lo);
51     dst += 8;
52     _mm256_storeu_si256((__m256i *)dst, diff_hi);
53 
54     src1 += stride, src2 += stride2;
55     dst += sse_stride - 8;
56   }
57 }
58 
get_squared_error_32x32_avx2(const uint16_t * frame1,const unsigned int stride,const uint16_t * frame2,const unsigned int stride2,const int block_width,const int block_height,uint32_t * frame_sse,const unsigned int sse_stride)59 static AOM_FORCE_INLINE void get_squared_error_32x32_avx2(
60     const uint16_t *frame1, const unsigned int stride, const uint16_t *frame2,
61     const unsigned int stride2, const int block_width, const int block_height,
62     uint32_t *frame_sse, const unsigned int sse_stride) {
63   (void)block_width;
64   const uint16_t *src1 = frame1;
65   const uint16_t *src2 = frame2;
66   uint32_t *dst = frame_sse + 2;
67   for (int i = 0; i < block_height; i++) {
68     __m256i v_src1 = _mm256_loadu_si256((__m256i *)src1);
69     __m256i v_src2 = _mm256_loadu_si256((__m256i *)src2);
70     __m256i v_diff = _mm256_sub_epi16(v_src1, v_src2);
71     __m256i v_mullo = _mm256_mullo_epi16(v_diff, v_diff);
72     __m256i v_mulhi = _mm256_mulhi_epi16(v_diff, v_diff);
73 
74     __m256i v_lo = _mm256_unpacklo_epi16(v_mullo, v_mulhi);
75     __m256i v_hi = _mm256_unpackhi_epi16(v_mullo, v_mulhi);
76     __m256i diff_lo =
77         _mm256_inserti128_si256(v_lo, _mm256_extracti128_si256(v_hi, 0), 1);
78     __m256i diff_hi =
79         _mm256_inserti128_si256(v_hi, _mm256_extracti128_si256(v_lo, 1), 0);
80 
81     _mm256_storeu_si256((__m256i *)dst, diff_lo);
82     _mm256_storeu_si256((__m256i *)(dst + 8), diff_hi);
83 
84     v_src1 = _mm256_loadu_si256((__m256i *)(src1 + 16));
85     v_src2 = _mm256_loadu_si256((__m256i *)(src2 + 16));
86     v_diff = _mm256_sub_epi16(v_src1, v_src2);
87     v_mullo = _mm256_mullo_epi16(v_diff, v_diff);
88     v_mulhi = _mm256_mulhi_epi16(v_diff, v_diff);
89 
90     v_lo = _mm256_unpacklo_epi16(v_mullo, v_mulhi);
91     v_hi = _mm256_unpackhi_epi16(v_mullo, v_mulhi);
92     diff_lo =
93         _mm256_inserti128_si256(v_lo, _mm256_extracti128_si256(v_hi, 0), 1);
94     diff_hi =
95         _mm256_inserti128_si256(v_hi, _mm256_extracti128_si256(v_lo, 1), 0);
96 
97     _mm256_storeu_si256((__m256i *)(dst + 16), diff_lo);
98     _mm256_storeu_si256((__m256i *)(dst + 24), diff_hi);
99 
100     src1 += stride;
101     src2 += stride2;
102     dst += sse_stride;
103   }
104 }
105 
xx_load_and_pad_left(uint32_t * src,__m256i * v256tmp)106 static AOM_FORCE_INLINE void xx_load_and_pad_left(uint32_t *src,
107                                                   __m256i *v256tmp) {
108   *v256tmp = _mm256_loadu_si256((__m256i *)src);
109   // For the first column, replicate the first element twice to the left
110   __m256i v256tmp1 = _mm256_shuffle_epi32(*v256tmp, 0xEA);
111   *v256tmp = _mm256_inserti128_si256(*v256tmp,
112                                      _mm256_extracti128_si256(v256tmp1, 0), 0);
113 }
114 
xx_load_and_pad_right(uint32_t * src,__m256i * v256tmp)115 static AOM_FORCE_INLINE void xx_load_and_pad_right(uint32_t *src,
116                                                    __m256i *v256tmp) {
117   *v256tmp = _mm256_loadu_si256((__m256i *)src);
118   // For the last column, replicate the last element twice to the right
119   __m256i v256tmp1 = _mm256_shuffle_epi32(*v256tmp, 0x54);
120   *v256tmp = _mm256_inserti128_si256(*v256tmp,
121                                      _mm256_extracti128_si256(v256tmp1, 1), 1);
122 }
123 
xx_mask_and_hadd(__m256i vsum,int i)124 static AOM_FORCE_INLINE int32_t xx_mask_and_hadd(__m256i vsum, int i) {
125   // Mask the required 5 values inside the vector
126   __m256i vtmp = _mm256_and_si256(vsum, *(__m256i *)sse_bytemask[i]);
127   __m128i v128a, v128b;
128   // Extract 256b as two 128b registers A and B
129   v128a = _mm256_castsi256_si128(vtmp);
130   v128b = _mm256_extracti128_si256(vtmp, 1);
131   // A = [A0+B0, A1+B1, A2+B2, A3+B3]
132   v128a = _mm_add_epi32(v128a, v128b);
133   // B = [A2+B2, A3+B3, 0, 0]
134   v128b = _mm_srli_si128(v128a, 8);
135   // A = [A0+B0+A2+B2, A1+B1+A3+B3, X, X]
136   v128a = _mm_add_epi32(v128a, v128b);
137   // B = [A1+B1+A3+B3, 0, 0, 0]
138   v128b = _mm_srli_si128(v128a, 4);
139   // A = [A0+B0+A2+B2+A1+B1+A3+B3, X, X, X]
140   v128a = _mm_add_epi32(v128a, v128b);
141   return _mm_extract_epi32(v128a, 0);
142 }
143 
highbd_apply_temporal_filter(const uint16_t * frame1,const unsigned int stride,const uint16_t * frame2,const unsigned int stride2,const int block_width,const int block_height,const int * subblock_mses,unsigned int * accumulator,uint16_t * count,uint32_t * frame_sse,uint32_t * luma_sse_sum,int bd,const double inv_num_ref_pixels,const double decay_factor,const double inv_factor,const double weight_factor,double * d_factor)144 static void highbd_apply_temporal_filter(
145     const uint16_t *frame1, const unsigned int stride, const uint16_t *frame2,
146     const unsigned int stride2, const int block_width, const int block_height,
147     const int *subblock_mses, unsigned int *accumulator, uint16_t *count,
148     uint32_t *frame_sse, uint32_t *luma_sse_sum, int bd,
149     const double inv_num_ref_pixels, const double decay_factor,
150     const double inv_factor, const double weight_factor, double *d_factor) {
151   assert(((block_width == 16) || (block_width == 32)) &&
152          ((block_height == 16) || (block_height == 32)));
153 
154   uint32_t acc_5x5_sse[BH][BW];
155 
156   if (block_width == 32) {
157     get_squared_error_32x32_avx2(frame1, stride, frame2, stride2, block_width,
158                                  block_height, frame_sse, SSE_STRIDE);
159   } else {
160     get_squared_error_16x16_avx2(frame1, stride, frame2, stride2, block_width,
161                                  block_height, frame_sse, SSE_STRIDE);
162   }
163 
164   __m256i vsrc[5];
165 
166   // Traverse 4 columns at a time
167   // First and last columns will require padding
168   int col;
169   uint32_t *src = frame_sse;
170   for (int i = 2; i < 5; i++) {
171     xx_load_and_pad_left(src, &vsrc[i]);
172     src += SSE_STRIDE;
173   }
174 
175   // Copy first row to first 2 vectors
176   vsrc[0] = vsrc[2];
177   vsrc[1] = vsrc[2];
178 
179   for (int row = 0; row < block_height - 3; row++) {
180     __m256i vsum1 = _mm256_add_epi32(vsrc[0], vsrc[1]);
181     __m256i vsum2 = _mm256_add_epi32(vsrc[2], vsrc[3]);
182     __m256i vsum3 = _mm256_add_epi32(vsum1, vsum2);
183     __m256i vsum = _mm256_add_epi32(vsum3, vsrc[4]);
184 
185     for (int i = 0; i < 4; i++) {
186       vsrc[i] = vsrc[i + 1];
187     }
188 
189     xx_load_and_pad_left(src, &vsrc[4]);
190     src += SSE_STRIDE;
191 
192     acc_5x5_sse[row][0] = xx_mask_and_hadd(vsum, 0);
193     acc_5x5_sse[row][1] = xx_mask_and_hadd(vsum, 1);
194     acc_5x5_sse[row][2] = xx_mask_and_hadd(vsum, 2);
195     acc_5x5_sse[row][3] = xx_mask_and_hadd(vsum, 3);
196   }
197   for (int row = block_height - 3; row < block_height; row++) {
198     __m256i vsum1 = _mm256_add_epi32(vsrc[0], vsrc[1]);
199     __m256i vsum2 = _mm256_add_epi32(vsrc[2], vsrc[3]);
200     __m256i vsum3 = _mm256_add_epi32(vsum1, vsum2);
201     __m256i vsum = _mm256_add_epi32(vsum3, vsrc[4]);
202 
203     for (int i = 0; i < 4; i++) {
204       vsrc[i] = vsrc[i + 1];
205     }
206 
207     acc_5x5_sse[row][0] = xx_mask_and_hadd(vsum, 0);
208     acc_5x5_sse[row][1] = xx_mask_and_hadd(vsum, 1);
209     acc_5x5_sse[row][2] = xx_mask_and_hadd(vsum, 2);
210     acc_5x5_sse[row][3] = xx_mask_and_hadd(vsum, 3);
211   }
212   for (col = 4; col < block_width - 4; col += 4) {
213     src = frame_sse + col;
214 
215     // Load and pad(for first and last col) 3 rows from the top
216     for (int i = 2; i < 5; i++) {
217       vsrc[i] = _mm256_loadu_si256((__m256i *)src);
218       src += SSE_STRIDE;
219     }
220 
221     // Copy first row to first 2 vectors
222     vsrc[0] = vsrc[2];
223     vsrc[1] = vsrc[2];
224 
225     for (int row = 0; row < block_height - 3; row++) {
226       __m256i vsum1 = _mm256_add_epi32(vsrc[0], vsrc[1]);
227       __m256i vsum2 = _mm256_add_epi32(vsrc[2], vsrc[3]);
228       __m256i vsum3 = _mm256_add_epi32(vsum1, vsum2);
229       __m256i vsum = _mm256_add_epi32(vsum3, vsrc[4]);
230 
231       for (int i = 0; i < 4; i++) {
232         vsrc[i] = vsrc[i + 1];
233       }
234 
235       vsrc[4] = _mm256_loadu_si256((__m256i *)src);
236 
237       src += SSE_STRIDE;
238 
239       acc_5x5_sse[row][col] = xx_mask_and_hadd(vsum, 0);
240       acc_5x5_sse[row][col + 1] = xx_mask_and_hadd(vsum, 1);
241       acc_5x5_sse[row][col + 2] = xx_mask_and_hadd(vsum, 2);
242       acc_5x5_sse[row][col + 3] = xx_mask_and_hadd(vsum, 3);
243     }
244     for (int row = block_height - 3; row < block_height; row++) {
245       __m256i vsum1 = _mm256_add_epi32(vsrc[0], vsrc[1]);
246       __m256i vsum2 = _mm256_add_epi32(vsrc[2], vsrc[3]);
247       __m256i vsum3 = _mm256_add_epi32(vsum1, vsum2);
248       __m256i vsum = _mm256_add_epi32(vsum3, vsrc[4]);
249 
250       for (int i = 0; i < 4; i++) {
251         vsrc[i] = vsrc[i + 1];
252       }
253 
254       acc_5x5_sse[row][col] = xx_mask_and_hadd(vsum, 0);
255       acc_5x5_sse[row][col + 1] = xx_mask_and_hadd(vsum, 1);
256       acc_5x5_sse[row][col + 2] = xx_mask_and_hadd(vsum, 2);
257       acc_5x5_sse[row][col + 3] = xx_mask_and_hadd(vsum, 3);
258     }
259   }
260 
261   src = frame_sse + col;
262 
263   // Load and pad(for first and last col) 3 rows from the top
264   for (int i = 2; i < 5; i++) {
265     xx_load_and_pad_right(src, &vsrc[i]);
266     src += SSE_STRIDE;
267   }
268 
269   // Copy first row to first 2 vectors
270   vsrc[0] = vsrc[2];
271   vsrc[1] = vsrc[2];
272 
273   for (int row = 0; row < block_height - 3; row++) {
274     __m256i vsum1 = _mm256_add_epi32(vsrc[0], vsrc[1]);
275     __m256i vsum2 = _mm256_add_epi32(vsrc[2], vsrc[3]);
276     __m256i vsum3 = _mm256_add_epi32(vsum1, vsum2);
277     __m256i vsum = _mm256_add_epi32(vsum3, vsrc[4]);
278 
279     for (int i = 0; i < 4; i++) {
280       vsrc[i] = vsrc[i + 1];
281     }
282 
283     xx_load_and_pad_right(src, &vsrc[4]);
284     src += SSE_STRIDE;
285 
286     acc_5x5_sse[row][col] = xx_mask_and_hadd(vsum, 0);
287     acc_5x5_sse[row][col + 1] = xx_mask_and_hadd(vsum, 1);
288     acc_5x5_sse[row][col + 2] = xx_mask_and_hadd(vsum, 2);
289     acc_5x5_sse[row][col + 3] = xx_mask_and_hadd(vsum, 3);
290   }
291   for (int row = block_height - 3; row < block_height; row++) {
292     __m256i vsum1 = _mm256_add_epi32(vsrc[0], vsrc[1]);
293     __m256i vsum2 = _mm256_add_epi32(vsrc[2], vsrc[3]);
294     __m256i vsum3 = _mm256_add_epi32(vsum1, vsum2);
295     __m256i vsum = _mm256_add_epi32(vsum3, vsrc[4]);
296 
297     for (int i = 0; i < 4; i++) {
298       vsrc[i] = vsrc[i + 1];
299     }
300 
301     acc_5x5_sse[row][col] = xx_mask_and_hadd(vsum, 0);
302     acc_5x5_sse[row][col + 1] = xx_mask_and_hadd(vsum, 1);
303     acc_5x5_sse[row][col + 2] = xx_mask_and_hadd(vsum, 2);
304     acc_5x5_sse[row][col + 3] = xx_mask_and_hadd(vsum, 3);
305   }
306 
307   for (int i = 0, k = 0; i < block_height; i++) {
308     for (int j = 0; j < block_width; j++, k++) {
309       const int pixel_value = frame2[i * stride2 + j];
310       uint32_t diff_sse = acc_5x5_sse[i][j] + luma_sse_sum[i * BW + j];
311 
312       // Scale down the difference for high bit depth input.
313       diff_sse >>= ((bd - 8) * 2);
314 
315       const double window_error = diff_sse * inv_num_ref_pixels;
316       const int subblock_idx =
317           (i >= block_height / 2) * 2 + (j >= block_width / 2);
318       const double block_error = (double)subblock_mses[subblock_idx];
319       const double combined_error =
320           weight_factor * window_error + block_error * inv_factor;
321 
322       double scaled_error =
323           combined_error * d_factor[subblock_idx] * decay_factor;
324       scaled_error = AOMMIN(scaled_error, 7);
325       const int weight = (int)(exp(-scaled_error) * TF_WEIGHT_SCALE);
326 
327       count[k] += weight;
328       accumulator[k] += weight * pixel_value;
329     }
330   }
331 }
332 
av1_highbd_apply_temporal_filter_avx2(const YV12_BUFFER_CONFIG * frame_to_filter,const MACROBLOCKD * mbd,const BLOCK_SIZE block_size,const int mb_row,const int mb_col,const int num_planes,const double * noise_levels,const MV * subblock_mvs,const int * subblock_mses,const int q_factor,const int filter_strength,const uint8_t * pred,uint32_t * accum,uint16_t * count)333 void av1_highbd_apply_temporal_filter_avx2(
334     const YV12_BUFFER_CONFIG *frame_to_filter, const MACROBLOCKD *mbd,
335     const BLOCK_SIZE block_size, const int mb_row, const int mb_col,
336     const int num_planes, const double *noise_levels, const MV *subblock_mvs,
337     const int *subblock_mses, const int q_factor, const int filter_strength,
338     const uint8_t *pred, uint32_t *accum, uint16_t *count) {
339   const int is_high_bitdepth = frame_to_filter->flags & YV12_FLAG_HIGHBITDEPTH;
340   assert(block_size == BLOCK_32X32 && "Only support 32x32 block with sse2!");
341   assert(TF_WINDOW_LENGTH == 5 && "Only support window length 5 with sse2!");
342   assert(num_planes >= 1 && num_planes <= MAX_MB_PLANE);
343   (void)is_high_bitdepth;
344 
345   const int mb_height = block_size_high[block_size];
346   const int mb_width = block_size_wide[block_size];
347   const int frame_height = frame_to_filter->y_crop_height;
348   const int frame_width = frame_to_filter->y_crop_width;
349   const int min_frame_size = AOMMIN(frame_height, frame_width);
350   // Variables to simplify combined error calculation.
351   const double inv_factor = 1.0 / ((TF_WINDOW_BLOCK_BALANCE_WEIGHT + 1) *
352                                    TF_SEARCH_ERROR_NORM_WEIGHT);
353   const double weight_factor =
354       (double)TF_WINDOW_BLOCK_BALANCE_WEIGHT * inv_factor;
355   // Decay factors for non-local mean approach.
356   // Smaller q -> smaller filtering weight.
357   double q_decay = pow((double)q_factor / TF_Q_DECAY_THRESHOLD, 2);
358   q_decay = CLIP(q_decay, 1e-5, 1);
359   // Smaller strength -> smaller filtering weight.
360   double s_decay = pow((double)filter_strength / TF_STRENGTH_THRESHOLD, 2);
361   s_decay = CLIP(s_decay, 1e-5, 1);
362   double d_factor[4] = { 0 };
363   uint32_t frame_sse[SSE_STRIDE * BH] = { 0 };
364   uint32_t luma_sse_sum[BW * BH] = { 0 };
365   uint16_t *pred1 = CONVERT_TO_SHORTPTR(pred);
366 
367   for (int subblock_idx = 0; subblock_idx < 4; subblock_idx++) {
368     // Larger motion vector -> smaller filtering weight.
369     const MV mv = subblock_mvs[subblock_idx];
370     const double distance = sqrt(pow(mv.row, 2) + pow(mv.col, 2));
371     double distance_threshold = min_frame_size * TF_SEARCH_DISTANCE_THRESHOLD;
372     distance_threshold = AOMMAX(distance_threshold, 1);
373     d_factor[subblock_idx] = distance / distance_threshold;
374     d_factor[subblock_idx] = AOMMAX(d_factor[subblock_idx], 1);
375   }
376 
377   // Handle planes in sequence.
378   int plane_offset = 0;
379   for (int plane = 0; plane < num_planes; ++plane) {
380     const uint32_t plane_h = mb_height >> mbd->plane[plane].subsampling_y;
381     const uint32_t plane_w = mb_width >> mbd->plane[plane].subsampling_x;
382     const uint32_t frame_stride = frame_to_filter->strides[plane == 0 ? 0 : 1];
383     const int frame_offset = mb_row * plane_h * frame_stride + mb_col * plane_w;
384 
385     const uint16_t *ref =
386         CONVERT_TO_SHORTPTR(frame_to_filter->buffers[plane]) + frame_offset;
387     const int ss_x_shift =
388         mbd->plane[plane].subsampling_x - mbd->plane[AOM_PLANE_Y].subsampling_x;
389     const int ss_y_shift =
390         mbd->plane[plane].subsampling_y - mbd->plane[AOM_PLANE_Y].subsampling_y;
391     const int num_ref_pixels = TF_WINDOW_LENGTH * TF_WINDOW_LENGTH +
392                                ((plane) ? (1 << (ss_x_shift + ss_y_shift)) : 0);
393     const double inv_num_ref_pixels = 1.0 / num_ref_pixels;
394     // Larger noise -> larger filtering weight.
395     const double n_decay = 0.5 + log(2 * noise_levels[plane] + 5.0);
396     const double decay_factor = 1 / (n_decay * q_decay * s_decay);
397 
398     // Filter U-plane and V-plane using Y-plane. This is because motion
399     // search is only done on Y-plane, so the information from Y-plane
400     // will be more accurate. The luma sse sum is reused in both chroma
401     // planes.
402     if (plane == AOM_PLANE_U) {
403       for (unsigned int i = 0, k = 0; i < plane_h; i++) {
404         for (unsigned int j = 0; j < plane_w; j++, k++) {
405           for (int ii = 0; ii < (1 << ss_y_shift); ++ii) {
406             for (int jj = 0; jj < (1 << ss_x_shift); ++jj) {
407               const int yy = (i << ss_y_shift) + ii;  // Y-coord on Y-plane.
408               const int xx = (j << ss_x_shift) + jj;  // X-coord on Y-plane.
409               luma_sse_sum[i * BW + j] += frame_sse[yy * SSE_STRIDE + xx + 2];
410             }
411           }
412         }
413       }
414     }
415 
416     highbd_apply_temporal_filter(
417         ref, frame_stride, pred1 + plane_offset, plane_w, plane_w, plane_h,
418         subblock_mses, accum + plane_offset, count + plane_offset, frame_sse,
419         luma_sse_sum, mbd->bd, inv_num_ref_pixels, decay_factor, inv_factor,
420         weight_factor, d_factor);
421     plane_offset += plane_h * plane_w;
422   }
423 }
424