1 /*****************************************************************************
2  * This file is part of Kvazaar HEVC encoder.
3  *
4  * Copyright (c) 2021, Tampere University, ITU/ISO/IEC, project contributors
5  * All rights reserved.
6  *
7  * Redistribution and use in source and binary forms, with or without modification,
8  * are permitted provided that the following conditions are met:
9  *
10  * * Redistributions of source code must retain the above copyright notice, this
11  *   list of conditions and the following disclaimer.
12  *
13  * * Redistributions in binary form must reproduce the above copyright notice, this
14  *   list of conditions and the following disclaimer in the documentation and/or
15  *   other materials provided with the distribution.
16  *
17  * * Neither the name of the Tampere University or ITU/ISO/IEC nor the names of its
18  *   contributors may be used to endorse or promote products derived from
19  *   this software without specific prior written permission.
20  *
21  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
22  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
23  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
24  * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
25  * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
26  * INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
27  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION HOWEVER CAUSED AND ON
28  * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
29  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
30  * INCLUDING NEGLIGENCE OR OTHERWISE ARISING IN ANY WAY OUT OF THE USE OF THIS
31  ****************************************************************************/
32 
33 /*
34  * \file
35  */
36 
37 #include "global.h"
38 
39 #if COMPILE_INTEL_AVX2
40 #include "kvazaar.h"
41 #if KVZ_BIT_DEPTH == 8
42 #include "strategies/avx2/picture-avx2.h"
43 #include "strategies/avx2/reg_sad_pow2_widths-avx2.h"
44 
45 #include <immintrin.h>
46 #include <emmintrin.h>
47 #include <mmintrin.h>
48 #include <xmmintrin.h>
49 #include <string.h>
50 #include "strategies/strategies-picture.h"
51 #include "strategyselector.h"
52 #include "strategies/generic/picture-generic.h"
53 
54 /**
55  * \brief Calculate Sum of Absolute Differences (SAD)
56  *
57  * Calculate Sum of Absolute Differences (SAD) between two rectangular regions
58  * located in arbitrary points in the picture.
59  *
60  * \param data1   Starting point of the first picture.
61  * \param data2   Starting point of the second picture.
62  * \param width   Width of the region for which SAD is calculated.
63  * \param height  Height of the region for which SAD is calculated.
64  * \param stride  Width of the pixel array.
65  *
66  * \returns Sum of Absolute Differences
67  */
kvz_reg_sad_avx2(const uint8_t * const data1,const uint8_t * const data2,const int width,const int height,const unsigned stride1,const unsigned stride2)68 uint32_t kvz_reg_sad_avx2(const uint8_t * const data1, const uint8_t * const data2,
69                           const int width, const int height, const unsigned stride1, const unsigned stride2)
70 {
71   if (width == 0)
72     return 0;
73   if (width == 4)
74     return reg_sad_w4(data1, data2, height, stride1, stride2);
75   if (width == 8)
76     return reg_sad_w8(data1, data2, height, stride1, stride2);
77   if (width == 12)
78     return reg_sad_w12(data1, data2, height, stride1, stride2);
79   if (width == 16)
80     return reg_sad_w16(data1, data2, height, stride1, stride2);
81   if (width == 24)
82     return reg_sad_w24(data1, data2, height, stride1, stride2);
83   if (width == 32)
84     return reg_sad_w32(data1, data2, height, stride1, stride2);
85   if (width == 64)
86     return reg_sad_w64(data1, data2, height, stride1, stride2);
87   else
88     return reg_sad_arbitrary(data1, data2, width, height, stride1, stride2);
89 }
90 
91 /**
92 * \brief Calculate SAD for 8x8 bytes in continuous memory.
93 */
inline_8bit_sad_8x8_avx2(const __m256i * const a,const __m256i * const b)94 static INLINE __m256i inline_8bit_sad_8x8_avx2(const __m256i *const a, const __m256i *const b)
95 {
96   __m256i sum0, sum1;
97   sum0 = _mm256_sad_epu8(_mm256_load_si256(a + 0), _mm256_load_si256(b + 0));
98   sum1 = _mm256_sad_epu8(_mm256_load_si256(a + 1), _mm256_load_si256(b + 1));
99 
100   return _mm256_add_epi32(sum0, sum1);
101 }
102 
103 
104 /**
105 * \brief Calculate SAD for 16x16 bytes in continuous memory.
106 */
inline_8bit_sad_16x16_avx2(const __m256i * const a,const __m256i * const b)107 static INLINE __m256i inline_8bit_sad_16x16_avx2(const __m256i *const a, const __m256i *const b)
108 {
109   const unsigned size_of_8x8 = 8 * 8 / sizeof(__m256i);
110 
111   // Calculate in 4 chunks of 16x4.
112   __m256i sum0, sum1, sum2, sum3;
113   sum0 = inline_8bit_sad_8x8_avx2(a + 0 * size_of_8x8, b + 0 * size_of_8x8);
114   sum1 = inline_8bit_sad_8x8_avx2(a + 1 * size_of_8x8, b + 1 * size_of_8x8);
115   sum2 = inline_8bit_sad_8x8_avx2(a + 2 * size_of_8x8, b + 2 * size_of_8x8);
116   sum3 = inline_8bit_sad_8x8_avx2(a + 3 * size_of_8x8, b + 3 * size_of_8x8);
117 
118   sum0 = _mm256_add_epi32(sum0, sum1);
119   sum2 = _mm256_add_epi32(sum2, sum3);
120 
121   return _mm256_add_epi32(sum0, sum2);
122 }
123 
124 
125 /**
126 * \brief Get sum of the low 32 bits of four 64 bit numbers from __m256i as uint32_t.
127 */
m256i_horizontal_sum(const __m256i sum)128 static INLINE uint32_t m256i_horizontal_sum(const __m256i sum)
129 {
130   // Add the high 128 bits to low 128 bits.
131   __m128i mm128_result = _mm_add_epi32(_mm256_castsi256_si128(sum), _mm256_extractf128_si256(sum, 1));
132   // Add the high 64 bits  to low 64 bits.
133   uint32_t result[4];
134   _mm_storeu_si128((__m128i*)result, mm128_result);
135   return result[0] + result[2];
136 }
137 
138 
sad_8bit_8x8_avx2(const uint8_t * buf1,const uint8_t * buf2)139 static unsigned sad_8bit_8x8_avx2(const uint8_t *buf1, const uint8_t *buf2)
140 {
141   const __m256i *const a = (const __m256i *)buf1;
142   const __m256i *const b = (const __m256i *)buf2;
143   __m256i sum = inline_8bit_sad_8x8_avx2(a, b);
144 
145   return m256i_horizontal_sum(sum);
146 }
147 
148 
sad_8bit_16x16_avx2(const uint8_t * buf1,const uint8_t * buf2)149 static unsigned sad_8bit_16x16_avx2(const uint8_t *buf1, const uint8_t *buf2)
150 {
151   const __m256i *const a = (const __m256i *)buf1;
152   const __m256i *const b = (const __m256i *)buf2;
153   __m256i sum = inline_8bit_sad_16x16_avx2(a, b);
154 
155   return m256i_horizontal_sum(sum);
156 }
157 
158 
sad_8bit_32x32_avx2(const uint8_t * buf1,const uint8_t * buf2)159 static unsigned sad_8bit_32x32_avx2(const uint8_t *buf1, const uint8_t *buf2)
160 {
161   const __m256i *const a = (const __m256i *)buf1;
162   const __m256i *const b = (const __m256i *)buf2;
163 
164   const unsigned size_of_8x8 = 8 * 8 / sizeof(__m256i);
165   const unsigned size_of_32x32 = 32 * 32 / sizeof(__m256i);
166 
167   // Looping 512 bytes at a time seems faster than letting VC figure it out
168   // through inlining, like inline_8bit_sad_16x16_avx2 does.
169   __m256i sum0 = inline_8bit_sad_8x8_avx2(a, b);
170   for (unsigned i = size_of_8x8; i < size_of_32x32; i += size_of_8x8) {
171     __m256i sum1 = inline_8bit_sad_8x8_avx2(a + i, b + i);
172     sum0 = _mm256_add_epi32(sum0, sum1);
173   }
174 
175   return m256i_horizontal_sum(sum0);
176 }
177 
178 
sad_8bit_64x64_avx2(const uint8_t * buf1,const uint8_t * buf2)179 static unsigned sad_8bit_64x64_avx2(const uint8_t * buf1, const uint8_t * buf2)
180 {
181   const __m256i *const a = (const __m256i *)buf1;
182   const __m256i *const b = (const __m256i *)buf2;
183 
184   const unsigned size_of_8x8 = 8 * 8 / sizeof(__m256i);
185   const unsigned size_of_64x64 = 64 * 64 / sizeof(__m256i);
186 
187   // Looping 512 bytes at a time seems faster than letting VC figure it out
188   // through inlining, like inline_8bit_sad_16x16_avx2 does.
189   __m256i sum0 = inline_8bit_sad_8x8_avx2(a, b);
190   for (unsigned i = size_of_8x8; i < size_of_64x64; i += size_of_8x8) {
191     __m256i sum1 = inline_8bit_sad_8x8_avx2(a + i, b + i);
192     sum0 = _mm256_add_epi32(sum0, sum1);
193   }
194 
195   return m256i_horizontal_sum(sum0);
196 }
197 
satd_4x4_8bit_avx2(const uint8_t * org,const uint8_t * cur)198 static unsigned satd_4x4_8bit_avx2(const uint8_t *org, const uint8_t *cur)
199 {
200 
201   __m128i original = _mm_cvtepu8_epi16(_mm_loadl_epi64((__m128i*)org));
202   __m128i current = _mm_cvtepu8_epi16(_mm_loadl_epi64((__m128i*)cur));
203 
204   __m128i diff_lo = _mm_sub_epi16(current, original);
205 
206   original = _mm_cvtepu8_epi16(_mm_loadl_epi64((__m128i*)(org + 8)));
207   current = _mm_cvtepu8_epi16(_mm_loadl_epi64((__m128i*)(cur + 8)));
208 
209   __m128i diff_hi = _mm_sub_epi16(current, original);
210 
211 
212   //Hor
213   __m128i row0 = _mm_hadd_epi16(diff_lo, diff_hi);
214   __m128i row1 = _mm_hsub_epi16(diff_lo, diff_hi);
215 
216   __m128i row2 = _mm_hadd_epi16(row0, row1);
217   __m128i row3 = _mm_hsub_epi16(row0, row1);
218 
219   //Ver
220   row0 = _mm_hadd_epi16(row2, row3);
221   row1 = _mm_hsub_epi16(row2, row3);
222 
223   row2 = _mm_hadd_epi16(row0, row1);
224   row3 = _mm_hsub_epi16(row0, row1);
225 
226   //Abs and sum
227   row2 = _mm_abs_epi16(row2);
228   row3 = _mm_abs_epi16(row3);
229 
230   row3 = _mm_add_epi16(row2, row3);
231 
232   row3 = _mm_add_epi16(row3, _mm_shuffle_epi32(row3, _MM_SHUFFLE(1, 0, 3, 2) ));
233   row3 = _mm_add_epi16(row3, _mm_shuffle_epi32(row3, _MM_SHUFFLE(0, 1, 0, 1) ));
234   row3 = _mm_add_epi16(row3, _mm_shufflelo_epi16(row3, _MM_SHUFFLE(0, 1, 0, 1) ));
235 
236   unsigned sum = _mm_extract_epi16(row3, 0);
237   unsigned satd = (sum + 1) >> 1;
238 
239   return satd;
240 }
241 
242 
satd_8bit_4x4_dual_avx2(const pred_buffer preds,const uint8_t * const orig,unsigned num_modes,unsigned * satds_out)243 static void satd_8bit_4x4_dual_avx2(
244   const pred_buffer preds, const uint8_t * const orig, unsigned num_modes, unsigned *satds_out)
245 {
246 
247   __m256i original = _mm256_broadcastsi128_si256(_mm_cvtepu8_epi16(_mm_loadl_epi64((__m128i*)orig)));
248   __m256i pred = _mm256_cvtepu8_epi16(_mm_loadl_epi64((__m128i*)preds[0]));
249   pred = _mm256_inserti128_si256(pred, _mm_cvtepu8_epi16(_mm_loadl_epi64((__m128i*)preds[1])), 1);
250 
251   __m256i diff_lo = _mm256_sub_epi16(pred, original);
252 
253   original = _mm256_broadcastsi128_si256(_mm_cvtepu8_epi16(_mm_loadl_epi64((__m128i*)(orig + 8))));
254   pred = _mm256_cvtepu8_epi16(_mm_loadl_epi64((__m128i*)(preds[0] + 8)));
255   pred = _mm256_inserti128_si256(pred, _mm_cvtepu8_epi16(_mm_loadl_epi64((__m128i*)(preds[1] + 8))), 1);
256 
257   __m256i diff_hi = _mm256_sub_epi16(pred, original);
258 
259   //Hor
260   __m256i row0 = _mm256_hadd_epi16(diff_lo, diff_hi);
261   __m256i row1 = _mm256_hsub_epi16(diff_lo, diff_hi);
262 
263   __m256i row2 = _mm256_hadd_epi16(row0, row1);
264   __m256i row3 = _mm256_hsub_epi16(row0, row1);
265 
266   //Ver
267   row0 = _mm256_hadd_epi16(row2, row3);
268   row1 = _mm256_hsub_epi16(row2, row3);
269 
270   row2 = _mm256_hadd_epi16(row0, row1);
271   row3 = _mm256_hsub_epi16(row0, row1);
272 
273   //Abs and sum
274   row2 = _mm256_abs_epi16(row2);
275   row3 = _mm256_abs_epi16(row3);
276 
277   row3 = _mm256_add_epi16(row2, row3);
278 
279   row3 = _mm256_add_epi16(row3, _mm256_shuffle_epi32(row3, _MM_SHUFFLE(1, 0, 3, 2) ));
280   row3 = _mm256_add_epi16(row3, _mm256_shuffle_epi32(row3, _MM_SHUFFLE(0, 1, 0, 1) ));
281   row3 = _mm256_add_epi16(row3, _mm256_shufflelo_epi16(row3, _MM_SHUFFLE(0, 1, 0, 1) ));
282 
283   unsigned sum1 = _mm_extract_epi16(_mm256_castsi256_si128(row3), 0);
284   sum1 = (sum1 + 1) >> 1;
285 
286   unsigned sum2 = _mm_extract_epi16(_mm256_extracti128_si256(row3, 1), 0);
287   sum2 = (sum2 + 1) >> 1;
288 
289   satds_out[0] = sum1;
290   satds_out[1] = sum2;
291 }
292 
hor_transform_row_avx2(__m128i * row)293 static INLINE void hor_transform_row_avx2(__m128i* row){
294 
295   __m128i mask_pos = _mm_set1_epi16(1);
296   __m128i mask_neg = _mm_set1_epi16(-1);
297   __m128i sign_mask = _mm_unpacklo_epi64(mask_pos, mask_neg);
298   __m128i temp = _mm_shuffle_epi32(*row, _MM_SHUFFLE(1, 0, 3, 2));
299   *row = _mm_sign_epi16(*row, sign_mask);
300   *row = _mm_add_epi16(*row, temp);
301 
302   sign_mask = _mm_unpacklo_epi32(mask_pos, mask_neg);
303   temp = _mm_shuffle_epi32(*row, _MM_SHUFFLE(2, 3, 0, 1));
304   *row = _mm_sign_epi16(*row, sign_mask);
305   *row = _mm_add_epi16(*row, temp);
306 
307   sign_mask = _mm_unpacklo_epi16(mask_pos, mask_neg);
308   temp = _mm_shufflelo_epi16(*row, _MM_SHUFFLE(2,3,0,1));
309   temp = _mm_shufflehi_epi16(temp, _MM_SHUFFLE(2,3,0,1));
310   *row = _mm_sign_epi16(*row, sign_mask);
311   *row = _mm_add_epi16(*row, temp);
312 }
313 
hor_transform_row_dual_avx2(__m256i * row)314 static INLINE void hor_transform_row_dual_avx2(__m256i* row){
315 
316   __m256i mask_pos = _mm256_set1_epi16(1);
317   __m256i mask_neg = _mm256_set1_epi16(-1);
318   __m256i sign_mask = _mm256_unpacklo_epi64(mask_pos, mask_neg);
319   __m256i temp = _mm256_shuffle_epi32(*row, _MM_SHUFFLE(1, 0, 3, 2));
320   *row = _mm256_sign_epi16(*row, sign_mask);
321   *row = _mm256_add_epi16(*row, temp);
322 
323   sign_mask = _mm256_unpacklo_epi32(mask_pos, mask_neg);
324   temp = _mm256_shuffle_epi32(*row, _MM_SHUFFLE(2, 3, 0, 1));
325   *row = _mm256_sign_epi16(*row, sign_mask);
326   *row = _mm256_add_epi16(*row, temp);
327 
328   sign_mask = _mm256_unpacklo_epi16(mask_pos, mask_neg);
329   temp = _mm256_shufflelo_epi16(*row, _MM_SHUFFLE(2,3,0,1));
330   temp = _mm256_shufflehi_epi16(temp, _MM_SHUFFLE(2,3,0,1));
331   *row = _mm256_sign_epi16(*row, sign_mask);
332   *row = _mm256_add_epi16(*row, temp);
333 }
334 
add_sub_avx2(__m128i * out,__m128i * in,unsigned out_idx0,unsigned out_idx1,unsigned in_idx0,unsigned in_idx1)335 static INLINE void add_sub_avx2(__m128i *out, __m128i *in, unsigned out_idx0, unsigned out_idx1, unsigned in_idx0, unsigned in_idx1)
336 {
337   out[out_idx0] = _mm_add_epi16(in[in_idx0], in[in_idx1]);
338   out[out_idx1] = _mm_sub_epi16(in[in_idx0], in[in_idx1]);
339 }
340 
ver_transform_block_avx2(__m128i (* rows)[8])341 static INLINE void ver_transform_block_avx2(__m128i (*rows)[8]){
342 
343   __m128i temp0[8];
344   add_sub_avx2(temp0, (*rows), 0, 1, 0, 1);
345   add_sub_avx2(temp0, (*rows), 2, 3, 2, 3);
346   add_sub_avx2(temp0, (*rows), 4, 5, 4, 5);
347   add_sub_avx2(temp0, (*rows), 6, 7, 6, 7);
348 
349   __m128i temp1[8];
350   add_sub_avx2(temp1, temp0, 0, 1, 0, 2);
351   add_sub_avx2(temp1, temp0, 2, 3, 1, 3);
352   add_sub_avx2(temp1, temp0, 4, 5, 4, 6);
353   add_sub_avx2(temp1, temp0, 6, 7, 5, 7);
354 
355   add_sub_avx2((*rows), temp1, 0, 1, 0, 4);
356   add_sub_avx2((*rows), temp1, 2, 3, 1, 5);
357   add_sub_avx2((*rows), temp1, 4, 5, 2, 6);
358   add_sub_avx2((*rows), temp1, 6, 7, 3, 7);
359 
360 }
361 
add_sub_dual_avx2(__m256i * out,__m256i * in,unsigned out_idx0,unsigned out_idx1,unsigned in_idx0,unsigned in_idx1)362 static INLINE void add_sub_dual_avx2(__m256i *out, __m256i *in, unsigned out_idx0, unsigned out_idx1, unsigned in_idx0, unsigned in_idx1)
363 {
364   out[out_idx0] = _mm256_add_epi16(in[in_idx0], in[in_idx1]);
365   out[out_idx1] = _mm256_sub_epi16(in[in_idx0], in[in_idx1]);
366 }
367 
368 
ver_transform_block_dual_avx2(__m256i (* rows)[8])369 static INLINE void ver_transform_block_dual_avx2(__m256i (*rows)[8]){
370 
371   __m256i temp0[8];
372   add_sub_dual_avx2(temp0, (*rows), 0, 1, 0, 1);
373   add_sub_dual_avx2(temp0, (*rows), 2, 3, 2, 3);
374   add_sub_dual_avx2(temp0, (*rows), 4, 5, 4, 5);
375   add_sub_dual_avx2(temp0, (*rows), 6, 7, 6, 7);
376 
377   __m256i temp1[8];
378   add_sub_dual_avx2(temp1, temp0, 0, 1, 0, 2);
379   add_sub_dual_avx2(temp1, temp0, 2, 3, 1, 3);
380   add_sub_dual_avx2(temp1, temp0, 4, 5, 4, 6);
381   add_sub_dual_avx2(temp1, temp0, 6, 7, 5, 7);
382 
383   add_sub_dual_avx2((*rows), temp1, 0, 1, 0, 4);
384   add_sub_dual_avx2((*rows), temp1, 2, 3, 1, 5);
385   add_sub_dual_avx2((*rows), temp1, 4, 5, 2, 6);
386   add_sub_dual_avx2((*rows), temp1, 6, 7, 3, 7);
387 
388 }
389 
haddwd_accumulate_avx2(__m128i * accumulate,__m128i * ver_row)390 INLINE static void haddwd_accumulate_avx2(__m128i *accumulate, __m128i *ver_row)
391 {
392   __m128i abs_value = _mm_abs_epi16(*ver_row);
393   *accumulate = _mm_add_epi32(*accumulate, _mm_madd_epi16(abs_value, _mm_set1_epi16(1)));
394 }
395 
haddwd_accumulate_dual_avx2(__m256i * accumulate,__m256i * ver_row)396 INLINE static void haddwd_accumulate_dual_avx2(__m256i *accumulate, __m256i *ver_row)
397 {
398   __m256i abs_value = _mm256_abs_epi16(*ver_row);
399   *accumulate = _mm256_add_epi32(*accumulate, _mm256_madd_epi16(abs_value, _mm256_set1_epi16(1)));
400 }
401 
sum_block_avx2(__m128i * ver_row)402 INLINE static unsigned sum_block_avx2(__m128i *ver_row)
403 {
404   __m128i sad = _mm_setzero_si128();
405   haddwd_accumulate_avx2(&sad, ver_row + 0);
406   haddwd_accumulate_avx2(&sad, ver_row + 1);
407   haddwd_accumulate_avx2(&sad, ver_row + 2);
408   haddwd_accumulate_avx2(&sad, ver_row + 3);
409   haddwd_accumulate_avx2(&sad, ver_row + 4);
410   haddwd_accumulate_avx2(&sad, ver_row + 5);
411   haddwd_accumulate_avx2(&sad, ver_row + 6);
412   haddwd_accumulate_avx2(&sad, ver_row + 7);
413 
414   sad = _mm_add_epi32(sad, _mm_shuffle_epi32(sad, _MM_SHUFFLE(1, 0, 3, 2)));
415   sad = _mm_add_epi32(sad, _mm_shuffle_epi32(sad, _MM_SHUFFLE(0, 1, 0, 1)));
416 
417   return _mm_cvtsi128_si32(sad);
418 }
419 
sum_block_dual_avx2(__m256i * ver_row,unsigned * sum0,unsigned * sum1)420 INLINE static void sum_block_dual_avx2(__m256i *ver_row, unsigned *sum0, unsigned *sum1)
421 {
422   __m256i sad = _mm256_setzero_si256();
423   haddwd_accumulate_dual_avx2(&sad, ver_row + 0);
424   haddwd_accumulate_dual_avx2(&sad, ver_row + 1);
425   haddwd_accumulate_dual_avx2(&sad, ver_row + 2);
426   haddwd_accumulate_dual_avx2(&sad, ver_row + 3);
427   haddwd_accumulate_dual_avx2(&sad, ver_row + 4);
428   haddwd_accumulate_dual_avx2(&sad, ver_row + 5);
429   haddwd_accumulate_dual_avx2(&sad, ver_row + 6);
430   haddwd_accumulate_dual_avx2(&sad, ver_row + 7);
431 
432   sad = _mm256_add_epi32(sad, _mm256_shuffle_epi32(sad, _MM_SHUFFLE(1, 0, 3, 2)));
433   sad = _mm256_add_epi32(sad, _mm256_shuffle_epi32(sad, _MM_SHUFFLE(0, 1, 0, 1)));
434 
435   *sum0 = _mm_cvtsi128_si32(_mm256_extracti128_si256(sad, 0));
436   *sum1 = _mm_cvtsi128_si32(_mm256_extracti128_si256(sad, 1));
437 }
438 
diff_row_avx2(const uint8_t * buf1,const uint8_t * buf2)439 INLINE static __m128i diff_row_avx2(const uint8_t *buf1, const uint8_t *buf2)
440 {
441   __m128i buf1_row = _mm_cvtepu8_epi16(_mm_loadl_epi64((__m128i*)buf1));
442   __m128i buf2_row = _mm_cvtepu8_epi16(_mm_loadl_epi64((__m128i*)buf2));
443   return _mm_sub_epi16(buf1_row, buf2_row);
444 }
445 
diff_row_dual_avx2(const uint8_t * buf1,const uint8_t * buf2,const uint8_t * orig)446 INLINE static __m256i diff_row_dual_avx2(const uint8_t *buf1, const uint8_t *buf2, const uint8_t *orig)
447 {
448   __m128i temp1 = _mm_loadl_epi64((__m128i*)buf1);
449   __m128i temp2 = _mm_loadl_epi64((__m128i*)buf2);
450   __m128i temp3 = _mm_loadl_epi64((__m128i*)orig);
451   __m256i buf1_row = _mm256_cvtepu8_epi16(_mm_unpacklo_epi64(temp1, temp2));
452   __m256i buf2_row = _mm256_cvtepu8_epi16(_mm_broadcastq_epi64(temp3));
453 
454   return _mm256_sub_epi16(buf1_row, buf2_row);
455 }
456 
diff_blocks_avx2(__m128i (* row_diff)[8],const uint8_t * buf1,unsigned stride1,const uint8_t * orig,unsigned stride_orig)457 INLINE static void diff_blocks_avx2(__m128i (*row_diff)[8],
458                                                            const uint8_t * buf1, unsigned stride1,
459                                                            const uint8_t * orig, unsigned stride_orig)
460 {
461   (*row_diff)[0] = diff_row_avx2(buf1 + 0 * stride1, orig + 0 * stride_orig);
462   (*row_diff)[1] = diff_row_avx2(buf1 + 1 * stride1, orig + 1 * stride_orig);
463   (*row_diff)[2] = diff_row_avx2(buf1 + 2 * stride1, orig + 2 * stride_orig);
464   (*row_diff)[3] = diff_row_avx2(buf1 + 3 * stride1, orig + 3 * stride_orig);
465   (*row_diff)[4] = diff_row_avx2(buf1 + 4 * stride1, orig + 4 * stride_orig);
466   (*row_diff)[5] = diff_row_avx2(buf1 + 5 * stride1, orig + 5 * stride_orig);
467   (*row_diff)[6] = diff_row_avx2(buf1 + 6 * stride1, orig + 6 * stride_orig);
468   (*row_diff)[7] = diff_row_avx2(buf1 + 7 * stride1, orig + 7 * stride_orig);
469 
470 }
471 
diff_blocks_dual_avx2(__m256i (* row_diff)[8],const uint8_t * buf1,unsigned stride1,const uint8_t * buf2,unsigned stride2,const uint8_t * orig,unsigned stride_orig)472 INLINE static void diff_blocks_dual_avx2(__m256i (*row_diff)[8],
473                                                            const uint8_t * buf1, unsigned stride1,
474                                                            const uint8_t * buf2, unsigned stride2,
475                                                            const uint8_t * orig, unsigned stride_orig)
476 {
477   (*row_diff)[0] = diff_row_dual_avx2(buf1 + 0 * stride1, buf2 + 0 * stride2, orig + 0 * stride_orig);
478   (*row_diff)[1] = diff_row_dual_avx2(buf1 + 1 * stride1, buf2 + 1 * stride2, orig + 1 * stride_orig);
479   (*row_diff)[2] = diff_row_dual_avx2(buf1 + 2 * stride1, buf2 + 2 * stride2, orig + 2 * stride_orig);
480   (*row_diff)[3] = diff_row_dual_avx2(buf1 + 3 * stride1, buf2 + 3 * stride2, orig + 3 * stride_orig);
481   (*row_diff)[4] = diff_row_dual_avx2(buf1 + 4 * stride1, buf2 + 4 * stride2, orig + 4 * stride_orig);
482   (*row_diff)[5] = diff_row_dual_avx2(buf1 + 5 * stride1, buf2 + 5 * stride2, orig + 5 * stride_orig);
483   (*row_diff)[6] = diff_row_dual_avx2(buf1 + 6 * stride1, buf2 + 6 * stride2, orig + 6 * stride_orig);
484   (*row_diff)[7] = diff_row_dual_avx2(buf1 + 7 * stride1, buf2 + 7 * stride2, orig + 7 * stride_orig);
485 
486 }
487 
hor_transform_block_avx2(__m128i (* row_diff)[8])488 INLINE static void hor_transform_block_avx2(__m128i (*row_diff)[8])
489 {
490   hor_transform_row_avx2((*row_diff) + 0);
491   hor_transform_row_avx2((*row_diff) + 1);
492   hor_transform_row_avx2((*row_diff) + 2);
493   hor_transform_row_avx2((*row_diff) + 3);
494   hor_transform_row_avx2((*row_diff) + 4);
495   hor_transform_row_avx2((*row_diff) + 5);
496   hor_transform_row_avx2((*row_diff) + 6);
497   hor_transform_row_avx2((*row_diff) + 7);
498 }
499 
hor_transform_block_dual_avx2(__m256i (* row_diff)[8])500 INLINE static void hor_transform_block_dual_avx2(__m256i (*row_diff)[8])
501 {
502   hor_transform_row_dual_avx2((*row_diff) + 0);
503   hor_transform_row_dual_avx2((*row_diff) + 1);
504   hor_transform_row_dual_avx2((*row_diff) + 2);
505   hor_transform_row_dual_avx2((*row_diff) + 3);
506   hor_transform_row_dual_avx2((*row_diff) + 4);
507   hor_transform_row_dual_avx2((*row_diff) + 5);
508   hor_transform_row_dual_avx2((*row_diff) + 6);
509   hor_transform_row_dual_avx2((*row_diff) + 7);
510 }
511 
kvz_satd_8bit_8x8_general_dual_avx2(const uint8_t * buf1,unsigned stride1,const uint8_t * buf2,unsigned stride2,const uint8_t * orig,unsigned stride_orig,unsigned * sum0,unsigned * sum1)512 static void kvz_satd_8bit_8x8_general_dual_avx2(const uint8_t * buf1, unsigned stride1,
513                                                 const uint8_t * buf2, unsigned stride2,
514                                                 const uint8_t * orig, unsigned stride_orig,
515                                                 unsigned *sum0, unsigned *sum1)
516 {
517   __m256i temp[8];
518 
519   diff_blocks_dual_avx2(&temp, buf1, stride1, buf2, stride2, orig, stride_orig);
520   hor_transform_block_dual_avx2(&temp);
521   ver_transform_block_dual_avx2(&temp);
522 
523   sum_block_dual_avx2(temp, sum0, sum1);
524 
525   *sum0 = (*sum0 + 2) >> 2;
526   *sum1 = (*sum1 + 2) >> 2;
527 }
528 
529 /**
530 * \brief  Calculate SATD between two 4x4 blocks inside bigger arrays.
531 */
kvz_satd_4x4_subblock_8bit_avx2(const uint8_t * buf1,const int32_t stride1,const uint8_t * buf2,const int32_t stride2)532 static unsigned kvz_satd_4x4_subblock_8bit_avx2(const uint8_t * buf1,
533                                                 const int32_t     stride1,
534                                                 const uint8_t * buf2,
535                                                 const int32_t     stride2)
536 {
537   // TODO: AVX2 implementation
538   return kvz_satd_4x4_subblock_generic(buf1, stride1, buf2, stride2);
539 }
540 
kvz_satd_4x4_subblock_quad_avx2(const uint8_t * preds[4],const int stride,const uint8_t * orig,const int orig_stride,unsigned costs[4])541 static void kvz_satd_4x4_subblock_quad_avx2(const uint8_t *preds[4],
542                                        const int stride,
543                                        const uint8_t *orig,
544                                        const int orig_stride,
545                                        unsigned costs[4])
546 {
547   // TODO: AVX2 implementation
548   kvz_satd_4x4_subblock_quad_generic(preds, stride, orig, orig_stride, costs);
549 }
550 
satd_8x8_subblock_8bit_avx2(const uint8_t * buf1,unsigned stride1,const uint8_t * buf2,unsigned stride2)551 static unsigned satd_8x8_subblock_8bit_avx2(const uint8_t * buf1, unsigned stride1, const uint8_t * buf2, unsigned stride2)
552 {
553   __m128i temp[8];
554 
555   diff_blocks_avx2(&temp, buf1, stride1, buf2, stride2);
556   hor_transform_block_avx2(&temp);
557   ver_transform_block_avx2(&temp);
558 
559   unsigned sad = sum_block_avx2(temp);
560 
561   unsigned result = (sad + 2) >> 2;
562   return result;
563 }
564 
satd_8x8_subblock_quad_avx2(const uint8_t ** preds,const int stride,const uint8_t * orig,const int orig_stride,unsigned * costs)565 static void satd_8x8_subblock_quad_avx2(const uint8_t **preds,
566   const int stride,
567   const uint8_t *orig,
568   const int orig_stride,
569   unsigned *costs)
570 {
571   kvz_satd_8bit_8x8_general_dual_avx2(preds[0], stride, preds[1], stride, orig, orig_stride, &costs[0], &costs[1]);
572   kvz_satd_8bit_8x8_general_dual_avx2(preds[2], stride, preds[3], stride, orig, orig_stride, &costs[2], &costs[3]);
573 }
574 
575 SATD_NxN(8bit_avx2,  8)
576 SATD_NxN(8bit_avx2, 16)
577 SATD_NxN(8bit_avx2, 32)
578 SATD_NxN(8bit_avx2, 64)
579 SATD_ANY_SIZE(8bit_avx2)
580 
581 // Function macro for defining hadamard calculating functions
582 // for fixed size blocks. They calculate hadamard for integer
583 // multiples of 8x8 with the 8x8 hadamard function.
584 #define SATD_NXN_DUAL_AVX2(n) \
585 static void satd_8bit_ ## n ## x ## n ## _dual_avx2( \
586   const pred_buffer preds, const uint8_t * const orig, unsigned num_modes, unsigned *satds_out)  \
587 { \
588   unsigned x, y; \
589   satds_out[0] = 0; \
590   satds_out[1] = 0; \
591   unsigned sum1 = 0; \
592   unsigned sum2 = 0; \
593   for (y = 0; y < (n); y += 8) { \
594   unsigned row = y * (n); \
595   for (x = 0; x < (n); x += 8) { \
596   kvz_satd_8bit_8x8_general_dual_avx2(&preds[0][row + x], (n), &preds[1][row + x], (n), &orig[row + x], (n), &sum1, &sum2); \
597   satds_out[0] += sum1; \
598   satds_out[1] += sum2; \
599     } \
600     } \
601   satds_out[0] >>= (KVZ_BIT_DEPTH-8); \
602   satds_out[1] >>= (KVZ_BIT_DEPTH-8); \
603 }
604 
satd_8bit_8x8_dual_avx2(const pred_buffer preds,const uint8_t * const orig,unsigned num_modes,unsigned * satds_out)605 static void satd_8bit_8x8_dual_avx2(
606   const pred_buffer preds, const uint8_t * const orig, unsigned num_modes, unsigned *satds_out)
607 {
608   unsigned x, y;
609   satds_out[0] = 0;
610   satds_out[1] = 0;
611   unsigned sum1 = 0;
612   unsigned sum2 = 0;
613   for (y = 0; y < (8); y += 8) {
614   unsigned row = y * (8);
615   for (x = 0; x < (8); x += 8) {
616   kvz_satd_8bit_8x8_general_dual_avx2(&preds[0][row + x], (8), &preds[1][row + x], (8), &orig[row + x], (8), &sum1, &sum2);
617   satds_out[0] += sum1;
618   satds_out[1] += sum2;
619       }
620       }
621   satds_out[0] >>= (KVZ_BIT_DEPTH-8);
622   satds_out[1] >>= (KVZ_BIT_DEPTH-8);
623 }
624 
625 //SATD_NXN_DUAL_AVX2(8) //Use the non-macro version
626 SATD_NXN_DUAL_AVX2(16)
627 SATD_NXN_DUAL_AVX2(32)
628 SATD_NXN_DUAL_AVX2(64)
629 
630 #define SATD_ANY_SIZE_MULTI_AVX2(suffix, num_parallel_blocks) \
631   static cost_pixel_any_size_multi_func satd_any_size_## suffix; \
632   static void satd_any_size_ ## suffix ( \
633       int width, int height, \
634       const uint8_t **preds, \
635       const int stride, \
636       const uint8_t *orig, \
637       const int orig_stride, \
638       unsigned num_modes, \
639       unsigned *costs_out, \
640       int8_t *valid) \
641   { \
642     unsigned sums[num_parallel_blocks] = { 0 }; \
643     const uint8_t *pred_ptrs[4] = { preds[0], preds[1], preds[2], preds[3] };\
644     const uint8_t *orig_ptr = orig; \
645     costs_out[0] = 0; costs_out[1] = 0; costs_out[2] = 0; costs_out[3] = 0; \
646     if (width % 8 != 0) { \
647       /* Process the first column using 4x4 blocks. */ \
648       for (int y = 0; y < height; y += 4) { \
649         kvz_satd_4x4_subblock_ ## suffix(preds, stride, orig, orig_stride, sums); \
650             } \
651       orig_ptr += 4; \
652       for(int blk = 0; blk < num_parallel_blocks; ++blk){\
653         pred_ptrs[blk] += 4; \
654             }\
655       width -= 4; \
656             } \
657     if (height % 8 != 0) { \
658       /* Process the first row using 4x4 blocks. */ \
659       for (int x = 0; x < width; x += 4 ) { \
660         kvz_satd_4x4_subblock_ ## suffix(pred_ptrs, stride, orig_ptr, orig_stride, sums); \
661             } \
662       orig_ptr += 4 * orig_stride; \
663       for(int blk = 0; blk < num_parallel_blocks; ++blk){\
664         pred_ptrs[blk] += 4 * stride; \
665             }\
666       height -= 4; \
667         } \
668     /* The rest can now be processed with 8x8 blocks. */ \
669     for (int y = 0; y < height; y += 8) { \
670       orig_ptr = &orig[y * orig_stride]; \
671       pred_ptrs[0] = &preds[0][y * stride]; \
672       pred_ptrs[1] = &preds[1][y * stride]; \
673       pred_ptrs[2] = &preds[2][y * stride]; \
674       pred_ptrs[3] = &preds[3][y * stride]; \
675       for (int x = 0; x < width; x += 8) { \
676         satd_8x8_subblock_ ## suffix(pred_ptrs, stride, orig_ptr, orig_stride, sums); \
677         orig_ptr += 8; \
678         pred_ptrs[0] += 8; \
679         pred_ptrs[1] += 8; \
680         pred_ptrs[2] += 8; \
681         pred_ptrs[3] += 8; \
682         costs_out[0] += sums[0]; \
683         costs_out[1] += sums[1]; \
684         costs_out[2] += sums[2]; \
685         costs_out[3] += sums[3]; \
686       } \
687     } \
688     for(int i = 0; i < num_parallel_blocks; ++i){\
689       costs_out[i] = costs_out[i] >> (KVZ_BIT_DEPTH - 8);\
690     } \
691     return; \
692   }
693 
694 SATD_ANY_SIZE_MULTI_AVX2(quad_avx2, 4)
695 
696 
pixels_calc_ssd_avx2(const uint8_t * const ref,const uint8_t * const rec,const int ref_stride,const int rec_stride,const int width)697 static unsigned pixels_calc_ssd_avx2(const uint8_t *const ref, const uint8_t *const rec,
698                  const int ref_stride, const int rec_stride,
699                  const int width)
700 {
701   __m256i ssd_part;
702   __m256i diff = _mm256_setzero_si256();
703   __m128i sum;
704 
705   __m256i ref_epi16;
706   __m256i rec_epi16;
707 
708   __m128i ref_row0, ref_row1, ref_row2, ref_row3;
709   __m128i rec_row0, rec_row1, rec_row2, rec_row3;
710 
711   int ssd;
712 
713   switch (width) {
714 
715   case 4:
716 
717     ref_row0 = _mm_cvtsi32_si128(*(int32_t*)&(ref[0 * ref_stride]));
718     ref_row1 = _mm_cvtsi32_si128(*(int32_t*)&(ref[1 * ref_stride]));
719     ref_row2 = _mm_cvtsi32_si128(*(int32_t*)&(ref[2 * ref_stride]));
720     ref_row3 = _mm_cvtsi32_si128(*(int32_t*)&(ref[3 * ref_stride]));
721 
722     ref_row0 = _mm_unpacklo_epi32(ref_row0, ref_row1);
723     ref_row1 = _mm_unpacklo_epi32(ref_row2, ref_row3);
724     ref_epi16 = _mm256_cvtepu8_epi16(_mm_unpacklo_epi64(ref_row0, ref_row1) );
725 
726     rec_row0 = _mm_cvtsi32_si128(*(int32_t*)&(rec[0 * rec_stride]));
727     rec_row1 = _mm_cvtsi32_si128(*(int32_t*)&(rec[1 * rec_stride]));
728     rec_row2 = _mm_cvtsi32_si128(*(int32_t*)&(rec[2 * rec_stride]));
729     rec_row3 = _mm_cvtsi32_si128(*(int32_t*)&(rec[3 * rec_stride]));
730 
731     rec_row0 = _mm_unpacklo_epi32(rec_row0, rec_row1);
732     rec_row1 = _mm_unpacklo_epi32(rec_row2, rec_row3);
733     rec_epi16 = _mm256_cvtepu8_epi16(_mm_unpacklo_epi64(rec_row0, rec_row1) );
734 
735     diff = _mm256_sub_epi16(ref_epi16, rec_epi16);
736     ssd_part =  _mm256_madd_epi16(diff, diff);
737 
738     sum = _mm_add_epi32(_mm256_castsi256_si128(ssd_part), _mm256_extracti128_si256(ssd_part, 1));
739     sum = _mm_add_epi32(sum, _mm_shuffle_epi32(sum, _MM_SHUFFLE(1, 0, 3, 2)));
740     sum = _mm_add_epi32(sum, _mm_shuffle_epi32(sum, _MM_SHUFFLE(0, 1, 0, 1)));
741 
742     ssd = _mm_cvtsi128_si32(sum);
743 
744     return ssd >> (2*(KVZ_BIT_DEPTH-8));
745     break;
746 
747   default:
748 
749     ssd_part = _mm256_setzero_si256();
750     for (int y = 0; y < width; y += 8) {
751       for (int x = 0; x < width; x += 8) {
752         for (int i = 0; i < 8; i += 2) {
753           ref_epi16 = _mm256_cvtepu8_epi16(_mm_unpacklo_epi64(_mm_loadl_epi64((__m128i*)&(ref[x + (y + i) * ref_stride])), _mm_loadl_epi64((__m128i*)&(ref[x + (y + i + 1) * ref_stride]))));
754           rec_epi16 = _mm256_cvtepu8_epi16(_mm_unpacklo_epi64(_mm_loadl_epi64((__m128i*)&(rec[x + (y + i) * rec_stride])), _mm_loadl_epi64((__m128i*)&(rec[x + (y + i + 1) * rec_stride]))));
755           diff = _mm256_sub_epi16(ref_epi16, rec_epi16);
756           ssd_part = _mm256_add_epi32(ssd_part, _mm256_madd_epi16(diff, diff));
757         }
758       }
759     }
760 
761     sum = _mm_add_epi32(_mm256_castsi256_si128(ssd_part), _mm256_extracti128_si256(ssd_part, 1));
762     sum = _mm_add_epi32(sum, _mm_shuffle_epi32(sum, _MM_SHUFFLE(1, 0, 3, 2)));
763     sum = _mm_add_epi32(sum, _mm_shuffle_epi32(sum, _MM_SHUFFLE(0, 1, 0, 1)));
764 
765     ssd = _mm_cvtsi128_si32(sum);
766 
767     return ssd >> (2*(KVZ_BIT_DEPTH-8));
768     break;
769   }
770 }
771 
inter_recon_bipred_avx2(const int hi_prec_luma_rec0,const int hi_prec_luma_rec1,const int hi_prec_chroma_rec0,const int hi_prec_chroma_rec1,const int height,const int width,const int ypos,const int xpos,const hi_prec_buf_t * high_precision_rec0,const hi_prec_buf_t * high_precision_rec1,lcu_t * lcu,uint8_t * temp_lcu_y,uint8_t * temp_lcu_u,uint8_t * temp_lcu_v,bool predict_luma,bool predict_chroma)772 static void inter_recon_bipred_avx2(const int hi_prec_luma_rec0,
773  const int hi_prec_luma_rec1,
774  const int hi_prec_chroma_rec0,
775  const int hi_prec_chroma_rec1,
776  const int height,
777  const int width,
778  const int ypos,
779  const int xpos,
780  const hi_prec_buf_t*high_precision_rec0,
781  const hi_prec_buf_t*high_precision_rec1,
782  lcu_t* lcu,
783  uint8_t* temp_lcu_y,
784  uint8_t* temp_lcu_u,
785  uint8_t* temp_lcu_v,
786 bool predict_luma,
787 bool predict_chroma)
788 {
789   int y_in_lcu, x_in_lcu;
790   int shift = 15 - KVZ_BIT_DEPTH;
791   int offset = 1 << (shift - 1);
792   __m256i temp_epi8, temp_y_epi32, sample0_epi32, sample1_epi32, temp_epi16;
793   int32_t * pointer = 0;
794   __m256i offset_epi32 = _mm256_set1_epi32(offset);
795 
796   for (int temp_y = 0; temp_y < height; ++temp_y) {
797 
798    y_in_lcu = ((ypos + temp_y) & ((LCU_WIDTH)-1));
799 
800    for (int temp_x = 0; temp_x < width; temp_x += 8) {
801     x_in_lcu = ((xpos + temp_x) & ((LCU_WIDTH)-1));
802 
803     if (predict_luma) {
804       bool use_8_elements = ((temp_x + 8) <= width);
805 
806       if (!use_8_elements) {
807         if (width < 4) {
808           // If width is smaller than 4 there's no need to use SIMD
809           for (int temp_i = 0; temp_i < width; ++temp_i) {
810             x_in_lcu = ((xpos + temp_i) & ((LCU_WIDTH)-1));
811 
812             int sample0_y = (hi_prec_luma_rec0 ? high_precision_rec0->y[y_in_lcu * LCU_WIDTH + x_in_lcu] : (temp_lcu_y[y_in_lcu * LCU_WIDTH + x_in_lcu] << (14 - KVZ_BIT_DEPTH)));
813             int sample1_y = (hi_prec_luma_rec1 ? high_precision_rec1->y[y_in_lcu * LCU_WIDTH + x_in_lcu] : (lcu->rec.y[y_in_lcu * LCU_WIDTH + x_in_lcu] << (14 - KVZ_BIT_DEPTH)));
814 
815             lcu->rec.y[y_in_lcu * LCU_WIDTH + x_in_lcu] = (uint8_t)kvz_fast_clip_32bit_to_pixel((sample0_y + sample1_y + offset) >> shift);
816           }
817         }
818 
819         else {
820           // Load total of 4 elements from memory to vector
821           sample0_epi32 = hi_prec_luma_rec0 ? _mm256_cvtepi16_epi32(_mm_loadl_epi64((__m128i*) &(high_precision_rec0->y[y_in_lcu * LCU_WIDTH + x_in_lcu]))) :
822             _mm256_slli_epi32(_mm256_cvtepu8_epi32(_mm_cvtsi32_si128(*(int32_t*)&(temp_lcu_y[y_in_lcu * LCU_WIDTH + x_in_lcu]))), 14 - KVZ_BIT_DEPTH);
823 
824 
825           sample1_epi32 = hi_prec_luma_rec1 ? _mm256_cvtepi16_epi32(_mm_loadl_epi64((__m128i*) &(high_precision_rec1->y[y_in_lcu * LCU_WIDTH + x_in_lcu]))) :
826             _mm256_slli_epi32(_mm256_cvtepu8_epi32(_mm_cvtsi32_si128(*(int32_t*) &(lcu->rec.y[y_in_lcu * LCU_WIDTH + x_in_lcu]))), 14 - KVZ_BIT_DEPTH);
827 
828 
829           // (sample1 + sample2 + offset)>>shift
830           temp_y_epi32 = _mm256_add_epi32(sample0_epi32, sample1_epi32);
831           temp_y_epi32 = _mm256_add_epi32(temp_y_epi32, offset_epi32);
832           temp_y_epi32 = _mm256_srai_epi32(temp_y_epi32, shift);
833 
834           // Pack the bits from 32-bit to 8-bit
835           temp_epi16 = _mm256_packs_epi32(temp_y_epi32, temp_y_epi32);
836           temp_epi16 = _mm256_permute4x64_epi64(temp_epi16, _MM_SHUFFLE(3, 1, 2, 0));
837           temp_epi8 = _mm256_packus_epi16(temp_epi16, temp_epi16);
838 
839           pointer = (int32_t*)&(lcu->rec.y[(y_in_lcu)* LCU_WIDTH + x_in_lcu]);
840           *pointer = _mm_cvtsi128_si32(_mm256_castsi256_si128(temp_epi8));
841 
842 
843 
844           for (int temp_i = temp_x + 4; temp_i < width; ++temp_i) {
845             x_in_lcu = ((xpos + temp_i) & ((LCU_WIDTH)-1));
846 
847             int16_t sample0_y = (hi_prec_luma_rec0 ? high_precision_rec0->y[y_in_lcu * LCU_WIDTH + x_in_lcu] : (temp_lcu_y[y_in_lcu * LCU_WIDTH + x_in_lcu] << (14 - KVZ_BIT_DEPTH)));
848             int16_t sample1_y = (hi_prec_luma_rec1 ? high_precision_rec1->y[y_in_lcu * LCU_WIDTH + x_in_lcu] : (lcu->rec.y[y_in_lcu * LCU_WIDTH + x_in_lcu] << (14 - KVZ_BIT_DEPTH)));
849 
850             lcu->rec.y[y_in_lcu * LCU_WIDTH + x_in_lcu] = (uint8_t)kvz_fast_clip_32bit_to_pixel((sample0_y + sample1_y + offset) >> shift);
851           }
852 
853         }
854       } else {
855         // Load total of 8 elements from memory to vector
856         sample0_epi32 = hi_prec_luma_rec0 ? _mm256_cvtepi16_epi32(_mm_loadu_si128((__m128i*) &(high_precision_rec0->y[y_in_lcu * LCU_WIDTH + x_in_lcu]))) :
857           _mm256_slli_epi32(_mm256_cvtepu8_epi32((_mm_loadl_epi64((__m128i*) &(temp_lcu_y[y_in_lcu * LCU_WIDTH + x_in_lcu])))), 14 - KVZ_BIT_DEPTH);
858 
859         sample1_epi32 = hi_prec_luma_rec1 ? _mm256_cvtepi16_epi32(_mm_loadu_si128((__m128i*) &(high_precision_rec1->y[y_in_lcu * LCU_WIDTH + x_in_lcu]))) :
860           _mm256_slli_epi32(_mm256_cvtepu8_epi32((_mm_loadl_epi64((__m128i*) &(lcu->rec.y[y_in_lcu * LCU_WIDTH + x_in_lcu])))), 14 - KVZ_BIT_DEPTH);
861 
862         // (sample1 + sample2 + offset)>>shift
863         temp_y_epi32 = _mm256_add_epi32(sample0_epi32, sample1_epi32);
864         temp_y_epi32 = _mm256_add_epi32(temp_y_epi32, offset_epi32);
865         temp_y_epi32 = _mm256_srai_epi32(temp_y_epi32, shift);
866 
867         // Pack the bits from 32-bit to 8-bit
868         temp_epi16 = _mm256_packs_epi32(temp_y_epi32, temp_y_epi32);
869         temp_epi16 = _mm256_permute4x64_epi64(temp_epi16, _MM_SHUFFLE(3, 1, 2, 0));
870         temp_epi8 = _mm256_packus_epi16(temp_epi16, temp_epi16);
871 
872         // Store 64-bits from vector to memory
873         _mm_storel_epi64((__m128i*)&(lcu->rec.y[(y_in_lcu)* LCU_WIDTH + x_in_lcu]), _mm256_castsi256_si128(temp_epi8));
874       }
875     }
876    }
877   }
878   for (int temp_y = 0; temp_y < height >> 1; ++temp_y) {
879    int y_in_lcu = (((ypos >> 1) + temp_y) & (LCU_WIDTH_C - 1));
880 
881    for (int temp_x = 0; temp_x < width >> 1; temp_x += 8) {
882 
883     int x_in_lcu = (((xpos >> 1) + temp_x) & (LCU_WIDTH_C - 1));
884 
885     if (predict_chroma) {
886       if ((width >> 1) < 4) {
887         // If width>>1 is smaller than 4 there's no need to use SIMD
888 
889         for (int temp_i = 0; temp_i < width >> 1; ++temp_i) {
890           int temp_x_in_lcu = (((xpos >> 1) + temp_i) & (LCU_WIDTH_C - 1));
891           int16_t sample0_u = (hi_prec_chroma_rec0 ? high_precision_rec0->u[y_in_lcu * LCU_WIDTH_C + temp_x_in_lcu] : (temp_lcu_u[y_in_lcu * LCU_WIDTH_C + temp_x_in_lcu] << (14 - KVZ_BIT_DEPTH)));
892           int16_t sample1_u = (hi_prec_chroma_rec1 ? high_precision_rec1->u[y_in_lcu * LCU_WIDTH_C + temp_x_in_lcu] : (lcu->rec.u[y_in_lcu * LCU_WIDTH_C + temp_x_in_lcu] << (14 - KVZ_BIT_DEPTH)));
893           lcu->rec.u[y_in_lcu * LCU_WIDTH_C + temp_x_in_lcu] = (uint8_t)kvz_fast_clip_32bit_to_pixel((sample0_u + sample1_u + offset) >> shift);
894 
895           int16_t sample0_v = (hi_prec_chroma_rec0 ? high_precision_rec0->v[y_in_lcu * LCU_WIDTH_C + temp_x_in_lcu] : (temp_lcu_v[y_in_lcu * LCU_WIDTH_C + temp_x_in_lcu] << (14 - KVZ_BIT_DEPTH)));
896           int16_t sample1_v = (hi_prec_chroma_rec1 ? high_precision_rec1->v[y_in_lcu * LCU_WIDTH_C + temp_x_in_lcu] : (lcu->rec.v[y_in_lcu * LCU_WIDTH_C + temp_x_in_lcu] << (14 - KVZ_BIT_DEPTH)));
897           lcu->rec.v[y_in_lcu * LCU_WIDTH_C + temp_x_in_lcu] = (uint8_t)kvz_fast_clip_32bit_to_pixel((sample0_v + sample1_v + offset) >> shift);
898         }
899       }
900 
901       else {
902 
903         bool use_8_elements = ((temp_x + 8) <= (width >> 1));
904 
905         __m256i temp_u_epi32, temp_v_epi32;
906 
907         if (!use_8_elements) {
908           // Load 4 pixels to vector
909           sample0_epi32 = hi_prec_chroma_rec0 ? _mm256_cvtepi16_epi32(_mm_loadl_epi64((__m128i*) &(high_precision_rec0->u[y_in_lcu * LCU_WIDTH_C + x_in_lcu]))) :
910             _mm256_slli_epi32(_mm256_cvtepu8_epi32(_mm_cvtsi32_si128(*(int32_t*) &(temp_lcu_u[y_in_lcu * LCU_WIDTH_C + x_in_lcu]))), 14 - KVZ_BIT_DEPTH);
911 
912           sample1_epi32 = hi_prec_chroma_rec1 ? _mm256_cvtepi16_epi32(_mm_loadl_epi64((__m128i*) &(high_precision_rec1->u[y_in_lcu * LCU_WIDTH_C + x_in_lcu]))) :
913             _mm256_slli_epi32(_mm256_cvtepu8_epi32(_mm_cvtsi32_si128(*(int32_t*) &(lcu->rec.u[y_in_lcu * LCU_WIDTH_C + x_in_lcu]))), 14 - KVZ_BIT_DEPTH);
914 
915           // (sample1 + sample2 + offset)>>shift
916           temp_u_epi32 = _mm256_add_epi32(sample0_epi32, sample1_epi32);
917           temp_u_epi32 = _mm256_add_epi32(temp_u_epi32, offset_epi32);
918           temp_u_epi32 = _mm256_srai_epi32(temp_u_epi32, shift);
919 
920 
921 
922           sample0_epi32 = hi_prec_chroma_rec0 ? _mm256_cvtepi16_epi32(_mm_loadl_epi64((__m128i*) &(high_precision_rec0->v[y_in_lcu * LCU_WIDTH_C + x_in_lcu]))) :
923             _mm256_slli_epi32(_mm256_cvtepu8_epi32(_mm_cvtsi32_si128(*(int32_t*) &(temp_lcu_v[y_in_lcu * LCU_WIDTH_C + x_in_lcu]))), 14 - KVZ_BIT_DEPTH);
924 
925           sample1_epi32 = hi_prec_chroma_rec1 ? _mm256_cvtepi16_epi32(_mm_loadl_epi64((__m128i*) &(high_precision_rec1->v[y_in_lcu * LCU_WIDTH_C + x_in_lcu]))) :
926             _mm256_slli_epi32(_mm256_cvtepu8_epi32(_mm_cvtsi32_si128(*(int32_t*) &(lcu->rec.v[y_in_lcu * LCU_WIDTH_C + x_in_lcu]))), 14 - KVZ_BIT_DEPTH);
927 
928 
929           // (sample1 + sample2 + offset)>>shift
930           temp_v_epi32 = _mm256_add_epi32(sample0_epi32, sample1_epi32);
931           temp_v_epi32 = _mm256_add_epi32(temp_v_epi32, offset_epi32);
932           temp_v_epi32 = _mm256_srai_epi32(temp_v_epi32, shift);
933 
934 
935           temp_epi16 = _mm256_packs_epi32(temp_u_epi32, temp_u_epi32);
936           temp_epi16 = _mm256_permute4x64_epi64(temp_epi16, _MM_SHUFFLE(3, 1, 2, 0));
937           temp_epi8 = _mm256_packus_epi16(temp_epi16, temp_epi16);
938 
939           pointer = (int32_t*)&(lcu->rec.u[(y_in_lcu)* LCU_WIDTH_C + x_in_lcu]);
940           *pointer = _mm_cvtsi128_si32(_mm256_castsi256_si128(temp_epi8));
941 
942 
943           temp_epi16 = _mm256_packs_epi32(temp_v_epi32, temp_v_epi32);
944           temp_epi16 = _mm256_permute4x64_epi64(temp_epi16, _MM_SHUFFLE(3, 1, 2, 0));
945           temp_epi8 = _mm256_packus_epi16(temp_epi16, temp_epi16);
946 
947           pointer = (int32_t*)&(lcu->rec.v[(y_in_lcu)* LCU_WIDTH_C + x_in_lcu]);
948           *pointer = _mm_cvtsi128_si32(_mm256_castsi256_si128(temp_epi8));
949 
950           for (int temp_i = 4; temp_i < width >> 1; ++temp_i) {
951 
952             // Use only if width>>1 is not divideble by 4
953             int temp_x_in_lcu = (((xpos >> 1) + temp_i) & (LCU_WIDTH_C - 1));
954             int16_t sample0_u = (hi_prec_chroma_rec0 ? high_precision_rec0->u[y_in_lcu * LCU_WIDTH_C + temp_x_in_lcu] : (temp_lcu_u[y_in_lcu * LCU_WIDTH_C + temp_x_in_lcu] << (14 - KVZ_BIT_DEPTH)));
955             int16_t sample1_u = (hi_prec_chroma_rec1 ? high_precision_rec1->u[y_in_lcu * LCU_WIDTH_C + temp_x_in_lcu] : (lcu->rec.u[y_in_lcu * LCU_WIDTH_C + temp_x_in_lcu] << (14 - KVZ_BIT_DEPTH)));
956             lcu->rec.u[y_in_lcu * LCU_WIDTH_C + temp_x_in_lcu] = (uint8_t)kvz_fast_clip_32bit_to_pixel((sample0_u + sample1_u + offset) >> shift);
957 
958             int16_t sample0_v = (hi_prec_chroma_rec0 ? high_precision_rec0->v[y_in_lcu * LCU_WIDTH_C + temp_x_in_lcu] : (temp_lcu_v[y_in_lcu * LCU_WIDTH_C + temp_x_in_lcu] << (14 - KVZ_BIT_DEPTH)));
959             int16_t sample1_v = (hi_prec_chroma_rec1 ? high_precision_rec1->v[y_in_lcu * LCU_WIDTH_C + temp_x_in_lcu] : (lcu->rec.v[y_in_lcu * LCU_WIDTH_C + temp_x_in_lcu] << (14 - KVZ_BIT_DEPTH)));
960             lcu->rec.v[y_in_lcu * LCU_WIDTH_C + temp_x_in_lcu] = (uint8_t)kvz_fast_clip_32bit_to_pixel((sample0_v + sample1_v + offset) >> shift);
961           }
962         } else {
963           // Load 8 pixels to vector
964           sample0_epi32 = hi_prec_chroma_rec0 ? _mm256_cvtepi16_epi32(_mm_loadu_si128((__m128i*) &(high_precision_rec0->u[y_in_lcu * LCU_WIDTH_C + x_in_lcu]))) :
965             _mm256_slli_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((__m128i*) &(temp_lcu_u[y_in_lcu * LCU_WIDTH_C + x_in_lcu]))), 14 - KVZ_BIT_DEPTH);
966 
967           sample1_epi32 = hi_prec_chroma_rec1 ? _mm256_cvtepi16_epi32(_mm_loadu_si128((__m128i*) &(high_precision_rec1->u[y_in_lcu * LCU_WIDTH_C + x_in_lcu]))) :
968             _mm256_slli_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((__m128i*) &(lcu->rec.u[y_in_lcu * LCU_WIDTH_C + x_in_lcu]))), 14 - KVZ_BIT_DEPTH);
969 
970           // (sample1 + sample2 + offset)>>shift
971           temp_u_epi32 = _mm256_add_epi32(sample0_epi32, sample1_epi32);
972           temp_u_epi32 = _mm256_add_epi32(temp_u_epi32, offset_epi32);
973           temp_u_epi32 = _mm256_srai_epi32(temp_u_epi32, shift);
974 
975           sample0_epi32 = hi_prec_chroma_rec0 ? _mm256_cvtepi16_epi32(_mm_loadu_si128((__m128i*) &(high_precision_rec0->v[y_in_lcu * LCU_WIDTH_C + x_in_lcu]))) :
976             _mm256_slli_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((__m128i*) &(temp_lcu_v[y_in_lcu * LCU_WIDTH_C + x_in_lcu]))), 14 - KVZ_BIT_DEPTH);
977 
978           sample1_epi32 = hi_prec_chroma_rec1 ? _mm256_cvtepi16_epi32(_mm_loadu_si128((__m128i*) &(high_precision_rec1->v[y_in_lcu * LCU_WIDTH_C + x_in_lcu]))) :
979             _mm256_slli_epi32(_mm256_cvtepu8_epi32(_mm_loadl_epi64((__m128i*) &(lcu->rec.v[y_in_lcu * LCU_WIDTH_C + x_in_lcu]))), 14 - KVZ_BIT_DEPTH);
980 
981 
982           // (sample1 + sample2 + offset)>>shift
983           temp_v_epi32 = _mm256_add_epi32(sample0_epi32, sample1_epi32);
984           temp_v_epi32 = _mm256_add_epi32(temp_v_epi32, offset_epi32);
985           temp_v_epi32 = _mm256_srai_epi32(temp_v_epi32, shift);
986 
987           temp_epi16 = _mm256_packs_epi32(temp_u_epi32, temp_u_epi32);
988           temp_epi16 = _mm256_permute4x64_epi64(temp_epi16, _MM_SHUFFLE(3, 1, 2, 0));
989           temp_epi8 = _mm256_packus_epi16(temp_epi16, temp_epi16);
990 
991           // Store 64-bit integer into memory
992           _mm_storel_epi64((__m128i*)&(lcu->rec.u[(y_in_lcu)* LCU_WIDTH_C + x_in_lcu]), _mm256_castsi256_si128(temp_epi8));
993 
994           temp_epi16 = _mm256_packs_epi32(temp_v_epi32, temp_v_epi32);
995           temp_epi16 = _mm256_permute4x64_epi64(temp_epi16, _MM_SHUFFLE(3, 1, 2, 0));
996           temp_epi8 = _mm256_packus_epi16(temp_epi16, temp_epi16);
997 
998           // Store 64-bit integer into memory
999           _mm_storel_epi64((__m128i*)&(lcu->rec.v[(y_in_lcu)* LCU_WIDTH_C + x_in_lcu]), _mm256_castsi256_si128(temp_epi8));
1000         }
1001       }
1002     }
1003    }
1004   }
1005 }
1006 
get_optimized_sad_avx2(int32_t width)1007 static optimized_sad_func_ptr_t get_optimized_sad_avx2(int32_t width)
1008 {
1009   if (width == 0)
1010     return reg_sad_w0;
1011   if (width == 4)
1012     return reg_sad_w4;
1013   if (width == 8)
1014     return reg_sad_w8;
1015   if (width == 12)
1016     return reg_sad_w12;
1017   if (width == 16)
1018     return reg_sad_w16;
1019   if (width == 24)
1020     return reg_sad_w24;
1021   if (width == 32)
1022     return reg_sad_w32;
1023   if (width == 64)
1024     return reg_sad_w64;
1025   else
1026     return NULL;
1027 }
1028 
ver_sad_avx2(const uint8_t * pic_data,const uint8_t * ref_data,int32_t width,int32_t height,uint32_t stride)1029 static uint32_t ver_sad_avx2(const uint8_t *pic_data, const uint8_t *ref_data,
1030                              int32_t width, int32_t height, uint32_t stride)
1031 {
1032   if (width == 0)
1033     return 0;
1034   if (width == 4)
1035     return ver_sad_w4(pic_data, ref_data, height, stride);
1036   if (width == 8)
1037     return ver_sad_w8(pic_data, ref_data, height, stride);
1038   if (width == 12)
1039     return ver_sad_w12(pic_data, ref_data, height, stride);
1040   if (width == 16)
1041     return ver_sad_w16(pic_data, ref_data, height, stride);
1042   else
1043     return ver_sad_arbitrary(pic_data, ref_data, width, height, stride);
1044 }
1045 
hor_sad_avx2(const uint8_t * pic_data,const uint8_t * ref_data,int32_t width,int32_t height,uint32_t pic_stride,uint32_t ref_stride,uint32_t left,uint32_t right)1046 static uint32_t hor_sad_avx2(const uint8_t *pic_data, const uint8_t *ref_data,
1047                              int32_t width, int32_t height, uint32_t pic_stride,
1048                              uint32_t ref_stride, uint32_t left, uint32_t right)
1049 {
1050   if (width == 4)
1051     return hor_sad_sse41_w4(pic_data, ref_data, height,
1052                             pic_stride, ref_stride, left, right);
1053   if (width == 8)
1054     return hor_sad_sse41_w8(pic_data, ref_data, height,
1055                             pic_stride, ref_stride, left, right);
1056   if (width == 16)
1057     return hor_sad_sse41_w16(pic_data, ref_data, height,
1058                              pic_stride, ref_stride, left, right);
1059   if (width == 32)
1060     return hor_sad_avx2_w32 (pic_data, ref_data, height,
1061                              pic_stride, ref_stride, left, right);
1062   else
1063     return hor_sad_sse41_arbitrary(pic_data, ref_data, width, height,
1064                                    pic_stride, ref_stride, left, right);
1065 }
1066 
pixel_var_avx2_largebuf(const uint8_t * buf,const uint32_t len)1067 static double pixel_var_avx2_largebuf(const uint8_t *buf, const uint32_t len)
1068 {
1069   const float len_f  = (float)len;
1070   const __m256i zero = _mm256_setzero_si256();
1071 
1072   int64_t sum;
1073   size_t i;
1074   __m256i sums = zero;
1075   for (i = 0; i + 31 < len; i += 32) {
1076     __m256i curr = _mm256_loadu_si256((const __m256i *)(buf + i));
1077     __m256i curr_sum = _mm256_sad_epu8(curr, zero);
1078             sums = _mm256_add_epi64(sums, curr_sum);
1079   }
1080   __m128i sum_lo = _mm256_castsi256_si128  (sums);
1081   __m128i sum_hi = _mm256_extracti128_si256(sums,   1);
1082   __m128i sum_3  = _mm_add_epi64           (sum_lo, sum_hi);
1083   __m128i sum_4  = _mm_shuffle_epi32       (sum_3,  _MM_SHUFFLE(1, 0, 3, 2));
1084   __m128i sum_5  = _mm_add_epi64           (sum_3,  sum_4);
1085 
1086   _mm_storel_epi64((__m128i *)&sum, sum_5);
1087 
1088   // Remaining len mod 32 pixels
1089   for (; i < len; ++i) {
1090     sum += buf[i];
1091   }
1092 
1093   float   mean_f = (float)sum / len_f;
1094   __m256  mean   = _mm256_set1_ps(mean_f);
1095   __m256  accum  = _mm256_setzero_ps();
1096 
1097   for (i = 0; i + 31 < len; i += 32) {
1098     __m128i curr0    = _mm_loadl_epi64((const __m128i *)(buf + i +  0));
1099     __m128i curr1    = _mm_loadl_epi64((const __m128i *)(buf + i +  8));
1100     __m128i curr2    = _mm_loadl_epi64((const __m128i *)(buf + i + 16));
1101     __m128i curr3    = _mm_loadl_epi64((const __m128i *)(buf + i + 24));
1102 
1103     __m256i curr0_32 = _mm256_cvtepu8_epi32(curr0);
1104     __m256i curr1_32 = _mm256_cvtepu8_epi32(curr1);
1105     __m256i curr2_32 = _mm256_cvtepu8_epi32(curr2);
1106     __m256i curr3_32 = _mm256_cvtepu8_epi32(curr3);
1107 
1108     __m256  curr0_f  = _mm256_cvtepi32_ps  (curr0_32);
1109     __m256  curr1_f  = _mm256_cvtepi32_ps  (curr1_32);
1110     __m256  curr2_f  = _mm256_cvtepi32_ps  (curr2_32);
1111     __m256  curr3_f  = _mm256_cvtepi32_ps  (curr3_32);
1112 
1113     __m256  curr0_sd = _mm256_sub_ps       (curr0_f,  mean);
1114     __m256  curr1_sd = _mm256_sub_ps       (curr1_f,  mean);
1115     __m256  curr2_sd = _mm256_sub_ps       (curr2_f,  mean);
1116     __m256  curr3_sd = _mm256_sub_ps       (curr3_f,  mean);
1117 
1118     __m256  curr0_v  = _mm256_mul_ps       (curr0_sd, curr0_sd);
1119     __m256  curr1_v  = _mm256_mul_ps       (curr1_sd, curr1_sd);
1120     __m256  curr2_v  = _mm256_mul_ps       (curr2_sd, curr2_sd);
1121     __m256  curr3_v  = _mm256_mul_ps       (curr3_sd, curr3_sd);
1122 
1123     __m256  curr01   = _mm256_add_ps       (curr0_v,  curr1_v);
1124     __m256  curr23   = _mm256_add_ps       (curr2_v,  curr3_v);
1125     __m256  curr     = _mm256_add_ps       (curr01,   curr23);
1126             accum    = _mm256_add_ps       (accum,    curr);
1127   }
1128   __m256d accum_d  = _mm256_castps_pd     (accum);
1129   __m256d accum2_d = _mm256_permute4x64_pd(accum_d, _MM_SHUFFLE(1, 0, 3, 2));
1130   __m256  accum2   = _mm256_castpd_ps     (accum2_d);
1131 
1132   __m256  accum3   = _mm256_add_ps        (accum,  accum2);
1133   __m256  accum4   = _mm256_permute_ps    (accum3, _MM_SHUFFLE(1, 0, 3, 2));
1134   __m256  accum5   = _mm256_add_ps        (accum3, accum4);
1135   __m256  accum6   = _mm256_permute_ps    (accum5, _MM_SHUFFLE(2, 3, 0, 1));
1136   __m256  accum7   = _mm256_add_ps        (accum5, accum6);
1137 
1138   __m128  accum8   = _mm256_castps256_ps128(accum7);
1139   float   var_sum  = _mm_cvtss_f32         (accum8);
1140 
1141   // Remaining len mod 32 pixels
1142   for (; i < len; ++i) {
1143     float diff = buf[i] - mean_f;
1144     var_sum += diff * diff;
1145   }
1146 
1147   return  var_sum / len_f;
1148 }
1149 
1150 #ifdef INACCURATE_VARIANCE_CALCULATION
1151 
1152 // Assumes that u is a power of two
ilog2(uint32_t u)1153 static INLINE uint32_t ilog2(uint32_t u)
1154 {
1155   return _tzcnt_u32(u);
1156 }
1157 
1158 // A B C D | E F G H (8x32b)
1159 //        ==>
1160 // A+B C+D | E+F G+H (4x64b)
hsum_epi32_to_epi64(const __m256i v)1161 static __m256i hsum_epi32_to_epi64(const __m256i v)
1162 {
1163   const __m256i zero    = _mm256_setzero_si256();
1164         __m256i v_shufd = _mm256_shuffle_epi32(v, _MM_SHUFFLE(3, 3, 1, 1));
1165         __m256i sums_32 = _mm256_add_epi32    (v, v_shufd);
1166         __m256i sums_64 = _mm256_blend_epi32  (sums_32, zero, 0xaa);
1167   return        sums_64;
1168 }
1169 
pixel_var_avx2(const uint8_t * buf,const uint32_t len)1170 static double pixel_var_avx2(const uint8_t *buf, const uint32_t len)
1171 {
1172   assert(sizeof(*buf) == 1);
1173   assert((len & 31) == 0);
1174 
1175   // Uses Q8.7 numbers to measure mean and deviation, so variances are Q16.14
1176   const uint64_t sum_maxwid     = ilog2(len) + (8 * sizeof(*buf));
1177   const __m128i normalize_sum   = _mm_cvtsi32_si128(sum_maxwid - 15); // Normalize mean to [0, 32767], so signed 16-bit subtraction never overflows
1178   const __m128i debias_sum      = _mm_cvtsi32_si128(1 << (sum_maxwid - 16));
1179   const float varsum_to_f       = 1.0f / (float)(1 << (14 + ilog2(len)));
1180 
1181   const bool power_of_two = (len & (len - 1)) == 0;
1182   if (sum_maxwid > 32 || sum_maxwid < 15 || !power_of_two) {
1183     return pixel_var_avx2_largebuf(buf, len);
1184   }
1185 
1186   const __m256i zero      = _mm256_setzero_si256();
1187   const __m256i himask_15 = _mm256_set1_epi16(0x7f00);
1188 
1189   uint64_t vars;
1190   size_t i;
1191   __m256i sums = zero;
1192   for (i = 0; i < len; i += 32) {
1193     __m256i curr = _mm256_loadu_si256((const __m256i *)(buf + i));
1194     __m256i curr_sum = _mm256_sad_epu8(curr, zero);
1195             sums = _mm256_add_epi64(sums, curr_sum);
1196   }
1197   __m128i sum_lo = _mm256_castsi256_si128  (sums);
1198   __m128i sum_hi = _mm256_extracti128_si256(sums,   1);
1199   __m128i sum_3  = _mm_add_epi64           (sum_lo, sum_hi);
1200   __m128i sum_4  = _mm_shuffle_epi32       (sum_3,  _MM_SHUFFLE(1, 0, 3, 2));
1201   __m128i sum_5  = _mm_add_epi64           (sum_3,  sum_4);
1202   __m128i sum_5n = _mm_srl_epi32           (sum_5,  normalize_sum);
1203           sum_5n = _mm_add_epi32           (sum_5n, debias_sum);
1204 
1205   __m256i sum_n  = _mm256_broadcastw_epi16 (sum_5n);
1206 
1207   __m256i accum = zero;
1208   for (i = 0; i < len; i += 32) {
1209     __m256i curr = _mm256_loadu_si256((const __m256i *)(buf + i));
1210 
1211     __m256i curr0    = _mm256_slli_epi16  (curr,  7);
1212     __m256i curr1    = _mm256_srli_epi16  (curr,  1);
1213             curr0    = _mm256_and_si256   (curr0, himask_15);
1214             curr1    = _mm256_and_si256   (curr1, himask_15);
1215 
1216     __m256i dev0     = _mm256_sub_epi16   (curr0, sum_n);
1217     __m256i dev1     = _mm256_sub_epi16   (curr1, sum_n);
1218 
1219     __m256i vars0    = _mm256_madd_epi16  (dev0,  dev0);
1220     __m256i vars1    = _mm256_madd_epi16  (dev1,  dev1);
1221 
1222     __m256i varsum   = _mm256_add_epi32   (vars0, vars1);
1223             varsum   = hsum_epi32_to_epi64(varsum);
1224             accum    = _mm256_add_epi64   (accum, varsum);
1225   }
1226   __m256i accum2 = _mm256_permute4x64_epi64(accum,  _MM_SHUFFLE(1, 0, 3, 2));
1227   __m256i accum3 = _mm256_add_epi64        (accum,  accum2);
1228   __m256i accum4 = _mm256_permute4x64_epi64(accum3, _MM_SHUFFLE(2, 3, 1, 0));
1229   __m256i v_tot  = _mm256_add_epi64        (accum3, accum4);
1230   __m128i vt128  = _mm256_castsi256_si128  (v_tot);
1231 
1232   _mm_storel_epi64((__m128i *)&vars, vt128);
1233 
1234   return (float)vars * varsum_to_f;
1235 }
1236 
1237 #else // INACCURATE_VARIANCE_CALCULATION
1238 
pixel_var_avx2(const uint8_t * buf,const uint32_t len)1239 static double pixel_var_avx2(const uint8_t *buf, const uint32_t len)
1240 {
1241   return pixel_var_avx2_largebuf(buf, len);
1242 }
1243 
1244 #endif // !INACCURATE_VARIANCE_CALCULATION
1245 
1246 #endif // KVZ_BIT_DEPTH == 8
1247 #endif //COMPILE_INTEL_AVX2
1248 
kvz_strategy_register_picture_avx2(void * opaque,uint8_t bitdepth)1249 int kvz_strategy_register_picture_avx2(void* opaque, uint8_t bitdepth)
1250 {
1251   bool success = true;
1252 #if COMPILE_INTEL_AVX2
1253 #if KVZ_BIT_DEPTH == 8
1254   // We don't actually use SAD for intra right now, other than 4x4 for
1255   // transform skip, but we might again one day and this is some of the
1256   // simplest code to look at for anyone interested in doing more
1257   // optimizations, so it's worth it to keep this maintained.
1258   if (bitdepth == 8){
1259 
1260     success &= kvz_strategyselector_register(opaque, "reg_sad", "avx2", 40, &kvz_reg_sad_avx2);
1261     success &= kvz_strategyselector_register(opaque, "sad_8x8", "avx2", 40, &sad_8bit_8x8_avx2);
1262     success &= kvz_strategyselector_register(opaque, "sad_16x16", "avx2", 40, &sad_8bit_16x16_avx2);
1263     success &= kvz_strategyselector_register(opaque, "sad_32x32", "avx2", 40, &sad_8bit_32x32_avx2);
1264     success &= kvz_strategyselector_register(opaque, "sad_64x64", "avx2", 40, &sad_8bit_64x64_avx2);
1265 
1266     success &= kvz_strategyselector_register(opaque, "satd_4x4", "avx2", 40, &satd_4x4_8bit_avx2);
1267     success &= kvz_strategyselector_register(opaque, "satd_8x8", "avx2", 40, &satd_8x8_8bit_avx2);
1268     success &= kvz_strategyselector_register(opaque, "satd_16x16", "avx2", 40, &satd_16x16_8bit_avx2);
1269     success &= kvz_strategyselector_register(opaque, "satd_32x32", "avx2", 40, &satd_32x32_8bit_avx2);
1270     success &= kvz_strategyselector_register(opaque, "satd_64x64", "avx2", 40, &satd_64x64_8bit_avx2);
1271 
1272     success &= kvz_strategyselector_register(opaque, "satd_4x4_dual", "avx2", 40, &satd_8bit_4x4_dual_avx2);
1273     success &= kvz_strategyselector_register(opaque, "satd_8x8_dual", "avx2", 40, &satd_8bit_8x8_dual_avx2);
1274     success &= kvz_strategyselector_register(opaque, "satd_16x16_dual", "avx2", 40, &satd_8bit_16x16_dual_avx2);
1275     success &= kvz_strategyselector_register(opaque, "satd_32x32_dual", "avx2", 40, &satd_8bit_32x32_dual_avx2);
1276     success &= kvz_strategyselector_register(opaque, "satd_64x64_dual", "avx2", 40, &satd_8bit_64x64_dual_avx2);
1277     success &= kvz_strategyselector_register(opaque, "satd_any_size", "avx2", 40, &satd_any_size_8bit_avx2);
1278     success &= kvz_strategyselector_register(opaque, "satd_any_size_quad", "avx2", 40, &satd_any_size_quad_avx2);
1279 
1280     success &= kvz_strategyselector_register(opaque, "pixels_calc_ssd", "avx2", 40, &pixels_calc_ssd_avx2);
1281     success &= kvz_strategyselector_register(opaque, "inter_recon_bipred", "avx2", 40, &inter_recon_bipred_avx2);
1282     success &= kvz_strategyselector_register(opaque, "get_optimized_sad", "avx2", 40, &get_optimized_sad_avx2);
1283     success &= kvz_strategyselector_register(opaque, "ver_sad", "avx2", 40, &ver_sad_avx2);
1284     success &= kvz_strategyselector_register(opaque, "hor_sad", "avx2", 40, &hor_sad_avx2);
1285 
1286     success &= kvz_strategyselector_register(opaque, "pixel_var", "avx2", 40, &pixel_var_avx2);
1287 
1288   }
1289 #endif // KVZ_BIT_DEPTH == 8
1290 #endif
1291   return success;
1292 }
1293