1 // Copyright 2020 The libgav1 Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "src/dsp/motion_field_projection.h"
16 #include "src/utils/cpu.h"
17
18 #if LIBGAV1_ENABLE_NEON
19
20 #include <arm_neon.h>
21
22 #include <algorithm>
23 #include <cassert>
24 #include <cstddef>
25 #include <cstdint>
26
27 #include "src/dsp/constants.h"
28 #include "src/dsp/dsp.h"
29 #include "src/utils/common.h"
30 #include "src/utils/constants.h"
31 #include "src/utils/types.h"
32
33 namespace libgav1 {
34 namespace dsp {
35 namespace {
36
LoadDivision(const int8x8x2_t division_table,const int8x8_t reference_offset)37 inline int16x8_t LoadDivision(const int8x8x2_t division_table,
38 const int8x8_t reference_offset) {
39 const int8x8_t kOne = vcreate_s8(0x0100010001000100);
40 const int8x16_t kOneQ = vcombine_s8(kOne, kOne);
41 const int8x8_t t = vadd_s8(reference_offset, reference_offset);
42 const int8x8x2_t tt = vzip_s8(t, t);
43 const int8x16_t t1 = vcombine_s8(tt.val[0], tt.val[1]);
44 const int8x16_t idx = vaddq_s8(t1, kOneQ);
45 const int8x8_t idx_low = vget_low_s8(idx);
46 const int8x8_t idx_high = vget_high_s8(idx);
47 const int16x4_t d0 = vreinterpret_s16_s8(vtbl2_s8(division_table, idx_low));
48 const int16x4_t d1 = vreinterpret_s16_s8(vtbl2_s8(division_table, idx_high));
49 return vcombine_s16(d0, d1);
50 }
51
MvProjection(const int16x4_t mv,const int16x4_t denominator,const int numerator)52 inline int16x4_t MvProjection(const int16x4_t mv, const int16x4_t denominator,
53 const int numerator) {
54 const int32x4_t m0 = vmull_s16(mv, denominator);
55 const int32x4_t m = vmulq_n_s32(m0, numerator);
56 // Add the sign (0 or -1) to round towards zero.
57 const int32x4_t add_sign = vsraq_n_s32(m, m, 31);
58 return vqrshrn_n_s32(add_sign, 14);
59 }
60
MvProjectionClip(const int16x8_t mv,const int16x8_t denominator,const int numerator)61 inline int16x8_t MvProjectionClip(const int16x8_t mv,
62 const int16x8_t denominator,
63 const int numerator) {
64 const int16x4_t mv0 = vget_low_s16(mv);
65 const int16x4_t mv1 = vget_high_s16(mv);
66 const int16x4_t s0 = MvProjection(mv0, vget_low_s16(denominator), numerator);
67 const int16x4_t s1 = MvProjection(mv1, vget_high_s16(denominator), numerator);
68 const int16x8_t projection = vcombine_s16(s0, s1);
69 const int16x8_t projection_mv_clamp = vdupq_n_s16(kProjectionMvClamp);
70 const int16x8_t clamp = vminq_s16(projection, projection_mv_clamp);
71 return vmaxq_s16(clamp, vnegq_s16(projection_mv_clamp));
72 }
73
Project_NEON(const int16x8_t delta,const int16x8_t dst_sign)74 inline int8x8_t Project_NEON(const int16x8_t delta, const int16x8_t dst_sign) {
75 // Add 63 to negative delta so that it shifts towards zero.
76 const int16x8_t delta_sign = vshrq_n_s16(delta, 15);
77 const uint16x8_t delta_u = vreinterpretq_u16_s16(delta);
78 const uint16x8_t delta_sign_u = vreinterpretq_u16_s16(delta_sign);
79 const uint16x8_t delta_adjust_u = vsraq_n_u16(delta_u, delta_sign_u, 10);
80 const int16x8_t delta_adjust = vreinterpretq_s16_u16(delta_adjust_u);
81 const int16x8_t offset0 = vshrq_n_s16(delta_adjust, 6);
82 const int16x8_t offset1 = veorq_s16(offset0, dst_sign);
83 const int16x8_t offset2 = vsubq_s16(offset1, dst_sign);
84 return vqmovn_s16(offset2);
85 }
86
GetPosition(const int8x8x2_t division_table,const MotionVector * const mv,const int numerator,const int x8_start,const int x8_end,const int x8,const int8x8_t r_offsets,const int8x8_t source_reference_type8,const int8x8_t skip_r,const int8x8_t y8_floor8,const int8x8_t y8_ceiling8,const int16x8_t d_sign,const int delta,int8x8_t * const r,int8x8_t * const position_y8,int8x8_t * const position_x8,int64_t * const skip_64,int32x4_t mvs[2])87 inline void GetPosition(
88 const int8x8x2_t division_table, const MotionVector* const mv,
89 const int numerator, const int x8_start, const int x8_end, const int x8,
90 const int8x8_t r_offsets, const int8x8_t source_reference_type8,
91 const int8x8_t skip_r, const int8x8_t y8_floor8, const int8x8_t y8_ceiling8,
92 const int16x8_t d_sign, const int delta, int8x8_t* const r,
93 int8x8_t* const position_y8, int8x8_t* const position_x8,
94 int64_t* const skip_64, int32x4_t mvs[2]) {
95 const auto* const mv_int = reinterpret_cast<const int32_t*>(mv + x8);
96 *r = vtbl1_s8(r_offsets, source_reference_type8);
97 const int16x8_t denorm = LoadDivision(division_table, source_reference_type8);
98 int16x8_t projection_mv[2];
99 mvs[0] = vld1q_s32(mv_int + 0);
100 mvs[1] = vld1q_s32(mv_int + 4);
101 // Deinterlace x and y components
102 const int16x8_t mv0 = vreinterpretq_s16_s32(mvs[0]);
103 const int16x8_t mv1 = vreinterpretq_s16_s32(mvs[1]);
104 const int16x8x2_t mv_yx = vuzpq_s16(mv0, mv1);
105 // numerator could be 0.
106 projection_mv[0] = MvProjectionClip(mv_yx.val[0], denorm, numerator);
107 projection_mv[1] = MvProjectionClip(mv_yx.val[1], denorm, numerator);
108 // Do not update the motion vector if the block position is not valid or
109 // if position_x8 is outside the current range of x8_start and x8_end.
110 // Note that position_y8 will always be within the range of y8_start and
111 // y8_end.
112 // After subtracting the base, valid projections are within 8-bit.
113 *position_y8 = Project_NEON(projection_mv[0], d_sign);
114 const int8x8_t position_x = Project_NEON(projection_mv[1], d_sign);
115 const int8x8_t k01234567 = vcreate_s8(uint64_t{0x0706050403020100});
116 *position_x8 = vqadd_s8(position_x, k01234567);
117 const int8x16_t position_xy = vcombine_s8(*position_x8, *position_y8);
118 const int x8_floor = std::max(
119 x8_start - x8, delta - kProjectionMvMaxHorizontalOffset); // [-8, 8]
120 const int x8_ceiling = std::min(
121 x8_end - x8, delta + 8 + kProjectionMvMaxHorizontalOffset); // [0, 16]
122 const int8x8_t x8_floor8 = vdup_n_s8(x8_floor);
123 const int8x8_t x8_ceiling8 = vdup_n_s8(x8_ceiling);
124 const int8x16_t floor_xy = vcombine_s8(x8_floor8, y8_floor8);
125 const int8x16_t ceiling_xy = vcombine_s8(x8_ceiling8, y8_ceiling8);
126 const uint8x16_t underflow = vcltq_s8(position_xy, floor_xy);
127 const uint8x16_t overflow = vcgeq_s8(position_xy, ceiling_xy);
128 const int8x16_t out = vreinterpretq_s8_u8(vorrq_u8(underflow, overflow));
129 const int8x8_t skip_low = vorr_s8(skip_r, vget_low_s8(out));
130 const int8x8_t skip = vorr_s8(skip_low, vget_high_s8(out));
131 *skip_64 = vget_lane_s64(vreinterpret_s64_s8(skip), 0);
132 }
133
134 template <int idx>
Store(const int16x8_t position,const int8x8_t reference_offset,const int32x4_t mv,int8_t * dst_reference_offset,MotionVector * dst_mv)135 inline void Store(const int16x8_t position, const int8x8_t reference_offset,
136 const int32x4_t mv, int8_t* dst_reference_offset,
137 MotionVector* dst_mv) {
138 const ptrdiff_t offset = vgetq_lane_s16(position, idx);
139 auto* const d_mv = reinterpret_cast<int32_t*>(&dst_mv[offset]);
140 vst1q_lane_s32(d_mv, mv, idx & 3);
141 vst1_lane_s8(&dst_reference_offset[offset], reference_offset, idx);
142 }
143
144 template <int idx>
CheckStore(const int8_t * skips,const int16x8_t position,const int8x8_t reference_offset,const int32x4_t mv,int8_t * dst_reference_offset,MotionVector * dst_mv)145 inline void CheckStore(const int8_t* skips, const int16x8_t position,
146 const int8x8_t reference_offset, const int32x4_t mv,
147 int8_t* dst_reference_offset, MotionVector* dst_mv) {
148 if (skips[idx] == 0) {
149 Store<idx>(position, reference_offset, mv, dst_reference_offset, dst_mv);
150 }
151 }
152
153 // 7.9.2.
MotionFieldProjectionKernel_NEON(const ReferenceInfo & reference_info,const int reference_to_current_with_sign,const int dst_sign,const int y8_start,const int y8_end,const int x8_start,const int x8_end,TemporalMotionField * const motion_field)154 void MotionFieldProjectionKernel_NEON(const ReferenceInfo& reference_info,
155 const int reference_to_current_with_sign,
156 const int dst_sign, const int y8_start,
157 const int y8_end, const int x8_start,
158 const int x8_end,
159 TemporalMotionField* const motion_field) {
160 const ptrdiff_t stride = motion_field->mv.columns();
161 // The column range has to be offset by kProjectionMvMaxHorizontalOffset since
162 // coordinates in that range could end up being position_x8 because of
163 // projection.
164 const int adjusted_x8_start =
165 std::max(x8_start - kProjectionMvMaxHorizontalOffset, 0);
166 const int adjusted_x8_end = std::min(
167 x8_end + kProjectionMvMaxHorizontalOffset, static_cast<int>(stride));
168 const int adjusted_x8_end8 = adjusted_x8_end & ~7;
169 const int leftover = adjusted_x8_end - adjusted_x8_end8;
170 const int8_t* const reference_offsets =
171 reference_info.relative_distance_to.data();
172 const bool* const skip_references = reference_info.skip_references.data();
173 const int16_t* const projection_divisions =
174 reference_info.projection_divisions.data();
175 const ReferenceFrameType* source_reference_types =
176 &reference_info.motion_field_reference_frame[y8_start][0];
177 const MotionVector* mv = &reference_info.motion_field_mv[y8_start][0];
178 int8_t* dst_reference_offset = motion_field->reference_offset[y8_start];
179 MotionVector* dst_mv = motion_field->mv[y8_start];
180 const int16x8_t d_sign = vdupq_n_s16(dst_sign);
181
182 static_assert(sizeof(int8_t) == sizeof(bool), "");
183 static_assert(sizeof(int8_t) == sizeof(ReferenceFrameType), "");
184 static_assert(sizeof(int32_t) == sizeof(MotionVector), "");
185 assert(dst_sign == 0 || dst_sign == -1);
186 assert(stride == motion_field->reference_offset.columns());
187 assert((y8_start & 7) == 0);
188 assert((adjusted_x8_start & 7) == 0);
189 // The final position calculation is represented with int16_t. Valid
190 // position_y8 from its base is at most 7. After considering the horizontal
191 // offset which is at most |stride - 1|, we have the following assertion,
192 // which means this optimization works for frame width up to 32K (each
193 // position is a 8x8 block).
194 assert(8 * stride <= 32768);
195 const int8x8_t skip_reference =
196 vld1_s8(reinterpret_cast<const int8_t*>(skip_references));
197 const int8x8_t r_offsets = vld1_s8(reference_offsets);
198 const int8x16_t table = vreinterpretq_s8_s16(vld1q_s16(projection_divisions));
199 int8x8x2_t division_table;
200 division_table.val[0] = vget_low_s8(table);
201 division_table.val[1] = vget_high_s8(table);
202
203 int y8 = y8_start;
204 do {
205 const int y8_floor = (y8 & ~7) - y8; // [-7, 0]
206 const int y8_ceiling = std::min(y8_end - y8, y8_floor + 8); // [1, 8]
207 const int8x8_t y8_floor8 = vdup_n_s8(y8_floor);
208 const int8x8_t y8_ceiling8 = vdup_n_s8(y8_ceiling);
209 int x8;
210
211 for (x8 = adjusted_x8_start; x8 < adjusted_x8_end8; x8 += 8) {
212 const int8x8_t source_reference_type8 =
213 vld1_s8(reinterpret_cast<const int8_t*>(source_reference_types + x8));
214 const int8x8_t skip_r = vtbl1_s8(skip_reference, source_reference_type8);
215 const int64_t early_skip = vget_lane_s64(vreinterpret_s64_s8(skip_r), 0);
216 // Early termination #1 if all are skips. Chance is typically ~30-40%.
217 if (early_skip == -1) continue;
218 int64_t skip_64;
219 int8x8_t r, position_x8, position_y8;
220 int32x4_t mvs[2];
221 GetPosition(division_table, mv, reference_to_current_with_sign, x8_start,
222 x8_end, x8, r_offsets, source_reference_type8, skip_r,
223 y8_floor8, y8_ceiling8, d_sign, 0, &r, &position_y8,
224 &position_x8, &skip_64, mvs);
225 // Early termination #2 if all are skips.
226 // Chance is typically ~15-25% after Early termination #1.
227 if (skip_64 == -1) continue;
228 const int16x8_t p_y = vmovl_s8(position_y8);
229 const int16x8_t p_x = vmovl_s8(position_x8);
230 const int16x8_t pos = vmlaq_n_s16(p_x, p_y, stride);
231 const int16x8_t position = vaddq_s16(pos, vdupq_n_s16(x8));
232 if (skip_64 == 0) {
233 // Store all. Chance is typically ~70-85% after Early termination #2.
234 Store<0>(position, r, mvs[0], dst_reference_offset, dst_mv);
235 Store<1>(position, r, mvs[0], dst_reference_offset, dst_mv);
236 Store<2>(position, r, mvs[0], dst_reference_offset, dst_mv);
237 Store<3>(position, r, mvs[0], dst_reference_offset, dst_mv);
238 Store<4>(position, r, mvs[1], dst_reference_offset, dst_mv);
239 Store<5>(position, r, mvs[1], dst_reference_offset, dst_mv);
240 Store<6>(position, r, mvs[1], dst_reference_offset, dst_mv);
241 Store<7>(position, r, mvs[1], dst_reference_offset, dst_mv);
242 } else {
243 // Check and store each.
244 // Chance is typically ~15-30% after Early termination #2.
245 // The compiler is smart enough to not create the local buffer skips[].
246 int8_t skips[8];
247 memcpy(skips, &skip_64, sizeof(skips));
248 CheckStore<0>(skips, position, r, mvs[0], dst_reference_offset, dst_mv);
249 CheckStore<1>(skips, position, r, mvs[0], dst_reference_offset, dst_mv);
250 CheckStore<2>(skips, position, r, mvs[0], dst_reference_offset, dst_mv);
251 CheckStore<3>(skips, position, r, mvs[0], dst_reference_offset, dst_mv);
252 CheckStore<4>(skips, position, r, mvs[1], dst_reference_offset, dst_mv);
253 CheckStore<5>(skips, position, r, mvs[1], dst_reference_offset, dst_mv);
254 CheckStore<6>(skips, position, r, mvs[1], dst_reference_offset, dst_mv);
255 CheckStore<7>(skips, position, r, mvs[1], dst_reference_offset, dst_mv);
256 }
257 }
258
259 // The following leftover processing cannot be moved out of the do...while
260 // loop. Doing so may change the result storing orders of the same position.
261 if (leftover > 0) {
262 // Use SIMD only when leftover is at least 4, and there are at least 8
263 // elements in a row.
264 if (leftover >= 4 && adjusted_x8_start < adjusted_x8_end8) {
265 // Process the last 8 elements to avoid loading invalid memory. Some
266 // elements may have been processed in the above loop, which is OK.
267 const int delta = 8 - leftover;
268 x8 = adjusted_x8_end - 8;
269 const int8x8_t source_reference_type8 = vld1_s8(
270 reinterpret_cast<const int8_t*>(source_reference_types + x8));
271 const int8x8_t skip_r =
272 vtbl1_s8(skip_reference, source_reference_type8);
273 const int64_t early_skip =
274 vget_lane_s64(vreinterpret_s64_s8(skip_r), 0);
275 // Early termination #1 if all are skips.
276 if (early_skip != -1) {
277 int64_t skip_64;
278 int8x8_t r, position_x8, position_y8;
279 int32x4_t mvs[2];
280 GetPosition(division_table, mv, reference_to_current_with_sign,
281 x8_start, x8_end, x8, r_offsets, source_reference_type8,
282 skip_r, y8_floor8, y8_ceiling8, d_sign, delta, &r,
283 &position_y8, &position_x8, &skip_64, mvs);
284 // Early termination #2 if all are skips.
285 if (skip_64 != -1) {
286 const int16x8_t p_y = vmovl_s8(position_y8);
287 const int16x8_t p_x = vmovl_s8(position_x8);
288 const int16x8_t pos = vmlaq_n_s16(p_x, p_y, stride);
289 const int16x8_t position = vaddq_s16(pos, vdupq_n_s16(x8));
290 // Store up to 7 elements since leftover is at most 7.
291 if (skip_64 == 0) {
292 // Store all.
293 Store<1>(position, r, mvs[0], dst_reference_offset, dst_mv);
294 Store<2>(position, r, mvs[0], dst_reference_offset, dst_mv);
295 Store<3>(position, r, mvs[0], dst_reference_offset, dst_mv);
296 Store<4>(position, r, mvs[1], dst_reference_offset, dst_mv);
297 Store<5>(position, r, mvs[1], dst_reference_offset, dst_mv);
298 Store<6>(position, r, mvs[1], dst_reference_offset, dst_mv);
299 Store<7>(position, r, mvs[1], dst_reference_offset, dst_mv);
300 } else {
301 // Check and store each.
302 // The compiler is smart enough to not create the local buffer
303 // skips[].
304 int8_t skips[8];
305 memcpy(skips, &skip_64, sizeof(skips));
306 CheckStore<1>(skips, position, r, mvs[0], dst_reference_offset,
307 dst_mv);
308 CheckStore<2>(skips, position, r, mvs[0], dst_reference_offset,
309 dst_mv);
310 CheckStore<3>(skips, position, r, mvs[0], dst_reference_offset,
311 dst_mv);
312 CheckStore<4>(skips, position, r, mvs[1], dst_reference_offset,
313 dst_mv);
314 CheckStore<5>(skips, position, r, mvs[1], dst_reference_offset,
315 dst_mv);
316 CheckStore<6>(skips, position, r, mvs[1], dst_reference_offset,
317 dst_mv);
318 CheckStore<7>(skips, position, r, mvs[1], dst_reference_offset,
319 dst_mv);
320 }
321 }
322 }
323 } else {
324 for (; x8 < adjusted_x8_end; ++x8) {
325 const int source_reference_type = source_reference_types[x8];
326 if (skip_references[source_reference_type]) continue;
327 MotionVector projection_mv;
328 // reference_to_current_with_sign could be 0.
329 GetMvProjection(mv[x8], reference_to_current_with_sign,
330 projection_divisions[source_reference_type],
331 &projection_mv);
332 // Do not update the motion vector if the block position is not valid
333 // or if position_x8 is outside the current range of x8_start and
334 // x8_end. Note that position_y8 will always be within the range of
335 // y8_start and y8_end.
336 const int position_y8 = Project(0, projection_mv.mv[0], dst_sign);
337 if (position_y8 < y8_floor || position_y8 >= y8_ceiling) continue;
338 const int x8_base = x8 & ~7;
339 const int x8_floor =
340 std::max(x8_start, x8_base - kProjectionMvMaxHorizontalOffset);
341 const int x8_ceiling =
342 std::min(x8_end, x8_base + 8 + kProjectionMvMaxHorizontalOffset);
343 const int position_x8 = Project(x8, projection_mv.mv[1], dst_sign);
344 if (position_x8 < x8_floor || position_x8 >= x8_ceiling) continue;
345 dst_mv[position_y8 * stride + position_x8] = mv[x8];
346 dst_reference_offset[position_y8 * stride + position_x8] =
347 reference_offsets[source_reference_type];
348 }
349 }
350 }
351
352 source_reference_types += stride;
353 mv += stride;
354 dst_reference_offset += stride;
355 dst_mv += stride;
356 } while (++y8 < y8_end);
357 }
358
Init8bpp()359 void Init8bpp() {
360 Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
361 assert(dsp != nullptr);
362 dsp->motion_field_projection_kernel = MotionFieldProjectionKernel_NEON;
363 }
364
365 #if LIBGAV1_MAX_BITDEPTH >= 10
Init10bpp()366 void Init10bpp() {
367 Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth10);
368 assert(dsp != nullptr);
369 dsp->motion_field_projection_kernel = MotionFieldProjectionKernel_NEON;
370 }
371 #endif
372
373 } // namespace
374
MotionFieldProjectionInit_NEON()375 void MotionFieldProjectionInit_NEON() {
376 Init8bpp();
377 #if LIBGAV1_MAX_BITDEPTH >= 10
378 Init10bpp();
379 #endif
380 }
381
382 } // namespace dsp
383 } // namespace libgav1
384
385 #else // !LIBGAV1_ENABLE_NEON
386 namespace libgav1 {
387 namespace dsp {
388
MotionFieldProjectionInit_NEON()389 void MotionFieldProjectionInit_NEON() {}
390
391 } // namespace dsp
392 } // namespace libgav1
393 #endif // LIBGAV1_ENABLE_NEON
394