1 /*
2 * Copyright(c) 2019 Netflix, Inc.
3 * Copyright (c) 2016, Alliance for Open Media. All rights reserved
4 *
5 * This source code is subject to the terms of the BSD 2 Clause License and
6 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
7 * was not distributed with this source code in the LICENSE file, you can
8 * obtain it at https://www.aomedia.org/license/software-license. If the Alliance for Open
9 * Media Patent License 1.0 was not distributed with this source code in the
10 * PATENTS file, you can obtain it at https://www.aomedia.org/license/patent-license.
11 */
12 
13 #include "EbDecParseInterBlock.h"
14 #include "EbCommonUtils.h"
15 #include "EbWarpedMotion.h"
16 
17 typedef const int (*ColorCost)[PALETTE_SIZES][PALETTE_COLOR_INDEX_CONTEXTS][PALETTE_COLORS];
18 typedef AomCdfProb (*MapCdf)[PALETTE_SIZES][PALETTE_COLOR_INDEX_CONTEXTS];
19 
20 #define MAX_COLOR_CONTEXT_HASH 8
21 #define NUM_PALETTE_NEIGHBORS 3 // left, top-left and top.
22 #define COLOR_MAP_STRIDE 128 // worst case
23 
24 // Negative values are invalid
25 extern int palette_color_index_context_lookup[MAX_COLOR_CONTEXT_HASH + 1];
26 
27 static uint16_t compound_mode_ctx_map[3][COMP_NEWMV_CTXS] = {
28     {0, 1, 1, 1, 1},
29     {1, 2, 3, 4, 4},
30     {4, 4, 5, 6, 7},
31 };
32 
svt_collect_neighbors_ref_counts(PartitionInfo * pi)33 static INLINE void svt_collect_neighbors_ref_counts(PartitionInfo *pi) {
34     ZERO_ARRAY(&pi->neighbors_ref_counts[0], sizeof(pi->neighbors_ref_counts[0]) * REF_FRAMES);
35 
36     uint8_t *const ref_counts = pi->neighbors_ref_counts;
37 
38     const BlockModeInfo *const above_mbmi     = pi->above_mbmi;
39     const BlockModeInfo *const left_mbmi      = pi->left_mbmi;
40     const int                  above_in_image = pi->up_available;
41     const int                  left_in_image  = pi->left_available;
42 
43     // Above neighbor
44     if (above_in_image && is_inter_block(above_mbmi)) {
45         ref_counts[above_mbmi->ref_frame[0]]++;
46         if (has_second_ref(above_mbmi))
47             ref_counts[above_mbmi->ref_frame[1]]++;
48     }
49 
50     // Left neighbor
51     if (left_in_image && is_inter_block(left_mbmi)) {
52         ref_counts[left_mbmi->ref_frame[0]]++;
53         if (has_second_ref(left_mbmi))
54             ref_counts[left_mbmi->ref_frame[1]]++;
55     }
56 }
57 
is_inside(TileInfo * tile,int mi_col,int mi_row)58 static INLINE int is_inside(TileInfo *tile, int mi_col, int mi_row) {
59     return (mi_col >= tile->mi_col_start && mi_col < tile->mi_col_end &&
60             mi_row >= tile->mi_row_start && mi_row < tile->mi_row_end);
61 }
62 
get_reference_mode_context(const PartitionInfo * xd)63 static int get_reference_mode_context(const PartitionInfo *xd) {
64     int                        ctx;
65     const BlockModeInfo *const above_mbmi = xd->above_mbmi;
66     const BlockModeInfo *const left_mbmi  = xd->left_mbmi;
67     const int                  has_above  = xd->up_available;
68     const int                  has_left   = xd->left_available;
69 
70     // Note:
71     // The mode info data structure has a one element border above and to the
72     // left of the entries corresponding to real macroblocks.
73     // The prediction flags in these dummy entries are initialized to 0.
74     if (has_above && has_left) { // both edges available
75         if (!has_second_ref(above_mbmi) && !has_second_ref(left_mbmi))
76             // neither edge uses comp pred (0/1)
77             ctx = IS_BACKWARD_REF_FRAME(above_mbmi->ref_frame[0]) ^
78                 IS_BACKWARD_REF_FRAME(left_mbmi->ref_frame[0]);
79         else if (!has_second_ref(above_mbmi))
80             // one of two edges uses comp pred (2/3)
81             ctx = 2 +
82                 (IS_BACKWARD_REF_FRAME(above_mbmi->ref_frame[0]) || !is_inter_block(above_mbmi));
83         else if (!has_second_ref(left_mbmi))
84             // one of two edges uses comp pred (2/3)
85             ctx = 2 +
86                 (IS_BACKWARD_REF_FRAME(left_mbmi->ref_frame[0]) || !is_inter_block(left_mbmi));
87         else // both edges use comp pred (4)
88             ctx = 4;
89     } else if (has_above || has_left) { // one edge available
90         const BlockModeInfo *edge_mbmi = has_above ? above_mbmi : left_mbmi;
91 
92         if (!has_second_ref(edge_mbmi))
93             // edge does not use comp pred (0/1)
94             ctx = IS_BACKWARD_REF_FRAME(edge_mbmi->ref_frame[0]);
95         else
96             // edge uses comp pred (3)
97             ctx = 3;
98     } else { // no edges available (1)
99         ctx = 1;
100     }
101     assert(ctx >= 0 && ctx < COMP_INTER_CONTEXTS);
102     return ctx;
103 }
104 
get_pred_context_comp_ref_p(PartitionInfo * pi)105 static int32_t get_pred_context_comp_ref_p(PartitionInfo *pi) {
106     const uint8_t *const ref_counts = &pi->neighbors_ref_counts[0];
107 
108     // Count of LAST + LAST2
109     const int32_t last_last2_count = ref_counts[LAST_FRAME] + ref_counts[LAST2_FRAME];
110     // Count of LAST3 + GOLDEN
111     const int32_t last3_gld_count = ref_counts[LAST3_FRAME] + ref_counts[GOLDEN_FRAME];
112 
113     const int32_t pred_context = (last_last2_count == last3_gld_count)
114         ? 1
115         : ((last_last2_count < last3_gld_count) ? 0 : 2);
116 
117     assert(pred_context >= 0 && pred_context < REF_CONTEXTS);
118     return pred_context;
119 }
120 
get_pred_context_comp_bwdref_p(PartitionInfo * pi)121 static int32_t get_pred_context_comp_bwdref_p(PartitionInfo *pi) {
122     const uint8_t *const ref_counts = &pi->neighbors_ref_counts[0];
123 
124     // Counts of BWDREF, ALTREF2, or ALTREF frames (b, A2, or A)
125     const int32_t brfarf2_count = ref_counts[BWDREF_FRAME] + ref_counts[ALTREF2_FRAME];
126     const int32_t arf_count     = ref_counts[ALTREF_FRAME];
127 
128     const int32_t pred_context = (brfarf2_count == arf_count)
129         ? 1
130         : ((brfarf2_count < arf_count) ? 0 : 2);
131 
132     assert(pred_context >= 0 && pred_context < REF_CONTEXTS);
133     return pred_context;
134 }
135 
get_pred_context_comp_bwdref_p1(PartitionInfo * pi)136 static int32_t get_pred_context_comp_bwdref_p1(PartitionInfo *pi) {
137     const uint8_t *const ref_counts = &pi->neighbors_ref_counts[0];
138 
139     // Count of BWDREF frames (b)
140     const int32_t brf_count = ref_counts[BWDREF_FRAME];
141     // Count of ALTREF2 frames (A2)
142     const int32_t arf2_count = ref_counts[ALTREF2_FRAME];
143 
144     const int32_t pred_context = (brf_count == arf2_count) ? 1 : ((brf_count < arf2_count) ? 0 : 2);
145 
146     assert(pred_context >= 0 && pred_context < REF_CONTEXTS);
147     return pred_context;
148 }
149 
get_pred_context_uni_comp_ref_p2(PartitionInfo * pi)150 static int get_pred_context_uni_comp_ref_p2(PartitionInfo *pi) {
151     const uint8_t *const ref_counts = &pi->neighbors_ref_counts[0];
152 
153     // Count of LAST3
154     const int last3_count = ref_counts[LAST3_FRAME];
155     // Count of GOLDEN
156     const int gld_count = ref_counts[GOLDEN_FRAME];
157 
158     const int pred_context = (last3_count == gld_count) ? 1 : ((last3_count < gld_count) ? 0 : 2);
159 
160     assert(pred_context >= 0 && pred_context < UNI_COMP_REF_CONTEXTS);
161     return pred_context;
162 }
163 
get_pred_context_uni_comp_ref_p1(PartitionInfo * pi)164 static int get_pred_context_uni_comp_ref_p1(PartitionInfo *pi) {
165     const uint8_t *const ref_counts = &pi->neighbors_ref_counts[0];
166 
167     // Count of LAST2
168     const int last2_count = ref_counts[LAST2_FRAME];
169     // Count of LAST3 or GOLDEN
170     const int last3_or_gld_count = ref_counts[LAST3_FRAME] + ref_counts[GOLDEN_FRAME];
171 
172     const int pred_context = (last2_count == last3_or_gld_count)
173         ? 1
174         : ((last2_count < last3_or_gld_count) ? 0 : 2);
175 
176     assert(pred_context >= 0 && pred_context < UNI_COMP_REF_CONTEXTS);
177     return pred_context;
178 }
179 
get_pred_context_uni_comp_ref_p(PartitionInfo * pi)180 static int get_pred_context_uni_comp_ref_p(PartitionInfo *pi) {
181     const uint8_t *const ref_counts = &pi->neighbors_ref_counts[0];
182 
183     // Count of forward references (L, L2, L3, or G)
184     const int frf_count = ref_counts[LAST_FRAME] + ref_counts[LAST2_FRAME] +
185         ref_counts[LAST3_FRAME] + ref_counts[GOLDEN_FRAME];
186     // Count of backward references (b or A)
187     const int brf_count = ref_counts[BWDREF_FRAME] + ref_counts[ALTREF2_FRAME] +
188         ref_counts[ALTREF_FRAME];
189 
190     const int pred_context = (frf_count == brf_count) ? 1 : ((frf_count < brf_count) ? 0 : 2);
191 
192     assert(pred_context >= 0 && pred_context < UNI_COMP_REF_CONTEXTS);
193     return pred_context;
194 }
195 
get_pred_context_single_ref_p1(PartitionInfo * pi)196 static int32_t get_pred_context_single_ref_p1(PartitionInfo *pi) {
197     const uint8_t *const ref_counts = &pi->neighbors_ref_counts[0];
198 
199     // Count of forward reference frames
200     const int32_t fwd_count = ref_counts[LAST_FRAME] + ref_counts[LAST2_FRAME] +
201         ref_counts[LAST3_FRAME] + ref_counts[GOLDEN_FRAME];
202     // Count of backward reference frames
203     const int32_t bwd_count = ref_counts[BWDREF_FRAME] + ref_counts[ALTREF2_FRAME] +
204         ref_counts[ALTREF_FRAME];
205 
206     const int32_t pred_context = (fwd_count == bwd_count) ? 1 : ((fwd_count < bwd_count) ? 0 : 2);
207 
208     assert(pred_context >= 0 && pred_context < REF_CONTEXTS);
209     return pred_context;
210 }
211 
get_pred_context_single_ref_p4(PartitionInfo * pi)212 static int32_t get_pred_context_single_ref_p4(PartitionInfo *pi) {
213     const uint8_t *const ref_counts = &pi->neighbors_ref_counts[0];
214 
215     // Count of LAST
216     const int32_t last_count = ref_counts[LAST_FRAME];
217     // Count of LAST2
218     const int32_t last2_count = ref_counts[LAST2_FRAME];
219 
220     const int32_t pred_context = (last_count == last2_count) ? 1
221                                                              : ((last_count < last2_count) ? 0 : 2);
222 
223     assert(pred_context >= 0 && pred_context < REF_CONTEXTS);
224     return pred_context;
225 }
226 
get_pred_context_last3_or_gld(PartitionInfo * pi)227 static int32_t get_pred_context_last3_or_gld(PartitionInfo *pi) {
228     const uint8_t *const ref_counts = &pi->neighbors_ref_counts[0];
229 
230     // Count of LAST3
231     const int32_t last3_count = ref_counts[LAST3_FRAME];
232     // Count of GOLDEN
233     const int32_t gld_count = ref_counts[GOLDEN_FRAME];
234 
235     const int32_t pred_context = (last3_count == gld_count) ? 1
236                                                             : ((last3_count < gld_count) ? 0 : 2);
237 
238     assert(pred_context >= 0 && pred_context < REF_CONTEXTS);
239     return pred_context;
240 }
241 
read_ref_frames(ParseCtxt * parse_ctxt,PartitionInfo * const pi)242 static void read_ref_frames(ParseCtxt *parse_ctxt, PartitionInfo *const pi) {
243     SvtReader *         r          = &parse_ctxt->r;
244     int                 segment_id = pi->mi->segment_id;
245     MvReferenceFrame *  ref_frame  = pi->mi->ref_frame;
246     AomCdfProb *        cdf;
247     SegmentationParams *seg_params = &parse_ctxt->frame_header->segmentation_params;
248     if (pi->mi->skip_mode) {
249         ref_frame[0] = (MvReferenceFrame)(
250             parse_ctxt->frame_header->skip_mode_params.ref_frame_idx_0);
251         ref_frame[1] = (MvReferenceFrame)(
252             parse_ctxt->frame_header->skip_mode_params.ref_frame_idx_1);
253     } else if (seg_feature_active(seg_params, segment_id, SEG_LVL_REF_FRAME)) {
254         ref_frame[0] = (MvReferenceFrame)get_segdata(seg_params, segment_id, SEG_LVL_REF_FRAME);
255         ref_frame[1] = NONE_FRAME;
256     } else if (seg_feature_active(seg_params, segment_id, SEG_LVL_SKIP) ||
257                seg_feature_active(seg_params, segment_id, SEG_LVL_GLOBALMV)) {
258         ref_frame[0] = LAST_FRAME;
259         ref_frame[1] = NONE_FRAME;
260     } else {
261         ReferenceMode mode = SINGLE_REFERENCE;
262         int           bw4  = mi_size_wide[pi->mi->sb_type];
263         int           bh4  = mi_size_high[pi->mi->sb_type];
264         if (parse_ctxt->frame_header->reference_mode == REFERENCE_MODE_SELECT &&
265             (AOMMIN(bw4, bh4) >= 2)) {
266             const int ctx = get_reference_mode_context(pi);
267             mode          = (ReferenceMode)svt_read_symbol(
268                 r, parse_ctxt->cur_tile_ctx.comp_inter_cdf[ctx], 2, ACCT_STR);
269         }
270 
271         if (mode == COMPOUND_REFERENCE) {
272             int                     pred_context;
273             const int               ctx           = get_comp_reference_type_context(pi);
274             const CompReferenceType comp_ref_type = (CompReferenceType)svt_read_symbol(
275                 r, parse_ctxt->cur_tile_ctx.comp_ref_type_cdf[ctx], 2, ACCT_STR);
276 
277             if (comp_ref_type == UNIDIR_COMP_REFERENCE) {
278                 pred_context = get_pred_context_uni_comp_ref_p(pi);
279                 uint16_t bit = (uint16_t)svt_read_symbol(
280                     r, parse_ctxt->cur_tile_ctx.uni_comp_ref_cdf[pred_context][0], 2, ACCT_STR);
281                 if (bit) {
282                     ref_frame[0] = BWDREF_FRAME;
283                     ref_frame[1] = ALTREF_FRAME;
284                 } else {
285                     pred_context  = get_pred_context_uni_comp_ref_p1(pi);
286                     uint16_t bit1 = (uint16_t)svt_read_symbol(
287                         r, parse_ctxt->cur_tile_ctx.uni_comp_ref_cdf[pred_context][1], 2, ACCT_STR);
288                     if (bit1) {
289                         pred_context  = get_pred_context_uni_comp_ref_p2(pi);
290                         uint16_t bit2 = (uint16_t)svt_read_symbol(
291                             r,
292                             parse_ctxt->cur_tile_ctx.uni_comp_ref_cdf[pred_context][2],
293                             2,
294                             ACCT_STR);
295                         if (bit2) {
296                             ref_frame[0] = LAST_FRAME;
297                             ref_frame[1] = GOLDEN_FRAME;
298                         } else {
299                             ref_frame[0] = LAST_FRAME;
300                             ref_frame[1] = LAST3_FRAME;
301                         }
302                     } else {
303                         ref_frame[0] = LAST_FRAME;
304                         ref_frame[1] = LAST2_FRAME;
305                     }
306                 }
307                 return;
308             }
309 
310             assert(comp_ref_type == BIDIR_COMP_REFERENCE);
311 
312             const int idx = 1;
313             pred_context  = get_pred_context_comp_ref_p(pi);
314             uint16_t bit  = (uint16_t)svt_read_symbol(
315                 r, parse_ctxt->cur_tile_ctx.comp_ref_cdf[pred_context][0], 2, ACCT_STR);
316             // Decode forward references.
317             if (!bit) {
318                 uint16_t bit1 = (uint16_t)svt_read_symbol(
319                     r,
320                     parse_ctxt->cur_tile_ctx.comp_ref_cdf[get_pred_context_single_ref_p4(pi)][1],
321                     2,
322                     ACCT_STR);
323                 ref_frame[0] = bit1 ? LAST2_FRAME : LAST_FRAME;
324             } else {
325                 uint16_t bit2 = (uint16_t)svt_read_symbol(
326                     r,
327                     parse_ctxt->cur_tile_ctx.comp_ref_cdf[get_pred_context_last3_or_gld(pi)][2],
328                     2,
329                     ACCT_STR);
330                 ref_frame[0] = bit2 ? GOLDEN_FRAME : LAST3_FRAME;
331             }
332 
333             // Decode backward references.
334             pred_context     = get_pred_context_comp_bwdref_p(pi);
335             uint16_t bit_bwd = (uint16_t)svt_read_symbol(
336                 r, parse_ctxt->cur_tile_ctx.comp_bwdref_cdf[pred_context][0], 2, ACCT_STR);
337             if (!bit_bwd) {
338                 pred_context      = get_pred_context_comp_bwdref_p1(pi);
339                 uint16_t bit1_bwd = (uint16_t)svt_read_symbol(
340                     r, parse_ctxt->cur_tile_ctx.comp_bwdref_cdf[pred_context][1], 2, ACCT_STR);
341                 ref_frame[idx] = bit1_bwd ? ALTREF2_FRAME : BWDREF_FRAME;
342             } else {
343                 ref_frame[idx] = ALTREF_FRAME;
344             }
345         } else if (mode == SINGLE_REFERENCE) {
346             cdf = parse_ctxt->cur_tile_ctx.single_ref_cdf[get_pred_context_single_ref_p1(pi)][0];
347             const int32_t bit0 = svt_read_symbol(r, cdf, 2, ACCT_STR);
348 
349             if (bit0) {
350                 cdf =
351                     parse_ctxt->cur_tile_ctx.single_ref_cdf[get_pred_context_comp_bwdref_p(pi)][1];
352                 const int32_t bit1 = svt_read_symbol(r, cdf, 2, ACCT_STR);
353                 if (!bit1) {
354                     cdf = parse_ctxt->cur_tile_ctx
355                               .single_ref_cdf[get_pred_context_comp_bwdref_p1(pi)][5];
356                     const int32_t bit5 = svt_read_symbol(r, cdf, 2, ACCT_STR);
357                     ref_frame[0]       = bit5 ? ALTREF2_FRAME : BWDREF_FRAME;
358                 } else {
359                     ref_frame[0] = ALTREF_FRAME;
360                 }
361             } else {
362                 cdf = parse_ctxt->cur_tile_ctx.single_ref_cdf[get_pred_context_comp_ref_p(pi)][2];
363                 const int32_t bit2 = svt_read_symbol(r, cdf, 2, ACCT_STR);
364                 if (bit2) {
365                     cdf = parse_ctxt->cur_tile_ctx
366                               .single_ref_cdf[get_pred_context_last3_or_gld(pi)][4];
367                     const int32_t bit4 = svt_read_symbol(r, cdf, 2, ACCT_STR);
368                     ref_frame[0]       = bit4 ? GOLDEN_FRAME : LAST3_FRAME;
369                 } else {
370                     cdf = parse_ctxt->cur_tile_ctx
371                               .single_ref_cdf[get_pred_context_single_ref_p4(pi)][3];
372                     const int32_t bit3 = svt_read_symbol(r, cdf, 2, ACCT_STR);
373                     ref_frame[0]       = bit3 ? LAST2_FRAME : LAST_FRAME;
374                 }
375             }
376 
377             ref_frame[1] = NONE_FRAME;
378         } else
379             assert(0 && "Invalid prediction mode.");
380     }
381 }
382 
has_newmv(PredictionMode mode)383 int has_newmv(PredictionMode mode) {
384     return (mode == NEWMV || mode == NEW_NEWMV || mode == NEAR_NEWMV || mode == NEW_NEARMV ||
385             mode == NEAREST_NEWMV || mode == NEW_NEARESTMV);
386 }
387 
add_ref_mv_candidate(EbDecHandle * dec_handle,const BlockModeInfo * const candidate,const MvReferenceFrame rf[2],uint8_t * num_mv_found,uint8_t * found_match,uint8_t * newmv_count,CandidateMv * ref_mv_stack,IntMv * gm_mv_candidates,int weight)388 static void add_ref_mv_candidate(EbDecHandle *dec_handle, const BlockModeInfo *const candidate,
389                                  const MvReferenceFrame rf[2], uint8_t *num_mv_found,
390                                  uint8_t *found_match, uint8_t *newmv_count,
391                                  CandidateMv *ref_mv_stack, IntMv *gm_mv_candidates, int weight) {
392     if (!is_inter_block(candidate))
393         return; // for intrabc
394     assert(weight % 2 == 0);
395 
396     EbDecPicBuf *buf = dec_handle->cur_pic_buf[0];
397     if (rf[1] == NONE_FRAME) {
398         // single reference frame
399         for (int ref = 0; ref < 2; ++ref) {
400             if (candidate->ref_frame[ref] == rf[0]) {
401                 IntMv this_refmv = is_global_mv_block(candidate->mode,
402                                                       candidate->sb_type,
403                                                       buf->global_motion[rf[0]].gm_type)
404                     ? gm_mv_candidates[0]
405                     : candidate->mv[ref];
406                 int   index;
407                 for (index = 0; index < *num_mv_found; ++index)
408                     if (ref_mv_stack[index].this_mv.as_int == this_refmv.as_int)
409                         break;
410 
411                 if (index < *num_mv_found)
412                     ref_mv_stack[index].weight += weight;
413 
414                 // Add a new item to the list.
415                 if (index == *num_mv_found && *num_mv_found < MAX_REF_MV_STACK_SIZE) {
416                     ref_mv_stack[index].this_mv.as_int = this_refmv.as_int;
417                     ref_mv_stack[index].weight         = weight;
418                     ++(*num_mv_found);
419                 }
420                 if (has_newmv(candidate->mode))
421                     ++*newmv_count;
422                 ++*found_match;
423             }
424         }
425     } else {
426         // compound reference frame
427         if (candidate->ref_frame[0] == rf[0] && candidate->ref_frame[1] == rf[1]) {
428             IntMv this_refmv[2];
429             for (int ref = 0; ref < 2; ++ref)
430                 this_refmv[ref] = is_global_mv_block(candidate->mode,
431                                                      candidate->sb_type,
432                                                      buf->global_motion[rf[ref]].gm_type)
433                     ? gm_mv_candidates[ref]
434                     : candidate->mv[ref];
435 
436             //*found_match = 1;
437             int index;
438             for (index = 0; index < *num_mv_found; ++index)
439                 if ((ref_mv_stack[index].this_mv.as_int == this_refmv[0].as_int) &&
440                     (ref_mv_stack[index].comp_mv.as_int == this_refmv[1].as_int))
441                     break;
442 
443             if (index < *num_mv_found)
444                 ref_mv_stack[index].weight += weight;
445 
446             // Add a new item to the list.
447             if (index == *num_mv_found && *num_mv_found < MAX_REF_MV_STACK_SIZE) {
448                 ref_mv_stack[index].this_mv.as_int = this_refmv[0].as_int;
449                 ref_mv_stack[index].comp_mv.as_int = this_refmv[1].as_int;
450                 ref_mv_stack[index].weight         = weight;
451                 ++(*num_mv_found);
452             }
453             if (has_newmv(candidate->mode))
454                 ++*newmv_count;
455             ++*found_match;
456         }
457     }
458 }
459 
scan_row_mbmi(EbDecHandle * dec_handle,ParseCtxt * parse_ctx,PartitionInfo * pi,int delta_row,const MvReferenceFrame rf[2],CandidateMv * ref_mv_stack,uint8_t * num_mv_found,uint8_t * found_match,uint8_t * newmv_count,IntMv * gm_mv_candidates,int max_row_offset,int * processed_rows)460 static void scan_row_mbmi(EbDecHandle *dec_handle, ParseCtxt *parse_ctx, PartitionInfo *pi,
461                           int delta_row, const MvReferenceFrame rf[2], CandidateMv *ref_mv_stack,
462                           uint8_t *num_mv_found, uint8_t *found_match, uint8_t *newmv_count,
463                           IntMv *gm_mv_candidates, int max_row_offset, int *processed_rows) {
464     int mi_row = pi->mi_row;
465     int mi_col = pi->mi_col;
466 
467     int          bw4         = mi_size_wide[pi->mi->sb_type];
468     FrameHeader *frm_header  = &dec_handle->frame_header;
469     int          end4        = AOMMIN(AOMMIN(bw4, (int)frm_header->mi_cols - mi_col), 16);
470     int          delta_col   = 0;
471     int          use_step_16 = (bw4 >= 16);
472     const int    n8_w_8      = mi_size_wide[BLOCK_8X8];
473     const int    n8_w_16     = mi_size_wide[BLOCK_16X16];
474 
475     if (abs(delta_row) > 1) {
476         delta_col = 1;
477         if ((mi_col & 0x01) && bw4 < n8_w_8)
478             --delta_col;
479     }
480 
481     for (int i = 0; i < end4;) {
482         int mv_row = mi_row + delta_row;
483         int mv_col = mi_col + delta_col + i;
484         if (!is_inside(&parse_ctx->cur_tile_info, mv_col, mv_row))
485             break;
486         BlockModeInfo *candidate = get_cur_mode_info(dec_handle, mv_row, mv_col, pi->sb_info);
487         int            len       = AOMMIN(bw4, mi_size_wide[candidate->sb_type]);
488         const int      n4_w      = mi_size_wide[candidate->sb_type];
489         if (use_step_16)
490             len = AOMMAX(n8_w_16, len);
491         else if (abs(delta_row) > 1)
492             len = AOMMAX(n8_w_8, len);
493 
494         int weight = 2;
495         if (bw4 >= n8_w_8 && bw4 <= n4_w) {
496             int inc = AOMMIN(-max_row_offset + delta_row + 1, mi_size_high[candidate->sb_type]);
497             // Obtain range used in weight calculation.
498             weight          = AOMMAX(weight, inc);
499             *processed_rows = inc - delta_row - 1;
500         }
501         add_ref_mv_candidate(dec_handle,
502                              candidate,
503                              rf,
504                              num_mv_found,
505                              found_match,
506                              newmv_count,
507                              ref_mv_stack,
508                              gm_mv_candidates,
509                              len * weight);
510 
511         i += len;
512     }
513 }
514 
scan_col_mbmi(EbDecHandle * dec_handle,ParseCtxt * parse_ctx,PartitionInfo * pi,int delta_col,const MvReferenceFrame rf[2],CandidateMv * ref_mv_stack,uint8_t * num_mv_found,uint8_t * found_match,uint8_t * newmv_count,IntMv * gm_mv_candidates,int max_col_offset,int * processed_cols)515 static void scan_col_mbmi(EbDecHandle *dec_handle, ParseCtxt *parse_ctx, PartitionInfo *pi,
516                           int delta_col, const MvReferenceFrame rf[2], CandidateMv *ref_mv_stack,
517                           uint8_t *num_mv_found, uint8_t *found_match, uint8_t *newmv_count,
518                           IntMv *gm_mv_candidates, int max_col_offset, int *processed_cols) {
519     int          mi_row      = pi->mi_row;
520     int          mi_col      = pi->mi_col;
521     int          bh4         = mi_size_high[pi->mi->sb_type];
522     FrameHeader *frm_header  = &dec_handle->frame_header;
523     int          end4        = AOMMIN(AOMMIN(bh4, (int)frm_header->mi_rows - mi_row), 16);
524     int          delta_row   = 0;
525     int          use_step_16 = (bh4 >= 16);
526     const int    n8_h_8      = mi_size_high[BLOCK_8X8];
527 
528     if (abs(delta_col) > 1) {
529         delta_row = 1;
530         if ((mi_row & 0x01) && bh4 < n8_h_8)
531             --delta_row;
532     }
533 
534     for (int i = 0; i < end4;) {
535         int mv_row = mi_row + delta_row + i;
536         int mv_col = mi_col + delta_col;
537         if (!is_inside(&parse_ctx->cur_tile_info, mv_col, mv_row))
538             break;
539         BlockModeInfo *candidate = get_cur_mode_info(dec_handle, mv_row, mv_col, pi->sb_info);
540         int            len       = AOMMIN(bh4, mi_size_high[candidate->sb_type]);
541         const int      n4_h      = mi_size_high[candidate->sb_type];
542         if (abs(delta_col) > 1)
543             len = AOMMAX(2, len);
544         if (use_step_16)
545             len = AOMMAX(4, len);
546 
547         int weight = 2;
548         if (bh4 >= n8_h_8 && bh4 <= n4_h) {
549             int inc = AOMMIN(-max_col_offset + delta_col + 1, mi_size_wide[candidate->sb_type]);
550             // Obtain range used in weight calculation.
551             weight          = AOMMAX(weight, inc);
552             *processed_cols = inc - delta_col - 1;
553         }
554 
555         add_ref_mv_candidate(dec_handle,
556                              candidate,
557                              rf,
558                              num_mv_found,
559                              found_match,
560                              newmv_count,
561                              ref_mv_stack,
562                              gm_mv_candidates,
563                              len * weight);
564 
565         i += len;
566     }
567 }
568 
scan_blk_mbmi(EbDecHandle * dec_handle,ParseCtxt * parse_ctx,PartitionInfo * pi,int delta_row,int delta_col,const MvReferenceFrame rf[2],CandidateMv * ref_mv_stack,uint8_t * found_match,uint8_t * newmv_count,IntMv * gm_mv_candidates,uint8_t num_mv_found[MODE_CTX_REF_FRAMES])569 static void scan_blk_mbmi(EbDecHandle *dec_handle, ParseCtxt *parse_ctx, PartitionInfo *pi,
570                           int delta_row, int delta_col, const MvReferenceFrame rf[2],
571                           CandidateMv *ref_mv_stack, uint8_t *found_match, uint8_t *newmv_count,
572                           IntMv *gm_mv_candidates, uint8_t num_mv_found[MODE_CTX_REF_FRAMES]) {
573     const int mv_row = pi->mi_row + delta_row;
574     const int mv_col = pi->mi_col + delta_col;
575 
576     if (is_inside(&parse_ctx->cur_tile_info, mv_col, mv_row)) {
577         const BlockModeInfo *const candidate = get_cur_mode_info(
578             dec_handle, mv_row, mv_col, pi->sb_info);
579 
580         add_ref_mv_candidate(dec_handle,
581                              candidate,
582                              rf,
583                              num_mv_found,
584                              found_match,
585                              newmv_count,
586                              ref_mv_stack,
587                              gm_mv_candidates,
588                              4);
589     } // Analyze a single 8x8 block motion information.
590 }
591 
592 /* TODO: Harmonize with Encoder. */
has_top_right(EbDecHandle * dec_handle,PartitionInfo * pi,int bs)593 static int has_top_right(EbDecHandle *dec_handle, PartitionInfo *pi, int bs) {
594     int       n4_w       = mi_size_wide[pi->mi->sb_type];
595     int       n4_h       = mi_size_high[pi->mi->sb_type];
596     const int sb_mi_size = mi_size_wide[dec_handle->seq_header.sb_size];
597     const int mask_row   = pi->mi_row & (sb_mi_size - 1);
598     const int mask_col   = pi->mi_col & (sb_mi_size - 1);
599 
600     if (bs > mi_size_wide[BLOCK_64X64])
601         return 0;
602     int has_tr = !((mask_row & bs) && (mask_col & bs));
603 
604     assert(bs > 0 && !(bs & (bs - 1)));
605 
606     while (bs < sb_mi_size) {
607         if (mask_col & bs) {
608             if ((mask_col & (2 * bs)) && (mask_row & (2 * bs))) {
609                 has_tr = 0;
610                 break;
611             }
612         } else {
613             break;
614         }
615         bs <<= 1;
616     }
617 
618     if (n4_w < n4_h)
619         if (!pi->is_sec_rect)
620             has_tr = 1;
621     if (n4_w > n4_h)
622         if (pi->is_sec_rect)
623             has_tr = 0;
624     if (pi->mi->partition == PARTITION_VERT_A) {
625         if (n4_w == n4_h)
626             if (mask_row & bs)
627                 has_tr = 0;
628     }
629     return has_tr;
630 }
631 
add_tpl_ref_mv(EbDecHandle * dec_handle,ParseCtxt * parse_ctx,int mi_row,int mi_col,MvReferenceFrame ref_frame,int blk_row,int blk_col,IntMv * gm_mv_candidates,uint8_t * num_mv_found,CandidateMv ref_mv_stacks[][MAX_REF_MV_STACK_SIZE],int16_t * mode_context)632 static int add_tpl_ref_mv(EbDecHandle *dec_handle, ParseCtxt *parse_ctx, int mi_row, int mi_col,
633                           MvReferenceFrame ref_frame, int blk_row, int blk_col,
634                           IntMv *gm_mv_candidates, uint8_t *num_mv_found,
635                           CandidateMv ref_mv_stacks[][MAX_REF_MV_STACK_SIZE],
636                           int16_t *   mode_context) {
637     uint8_t      idx;
638     FrameHeader *frm_header = &dec_handle->frame_header;
639     int          mv_row     = (mi_row + blk_row) | 1;
640     int          mv_col     = (mi_col + blk_col) | 1;
641 
642     if (!is_inside(&parse_ctx->cur_tile_info, mv_col, mv_row))
643         return 0;
644 
645     int x8 = mv_col >> 1;
646     int y8 = mv_row >> 1;
647 
648     MvReferenceFrame rf[2];
649     av1_set_ref_frame(rf, ref_frame);
650 
651     const TemporalMvRef *tpl_mvs = dec_handle->main_frame_buf.tpl_mvs +
652         y8 * (frm_header->mi_stride >> 1) + x8;
653     const IntMv prev_frame_mvs = tpl_mvs->mf_mv0;
654     if (rf[1] == NONE_FRAME) {
655         int                      cur_frame_index = dec_handle->cur_pic_buf[0]->order_hint;
656         const EbDecPicBuf *const buf_0           = get_ref_frame_buf(dec_handle, rf[0]);
657         int                      frame0_index    = buf_0->order_hint;
658         int                      cur_offset_0    = get_relative_dist(
659             &dec_handle->seq_header.order_hint_info, cur_frame_index, frame0_index);
660         CandidateMv *ref_mv_stack = ref_mv_stacks[rf[0]];
661 
662         if (prev_frame_mvs.as_int == INVALID_MV)
663             return 0;
664 
665         IntMv this_refmv;
666         get_mv_projection(
667             &this_refmv.as_mv, prev_frame_mvs.as_mv, cur_offset_0, tpl_mvs->ref_frame_offset);
668 
669         lower_mv_precision(
670             &this_refmv.as_mv, frm_header->allow_high_precision_mv, frm_header->force_integer_mv);
671 
672         if (blk_row == 0 && blk_col == 0) {
673             if (abs(this_refmv.as_mv.row - gm_mv_candidates[0].as_mv.row) >= 16 ||
674                 abs(this_refmv.as_mv.col - gm_mv_candidates[0].as_mv.col) >= 16)
675                 /*zero_mv_ctxt*/
676                 mode_context[ref_frame] |= (1 << GLOBALMV_OFFSET);
677         }
678 
679         for (idx = 0; idx < *num_mv_found; ++idx)
680             if (this_refmv.as_int == ref_mv_stack[idx].this_mv.as_int)
681                 break;
682 
683         if (idx < *num_mv_found)
684             ref_mv_stack[idx].weight += 2;
685         else if (*num_mv_found < MAX_REF_MV_STACK_SIZE) {
686             ref_mv_stack[idx].this_mv.as_int = this_refmv.as_int;
687             ref_mv_stack[idx].weight         = 2;
688             ++(*num_mv_found);
689         }
690         return 1;
691     } else {
692         // Process compound inter mode
693         int                      cur_frame_index = dec_handle->cur_pic_buf[0]->order_hint;
694         const EbDecPicBuf *const buf_0           = get_ref_frame_buf(dec_handle, rf[0]);
695         int                      frame0_index    = buf_0->order_hint;
696 
697         int cur_offset_0 = get_relative_dist(
698             &dec_handle->seq_header.order_hint_info, cur_frame_index, frame0_index);
699         const EbDecPicBuf *const buf_1        = get_ref_frame_buf(dec_handle, rf[1]);
700         int                      frame1_index = buf_1->order_hint;
701         int                      cur_offset_1 = get_relative_dist(
702             &dec_handle->seq_header.order_hint_info, cur_frame_index, frame1_index);
703         CandidateMv *ref_mv_stack = ref_mv_stacks[ref_frame];
704 
705         if (prev_frame_mvs.as_int == INVALID_MV)
706             return 0;
707 
708         IntMv this_refmv;
709         IntMv comp_refmv;
710         get_mv_projection(
711             &this_refmv.as_mv, prev_frame_mvs.as_mv, cur_offset_0, tpl_mvs->ref_frame_offset);
712         get_mv_projection(
713             &comp_refmv.as_mv, prev_frame_mvs.as_mv, cur_offset_1, tpl_mvs->ref_frame_offset);
714 
715         lower_mv_precision(
716             &this_refmv.as_mv, frm_header->allow_high_precision_mv, frm_header->force_integer_mv);
717         lower_mv_precision(
718             &comp_refmv.as_mv, frm_header->allow_high_precision_mv, frm_header->force_integer_mv);
719 
720         if (blk_row == 0 && blk_col == 0) {
721             if (abs(this_refmv.as_mv.row - gm_mv_candidates[0].as_mv.row) >= 16 ||
722                 abs(this_refmv.as_mv.col - gm_mv_candidates[0].as_mv.col) >= 16 ||
723                 abs(comp_refmv.as_mv.row - gm_mv_candidates[1].as_mv.row) >= 16 ||
724                 abs(comp_refmv.as_mv.col - gm_mv_candidates[1].as_mv.col) >= 16)
725                 /*zero_mv_ctxt*/
726                 mode_context[ref_frame] |= (1 << GLOBALMV_OFFSET);
727         }
728 
729         for (idx = 0; idx < *num_mv_found; ++idx)
730             if (this_refmv.as_int == ref_mv_stack[idx].this_mv.as_int &&
731                 comp_refmv.as_int == ref_mv_stack[idx].comp_mv.as_int)
732                 break;
733 
734         if (idx < *num_mv_found)
735             ref_mv_stack[idx].weight += 2;
736         else if (*num_mv_found < MAX_REF_MV_STACK_SIZE) {
737             ref_mv_stack[idx].this_mv.as_int = this_refmv.as_int;
738             ref_mv_stack[idx].comp_mv.as_int = comp_refmv.as_int;
739             ref_mv_stack[idx].weight         = 2;
740             ++(*num_mv_found);
741         }
742     }
743     return 1;
744 }
745 
add_extra_mv_candidate(BlockModeInfo * candidate,EbDecHandle * dec_handle,MvReferenceFrame * rf,IntMv ref_id[2][2],int ref_id_count[2],IntMv ref_diff[2][2],int ref_diff_count[2])746 static void add_extra_mv_candidate(BlockModeInfo *candidate, EbDecHandle *dec_handle,
747                                    MvReferenceFrame *rf, IntMv ref_id[2][2], int ref_id_count[2],
748                                    IntMv ref_diff[2][2], int ref_diff_count[2]) {
749     FrameHeader *frm_header = &dec_handle->frame_header;
750     for (int rf_idx = 0; rf_idx < 2; ++rf_idx) {
751         MvReferenceFrame can_rf = candidate->ref_frame[rf_idx];
752         if (can_rf > INTRA_FRAME) {
753             for (int cmp_idx = 0; cmp_idx < 2; ++cmp_idx) {
754                 if (can_rf == rf[cmp_idx] && ref_id_count[cmp_idx] < 2) {
755                     ref_id[cmp_idx][ref_id_count[cmp_idx]] = candidate->mv[rf_idx];
756                     ++ref_id_count[cmp_idx];
757                 } else if (ref_diff_count[cmp_idx] < 2) {
758                     IntMv this_mv = candidate->mv[rf_idx];
759                     if (frm_header->ref_frame_sign_bias[can_rf] !=
760                         frm_header->ref_frame_sign_bias[rf[cmp_idx]]) {
761                         this_mv.as_mv.row = -this_mv.as_mv.row;
762                         this_mv.as_mv.col = -this_mv.as_mv.col;
763                     }
764                     ref_diff[cmp_idx][ref_diff_count[cmp_idx]] = this_mv;
765                     ++ref_diff_count[cmp_idx];
766                 }
767             }
768         }
769     }
770 }
771 
process_single_ref_mv_candidate(BlockModeInfo * candidate,EbDecHandle * dec_handle,MvReferenceFrame ref_frame,uint8_t refmv_count[MODE_CTX_REF_FRAMES],CandidateMv ref_mv_stack[][MAX_REF_MV_STACK_SIZE])772 static void process_single_ref_mv_candidate(BlockModeInfo *candidate, EbDecHandle *dec_handle,
773                                             MvReferenceFrame ref_frame,
774                                             uint8_t          refmv_count[MODE_CTX_REF_FRAMES],
775                                             CandidateMv ref_mv_stack[][MAX_REF_MV_STACK_SIZE]) {
776     FrameHeader *frm_header = &dec_handle->frame_header;
777     for (int rf_idx = 0; rf_idx < 2; ++rf_idx) {
778         if (candidate->ref_frame[rf_idx] > INTRA_FRAME) {
779             IntMv this_mv = candidate->mv[rf_idx];
780             if (frm_header->ref_frame_sign_bias[candidate->ref_frame[rf_idx]] !=
781                 frm_header->ref_frame_sign_bias[ref_frame]) {
782                 this_mv.as_mv.row = -this_mv.as_mv.row;
783                 this_mv.as_mv.col = -this_mv.as_mv.col;
784             }
785             int stack_idx;
786             for (stack_idx = 0; stack_idx < refmv_count[ref_frame]; ++stack_idx) {
787                 const IntMv stack_mv = ref_mv_stack[ref_frame][stack_idx].this_mv;
788                 if (this_mv.as_int == stack_mv.as_int)
789                     break;
790             }
791 
792             if (stack_idx == refmv_count[ref_frame]) {
793                 ref_mv_stack[ref_frame][stack_idx].this_mv = this_mv;
794                 ref_mv_stack[ref_frame][stack_idx].weight  = 2;
795                 ++refmv_count[ref_frame];
796             }
797         }
798     }
799 }
800 
clamp_mv_ref(MV * mv,int bw,int bh,PartitionInfo * pi)801 static INLINE void clamp_mv_ref(MV *mv, int bw, int bh, PartitionInfo *pi) {
802     clamp_mv(mv,
803              pi->mb_to_left_edge - bw * 8 - MV_BORDER,
804              pi->mb_to_right_edge + bw * 8 + MV_BORDER,
805              pi->mb_to_top_edge - bh * 8 - MV_BORDER,
806              pi->mb_to_bottom_edge + bh * 8 + MV_BORDER);
807 }
808 
dec_setup_ref_mv_list(EbDecHandle * dec_handle,ParseCtxt * parse_ctx,PartitionInfo * pi,MvReferenceFrame ref_frame,CandidateMv ref_mv_stack[][MAX_REF_MV_STACK_SIZE],IntMv mv_ref_list[][MAX_MV_REF_CANDIDATES],IntMv * gm_mv_candidates,int16_t * mode_context,MvCount * mv_cnt)809 static void dec_setup_ref_mv_list(EbDecHandle *dec_handle, ParseCtxt *parse_ctx, PartitionInfo *pi,
810                                   MvReferenceFrame ref_frame,
811                                   CandidateMv      ref_mv_stack[][MAX_REF_MV_STACK_SIZE],
812                                   IntMv            mv_ref_list[][MAX_MV_REF_CANDIDATES],
813                                   IntMv *gm_mv_candidates, int16_t *mode_context, MvCount *mv_cnt) {
814     int              n4_w = mi_size_wide[pi->mi->sb_type];
815     int              n4_h = mi_size_high[pi->mi->sb_type];
816     const int        bs   = AOMMAX(n4_w, n4_h);
817     MvReferenceFrame rf[2];
818 
819     FrameHeader *         frame_info     = parse_ctx->frame_header;
820     const TileInfo *const tile           = &parse_ctx->cur_tile_info;
821     int                   max_row_offset = 0, max_col_offset = 0;
822     int32_t               mi_row         = pi->mi_row;
823     int32_t               mi_col         = pi->mi_col;
824     const int             row_adj        = (n4_h < mi_size_high[BLOCK_8X8]) && (mi_row & 0x01);
825     const int             col_adj        = (n4_w < mi_size_wide[BLOCK_8X8]) && (mi_col & 0x01);
826     int                   processed_rows = 0;
827     int                   processed_cols = 0;
828 
829     av1_set_ref_frame(rf, ref_frame);
830     mode_context[ref_frame] = 0;
831 
832     // Find valid maximum row/col offset.
833     if (pi->up_available) {
834         max_row_offset = -(MVREF_ROW_COLS << 1) + row_adj;
835 
836         if (n4_h < mi_size_high[BLOCK_8X8])
837             max_row_offset = -(2 << 1) + row_adj;
838 
839         max_row_offset = clamp(
840             max_row_offset, tile->mi_row_start - mi_row, tile->mi_row_end - mi_row - 1);
841     }
842 
843     if (pi->left_available) {
844         max_col_offset = -(MVREF_ROW_COLS << 1) + col_adj;
845 
846         if (n4_w < mi_size_wide[BLOCK_8X8])
847             max_col_offset = -(2 << 1) + col_adj;
848 
849         max_col_offset = clamp(
850             max_col_offset, tile->mi_col_start - mi_col, tile->mi_col_end - mi_col - 1);
851     }
852     memset(mv_cnt, 0, sizeof(*mv_cnt));
853 
854     // Scan the first above row mode info. row_offset = -1;
855     if (abs(max_row_offset) >= 1) {
856         scan_row_mbmi(dec_handle,
857                       parse_ctx,
858                       pi,
859                       -1,
860                       rf,
861                       ref_mv_stack[ref_frame],
862                       &mv_cnt->num_mv_found[ref_frame],
863                       &mv_cnt->found_above_match,
864                       &mv_cnt->newmv_count,
865                       gm_mv_candidates,
866                       max_row_offset,
867                       &processed_rows);
868     }
869 
870     // Scan the first left column mode info. col_offset = -1;
871     if (abs(max_col_offset) >= 1) {
872         scan_col_mbmi(dec_handle,
873                       parse_ctx,
874                       pi,
875                       -1,
876                       rf,
877                       ref_mv_stack[ref_frame],
878                       &mv_cnt->num_mv_found[ref_frame],
879                       &mv_cnt->found_left_match,
880                       &mv_cnt->newmv_count,
881                       gm_mv_candidates,
882                       max_col_offset,
883                       &processed_cols);
884     }
885 
886     if (has_top_right(dec_handle, pi, bs)) {
887         scan_blk_mbmi(dec_handle,
888                       parse_ctx,
889                       pi,
890                       -1,
891                       n4_w,
892                       rf,
893                       ref_mv_stack[ref_frame],
894                       &mv_cnt->found_above_match,
895                       &mv_cnt->newmv_count,
896                       gm_mv_candidates,
897                       &mv_cnt->num_mv_found[ref_frame]);
898     }
899 
900     const uint8_t nearest_match = (mv_cnt->found_above_match > 0) + (mv_cnt->found_left_match > 0);
901     const uint8_t num_nearest   = mv_cnt->num_mv_found[ref_frame];
902     const uint8_t num_new       = mv_cnt->newmv_count;
903 
904     for (int idx = 0; idx < num_nearest; ++idx)
905         ref_mv_stack[ref_frame][idx].weight += REF_CAT_LEVEL;
906 
907     if (frame_info->use_ref_frame_mvs) {
908         int       is_available = 0;
909         const int voffset      = AOMMAX(mi_size_high[BLOCK_8X8], n4_h);
910         const int hoffset      = AOMMAX(mi_size_wide[BLOCK_8X8], n4_w);
911         const int blk_row_end  = AOMMIN(n4_h, mi_size_high[BLOCK_64X64]);
912         const int blk_col_end  = AOMMIN(n4_w, mi_size_wide[BLOCK_64X64]);
913 
914         const int tpl_sample_pos[3][2] = {
915             {voffset, -2},
916             {voffset, hoffset},
917             {voffset - 2, hoffset},
918         };
919         const int allow_extension = (n4_h >= mi_size_high[BLOCK_8X8]) &&
920             (n4_h < mi_size_high[BLOCK_64X64]) && (n4_w >= mi_size_wide[BLOCK_8X8]) &&
921             (n4_w < mi_size_wide[BLOCK_64X64]);
922 
923         const int step_h = (n4_h >= mi_size_high[BLOCK_64X64]) ? mi_size_high[BLOCK_16X16]
924                                                                : mi_size_high[BLOCK_8X8];
925         const int step_w = (n4_w >= mi_size_wide[BLOCK_64X64]) ? mi_size_wide[BLOCK_16X16]
926                                                                : mi_size_wide[BLOCK_8X8];
927 
928         for (int blk_row = 0; blk_row < blk_row_end; blk_row += step_h) {
929             for (int blk_col = 0; blk_col < blk_col_end; blk_col += step_w) {
930                 int ret = add_tpl_ref_mv(dec_handle,
931                                          parse_ctx,
932                                          mi_row,
933                                          mi_col,
934                                          ref_frame,
935                                          blk_row,
936                                          blk_col,
937                                          gm_mv_candidates,
938                                          &mv_cnt->num_mv_found[ref_frame],
939                                          ref_mv_stack,
940                                          mode_context);
941                 if (blk_row == 0 && blk_col == 0)
942                     is_available = ret;
943             }
944         }
945 
946         if (is_available == 0)
947             mode_context[ref_frame] |= (1 << GLOBALMV_OFFSET);
948 
949         if (allow_extension) {
950             for (int i = 0; i < 3; ++i) {
951                 const int blk_row = tpl_sample_pos[i][0];
952                 const int blk_col = tpl_sample_pos[i][1];
953 
954                 if (check_sb_border(mi_row, mi_col, blk_row, blk_col)) {
955                     add_tpl_ref_mv(dec_handle,
956                                    parse_ctx,
957                                    mi_row,
958                                    mi_col,
959                                    ref_frame,
960                                    blk_row,
961                                    blk_col,
962                                    gm_mv_candidates,
963                                    &mv_cnt->num_mv_found[ref_frame],
964                                    ref_mv_stack,
965                                    mode_context);
966                 }
967             }
968         }
969     }
970 
971     // Scan the second outer area.
972     scan_blk_mbmi(dec_handle,
973                   parse_ctx,
974                   pi,
975                   -1,
976                   -1,
977                   rf,
978                   ref_mv_stack[ref_frame],
979                   &mv_cnt->found_above_match,
980                   &mv_cnt->newmv_count,
981                   gm_mv_candidates,
982                   &mv_cnt->num_mv_found[ref_frame]);
983 
984     for (int idx = 2; idx <= MVREF_ROW_COLS; ++idx) {
985         const int row_offset = -(idx << 1) + 1 + row_adj;
986         const int col_offset = -(idx << 1) + 1 + col_adj;
987         if (abs(row_offset) <= abs(max_row_offset) && abs(row_offset) > processed_rows) {
988             scan_row_mbmi(dec_handle,
989                           parse_ctx,
990                           pi,
991                           row_offset,
992                           rf,
993                           ref_mv_stack[ref_frame],
994                           &mv_cnt->num_mv_found[ref_frame],
995                           &mv_cnt->found_above_match,
996                           &mv_cnt->newmv_count,
997                           gm_mv_candidates,
998                           max_row_offset,
999                           &processed_rows);
1000         }
1001 
1002         if (abs(col_offset) <= abs(max_col_offset) && abs(col_offset) > processed_cols) {
1003             scan_col_mbmi(dec_handle,
1004                           parse_ctx,
1005                           pi,
1006                           col_offset,
1007                           rf,
1008                           ref_mv_stack[ref_frame],
1009                           &mv_cnt->num_mv_found[ref_frame],
1010                           &mv_cnt->found_left_match,
1011                           &mv_cnt->newmv_count,
1012                           gm_mv_candidates,
1013                           max_col_offset,
1014                           &processed_cols);
1015         }
1016     }
1017 
1018     /* sorting process*/
1019     int start = 0;
1020     int end   = num_nearest;
1021     while (end > start) {
1022         int new_end = start;
1023         for (int idx = start + 1; idx < end; ++idx) {
1024             if (ref_mv_stack[ref_frame][idx - 1].weight < ref_mv_stack[ref_frame][idx].weight) {
1025                 CandidateMv tmp_mv               = ref_mv_stack[ref_frame][idx - 1];
1026                 ref_mv_stack[ref_frame][idx - 1] = ref_mv_stack[ref_frame][idx];
1027                 ref_mv_stack[ref_frame][idx]     = tmp_mv;
1028                 new_end                          = idx;
1029             }
1030         }
1031         end = new_end;
1032     }
1033 
1034     start = num_nearest;
1035     end   = mv_cnt->num_mv_found[ref_frame];
1036     while (end > start) {
1037         int new_end = start;
1038         for (int idx = start + 1; idx < end; ++idx) {
1039             if (ref_mv_stack[ref_frame][idx - 1].weight < ref_mv_stack[ref_frame][idx].weight) {
1040                 CandidateMv tmp_mv               = ref_mv_stack[ref_frame][idx - 1];
1041                 ref_mv_stack[ref_frame][idx - 1] = ref_mv_stack[ref_frame][idx];
1042                 ref_mv_stack[ref_frame][idx]     = tmp_mv;
1043                 new_end                          = idx;
1044             }
1045         }
1046         end = new_end;
1047     }
1048 
1049     /* extra search process */
1050     if (mv_cnt->num_mv_found[ref_frame] < MAX_MV_REF_CANDIDATES) {
1051         IntMv ref_id[2][2], ref_diff[2][2];
1052         int   ref_id_count[2] = {0}, ref_diff_count[2] = {0};
1053 
1054         int mi_width  = AOMMIN(16, n4_w);
1055         mi_width      = AOMMIN(mi_width, (int)frame_info->mi_cols - mi_col);
1056         int mi_height = AOMMIN(16, n4_h);
1057         mi_height     = AOMMIN(mi_height, (int)frame_info->mi_rows - mi_row);
1058         int mi_size   = AOMMIN(mi_width, mi_height);
1059 
1060         for (int pass = 0; pass < 2; pass++) {
1061             int idx = 0;
1062             while (idx < mi_size && mv_cnt->num_mv_found[ref_frame] < MAX_MV_REF_CANDIDATES) {
1063                 int mv_row, mv_col;
1064                 if (pass == 0) {
1065                     mv_row = mi_row - 1;
1066                     mv_col = mi_col + idx;
1067                 } else {
1068                     mv_row = mi_row + idx;
1069                     mv_col = mi_col - 1;
1070                 }
1071 
1072                 if (!is_inside(&parse_ctx->cur_tile_info, mv_col, mv_row))
1073                     break;
1074 
1075                 BlockModeInfo *nbr = get_cur_mode_info(dec_handle, mv_row, mv_col, pi->sb_info);
1076 
1077                 if (rf[1] != NONE_FRAME)
1078                     add_extra_mv_candidate(
1079                         nbr, dec_handle, rf, ref_id, ref_id_count, ref_diff, ref_diff_count);
1080                 else
1081                     process_single_ref_mv_candidate(
1082                         nbr, dec_handle, ref_frame, mv_cnt->num_mv_found, ref_mv_stack);
1083 
1084                 idx += pass ? mi_size_high[nbr->sb_type] : mi_size_wide[nbr->sb_type];
1085             }
1086         }
1087 
1088         if (rf[1] > NONE_FRAME) {
1089             IntMv comp_list[3][2];
1090 
1091             for (int idx = 0; idx < 2; ++idx) {
1092                 int comp_idx = 0;
1093                 for (int list_idx = 0; list_idx < ref_id_count[idx]; ++list_idx, comp_idx++) {
1094                     comp_list[comp_idx][idx] = ref_id[idx][list_idx];
1095                 }
1096 
1097                 for (int list_idx = 0; list_idx < ref_diff_count[idx] && comp_idx < 2;
1098                      ++list_idx, ++comp_idx) {
1099                     comp_list[comp_idx][idx] = ref_diff[idx][list_idx];
1100                 }
1101 
1102                 for (; comp_idx < 2; ++comp_idx) comp_list[comp_idx][idx] = gm_mv_candidates[idx];
1103             }
1104 
1105             if (mv_cnt->num_mv_found[ref_frame]) {
1106                 assert(mv_cnt->num_mv_found[ref_frame] == 1);
1107                 if (comp_list[0][0].as_int == ref_mv_stack[ref_frame][0].this_mv.as_int &&
1108                     comp_list[0][1].as_int == ref_mv_stack[ref_frame][0].comp_mv.as_int) {
1109                     ref_mv_stack[ref_frame][mv_cnt->num_mv_found[ref_frame]].this_mv =
1110                         comp_list[1][0];
1111                     ref_mv_stack[ref_frame][mv_cnt->num_mv_found[ref_frame]].comp_mv =
1112                         comp_list[1][1];
1113                 } else {
1114                     ref_mv_stack[ref_frame][mv_cnt->num_mv_found[ref_frame]].this_mv =
1115                         comp_list[0][0];
1116                     ref_mv_stack[ref_frame][mv_cnt->num_mv_found[ref_frame]].comp_mv =
1117                         comp_list[0][1];
1118                 }
1119                 ref_mv_stack[ref_frame][mv_cnt->num_mv_found[ref_frame]].weight = 2;
1120                 ++mv_cnt->num_mv_found[ref_frame];
1121             } else {
1122                 for (int idx = 0; idx < MAX_MV_REF_CANDIDATES; ++idx) {
1123                     ref_mv_stack[ref_frame][mv_cnt->num_mv_found[ref_frame]].this_mv =
1124                         comp_list[idx][0];
1125                     ref_mv_stack[ref_frame][mv_cnt->num_mv_found[ref_frame]].comp_mv =
1126                         comp_list[idx][1];
1127                     ref_mv_stack[ref_frame][mv_cnt->num_mv_found[ref_frame]].weight = 2;
1128                     ++mv_cnt->num_mv_found[ref_frame];
1129                 }
1130             }
1131         }
1132     }
1133 
1134     /* context and clamping process */
1135     //int num_lists = rf[1] > NONE_FRAME ? 2 : 1;
1136     //for (int list = 0; list < num_lists; list++) {
1137     //    for (int idx = 0; idx < mv_cnt->num_mv_found[ref_frame]; idx++) {
1138     //        IntMv refMv = ref_mv_stack[ref_frame][idx].this_mv;
1139     //        refMv.as_mv.row = clamp_mv_row(dec_handle, pi, mi_row,
1140     //            refMv.as_mv.row, MV_BORDER + n4_h * 8);
1141     //        refMv.as_mv.col = clamp_mv_col(dec_handle, pi, mi_col,
1142     //            refMv.as_mv.col, MV_BORDER + n4_w * 8);
1143     //        ref_mv_stack[ref_frame][idx].this_mv = refMv;
1144     //    }
1145     //}
1146 
1147     if (rf[1] > NONE_FRAME) {
1148         for (int idx = 0; idx < mv_cnt->num_mv_found[ref_frame]; ++idx) {
1149             clamp_mv_ref(&ref_mv_stack[ref_frame][idx].this_mv.as_mv,
1150                          n4_w << MI_SIZE_LOG2,
1151                          n4_h << MI_SIZE_LOG2,
1152                          pi);
1153             clamp_mv_ref(&ref_mv_stack[ref_frame][idx].comp_mv.as_mv,
1154                          n4_w << MI_SIZE_LOG2,
1155                          n4_h << MI_SIZE_LOG2,
1156                          pi);
1157         }
1158     } else {
1159         for (int idx = 0; idx < mv_cnt->num_mv_found[ref_frame]; ++idx) {
1160             clamp_mv_ref(&ref_mv_stack[ref_frame][idx].this_mv.as_mv,
1161                          n4_w << MI_SIZE_LOG2,
1162                          n4_h << MI_SIZE_LOG2,
1163                          pi);
1164         }
1165     }
1166 
1167     const uint8_t ref_match_count = (mv_cnt->found_above_match > 0) +
1168         (mv_cnt->found_left_match > 0);
1169     switch (nearest_match) {
1170     case 0:
1171         mode_context[ref_frame] |= 0;
1172         if (ref_match_count >= 1)
1173             mode_context[ref_frame] |= 1;
1174         if (ref_match_count == 1)
1175             mode_context[ref_frame] |= (1 << REFMV_OFFSET);
1176         else if (ref_match_count >= 2)
1177             mode_context[ref_frame] |= (2 << REFMV_OFFSET);
1178         break;
1179     case 1:
1180         mode_context[ref_frame] |= (num_new > 0) ? 2 : 3;
1181         if (ref_match_count == 1)
1182             mode_context[ref_frame] |= (3 << REFMV_OFFSET);
1183         else if (ref_match_count >= 2)
1184             mode_context[ref_frame] |= (4 << REFMV_OFFSET);
1185         break;
1186     case 2:
1187     default:
1188         if (num_new >= 1)
1189             mode_context[ref_frame] |= 4;
1190         else
1191             mode_context[ref_frame] |= 5;
1192 
1193         mode_context[ref_frame] |= (5 << REFMV_OFFSET);
1194         break;
1195     }
1196 
1197     if (rf[1] == NONE_FRAME && mv_ref_list != NULL) {
1198         for (int idx = mv_cnt->num_mv_found[ref_frame]; idx < MAX_MV_REF_CANDIDATES; ++idx)
1199             mv_ref_list[rf[0]][idx].as_int = gm_mv_candidates[0].as_int;
1200 
1201         for (int idx = 0; idx < AOMMIN(MAX_MV_REF_CANDIDATES, mv_cnt->num_mv_found[ref_frame]);
1202              ++idx) {
1203             mv_ref_list[rf[0]][idx].as_int = ref_mv_stack[ref_frame][idx].this_mv.as_int;
1204         }
1205     }
1206 }
1207 
svt_mode_context_analyzer(const int16_t * const mode_context,const MvReferenceFrame * const rf)1208 static INLINE int16_t svt_mode_context_analyzer(const int16_t *const          mode_context,
1209                                                 const MvReferenceFrame *const rf) {
1210     const int8_t ref_frame = av1_ref_frame_type(rf);
1211 
1212     if (rf[1] <= INTRA_FRAME)
1213         return mode_context[ref_frame];
1214 
1215     const int16_t newmv_ctx = mode_context[ref_frame] & NEWMV_CTX_MASK;
1216     const int16_t refmv_ctx = (mode_context[ref_frame] >> REFMV_OFFSET) & REFMV_CTX_MASK;
1217 
1218     const int16_t comp_ctx =
1219         compound_mode_ctx_map[refmv_ctx >> 1][AOMMIN(newmv_ctx, COMP_NEWMV_CTXS - 1)];
1220     return comp_ctx;
1221 }
1222 
svt_av1_find_mv_refs(EbDecHandle * dec_handle,PartitionInfo * pi,ParseCtxt * parse_ctx,MvReferenceFrame ref_frame,CandidateMv ref_mv_stack[][MAX_REF_MV_STACK_SIZE],IntMv mv_ref_list[][MAX_MV_REF_CANDIDATES],IntMv global_mvs[2],int16_t * mode_context,MvCount * mv_cnt)1223 void svt_av1_find_mv_refs(EbDecHandle *dec_handle, PartitionInfo *pi, ParseCtxt *parse_ctx,
1224                           MvReferenceFrame ref_frame,
1225                           CandidateMv      ref_mv_stack[][MAX_REF_MV_STACK_SIZE],
1226                           IntMv mv_ref_list[][MAX_MV_REF_CANDIDATES], IntMv global_mvs[2],
1227                           int16_t *mode_context, MvCount *mv_cnt) {
1228     BlockSize        bsize = pi->mi->sb_type;
1229     MvReferenceFrame rf[2];
1230     av1_set_ref_frame(rf, ref_frame);
1231 
1232     /* setup global mv process */
1233     global_mvs[0].as_int = 0;
1234     global_mvs[1].as_int = 0;
1235     if (ref_frame != INTRA_FRAME) {
1236         EbDecPicBuf *buf     = dec_handle->cur_pic_buf[0];
1237         global_mvs[0].as_int = gm_get_motion_vector(
1238                                    &buf->global_motion[rf[0]],
1239                                    dec_handle->frame_header.allow_high_precision_mv,
1240                                    bsize,
1241                                    pi->mi_col,
1242                                    pi->mi_row,
1243                                    dec_handle->frame_header.force_integer_mv)
1244                                    .as_int;
1245 
1246         global_mvs[1].as_int = (rf[1] != NONE_FRAME)
1247             ? gm_get_motion_vector(&buf->global_motion[rf[1]],
1248                                    dec_handle->frame_header.allow_high_precision_mv,
1249                                    bsize,
1250                                    pi->mi_col,
1251                                    pi->mi_row,
1252                                    dec_handle->frame_header.force_integer_mv)
1253                   .as_int
1254             : 0;
1255     }
1256     dec_setup_ref_mv_list(dec_handle,
1257                           parse_ctx,
1258                           pi,
1259                           ref_frame,
1260                           ref_mv_stack,
1261                           mv_ref_list,
1262                           global_mvs,
1263                           mode_context,
1264                           mv_cnt);
1265 }
1266 
read_inter_compound_mode(ParseCtxt * parse_ctxt,int16_t ctx)1267 static PredictionMode read_inter_compound_mode(ParseCtxt *parse_ctxt, int16_t ctx) {
1268     SvtReader *r    = &parse_ctxt->r;
1269     const int  mode = svt_read_symbol(
1270         r, parse_ctxt->cur_tile_ctx.inter_compound_mode_cdf[ctx], INTER_COMPOUND_MODES, ACCT_STR);
1271     assert(is_inter_compound_mode(NEAREST_NEARESTMV + mode));
1272     return NEAREST_NEARESTMV + mode;
1273 }
1274 
has_nearmv(PredictionMode mode)1275 static INLINE int has_nearmv(PredictionMode mode) {
1276     return (mode == NEARMV || mode == NEAR_NEARMV || mode == NEAR_NEWMV || mode == NEW_NEARMV);
1277 }
1278 
get_drl_ctx(const CandidateMv * ref_mv_stack,int ref_idx)1279 static INLINE uint8_t get_drl_ctx(const CandidateMv *ref_mv_stack, int ref_idx) {
1280     if (ref_mv_stack[ref_idx].weight >= REF_CAT_LEVEL &&
1281         ref_mv_stack[ref_idx + 1].weight < REF_CAT_LEVEL) {
1282         return 1;
1283     }
1284 
1285     if (ref_mv_stack[ref_idx].weight < REF_CAT_LEVEL &&
1286         ref_mv_stack[ref_idx + 1].weight < REF_CAT_LEVEL) {
1287         return 2;
1288     }
1289 
1290     return 0;
1291 }
1292 
read_drl_idx(ParseCtxt * parse_ctxt,PartitionInfo * pi,BlockModeInfo * mbmi,int num_mv_found)1293 static void read_drl_idx(ParseCtxt *parse_ctxt, PartitionInfo *pi, BlockModeInfo *mbmi,
1294                          int num_mv_found) {
1295     SvtReader *r              = &parse_ctxt->r;
1296     uint8_t    ref_frame_type = av1_ref_frame_type(mbmi->ref_frame);
1297     mbmi->ref_mv_idx          = 0;
1298     if (mbmi->mode == NEWMV || mbmi->mode == NEW_NEWMV) {
1299         for (int idx = 0; idx < 2; ++idx) {
1300             if (num_mv_found > idx + 1) {
1301                 uint8_t drl_ctx = get_drl_ctx(pi->ref_mv_stack[ref_frame_type], idx);
1302                 int     drl_idx = svt_read_symbol(
1303                     r, parse_ctxt->cur_tile_ctx.drl_cdf[drl_ctx], 2, ACCT_STR);
1304                 mbmi->ref_mv_idx = idx;
1305                 if (!drl_idx)
1306                     return;
1307                 mbmi->ref_mv_idx = idx + 1;
1308             }
1309         }
1310     }
1311     if (have_nearmv_in_inter_mode(mbmi->mode)) {
1312         for (int idx = 1; idx < 3; ++idx) {
1313             if (num_mv_found > idx + 1) {
1314                 uint8_t drl_ctx = get_drl_ctx(pi->ref_mv_stack[ref_frame_type], idx);
1315                 int     drl_idx = svt_read_symbol(
1316                     r, parse_ctxt->cur_tile_ctx.drl_cdf[drl_ctx], 2, ACCT_STR);
1317                 mbmi->ref_mv_idx = idx + drl_idx - 1;
1318                 if (!drl_idx)
1319                     return;
1320             }
1321         }
1322     }
1323 }
1324 
1325 /* TODO: Harmonize*/
svt_find_best_ref_mvs(int allow_hp,IntMv * mvlist,IntMv * nearest_mv,IntMv * near_mv,int is_integer)1326 static void svt_find_best_ref_mvs(int allow_hp, IntMv *mvlist, IntMv *nearest_mv, IntMv *near_mv,
1327                                   int is_integer) {
1328     int i;
1329     // Make sure all the candidates are properly clamped etc
1330     for (i = 0; i < MAX_MV_REF_CANDIDATES; ++i) {
1331         lower_mv_precision(&mvlist[i].as_mv, allow_hp, is_integer);
1332     }
1333     *nearest_mv = mvlist[0];
1334     *near_mv    = mvlist[1];
1335 }
1336 
read_mv_component(SvtReader * r,NmvComponent * mvcomp,int use_subpel,int usehp)1337 int read_mv_component(SvtReader *r, NmvComponent *mvcomp, int use_subpel, int usehp) {
1338     int       mag, d, fr, hp;
1339     const int sign     = svt_read_symbol(r, mvcomp->sign_cdf, 2, ACCT_STR);
1340     const int mv_class = svt_read_symbol(r, mvcomp->classes_cdf, MV_CLASSES, ACCT_STR);
1341     const int class0   = mv_class == MV_CLASS_0;
1342 
1343     // Integer part
1344     if (class0) {
1345         d   = svt_read_symbol(r, mvcomp->class0_cdf, CLASS0_SIZE, ACCT_STR);
1346         mag = 0;
1347     } else {
1348         d = 0;
1349         for (int i = 0; i < mv_class; ++i)
1350             d |= svt_read_symbol(r, mvcomp->bits_cdf[i], 2, ACCT_STR) << i;
1351         mag = CLASS0_SIZE << (mv_class + 2);
1352     }
1353 
1354     fr = use_subpel
1355         ? svt_read_symbol(
1356               r, class0 ? mvcomp->class0_fp_cdf[d] : mvcomp->fp_cdf, MV_FP_SIZE, ACCT_STR)
1357         : 3;
1358 
1359     hp = usehp ? svt_read_symbol(r, class0 ? mvcomp->class0_hp_cdf : mvcomp->hp_cdf, 2, ACCT_STR)
1360                : 1;
1361 
1362     // Result
1363     mag += ((d << 3) | (fr << 1) | hp) + 1;
1364     return sign ? -mag : mag;
1365 }
1366 
read_mv(SvtReader * r,MV * mv,MV * ref,NmvContext * ctx,MvSubpelPrecision precision)1367 static INLINE void read_mv(SvtReader *r, MV *mv, MV *ref, NmvContext *ctx,
1368                            MvSubpelPrecision precision) {
1369     MV diff = k_zero_mv;
1370 
1371     const MvJointType joint_type = (MvJointType)svt_read_symbol(
1372         r, ctx->joints_cdf, MV_JOINTS, ACCT_STR);
1373 
1374     if (mv_joint_vertical(joint_type))
1375         diff.row = read_mv_component(
1376             r, &ctx->comps[0], precision > MV_SUBPEL_NONE, precision > MV_SUBPEL_LOW_PRECISION);
1377 
1378     if (mv_joint_horizontal(joint_type))
1379         diff.col = read_mv_component(
1380             r, &ctx->comps[1], precision > MV_SUBPEL_NONE, precision > MV_SUBPEL_LOW_PRECISION);
1381 
1382     mv->row = ref->row + diff.row;
1383     mv->col = ref->col + diff.col;
1384 }
1385 
assign_mv(ParseCtxt * parse_ctxt,PartitionInfo * pi,IntMv mv[2],IntMv * global_mvs,IntMv ref_mv[2],IntMv nearest_mv[2],IntMv near_mv[2],int is_compound,int allow_hp)1386 static INLINE int assign_mv(ParseCtxt *parse_ctxt, PartitionInfo *pi, IntMv mv[2],
1387                             IntMv *global_mvs, IntMv ref_mv[2], IntMv nearest_mv[2],
1388                             IntMv near_mv[2], int is_compound, int allow_hp) {
1389     SvtReader *    r    = &parse_ctxt->r;
1390     BlockModeInfo *mbmi = pi->mi;
1391 
1392     if (parse_ctxt->frame_header->force_integer_mv)
1393         allow_hp = MV_SUBPEL_NONE;
1394 
1395     switch (mbmi->mode) {
1396     case NEWMV: {
1397         NmvContext *const nmvc = &parse_ctxt->cur_tile_ctx.nmvc;
1398         read_mv(r, &mv[0].as_mv, &ref_mv[0].as_mv, nmvc, allow_hp);
1399         break;
1400     }
1401     case NEARESTMV: {
1402         mv[0].as_int = nearest_mv[0].as_int;
1403         break;
1404     }
1405     case NEARMV: {
1406         mv[0].as_int = near_mv[0].as_int;
1407         break;
1408     }
1409     case GLOBALMV: {
1410         mv[0].as_int = global_mvs[0].as_int;
1411         break;
1412     }
1413     case NEW_NEWMV: {
1414         assert(is_compound);
1415         for (int i = 0; i < 2; ++i) {
1416             NmvContext *const nmvc = &parse_ctxt->cur_tile_ctx.nmvc;
1417             read_mv(r, &mv[i].as_mv, &ref_mv[i].as_mv, nmvc, allow_hp);
1418         }
1419         break;
1420     }
1421     case NEAREST_NEARESTMV: {
1422         assert(is_compound);
1423         mv[0].as_int = nearest_mv[0].as_int;
1424         mv[1].as_int = nearest_mv[1].as_int;
1425         break;
1426     }
1427     case NEAR_NEARMV: {
1428         assert(is_compound);
1429         mv[0].as_int = near_mv[0].as_int;
1430         mv[1].as_int = near_mv[1].as_int;
1431         break;
1432     }
1433     case NEW_NEARESTMV: {
1434         NmvContext *const nmvc = &parse_ctxt->cur_tile_ctx.nmvc;
1435         read_mv(r, &mv[0].as_mv, &ref_mv[0].as_mv, nmvc, allow_hp);
1436         assert(is_compound);
1437         mv[1].as_int = nearest_mv[1].as_int;
1438         break;
1439     }
1440     case NEAREST_NEWMV: {
1441         mv[0].as_int           = nearest_mv[0].as_int;
1442         NmvContext *const nmvc = &parse_ctxt->cur_tile_ctx.nmvc;
1443         read_mv(r, &mv[1].as_mv, &ref_mv[1].as_mv, nmvc, allow_hp);
1444         assert(is_compound);
1445         break;
1446     }
1447     case NEAR_NEWMV: {
1448         mv[0].as_int           = near_mv[0].as_int;
1449         NmvContext *const nmvc = &parse_ctxt->cur_tile_ctx.nmvc;
1450         read_mv(r, &mv[1].as_mv, &ref_mv[1].as_mv, nmvc, allow_hp);
1451         assert(is_compound);
1452         break;
1453     }
1454     case NEW_NEARMV: {
1455         NmvContext *const nmvc = &parse_ctxt->cur_tile_ctx.nmvc;
1456         read_mv(r, &mv[0].as_mv, &ref_mv[0].as_mv, nmvc, allow_hp);
1457         assert(is_compound);
1458         mv[1].as_int = near_mv[1].as_int;
1459         break;
1460     }
1461     case GLOBAL_GLOBALMV: {
1462         assert(is_compound);
1463         mv[0].as_int = global_mvs[0].as_int;
1464         mv[1].as_int = global_mvs[1].as_int;
1465         break;
1466     }
1467     default: {
1468         return 0;
1469     }
1470     }
1471 
1472     int ret = is_mv_valid(&mv[0].as_mv);
1473     if (is_compound)
1474         ret = ret && is_mv_valid(&mv[1].as_mv);
1475     return ret;
1476 }
1477 
is_dv_valid(MV dv,ParseCtxt * parse_ctx,PartitionInfo * pi)1478 static INLINE int is_dv_valid(MV dv, ParseCtxt *parse_ctx, PartitionInfo *pi) {
1479     int       mi_row         = pi->mi_row;
1480     int       mi_col         = pi->mi_col;
1481     int       mib_size_log2  = parse_ctx->seq_header->sb_size_log2;
1482     int       subsampling_x  = parse_ctx->seq_header->color_config.subsampling_x;
1483     int       subsampling_y  = parse_ctx->seq_header->color_config.subsampling_y;
1484     BlockSize bsize          = pi->mi->sb_type;
1485     const int bw             = block_size_wide[bsize];
1486     const int bh             = block_size_high[bsize];
1487     const int scale_px_to_mv = 8;
1488     if (((dv.row & (scale_px_to_mv - 1)) || (dv.col & (scale_px_to_mv - 1))))
1489         return 0;
1490 
1491     TileInfo *tile          = &parse_ctx->cur_tile_info;
1492     const int src_top_edge  = mi_row * MI_SIZE * scale_px_to_mv + dv.row;
1493     const int tile_top_edge = tile->mi_row_start * MI_SIZE * scale_px_to_mv;
1494     if (src_top_edge < tile_top_edge)
1495         return 0;
1496     const int src_left_edge  = mi_col * MI_SIZE * scale_px_to_mv + dv.col;
1497     const int tile_left_edge = tile->mi_col_start * MI_SIZE * scale_px_to_mv;
1498     if (src_left_edge < tile_left_edge)
1499         return 0;
1500     const int src_bottom_edge  = (mi_row * MI_SIZE + bh) * scale_px_to_mv + dv.row;
1501     const int tile_bottom_edge = tile->mi_row_end * MI_SIZE * scale_px_to_mv;
1502     if (src_bottom_edge > tile_bottom_edge)
1503         return 0;
1504     const int src_right_edge  = (mi_col * MI_SIZE + bw) * scale_px_to_mv + dv.col;
1505     const int tile_right_edge = tile->mi_col_end * MI_SIZE * scale_px_to_mv;
1506     if (src_right_edge > tile_right_edge)
1507         return 0;
1508 
1509     // Special case for sub 8x8 chroma cases, to prevent referring to chroma
1510     // pixels outside current tile.
1511     int     num_planes    = parse_ctx->seq_header->color_config.mono_chrome ? 1 : MAX_MB_PLANE;
1512     int32_t is_chroma_ref = pi->is_chroma_ref;
1513     for (int plane = 1; plane < num_planes; ++plane) {
1514         if (is_chroma_ref) {
1515             if (bw < 8 && subsampling_x)
1516                 if (src_left_edge < tile_left_edge + 4 * scale_px_to_mv)
1517                     return 0;
1518             if (bh < 8 && subsampling_y)
1519                 if (src_top_edge < tile_top_edge + 4 * scale_px_to_mv)
1520                     return 0;
1521         }
1522     }
1523 
1524     const int max_mib_size       = 1 << mib_size_log2;
1525     const int active_sb_row      = mi_row >> mib_size_log2;
1526     const int active_sb64_col    = (mi_col * MI_SIZE) >> 6;
1527     const int sb_size            = max_mib_size * MI_SIZE;
1528     const int src_sb_row         = ((src_bottom_edge >> 3) - 1) / sb_size;
1529     const int src_sb64_col       = ((src_right_edge >> 3) - 1) >> 6;
1530     const int total_sb64_per_row = ((tile->mi_col_end - tile->mi_col_start - 1) >> 4) + 1;
1531     const int active_sb64        = active_sb_row * total_sb64_per_row + active_sb64_col;
1532     const int src_sb64           = src_sb_row * total_sb64_per_row + src_sb64_col;
1533     if (src_sb64 >= active_sb64 - INTRABC_DELAY_SB64)
1534         return 0;
1535 
1536     // Wavefront constraint: use only top left area of frame for reference.
1537     const int gradient  = 1 + INTRABC_DELAY_SB64 + (sb_size > 64);
1538     const int wf_offset = gradient * (active_sb_row - src_sb_row);
1539     if (src_sb_row > active_sb_row ||
1540         src_sb64_col >= active_sb64_col - INTRABC_DELAY_SB64 + wf_offset)
1541         return 0;
1542 
1543     return 1;
1544 }
1545 
dec_assign_dv(ParseCtxt * parse_ctxt,PartitionInfo * pi,IntMv * mv,IntMv * ref_mv)1546 int dec_assign_dv(ParseCtxt *parse_ctxt, PartitionInfo *pi, IntMv *mv, IntMv *ref_mv) {
1547     SvtReader *    r       = &parse_ctxt->r;
1548     FRAME_CONTEXT *frm_ctx = &parse_ctxt->cur_tile_ctx;
1549     read_mv(r, &mv->as_mv, &ref_mv->as_mv, &frm_ctx->ndvc, MV_SUBPEL_NONE);
1550     // DV should not have sub-pel.
1551     assert((mv->as_mv.col & 7) == 0);
1552     assert((mv->as_mv.row & 7) == 0);
1553     mv->as_mv.col = (mv->as_mv.col >> 3) * 8;
1554     mv->as_mv.row = (mv->as_mv.row >> 3) * 8;
1555     int valid     = is_mv_valid(&mv->as_mv) && is_dv_valid(mv->as_mv, parse_ctxt, pi);
1556     return valid;
1557 }
1558 
assign_intrabc_mv(ParseCtxt * parse_ctxt,IntMv ref_mvs[INTRA_FRAME+1][MAX_MV_REF_CANDIDATES],PartitionInfo * pi)1559 void assign_intrabc_mv(ParseCtxt *parse_ctxt, IntMv ref_mvs[INTRA_FRAME + 1][MAX_MV_REF_CANDIDATES],
1560                        PartitionInfo *pi) {
1561     BlockModeInfo *mbmi = pi->mi;
1562     IntMv          nearestmv, nearmv;
1563     svt_find_best_ref_mvs(0, ref_mvs[INTRA_FRAME], &nearestmv, &nearmv, 0);
1564     IntMv dv_ref = nearestmv.as_int == 0 ? nearmv : nearestmv;
1565     if (dv_ref.as_int == 0) {
1566         av1_find_ref_dv(&dv_ref,
1567                         &parse_ctxt->cur_tile_info,
1568                         parse_ctxt->seq_header->sb_mi_size,
1569                         pi->mi_row,
1570                         pi->mi_col);
1571     }
1572     // Ref DV should not have sub-pel.
1573     dv_ref.as_mv.col = (dv_ref.as_mv.col >> 3) * 8;
1574     dv_ref.as_mv.row = (dv_ref.as_mv.row >> 3) * 8;
1575     dec_assign_dv(parse_ctxt, pi, &mbmi->mv[0], &dv_ref);
1576 }
1577 
read_interintra_mode(ParseCtxt * parse_ctxt,BlockModeInfo * mbmi)1578 void read_interintra_mode(ParseCtxt *parse_ctxt, BlockModeInfo *mbmi) {
1579     SvtReader *    r       = &parse_ctxt->r;
1580     FRAME_CONTEXT *frm_ctx = &parse_ctxt->cur_tile_ctx;
1581     BlockSize      bsize   = mbmi->sb_type;
1582     if (parse_ctxt->seq_header->enable_interintra_compound && !mbmi->skip_mode &&
1583         is_interintra_allowed(mbmi)) {
1584         const int bsize_group = size_group_lookup[bsize];
1585         mbmi->is_inter_intra  = svt_read_symbol(
1586             r, frm_ctx->interintra_cdf[bsize_group], 2, ACCT_STR);
1587         assert(mbmi->ref_frame[1] == NONE_FRAME);
1588         if (mbmi->is_inter_intra) {
1589             mbmi->interintra_mode_params.interintra_mode = (InterIntraMode)svt_read_symbol(
1590                 r, frm_ctx->interintra_mode_cdf[bsize_group], INTERINTRA_MODES, ACCT_STR);
1591             mbmi->ref_frame[1]                            = INTRA_FRAME;
1592             mbmi->angle_delta[PLANE_TYPE_Y]               = 0;
1593             mbmi->angle_delta[PLANE_TYPE_UV]              = 0;
1594             mbmi->filter_intra_mode_info.use_filter_intra = 0;
1595             if (is_interintra_wedge_used(bsize)) {
1596                 mbmi->interintra_mode_params.wedge_interintra = svt_read_symbol(
1597                     r, frm_ctx->wedge_interintra_cdf[bsize], 2, ACCT_STR);
1598                 if (mbmi->interintra_mode_params.wedge_interintra) {
1599                     mbmi->interintra_mode_params.interintra_wedge_index = svt_read_symbol(
1600                         r, frm_ctx->wedge_idx_cdf[bsize], 16, ACCT_STR);
1601                 }
1602             }
1603         }
1604     }
1605 }
1606 
add_samples(BlockModeInfo * mbmi,int * pts,int * pts_inref,int row_offset,int sign_r,int col_offset,int sign_c)1607 static INLINE void add_samples(BlockModeInfo *mbmi, int *pts, int *pts_inref, int row_offset,
1608                                int sign_r, int col_offset, int sign_c) {
1609     int bw = block_size_wide[mbmi->sb_type];
1610     int bh = block_size_high[mbmi->sb_type];
1611     int x  = col_offset * MI_SIZE + sign_c * AOMMAX(bw, MI_SIZE) / 2 - 1;
1612     int y  = row_offset * MI_SIZE + sign_r * AOMMAX(bh, MI_SIZE) / 2 - 1;
1613 
1614     pts[0]       = (x * 8);
1615     pts[1]       = (y * 8);
1616     pts_inref[0] = (x * 8) + mbmi->mv[0].as_mv.col;
1617     pts_inref[1] = (y * 8) + mbmi->mv[0].as_mv.row;
1618 }
1619 
find_warp_samples(EbDecHandle * dec_handle,TileInfo * tile,PartitionInfo * pi,int * pts,int * pts_inref)1620 int find_warp_samples(EbDecHandle *dec_handle, TileInfo *tile, PartitionInfo *pi, int *pts,
1621                       int *pts_inref) {
1622     int mi_row = pi->mi_row;
1623     int mi_col = pi->mi_col;
1624 
1625     BlockModeInfo *const mbmi0          = pi->mi;
1626     int                  ref_frame      = mbmi0->ref_frame[0];
1627     int                  up_available   = pi->up_available;
1628     int                  left_available = pi->left_available;
1629     int                  i, mi_step = 1, np = 0;
1630 
1631     int do_tl = 1;
1632     int do_tr = 1;
1633     int b4_w  = mi_size_wide[pi->mi->sb_type];
1634     int b4_h  = mi_size_high[pi->mi->sb_type];
1635 
1636     // scan the nearest above rows
1637     if (up_available) {
1638         BlockModeInfo *mbmi = get_cur_mode_info(dec_handle, mi_row - 1, mi_col, pi->sb_info);
1639         uint8_t        n4_w = mi_size_wide[mbmi->sb_type];
1640 
1641         if (b4_w <= n4_w) {
1642             // Handle "current block width <= above block width" case.
1643             int col_offset = -mi_col % n4_w;
1644 
1645             if (col_offset < 0)
1646                 do_tl = 0;
1647             if (col_offset + n4_w > b4_w)
1648                 do_tr = 0;
1649 
1650             if (mbmi->ref_frame[0] == ref_frame && mbmi->ref_frame[1] == NONE_FRAME) {
1651                 add_samples(mbmi, pts, pts_inref, 0, -1, col_offset, 1);
1652                 pts += 2;
1653                 pts_inref += 2;
1654                 np++;
1655                 if (np >= LEAST_SQUARES_SAMPLES_MAX)
1656                     return LEAST_SQUARES_SAMPLES_MAX;
1657             }
1658         } else {
1659             // Handle "current block width > above block width" case.
1660             for (i = 0; i < AOMMIN(b4_w, tile->mi_col_end - mi_col); i += mi_step) {
1661                 mbmi    = get_cur_mode_info(dec_handle, mi_row - 1, mi_col + i, pi->sb_info);
1662                 n4_w    = mi_size_wide[mbmi->sb_type];
1663                 mi_step = AOMMIN(b4_w, n4_w);
1664 
1665                 if (mbmi->ref_frame[0] == ref_frame && mbmi->ref_frame[1] == NONE_FRAME) {
1666                     add_samples(mbmi, pts, pts_inref, 0, -1, i, 1);
1667                     pts += 2;
1668                     pts_inref += 2;
1669                     np++;
1670                     if (np >= LEAST_SQUARES_SAMPLES_MAX)
1671                         return LEAST_SQUARES_SAMPLES_MAX;
1672                 }
1673             }
1674         }
1675     }
1676     assert(np <= LEAST_SQUARES_SAMPLES_MAX);
1677 
1678     // scan the nearest left columns
1679     if (left_available) {
1680         BlockModeInfo *mbmi = get_cur_mode_info(dec_handle, mi_row, mi_col - 1, pi->sb_info);
1681         uint8_t        n4_h = mi_size_high[mbmi->sb_type];
1682 
1683         if (b4_h <= n4_h) {
1684             // Handle "current block height <= above block height" case.
1685             int row_offset = -mi_row % n4_h;
1686 
1687             if (row_offset < 0)
1688                 do_tl = 0;
1689 
1690             if (mbmi->ref_frame[0] == ref_frame && mbmi->ref_frame[1] == NONE_FRAME) {
1691                 add_samples(mbmi, pts, pts_inref, row_offset, 1, 0, -1);
1692                 pts += 2;
1693                 pts_inref += 2;
1694                 np++;
1695                 if (np >= LEAST_SQUARES_SAMPLES_MAX)
1696                     return LEAST_SQUARES_SAMPLES_MAX;
1697             }
1698         } else {
1699             // Handle "current block height > above block height" case.
1700             for (i = 0; i < AOMMIN(b4_h, tile->mi_row_end - mi_row); i += mi_step) {
1701                 mbmi    = get_cur_mode_info(dec_handle, mi_row + i, mi_col - 1, pi->sb_info);
1702                 n4_h    = mi_size_high[mbmi->sb_type];
1703                 mi_step = AOMMIN(b4_h, n4_h);
1704 
1705                 if (mbmi->ref_frame[0] == ref_frame && mbmi->ref_frame[1] == NONE_FRAME) {
1706                     add_samples(mbmi, pts, pts_inref, i, 1, 0, -1);
1707                     pts += 2;
1708                     pts_inref += 2;
1709                     np++;
1710                     if (np >= LEAST_SQUARES_SAMPLES_MAX)
1711                         return LEAST_SQUARES_SAMPLES_MAX;
1712                 }
1713             }
1714         }
1715     }
1716     assert(np <= LEAST_SQUARES_SAMPLES_MAX);
1717 
1718     // Top-left block
1719     if (do_tl && left_available && up_available) {
1720         BlockModeInfo *mbmi = get_cur_mode_info(dec_handle, mi_row - 1, mi_col - 1, pi->sb_info);
1721 
1722         if (mbmi->ref_frame[0] == ref_frame && mbmi->ref_frame[1] == NONE_FRAME) {
1723             add_samples(mbmi, pts, pts_inref, 0, -1, 0, -1);
1724             pts += 2;
1725             pts_inref += 2;
1726             np++;
1727             if (np >= LEAST_SQUARES_SAMPLES_MAX)
1728                 return LEAST_SQUARES_SAMPLES_MAX;
1729         }
1730     }
1731     assert(np <= LEAST_SQUARES_SAMPLES_MAX);
1732 
1733     // Top-right block
1734     if (do_tr && has_top_right(dec_handle, pi, AOMMAX(b4_w, b4_h))) {
1735         int mv_row = mi_row - 1;
1736         int mv_col = mi_col + b4_w;
1737 
1738         if (is_inside(tile, mv_col, mv_row)) {
1739             BlockModeInfo *mbmi = get_cur_mode_info(dec_handle, mv_row, mv_col, pi->sb_info);
1740 
1741             if (mbmi->ref_frame[0] == ref_frame && mbmi->ref_frame[1] == NONE_FRAME) {
1742                 add_samples(mbmi, pts, pts_inref, 0, -1, b4_w, 1);
1743                 np++;
1744                 if (np >= LEAST_SQUARES_SAMPLES_MAX)
1745                     return LEAST_SQUARES_SAMPLES_MAX;
1746             }
1747         }
1748     }
1749     assert(np <= LEAST_SQUARES_SAMPLES_MAX);
1750 
1751     return np;
1752 }
1753 
has_overlappable_cand(EbDecHandle * dec_handle,ParseCtxt * parse_ctx,PartitionInfo * pi)1754 int has_overlappable_cand(EbDecHandle *dec_handle, ParseCtxt *parse_ctx, PartitionInfo *pi) {
1755     int                   mi_row = pi->mi_row;
1756     int                   mi_col = pi->mi_col;
1757     const TileInfo *const tile   = &parse_ctx->cur_tile_info;
1758     BlockModeInfo *       mbmi   = pi->mi;
1759     if (!is_motion_variation_allowed_bsize(mbmi->sb_type))
1760         return 0;
1761 
1762     if (pi->up_available) {
1763         int w4 = mi_size_wide[mbmi->sb_type];
1764         int x4 = mi_col;
1765         while (x4 < AOMMIN(tile->mi_col_end, mi_col + w4)) {
1766             BlockModeInfo *top_nb_mode = get_cur_mode_info(
1767                 dec_handle, mi_row - 1, x4 | 1, pi->sb_info);
1768             x4 += AOMMAX(2, mi_size_wide[top_nb_mode->sb_type] >> 2);
1769             if (is_inter_block(top_nb_mode))
1770                 return 1;
1771         }
1772     }
1773     if (pi->left_available) {
1774         int h4 = mi_size_high[mbmi->sb_type];
1775         int y4 = mi_row;
1776         while (y4 < AOMMIN(tile->mi_row_end, mi_row + h4)) {
1777             BlockModeInfo *left_nb_mode = get_cur_mode_info(
1778                 dec_handle, y4 | 1, mi_col - 1, pi->sb_info);
1779             y4 += AOMMAX(2, mi_size_high[left_nb_mode->sb_type] >> 2);
1780             if (is_inter_block(left_nb_mode))
1781                 return 1;
1782         }
1783     }
1784     return 0;
1785 }
1786 
is_motion_mode_allowed(EbDecHandle * dec_handle,ParseCtxt * parse_ctx,GlobalMotionParams * gm_params,PartitionInfo * pi,int allow_warped_motion)1787 static INLINE MotionMode is_motion_mode_allowed(EbDecHandle *dec_handle, ParseCtxt *parse_ctx,
1788                                                 GlobalMotionParams *gm_params, PartitionInfo *pi,
1789                                                 int allow_warped_motion) {
1790     BlockModeInfo *mbmi = pi->mi;
1791     if (dec_handle->frame_header.force_integer_mv == 0) {
1792         const TransformationType gm_type = gm_params[mbmi->ref_frame[0]].gm_type;
1793         if (is_global_mv_block(mbmi->mode, mbmi->sb_type, gm_type))
1794             return SIMPLE_TRANSLATION;
1795     }
1796     if ((block_size_wide[mbmi->sb_type] >= 8 && block_size_high[mbmi->sb_type] >= 8) &&
1797         (mbmi->mode >= NEARESTMV && mbmi->mode < MB_MODE_COUNT) &&
1798         mbmi->ref_frame[1] != INTRA_FRAME && !has_second_ref(mbmi)) {
1799         if (!has_overlappable_cand(dec_handle, parse_ctx, pi))
1800             return SIMPLE_TRANSLATION;
1801         assert(!has_second_ref(mbmi));
1802 
1803         if (pi->num_samples >= 1 && (allow_warped_motion && !av1_is_scaled(pi->block_ref_sf[0]))) {
1804             if (dec_handle->frame_header.force_integer_mv) {
1805                 return OBMC_CAUSAL;
1806             }
1807             return WARPED_CAUSAL;
1808         }
1809         return OBMC_CAUSAL;
1810     } else {
1811         return SIMPLE_TRANSLATION;
1812     }
1813 }
1814 
read_motion_mode(EbDecHandle * dec_handle,ParseCtxt * parse_ctxt,PartitionInfo * pi)1815 MotionMode read_motion_mode(EbDecHandle *dec_handle, ParseCtxt *parse_ctxt, PartitionInfo *pi) {
1816     SvtReader *    r                   = &parse_ctxt->r;
1817     FRAME_CONTEXT *frm_ctx             = &parse_ctxt->cur_tile_ctx;
1818     FrameHeader *  frame_info          = &dec_handle->frame_header;
1819     int            allow_warped_motion = frame_info->allow_warped_motion;
1820     BlockModeInfo *mbmi                = pi->mi;
1821 
1822     if (dec_handle->frame_header.is_motion_mode_switchable == 0)
1823         return SIMPLE_TRANSLATION;
1824     if (mbmi->skip_mode)
1825         return SIMPLE_TRANSLATION;
1826 
1827     const MotionMode last_motion_mode_allowed = is_motion_mode_allowed(
1828         dec_handle, parse_ctxt, dec_handle->cur_pic_buf[0]->global_motion, pi, allow_warped_motion);
1829     int motion_mode;
1830 
1831     if (last_motion_mode_allowed == SIMPLE_TRANSLATION)
1832         return SIMPLE_TRANSLATION;
1833 
1834     if (last_motion_mode_allowed == OBMC_CAUSAL) {
1835         motion_mode = svt_read_symbol(r, frm_ctx->obmc_cdf[mbmi->sb_type], 2, ACCT_STR);
1836         return (MotionMode)(motion_mode);
1837     } else {
1838         motion_mode = svt_read_symbol(
1839             r, frm_ctx->motion_mode_cdf[mbmi->sb_type], MOTION_MODES, ACCT_STR);
1840         return (MotionMode)(motion_mode);
1841     }
1842 }
1843 
get_comp_group_idx_context(ParseCtxt * parse_ctxt,const PartitionInfo * xd)1844 static INLINE int get_comp_group_idx_context(ParseCtxt *parse_ctxt, const PartitionInfo *xd) {
1845     const BlockModeInfo *const above_mi  = xd->above_mbmi;
1846     const BlockModeInfo *const left_mi   = xd->left_mbmi;
1847     int                        above_ctx = 0, left_ctx = 0;
1848 
1849     if (above_mi) {
1850         if (has_second_ref(above_mi)) {
1851             above_ctx =
1852                 parse_ctxt->parse_above_nbr4x4_ctxt
1853                     ->above_comp_grp_idx[xd->mi_col - parse_ctxt->cur_tile_info.mi_col_start];
1854         } else if (above_mi->ref_frame[0] == ALTREF_FRAME)
1855             above_ctx = 3;
1856     }
1857     if (left_mi) {
1858         if (has_second_ref(left_mi)) {
1859             left_ctx = parse_ctxt->parse_left_nbr4x4_ctxt
1860                            ->left_comp_grp_idx[xd->mi_row - parse_ctxt->sb_row_mi];
1861         } else if (left_mi->ref_frame[0] == ALTREF_FRAME)
1862             left_ctx = 3;
1863     }
1864 
1865     return AOMMIN(5, above_ctx + left_ctx);
1866 }
1867 
get_comp_index_context(EbDecHandle * dec_handle,PartitionInfo * pi)1868 int get_comp_index_context(EbDecHandle *dec_handle, PartitionInfo *pi) {
1869     BlockModeInfo *mbmi       = pi->mi;
1870     SeqHeader *    seq_params = &dec_handle->seq_header;
1871     FrameHeader *  frm_header = &dec_handle->frame_header;
1872 
1873     int bck_frame_index = 0, fwd_frame_index = 0;
1874     int cur_frame_index = frm_header->order_hint;
1875 
1876     EbDecPicBuf *bck_buf = get_ref_frame_buf(dec_handle, mbmi->ref_frame[0]);
1877     EbDecPicBuf *fwd_buf = get_ref_frame_buf(dec_handle, mbmi->ref_frame[1]);
1878 
1879     if (bck_buf != NULL)
1880         bck_frame_index = bck_buf->order_hint;
1881     if (fwd_buf != NULL)
1882         fwd_frame_index = fwd_buf->order_hint;
1883 
1884     int fwd = abs(
1885         get_relative_dist(&seq_params->order_hint_info, fwd_frame_index, cur_frame_index));
1886     int bck = abs(
1887         get_relative_dist(&seq_params->order_hint_info, cur_frame_index, bck_frame_index));
1888 
1889     const BlockModeInfo *const above_mi = pi->above_mbmi;
1890     const BlockModeInfo *const left_mi  = pi->left_mbmi;
1891 
1892     int       above_ctx = 0, left_ctx = 0;
1893     const int offset = (fwd == bck);
1894 
1895     if (above_mi != NULL) {
1896         if (has_second_ref(above_mi))
1897             above_ctx = above_mi->compound_idx;
1898         else if (above_mi->ref_frame[0] == ALTREF_FRAME)
1899             above_ctx = 1;
1900     }
1901 
1902     if (left_mi != NULL) {
1903         if (has_second_ref(left_mi))
1904             left_ctx = left_mi->compound_idx;
1905         else if (left_mi->ref_frame[0] == ALTREF_FRAME)
1906             left_ctx = 1;
1907     }
1908 
1909     return above_ctx + left_ctx + 3 * offset;
1910 }
1911 
update_compound_ctx(ParseCtxt * parse_ctxt,PartitionInfo * pi,uint32_t blk_row,uint32_t blk_col,uint32_t comp_grp_idx)1912 void update_compound_ctx(ParseCtxt *parse_ctxt, PartitionInfo *pi, uint32_t blk_row,
1913                          uint32_t blk_col, uint32_t comp_grp_idx) {
1914     ParseAboveNbr4x4Ctxt *above_parse_ctx = parse_ctxt->parse_above_nbr4x4_ctxt;
1915     ParseLeftNbr4x4Ctxt * left_parse_ctx  = parse_ctxt->parse_left_nbr4x4_ctxt;
1916 
1917     const uint32_t bw = mi_size_wide[pi->mi->sb_type];
1918     const uint32_t bh = mi_size_high[pi->mi->sb_type];
1919 
1920     int8_t *above_ctx = above_parse_ctx->above_comp_grp_idx + blk_col -
1921         parse_ctxt->cur_tile_info.mi_col_start;
1922     int8_t *left_ctx = left_parse_ctx->left_comp_grp_idx +
1923         ((blk_row - parse_ctxt->sb_row_mi) & MAX_MIB_MASK);
1924 
1925     memset(above_ctx, comp_grp_idx, bw);
1926     memset(left_ctx, comp_grp_idx, bh);
1927 }
1928 
read_compound_type(EbDecHandle * dec_handle,ParseCtxt * parse_ctxt,PartitionInfo * pi)1929 void read_compound_type(EbDecHandle *dec_handle, ParseCtxt *parse_ctxt, PartitionInfo *pi) {
1930     SvtReader *    r              = &parse_ctxt->r;
1931     BlockModeInfo *mbmi           = pi->mi;
1932     BlockSize      bsize          = mbmi->sb_type;
1933     int32_t        comp_group_idx = 0;
1934     mbmi->compound_idx            = 1;
1935     FRAME_CONTEXT *frm_ctx        = &parse_ctxt->cur_tile_ctx;
1936 
1937     if (mbmi->skip_mode)
1938         mbmi->inter_inter_compound.type = COMPOUND_AVERAGE;
1939 
1940     if (has_second_ref(mbmi) && !mbmi->skip_mode) {
1941         // Read idx to indicate current compound inter prediction mode group
1942         const int masked_compound_used = is_any_masked_compound_used(bsize) &&
1943             dec_handle->seq_header.enable_masked_compound;
1944 
1945         if (masked_compound_used) {
1946             const int ctx_comp_group_idx = get_comp_group_idx_context(parse_ctxt, pi);
1947             comp_group_idx               = svt_read_symbol(
1948                 r, frm_ctx->comp_group_idx_cdf[ctx_comp_group_idx], 2, ACCT_STR);
1949         }
1950 
1951         if (comp_group_idx == 0) {
1952             if (dec_handle->seq_header.order_hint_info.enable_jnt_comp) {
1953                 const int comp_index_ctx = get_comp_index_context(dec_handle, pi);
1954                 mbmi->compound_idx       = svt_read_symbol(
1955                     r, frm_ctx->compound_index_cdf[comp_index_ctx], 2, ACCT_STR);
1956                 mbmi->inter_inter_compound.type = mbmi->compound_idx ? COMPOUND_AVERAGE
1957                                                                      : COMPOUND_DISTWTD;
1958             } else {
1959                 // Distance-weighted compound is disabled, so always use average
1960                 mbmi->compound_idx              = 1;
1961                 mbmi->inter_inter_compound.type = COMPOUND_AVERAGE;
1962             }
1963         } else {
1964             assert(dec_handle->frame_header.reference_mode != SINGLE_REFERENCE &&
1965                    is_inter_compound_mode(mbmi->mode) && mbmi->motion_mode == SIMPLE_TRANSLATION);
1966             assert(masked_compound_used);
1967 
1968             // compound_diffwtd, wedge
1969             if (is_interinter_compound_used(COMPOUND_WEDGE, bsize))
1970                 mbmi->inter_inter_compound.type = COMPOUND_WEDGE +
1971                     svt_read_symbol(r,
1972                                     frm_ctx->compound_type_cdf[bsize],
1973                                     MASKED_COMPOUND_TYPES,
1974                                     ACCT_STR);
1975             else
1976                 mbmi->inter_inter_compound.type = COMPOUND_DIFFWTD;
1977 
1978             if (mbmi->inter_inter_compound.type == COMPOUND_WEDGE) {
1979                 assert(is_interinter_compound_used(COMPOUND_WEDGE, bsize));
1980                 mbmi->inter_inter_compound.wedge_index = svt_read_symbol(
1981                     r, frm_ctx->wedge_idx_cdf[bsize], 16, ACCT_STR);
1982                 mbmi->inter_inter_compound.wedge_sign = svt_read_bit(r, ACCT_STR);
1983             } else {
1984                 assert(mbmi->inter_inter_compound.type == COMPOUND_DIFFWTD);
1985                 mbmi->inter_inter_compound.mask_type = svt_read_literal(
1986                     r, MAX_DIFFWTD_MASK_BITS, ACCT_STR);
1987             }
1988         }
1989     }
1990 
1991     update_compound_ctx(parse_ctxt, pi, pi->mi_row, pi->mi_col, comp_group_idx);
1992 }
1993 
is_nontrans_global_motion(PartitionInfo * pi,GlobalMotionParams * gm_params)1994 static INLINE int is_nontrans_global_motion(PartitionInfo *pi, GlobalMotionParams *gm_params) {
1995     int            ref;
1996     BlockModeInfo *mbmi = pi->mi;
1997     // First check if all modes are GLOBALMV
1998     if (mbmi->mode != GLOBALMV && mbmi->mode != GLOBAL_GLOBALMV)
1999         return 0;
2000 
2001     if (AOMMIN(mi_size_wide[mbmi->sb_type], mi_size_high[mbmi->sb_type]) < 2)
2002         return 0;
2003 
2004     // Now check if all global motion is non translational
2005     for (ref = 0; ref < 1 + has_second_ref(mbmi); ++ref) {
2006         if (gm_params[mbmi->ref_frame[ref]].gm_type == TRANSLATION)
2007             return 0;
2008     }
2009     return 1;
2010 }
2011 
av1_is_interp_needed(PartitionInfo * pi,GlobalMotionParams * gm_params)2012 static INLINE int av1_is_interp_needed(PartitionInfo *pi, GlobalMotionParams *gm_params) {
2013     BlockModeInfo *mbmi = pi->mi;
2014     if (mbmi->skip_mode)
2015         return 0;
2016     if (mbmi->motion_mode == WARPED_CAUSAL)
2017         return 0;
2018     if (is_nontrans_global_motion(pi, gm_params))
2019         return 0;
2020     return 1;
2021 }
2022 
get_ref_filter_type(const BlockModeInfo * ref_mbmi,int dir,MvReferenceFrame ref_frame)2023 static InterpFilter get_ref_filter_type(const BlockModeInfo *ref_mbmi, int dir,
2024                                         MvReferenceFrame ref_frame) {
2025     return ((ref_mbmi->ref_frame[0] == ref_frame || ref_mbmi->ref_frame[1] == ref_frame)
2026                 ? av1_extract_interp_filter(ref_mbmi->interp_filters, dir & 0x01)
2027                 : SWITCHABLE_FILTERS);
2028 }
2029 
get_context_interp(PartitionInfo * pi,int dir)2030 int get_context_interp(PartitionInfo *pi, int dir) {
2031     const BlockModeInfo *const mbmi = pi->mi;
2032     const int ctx_offset            = (mbmi->ref_frame[1] > INTRA_FRAME) * INTER_FILTER_COMP_OFFSET;
2033     assert(dir == 0 || dir == 1);
2034     const MvReferenceFrame ref_frame = mbmi->ref_frame[0];
2035 
2036     int filter_type_ctx = ctx_offset + (dir & 0x01) * INTER_FILTER_DIR_OFFSET;
2037     int left_type       = SWITCHABLE_FILTERS;
2038     int above_type      = SWITCHABLE_FILTERS;
2039 
2040     if (pi->left_available)
2041         left_type = get_ref_filter_type(pi->left_mbmi, dir, ref_frame);
2042 
2043     if (pi->up_available)
2044         above_type = get_ref_filter_type(pi->above_mbmi, dir, ref_frame);
2045 
2046     if (left_type == above_type) {
2047         filter_type_ctx += left_type;
2048     } else if (left_type == SWITCHABLE_FILTERS) {
2049         assert(above_type != SWITCHABLE_FILTERS);
2050         filter_type_ctx += above_type;
2051     } else if (above_type == SWITCHABLE_FILTERS) {
2052         assert(left_type != SWITCHABLE_FILTERS);
2053         filter_type_ctx += left_type;
2054     } else {
2055         filter_type_ctx += SWITCHABLE_FILTERS;
2056     }
2057 
2058     return filter_type_ctx;
2059 }
2060 
inter_block_mode_info(EbDecHandle * dec_handle,ParseCtxt * parse_ctxt,PartitionInfo * pi)2061 void inter_block_mode_info(EbDecHandle *dec_handle, ParseCtxt *parse_ctxt, PartitionInfo *pi) {
2062     BlockModeInfo *     mbmi     = pi->mi;
2063     SvtReader *         r        = &parse_ctxt->r;
2064     const int           allow_hp = dec_handle->frame_header.allow_high_precision_mv;
2065     IntMv               ref_mvs[MODE_CTX_REF_FRAMES][MAX_MV_REF_CANDIDATES] = {{{0}}};
2066     int16_t             inter_mode_ctx[MODE_CTX_REF_FRAMES];
2067     int                 pts[SAMPLES_ARRAY_SIZE], pts_inref[SAMPLES_ARRAY_SIZE];
2068     SegmentationParams *seg = &dec_handle->frame_header.segmentation_params;
2069     MvCount             mv_cnt;
2070 
2071     mbmi->palette_size[0] = 0;
2072     mbmi->palette_size[1] = 0;
2073 
2074     /* TO-DO initialize palette info */
2075 
2076     svt_collect_neighbors_ref_counts(pi);
2077 
2078     read_ref_frames(parse_ctxt, pi);
2079     /* if ((pi->mi->ref_frame[0] >= BWDREF_FRAME && pi->mi->ref_frame[0] <= ALTREF_FRAME) ||
2080         (pi->mi->ref_frame[1] >= BWDREF_FRAME && pi->mi->ref_frame[1] <= ALTREF_FRAME)) {
2081         SVT_LOG("ALTREF found - frame : %d\n", dec_handle->dec_cnt);
2082         exit(0);
2083     }*/
2084     const int is_compound = has_second_ref(mbmi);
2085 
2086     MvReferenceFrame ref_frame = av1_ref_frame_type(mbmi->ref_frame);
2087     IntMv            global_mvs[2];
2088     svt_av1_find_mv_refs(dec_handle,
2089                          pi,
2090                          parse_ctxt,
2091                          ref_frame,
2092                          pi->ref_mv_stack,
2093                          ref_mvs,
2094                          global_mvs,
2095                          inter_mode_ctx,
2096                          &mv_cnt);
2097 
2098 #if EXTRA_DUMP
2099     if (enable_dump) {
2100         SVT_LOG("\n mi_row: %d mi_col: %d\n", pi->mi_row, pi->mi_col);
2101         /*for (int i = 0; i < MODE_CTX_REF_FRAMES; i++)
2102             for (int j = 0; j < MAX_REF_MV_STACK_SIZE; j++)
2103                 SVT_LOG("ref_mv_stack[%d][%d]=%d\t", i, j, pi->ref_mv_stack[i][j].this_mv.as_int);
2104         SVT_LOG("\n");*/
2105         fflush(stdout);
2106     }
2107 #endif
2108 
2109     int mode_ctx     = svt_mode_context_analyzer(inter_mode_ctx, mbmi->ref_frame);
2110     mbmi->ref_mv_idx = 0;
2111 
2112     if (mbmi->skip_mode) {
2113         assert(is_compound);
2114         mbmi->mode = NEAREST_NEARESTMV;
2115     } else {
2116         if (seg_feature_active(seg, mbmi->segment_id, SEG_LVL_SKIP) ||
2117             seg_feature_active(seg, mbmi->segment_id, SEG_LVL_GLOBALMV))
2118             mbmi->mode = GLOBALMV;
2119         else {
2120             if (is_compound)
2121                 mbmi->mode = read_inter_compound_mode(parse_ctxt, mode_ctx);
2122             else {
2123                 int new_mv = svt_read_symbol(
2124                     r, parse_ctxt->cur_tile_ctx.newmv_cdf[mode_ctx & NEWMV_CTX_MASK], 2, ACCT_STR);
2125                 if (new_mv) {
2126                     int zero_mv = svt_read_symbol(
2127                         r,
2128                         parse_ctxt->cur_tile_ctx
2129                             .zeromv_cdf[(mode_ctx >> GLOBALMV_OFFSET) & GLOBALMV_CTX_MASK],
2130                         2,
2131                         ACCT_STR);
2132                     if (zero_mv) {
2133                         int ref_mv = svt_read_symbol(
2134                             r,
2135                             parse_ctxt->cur_tile_ctx
2136                                 .refmv_cdf[(mode_ctx >> REFMV_OFFSET) & REFMV_CTX_MASK],
2137                             2,
2138                             ACCT_STR);
2139                         mbmi->mode = ref_mv ? NEARMV : NEARESTMV;
2140                     } else
2141                         mbmi->mode = GLOBALMV;
2142                 } else
2143                     mbmi->mode = NEWMV;
2144             }
2145             if (mbmi->mode == NEWMV || mbmi->mode == NEW_NEWMV || has_nearmv(mbmi->mode))
2146                 read_drl_idx(parse_ctxt, pi, mbmi, mv_cnt.num_mv_found[ref_frame]);
2147         }
2148     }
2149     mbmi->uv_mode = UV_DC_PRED;
2150 
2151     IntMv ref_mv[2];
2152     IntMv nearestmv[2], nearmv[2];
2153     memset(nearestmv, 0, sizeof(nearestmv));
2154     if (!is_compound && mbmi->mode != GLOBALMV) {
2155         svt_find_best_ref_mvs(allow_hp,
2156                               ref_mvs[mbmi->ref_frame[0]],
2157                               &nearestmv[0],
2158                               &nearmv[0],
2159                               dec_handle->frame_header.force_integer_mv);
2160     }
2161     if (is_compound && mbmi->mode != GLOBAL_GLOBALMV) {
2162         int ref_mv_idx = mbmi->ref_mv_idx + 1;
2163         nearestmv[0]   = pi->ref_mv_stack[ref_frame][0].this_mv;
2164         nearestmv[1]   = pi->ref_mv_stack[ref_frame][0].comp_mv;
2165         nearmv[0]      = pi->ref_mv_stack[ref_frame][ref_mv_idx].this_mv;
2166         nearmv[1]      = pi->ref_mv_stack[ref_frame][ref_mv_idx].comp_mv;
2167         lower_mv_precision(
2168             &nearestmv[0].as_mv, allow_hp, dec_handle->frame_header.force_integer_mv);
2169         lower_mv_precision(
2170             &nearestmv[1].as_mv, allow_hp, dec_handle->frame_header.force_integer_mv);
2171         lower_mv_precision(&nearmv[0].as_mv, allow_hp, dec_handle->frame_header.force_integer_mv);
2172         lower_mv_precision(&nearmv[1].as_mv, allow_hp, dec_handle->frame_header.force_integer_mv);
2173     } else if (mbmi->ref_mv_idx > 0 && mbmi->mode == NEARMV) {
2174         IntMv cur_mv = pi->ref_mv_stack[mbmi->ref_frame[0]][1 + mbmi->ref_mv_idx].this_mv;
2175         nearmv[0]    = cur_mv;
2176     }
2177 
2178     ref_mv[0] = nearestmv[0];
2179     ref_mv[1] = nearestmv[1];
2180 
2181     if (is_compound) {
2182         int ref_mv_idx = mbmi->ref_mv_idx;
2183         // Special case: NEAR_NEWMV and NEW_NEARMV modes use
2184         // 1 + mbmi->ref_mv_idx (like NEARMV) instead of
2185         // mbmi->ref_mv_idx (like NEWMV)
2186         if (mbmi->mode == NEAR_NEWMV || mbmi->mode == NEW_NEARMV)
2187             ref_mv_idx = 1 + mbmi->ref_mv_idx;
2188 
2189         if (compound_ref0_mode(mbmi->mode) == NEWMV)
2190             ref_mv[0] = pi->ref_mv_stack[ref_frame][ref_mv_idx].this_mv;
2191 
2192         if (compound_ref1_mode(mbmi->mode) == NEWMV)
2193             ref_mv[1] = pi->ref_mv_stack[ref_frame][ref_mv_idx].comp_mv;
2194     } else {
2195         if (mbmi->mode == NEWMV) {
2196             if (mv_cnt.num_mv_found[ref_frame] > 1)
2197                 ref_mv[0] = pi->ref_mv_stack[ref_frame][mbmi->ref_mv_idx].this_mv;
2198         }
2199     }
2200 
2201     assign_mv(
2202         parse_ctxt, pi, mbmi->mv, global_mvs, ref_mv, nearestmv, nearmv, is_compound, allow_hp);
2203 
2204 #if EXTRA_DUMP
2205     if (enable_dump) {
2206         SVT_LOG("\n mode %d MV %d %d \n", mbmi->mode, mbmi->mv[0].as_mv.row, mbmi->mv[0].as_mv.col);
2207         fflush(stdout);
2208     }
2209 #endif
2210     read_interintra_mode(parse_ctxt, mbmi);
2211 
2212     for (int ref = 0; ref < 1 + has_second_ref(mbmi); ++ref) {
2213         const MvReferenceFrame frame = mbmi->ref_frame[ref];
2214         pi->block_ref_sf[ref]        = get_ref_scale_factors(dec_handle, frame);
2215     }
2216 
2217     pi->num_samples = find_warp_samples(dec_handle, &parse_ctxt->cur_tile_info, pi, pts, pts_inref);
2218 
2219     mbmi->motion_mode = read_motion_mode(dec_handle, parse_ctxt, pi);
2220 
2221     read_compound_type(dec_handle, parse_ctxt, pi);
2222 
2223     if (!av1_is_interp_needed(pi, dec_handle->cur_pic_buf[0]->global_motion)) {
2224         mbmi->interp_filters = av1_broadcast_interp_filter(
2225             av1_unswitchable_filter(dec_handle->frame_header.interpolation_filter));
2226     } else {
2227         if (dec_handle->frame_header.interpolation_filter != SWITCHABLE) {
2228             mbmi->interp_filters = av1_broadcast_interp_filter(
2229                 dec_handle->frame_header.interpolation_filter);
2230         } else {
2231             InterpFilter ref0_filter[2] = {EIGHTTAP_REGULAR, EIGHTTAP_REGULAR};
2232             for (int dir = 0; dir < 2; ++dir) {
2233                 const int ctx    = get_context_interp(pi, dir);
2234                 ref0_filter[dir] = (InterpFilter)svt_read_symbol(
2235                     r,
2236                     parse_ctxt->cur_tile_ctx.switchable_interp_cdf[ctx],
2237                     SWITCHABLE_FILTERS,
2238                     ACCT_STR);
2239                 if (dec_handle->seq_header.enable_dual_filter == 0) {
2240                     ref0_filter[1] = ref0_filter[0];
2241                     break;
2242                 }
2243             }
2244             mbmi->interp_filters = av1_make_interp_filters(ref0_filter[0], ref0_filter[1]);
2245         }
2246     }
2247 }
2248 
get_palette_color_context(uint8_t (* color_map)[COLOR_MAP_STRIDE][COLOR_MAP_STRIDE],int r,int c,int palette_size,uint8_t * color_order)2249 int get_palette_color_context(uint8_t (*color_map)[COLOR_MAP_STRIDE][COLOR_MAP_STRIDE], int r,
2250                               int c, int palette_size, uint8_t *color_order) {
2251     // Get color indices of neighbors.
2252     int color_neighbors[NUM_PALETTE_NEIGHBORS] = {
2253         c - 1 >= 0 ? (*color_map)[r][c - 1] : -1,
2254         c - 1 >= 0 && r - 1 >= 0 ? (*color_map)[r - 1][c - 1] : -1,
2255         r - 1 >= 0 ? (*color_map)[r - 1][c] : -1};
2256     int              scores[PALETTE_MAX_SIZE + 10]  = {0};
2257     static const int weights[NUM_PALETTE_NEIGHBORS] = {2, 1, 2};
2258     for (int i = 0; i < NUM_PALETTE_NEIGHBORS; ++i)
2259         if (color_neighbors[i] >= 0)
2260             scores[color_neighbors[i]] += weights[i];
2261 
2262     for (int i = 0; i < PALETTE_MAX_SIZE; ++i) color_order[i] = i;
2263 
2264     for (int i = 0; i < NUM_PALETTE_NEIGHBORS; i++) {
2265         int max_score = scores[i];
2266         int max_id    = i;
2267         for (int j = i + 1; j < palette_size; j++) {
2268             if (scores[j] > max_score) {
2269                 max_score = scores[j];
2270                 max_id    = j;
2271             }
2272         }
2273         if (max_id != i) {
2274             max_score           = scores[max_id];
2275             int max_color_order = color_order[max_id];
2276             for (int k = max_id; k > i; k--) {
2277                 scores[k]      = scores[k - 1];
2278                 color_order[k] = color_order[k - 1];
2279             }
2280             scores[i]      = max_score;
2281             color_order[i] = max_color_order;
2282         }
2283     }
2284     int              color_index_ctx_hash                    = 0;
2285     static const int hash_multipliers[NUM_PALETTE_NEIGHBORS] = {1, 2, 2};
2286     for (int i = 0; i < NUM_PALETTE_NEIGHBORS; i++)
2287         color_index_ctx_hash += scores[i] * hash_multipliers[i];
2288     assert(color_index_ctx_hash > 0);
2289     assert(color_index_ctx_hash <= MAX_COLOR_CONTEXT_HASH);
2290 
2291     const int color_index_ctx = palette_color_index_context_lookup[color_index_ctx_hash];
2292     assert(color_index_ctx >= 0);
2293     assert(color_index_ctx < PALETTE_COLOR_INDEX_CONTEXTS);
2294     return color_index_ctx;
2295 }
2296 
palette_tokens(EbDecHandle * dec_handle,ParseCtxt * parse_ctx,PartitionInfo * pi)2297 void palette_tokens(EbDecHandle *dec_handle, ParseCtxt *parse_ctx, PartitionInfo *pi) {
2298     int            mi_row           = pi->mi_row;
2299     int            mi_col           = pi->mi_col;
2300     BlockModeInfo *mbmi             = pi->mi;
2301     BlockSize      bsize            = mbmi->sb_type;
2302     FRAME_CONTEXT *frm_ctx          = &parse_ctx->cur_tile_ctx;
2303     SvtReader *    r                = &parse_ctx->r;
2304     int            block_height     = block_size_high[bsize];
2305     int            block_width      = block_size_wide[bsize];
2306     int            mi_cols          = (&dec_handle->frame_header)->mi_cols;
2307     int            mi_rows          = (&dec_handle->frame_header)->mi_rows;
2308     int            on_screen_height = MIN(block_height, (mi_rows - mi_row) * MI_SIZE);
2309     int            on_screen_width  = MIN(block_width, (mi_cols - mi_col) * MI_SIZE);
2310 
2311     int32_t is_chroma_ref = pi->is_chroma_ref;
2312     uint8_t color_order[PALETTE_MAX_SIZE];
2313     uint8_t color_map[COLOR_MAP_STRIDE][COLOR_MAP_STRIDE];
2314     for (int plane_itr = 0; plane_itr < MAX_MB_PLANE; plane_itr++) {
2315         uint8_t   palette_size = mbmi->palette_size[plane_itr != 0];
2316         const int sub_x        = plane_itr ? dec_handle->seq_header.color_config.subsampling_x : 0;
2317         const int sub_y        = plane_itr ? dec_handle->seq_header.color_config.subsampling_y : 0;
2318         if (plane_itr < PLANE_TYPES && palette_size) {
2319             block_height     = block_height >> sub_y;
2320             block_width      = block_width >> sub_x;
2321             on_screen_height = on_screen_height >> sub_y;
2322             on_screen_width  = on_screen_width >> sub_x;
2323 
2324             if (plane_itr) {
2325                 if (block_width < 4) {
2326                     block_width += 2;
2327                     on_screen_width += 2;
2328                 }
2329                 if (block_height < 4) {
2330                     block_height += 2;
2331                     on_screen_height += 2;
2332                 }
2333             }
2334 
2335             if ((plane_itr ? is_chroma_ref : 1)) {
2336                 int color_index_map = svt_read_ns_ae(r, palette_size, ACCT_STR);
2337                 color_map[0][0]     = color_index_map;
2338                 for (int i = 1; i < on_screen_height + on_screen_width - 1; i++) {
2339                     for (int j = MIN(i, on_screen_width - 1); j >= MAX(0, i - on_screen_height + 1);
2340                          j--) {
2341                         int color_ctx = get_palette_color_context(
2342                             &color_map, (i - j), j, palette_size, color_order);
2343                         int palette_color_idx = svt_read_symbol(
2344                             r,
2345                             plane_itr
2346                                 ? frm_ctx->palette_uv_color_index_cdf[palette_size -
2347                                                                       PALETTE_MIN_SIZE][color_ctx]
2348                                 : frm_ctx->palette_y_color_index_cdf[palette_size -
2349                                                                      PALETTE_MIN_SIZE][color_ctx],
2350                             palette_size,
2351                             ACCT_STR);
2352                         color_map[(i - j)][j] = color_order[palette_color_idx];
2353                     }
2354                 }
2355                 for (int i = 0; i < on_screen_height; i++) {
2356                     for (int j = on_screen_width; j < block_width; j++) {
2357                         color_map[i][j] = color_map[i][on_screen_width - 1];
2358                     }
2359                 }
2360                 for (int i = on_screen_height; i < block_height; i++) {
2361                     for (int j = 0; j < block_width; j++) {
2362                         color_map[i][j] = color_map[on_screen_height - 1][j];
2363                     }
2364                 }
2365             }
2366         }
2367 
2368         if ((plane_itr ? is_chroma_ref : 1)) {
2369             if (palette_size) {
2370                 /* Palette prediction process */
2371                 void *               blk_recon_buf;
2372                 int32_t              recon_stride;
2373                 EbPictureBufferDesc *recon_picture_buf = dec_handle->cur_pic_buf[0]->ps_pic_buf;
2374 
2375                 derive_blk_pointers(recon_picture_buf,
2376                                     plane_itr,
2377                                     (mi_col >> sub_x) * MI_SIZE,
2378                                     (mi_row >> sub_y) * MI_SIZE,
2379                                     &blk_recon_buf,
2380                                     &recon_stride,
2381                                     sub_x,
2382                                     sub_y);
2383                 uint16_t *palette = parse_ctx->palette_colors[plane_itr];
2384                 if (recon_picture_buf->bit_depth == EB_8BIT && !(dec_handle->is_16bit_pipeline)) {
2385                     uint8_t *temp_buf = (uint8_t *)blk_recon_buf;
2386                     for (int i = 0; i < block_height; i++) {
2387                         for (int j = 0; j < block_width; j++) {
2388                             temp_buf[i * recon_stride + j] = (uint8_t)palette[color_map[i][j]];
2389                         }
2390                     }
2391                 } else {
2392                     uint16_t *temp_buf = (uint16_t *)blk_recon_buf;
2393                     for (int i = 0; i < block_height; i++) {
2394                         for (int j = 0; j < block_width; j++) {
2395                             temp_buf[i * recon_stride + j] = palette[color_map[i][j]];
2396                         }
2397                     }
2398                 }
2399             }
2400         }
2401     }
2402 }
2403