1 /*
2  * Copyright © 2020, VideoLAN and dav1d authors
3  * Copyright © 2020, Two Orioles, LLC
4  * All rights reserved.
5  *
6  * Redistribution and use in source and binary forms, with or without
7  * modification, are permitted provided that the following conditions are met:
8  *
9  * 1. Redistributions of source code must retain the above copyright notice, this
10  *    list of conditions and the following disclaimer.
11  *
12  * 2. Redistributions in binary form must reproduce the above copyright notice,
13  *    this list of conditions and the following disclaimer in the documentation
14  *    and/or other materials provided with the distribution.
15  *
16  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
17  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
18  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
19  * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
20  * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
21  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
22  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
23  * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
25  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26  */
27 
28 #include "config.h"
29 
30 #include <limits.h>
31 #include <stdlib.h>
32 
33 #include "dav1d/common.h"
34 
35 #include "common/intops.h"
36 
37 #include "src/env.h"
38 #include "src/refmvs.h"
39 
add_spatial_candidate(refmvs_candidate * const mvstack,int * const cnt,const int weight,const refmvs_block * const b,const union refmvs_refpair ref,const mv gmv[2],int * const have_newmv_match,int * const have_refmv_match)40 static void add_spatial_candidate(refmvs_candidate *const mvstack, int *const cnt,
41                                   const int weight, const refmvs_block *const b,
42                                   const union refmvs_refpair ref, const mv gmv[2],
43                                   int *const have_newmv_match,
44                                   int *const have_refmv_match)
45 {
46     if (b->mv.mv[0].n == INVALID_MV) return; // intra block, no intrabc
47 
48     if (ref.ref[1] == -1) {
49         for (int n = 0; n < 2; n++) {
50             if (b->ref.ref[n] == ref.ref[0]) {
51                 const mv cand_mv = ((b->mf & 1) && gmv[0].n != INVALID_MV) ?
52                                    gmv[0] : b->mv.mv[n];
53 
54                 const int last = *cnt;
55                 for (int m = 0; m < last; m++)
56                     if (mvstack[m].mv.mv[0].n == cand_mv.n) {
57                         mvstack[m].weight += weight;
58                         *have_refmv_match = 1;
59                         *have_newmv_match |= b->mf >> 1;
60                         return;
61                     }
62 
63                 if (last < 8) {
64                     mvstack[last].mv.mv[0] = cand_mv;
65                     mvstack[last].weight = weight;
66                     *cnt = last + 1;
67                 }
68                 *have_refmv_match = 1;
69                 *have_newmv_match |= b->mf >> 1;
70                 return;
71             }
72         }
73     } else if (b->ref.pair == ref.pair) {
74         const refmvs_mvpair cand_mv = { .mv = {
75             [0] = ((b->mf & 1) && gmv[0].n != INVALID_MV) ? gmv[0] : b->mv.mv[0],
76             [1] = ((b->mf & 1) && gmv[1].n != INVALID_MV) ? gmv[1] : b->mv.mv[1],
77         }};
78 
79         const int last = *cnt;
80         for (int n = 0; n < last; n++)
81             if (mvstack[n].mv.n == cand_mv.n) {
82                 mvstack[n].weight += weight;
83                 *have_refmv_match = 1;
84                 *have_newmv_match |= b->mf >> 1;
85                 return;
86             }
87 
88         if (last < 8) {
89             mvstack[last].mv = cand_mv;
90             mvstack[last].weight = weight;
91             *cnt = last + 1;
92         }
93         *have_refmv_match = 1;
94         *have_newmv_match |= b->mf >> 1;
95     }
96 }
97 
scan_row(refmvs_candidate * const mvstack,int * const cnt,const union refmvs_refpair ref,const mv gmv[2],const refmvs_block * b,const int bw4,const int w4,const int max_rows,const int step,int * const have_newmv_match,int * const have_refmv_match)98 static int scan_row(refmvs_candidate *const mvstack, int *const cnt,
99                     const union refmvs_refpair ref, const mv gmv[2],
100                     const refmvs_block *b, const int bw4, const int w4,
101                     const int max_rows, const int step,
102                     int *const have_newmv_match, int *const have_refmv_match)
103 {
104     const refmvs_block *cand_b = b;
105     const enum BlockSize first_cand_bs = cand_b->bs;
106     const uint8_t *const first_cand_b_dim = dav1d_block_dimensions[first_cand_bs];
107     int cand_bw4 = first_cand_b_dim[0];
108     int len = imax(step, imin(bw4, cand_bw4));
109 
110     if (bw4 <= cand_bw4) {
111         // FIXME weight can be higher for odd blocks (bx4 & 1), but then the
112         // position of the first block has to be odd already, i.e. not just
113         // for row_offset=-3/-5
114         // FIXME why can this not be cand_bw4?
115         const int weight = bw4 == 1 ? 2 :
116                            imax(2, imin(2 * max_rows, first_cand_b_dim[1]));
117         add_spatial_candidate(mvstack, cnt, len * weight, cand_b, ref, gmv,
118                               have_newmv_match, have_refmv_match);
119         return weight >> 1;
120     }
121 
122     for (int x = 0;;) {
123         // FIXME if we overhang above, we could fill a bitmask so we don't have
124         // to repeat the add_spatial_candidate() for the next row, but just increase
125         // the weight here
126         add_spatial_candidate(mvstack, cnt, len * 2, cand_b, ref, gmv,
127                               have_newmv_match, have_refmv_match);
128         x += len;
129         if (x >= w4) return 1;
130         cand_b = &b[x];
131         cand_bw4 = dav1d_block_dimensions[cand_b->bs][0];
132         assert(cand_bw4 < bw4);
133         len = imax(step, cand_bw4);
134     }
135 }
136 
scan_col(refmvs_candidate * const mvstack,int * const cnt,const union refmvs_refpair ref,const mv gmv[2],refmvs_block * const * b,const int bh4,const int h4,const int bx4,const int max_cols,const int step,int * const have_newmv_match,int * const have_refmv_match)137 static int scan_col(refmvs_candidate *const mvstack, int *const cnt,
138                     const union refmvs_refpair ref, const mv gmv[2],
139                     /*const*/ refmvs_block *const *b, const int bh4, const int h4,
140                     const int bx4, const int max_cols, const int step,
141                     int *const have_newmv_match, int *const have_refmv_match)
142 {
143     const refmvs_block *cand_b = &b[0][bx4];
144     const enum BlockSize first_cand_bs = cand_b->bs;
145     const uint8_t *const first_cand_b_dim = dav1d_block_dimensions[first_cand_bs];
146     int cand_bh4 = first_cand_b_dim[1];
147     int len = imax(step, imin(bh4, cand_bh4));
148 
149     if (bh4 <= cand_bh4) {
150         // FIXME weight can be higher for odd blocks (by4 & 1), but then the
151         // position of the first block has to be odd already, i.e. not just
152         // for col_offset=-3/-5
153         // FIXME why can this not be cand_bh4?
154         const int weight = bh4 == 1 ? 2 :
155                            imax(2, imin(2 * max_cols, first_cand_b_dim[0]));
156         add_spatial_candidate(mvstack, cnt, len * weight, cand_b, ref, gmv,
157                             have_newmv_match, have_refmv_match);
158         return weight >> 1;
159     }
160 
161     for (int y = 0;;) {
162         // FIXME if we overhang above, we could fill a bitmask so we don't have
163         // to repeat the add_spatial_candidate() for the next row, but just increase
164         // the weight here
165         add_spatial_candidate(mvstack, cnt, len * 2, cand_b, ref, gmv,
166                               have_newmv_match, have_refmv_match);
167         y += len;
168         if (y >= h4) return 1;
169         cand_b = &b[y][bx4];
170         cand_bh4 = dav1d_block_dimensions[cand_b->bs][1];
171         assert(cand_bh4 < bh4);
172         len = imax(step, cand_bh4);
173     }
174 }
175 
mv_projection(const union mv mv,const int num,const int den)176 static inline union mv mv_projection(const union mv mv, const int num, const int den) {
177     static const uint16_t div_mult[32] = {
178            0, 16384, 8192, 5461, 4096, 3276, 2730, 2340,
179         2048,  1820, 1638, 1489, 1365, 1260, 1170, 1092,
180         1024,   963,  910,  862,  819,  780,  744,  712,
181          682,   655,  630,  606,  585,  564,  546,  528
182     };
183     assert(den > 0 && den < 32);
184     assert(num > -32 && num < 32);
185     const int dm = div_mult[den];
186     const int y = mv.y * num * dm, x = mv.x * num * dm;
187     return (union mv) { .y = (y + 8192 + (y >> 31)) >> 14,
188                         .x = (x + 8192 + (x >> 31)) >> 14 };
189 }
190 
add_temporal_candidate(const refmvs_frame * const rf,refmvs_candidate * const mvstack,int * const cnt,const refmvs_temporal_block * const rb,const union refmvs_refpair ref,int * const globalmv_ctx,const union mv gmv[])191 static void add_temporal_candidate(const refmvs_frame *const rf,
192                                    refmvs_candidate *const mvstack, int *const cnt,
193                                    const refmvs_temporal_block *const rb,
194                                    const union refmvs_refpair ref, int *const globalmv_ctx,
195                                    const union mv gmv[])
196 {
197     if (rb->mv.n == INVALID_MV) return;
198 
199     union mv mv = mv_projection(rb->mv, rf->pocdiff[ref.ref[0] - 1], rb->ref);
200     fix_mv_precision(rf->frm_hdr, &mv);
201 
202     const int last = *cnt;
203     if (ref.ref[1] == -1) {
204         if (globalmv_ctx)
205             *globalmv_ctx = (abs(mv.x - gmv[0].x) | abs(mv.y - gmv[0].y)) >= 16;
206 
207         for (int n = 0; n < last; n++)
208             if (mvstack[n].mv.mv[0].n == mv.n) {
209                 mvstack[n].weight += 2;
210                 return;
211             }
212         if (last < 8) {
213             mvstack[last].mv.mv[0] = mv;
214             mvstack[last].weight = 2;
215             *cnt = last + 1;
216         }
217     } else {
218         refmvs_mvpair mvp = { .mv = {
219             [0] = mv,
220             [1] = mv_projection(rb->mv, rf->pocdiff[ref.ref[1] - 1], rb->ref),
221         }};
222         fix_mv_precision(rf->frm_hdr, &mvp.mv[1]);
223 
224         for (int n = 0; n < last; n++)
225             if (mvstack[n].mv.n == mvp.n) {
226                 mvstack[n].weight += 2;
227                 return;
228             }
229         if (last < 8) {
230             mvstack[last].mv = mvp;
231             mvstack[last].weight = 2;
232             *cnt = last + 1;
233         }
234     }
235 }
236 
add_compound_extended_candidate(refmvs_candidate * const same,int * const same_count,const refmvs_block * const cand_b,const int sign0,const int sign1,const union refmvs_refpair ref,const uint8_t * const sign_bias)237 static void add_compound_extended_candidate(refmvs_candidate *const same,
238                                             int *const same_count,
239                                             const refmvs_block *const cand_b,
240                                             const int sign0, const int sign1,
241                                             const union refmvs_refpair ref,
242                                             const uint8_t *const sign_bias)
243 {
244     refmvs_candidate *const diff = &same[2];
245     int *const diff_count = &same_count[2];
246 
247     for (int n = 0; n < 2; n++) {
248         const int cand_ref = cand_b->ref.ref[n];
249 
250         if (cand_ref <= 0) break;
251 
252         mv cand_mv = cand_b->mv.mv[n];
253         if (cand_ref == ref.ref[0]) {
254             if (same_count[0] < 2)
255                 same[same_count[0]++].mv.mv[0] = cand_mv;
256             if (diff_count[1] < 2) {
257                 if (sign1 ^ sign_bias[cand_ref - 1]) {
258                     cand_mv.y = -cand_mv.y;
259                     cand_mv.x = -cand_mv.x;
260                 }
261                 diff[diff_count[1]++].mv.mv[1] = cand_mv;
262             }
263         } else if (cand_ref == ref.ref[1]) {
264             if (same_count[1] < 2)
265                 same[same_count[1]++].mv.mv[1] = cand_mv;
266             if (diff_count[0] < 2) {
267                 if (sign0 ^ sign_bias[cand_ref - 1]) {
268                     cand_mv.y = -cand_mv.y;
269                     cand_mv.x = -cand_mv.x;
270                 }
271                 diff[diff_count[0]++].mv.mv[0] = cand_mv;
272             }
273         } else {
274             mv i_cand_mv = (union mv) {
275                 .x = -cand_mv.x,
276                 .y = -cand_mv.y
277             };
278 
279             if (diff_count[0] < 2) {
280                 diff[diff_count[0]++].mv.mv[0] =
281                     sign0 ^ sign_bias[cand_ref - 1] ?
282                     i_cand_mv : cand_mv;
283             }
284 
285             if (diff_count[1] < 2) {
286                 diff[diff_count[1]++].mv.mv[1] =
287                     sign1 ^ sign_bias[cand_ref - 1] ?
288                     i_cand_mv : cand_mv;
289             }
290         }
291     }
292 }
293 
add_single_extended_candidate(refmvs_candidate mvstack[8],int * const cnt,const refmvs_block * const cand_b,const int sign,const uint8_t * const sign_bias)294 static void add_single_extended_candidate(refmvs_candidate mvstack[8], int *const cnt,
295                                           const refmvs_block *const cand_b,
296                                           const int sign, const uint8_t *const sign_bias)
297 {
298     for (int n = 0; n < 2; n++) {
299         const int cand_ref = cand_b->ref.ref[n];
300 
301         if (cand_ref <= 0) break;
302         // we need to continue even if cand_ref == ref.ref[0], since
303         // the candidate could have been added as a globalmv variant,
304         // which changes the value
305         // FIXME if scan_{row,col}() returned a mask for the nearest
306         // edge, we could skip the appropriate ones here
307 
308         mv cand_mv = cand_b->mv.mv[n];
309         if (sign ^ sign_bias[cand_ref - 1]) {
310             cand_mv.y = -cand_mv.y;
311             cand_mv.x = -cand_mv.x;
312         }
313 
314         int m;
315         const int last = *cnt;
316         for (m = 0; m < last; m++)
317             if (cand_mv.n == mvstack[m].mv.mv[0].n)
318                 break;
319         if (m == last) {
320             mvstack[m].mv.mv[0] = cand_mv;
321             mvstack[m].weight = 2; // "minimal"
322             *cnt = last + 1;
323         }
324     }
325 }
326 
327 /*
328  * refmvs_frame allocates memory for one sbrow (32 blocks high, whole frame
329  * wide) of 4x4-resolution refmvs_block entries for spatial MV referencing.
330  * mvrefs_tile[] keeps a list of 35 (32 + 3 above) pointers into this memory,
331  * and each sbrow, the bottom entries (y=27/29/31) are exchanged with the top
332  * (-5/-3/-1) pointers by calling dav1d_refmvs_tile_sbrow_init() at the start
333  * of each tile/sbrow.
334  *
335  * For temporal MV referencing, we call dav1d_refmvs_save_tmvs() at the end of
336  * each tile/sbrow (when tile column threading is enabled), or at the start of
337  * each interleaved sbrow (i.e. once for all tile columns together, when tile
338  * column threading is disabled). This will copy the 4x4-resolution spatial MVs
339  * into 8x8-resolution refmvs_temporal_block structures. Then, for subsequent
340  * frames, at the start of each tile/sbrow (when tile column threading is
341  * enabled) or at the start of each interleaved sbrow (when tile column
342  * threading is disabled), we call load_tmvs(), which will project the MVs to
343  * their respective position in the current frame.
344  */
345 
dav1d_refmvs_find(const refmvs_tile * const rt,refmvs_candidate mvstack[8],int * const cnt,int * const ctx,const union refmvs_refpair ref,const enum BlockSize bs,const enum EdgeFlags edge_flags,const int by4,const int bx4)346 void dav1d_refmvs_find(const refmvs_tile *const rt,
347                        refmvs_candidate mvstack[8], int *const cnt,
348                        int *const ctx,
349                        const union refmvs_refpair ref, const enum BlockSize bs,
350                        const enum EdgeFlags edge_flags,
351                        const int by4, const int bx4)
352 {
353     const refmvs_frame *const rf = rt->rf;
354     const uint8_t *const b_dim = dav1d_block_dimensions[bs];
355     const int bw4 = b_dim[0], w4 = imin(imin(bw4, 16), rt->tile_col.end - bx4);
356     const int bh4 = b_dim[1], h4 = imin(imin(bh4, 16), rt->tile_row.end - by4);
357     mv gmv[2], tgmv[2];
358 
359     *cnt = 0;
360     assert(ref.ref[0] >=  0 && ref.ref[0] <= 8 &&
361            ref.ref[1] >= -1 && ref.ref[1] <= 8);
362     if (ref.ref[0] > 0) {
363         tgmv[0] = get_gmv_2d(&rf->frm_hdr->gmv[ref.ref[0] - 1],
364                              bx4, by4, bw4, bh4, rf->frm_hdr);
365         gmv[0] = rf->frm_hdr->gmv[ref.ref[0] - 1].type > DAV1D_WM_TYPE_TRANSLATION ?
366                  tgmv[0] : (mv) { .n = INVALID_MV };
367     } else {
368         tgmv[0] = (mv) { .n = 0 };
369         gmv[0] = (mv) { .n = INVALID_MV };
370     }
371     if (ref.ref[1] > 0) {
372         tgmv[1] = get_gmv_2d(&rf->frm_hdr->gmv[ref.ref[1] - 1],
373                              bx4, by4, bw4, bh4, rf->frm_hdr);
374         gmv[1] = rf->frm_hdr->gmv[ref.ref[1] - 1].type > DAV1D_WM_TYPE_TRANSLATION ?
375                  tgmv[1] : (mv) { .n = INVALID_MV };
376     }
377 
378     // top
379     int have_newmv = 0, have_col_mvs = 0, have_row_mvs = 0;
380     unsigned max_rows = 0, n_rows = ~0;
381     const refmvs_block *b_top;
382     if (by4 > rt->tile_row.start) {
383         max_rows = imin((by4 - rt->tile_row.start + 1) >> 1, 2 + (bh4 > 1));
384         b_top = &rt->r[(by4 & 31) - 1 + 5][bx4];
385         n_rows = scan_row(mvstack, cnt, ref, gmv, b_top,
386                           bw4, w4, max_rows, bw4 >= 16 ? 4 : 1,
387                           &have_newmv, &have_row_mvs);
388     }
389 
390     // left
391     unsigned max_cols = 0, n_cols = ~0U;
392     refmvs_block *const *b_left;
393     if (bx4 > rt->tile_col.start) {
394         max_cols = imin((bx4 - rt->tile_col.start + 1) >> 1, 2 + (bw4 > 1));
395         b_left = &rt->r[(by4 & 31) + 5];
396         n_cols = scan_col(mvstack, cnt, ref, gmv, b_left,
397                           bh4, h4, bx4 - 1, max_cols, bh4 >= 16 ? 4 : 1,
398                           &have_newmv, &have_col_mvs);
399     }
400 
401     // top/right
402     if (n_rows != ~0U && edge_flags & EDGE_I444_TOP_HAS_RIGHT &&
403         imax(bw4, bh4) <= 16 && bw4 + bx4 < rt->tile_col.end)
404     {
405         add_spatial_candidate(mvstack, cnt, 4, &b_top[bw4], ref, gmv,
406                               &have_newmv, &have_row_mvs);
407     }
408 
409     const int nearest_match = have_col_mvs + have_row_mvs;
410     const int nearest_cnt = *cnt;
411     for (int n = 0; n < nearest_cnt; n++)
412         mvstack[n].weight += 640;
413 
414     // temporal
415     int globalmv_ctx = rf->frm_hdr->use_ref_frame_mvs;
416     if (rf->use_ref_frame_mvs) {
417         const ptrdiff_t stride = rf->rp_stride;
418         const int by8 = by4 >> 1, bx8 = bx4 >> 1;
419         const refmvs_temporal_block *const rbi = &rt->rp_proj[(by8 & 15) * stride + bx8];
420         const refmvs_temporal_block *rb = rbi;
421         const int step_h = bw4 >= 16 ? 2 : 1, step_v = bh4 >= 16 ? 2 : 1;
422         const int w8 = imin((w4 + 1) >> 1, 8), h8 = imin((h4 + 1) >> 1, 8);
423         for (int y = 0; y < h8; y += step_v) {
424             for (int x = 0; x < w8; x+= step_h) {
425                 add_temporal_candidate(rf, mvstack, cnt, &rb[x], ref,
426                                        !(x | y) ? &globalmv_ctx : NULL, tgmv);
427             }
428             rb += stride * step_v;
429         }
430         if (imin(bw4, bh4) >= 2 && imax(bw4, bh4) < 16) {
431             const int bh8 = bh4 >> 1, bw8 = bw4 >> 1;
432             rb = &rbi[bh8 * stride];
433             const int has_bottom = by8 + bh8 < imin(rt->tile_row.end >> 1,
434                                                     (by8 & ~7) + 8);
435             if (has_bottom && bx8 - 1 >= imax(rt->tile_col.start >> 1, bx8 & ~7)) {
436                 add_temporal_candidate(rf, mvstack, cnt, &rb[-1], ref,
437                                        NULL, NULL);
438             }
439             if (bx8 + bw8 < imin(rt->tile_col.end >> 1, (bx8 & ~7) + 8)) {
440                 if (has_bottom) {
441                     add_temporal_candidate(rf, mvstack, cnt, &rb[bw8], ref,
442                                            NULL, NULL);
443                 }
444                 if (by8 + bh8 - 1 < imin(rt->tile_row.end >> 1, (by8 & ~7) + 8)) {
445                     add_temporal_candidate(rf, mvstack, cnt, &rb[bw8 - stride],
446                                            ref, NULL, NULL);
447                 }
448             }
449         }
450     }
451     assert(*cnt <= 8);
452 
453     // top/left (which, confusingly, is part of "secondary" references)
454     int have_dummy_newmv_match;
455     if ((n_rows | n_cols) != ~0U) {
456         add_spatial_candidate(mvstack, cnt, 4, &b_top[-1], ref, gmv,
457                               &have_dummy_newmv_match, &have_row_mvs);
458     }
459 
460     // "secondary" (non-direct neighbour) top & left edges
461     // what is different about secondary is that everything is now in 8x8 resolution
462     for (int n = 2; n <= 3; n++) {
463         if ((unsigned) n > n_rows && (unsigned) n <= max_rows) {
464             n_rows += scan_row(mvstack, cnt, ref, gmv,
465                                &rt->r[(((by4 & 31) - 2 * n + 1) | 1) + 5][bx4 | 1],
466                                bw4, w4, 1 + max_rows - n, bw4 >= 16 ? 4 : 2,
467                                &have_dummy_newmv_match, &have_row_mvs);
468         }
469 
470         if ((unsigned) n > n_cols && (unsigned) n <= max_cols) {
471             n_cols += scan_col(mvstack, cnt, ref, gmv, &rt->r[((by4 & 31) | 1) + 5],
472                                bh4, h4, (bx4 - n * 2 + 1) | 1,
473                                1 + max_cols - n, bh4 >= 16 ? 4 : 2,
474                                &have_dummy_newmv_match, &have_col_mvs);
475         }
476     }
477     assert(*cnt <= 8);
478 
479     const int ref_match_count = have_col_mvs + have_row_mvs;
480 
481     // context build-up
482     int refmv_ctx, newmv_ctx;
483     switch (nearest_match) {
484     case 0:
485         refmv_ctx = imin(2, ref_match_count);
486         newmv_ctx = ref_match_count > 0;
487         break;
488     case 1:
489         refmv_ctx = imin(ref_match_count * 3, 4);
490         newmv_ctx = 3 - have_newmv;
491         break;
492     case 2:
493         refmv_ctx = 5;
494         newmv_ctx = 5 - have_newmv;
495         break;
496     }
497 
498     // sorting (nearest, then "secondary")
499     int len = nearest_cnt;
500     while (len) {
501         int last = 0;
502         for (int n = 1; n < len; n++) {
503             if (mvstack[n - 1].weight < mvstack[n].weight) {
504 #define EXCHANGE(a, b) do { refmvs_candidate tmp = a; a = b; b = tmp; } while (0)
505                 EXCHANGE(mvstack[n - 1], mvstack[n]);
506                 last = n;
507             }
508         }
509         len = last;
510     }
511     len = *cnt;
512     while (len > nearest_cnt) {
513         int last = nearest_cnt;
514         for (int n = nearest_cnt + 1; n < len; n++) {
515             if (mvstack[n - 1].weight < mvstack[n].weight) {
516                 EXCHANGE(mvstack[n - 1], mvstack[n]);
517 #undef EXCHANGE
518                 last = n;
519             }
520         }
521         len = last;
522     }
523 
524     if (ref.ref[1] > 0) {
525         if (*cnt < 2) {
526             const int sign0 = rf->sign_bias[ref.ref[0] - 1];
527             const int sign1 = rf->sign_bias[ref.ref[1] - 1];
528             const int sz4 = imin(w4, h4);
529             refmvs_candidate *const same = &mvstack[*cnt];
530             int same_count[4] = { 0 };
531 
532             // non-self references in top
533             if (n_rows != ~0U) for (int x = 0; x < sz4;) {
534                 const refmvs_block *const cand_b = &b_top[x];
535                 add_compound_extended_candidate(same, same_count, cand_b,
536                                                 sign0, sign1, ref, rf->sign_bias);
537                 x += dav1d_block_dimensions[cand_b->bs][0];
538             }
539 
540             // non-self references in left
541             if (n_cols != ~0U) for (int y = 0; y < sz4;) {
542                 const refmvs_block *const cand_b = &b_left[y][bx4 - 1];
543                 add_compound_extended_candidate(same, same_count, cand_b,
544                                                 sign0, sign1, ref, rf->sign_bias);
545                 y += dav1d_block_dimensions[cand_b->bs][1];
546             }
547 
548             refmvs_candidate *const diff = &same[2];
549             const int *const diff_count = &same_count[2];
550 
551             // merge together
552             for (int n = 0; n < 2; n++) {
553                 int m = same_count[n];
554 
555                 if (m >= 2) continue;
556 
557                 const int l = diff_count[n];
558                 if (l) {
559                     same[m].mv.mv[n] = diff[0].mv.mv[n];
560                     if (++m == 2) continue;
561                     if (l == 2) {
562                         same[1].mv.mv[n] = diff[1].mv.mv[n];
563                         continue;
564                     }
565                 }
566                 do {
567                     same[m].mv.mv[n] = tgmv[n];
568                 } while (++m < 2);
569             }
570 
571             // if the first extended was the same as the non-extended one,
572             // then replace it with the second extended one
573             int n = *cnt;
574             if (n == 1 && mvstack[0].mv.n == same[0].mv.n)
575                 mvstack[1].mv = mvstack[2].mv;
576             do {
577                 mvstack[n].weight = 2;
578             } while (++n < 2);
579             *cnt = 2;
580         }
581 
582         // clamping
583         const int left = -(bx4 + bw4 + 4) * 4 * 8;
584         const int right = (rf->iw4 - bx4 + 4) * 4 * 8;
585         const int top = -(by4 + bh4 + 4) * 4 * 8;
586         const int bottom = (rf->ih4 - by4 + 4) * 4 * 8;
587 
588         const int n_refmvs = *cnt;
589         int n = 0;
590         do {
591             mvstack[n].mv.mv[0].x = iclip(mvstack[n].mv.mv[0].x, left, right);
592             mvstack[n].mv.mv[0].y = iclip(mvstack[n].mv.mv[0].y, top, bottom);
593             mvstack[n].mv.mv[1].x = iclip(mvstack[n].mv.mv[1].x, left, right);
594             mvstack[n].mv.mv[1].y = iclip(mvstack[n].mv.mv[1].y, top, bottom);
595         } while (++n < n_refmvs);
596 
597         switch (refmv_ctx >> 1) {
598         case 0:
599             *ctx = imin(newmv_ctx, 1);
600             break;
601         case 1:
602             *ctx = 1 + imin(newmv_ctx, 3);
603             break;
604         case 2:
605             *ctx = iclip(3 + newmv_ctx, 4, 7);
606             break;
607         }
608 
609         return;
610     } else if (*cnt < 2 && ref.ref[0] > 0) {
611         const int sign = rf->sign_bias[ref.ref[0] - 1];
612         const int sz4 = imin(w4, h4);
613 
614         // non-self references in top
615         if (n_rows != ~0U) for (int x = 0; x < sz4 && *cnt < 2;) {
616             const refmvs_block *const cand_b = &b_top[x];
617             add_single_extended_candidate(mvstack, cnt, cand_b, sign, rf->sign_bias);
618             x += dav1d_block_dimensions[cand_b->bs][0];
619         }
620 
621         // non-self references in left
622         if (n_cols != ~0U) for (int y = 0; y < sz4 && *cnt < 2;) {
623             const refmvs_block *const cand_b = &b_left[y][bx4 - 1];
624             add_single_extended_candidate(mvstack, cnt, cand_b, sign, rf->sign_bias);
625             y += dav1d_block_dimensions[cand_b->bs][1];
626         }
627     }
628     assert(*cnt <= 8);
629 
630     // clamping
631     int n_refmvs = *cnt;
632     if (n_refmvs) {
633         const int left = -(bx4 + bw4 + 4) * 4 * 8;
634         const int right = (rf->iw4 - bx4 + 4) * 4 * 8;
635         const int top = -(by4 + bh4 + 4) * 4 * 8;
636         const int bottom = (rf->ih4 - by4 + 4) * 4 * 8;
637 
638         int n = 0;
639         do {
640             mvstack[n].mv.mv[0].x = iclip(mvstack[n].mv.mv[0].x, left, right);
641             mvstack[n].mv.mv[0].y = iclip(mvstack[n].mv.mv[0].y, top, bottom);
642         } while (++n < n_refmvs);
643     }
644 
645     for (int n = *cnt; n < 2; n++)
646         mvstack[n].mv.mv[0] = tgmv[0];
647 
648     *ctx = (refmv_ctx << 4) | (globalmv_ctx << 3) | newmv_ctx;
649 }
650 
dav1d_refmvs_tile_sbrow_init(refmvs_tile * const rt,const refmvs_frame * const rf,const int tile_col_start4,const int tile_col_end4,const int tile_row_start4,const int tile_row_end4,const int sby,int tile_row_idx)651 void dav1d_refmvs_tile_sbrow_init(refmvs_tile *const rt, const refmvs_frame *const rf,
652                                   const int tile_col_start4, const int tile_col_end4,
653                                   const int tile_row_start4, const int tile_row_end4,
654                                   const int sby, int tile_row_idx)
655 {
656     if (rf->n_tile_threads == 1) tile_row_idx = 0;
657     rt->rp_proj = &rf->rp_proj[16 * rf->rp_stride * tile_row_idx];
658     refmvs_block *r = &rf->r[35 * rf->r_stride * tile_row_idx];
659     const int sbsz = rf->sbsz;
660     const int off = (sbsz * sby) & 16;
661     for (int i = 0; i < sbsz; i++, r += rf->r_stride)
662         rt->r[off + 5 + i] = r;
663     rt->r[off + 0] = r;
664     r += rf->r_stride;
665     rt->r[off + 1] = NULL;
666     rt->r[off + 2] = r;
667     r += rf->r_stride;
668     rt->r[off + 3] = NULL;
669     rt->r[off + 4] = r;
670     if (sby & 1) {
671 #define EXCHANGE(a, b) do { void *const tmp = a; a = b; b = tmp; } while (0)
672         EXCHANGE(rt->r[off + 0], rt->r[off + sbsz + 0]);
673         EXCHANGE(rt->r[off + 2], rt->r[off + sbsz + 2]);
674         EXCHANGE(rt->r[off + 4], rt->r[off + sbsz + 4]);
675 #undef EXCHANGE
676     }
677 
678     rt->rf = rf;
679     rt->tile_row.start = tile_row_start4;
680     rt->tile_row.end = imin(tile_row_end4, rf->ih4);
681     rt->tile_col.start = tile_col_start4;
682     rt->tile_col.end = imin(tile_col_end4, rf->iw4);
683 }
684 
dav1d_refmvs_load_tmvs(const refmvs_frame * const rf,int tile_row_idx,const int col_start8,const int col_end8,const int row_start8,int row_end8)685 void dav1d_refmvs_load_tmvs(const refmvs_frame *const rf, int tile_row_idx,
686                             const int col_start8, const int col_end8,
687                             const int row_start8, int row_end8)
688 {
689     if (rf->n_tile_threads == 1) tile_row_idx = 0;
690     assert(row_start8 >= 0);
691     assert((unsigned) (row_end8 - row_start8) <= 16U);
692     row_end8 = imin(row_end8, rf->ih8);
693     const int col_start8i = imax(col_start8 - 8, 0);
694     const int col_end8i = imin(col_end8 + 8, rf->iw8);
695 
696     const ptrdiff_t stride = rf->rp_stride;
697     refmvs_temporal_block *rp_proj =
698         &rf->rp_proj[16 * stride * tile_row_idx + (row_start8 & 15) * stride];
699     for (int y = row_start8; y < row_end8; y++) {
700         for (int x = col_start8; x < col_end8; x++)
701             rp_proj[x].mv.n = INVALID_MV;
702         rp_proj += stride;
703     }
704 
705     rp_proj = &rf->rp_proj[16 * stride * tile_row_idx];
706     for (int n = 0; n < rf->n_mfmvs; n++) {
707         const int ref2cur = rf->mfmv_ref2cur[n];
708         if (ref2cur == INT_MIN) continue;
709 
710         const int ref = rf->mfmv_ref[n];
711         const int ref_sign = ref - 4;
712         const refmvs_temporal_block *r = &rf->rp_ref[ref][row_start8 * stride];
713         for (int y = row_start8; y < row_end8; y++) {
714             const int y_sb_align = y & ~7;
715             const int y_proj_start = imax(y_sb_align, row_start8);
716             const int y_proj_end = imin(y_sb_align + 8, row_end8);
717             for (int x = col_start8i; x < col_end8i; x++) {
718                 const refmvs_temporal_block *rb = &r[x];
719                 const int b_ref = rb->ref;
720                 if (!b_ref) continue;
721                 const int ref2ref = rf->mfmv_ref2ref[n][b_ref - 1];
722                 if (!ref2ref) continue;
723                 const mv b_mv = rb->mv;
724                 const mv offset = mv_projection(b_mv, ref2cur, ref2ref);
725                 int pos_x = x + apply_sign(abs(offset.x) >> 6,
726                                            offset.x ^ ref_sign);
727                 const int pos_y = y + apply_sign(abs(offset.y) >> 6,
728                                                  offset.y ^ ref_sign);
729                 if (pos_y >= y_proj_start && pos_y < y_proj_end) {
730                     const ptrdiff_t pos = (pos_y & 15) * stride;
731                     for (;;) {
732                         const int x_sb_align = x & ~7;
733                         if (pos_x >= imax(x_sb_align - 8, col_start8) &&
734                             pos_x < imin(x_sb_align + 16, col_end8))
735                         {
736                             rp_proj[pos + pos_x].mv = rb->mv;
737                             rp_proj[pos + pos_x].ref = ref2ref;
738                         }
739                         if (++x >= col_end8i) break;
740                         rb++;
741                         if (rb->ref != b_ref || rb->mv.n != b_mv.n) break;
742                         pos_x++;
743                     }
744                 } else {
745                     for (;;) {
746                         if (++x >= col_end8i) break;
747                         rb++;
748                         if (rb->ref != b_ref || rb->mv.n != b_mv.n) break;
749                     }
750                 }
751                 x--;
752             }
753             r += stride;
754         }
755     }
756 }
757 
dav1d_refmvs_save_tmvs(const refmvs_tile * const rt,const int col_start8,int col_end8,const int row_start8,int row_end8)758 void dav1d_refmvs_save_tmvs(const refmvs_tile *const rt,
759                             const int col_start8, int col_end8,
760                             const int row_start8, int row_end8)
761 {
762     const refmvs_frame *const rf = rt->rf;
763 
764     assert(row_start8 >= 0);
765     assert((unsigned) (row_end8 - row_start8) <= 16U);
766     row_end8 = imin(row_end8, rf->ih8);
767     col_end8 = imin(col_end8, rf->iw8);
768 
769     const ptrdiff_t stride = rf->rp_stride;
770     const uint8_t *const ref_sign = rf->mfmv_sign;
771     refmvs_temporal_block *rp = &rf->rp[row_start8 * stride];
772     for (int y = row_start8; y < row_end8; y++) {
773         const refmvs_block *const b = rt->r[6 + (y & 15) * 2];
774 
775         for (int x = col_start8; x < col_end8;) {
776             const refmvs_block *const cand_b = &b[x * 2 + 1];
777             const int bw8 = (dav1d_block_dimensions[cand_b->bs][0] + 1) >> 1;
778 
779             if (cand_b->ref.ref[1] > 0 && ref_sign[cand_b->ref.ref[1] - 1] &&
780                 (abs(cand_b->mv.mv[1].y) | abs(cand_b->mv.mv[1].x)) < 4096)
781             {
782                 for (int n = 0; n < bw8; n++, x++)
783                     rp[x] = (refmvs_temporal_block) { .mv = cand_b->mv.mv[1],
784                                                       .ref = cand_b->ref.ref[1] };
785             } else if (cand_b->ref.ref[0] > 0 && ref_sign[cand_b->ref.ref[0] - 1] &&
786                        (abs(cand_b->mv.mv[0].y) | abs(cand_b->mv.mv[0].x)) < 4096)
787             {
788                 for (int n = 0; n < bw8; n++, x++)
789                     rp[x] = (refmvs_temporal_block) { .mv = cand_b->mv.mv[0],
790                                                       .ref = cand_b->ref.ref[0] };
791             } else {
792                 for (int n = 0; n < bw8; n++, x++)
793                     rp[x].ref = 0; // "invalid"
794             }
795         }
796         rp += stride;
797     }
798 }
799 
dav1d_refmvs_init_frame(refmvs_frame * const rf,const Dav1dSequenceHeader * const seq_hdr,const Dav1dFrameHeader * const frm_hdr,const unsigned ref_poc[7],refmvs_temporal_block * const rp,const unsigned ref_ref_poc[7][7],refmvs_temporal_block * const rp_ref[7],const int n_tile_threads)800 int dav1d_refmvs_init_frame(refmvs_frame *const rf,
801                             const Dav1dSequenceHeader *const seq_hdr,
802                             const Dav1dFrameHeader *const frm_hdr,
803                             const unsigned ref_poc[7],
804                             refmvs_temporal_block *const rp,
805                             const unsigned ref_ref_poc[7][7],
806                             /*const*/ refmvs_temporal_block *const rp_ref[7],
807                             const int n_tile_threads)
808 {
809     rf->sbsz = 16 << seq_hdr->sb128;
810     rf->frm_hdr = frm_hdr;
811     rf->iw8 = (frm_hdr->width[0] + 7) >> 3;
812     rf->ih8 = (frm_hdr->height + 7) >> 3;
813     rf->iw4 = rf->iw8 << 1;
814     rf->ih4 = rf->ih8 << 1;
815 
816     const ptrdiff_t r_stride = ((frm_hdr->width[0] + 127) & ~127) >> 2;
817     const int n_tile_rows = n_tile_threads > 1 ? frm_hdr->tiling.rows : 1;
818     if (r_stride != rf->r_stride || n_tile_rows != rf->n_tile_rows) {
819         if (rf->r) free(rf->r);
820         rf->r = malloc(sizeof(*rf->r) * 35 * r_stride * n_tile_rows);
821         if (!rf->r) return DAV1D_ERR(ENOMEM);
822         rf->r_stride = r_stride;
823     }
824 
825     const ptrdiff_t rp_stride = r_stride >> 1;
826     if (rp_stride != rf->rp_stride || n_tile_rows != rf->n_tile_rows) {
827         if (rf->rp_proj) free(rf->rp_proj);
828         rf->rp_proj = malloc(sizeof(*rf->rp_proj) * 16 * rp_stride * n_tile_rows);
829         if (!rf->rp_proj) return DAV1D_ERR(ENOMEM);
830         rf->rp_stride = rp_stride;
831     }
832     rf->n_tile_rows = n_tile_rows;
833     rf->n_tile_threads = n_tile_threads;
834     rf->rp = rp;
835     rf->rp_ref = rp_ref;
836     const unsigned poc = frm_hdr->frame_offset;
837     for (int i = 0; i < 7; i++) {
838         const int poc_diff = get_poc_diff(seq_hdr->order_hint_n_bits,
839                                           ref_poc[i], poc);
840         rf->sign_bias[i] = poc_diff > 0;
841         rf->mfmv_sign[i] = poc_diff < 0;
842         rf->pocdiff[i] = iclip(get_poc_diff(seq_hdr->order_hint_n_bits,
843                                             poc, ref_poc[i]), -31, 31);
844     }
845 
846     // temporal MV setup
847     rf->n_mfmvs = 0;
848     if (frm_hdr->use_ref_frame_mvs && seq_hdr->order_hint_n_bits) {
849         int total = 2;
850         if (rp_ref[0] && ref_ref_poc[0][6] != ref_poc[3] /* alt-of-last != gold */) {
851             rf->mfmv_ref[rf->n_mfmvs++] = 0; // last
852             total = 3;
853         }
854         if (rp_ref[4] && get_poc_diff(seq_hdr->order_hint_n_bits, ref_poc[4],
855                                       frm_hdr->frame_offset) > 0)
856         {
857             rf->mfmv_ref[rf->n_mfmvs++] = 4; // bwd
858         }
859         if (rp_ref[5] && get_poc_diff(seq_hdr->order_hint_n_bits, ref_poc[5],
860                                       frm_hdr->frame_offset) > 0)
861         {
862             rf->mfmv_ref[rf->n_mfmvs++] = 5; // altref2
863         }
864         if (rf->n_mfmvs < total && rp_ref[6] &&
865             get_poc_diff(seq_hdr->order_hint_n_bits, ref_poc[6],
866                          frm_hdr->frame_offset) > 0)
867         {
868             rf->mfmv_ref[rf->n_mfmvs++] = 6; // altref
869         }
870         if (rf->n_mfmvs < total && rp_ref[1])
871             rf->mfmv_ref[rf->n_mfmvs++] = 1; // last2
872 
873         for (int n = 0; n < rf->n_mfmvs; n++) {
874             const unsigned rpoc = ref_poc[rf->mfmv_ref[n]];
875             const int diff1 = get_poc_diff(seq_hdr->order_hint_n_bits,
876                                            rpoc, frm_hdr->frame_offset);
877             if (abs(diff1) > 31) {
878                 rf->mfmv_ref2cur[n] = INT_MIN;
879             } else {
880                 rf->mfmv_ref2cur[n] = rf->mfmv_ref[n] < 4 ? -diff1 : diff1;
881                 for (int m = 0; m < 7; m++) {
882                     const unsigned rrpoc = ref_ref_poc[rf->mfmv_ref[n]][m];
883                     const int diff2 = get_poc_diff(seq_hdr->order_hint_n_bits,
884                                                    rpoc, rrpoc);
885                     // unsigned comparison also catches the < 0 case
886                     rf->mfmv_ref2ref[n][m] = (unsigned) diff2 > 31U ? 0 : diff2;
887                 }
888             }
889         }
890     }
891     rf->use_ref_frame_mvs = rf->n_mfmvs > 0;
892 
893     return 0;
894 }
895 
dav1d_refmvs_init(refmvs_frame * const rf)896 void dav1d_refmvs_init(refmvs_frame *const rf) {
897     rf->r = NULL;
898     rf->r_stride = 0;
899     rf->rp_proj = NULL;
900     rf->rp_stride = 0;
901 }
902 
dav1d_refmvs_clear(refmvs_frame * const rf)903 void dav1d_refmvs_clear(refmvs_frame *const rf) {
904     if (rf->r) free(rf->r);
905     if (rf->rp_proj) free(rf->rp_proj);
906 }
907