1 /*
2  * Copyright (c) 2019, Alliance for Open Media. All rights reserved
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include "config/aom_config.h"
13 
14 #include "aom_ports/system_state.h"
15 
16 #include "av1/encoder/encodemv.h"
17 #if !CONFIG_REALTIME_ONLY
18 #include "av1/encoder/misc_model_weights.h"
19 #endif  // !CONFIG_REALTIME_ONLY
20 #include "av1/encoder/mv_prec.h"
21 
22 #if !CONFIG_REALTIME_ONLY
get_ref_mv_for_mv_stats(const MB_MODE_INFO * mbmi,const MB_MODE_INFO_EXT_FRAME * mbmi_ext_frame,int ref_idx)23 static AOM_INLINE int_mv get_ref_mv_for_mv_stats(
24     const MB_MODE_INFO *mbmi, const MB_MODE_INFO_EXT_FRAME *mbmi_ext_frame,
25     int ref_idx) {
26   int ref_mv_idx = mbmi->ref_mv_idx;
27   if (mbmi->mode == NEAR_NEWMV || mbmi->mode == NEW_NEARMV) {
28     assert(has_second_ref(mbmi));
29     ref_mv_idx += 1;
30   }
31 
32   const MV_REFERENCE_FRAME *ref_frames = mbmi->ref_frame;
33   const int8_t ref_frame_type = av1_ref_frame_type(ref_frames);
34   const CANDIDATE_MV *curr_ref_mv_stack = mbmi_ext_frame->ref_mv_stack;
35 
36   if (ref_frames[1] > INTRA_FRAME) {
37     assert(ref_idx == 0 || ref_idx == 1);
38     return ref_idx ? curr_ref_mv_stack[ref_mv_idx].comp_mv
39                    : curr_ref_mv_stack[ref_mv_idx].this_mv;
40   }
41 
42   assert(ref_idx == 0);
43   return ref_mv_idx < mbmi_ext_frame->ref_mv_count
44              ? curr_ref_mv_stack[ref_mv_idx].this_mv
45              : mbmi_ext_frame->global_mvs[ref_frame_type];
46 }
47 
get_symbol_cost(const aom_cdf_prob * cdf,int symbol)48 static AOM_INLINE int get_symbol_cost(const aom_cdf_prob *cdf, int symbol) {
49   const aom_cdf_prob cur_cdf = AOM_ICDF(cdf[symbol]);
50   const aom_cdf_prob prev_cdf = symbol ? AOM_ICDF(cdf[symbol - 1]) : 0;
51   const aom_cdf_prob p15 = AOMMAX(cur_cdf - prev_cdf, EC_MIN_PROB);
52 
53   return av1_cost_symbol(p15);
54 }
55 
keep_one_comp_stat(MV_STATS * mv_stats,int comp_val,int comp_idx,const AV1_COMP * cpi,int * rates)56 static AOM_INLINE int keep_one_comp_stat(MV_STATS *mv_stats, int comp_val,
57                                          int comp_idx, const AV1_COMP *cpi,
58                                          int *rates) {
59   assert(comp_val != 0 && "mv component should not have zero value!");
60   const int sign = comp_val < 0;
61   const int mag = sign ? -comp_val : comp_val;
62   const int mag_minus_1 = mag - 1;
63   int offset;
64   const int mv_class = av1_get_mv_class(mag_minus_1, &offset);
65   const int int_part = offset >> 3;         // int mv data
66   const int frac_part = (offset >> 1) & 3;  // fractional mv data
67   const int high_part = offset & 1;         // high precision mv data
68   const int use_hp = cpi->common.features.allow_high_precision_mv;
69   int r_idx = 0;
70 
71   const MACROBLOCK *const x = &cpi->td.mb;
72   const MACROBLOCKD *const xd = &x->e_mbd;
73   FRAME_CONTEXT *ec_ctx = xd->tile_ctx;
74   nmv_context *nmvc = &ec_ctx->nmvc;
75   nmv_component *mvcomp_ctx = nmvc->comps;
76   nmv_component *cur_mvcomp_ctx = &mvcomp_ctx[comp_idx];
77   aom_cdf_prob *sign_cdf = cur_mvcomp_ctx->sign_cdf;
78   aom_cdf_prob *class_cdf = cur_mvcomp_ctx->classes_cdf;
79   aom_cdf_prob *class0_cdf = cur_mvcomp_ctx->class0_cdf;
80   aom_cdf_prob(*bits_cdf)[3] = cur_mvcomp_ctx->bits_cdf;
81   aom_cdf_prob *frac_part_cdf = mv_class
82                                     ? (cur_mvcomp_ctx->fp_cdf)
83                                     : (cur_mvcomp_ctx->class0_fp_cdf[int_part]);
84   aom_cdf_prob *high_part_cdf =
85       mv_class ? (cur_mvcomp_ctx->hp_cdf) : (cur_mvcomp_ctx->class0_hp_cdf);
86 
87   const int sign_rate = get_symbol_cost(sign_cdf, sign);
88   rates[r_idx++] = sign_rate;
89   update_cdf(sign_cdf, sign, 2);
90 
91   const int class_rate = get_symbol_cost(class_cdf, mv_class);
92   rates[r_idx++] = class_rate;
93   update_cdf(class_cdf, mv_class, MV_CLASSES);
94 
95   int int_bit_rate = 0;
96   if (mv_class == MV_CLASS_0) {
97     int_bit_rate = get_symbol_cost(class0_cdf, int_part);
98     update_cdf(class0_cdf, int_part, CLASS0_SIZE);
99   } else {
100     const int n = mv_class + CLASS0_BITS - 1;  // number of bits
101     for (int i = 0; i < n; ++i) {
102       int_bit_rate += get_symbol_cost(bits_cdf[i], (int_part >> i) & 1);
103       update_cdf(bits_cdf[i], (int_part >> i) & 1, 2);
104     }
105   }
106   rates[r_idx++] = int_bit_rate;
107   const int frac_part_rate = get_symbol_cost(frac_part_cdf, frac_part);
108   rates[r_idx++] = frac_part_rate;
109   update_cdf(frac_part_cdf, frac_part, MV_FP_SIZE);
110   const int high_part_rate =
111       use_hp ? get_symbol_cost(high_part_cdf, high_part) : 0;
112   if (use_hp) {
113     update_cdf(high_part_cdf, high_part, 2);
114   }
115   rates[r_idx++] = high_part_rate;
116 
117   mv_stats->last_bit_zero += !high_part;
118   mv_stats->last_bit_nonzero += high_part;
119   const int total_rate =
120       (sign_rate + class_rate + int_bit_rate + frac_part_rate + high_part_rate);
121   return total_rate;
122 }
123 
keep_one_mv_stat(MV_STATS * mv_stats,const MV * ref_mv,const MV * cur_mv,const AV1_COMP * cpi)124 static AOM_INLINE void keep_one_mv_stat(MV_STATS *mv_stats, const MV *ref_mv,
125                                         const MV *cur_mv, const AV1_COMP *cpi) {
126   const MACROBLOCK *const x = &cpi->td.mb;
127   const MACROBLOCKD *const xd = &x->e_mbd;
128   FRAME_CONTEXT *ec_ctx = xd->tile_ctx;
129   nmv_context *nmvc = &ec_ctx->nmvc;
130   aom_cdf_prob *joint_cdf = nmvc->joints_cdf;
131   const int use_hp = cpi->common.features.allow_high_precision_mv;
132 
133   const MV diff = { cur_mv->row - ref_mv->row, cur_mv->col - ref_mv->col };
134   const int mv_joint = av1_get_mv_joint(&diff);
135   // TODO(chiyotsai@google.com): Estimate hp_diff when we are using lp
136   const MV hp_diff = diff;
137   const int hp_mv_joint = av1_get_mv_joint(&hp_diff);
138   const MV truncated_diff = { (diff.row / 2) * 2, (diff.col / 2) * 2 };
139   const MV lp_diff = use_hp ? truncated_diff : diff;
140   const int lp_mv_joint = av1_get_mv_joint(&lp_diff);
141 
142   aom_clear_system_state();
143   const int mv_joint_rate = get_symbol_cost(joint_cdf, mv_joint);
144   const int hp_mv_joint_rate = get_symbol_cost(joint_cdf, hp_mv_joint);
145   const int lp_mv_joint_rate = get_symbol_cost(joint_cdf, lp_mv_joint);
146 
147   update_cdf(joint_cdf, mv_joint, MV_JOINTS);
148 
149   mv_stats->total_mv_rate += mv_joint_rate;
150   mv_stats->hp_total_mv_rate += hp_mv_joint_rate;
151   mv_stats->lp_total_mv_rate += lp_mv_joint_rate;
152   mv_stats->mv_joint_count[mv_joint]++;
153 
154   for (int comp_idx = 0; comp_idx < 2; comp_idx++) {
155     const int comp_val = comp_idx ? diff.col : diff.row;
156     const int hp_comp_val = comp_idx ? hp_diff.col : hp_diff.row;
157     const int lp_comp_val = comp_idx ? lp_diff.col : lp_diff.row;
158     int rates[5];
159     av1_zero_array(rates, 5);
160 
161     const int comp_rate =
162         comp_val ? keep_one_comp_stat(mv_stats, comp_val, comp_idx, cpi, rates)
163                  : 0;
164     // TODO(chiyotsai@google.com): Properly get hp rate when use_hp is false
165     const int hp_rate =
166         hp_comp_val ? rates[0] + rates[1] + rates[2] + rates[3] + rates[4] : 0;
167     const int lp_rate =
168         lp_comp_val ? rates[0] + rates[1] + rates[2] + rates[3] : 0;
169 
170     mv_stats->total_mv_rate += comp_rate;
171     mv_stats->hp_total_mv_rate += hp_rate;
172     mv_stats->lp_total_mv_rate += lp_rate;
173   }
174 }
175 
collect_mv_stats_b(MV_STATS * mv_stats,const AV1_COMP * cpi,int mi_row,int mi_col)176 static AOM_INLINE void collect_mv_stats_b(MV_STATS *mv_stats,
177                                           const AV1_COMP *cpi, int mi_row,
178                                           int mi_col) {
179   const AV1_COMMON *cm = &cpi->common;
180   const CommonModeInfoParams *const mi_params = &cm->mi_params;
181 
182   if (mi_row >= mi_params->mi_rows || mi_col >= mi_params->mi_cols) {
183     return;
184   }
185 
186   const MB_MODE_INFO *mbmi =
187       mi_params->mi_grid_base[mi_row * mi_params->mi_stride + mi_col];
188   const MB_MODE_INFO_EXT_FRAME *mbmi_ext_frame =
189       cpi->mbmi_ext_info.frame_base +
190       get_mi_ext_idx(mi_row, mi_col, cm->mi_params.mi_alloc_bsize,
191                      cpi->mbmi_ext_info.stride);
192 
193   if (!is_inter_block(mbmi)) {
194     mv_stats->intra_count++;
195     return;
196   }
197   mv_stats->inter_count++;
198 
199   const PREDICTION_MODE mode = mbmi->mode;
200   const int is_compound = has_second_ref(mbmi);
201 
202   if (mode == NEWMV || mode == NEW_NEWMV) {
203     // All mvs are new
204     for (int ref_idx = 0; ref_idx < 1 + is_compound; ++ref_idx) {
205       const MV ref_mv =
206           get_ref_mv_for_mv_stats(mbmi, mbmi_ext_frame, ref_idx).as_mv;
207       const MV cur_mv = mbmi->mv[ref_idx].as_mv;
208       keep_one_mv_stat(mv_stats, &ref_mv, &cur_mv, cpi);
209     }
210   } else if (mode == NEAREST_NEWMV || mode == NEAR_NEWMV ||
211              mode == NEW_NEARESTMV || mode == NEW_NEARMV) {
212     // has exactly one new_mv
213     mv_stats->default_mvs += 1;
214 
215     const int ref_idx = (mode == NEAREST_NEWMV || mode == NEAR_NEWMV);
216     const MV ref_mv =
217         get_ref_mv_for_mv_stats(mbmi, mbmi_ext_frame, ref_idx).as_mv;
218     const MV cur_mv = mbmi->mv[ref_idx].as_mv;
219 
220     keep_one_mv_stat(mv_stats, &ref_mv, &cur_mv, cpi);
221   } else {
222     // No new_mv
223     mv_stats->default_mvs += 1 + is_compound;
224   }
225 
226   // Add texture information
227   const BLOCK_SIZE bsize = mbmi->bsize;
228   const int num_rows = block_size_high[bsize];
229   const int num_cols = block_size_wide[bsize];
230   const int y_stride = cpi->source->y_stride;
231   const int px_row = 4 * mi_row, px_col = 4 * mi_col;
232   const int buf_is_hbd = cpi->source->flags & YV12_FLAG_HIGHBITDEPTH;
233   const int bd = cm->seq_params.bit_depth;
234   if (buf_is_hbd) {
235     uint16_t *source_buf =
236         CONVERT_TO_SHORTPTR(cpi->source->y_buffer) + px_row * y_stride + px_col;
237     for (int row = 0; row < num_rows - 1; row++) {
238       for (int col = 0; col < num_cols - 1; col++) {
239         const int offset = row * y_stride + col;
240         const int horz_diff =
241             abs(source_buf[offset + 1] - source_buf[offset]) >> (bd - 8);
242         const int vert_diff =
243             abs(source_buf[offset + y_stride] - source_buf[offset]) >> (bd - 8);
244         mv_stats->horz_text += horz_diff;
245         mv_stats->vert_text += vert_diff;
246         mv_stats->diag_text += horz_diff * vert_diff;
247       }
248     }
249   } else {
250     uint8_t *source_buf = cpi->source->y_buffer + px_row * y_stride + px_col;
251     for (int row = 0; row < num_rows - 1; row++) {
252       for (int col = 0; col < num_cols - 1; col++) {
253         const int offset = row * y_stride + col;
254         const int horz_diff = abs(source_buf[offset + 1] - source_buf[offset]);
255         const int vert_diff =
256             abs(source_buf[offset + y_stride] - source_buf[offset]);
257         mv_stats->horz_text += horz_diff;
258         mv_stats->vert_text += vert_diff;
259         mv_stats->diag_text += horz_diff * vert_diff;
260       }
261     }
262   }
263 }
264 
265 // Split block
collect_mv_stats_sb(MV_STATS * mv_stats,const AV1_COMP * cpi,int mi_row,int mi_col,BLOCK_SIZE bsize)266 static AOM_INLINE void collect_mv_stats_sb(MV_STATS *mv_stats,
267                                            const AV1_COMP *cpi, int mi_row,
268                                            int mi_col, BLOCK_SIZE bsize) {
269   assert(bsize < BLOCK_SIZES_ALL);
270   const AV1_COMMON *cm = &cpi->common;
271 
272   if (mi_row >= cm->mi_params.mi_rows || mi_col >= cm->mi_params.mi_cols)
273     return;
274 
275   const PARTITION_TYPE partition = get_partition(cm, mi_row, mi_col, bsize);
276   const BLOCK_SIZE subsize = get_partition_subsize(bsize, partition);
277 
278   const int hbs = mi_size_wide[bsize] / 2;
279   const int qbs = mi_size_wide[bsize] / 4;
280   switch (partition) {
281     case PARTITION_NONE:
282       collect_mv_stats_b(mv_stats, cpi, mi_row, mi_col);
283       break;
284     case PARTITION_HORZ:
285       collect_mv_stats_b(mv_stats, cpi, mi_row, mi_col);
286       collect_mv_stats_b(mv_stats, cpi, mi_row + hbs, mi_col);
287       break;
288     case PARTITION_VERT:
289       collect_mv_stats_b(mv_stats, cpi, mi_row, mi_col);
290       collect_mv_stats_b(mv_stats, cpi, mi_row, mi_col + hbs);
291       break;
292     case PARTITION_SPLIT:
293       collect_mv_stats_sb(mv_stats, cpi, mi_row, mi_col, subsize);
294       collect_mv_stats_sb(mv_stats, cpi, mi_row, mi_col + hbs, subsize);
295       collect_mv_stats_sb(mv_stats, cpi, mi_row + hbs, mi_col, subsize);
296       collect_mv_stats_sb(mv_stats, cpi, mi_row + hbs, mi_col + hbs, subsize);
297       break;
298     case PARTITION_HORZ_A:
299       collect_mv_stats_b(mv_stats, cpi, mi_row, mi_col);
300       collect_mv_stats_b(mv_stats, cpi, mi_row, mi_col + hbs);
301       collect_mv_stats_b(mv_stats, cpi, mi_row + hbs, mi_col);
302       break;
303     case PARTITION_HORZ_B:
304       collect_mv_stats_b(mv_stats, cpi, mi_row, mi_col);
305       collect_mv_stats_b(mv_stats, cpi, mi_row + hbs, mi_col);
306       collect_mv_stats_b(mv_stats, cpi, mi_row + hbs, mi_col + hbs);
307       break;
308     case PARTITION_VERT_A:
309       collect_mv_stats_b(mv_stats, cpi, mi_row, mi_col);
310       collect_mv_stats_b(mv_stats, cpi, mi_row + hbs, mi_col);
311       collect_mv_stats_b(mv_stats, cpi, mi_row, mi_col + hbs);
312       break;
313     case PARTITION_VERT_B:
314       collect_mv_stats_b(mv_stats, cpi, mi_row, mi_col);
315       collect_mv_stats_b(mv_stats, cpi, mi_row, mi_col + hbs);
316       collect_mv_stats_b(mv_stats, cpi, mi_row + hbs, mi_col + hbs);
317       break;
318     case PARTITION_HORZ_4:
319       for (int i = 0; i < 4; ++i) {
320         const int this_mi_row = mi_row + i * qbs;
321         collect_mv_stats_b(mv_stats, cpi, this_mi_row, mi_col);
322       }
323       break;
324     case PARTITION_VERT_4:
325       for (int i = 0; i < 4; ++i) {
326         const int this_mi_col = mi_col + i * qbs;
327         collect_mv_stats_b(mv_stats, cpi, mi_row, this_mi_col);
328       }
329       break;
330     default: assert(0);
331   }
332 }
333 
collect_mv_stats_tile(MV_STATS * mv_stats,const AV1_COMP * cpi,const TileInfo * tile_info)334 static AOM_INLINE void collect_mv_stats_tile(MV_STATS *mv_stats,
335                                              const AV1_COMP *cpi,
336                                              const TileInfo *tile_info) {
337   const AV1_COMMON *cm = &cpi->common;
338   const int mi_row_start = tile_info->mi_row_start;
339   const int mi_row_end = tile_info->mi_row_end;
340   const int mi_col_start = tile_info->mi_col_start;
341   const int mi_col_end = tile_info->mi_col_end;
342   const int sb_size_mi = cm->seq_params.mib_size;
343   BLOCK_SIZE sb_size = cm->seq_params.sb_size;
344   for (int mi_row = mi_row_start; mi_row < mi_row_end; mi_row += sb_size_mi) {
345     for (int mi_col = mi_col_start; mi_col < mi_col_end; mi_col += sb_size_mi) {
346       collect_mv_stats_sb(mv_stats, cpi, mi_row, mi_col, sb_size);
347     }
348   }
349 }
350 
av1_collect_mv_stats(AV1_COMP * cpi,int current_q)351 void av1_collect_mv_stats(AV1_COMP *cpi, int current_q) {
352   MV_STATS *mv_stats = &cpi->mv_stats;
353   const AV1_COMMON *cm = &cpi->common;
354   const int tile_cols = cm->tiles.cols;
355   const int tile_rows = cm->tiles.rows;
356 
357   for (int tile_row = 0; tile_row < tile_rows; tile_row++) {
358     TileInfo tile_info;
359     av1_tile_set_row(&tile_info, cm, tile_row);
360     for (int tile_col = 0; tile_col < tile_cols; tile_col++) {
361       const int tile_idx = tile_row * tile_cols + tile_col;
362       av1_tile_set_col(&tile_info, cm, tile_col);
363       cpi->tile_data[tile_idx].tctx = *cm->fc;
364       cpi->td.mb.e_mbd.tile_ctx = &cpi->tile_data[tile_idx].tctx;
365       collect_mv_stats_tile(mv_stats, cpi, &tile_info);
366     }
367   }
368 
369   mv_stats->q = current_q;
370   mv_stats->order = cpi->common.current_frame.order_hint;
371   mv_stats->valid = 1;
372 }
373 
get_smart_mv_prec(AV1_COMP * cpi,const MV_STATS * mv_stats,int current_q)374 static AOM_INLINE int get_smart_mv_prec(AV1_COMP *cpi, const MV_STATS *mv_stats,
375                                         int current_q) {
376   const AV1_COMMON *cm = &cpi->common;
377   const int order_hint = cpi->common.current_frame.order_hint;
378   const int order_diff = order_hint - mv_stats->order;
379   aom_clear_system_state();
380   const float area = (float)(cm->width * cm->height);
381   float features[MV_PREC_FEATURE_SIZE] = {
382     (float)current_q,
383     (float)mv_stats->q,
384     (float)order_diff,
385     mv_stats->inter_count / area,
386     mv_stats->intra_count / area,
387     mv_stats->default_mvs / area,
388     mv_stats->mv_joint_count[0] / area,
389     mv_stats->mv_joint_count[1] / area,
390     mv_stats->mv_joint_count[2] / area,
391     mv_stats->mv_joint_count[3] / area,
392     mv_stats->last_bit_zero / area,
393     mv_stats->last_bit_nonzero / area,
394     mv_stats->total_mv_rate / area,
395     mv_stats->hp_total_mv_rate / area,
396     mv_stats->lp_total_mv_rate / area,
397     mv_stats->horz_text / area,
398     mv_stats->vert_text / area,
399     mv_stats->diag_text / area,
400   };
401 
402   for (int f_idx = 0; f_idx < MV_PREC_FEATURE_SIZE; f_idx++) {
403     features[f_idx] =
404         (features[f_idx] - av1_mv_prec_mean[f_idx]) / av1_mv_prec_std[f_idx];
405   }
406   float score = 0.0f;
407 
408   av1_nn_predict(features, &av1_mv_prec_dnn_config, 1, &score);
409 
410   const int use_high_hp = score >= 0.0f;
411   return use_high_hp;
412 }
413 #endif  // !CONFIG_REALTIME_ONLY
414 
av1_pick_and_set_high_precision_mv(AV1_COMP * cpi,int qindex)415 void av1_pick_and_set_high_precision_mv(AV1_COMP *cpi, int qindex) {
416   int use_hp = qindex < HIGH_PRECISION_MV_QTHRESH;
417 
418   if (cpi->sf.hl_sf.high_precision_mv_usage == QTR_ONLY) {
419     use_hp = 0;
420   }
421 #if !CONFIG_REALTIME_ONLY
422   else if (cpi->sf.hl_sf.high_precision_mv_usage == LAST_MV_DATA &&
423            av1_frame_allows_smart_mv(cpi) && cpi->mv_stats.valid) {
424     use_hp = get_smart_mv_prec(cpi, &cpi->mv_stats, qindex);
425   }
426 #endif  // !CONFIG_REALTIME_ONLY
427 
428   av1_set_high_precision_mv(cpi, use_hp,
429                             cpi->common.features.cur_frame_force_integer_mv);
430 }
431