1 /*
2  * Copyright (c) 2018, Alliance for Open Media. All rights reserved
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <assert.h>
13 #include <immintrin.h>
14 #include <smmintrin.h>
15 
16 #include "aom_dsp/x86/synonyms.h"
17 #include "aom_dsp/x86/synonyms_avx2.h"
18 #include "aom/aom_integer.h"
19 
20 #include "av1/common/reconinter.h"
21 
22 #define MAX_MASK_VALUE (1 << WEDGE_WEIGHT_BITS)
23 
24 /**
25  * See av1_wedge_sse_from_residuals_c
26  */
av1_wedge_sse_from_residuals_avx2(const int16_t * r1,const int16_t * d,const uint8_t * m,int N)27 uint64_t av1_wedge_sse_from_residuals_avx2(const int16_t *r1, const int16_t *d,
28                                            const uint8_t *m, int N) {
29   int n = -N;
30 
31   uint64_t csse;
32 
33   const __m256i v_mask_max_w = _mm256_set1_epi16(MAX_MASK_VALUE);
34   const __m256i v_zext_q = yy_set1_64_from_32i(0xffffffff);
35 
36   __m256i v_acc0_q = _mm256_setzero_si256();
37 
38   assert(N % 64 == 0);
39 
40   r1 += N;
41   d += N;
42   m += N;
43 
44   do {
45     const __m256i v_r0_w = _mm256_lddqu_si256((__m256i *)(r1 + n));
46     const __m256i v_d0_w = _mm256_lddqu_si256((__m256i *)(d + n));
47     const __m128i v_m01_b = _mm_lddqu_si128((__m128i *)(m + n));
48 
49     const __m256i v_rd0l_w = _mm256_unpacklo_epi16(v_d0_w, v_r0_w);
50     const __m256i v_rd0h_w = _mm256_unpackhi_epi16(v_d0_w, v_r0_w);
51     const __m256i v_m0_w = _mm256_cvtepu8_epi16(v_m01_b);
52 
53     const __m256i v_m0l_w = _mm256_unpacklo_epi16(v_m0_w, v_mask_max_w);
54     const __m256i v_m0h_w = _mm256_unpackhi_epi16(v_m0_w, v_mask_max_w);
55 
56     const __m256i v_t0l_d = _mm256_madd_epi16(v_rd0l_w, v_m0l_w);
57     const __m256i v_t0h_d = _mm256_madd_epi16(v_rd0h_w, v_m0h_w);
58 
59     const __m256i v_t0_w = _mm256_packs_epi32(v_t0l_d, v_t0h_d);
60 
61     const __m256i v_sq0_d = _mm256_madd_epi16(v_t0_w, v_t0_w);
62 
63     const __m256i v_sum0_q = _mm256_add_epi64(
64         _mm256_and_si256(v_sq0_d, v_zext_q), _mm256_srli_epi64(v_sq0_d, 32));
65 
66     v_acc0_q = _mm256_add_epi64(v_acc0_q, v_sum0_q);
67 
68     n += 16;
69   } while (n);
70 
71   v_acc0_q = _mm256_add_epi64(v_acc0_q, _mm256_srli_si256(v_acc0_q, 8));
72   __m128i v_acc_q_0 = _mm256_castsi256_si128(v_acc0_q);
73   __m128i v_acc_q_1 = _mm256_extracti128_si256(v_acc0_q, 1);
74   v_acc_q_0 = _mm_add_epi64(v_acc_q_0, v_acc_q_1);
75 #if ARCH_X86_64
76   csse = (uint64_t)_mm_extract_epi64(v_acc_q_0, 0);
77 #else
78   xx_storel_64(&csse, v_acc_q_0);
79 #endif
80 
81   return ROUND_POWER_OF_TWO(csse, 2 * WEDGE_WEIGHT_BITS);
82 }
83 
84 /**
85  * See av1_wedge_sign_from_residuals_c
86  */
av1_wedge_sign_from_residuals_avx2(const int16_t * ds,const uint8_t * m,int N,int64_t limit)87 int8_t av1_wedge_sign_from_residuals_avx2(const int16_t *ds, const uint8_t *m,
88                                           int N, int64_t limit) {
89   int64_t acc;
90   __m256i v_acc0_d = _mm256_setzero_si256();
91 
92   // Input size limited to 8192 by the use of 32 bit accumulators and m
93   // being between [0, 64]. Overflow might happen at larger sizes,
94   // though it is practically impossible on real video input.
95   assert(N < 8192);
96   assert(N % 64 == 0);
97 
98   do {
99     const __m256i v_m01_b = _mm256_lddqu_si256((__m256i *)(m));
100     const __m256i v_m23_b = _mm256_lddqu_si256((__m256i *)(m + 32));
101 
102     const __m256i v_d0_w = _mm256_lddqu_si256((__m256i *)(ds));
103     const __m256i v_d1_w = _mm256_lddqu_si256((__m256i *)(ds + 16));
104     const __m256i v_d2_w = _mm256_lddqu_si256((__m256i *)(ds + 32));
105     const __m256i v_d3_w = _mm256_lddqu_si256((__m256i *)(ds + 48));
106 
107     const __m256i v_m0_w =
108         _mm256_cvtepu8_epi16(_mm256_castsi256_si128(v_m01_b));
109     const __m256i v_m1_w =
110         _mm256_cvtepu8_epi16(_mm256_extracti128_si256(v_m01_b, 1));
111     const __m256i v_m2_w =
112         _mm256_cvtepu8_epi16(_mm256_castsi256_si128(v_m23_b));
113     const __m256i v_m3_w =
114         _mm256_cvtepu8_epi16(_mm256_extracti128_si256(v_m23_b, 1));
115 
116     const __m256i v_p0_d = _mm256_madd_epi16(v_d0_w, v_m0_w);
117     const __m256i v_p1_d = _mm256_madd_epi16(v_d1_w, v_m1_w);
118     const __m256i v_p2_d = _mm256_madd_epi16(v_d2_w, v_m2_w);
119     const __m256i v_p3_d = _mm256_madd_epi16(v_d3_w, v_m3_w);
120 
121     const __m256i v_p01_d = _mm256_add_epi32(v_p0_d, v_p1_d);
122     const __m256i v_p23_d = _mm256_add_epi32(v_p2_d, v_p3_d);
123 
124     const __m256i v_p0123_d = _mm256_add_epi32(v_p01_d, v_p23_d);
125 
126     v_acc0_d = _mm256_add_epi32(v_acc0_d, v_p0123_d);
127 
128     ds += 64;
129     m += 64;
130 
131     N -= 64;
132   } while (N);
133 
134   __m256i v_sign_d = _mm256_srai_epi32(v_acc0_d, 31);
135   v_acc0_d = _mm256_add_epi64(_mm256_unpacklo_epi32(v_acc0_d, v_sign_d),
136                               _mm256_unpackhi_epi32(v_acc0_d, v_sign_d));
137 
138   __m256i v_acc_q = _mm256_add_epi64(v_acc0_d, _mm256_srli_si256(v_acc0_d, 8));
139 
140   __m128i v_acc_q_0 = _mm256_castsi256_si128(v_acc_q);
141   __m128i v_acc_q_1 = _mm256_extracti128_si256(v_acc_q, 1);
142   v_acc_q_0 = _mm_add_epi64(v_acc_q_0, v_acc_q_1);
143 
144 #if ARCH_X86_64
145   acc = (uint64_t)_mm_extract_epi64(v_acc_q_0, 0);
146 #else
147   xx_storel_64(&acc, v_acc_q_0);
148 #endif
149 
150   return acc > limit;
151 }
152 
153 /**
154  * av1_wedge_compute_delta_squares_c
155  */
av1_wedge_compute_delta_squares_avx2(int16_t * d,const int16_t * a,const int16_t * b,int N)156 void av1_wedge_compute_delta_squares_avx2(int16_t *d, const int16_t *a,
157                                           const int16_t *b, int N) {
158   const __m256i v_neg_w = _mm256_set1_epi32(0xffff0001);
159 
160   assert(N % 64 == 0);
161 
162   do {
163     const __m256i v_a0_w = _mm256_lddqu_si256((__m256i *)(a));
164     const __m256i v_b0_w = _mm256_lddqu_si256((__m256i *)(b));
165     const __m256i v_a1_w = _mm256_lddqu_si256((__m256i *)(a + 16));
166     const __m256i v_b1_w = _mm256_lddqu_si256((__m256i *)(b + 16));
167     const __m256i v_a2_w = _mm256_lddqu_si256((__m256i *)(a + 32));
168     const __m256i v_b2_w = _mm256_lddqu_si256((__m256i *)(b + 32));
169     const __m256i v_a3_w = _mm256_lddqu_si256((__m256i *)(a + 48));
170     const __m256i v_b3_w = _mm256_lddqu_si256((__m256i *)(b + 48));
171 
172     const __m256i v_ab0l_w = _mm256_unpacklo_epi16(v_a0_w, v_b0_w);
173     const __m256i v_ab0h_w = _mm256_unpackhi_epi16(v_a0_w, v_b0_w);
174     const __m256i v_ab1l_w = _mm256_unpacklo_epi16(v_a1_w, v_b1_w);
175     const __m256i v_ab1h_w = _mm256_unpackhi_epi16(v_a1_w, v_b1_w);
176     const __m256i v_ab2l_w = _mm256_unpacklo_epi16(v_a2_w, v_b2_w);
177     const __m256i v_ab2h_w = _mm256_unpackhi_epi16(v_a2_w, v_b2_w);
178     const __m256i v_ab3l_w = _mm256_unpacklo_epi16(v_a3_w, v_b3_w);
179     const __m256i v_ab3h_w = _mm256_unpackhi_epi16(v_a3_w, v_b3_w);
180 
181     // Negate top word of pairs
182     const __m256i v_abl0n_w = _mm256_sign_epi16(v_ab0l_w, v_neg_w);
183     const __m256i v_abh0n_w = _mm256_sign_epi16(v_ab0h_w, v_neg_w);
184     const __m256i v_abl1n_w = _mm256_sign_epi16(v_ab1l_w, v_neg_w);
185     const __m256i v_abh1n_w = _mm256_sign_epi16(v_ab1h_w, v_neg_w);
186     const __m256i v_abl2n_w = _mm256_sign_epi16(v_ab2l_w, v_neg_w);
187     const __m256i v_abh2n_w = _mm256_sign_epi16(v_ab2h_w, v_neg_w);
188     const __m256i v_abl3n_w = _mm256_sign_epi16(v_ab3l_w, v_neg_w);
189     const __m256i v_abh3n_w = _mm256_sign_epi16(v_ab3h_w, v_neg_w);
190 
191     const __m256i v_r0l_w = _mm256_madd_epi16(v_ab0l_w, v_abl0n_w);
192     const __m256i v_r0h_w = _mm256_madd_epi16(v_ab0h_w, v_abh0n_w);
193     const __m256i v_r1l_w = _mm256_madd_epi16(v_ab1l_w, v_abl1n_w);
194     const __m256i v_r1h_w = _mm256_madd_epi16(v_ab1h_w, v_abh1n_w);
195     const __m256i v_r2l_w = _mm256_madd_epi16(v_ab2l_w, v_abl2n_w);
196     const __m256i v_r2h_w = _mm256_madd_epi16(v_ab2h_w, v_abh2n_w);
197     const __m256i v_r3l_w = _mm256_madd_epi16(v_ab3l_w, v_abl3n_w);
198     const __m256i v_r3h_w = _mm256_madd_epi16(v_ab3h_w, v_abh3n_w);
199 
200     const __m256i v_r0_w = _mm256_packs_epi32(v_r0l_w, v_r0h_w);
201     const __m256i v_r1_w = _mm256_packs_epi32(v_r1l_w, v_r1h_w);
202     const __m256i v_r2_w = _mm256_packs_epi32(v_r2l_w, v_r2h_w);
203     const __m256i v_r3_w = _mm256_packs_epi32(v_r3l_w, v_r3h_w);
204 
205     _mm256_store_si256((__m256i *)(d), v_r0_w);
206     _mm256_store_si256((__m256i *)(d + 16), v_r1_w);
207     _mm256_store_si256((__m256i *)(d + 32), v_r2_w);
208     _mm256_store_si256((__m256i *)(d + 48), v_r3_w);
209 
210     a += 64;
211     b += 64;
212     d += 64;
213     N -= 64;
214   } while (N);
215 }
216