1 // Copyright 2019 The libgav1 Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "src/dsp/distance_weighted_blend.h"
16 #include "src/utils/cpu.h"
17 
18 #if LIBGAV1_TARGETING_SSE4_1
19 
20 #include <xmmintrin.h>
21 
22 #include <cassert>
23 #include <cstddef>
24 #include <cstdint>
25 
26 #include "src/dsp/constants.h"
27 #include "src/dsp/dsp.h"
28 #include "src/dsp/x86/common_sse4.h"
29 #include "src/utils/common.h"
30 
31 namespace libgav1 {
32 namespace dsp {
33 namespace low_bitdepth {
34 namespace {
35 
36 constexpr int kInterPostRoundBit = 4;
37 
ComputeWeightedAverage8(const __m128i & pred0,const __m128i & pred1,const __m128i & weights)38 inline __m128i ComputeWeightedAverage8(const __m128i& pred0,
39                                        const __m128i& pred1,
40                                        const __m128i& weights) {
41   // TODO(https://issuetracker.google.com/issues/150325685): Investigate range.
42   const __m128i preds_lo = _mm_unpacklo_epi16(pred0, pred1);
43   const __m128i mult_lo = _mm_madd_epi16(preds_lo, weights);
44   const __m128i result_lo =
45       RightShiftWithRounding_S32(mult_lo, kInterPostRoundBit + 4);
46 
47   const __m128i preds_hi = _mm_unpackhi_epi16(pred0, pred1);
48   const __m128i mult_hi = _mm_madd_epi16(preds_hi, weights);
49   const __m128i result_hi =
50       RightShiftWithRounding_S32(mult_hi, kInterPostRoundBit + 4);
51 
52   return _mm_packs_epi32(result_lo, result_hi);
53 }
54 
55 template <int height>
DistanceWeightedBlend4xH_SSE4_1(const int16_t * LIBGAV1_RESTRICT pred_0,const int16_t * LIBGAV1_RESTRICT pred_1,const uint8_t weight_0,const uint8_t weight_1,void * LIBGAV1_RESTRICT const dest,const ptrdiff_t dest_stride)56 inline void DistanceWeightedBlend4xH_SSE4_1(
57     const int16_t* LIBGAV1_RESTRICT pred_0,
58     const int16_t* LIBGAV1_RESTRICT pred_1, const uint8_t weight_0,
59     const uint8_t weight_1, void* LIBGAV1_RESTRICT const dest,
60     const ptrdiff_t dest_stride) {
61   auto* dst = static_cast<uint8_t*>(dest);
62   const __m128i weights = _mm_set1_epi32(weight_0 | (weight_1 << 16));
63 
64   for (int y = 0; y < height; y += 4) {
65     // TODO(b/150326556): Use larger loads.
66     const __m128i src_00 = LoadLo8(pred_0);
67     const __m128i src_10 = LoadLo8(pred_1);
68     pred_0 += 4;
69     pred_1 += 4;
70     __m128i src_0 = LoadHi8(src_00, pred_0);
71     __m128i src_1 = LoadHi8(src_10, pred_1);
72     pred_0 += 4;
73     pred_1 += 4;
74     const __m128i res0 = ComputeWeightedAverage8(src_0, src_1, weights);
75 
76     const __m128i src_01 = LoadLo8(pred_0);
77     const __m128i src_11 = LoadLo8(pred_1);
78     pred_0 += 4;
79     pred_1 += 4;
80     src_0 = LoadHi8(src_01, pred_0);
81     src_1 = LoadHi8(src_11, pred_1);
82     pred_0 += 4;
83     pred_1 += 4;
84     const __m128i res1 = ComputeWeightedAverage8(src_0, src_1, weights);
85 
86     const __m128i result_pixels = _mm_packus_epi16(res0, res1);
87     Store4(dst, result_pixels);
88     dst += dest_stride;
89     const int result_1 = _mm_extract_epi32(result_pixels, 1);
90     memcpy(dst, &result_1, sizeof(result_1));
91     dst += dest_stride;
92     const int result_2 = _mm_extract_epi32(result_pixels, 2);
93     memcpy(dst, &result_2, sizeof(result_2));
94     dst += dest_stride;
95     const int result_3 = _mm_extract_epi32(result_pixels, 3);
96     memcpy(dst, &result_3, sizeof(result_3));
97     dst += dest_stride;
98   }
99 }
100 
101 template <int height>
DistanceWeightedBlend8xH_SSE4_1(const int16_t * LIBGAV1_RESTRICT pred_0,const int16_t * LIBGAV1_RESTRICT pred_1,const uint8_t weight_0,const uint8_t weight_1,void * LIBGAV1_RESTRICT const dest,const ptrdiff_t dest_stride)102 inline void DistanceWeightedBlend8xH_SSE4_1(
103     const int16_t* LIBGAV1_RESTRICT pred_0,
104     const int16_t* LIBGAV1_RESTRICT pred_1, const uint8_t weight_0,
105     const uint8_t weight_1, void* LIBGAV1_RESTRICT const dest,
106     const ptrdiff_t dest_stride) {
107   auto* dst = static_cast<uint8_t*>(dest);
108   const __m128i weights = _mm_set1_epi32(weight_0 | (weight_1 << 16));
109 
110   for (int y = 0; y < height; y += 2) {
111     const __m128i src_00 = LoadAligned16(pred_0);
112     const __m128i src_10 = LoadAligned16(pred_1);
113     pred_0 += 8;
114     pred_1 += 8;
115     const __m128i res0 = ComputeWeightedAverage8(src_00, src_10, weights);
116 
117     const __m128i src_01 = LoadAligned16(pred_0);
118     const __m128i src_11 = LoadAligned16(pred_1);
119     pred_0 += 8;
120     pred_1 += 8;
121     const __m128i res1 = ComputeWeightedAverage8(src_01, src_11, weights);
122 
123     const __m128i result_pixels = _mm_packus_epi16(res0, res1);
124     StoreLo8(dst, result_pixels);
125     dst += dest_stride;
126     StoreHi8(dst, result_pixels);
127     dst += dest_stride;
128   }
129 }
130 
DistanceWeightedBlendLarge_SSE4_1(const int16_t * LIBGAV1_RESTRICT pred_0,const int16_t * LIBGAV1_RESTRICT pred_1,const uint8_t weight_0,const uint8_t weight_1,const int width,const int height,void * LIBGAV1_RESTRICT const dest,const ptrdiff_t dest_stride)131 inline void DistanceWeightedBlendLarge_SSE4_1(
132     const int16_t* LIBGAV1_RESTRICT pred_0,
133     const int16_t* LIBGAV1_RESTRICT pred_1, const uint8_t weight_0,
134     const uint8_t weight_1, const int width, const int height,
135     void* LIBGAV1_RESTRICT const dest, const ptrdiff_t dest_stride) {
136   auto* dst = static_cast<uint8_t*>(dest);
137   const __m128i weights = _mm_set1_epi32(weight_0 | (weight_1 << 16));
138 
139   int y = height;
140   do {
141     int x = 0;
142     do {
143       const __m128i src_0_lo = LoadAligned16(pred_0 + x);
144       const __m128i src_1_lo = LoadAligned16(pred_1 + x);
145       const __m128i res_lo =
146           ComputeWeightedAverage8(src_0_lo, src_1_lo, weights);
147 
148       const __m128i src_0_hi = LoadAligned16(pred_0 + x + 8);
149       const __m128i src_1_hi = LoadAligned16(pred_1 + x + 8);
150       const __m128i res_hi =
151           ComputeWeightedAverage8(src_0_hi, src_1_hi, weights);
152 
153       StoreUnaligned16(dst + x, _mm_packus_epi16(res_lo, res_hi));
154       x += 16;
155     } while (x < width);
156     dst += dest_stride;
157     pred_0 += width;
158     pred_1 += width;
159   } while (--y != 0);
160 }
161 
DistanceWeightedBlend_SSE4_1(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,const uint8_t weight_0,const uint8_t weight_1,const int width,const int height,void * LIBGAV1_RESTRICT const dest,const ptrdiff_t dest_stride)162 void DistanceWeightedBlend_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
163                                   const void* LIBGAV1_RESTRICT prediction_1,
164                                   const uint8_t weight_0,
165                                   const uint8_t weight_1, const int width,
166                                   const int height,
167                                   void* LIBGAV1_RESTRICT const dest,
168                                   const ptrdiff_t dest_stride) {
169   const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
170   const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
171   if (width == 4) {
172     if (height == 4) {
173       DistanceWeightedBlend4xH_SSE4_1<4>(pred_0, pred_1, weight_0, weight_1,
174                                          dest, dest_stride);
175     } else if (height == 8) {
176       DistanceWeightedBlend4xH_SSE4_1<8>(pred_0, pred_1, weight_0, weight_1,
177                                          dest, dest_stride);
178     } else {
179       assert(height == 16);
180       DistanceWeightedBlend4xH_SSE4_1<16>(pred_0, pred_1, weight_0, weight_1,
181                                           dest, dest_stride);
182     }
183     return;
184   }
185 
186   if (width == 8) {
187     switch (height) {
188       case 4:
189         DistanceWeightedBlend8xH_SSE4_1<4>(pred_0, pred_1, weight_0, weight_1,
190                                            dest, dest_stride);
191         return;
192       case 8:
193         DistanceWeightedBlend8xH_SSE4_1<8>(pred_0, pred_1, weight_0, weight_1,
194                                            dest, dest_stride);
195         return;
196       case 16:
197         DistanceWeightedBlend8xH_SSE4_1<16>(pred_0, pred_1, weight_0, weight_1,
198                                             dest, dest_stride);
199         return;
200       default:
201         assert(height == 32);
202         DistanceWeightedBlend8xH_SSE4_1<32>(pred_0, pred_1, weight_0, weight_1,
203                                             dest, dest_stride);
204 
205         return;
206     }
207   }
208 
209   DistanceWeightedBlendLarge_SSE4_1(pred_0, pred_1, weight_0, weight_1, width,
210                                     height, dest, dest_stride);
211 }
212 
Init8bpp()213 void Init8bpp() {
214   Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
215   assert(dsp != nullptr);
216 #if DSP_ENABLED_8BPP_SSE4_1(DistanceWeightedBlend)
217   dsp->distance_weighted_blend = DistanceWeightedBlend_SSE4_1;
218 #endif
219 }
220 
221 }  // namespace
222 }  // namespace low_bitdepth
223 
224 #if LIBGAV1_MAX_BITDEPTH >= 10
225 namespace high_bitdepth {
226 namespace {
227 
228 constexpr int kMax10bppSample = (1 << 10) - 1;
229 constexpr int kInterPostRoundBit = 4;
230 
ComputeWeightedAverage8(const __m128i & pred0,const __m128i & pred1,const __m128i & weight0,const __m128i & weight1)231 inline __m128i ComputeWeightedAverage8(const __m128i& pred0,
232                                        const __m128i& pred1,
233                                        const __m128i& weight0,
234                                        const __m128i& weight1) {
235   // This offset is a combination of round_factor and round_offset
236   // which are to be added and subtracted respectively.
237   // Here kInterPostRoundBit + 4 is considering bitdepth=10.
238   constexpr int offset =
239       (1 << ((kInterPostRoundBit + 4) - 1)) - (kCompoundOffset << 4);
240   const __m128i zero = _mm_setzero_si128();
241   const __m128i bias = _mm_set1_epi32(offset);
242   const __m128i clip_high = _mm_set1_epi16(kMax10bppSample);
243 
244   __m128i prediction0 = _mm_cvtepu16_epi32(pred0);
245   __m128i mult0 = _mm_mullo_epi32(prediction0, weight0);
246   __m128i prediction1 = _mm_cvtepu16_epi32(pred1);
247   __m128i mult1 = _mm_mullo_epi32(prediction1, weight1);
248   __m128i sum = _mm_add_epi32(mult0, mult1);
249   sum = _mm_add_epi32(sum, bias);
250   const __m128i result0 = _mm_srai_epi32(sum, kInterPostRoundBit + 4);
251 
252   prediction0 = _mm_unpackhi_epi16(pred0, zero);
253   mult0 = _mm_mullo_epi32(prediction0, weight0);
254   prediction1 = _mm_unpackhi_epi16(pred1, zero);
255   mult1 = _mm_mullo_epi32(prediction1, weight1);
256   sum = _mm_add_epi32(mult0, mult1);
257   sum = _mm_add_epi32(sum, bias);
258   const __m128i result1 = _mm_srai_epi32(sum, kInterPostRoundBit + 4);
259   const __m128i pack = _mm_packus_epi32(result0, result1);
260 
261   return _mm_min_epi16(pack, clip_high);
262 }
263 
264 template <int height>
DistanceWeightedBlend4xH_SSE4_1(const uint16_t * LIBGAV1_RESTRICT pred_0,const uint16_t * LIBGAV1_RESTRICT pred_1,const uint8_t weight_0,const uint8_t weight_1,void * LIBGAV1_RESTRICT const dest,const ptrdiff_t dest_stride)265 inline void DistanceWeightedBlend4xH_SSE4_1(
266     const uint16_t* LIBGAV1_RESTRICT pred_0,
267     const uint16_t* LIBGAV1_RESTRICT pred_1, const uint8_t weight_0,
268     const uint8_t weight_1, void* LIBGAV1_RESTRICT const dest,
269     const ptrdiff_t dest_stride) {
270   auto* dst = static_cast<uint16_t*>(dest);
271   const __m128i weight0 = _mm_set1_epi32(weight_0);
272   const __m128i weight1 = _mm_set1_epi32(weight_1);
273 
274   int y = height;
275   do {
276     const __m128i src_00 = LoadLo8(pred_0);
277     const __m128i src_10 = LoadLo8(pred_1);
278     pred_0 += 4;
279     pred_1 += 4;
280     __m128i src_0 = LoadHi8(src_00, pred_0);
281     __m128i src_1 = LoadHi8(src_10, pred_1);
282     pred_0 += 4;
283     pred_1 += 4;
284     const __m128i res0 =
285         ComputeWeightedAverage8(src_0, src_1, weight0, weight1);
286 
287     const __m128i src_01 = LoadLo8(pred_0);
288     const __m128i src_11 = LoadLo8(pred_1);
289     pred_0 += 4;
290     pred_1 += 4;
291     src_0 = LoadHi8(src_01, pred_0);
292     src_1 = LoadHi8(src_11, pred_1);
293     pred_0 += 4;
294     pred_1 += 4;
295     const __m128i res1 =
296         ComputeWeightedAverage8(src_0, src_1, weight0, weight1);
297 
298     StoreLo8(dst, res0);
299     dst += dest_stride;
300     StoreHi8(dst, res0);
301     dst += dest_stride;
302     StoreLo8(dst, res1);
303     dst += dest_stride;
304     StoreHi8(dst, res1);
305     dst += dest_stride;
306     y -= 4;
307   } while (y != 0);
308 }
309 
310 template <int height>
DistanceWeightedBlend8xH_SSE4_1(const uint16_t * LIBGAV1_RESTRICT pred_0,const uint16_t * LIBGAV1_RESTRICT pred_1,const uint8_t weight_0,const uint8_t weight_1,void * LIBGAV1_RESTRICT const dest,const ptrdiff_t dest_stride)311 inline void DistanceWeightedBlend8xH_SSE4_1(
312     const uint16_t* LIBGAV1_RESTRICT pred_0,
313     const uint16_t* LIBGAV1_RESTRICT pred_1, const uint8_t weight_0,
314     const uint8_t weight_1, void* LIBGAV1_RESTRICT const dest,
315     const ptrdiff_t dest_stride) {
316   auto* dst = static_cast<uint16_t*>(dest);
317   const __m128i weight0 = _mm_set1_epi32(weight_0);
318   const __m128i weight1 = _mm_set1_epi32(weight_1);
319 
320   int y = height;
321   do {
322     const __m128i src_00 = LoadAligned16(pred_0);
323     const __m128i src_10 = LoadAligned16(pred_1);
324     pred_0 += 8;
325     pred_1 += 8;
326     const __m128i res0 =
327         ComputeWeightedAverage8(src_00, src_10, weight0, weight1);
328 
329     const __m128i src_01 = LoadAligned16(pred_0);
330     const __m128i src_11 = LoadAligned16(pred_1);
331     pred_0 += 8;
332     pred_1 += 8;
333     const __m128i res1 =
334         ComputeWeightedAverage8(src_01, src_11, weight0, weight1);
335 
336     StoreUnaligned16(dst, res0);
337     dst += dest_stride;
338     StoreUnaligned16(dst, res1);
339     dst += dest_stride;
340     y -= 2;
341   } while (y != 0);
342 }
343 
DistanceWeightedBlendLarge_SSE4_1(const uint16_t * LIBGAV1_RESTRICT pred_0,const uint16_t * LIBGAV1_RESTRICT pred_1,const uint8_t weight_0,const uint8_t weight_1,const int width,const int height,void * LIBGAV1_RESTRICT const dest,const ptrdiff_t dest_stride)344 inline void DistanceWeightedBlendLarge_SSE4_1(
345     const uint16_t* LIBGAV1_RESTRICT pred_0,
346     const uint16_t* LIBGAV1_RESTRICT pred_1, const uint8_t weight_0,
347     const uint8_t weight_1, const int width, const int height,
348     void* LIBGAV1_RESTRICT const dest, const ptrdiff_t dest_stride) {
349   auto* dst = static_cast<uint16_t*>(dest);
350   const __m128i weight0 = _mm_set1_epi32(weight_0);
351   const __m128i weight1 = _mm_set1_epi32(weight_1);
352 
353   int y = height;
354   do {
355     int x = 0;
356     do {
357       const __m128i src_0_lo = LoadAligned16(pred_0 + x);
358       const __m128i src_1_lo = LoadAligned16(pred_1 + x);
359       const __m128i res_lo =
360           ComputeWeightedAverage8(src_0_lo, src_1_lo, weight0, weight1);
361 
362       const __m128i src_0_hi = LoadAligned16(pred_0 + x + 8);
363       const __m128i src_1_hi = LoadAligned16(pred_1 + x + 8);
364       const __m128i res_hi =
365           ComputeWeightedAverage8(src_0_hi, src_1_hi, weight0, weight1);
366 
367       StoreUnaligned16(dst + x, res_lo);
368       x += 8;
369       StoreUnaligned16(dst + x, res_hi);
370       x += 8;
371     } while (x < width);
372     dst += dest_stride;
373     pred_0 += width;
374     pred_1 += width;
375   } while (--y != 0);
376 }
377 
DistanceWeightedBlend_SSE4_1(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,const uint8_t weight_0,const uint8_t weight_1,const int width,const int height,void * LIBGAV1_RESTRICT const dest,const ptrdiff_t dest_stride)378 void DistanceWeightedBlend_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
379                                   const void* LIBGAV1_RESTRICT prediction_1,
380                                   const uint8_t weight_0,
381                                   const uint8_t weight_1, const int width,
382                                   const int height,
383                                   void* LIBGAV1_RESTRICT const dest,
384                                   const ptrdiff_t dest_stride) {
385   const auto* pred_0 = static_cast<const uint16_t*>(prediction_0);
386   const auto* pred_1 = static_cast<const uint16_t*>(prediction_1);
387   const ptrdiff_t dst_stride = dest_stride / sizeof(*pred_0);
388   if (width == 4) {
389     if (height == 4) {
390       DistanceWeightedBlend4xH_SSE4_1<4>(pred_0, pred_1, weight_0, weight_1,
391                                          dest, dst_stride);
392     } else if (height == 8) {
393       DistanceWeightedBlend4xH_SSE4_1<8>(pred_0, pred_1, weight_0, weight_1,
394                                          dest, dst_stride);
395     } else {
396       assert(height == 16);
397       DistanceWeightedBlend4xH_SSE4_1<16>(pred_0, pred_1, weight_0, weight_1,
398                                           dest, dst_stride);
399     }
400     return;
401   }
402 
403   if (width == 8) {
404     switch (height) {
405       case 4:
406         DistanceWeightedBlend8xH_SSE4_1<4>(pred_0, pred_1, weight_0, weight_1,
407                                            dest, dst_stride);
408         return;
409       case 8:
410         DistanceWeightedBlend8xH_SSE4_1<8>(pred_0, pred_1, weight_0, weight_1,
411                                            dest, dst_stride);
412         return;
413       case 16:
414         DistanceWeightedBlend8xH_SSE4_1<16>(pred_0, pred_1, weight_0, weight_1,
415                                             dest, dst_stride);
416         return;
417       default:
418         assert(height == 32);
419         DistanceWeightedBlend8xH_SSE4_1<32>(pred_0, pred_1, weight_0, weight_1,
420                                             dest, dst_stride);
421 
422         return;
423     }
424   }
425 
426   DistanceWeightedBlendLarge_SSE4_1(pred_0, pred_1, weight_0, weight_1, width,
427                                     height, dest, dst_stride);
428 }
429 
Init10bpp()430 void Init10bpp() {
431   Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth10);
432   assert(dsp != nullptr);
433 #if DSP_ENABLED_10BPP_SSE4_1(DistanceWeightedBlend)
434   dsp->distance_weighted_blend = DistanceWeightedBlend_SSE4_1;
435 #endif
436 }
437 
438 }  // namespace
439 }  // namespace high_bitdepth
440 #endif  // LIBGAV1_MAX_BITDEPTH >= 10
441 
DistanceWeightedBlendInit_SSE4_1()442 void DistanceWeightedBlendInit_SSE4_1() {
443   low_bitdepth::Init8bpp();
444 #if LIBGAV1_MAX_BITDEPTH >= 10
445   high_bitdepth::Init10bpp();
446 #endif
447 }
448 
449 }  // namespace dsp
450 }  // namespace libgav1
451 
452 #else   // !LIBGAV1_TARGETING_SSE4_1
453 
454 namespace libgav1 {
455 namespace dsp {
456 
DistanceWeightedBlendInit_SSE4_1()457 void DistanceWeightedBlendInit_SSE4_1() {}
458 
459 }  // namespace dsp
460 }  // namespace libgav1
461 #endif  // LIBGAV1_TARGETING_SSE4_1
462