1 /*
2 * Copyright (c) 2017, Alliance for Open Media. All rights reserved
3 *
4 * This source code is subject to the terms of the BSD 2 Clause License and
5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6 * was not distributed with this source code in the LICENSE file, you can
7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8 * Media Patent License 1.0 was not distributed with this source code in the
9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10 */
11
12 #include <stdio.h>
13 #include <tmmintrin.h>
14
15 #include "config/aom_config.h"
16 #include "config/aom_dsp_rtcd.h"
17
18 #include "aom_dsp/blend.h"
19 #include "aom/aom_integer.h"
20 #include "aom_dsp/x86/synonyms.h"
21
22 #include "aom_dsp/x86/masked_sad_intrin_ssse3.h"
23
24 // For width a multiple of 16
25 static INLINE unsigned int masked_sad_ssse3(const uint8_t *src_ptr,
26 int src_stride,
27 const uint8_t *a_ptr, int a_stride,
28 const uint8_t *b_ptr, int b_stride,
29 const uint8_t *m_ptr, int m_stride,
30 int width, int height);
31
32 #define MASKSADMXN_SSSE3(m, n) \
33 unsigned int aom_masked_sad##m##x##n##_ssse3( \
34 const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride, \
35 const uint8_t *second_pred, const uint8_t *msk, int msk_stride, \
36 int invert_mask) { \
37 if (!invert_mask) \
38 return masked_sad_ssse3(src, src_stride, ref, ref_stride, second_pred, \
39 m, msk, msk_stride, m, n); \
40 else \
41 return masked_sad_ssse3(src, src_stride, second_pred, m, ref, \
42 ref_stride, msk, msk_stride, m, n); \
43 }
44
45 #define MASKSAD8XN_SSSE3(n) \
46 unsigned int aom_masked_sad8x##n##_ssse3( \
47 const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride, \
48 const uint8_t *second_pred, const uint8_t *msk, int msk_stride, \
49 int invert_mask) { \
50 if (!invert_mask) \
51 return aom_masked_sad8xh_ssse3(src, src_stride, ref, ref_stride, \
52 second_pred, 8, msk, msk_stride, n); \
53 else \
54 return aom_masked_sad8xh_ssse3(src, src_stride, second_pred, 8, ref, \
55 ref_stride, msk, msk_stride, n); \
56 }
57
58 #define MASKSAD4XN_SSSE3(n) \
59 unsigned int aom_masked_sad4x##n##_ssse3( \
60 const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride, \
61 const uint8_t *second_pred, const uint8_t *msk, int msk_stride, \
62 int invert_mask) { \
63 if (!invert_mask) \
64 return aom_masked_sad4xh_ssse3(src, src_stride, ref, ref_stride, \
65 second_pred, 4, msk, msk_stride, n); \
66 else \
67 return aom_masked_sad4xh_ssse3(src, src_stride, second_pred, 4, ref, \
68 ref_stride, msk, msk_stride, n); \
69 }
70
71 MASKSADMXN_SSSE3(128, 128)
72 MASKSADMXN_SSSE3(128, 64)
73 MASKSADMXN_SSSE3(64, 128)
74 MASKSADMXN_SSSE3(64, 64)
75 MASKSADMXN_SSSE3(64, 32)
76 MASKSADMXN_SSSE3(32, 64)
77 MASKSADMXN_SSSE3(32, 32)
78 MASKSADMXN_SSSE3(32, 16)
79 MASKSADMXN_SSSE3(16, 32)
80 MASKSADMXN_SSSE3(16, 16)
81 MASKSADMXN_SSSE3(16, 8)
82 MASKSAD8XN_SSSE3(16)
83 MASKSAD8XN_SSSE3(8)
84 MASKSAD8XN_SSSE3(4)
85 MASKSAD4XN_SSSE3(8)
86 MASKSAD4XN_SSSE3(4)
87 MASKSAD4XN_SSSE3(16)
88 MASKSADMXN_SSSE3(16, 4)
89 MASKSAD8XN_SSSE3(32)
90 MASKSADMXN_SSSE3(32, 8)
91 MASKSADMXN_SSSE3(16, 64)
92 MASKSADMXN_SSSE3(64, 16)
93
masked_sad_ssse3(const uint8_t * src_ptr,int src_stride,const uint8_t * a_ptr,int a_stride,const uint8_t * b_ptr,int b_stride,const uint8_t * m_ptr,int m_stride,int width,int height)94 static INLINE unsigned int masked_sad_ssse3(const uint8_t *src_ptr,
95 int src_stride,
96 const uint8_t *a_ptr, int a_stride,
97 const uint8_t *b_ptr, int b_stride,
98 const uint8_t *m_ptr, int m_stride,
99 int width, int height) {
100 int x, y;
101 __m128i res = _mm_setzero_si128();
102 const __m128i mask_max = _mm_set1_epi8((1 << AOM_BLEND_A64_ROUND_BITS));
103
104 for (y = 0; y < height; y++) {
105 for (x = 0; x < width; x += 16) {
106 const __m128i src = _mm_loadu_si128((const __m128i *)&src_ptr[x]);
107 const __m128i a = _mm_loadu_si128((const __m128i *)&a_ptr[x]);
108 const __m128i b = _mm_loadu_si128((const __m128i *)&b_ptr[x]);
109 const __m128i m = _mm_loadu_si128((const __m128i *)&m_ptr[x]);
110 const __m128i m_inv = _mm_sub_epi8(mask_max, m);
111
112 // Calculate 16 predicted pixels.
113 // Note that the maximum value of any entry of 'pred_l' or 'pred_r'
114 // is 64 * 255, so we have plenty of space to add rounding constants.
115 const __m128i data_l = _mm_unpacklo_epi8(a, b);
116 const __m128i mask_l = _mm_unpacklo_epi8(m, m_inv);
117 __m128i pred_l = _mm_maddubs_epi16(data_l, mask_l);
118 pred_l = xx_roundn_epu16(pred_l, AOM_BLEND_A64_ROUND_BITS);
119
120 const __m128i data_r = _mm_unpackhi_epi8(a, b);
121 const __m128i mask_r = _mm_unpackhi_epi8(m, m_inv);
122 __m128i pred_r = _mm_maddubs_epi16(data_r, mask_r);
123 pred_r = xx_roundn_epu16(pred_r, AOM_BLEND_A64_ROUND_BITS);
124
125 const __m128i pred = _mm_packus_epi16(pred_l, pred_r);
126 res = _mm_add_epi32(res, _mm_sad_epu8(pred, src));
127 }
128
129 src_ptr += src_stride;
130 a_ptr += a_stride;
131 b_ptr += b_stride;
132 m_ptr += m_stride;
133 }
134 // At this point, we have two 32-bit partial SADs in lanes 0 and 2 of 'res'.
135 int32_t sad =
136 _mm_cvtsi128_si32(res) + _mm_cvtsi128_si32(_mm_srli_si128(res, 8));
137 return (sad + 31) >> 6;
138 }
139
aom_masked_sad8xh_ssse3(const uint8_t * src_ptr,int src_stride,const uint8_t * a_ptr,int a_stride,const uint8_t * b_ptr,int b_stride,const uint8_t * m_ptr,int m_stride,int height)140 unsigned int aom_masked_sad8xh_ssse3(const uint8_t *src_ptr, int src_stride,
141 const uint8_t *a_ptr, int a_stride,
142 const uint8_t *b_ptr, int b_stride,
143 const uint8_t *m_ptr, int m_stride,
144 int height) {
145 int y;
146 __m128i res = _mm_setzero_si128();
147 const __m128i mask_max = _mm_set1_epi8((1 << AOM_BLEND_A64_ROUND_BITS));
148
149 for (y = 0; y < height; y += 2) {
150 const __m128i src = _mm_unpacklo_epi64(
151 _mm_loadl_epi64((const __m128i *)src_ptr),
152 _mm_loadl_epi64((const __m128i *)&src_ptr[src_stride]));
153 const __m128i a0 = _mm_loadl_epi64((const __m128i *)a_ptr);
154 const __m128i a1 = _mm_loadl_epi64((const __m128i *)&a_ptr[a_stride]);
155 const __m128i b0 = _mm_loadl_epi64((const __m128i *)b_ptr);
156 const __m128i b1 = _mm_loadl_epi64((const __m128i *)&b_ptr[b_stride]);
157 const __m128i m =
158 _mm_unpacklo_epi64(_mm_loadl_epi64((const __m128i *)m_ptr),
159 _mm_loadl_epi64((const __m128i *)&m_ptr[m_stride]));
160 const __m128i m_inv = _mm_sub_epi8(mask_max, m);
161
162 const __m128i data_l = _mm_unpacklo_epi8(a0, b0);
163 const __m128i mask_l = _mm_unpacklo_epi8(m, m_inv);
164 __m128i pred_l = _mm_maddubs_epi16(data_l, mask_l);
165 pred_l = xx_roundn_epu16(pred_l, AOM_BLEND_A64_ROUND_BITS);
166
167 const __m128i data_r = _mm_unpacklo_epi8(a1, b1);
168 const __m128i mask_r = _mm_unpackhi_epi8(m, m_inv);
169 __m128i pred_r = _mm_maddubs_epi16(data_r, mask_r);
170 pred_r = xx_roundn_epu16(pred_r, AOM_BLEND_A64_ROUND_BITS);
171
172 const __m128i pred = _mm_packus_epi16(pred_l, pred_r);
173 res = _mm_add_epi32(res, _mm_sad_epu8(pred, src));
174
175 src_ptr += src_stride * 2;
176 a_ptr += a_stride * 2;
177 b_ptr += b_stride * 2;
178 m_ptr += m_stride * 2;
179 }
180 int32_t sad =
181 _mm_cvtsi128_si32(res) + _mm_cvtsi128_si32(_mm_srli_si128(res, 8));
182 return (sad + 31) >> 6;
183 }
184
aom_masked_sad4xh_ssse3(const uint8_t * src_ptr,int src_stride,const uint8_t * a_ptr,int a_stride,const uint8_t * b_ptr,int b_stride,const uint8_t * m_ptr,int m_stride,int height)185 unsigned int aom_masked_sad4xh_ssse3(const uint8_t *src_ptr, int src_stride,
186 const uint8_t *a_ptr, int a_stride,
187 const uint8_t *b_ptr, int b_stride,
188 const uint8_t *m_ptr, int m_stride,
189 int height) {
190 int y;
191 __m128i res = _mm_setzero_si128();
192 const __m128i mask_max = _mm_set1_epi8((1 << AOM_BLEND_A64_ROUND_BITS));
193
194 for (y = 0; y < height; y += 2) {
195 // Load two rows at a time, this seems to be a bit faster
196 // than four rows at a time in this case.
197 const __m128i src = _mm_unpacklo_epi32(
198 _mm_cvtsi32_si128(*(uint32_t *)src_ptr),
199 _mm_cvtsi32_si128(*(uint32_t *)&src_ptr[src_stride]));
200 const __m128i a =
201 _mm_unpacklo_epi32(_mm_cvtsi32_si128(*(uint32_t *)a_ptr),
202 _mm_cvtsi32_si128(*(uint32_t *)&a_ptr[a_stride]));
203 const __m128i b =
204 _mm_unpacklo_epi32(_mm_cvtsi32_si128(*(uint32_t *)b_ptr),
205 _mm_cvtsi32_si128(*(uint32_t *)&b_ptr[b_stride]));
206 const __m128i m =
207 _mm_unpacklo_epi32(_mm_cvtsi32_si128(*(uint32_t *)m_ptr),
208 _mm_cvtsi32_si128(*(uint32_t *)&m_ptr[m_stride]));
209 const __m128i m_inv = _mm_sub_epi8(mask_max, m);
210
211 const __m128i data = _mm_unpacklo_epi8(a, b);
212 const __m128i mask = _mm_unpacklo_epi8(m, m_inv);
213 __m128i pred_16bit = _mm_maddubs_epi16(data, mask);
214 pred_16bit = xx_roundn_epu16(pred_16bit, AOM_BLEND_A64_ROUND_BITS);
215
216 const __m128i pred = _mm_packus_epi16(pred_16bit, _mm_setzero_si128());
217 res = _mm_add_epi32(res, _mm_sad_epu8(pred, src));
218
219 src_ptr += src_stride * 2;
220 a_ptr += a_stride * 2;
221 b_ptr += b_stride * 2;
222 m_ptr += m_stride * 2;
223 }
224 // At this point, the SAD is stored in lane 0 of 'res'
225 int32_t sad = _mm_cvtsi128_si32(res);
226 return (sad + 31) >> 6;
227 }
228
229 // For width a multiple of 8
230 static INLINE unsigned int highbd_masked_sad_ssse3(
231 const uint8_t *src8, int src_stride, const uint8_t *a8, int a_stride,
232 const uint8_t *b8, int b_stride, const uint8_t *m_ptr, int m_stride,
233 int width, int height);
234
235 #define HIGHBD_MASKSADMXN_SSSE3(m, n) \
236 unsigned int aom_highbd_masked_sad##m##x##n##_ssse3( \
237 const uint8_t *src8, int src_stride, const uint8_t *ref8, \
238 int ref_stride, const uint8_t *second_pred8, const uint8_t *msk, \
239 int msk_stride, int invert_mask) { \
240 if (!invert_mask) \
241 return highbd_masked_sad_ssse3(src8, src_stride, ref8, ref_stride, \
242 second_pred8, m, msk, msk_stride, m, n); \
243 else \
244 return highbd_masked_sad_ssse3(src8, src_stride, second_pred8, m, ref8, \
245 ref_stride, msk, msk_stride, m, n); \
246 }
247
248 #define HIGHBD_MASKSAD4XN_SSSE3(n) \
249 unsigned int aom_highbd_masked_sad4x##n##_ssse3( \
250 const uint8_t *src8, int src_stride, const uint8_t *ref8, \
251 int ref_stride, const uint8_t *second_pred8, const uint8_t *msk, \
252 int msk_stride, int invert_mask) { \
253 if (!invert_mask) \
254 return aom_highbd_masked_sad4xh_ssse3(src8, src_stride, ref8, \
255 ref_stride, second_pred8, 4, msk, \
256 msk_stride, n); \
257 else \
258 return aom_highbd_masked_sad4xh_ssse3(src8, src_stride, second_pred8, 4, \
259 ref8, ref_stride, msk, msk_stride, \
260 n); \
261 }
262
263 HIGHBD_MASKSADMXN_SSSE3(128, 128)
264 HIGHBD_MASKSADMXN_SSSE3(128, 64)
265 HIGHBD_MASKSADMXN_SSSE3(64, 128)
266 HIGHBD_MASKSADMXN_SSSE3(64, 64)
267 HIGHBD_MASKSADMXN_SSSE3(64, 32)
268 HIGHBD_MASKSADMXN_SSSE3(32, 64)
269 HIGHBD_MASKSADMXN_SSSE3(32, 32)
270 HIGHBD_MASKSADMXN_SSSE3(32, 16)
271 HIGHBD_MASKSADMXN_SSSE3(16, 32)
272 HIGHBD_MASKSADMXN_SSSE3(16, 16)
273 HIGHBD_MASKSADMXN_SSSE3(16, 8)
274 HIGHBD_MASKSADMXN_SSSE3(8, 16)
275 HIGHBD_MASKSADMXN_SSSE3(8, 8)
276 HIGHBD_MASKSADMXN_SSSE3(8, 4)
277 HIGHBD_MASKSAD4XN_SSSE3(8)
278 HIGHBD_MASKSAD4XN_SSSE3(4)
279 HIGHBD_MASKSAD4XN_SSSE3(16)
280 HIGHBD_MASKSADMXN_SSSE3(16, 4)
281 HIGHBD_MASKSADMXN_SSSE3(8, 32)
282 HIGHBD_MASKSADMXN_SSSE3(32, 8)
283 HIGHBD_MASKSADMXN_SSSE3(16, 64)
284 HIGHBD_MASKSADMXN_SSSE3(64, 16)
285
highbd_masked_sad_ssse3(const uint8_t * src8,int src_stride,const uint8_t * a8,int a_stride,const uint8_t * b8,int b_stride,const uint8_t * m_ptr,int m_stride,int width,int height)286 static INLINE unsigned int highbd_masked_sad_ssse3(
287 const uint8_t *src8, int src_stride, const uint8_t *a8, int a_stride,
288 const uint8_t *b8, int b_stride, const uint8_t *m_ptr, int m_stride,
289 int width, int height) {
290 const uint16_t *src_ptr = CONVERT_TO_SHORTPTR(src8);
291 const uint16_t *a_ptr = CONVERT_TO_SHORTPTR(a8);
292 const uint16_t *b_ptr = CONVERT_TO_SHORTPTR(b8);
293 int x, y;
294 __m128i res = _mm_setzero_si128();
295 const __m128i mask_max = _mm_set1_epi16((1 << AOM_BLEND_A64_ROUND_BITS));
296 const __m128i round_const =
297 _mm_set1_epi32((1 << AOM_BLEND_A64_ROUND_BITS) >> 1);
298 const __m128i one = _mm_set1_epi16(1);
299
300 for (y = 0; y < height; y++) {
301 for (x = 0; x < width; x += 8) {
302 const __m128i src = _mm_loadu_si128((const __m128i *)&src_ptr[x]);
303 const __m128i a = _mm_loadu_si128((const __m128i *)&a_ptr[x]);
304 const __m128i b = _mm_loadu_si128((const __m128i *)&b_ptr[x]);
305 // Zero-extend mask to 16 bits
306 const __m128i m = _mm_unpacklo_epi8(
307 _mm_loadl_epi64((const __m128i *)&m_ptr[x]), _mm_setzero_si128());
308 const __m128i m_inv = _mm_sub_epi16(mask_max, m);
309
310 const __m128i data_l = _mm_unpacklo_epi16(a, b);
311 const __m128i mask_l = _mm_unpacklo_epi16(m, m_inv);
312 __m128i pred_l = _mm_madd_epi16(data_l, mask_l);
313 pred_l = _mm_srai_epi32(_mm_add_epi32(pred_l, round_const),
314 AOM_BLEND_A64_ROUND_BITS);
315
316 const __m128i data_r = _mm_unpackhi_epi16(a, b);
317 const __m128i mask_r = _mm_unpackhi_epi16(m, m_inv);
318 __m128i pred_r = _mm_madd_epi16(data_r, mask_r);
319 pred_r = _mm_srai_epi32(_mm_add_epi32(pred_r, round_const),
320 AOM_BLEND_A64_ROUND_BITS);
321
322 // Note: the maximum value in pred_l/r is (2^bd)-1 < 2^15,
323 // so it is safe to do signed saturation here.
324 const __m128i pred = _mm_packs_epi32(pred_l, pred_r);
325 // There is no 16-bit SAD instruction, so we have to synthesize
326 // an 8-element SAD. We do this by storing 4 32-bit partial SADs,
327 // and accumulating them at the end
328 const __m128i diff = _mm_abs_epi16(_mm_sub_epi16(pred, src));
329 res = _mm_add_epi32(res, _mm_madd_epi16(diff, one));
330 }
331
332 src_ptr += src_stride;
333 a_ptr += a_stride;
334 b_ptr += b_stride;
335 m_ptr += m_stride;
336 }
337 // At this point, we have four 32-bit partial SADs stored in 'res'.
338 res = _mm_hadd_epi32(res, res);
339 res = _mm_hadd_epi32(res, res);
340 int sad = _mm_cvtsi128_si32(res);
341 return (sad + 31) >> 6;
342 }
343
aom_highbd_masked_sad4xh_ssse3(const uint8_t * src8,int src_stride,const uint8_t * a8,int a_stride,const uint8_t * b8,int b_stride,const uint8_t * m_ptr,int m_stride,int height)344 unsigned int aom_highbd_masked_sad4xh_ssse3(const uint8_t *src8, int src_stride,
345 const uint8_t *a8, int a_stride,
346 const uint8_t *b8, int b_stride,
347 const uint8_t *m_ptr, int m_stride,
348 int height) {
349 const uint16_t *src_ptr = CONVERT_TO_SHORTPTR(src8);
350 const uint16_t *a_ptr = CONVERT_TO_SHORTPTR(a8);
351 const uint16_t *b_ptr = CONVERT_TO_SHORTPTR(b8);
352 int y;
353 __m128i res = _mm_setzero_si128();
354 const __m128i mask_max = _mm_set1_epi16((1 << AOM_BLEND_A64_ROUND_BITS));
355 const __m128i round_const =
356 _mm_set1_epi32((1 << AOM_BLEND_A64_ROUND_BITS) >> 1);
357 const __m128i one = _mm_set1_epi16(1);
358
359 for (y = 0; y < height; y += 2) {
360 const __m128i src = _mm_unpacklo_epi64(
361 _mm_loadl_epi64((const __m128i *)src_ptr),
362 _mm_loadl_epi64((const __m128i *)&src_ptr[src_stride]));
363 const __m128i a =
364 _mm_unpacklo_epi64(_mm_loadl_epi64((const __m128i *)a_ptr),
365 _mm_loadl_epi64((const __m128i *)&a_ptr[a_stride]));
366 const __m128i b =
367 _mm_unpacklo_epi64(_mm_loadl_epi64((const __m128i *)b_ptr),
368 _mm_loadl_epi64((const __m128i *)&b_ptr[b_stride]));
369 // Zero-extend mask to 16 bits
370 const __m128i m = _mm_unpacklo_epi8(
371 _mm_unpacklo_epi32(
372 _mm_cvtsi32_si128(*(const uint32_t *)m_ptr),
373 _mm_cvtsi32_si128(*(const uint32_t *)&m_ptr[m_stride])),
374 _mm_setzero_si128());
375 const __m128i m_inv = _mm_sub_epi16(mask_max, m);
376
377 const __m128i data_l = _mm_unpacklo_epi16(a, b);
378 const __m128i mask_l = _mm_unpacklo_epi16(m, m_inv);
379 __m128i pred_l = _mm_madd_epi16(data_l, mask_l);
380 pred_l = _mm_srai_epi32(_mm_add_epi32(pred_l, round_const),
381 AOM_BLEND_A64_ROUND_BITS);
382
383 const __m128i data_r = _mm_unpackhi_epi16(a, b);
384 const __m128i mask_r = _mm_unpackhi_epi16(m, m_inv);
385 __m128i pred_r = _mm_madd_epi16(data_r, mask_r);
386 pred_r = _mm_srai_epi32(_mm_add_epi32(pred_r, round_const),
387 AOM_BLEND_A64_ROUND_BITS);
388
389 const __m128i pred = _mm_packs_epi32(pred_l, pred_r);
390 const __m128i diff = _mm_abs_epi16(_mm_sub_epi16(pred, src));
391 res = _mm_add_epi32(res, _mm_madd_epi16(diff, one));
392
393 src_ptr += src_stride * 2;
394 a_ptr += a_stride * 2;
395 b_ptr += b_stride * 2;
396 m_ptr += m_stride * 2;
397 }
398 res = _mm_hadd_epi32(res, res);
399 res = _mm_hadd_epi32(res, res);
400 int sad = _mm_cvtsi128_si32(res);
401 return (sad + 31) >> 6;
402 }
403