1 // Copyright 2019 The libgav1 Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "src/dsp/distance_weighted_blend.h"
16 #include "src/utils/cpu.h"
17
18 #if LIBGAV1_TARGETING_SSE4_1
19
20 #include <xmmintrin.h>
21
22 #include <cassert>
23 #include <cstddef>
24 #include <cstdint>
25
26 #include "src/dsp/constants.h"
27 #include "src/dsp/dsp.h"
28 #include "src/dsp/x86/common_sse4.h"
29 #include "src/utils/common.h"
30
31 namespace libgav1 {
32 namespace dsp {
33 namespace low_bitdepth {
34 namespace {
35
36 constexpr int kInterPostRoundBit = 4;
37
ComputeWeightedAverage8(const __m128i & pred0,const __m128i & pred1,const __m128i & weights)38 inline __m128i ComputeWeightedAverage8(const __m128i& pred0,
39 const __m128i& pred1,
40 const __m128i& weights) {
41 // TODO(https://issuetracker.google.com/issues/150325685): Investigate range.
42 const __m128i preds_lo = _mm_unpacklo_epi16(pred0, pred1);
43 const __m128i mult_lo = _mm_madd_epi16(preds_lo, weights);
44 const __m128i result_lo =
45 RightShiftWithRounding_S32(mult_lo, kInterPostRoundBit + 4);
46
47 const __m128i preds_hi = _mm_unpackhi_epi16(pred0, pred1);
48 const __m128i mult_hi = _mm_madd_epi16(preds_hi, weights);
49 const __m128i result_hi =
50 RightShiftWithRounding_S32(mult_hi, kInterPostRoundBit + 4);
51
52 return _mm_packs_epi32(result_lo, result_hi);
53 }
54
55 template <int height>
DistanceWeightedBlend4xH_SSE4_1(const int16_t * LIBGAV1_RESTRICT pred_0,const int16_t * LIBGAV1_RESTRICT pred_1,const uint8_t weight_0,const uint8_t weight_1,void * LIBGAV1_RESTRICT const dest,const ptrdiff_t dest_stride)56 inline void DistanceWeightedBlend4xH_SSE4_1(
57 const int16_t* LIBGAV1_RESTRICT pred_0,
58 const int16_t* LIBGAV1_RESTRICT pred_1, const uint8_t weight_0,
59 const uint8_t weight_1, void* LIBGAV1_RESTRICT const dest,
60 const ptrdiff_t dest_stride) {
61 auto* dst = static_cast<uint8_t*>(dest);
62 const __m128i weights = _mm_set1_epi32(weight_0 | (weight_1 << 16));
63
64 for (int y = 0; y < height; y += 4) {
65 // TODO(b/150326556): Use larger loads.
66 const __m128i src_00 = LoadLo8(pred_0);
67 const __m128i src_10 = LoadLo8(pred_1);
68 pred_0 += 4;
69 pred_1 += 4;
70 __m128i src_0 = LoadHi8(src_00, pred_0);
71 __m128i src_1 = LoadHi8(src_10, pred_1);
72 pred_0 += 4;
73 pred_1 += 4;
74 const __m128i res0 = ComputeWeightedAverage8(src_0, src_1, weights);
75
76 const __m128i src_01 = LoadLo8(pred_0);
77 const __m128i src_11 = LoadLo8(pred_1);
78 pred_0 += 4;
79 pred_1 += 4;
80 src_0 = LoadHi8(src_01, pred_0);
81 src_1 = LoadHi8(src_11, pred_1);
82 pred_0 += 4;
83 pred_1 += 4;
84 const __m128i res1 = ComputeWeightedAverage8(src_0, src_1, weights);
85
86 const __m128i result_pixels = _mm_packus_epi16(res0, res1);
87 Store4(dst, result_pixels);
88 dst += dest_stride;
89 const int result_1 = _mm_extract_epi32(result_pixels, 1);
90 memcpy(dst, &result_1, sizeof(result_1));
91 dst += dest_stride;
92 const int result_2 = _mm_extract_epi32(result_pixels, 2);
93 memcpy(dst, &result_2, sizeof(result_2));
94 dst += dest_stride;
95 const int result_3 = _mm_extract_epi32(result_pixels, 3);
96 memcpy(dst, &result_3, sizeof(result_3));
97 dst += dest_stride;
98 }
99 }
100
101 template <int height>
DistanceWeightedBlend8xH_SSE4_1(const int16_t * LIBGAV1_RESTRICT pred_0,const int16_t * LIBGAV1_RESTRICT pred_1,const uint8_t weight_0,const uint8_t weight_1,void * LIBGAV1_RESTRICT const dest,const ptrdiff_t dest_stride)102 inline void DistanceWeightedBlend8xH_SSE4_1(
103 const int16_t* LIBGAV1_RESTRICT pred_0,
104 const int16_t* LIBGAV1_RESTRICT pred_1, const uint8_t weight_0,
105 const uint8_t weight_1, void* LIBGAV1_RESTRICT const dest,
106 const ptrdiff_t dest_stride) {
107 auto* dst = static_cast<uint8_t*>(dest);
108 const __m128i weights = _mm_set1_epi32(weight_0 | (weight_1 << 16));
109
110 for (int y = 0; y < height; y += 2) {
111 const __m128i src_00 = LoadAligned16(pred_0);
112 const __m128i src_10 = LoadAligned16(pred_1);
113 pred_0 += 8;
114 pred_1 += 8;
115 const __m128i res0 = ComputeWeightedAverage8(src_00, src_10, weights);
116
117 const __m128i src_01 = LoadAligned16(pred_0);
118 const __m128i src_11 = LoadAligned16(pred_1);
119 pred_0 += 8;
120 pred_1 += 8;
121 const __m128i res1 = ComputeWeightedAverage8(src_01, src_11, weights);
122
123 const __m128i result_pixels = _mm_packus_epi16(res0, res1);
124 StoreLo8(dst, result_pixels);
125 dst += dest_stride;
126 StoreHi8(dst, result_pixels);
127 dst += dest_stride;
128 }
129 }
130
DistanceWeightedBlendLarge_SSE4_1(const int16_t * LIBGAV1_RESTRICT pred_0,const int16_t * LIBGAV1_RESTRICT pred_1,const uint8_t weight_0,const uint8_t weight_1,const int width,const int height,void * LIBGAV1_RESTRICT const dest,const ptrdiff_t dest_stride)131 inline void DistanceWeightedBlendLarge_SSE4_1(
132 const int16_t* LIBGAV1_RESTRICT pred_0,
133 const int16_t* LIBGAV1_RESTRICT pred_1, const uint8_t weight_0,
134 const uint8_t weight_1, const int width, const int height,
135 void* LIBGAV1_RESTRICT const dest, const ptrdiff_t dest_stride) {
136 auto* dst = static_cast<uint8_t*>(dest);
137 const __m128i weights = _mm_set1_epi32(weight_0 | (weight_1 << 16));
138
139 int y = height;
140 do {
141 int x = 0;
142 do {
143 const __m128i src_0_lo = LoadAligned16(pred_0 + x);
144 const __m128i src_1_lo = LoadAligned16(pred_1 + x);
145 const __m128i res_lo =
146 ComputeWeightedAverage8(src_0_lo, src_1_lo, weights);
147
148 const __m128i src_0_hi = LoadAligned16(pred_0 + x + 8);
149 const __m128i src_1_hi = LoadAligned16(pred_1 + x + 8);
150 const __m128i res_hi =
151 ComputeWeightedAverage8(src_0_hi, src_1_hi, weights);
152
153 StoreUnaligned16(dst + x, _mm_packus_epi16(res_lo, res_hi));
154 x += 16;
155 } while (x < width);
156 dst += dest_stride;
157 pred_0 += width;
158 pred_1 += width;
159 } while (--y != 0);
160 }
161
DistanceWeightedBlend_SSE4_1(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,const uint8_t weight_0,const uint8_t weight_1,const int width,const int height,void * LIBGAV1_RESTRICT const dest,const ptrdiff_t dest_stride)162 void DistanceWeightedBlend_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
163 const void* LIBGAV1_RESTRICT prediction_1,
164 const uint8_t weight_0,
165 const uint8_t weight_1, const int width,
166 const int height,
167 void* LIBGAV1_RESTRICT const dest,
168 const ptrdiff_t dest_stride) {
169 const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
170 const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
171 if (width == 4) {
172 if (height == 4) {
173 DistanceWeightedBlend4xH_SSE4_1<4>(pred_0, pred_1, weight_0, weight_1,
174 dest, dest_stride);
175 } else if (height == 8) {
176 DistanceWeightedBlend4xH_SSE4_1<8>(pred_0, pred_1, weight_0, weight_1,
177 dest, dest_stride);
178 } else {
179 assert(height == 16);
180 DistanceWeightedBlend4xH_SSE4_1<16>(pred_0, pred_1, weight_0, weight_1,
181 dest, dest_stride);
182 }
183 return;
184 }
185
186 if (width == 8) {
187 switch (height) {
188 case 4:
189 DistanceWeightedBlend8xH_SSE4_1<4>(pred_0, pred_1, weight_0, weight_1,
190 dest, dest_stride);
191 return;
192 case 8:
193 DistanceWeightedBlend8xH_SSE4_1<8>(pred_0, pred_1, weight_0, weight_1,
194 dest, dest_stride);
195 return;
196 case 16:
197 DistanceWeightedBlend8xH_SSE4_1<16>(pred_0, pred_1, weight_0, weight_1,
198 dest, dest_stride);
199 return;
200 default:
201 assert(height == 32);
202 DistanceWeightedBlend8xH_SSE4_1<32>(pred_0, pred_1, weight_0, weight_1,
203 dest, dest_stride);
204
205 return;
206 }
207 }
208
209 DistanceWeightedBlendLarge_SSE4_1(pred_0, pred_1, weight_0, weight_1, width,
210 height, dest, dest_stride);
211 }
212
Init8bpp()213 void Init8bpp() {
214 Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
215 assert(dsp != nullptr);
216 #if DSP_ENABLED_8BPP_SSE4_1(DistanceWeightedBlend)
217 dsp->distance_weighted_blend = DistanceWeightedBlend_SSE4_1;
218 #endif
219 }
220
221 } // namespace
222 } // namespace low_bitdepth
223
224 #if LIBGAV1_MAX_BITDEPTH >= 10
225 namespace high_bitdepth {
226 namespace {
227
228 constexpr int kMax10bppSample = (1 << 10) - 1;
229 constexpr int kInterPostRoundBit = 4;
230
ComputeWeightedAverage8(const __m128i & pred0,const __m128i & pred1,const __m128i & weight0,const __m128i & weight1)231 inline __m128i ComputeWeightedAverage8(const __m128i& pred0,
232 const __m128i& pred1,
233 const __m128i& weight0,
234 const __m128i& weight1) {
235 // This offset is a combination of round_factor and round_offset
236 // which are to be added and subtracted respectively.
237 // Here kInterPostRoundBit + 4 is considering bitdepth=10.
238 constexpr int offset =
239 (1 << ((kInterPostRoundBit + 4) - 1)) - (kCompoundOffset << 4);
240 const __m128i zero = _mm_setzero_si128();
241 const __m128i bias = _mm_set1_epi32(offset);
242 const __m128i clip_high = _mm_set1_epi16(kMax10bppSample);
243
244 __m128i prediction0 = _mm_cvtepu16_epi32(pred0);
245 __m128i mult0 = _mm_mullo_epi32(prediction0, weight0);
246 __m128i prediction1 = _mm_cvtepu16_epi32(pred1);
247 __m128i mult1 = _mm_mullo_epi32(prediction1, weight1);
248 __m128i sum = _mm_add_epi32(mult0, mult1);
249 sum = _mm_add_epi32(sum, bias);
250 const __m128i result0 = _mm_srai_epi32(sum, kInterPostRoundBit + 4);
251
252 prediction0 = _mm_unpackhi_epi16(pred0, zero);
253 mult0 = _mm_mullo_epi32(prediction0, weight0);
254 prediction1 = _mm_unpackhi_epi16(pred1, zero);
255 mult1 = _mm_mullo_epi32(prediction1, weight1);
256 sum = _mm_add_epi32(mult0, mult1);
257 sum = _mm_add_epi32(sum, bias);
258 const __m128i result1 = _mm_srai_epi32(sum, kInterPostRoundBit + 4);
259 const __m128i pack = _mm_packus_epi32(result0, result1);
260
261 return _mm_min_epi16(pack, clip_high);
262 }
263
264 template <int height>
DistanceWeightedBlend4xH_SSE4_1(const uint16_t * LIBGAV1_RESTRICT pred_0,const uint16_t * LIBGAV1_RESTRICT pred_1,const uint8_t weight_0,const uint8_t weight_1,void * LIBGAV1_RESTRICT const dest,const ptrdiff_t dest_stride)265 inline void DistanceWeightedBlend4xH_SSE4_1(
266 const uint16_t* LIBGAV1_RESTRICT pred_0,
267 const uint16_t* LIBGAV1_RESTRICT pred_1, const uint8_t weight_0,
268 const uint8_t weight_1, void* LIBGAV1_RESTRICT const dest,
269 const ptrdiff_t dest_stride) {
270 auto* dst = static_cast<uint16_t*>(dest);
271 const __m128i weight0 = _mm_set1_epi32(weight_0);
272 const __m128i weight1 = _mm_set1_epi32(weight_1);
273
274 int y = height;
275 do {
276 const __m128i src_00 = LoadLo8(pred_0);
277 const __m128i src_10 = LoadLo8(pred_1);
278 pred_0 += 4;
279 pred_1 += 4;
280 __m128i src_0 = LoadHi8(src_00, pred_0);
281 __m128i src_1 = LoadHi8(src_10, pred_1);
282 pred_0 += 4;
283 pred_1 += 4;
284 const __m128i res0 =
285 ComputeWeightedAverage8(src_0, src_1, weight0, weight1);
286
287 const __m128i src_01 = LoadLo8(pred_0);
288 const __m128i src_11 = LoadLo8(pred_1);
289 pred_0 += 4;
290 pred_1 += 4;
291 src_0 = LoadHi8(src_01, pred_0);
292 src_1 = LoadHi8(src_11, pred_1);
293 pred_0 += 4;
294 pred_1 += 4;
295 const __m128i res1 =
296 ComputeWeightedAverage8(src_0, src_1, weight0, weight1);
297
298 StoreLo8(dst, res0);
299 dst += dest_stride;
300 StoreHi8(dst, res0);
301 dst += dest_stride;
302 StoreLo8(dst, res1);
303 dst += dest_stride;
304 StoreHi8(dst, res1);
305 dst += dest_stride;
306 y -= 4;
307 } while (y != 0);
308 }
309
310 template <int height>
DistanceWeightedBlend8xH_SSE4_1(const uint16_t * LIBGAV1_RESTRICT pred_0,const uint16_t * LIBGAV1_RESTRICT pred_1,const uint8_t weight_0,const uint8_t weight_1,void * LIBGAV1_RESTRICT const dest,const ptrdiff_t dest_stride)311 inline void DistanceWeightedBlend8xH_SSE4_1(
312 const uint16_t* LIBGAV1_RESTRICT pred_0,
313 const uint16_t* LIBGAV1_RESTRICT pred_1, const uint8_t weight_0,
314 const uint8_t weight_1, void* LIBGAV1_RESTRICT const dest,
315 const ptrdiff_t dest_stride) {
316 auto* dst = static_cast<uint16_t*>(dest);
317 const __m128i weight0 = _mm_set1_epi32(weight_0);
318 const __m128i weight1 = _mm_set1_epi32(weight_1);
319
320 int y = height;
321 do {
322 const __m128i src_00 = LoadAligned16(pred_0);
323 const __m128i src_10 = LoadAligned16(pred_1);
324 pred_0 += 8;
325 pred_1 += 8;
326 const __m128i res0 =
327 ComputeWeightedAverage8(src_00, src_10, weight0, weight1);
328
329 const __m128i src_01 = LoadAligned16(pred_0);
330 const __m128i src_11 = LoadAligned16(pred_1);
331 pred_0 += 8;
332 pred_1 += 8;
333 const __m128i res1 =
334 ComputeWeightedAverage8(src_01, src_11, weight0, weight1);
335
336 StoreUnaligned16(dst, res0);
337 dst += dest_stride;
338 StoreUnaligned16(dst, res1);
339 dst += dest_stride;
340 y -= 2;
341 } while (y != 0);
342 }
343
DistanceWeightedBlendLarge_SSE4_1(const uint16_t * LIBGAV1_RESTRICT pred_0,const uint16_t * LIBGAV1_RESTRICT pred_1,const uint8_t weight_0,const uint8_t weight_1,const int width,const int height,void * LIBGAV1_RESTRICT const dest,const ptrdiff_t dest_stride)344 inline void DistanceWeightedBlendLarge_SSE4_1(
345 const uint16_t* LIBGAV1_RESTRICT pred_0,
346 const uint16_t* LIBGAV1_RESTRICT pred_1, const uint8_t weight_0,
347 const uint8_t weight_1, const int width, const int height,
348 void* LIBGAV1_RESTRICT const dest, const ptrdiff_t dest_stride) {
349 auto* dst = static_cast<uint16_t*>(dest);
350 const __m128i weight0 = _mm_set1_epi32(weight_0);
351 const __m128i weight1 = _mm_set1_epi32(weight_1);
352
353 int y = height;
354 do {
355 int x = 0;
356 do {
357 const __m128i src_0_lo = LoadAligned16(pred_0 + x);
358 const __m128i src_1_lo = LoadAligned16(pred_1 + x);
359 const __m128i res_lo =
360 ComputeWeightedAverage8(src_0_lo, src_1_lo, weight0, weight1);
361
362 const __m128i src_0_hi = LoadAligned16(pred_0 + x + 8);
363 const __m128i src_1_hi = LoadAligned16(pred_1 + x + 8);
364 const __m128i res_hi =
365 ComputeWeightedAverage8(src_0_hi, src_1_hi, weight0, weight1);
366
367 StoreUnaligned16(dst + x, res_lo);
368 x += 8;
369 StoreUnaligned16(dst + x, res_hi);
370 x += 8;
371 } while (x < width);
372 dst += dest_stride;
373 pred_0 += width;
374 pred_1 += width;
375 } while (--y != 0);
376 }
377
DistanceWeightedBlend_SSE4_1(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,const uint8_t weight_0,const uint8_t weight_1,const int width,const int height,void * LIBGAV1_RESTRICT const dest,const ptrdiff_t dest_stride)378 void DistanceWeightedBlend_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
379 const void* LIBGAV1_RESTRICT prediction_1,
380 const uint8_t weight_0,
381 const uint8_t weight_1, const int width,
382 const int height,
383 void* LIBGAV1_RESTRICT const dest,
384 const ptrdiff_t dest_stride) {
385 const auto* pred_0 = static_cast<const uint16_t*>(prediction_0);
386 const auto* pred_1 = static_cast<const uint16_t*>(prediction_1);
387 const ptrdiff_t dst_stride = dest_stride / sizeof(*pred_0);
388 if (width == 4) {
389 if (height == 4) {
390 DistanceWeightedBlend4xH_SSE4_1<4>(pred_0, pred_1, weight_0, weight_1,
391 dest, dst_stride);
392 } else if (height == 8) {
393 DistanceWeightedBlend4xH_SSE4_1<8>(pred_0, pred_1, weight_0, weight_1,
394 dest, dst_stride);
395 } else {
396 assert(height == 16);
397 DistanceWeightedBlend4xH_SSE4_1<16>(pred_0, pred_1, weight_0, weight_1,
398 dest, dst_stride);
399 }
400 return;
401 }
402
403 if (width == 8) {
404 switch (height) {
405 case 4:
406 DistanceWeightedBlend8xH_SSE4_1<4>(pred_0, pred_1, weight_0, weight_1,
407 dest, dst_stride);
408 return;
409 case 8:
410 DistanceWeightedBlend8xH_SSE4_1<8>(pred_0, pred_1, weight_0, weight_1,
411 dest, dst_stride);
412 return;
413 case 16:
414 DistanceWeightedBlend8xH_SSE4_1<16>(pred_0, pred_1, weight_0, weight_1,
415 dest, dst_stride);
416 return;
417 default:
418 assert(height == 32);
419 DistanceWeightedBlend8xH_SSE4_1<32>(pred_0, pred_1, weight_0, weight_1,
420 dest, dst_stride);
421
422 return;
423 }
424 }
425
426 DistanceWeightedBlendLarge_SSE4_1(pred_0, pred_1, weight_0, weight_1, width,
427 height, dest, dst_stride);
428 }
429
Init10bpp()430 void Init10bpp() {
431 Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth10);
432 assert(dsp != nullptr);
433 #if DSP_ENABLED_10BPP_SSE4_1(DistanceWeightedBlend)
434 dsp->distance_weighted_blend = DistanceWeightedBlend_SSE4_1;
435 #endif
436 }
437
438 } // namespace
439 } // namespace high_bitdepth
440 #endif // LIBGAV1_MAX_BITDEPTH >= 10
441
DistanceWeightedBlendInit_SSE4_1()442 void DistanceWeightedBlendInit_SSE4_1() {
443 low_bitdepth::Init8bpp();
444 #if LIBGAV1_MAX_BITDEPTH >= 10
445 high_bitdepth::Init10bpp();
446 #endif
447 }
448
449 } // namespace dsp
450 } // namespace libgav1
451
452 #else // !LIBGAV1_TARGETING_SSE4_1
453
454 namespace libgav1 {
455 namespace dsp {
456
DistanceWeightedBlendInit_SSE4_1()457 void DistanceWeightedBlendInit_SSE4_1() {}
458
459 } // namespace dsp
460 } // namespace libgav1
461 #endif // LIBGAV1_TARGETING_SSE4_1
462