1 /*
2  * Copyright (c) 2019, Alliance for Open Media. All rights reserved
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <stdint.h>
13 #include <float.h>
14 
15 #include "config/aom_config.h"
16 #include "config/aom_dsp_rtcd.h"
17 #include "config/aom_scale_rtcd.h"
18 
19 #include "aom/aom_codec.h"
20 #include "aom_ports/system_state.h"
21 
22 #include "av1/common/av1_common_int.h"
23 #include "av1/common/enums.h"
24 #include "av1/common/idct.h"
25 #include "av1/common/reconintra.h"
26 
27 #include "av1/encoder/encoder.h"
28 #include "av1/encoder/ethread.h"
29 #include "av1/encoder/encodeframe_utils.h"
30 #include "av1/encoder/encode_strategy.h"
31 #include "av1/encoder/hybrid_fwd_txfm.h"
32 #include "av1/encoder/motion_search_facade.h"
33 #include "av1/encoder/rd.h"
34 #include "av1/encoder/rdopt.h"
35 #include "av1/encoder/reconinter_enc.h"
36 #include "av1/encoder/tpl_model.h"
37 
get_quantize_error(const MACROBLOCK * x,int plane,const tran_low_t * coeff,tran_low_t * qcoeff,tran_low_t * dqcoeff,TX_SIZE tx_size,uint16_t * eob,int64_t * recon_error,int64_t * sse)38 static AOM_INLINE void get_quantize_error(const MACROBLOCK *x, int plane,
39                                           const tran_low_t *coeff,
40                                           tran_low_t *qcoeff,
41                                           tran_low_t *dqcoeff, TX_SIZE tx_size,
42                                           uint16_t *eob, int64_t *recon_error,
43                                           int64_t *sse) {
44   const struct macroblock_plane *const p = &x->plane[plane];
45   const MACROBLOCKD *xd = &x->e_mbd;
46   const SCAN_ORDER *const scan_order = &av1_scan_orders[tx_size][DCT_DCT];
47   int pix_num = 1 << num_pels_log2_lookup[txsize_to_bsize[tx_size]];
48   const int shift = tx_size == TX_32X32 ? 0 : 2;
49 
50   QUANT_PARAM quant_param;
51   av1_setup_quant(tx_size, 0, AV1_XFORM_QUANT_FP, 0, &quant_param);
52 
53 #if CONFIG_AV1_HIGHBITDEPTH
54   if (is_cur_buf_hbd(xd)) {
55     av1_highbd_quantize_fp_facade(coeff, pix_num, p, qcoeff, dqcoeff, eob,
56                                   scan_order, &quant_param);
57     *recon_error =
58         av1_highbd_block_error(coeff, dqcoeff, pix_num, sse, xd->bd) >> shift;
59   } else {
60     av1_quantize_fp_facade(coeff, pix_num, p, qcoeff, dqcoeff, eob, scan_order,
61                            &quant_param);
62     *recon_error = av1_block_error(coeff, dqcoeff, pix_num, sse) >> shift;
63   }
64 #else
65   (void)xd;
66   av1_quantize_fp_facade(coeff, pix_num, p, qcoeff, dqcoeff, eob, scan_order,
67                          &quant_param);
68   *recon_error = av1_block_error(coeff, dqcoeff, pix_num, sse) >> shift;
69 #endif  // CONFIG_AV1_HIGHBITDEPTH
70 
71   *recon_error = AOMMAX(*recon_error, 1);
72 
73   *sse = (*sse) >> shift;
74   *sse = AOMMAX(*sse, 1);
75 }
76 
tpl_fwd_txfm(const int16_t * src_diff,int bw,tran_low_t * coeff,TX_SIZE tx_size,int bit_depth,int is_hbd)77 static AOM_INLINE void tpl_fwd_txfm(const int16_t *src_diff, int bw,
78                                     tran_low_t *coeff, TX_SIZE tx_size,
79                                     int bit_depth, int is_hbd) {
80   TxfmParam txfm_param;
81   txfm_param.tx_type = DCT_DCT;
82   txfm_param.tx_size = tx_size;
83   txfm_param.lossless = 0;
84   txfm_param.tx_set_type = EXT_TX_SET_ALL16;
85 
86   txfm_param.bd = bit_depth;
87   txfm_param.is_hbd = is_hbd;
88   av1_fwd_txfm(src_diff, coeff, bw, &txfm_param);
89 }
90 
tpl_get_satd_cost(const MACROBLOCK * x,int16_t * src_diff,int diff_stride,const uint8_t * src,int src_stride,const uint8_t * dst,int dst_stride,tran_low_t * coeff,int bw,int bh,TX_SIZE tx_size)91 static AOM_INLINE int64_t tpl_get_satd_cost(const MACROBLOCK *x,
92                                             int16_t *src_diff, int diff_stride,
93                                             const uint8_t *src, int src_stride,
94                                             const uint8_t *dst, int dst_stride,
95                                             tran_low_t *coeff, int bw, int bh,
96                                             TX_SIZE tx_size) {
97   const MACROBLOCKD *xd = &x->e_mbd;
98   const int pix_num = bw * bh;
99 
100   av1_subtract_block(xd, bh, bw, src_diff, diff_stride, src, src_stride, dst,
101                      dst_stride);
102   tpl_fwd_txfm(src_diff, bw, coeff, tx_size, xd->bd, is_cur_buf_hbd(xd));
103   return aom_satd(coeff, pix_num);
104 }
105 
rate_estimator(const tran_low_t * qcoeff,int eob,TX_SIZE tx_size)106 static int rate_estimator(const tran_low_t *qcoeff, int eob, TX_SIZE tx_size) {
107   const SCAN_ORDER *const scan_order = &av1_scan_orders[tx_size][DCT_DCT];
108 
109   assert((1 << num_pels_log2_lookup[txsize_to_bsize[tx_size]]) >= eob);
110   aom_clear_system_state();
111   int rate_cost = 1;
112 
113   for (int idx = 0; idx < eob; ++idx) {
114     int abs_level = abs(qcoeff[scan_order->scan[idx]]);
115     rate_cost += (int)(log(abs_level + 1.0) / log(2.0)) + 1;
116   }
117 
118   return (rate_cost << AV1_PROB_COST_SHIFT);
119 }
120 
txfm_quant_rdcost(const MACROBLOCK * x,int16_t * src_diff,int diff_stride,uint8_t * src,int src_stride,uint8_t * dst,int dst_stride,tran_low_t * coeff,tran_low_t * qcoeff,tran_low_t * dqcoeff,int bw,int bh,TX_SIZE tx_size,int * rate_cost,int64_t * recon_error,int64_t * sse)121 static AOM_INLINE void txfm_quant_rdcost(
122     const MACROBLOCK *x, int16_t *src_diff, int diff_stride, uint8_t *src,
123     int src_stride, uint8_t *dst, int dst_stride, tran_low_t *coeff,
124     tran_low_t *qcoeff, tran_low_t *dqcoeff, int bw, int bh, TX_SIZE tx_size,
125     int *rate_cost, int64_t *recon_error, int64_t *sse) {
126   const MACROBLOCKD *xd = &x->e_mbd;
127   uint16_t eob;
128   av1_subtract_block(xd, bh, bw, src_diff, diff_stride, src, src_stride, dst,
129                      dst_stride);
130   tpl_fwd_txfm(src_diff, diff_stride, coeff, tx_size, xd->bd,
131                is_cur_buf_hbd(xd));
132 
133   get_quantize_error(x, 0, coeff, qcoeff, dqcoeff, tx_size, &eob, recon_error,
134                      sse);
135 
136   *rate_cost = rate_estimator(qcoeff, eob, tx_size);
137 
138   av1_inverse_transform_block(xd, dqcoeff, 0, DCT_DCT, tx_size, dst, dst_stride,
139                               eob, 0);
140 }
141 
motion_estimation(AV1_COMP * cpi,MACROBLOCK * x,uint8_t * cur_frame_buf,uint8_t * ref_frame_buf,int stride,int stride_ref,BLOCK_SIZE bsize,MV center_mv,int_mv * best_mv)142 static uint32_t motion_estimation(AV1_COMP *cpi, MACROBLOCK *x,
143                                   uint8_t *cur_frame_buf,
144                                   uint8_t *ref_frame_buf, int stride,
145                                   int stride_ref, BLOCK_SIZE bsize,
146                                   MV center_mv, int_mv *best_mv) {
147   AV1_COMMON *cm = &cpi->common;
148   MACROBLOCKD *const xd = &x->e_mbd;
149   TPL_SPEED_FEATURES *tpl_sf = &cpi->sf.tpl_sf;
150   int step_param;
151   uint32_t bestsme = UINT_MAX;
152   int distortion;
153   uint32_t sse;
154   int cost_list[5];
155   FULLPEL_MV start_mv = get_fullmv_from_mv(&center_mv);
156 
157   // Setup frame pointers
158   x->plane[0].src.buf = cur_frame_buf;
159   x->plane[0].src.stride = stride;
160   xd->plane[0].pre[0].buf = ref_frame_buf;
161   xd->plane[0].pre[0].stride = stride_ref;
162 
163   step_param = tpl_sf->reduce_first_step_size;
164   step_param = AOMMIN(step_param, MAX_MVSEARCH_STEPS - 2);
165 
166   const search_site_config *search_site_cfg =
167       cpi->mv_search_params.search_site_cfg[SS_CFG_SRC];
168   if (search_site_cfg->stride != stride_ref)
169     search_site_cfg = cpi->mv_search_params.search_site_cfg[SS_CFG_LOOKAHEAD];
170   assert(search_site_cfg->stride == stride_ref);
171 
172   FULLPEL_MOTION_SEARCH_PARAMS full_ms_params;
173   av1_make_default_fullpel_ms_params(&full_ms_params, cpi, x, bsize, &center_mv,
174                                      search_site_cfg,
175                                      /*fine_search_interval=*/0);
176   av1_set_mv_search_method(&full_ms_params, search_site_cfg,
177                            tpl_sf->search_method);
178 
179   av1_full_pixel_search(start_mv, &full_ms_params, step_param,
180                         cond_cost_list(cpi, cost_list), &best_mv->as_fullmv,
181                         NULL);
182 
183   SUBPEL_MOTION_SEARCH_PARAMS ms_params;
184   av1_make_default_subpel_ms_params(&ms_params, cpi, x, bsize, &center_mv,
185                                     cost_list);
186   ms_params.forced_stop = tpl_sf->subpel_force_stop;
187   ms_params.var_params.subpel_search_type = USE_2_TAPS;
188   ms_params.mv_cost_params.mv_cost_type = MV_COST_NONE;
189   MV subpel_start_mv = get_mv_from_fullmv(&best_mv->as_fullmv);
190   bestsme = cpi->mv_search_params.find_fractional_mv_step(
191       xd, cm, &ms_params, subpel_start_mv, &best_mv->as_mv, &distortion, &sse,
192       NULL);
193 
194   return bestsme;
195 }
196 
197 typedef struct {
198   int_mv mv;
199   int sad;
200 } center_mv_t;
201 
compare_sad(const void * a,const void * b)202 static int compare_sad(const void *a, const void *b) {
203   const int diff = ((center_mv_t *)a)->sad - ((center_mv_t *)b)->sad;
204   if (diff < 0)
205     return -1;
206   else if (diff > 0)
207     return 1;
208   return 0;
209 }
210 
is_alike_mv(int_mv candidate_mv,center_mv_t * center_mvs,int center_mvs_count,int skip_alike_starting_mv)211 static int is_alike_mv(int_mv candidate_mv, center_mv_t *center_mvs,
212                        int center_mvs_count, int skip_alike_starting_mv) {
213   // MV difference threshold is in 1/8 precision.
214   const int mv_diff_thr[3] = { 1, (8 << 3), (16 << 3) };
215   int thr = mv_diff_thr[skip_alike_starting_mv];
216   int i;
217 
218   for (i = 0; i < center_mvs_count; i++) {
219     if (abs(center_mvs[i].mv.as_mv.col - candidate_mv.as_mv.col) < thr &&
220         abs(center_mvs[i].mv.as_mv.row - candidate_mv.as_mv.row) < thr)
221       return 1;
222   }
223 
224   return 0;
225 }
226 
get_rate_distortion(int * rate_cost,int64_t * recon_error,int16_t * src_diff,tran_low_t * coeff,tran_low_t * qcoeff,tran_low_t * dqcoeff,AV1_COMMON * cm,MACROBLOCK * x,const YV12_BUFFER_CONFIG * ref_frame_ptr[2],uint8_t * rec_buffer_pool[3],const int rec_stride_pool[3],TX_SIZE tx_size,PREDICTION_MODE best_mode,int mi_row,int mi_col)227 static void get_rate_distortion(
228     int *rate_cost, int64_t *recon_error, int16_t *src_diff, tran_low_t *coeff,
229     tran_low_t *qcoeff, tran_low_t *dqcoeff, AV1_COMMON *cm, MACROBLOCK *x,
230     const YV12_BUFFER_CONFIG *ref_frame_ptr[2], uint8_t *rec_buffer_pool[3],
231     const int rec_stride_pool[3], TX_SIZE tx_size, PREDICTION_MODE best_mode,
232     int mi_row, int mi_col) {
233   *rate_cost = 0;
234   *recon_error = 1;
235 
236   MACROBLOCKD *xd = &x->e_mbd;
237   int is_compound = (best_mode == NEW_NEWMV);
238 
239   uint8_t *src_buffer_pool[MAX_MB_PLANE] = {
240     xd->cur_buf->y_buffer,
241     xd->cur_buf->u_buffer,
242     xd->cur_buf->v_buffer,
243   };
244   const int src_stride_pool[MAX_MB_PLANE] = {
245     xd->cur_buf->y_stride,
246     xd->cur_buf->uv_stride,
247     xd->cur_buf->uv_stride,
248   };
249 
250   const int_interpfilters kernel =
251       av1_broadcast_interp_filter(EIGHTTAP_REGULAR);
252 
253   for (int plane = 0; plane < MAX_MB_PLANE; ++plane) {
254     struct macroblockd_plane *pd = &xd->plane[plane];
255     BLOCK_SIZE bsize_plane =
256         ss_size_lookup[txsize_to_bsize[tx_size]][pd->subsampling_x]
257                       [pd->subsampling_y];
258 
259     int dst_buffer_stride = rec_stride_pool[plane];
260     int dst_mb_offset =
261         ((mi_row * MI_SIZE * dst_buffer_stride) >> pd->subsampling_y) +
262         ((mi_col * MI_SIZE) >> pd->subsampling_x);
263     uint8_t *dst_buffer = rec_buffer_pool[plane] + dst_mb_offset;
264     for (int ref = 0; ref < 1 + is_compound; ++ref) {
265       if (!is_inter_mode(best_mode)) {
266         av1_predict_intra_block(
267             cm, xd, block_size_wide[bsize_plane], block_size_high[bsize_plane],
268             max_txsize_rect_lookup[bsize_plane], best_mode, 0, 0,
269             FILTER_INTRA_MODES, dst_buffer, dst_buffer_stride, dst_buffer,
270             dst_buffer_stride, 0, 0, plane);
271       } else {
272         int_mv best_mv = xd->mi[0]->mv[ref];
273         uint8_t *ref_buffer_pool[MAX_MB_PLANE] = {
274           ref_frame_ptr[ref]->y_buffer,
275           ref_frame_ptr[ref]->u_buffer,
276           ref_frame_ptr[ref]->v_buffer,
277         };
278         InterPredParams inter_pred_params;
279         struct buf_2d ref_buf = {
280           NULL, ref_buffer_pool[plane],
281           plane ? ref_frame_ptr[ref]->uv_width : ref_frame_ptr[ref]->y_width,
282           plane ? ref_frame_ptr[ref]->uv_height : ref_frame_ptr[ref]->y_height,
283           plane ? ref_frame_ptr[ref]->uv_stride : ref_frame_ptr[ref]->y_stride
284         };
285         av1_init_inter_params(&inter_pred_params, block_size_wide[bsize_plane],
286                               block_size_high[bsize_plane],
287                               (mi_row * MI_SIZE) >> pd->subsampling_y,
288                               (mi_col * MI_SIZE) >> pd->subsampling_x,
289                               pd->subsampling_x, pd->subsampling_y, xd->bd,
290                               is_cur_buf_hbd(xd), 0,
291                               xd->block_ref_scale_factors[0], &ref_buf, kernel);
292         if (is_compound) av1_init_comp_mode(&inter_pred_params);
293         inter_pred_params.conv_params = get_conv_params_no_round(
294             ref, plane, xd->tmp_conv_dst, MAX_SB_SIZE, is_compound, xd->bd);
295 
296         av1_enc_build_one_inter_predictor(dst_buffer, dst_buffer_stride,
297                                           &best_mv.as_mv, &inter_pred_params);
298       }
299     }
300 
301     int src_stride = src_stride_pool[plane];
302     int src_mb_offset = ((mi_row * MI_SIZE * src_stride) >> pd->subsampling_y) +
303                         ((mi_col * MI_SIZE) >> pd->subsampling_x);
304 
305     int this_rate = 1;
306     int64_t this_recon_error = 1;
307     int64_t sse;
308     txfm_quant_rdcost(
309         x, src_diff, block_size_wide[bsize_plane],
310         src_buffer_pool[plane] + src_mb_offset, src_stride, dst_buffer,
311         dst_buffer_stride, coeff, qcoeff, dqcoeff, block_size_wide[bsize_plane],
312         block_size_high[bsize_plane], max_txsize_rect_lookup[bsize_plane],
313         &this_rate, &this_recon_error, &sse);
314 
315     *recon_error += this_recon_error;
316     *rate_cost += this_rate;
317   }
318 }
319 
mode_estimation(AV1_COMP * cpi,MACROBLOCK * x,int mi_row,int mi_col,BLOCK_SIZE bsize,TX_SIZE tx_size,TplDepStats * tpl_stats)320 static AOM_INLINE void mode_estimation(AV1_COMP *cpi, MACROBLOCK *x, int mi_row,
321                                        int mi_col, BLOCK_SIZE bsize,
322                                        TX_SIZE tx_size,
323                                        TplDepStats *tpl_stats) {
324   AV1_COMMON *cm = &cpi->common;
325   const GF_GROUP *gf_group = &cpi->gf_group;
326 
327   (void)gf_group;
328 
329   MACROBLOCKD *xd = &x->e_mbd;
330   TplParams *tpl_data = &cpi->tpl_data;
331   TplDepFrame *tpl_frame = &tpl_data->tpl_frame[tpl_data->frame_idx];
332   const uint8_t block_mis_log2 = tpl_data->tpl_stats_block_mis_log2;
333 
334   const int bw = 4 << mi_size_wide_log2[bsize];
335   const int bh = 4 << mi_size_high_log2[bsize];
336   const int_interpfilters kernel =
337       av1_broadcast_interp_filter(EIGHTTAP_REGULAR);
338 
339   int64_t best_intra_cost = INT64_MAX;
340   int64_t intra_cost;
341   PREDICTION_MODE best_mode = DC_PRED;
342 
343   int mb_y_offset = mi_row * MI_SIZE * xd->cur_buf->y_stride + mi_col * MI_SIZE;
344   uint8_t *src_mb_buffer = xd->cur_buf->y_buffer + mb_y_offset;
345   int src_stride = xd->cur_buf->y_stride;
346 
347   int dst_mb_offset =
348       mi_row * MI_SIZE * tpl_frame->rec_picture->y_stride + mi_col * MI_SIZE;
349   uint8_t *dst_buffer = tpl_frame->rec_picture->y_buffer + dst_mb_offset;
350   int dst_buffer_stride = tpl_frame->rec_picture->y_stride;
351 
352   uint8_t *rec_buffer_pool[3] = {
353     tpl_frame->rec_picture->y_buffer,
354     tpl_frame->rec_picture->u_buffer,
355     tpl_frame->rec_picture->v_buffer,
356   };
357 
358   const int rec_stride_pool[3] = {
359     tpl_frame->rec_picture->y_stride,
360     tpl_frame->rec_picture->uv_stride,
361     tpl_frame->rec_picture->uv_stride,
362   };
363 
364   for (int plane = 1; plane < MAX_MB_PLANE; ++plane) {
365     struct macroblockd_plane *pd = &xd->plane[plane];
366     pd->subsampling_x = xd->cur_buf->subsampling_x;
367     pd->subsampling_y = xd->cur_buf->subsampling_y;
368   }
369 
370   // Number of pixels in a tpl block
371   const int tpl_block_pels = tpl_data->tpl_bsize_1d * tpl_data->tpl_bsize_1d;
372   // Allocate temporary buffers used in motion estimation.
373   uint8_t *predictor8 = aom_memalign(32, tpl_block_pels * 2 * sizeof(uint8_t));
374   int16_t *src_diff = aom_memalign(32, tpl_block_pels * sizeof(int16_t));
375   tran_low_t *coeff = aom_memalign(32, tpl_block_pels * sizeof(tran_low_t));
376   tran_low_t *qcoeff = aom_memalign(32, tpl_block_pels * sizeof(tran_low_t));
377   tran_low_t *dqcoeff = aom_memalign(32, tpl_block_pels * sizeof(tran_low_t));
378   uint8_t *predictor =
379       is_cur_buf_hbd(xd) ? CONVERT_TO_BYTEPTR(predictor8) : predictor8;
380   int64_t recon_error = 1;
381 
382   memset(tpl_stats, 0, sizeof(*tpl_stats));
383   tpl_stats->ref_frame_index[0] = -1;
384   tpl_stats->ref_frame_index[1] = -1;
385 
386   const int mi_width = mi_size_wide[bsize];
387   const int mi_height = mi_size_high[bsize];
388   set_mode_info_offsets(&cpi->common.mi_params, &cpi->mbmi_ext_info, x, xd,
389                         mi_row, mi_col);
390   set_mi_row_col(xd, &xd->tile, mi_row, mi_height, mi_col, mi_width,
391                  cm->mi_params.mi_rows, cm->mi_params.mi_cols);
392   set_plane_n4(xd, mi_size_wide[bsize], mi_size_high[bsize],
393                av1_num_planes(cm));
394   xd->mi[0]->bsize = bsize;
395   xd->mi[0]->motion_mode = SIMPLE_TRANSLATION;
396 
397   // Intra prediction search
398   xd->mi[0]->ref_frame[0] = INTRA_FRAME;
399 
400   // Pre-load the bottom left line.
401   if (xd->left_available &&
402       mi_row + tx_size_high_unit[tx_size] < xd->tile.mi_row_end) {
403 #if CONFIG_AV1_HIGHBITDEPTH
404     if (is_cur_buf_hbd(xd)) {
405       uint16_t *dst = CONVERT_TO_SHORTPTR(dst_buffer);
406       for (int i = 0; i < bw; ++i)
407         dst[(bw + i) * dst_buffer_stride - 1] =
408             dst[(bw - 1) * dst_buffer_stride - 1];
409     } else {
410       for (int i = 0; i < bw; ++i)
411         dst_buffer[(bw + i) * dst_buffer_stride - 1] =
412             dst_buffer[(bw - 1) * dst_buffer_stride - 1];
413     }
414 #else
415     for (int i = 0; i < bw; ++i)
416       dst_buffer[(bw + i) * dst_buffer_stride - 1] =
417           dst_buffer[(bw - 1) * dst_buffer_stride - 1];
418 #endif
419   }
420 
421   // if cpi->sf.tpl_sf.prune_intra_modes is on, then search only DC_PRED,
422   // H_PRED, and V_PRED
423   const PREDICTION_MODE last_intra_mode =
424       cpi->sf.tpl_sf.prune_intra_modes ? D45_PRED : INTRA_MODE_END;
425   for (PREDICTION_MODE mode = INTRA_MODE_START; mode < last_intra_mode;
426        ++mode) {
427     av1_predict_intra_block(cm, xd, block_size_wide[bsize],
428                             block_size_high[bsize], tx_size, mode, 0, 0,
429                             FILTER_INTRA_MODES, dst_buffer, dst_buffer_stride,
430                             predictor, bw, 0, 0, 0);
431 
432     intra_cost = tpl_get_satd_cost(x, src_diff, bw, src_mb_buffer, src_stride,
433                                    predictor, bw, coeff, bw, bh, tx_size);
434 
435     if (intra_cost < best_intra_cost) {
436       best_intra_cost = intra_cost;
437       best_mode = mode;
438     }
439   }
440 
441   // Motion compensated prediction
442   xd->mi[0]->ref_frame[0] = INTRA_FRAME;
443   xd->mi[0]->ref_frame[1] = NONE_FRAME;
444   xd->mi[0]->compound_idx = 1;
445 
446   int best_rf_idx = -1;
447   int_mv best_mv[2];
448   int64_t inter_cost;
449   int64_t best_inter_cost = INT64_MAX;
450   int rf_idx;
451   int_mv single_mv[INTER_REFS_PER_FRAME];
452 
453   best_mv[0].as_int = INVALID_MV;
454   best_mv[1].as_int = INVALID_MV;
455 
456   for (rf_idx = 0; rf_idx < INTER_REFS_PER_FRAME; ++rf_idx) {
457     single_mv[rf_idx].as_int = INVALID_MV;
458     if (tpl_data->ref_frame[rf_idx] == NULL ||
459         tpl_data->src_ref_frame[rf_idx] == NULL) {
460       tpl_stats->mv[rf_idx].as_int = INVALID_MV;
461       continue;
462     }
463 
464     const YV12_BUFFER_CONFIG *ref_frame_ptr = tpl_data->src_ref_frame[rf_idx];
465     int ref_mb_offset =
466         mi_row * MI_SIZE * ref_frame_ptr->y_stride + mi_col * MI_SIZE;
467     uint8_t *ref_mb = ref_frame_ptr->y_buffer + ref_mb_offset;
468     int ref_stride = ref_frame_ptr->y_stride;
469 
470     int_mv best_rfidx_mv = { 0 };
471     uint32_t bestsme = UINT32_MAX;
472 
473     center_mv_t center_mvs[4] = { { { 0 }, INT_MAX },
474                                   { { 0 }, INT_MAX },
475                                   { { 0 }, INT_MAX },
476                                   { { 0 }, INT_MAX } };
477     int refmv_count = 1;
478     int idx;
479 
480     if (xd->up_available) {
481       TplDepStats *ref_tpl_stats = &tpl_frame->tpl_stats_ptr[av1_tpl_ptr_pos(
482           mi_row - mi_height, mi_col, tpl_frame->stride, block_mis_log2)];
483       if (!is_alike_mv(ref_tpl_stats->mv[rf_idx], center_mvs, refmv_count,
484                        cpi->sf.tpl_sf.skip_alike_starting_mv)) {
485         center_mvs[refmv_count].mv.as_int = ref_tpl_stats->mv[rf_idx].as_int;
486         ++refmv_count;
487       }
488     }
489 
490     if (xd->left_available) {
491       TplDepStats *ref_tpl_stats = &tpl_frame->tpl_stats_ptr[av1_tpl_ptr_pos(
492           mi_row, mi_col - mi_width, tpl_frame->stride, block_mis_log2)];
493       if (!is_alike_mv(ref_tpl_stats->mv[rf_idx], center_mvs, refmv_count,
494                        cpi->sf.tpl_sf.skip_alike_starting_mv)) {
495         center_mvs[refmv_count].mv.as_int = ref_tpl_stats->mv[rf_idx].as_int;
496         ++refmv_count;
497       }
498     }
499 
500     if (xd->up_available && mi_col + mi_width < xd->tile.mi_col_end) {
501       TplDepStats *ref_tpl_stats = &tpl_frame->tpl_stats_ptr[av1_tpl_ptr_pos(
502           mi_row - mi_height, mi_col + mi_width, tpl_frame->stride,
503           block_mis_log2)];
504       if (!is_alike_mv(ref_tpl_stats->mv[rf_idx], center_mvs, refmv_count,
505                        cpi->sf.tpl_sf.skip_alike_starting_mv)) {
506         center_mvs[refmv_count].mv.as_int = ref_tpl_stats->mv[rf_idx].as_int;
507         ++refmv_count;
508       }
509     }
510 
511     // Prune starting mvs
512     if (cpi->sf.tpl_sf.prune_starting_mv) {
513       // Get each center mv's sad.
514       for (idx = 0; idx < refmv_count; ++idx) {
515         FULLPEL_MV mv = get_fullmv_from_mv(&center_mvs[idx].mv.as_mv);
516         clamp_fullmv(&mv, &x->mv_limits);
517         center_mvs[idx].sad = (int)cpi->fn_ptr[bsize].sdf(
518             src_mb_buffer, src_stride, &ref_mb[mv.row * ref_stride + mv.col],
519             ref_stride);
520       }
521 
522       // Rank center_mv using sad.
523       if (refmv_count > 1) {
524         qsort(center_mvs, refmv_count, sizeof(center_mvs[0]), compare_sad);
525       }
526       refmv_count = AOMMIN(4 - cpi->sf.tpl_sf.prune_starting_mv, refmv_count);
527       // Further reduce number of refmv based on sad difference.
528       if (refmv_count > 1) {
529         int last_sad = center_mvs[refmv_count - 1].sad;
530         int second_to_last_sad = center_mvs[refmv_count - 2].sad;
531         if ((last_sad - second_to_last_sad) * 5 > second_to_last_sad)
532           refmv_count--;
533       }
534     }
535 
536     for (idx = 0; idx < refmv_count; ++idx) {
537       int_mv this_mv;
538       uint32_t thissme = motion_estimation(cpi, x, src_mb_buffer, ref_mb,
539                                            src_stride, ref_stride, bsize,
540                                            center_mvs[idx].mv.as_mv, &this_mv);
541 
542       if (thissme < bestsme) {
543         bestsme = thissme;
544         best_rfidx_mv = this_mv;
545       }
546     }
547 
548     tpl_stats->mv[rf_idx].as_int = best_rfidx_mv.as_int;
549     single_mv[rf_idx] = best_rfidx_mv;
550 
551     struct buf_2d ref_buf = { NULL, ref_frame_ptr->y_buffer,
552                               ref_frame_ptr->y_width, ref_frame_ptr->y_height,
553                               ref_frame_ptr->y_stride };
554     InterPredParams inter_pred_params;
555     av1_init_inter_params(&inter_pred_params, bw, bh, mi_row * MI_SIZE,
556                           mi_col * MI_SIZE, 0, 0, xd->bd, is_cur_buf_hbd(xd), 0,
557                           &tpl_data->sf, &ref_buf, kernel);
558     inter_pred_params.conv_params = get_conv_params(0, 0, xd->bd);
559 
560     av1_enc_build_one_inter_predictor(predictor, bw, &best_rfidx_mv.as_mv,
561                                       &inter_pred_params);
562 
563     inter_cost = tpl_get_satd_cost(x, src_diff, bw, src_mb_buffer, src_stride,
564                                    predictor, bw, coeff, bw, bh, tx_size);
565     // Store inter cost for each ref frame
566     tpl_stats->pred_error[rf_idx] = AOMMAX(1, inter_cost);
567 
568     if (inter_cost < best_inter_cost) {
569       best_rf_idx = rf_idx;
570 
571       best_inter_cost = inter_cost;
572       best_mv[0].as_int = best_rfidx_mv.as_int;
573       if (best_inter_cost < best_intra_cost) {
574         best_mode = NEWMV;
575         xd->mi[0]->ref_frame[0] = best_rf_idx + LAST_FRAME;
576         xd->mi[0]->mv[0].as_int = best_mv[0].as_int;
577       }
578     }
579   }
580 
581   int comp_ref_frames[3][2] = {
582     { 0, 4 },
583     { 0, 6 },
584     { 3, 6 },
585   };
586 
587   xd->mi_row = mi_row;
588   xd->mi_col = mi_col;
589   int best_cmp_rf_idx = -1;
590   for (int cmp_rf_idx = 0; cmp_rf_idx < 3 && cpi->sf.tpl_sf.allow_compound_pred;
591        ++cmp_rf_idx) {
592     int rf_idx0 = comp_ref_frames[cmp_rf_idx][0];
593     int rf_idx1 = comp_ref_frames[cmp_rf_idx][1];
594 
595     if (tpl_data->ref_frame[rf_idx0] == NULL ||
596         tpl_data->src_ref_frame[rf_idx0] == NULL ||
597         tpl_data->ref_frame[rf_idx1] == NULL ||
598         tpl_data->src_ref_frame[rf_idx1] == NULL) {
599       continue;
600     }
601 
602     const YV12_BUFFER_CONFIG *ref_frame_ptr[2] = {
603       tpl_data->src_ref_frame[rf_idx0],
604       tpl_data->src_ref_frame[rf_idx1],
605     };
606 
607     xd->mi[0]->ref_frame[0] = LAST_FRAME;
608     xd->mi[0]->ref_frame[1] = ALTREF_FRAME;
609 
610     struct buf_2d yv12_mb[2][MAX_MB_PLANE];
611     for (int i = 0; i < 2; ++i) {
612       av1_setup_pred_block(xd, yv12_mb[i], ref_frame_ptr[i],
613                            xd->block_ref_scale_factors[i],
614                            xd->block_ref_scale_factors[i], MAX_MB_PLANE);
615       for (int plane = 0; plane < MAX_MB_PLANE; ++plane) {
616         xd->plane[plane].pre[i] = yv12_mb[i][plane];
617       }
618     }
619 
620     int_mv tmp_mv[2] = { single_mv[rf_idx0], single_mv[rf_idx1] };
621     int rate_mv;
622     av1_joint_motion_search(cpi, x, bsize, tmp_mv, NULL, 0, &rate_mv, 1);
623 
624     for (int ref = 0; ref < 2; ++ref) {
625       struct buf_2d ref_buf = { NULL, ref_frame_ptr[ref]->y_buffer,
626                                 ref_frame_ptr[ref]->y_width,
627                                 ref_frame_ptr[ref]->y_height,
628                                 ref_frame_ptr[ref]->y_stride };
629       InterPredParams inter_pred_params;
630       av1_init_inter_params(&inter_pred_params, bw, bh, mi_row * MI_SIZE,
631                             mi_col * MI_SIZE, 0, 0, xd->bd, is_cur_buf_hbd(xd),
632                             0, &tpl_data->sf, &ref_buf, kernel);
633       av1_init_comp_mode(&inter_pred_params);
634 
635       inter_pred_params.conv_params = get_conv_params_no_round(
636           ref, 0, xd->tmp_conv_dst, MAX_SB_SIZE, 1, xd->bd);
637 
638       av1_enc_build_one_inter_predictor(predictor, bw, &tmp_mv[ref].as_mv,
639                                         &inter_pred_params);
640     }
641     inter_cost = tpl_get_satd_cost(x, src_diff, bw, src_mb_buffer, src_stride,
642                                    predictor, bw, coeff, bw, bh, tx_size);
643     if (inter_cost < best_inter_cost) {
644       best_cmp_rf_idx = cmp_rf_idx;
645       best_inter_cost = inter_cost;
646       best_mv[0] = tmp_mv[0];
647       best_mv[1] = tmp_mv[1];
648 
649       if (best_inter_cost < best_intra_cost) {
650         best_mode = NEW_NEWMV;
651         xd->mi[0]->ref_frame[0] = rf_idx0 + LAST_FRAME;
652         xd->mi[0]->ref_frame[1] = rf_idx1 + LAST_FRAME;
653       }
654     }
655   }
656 
657   if (best_inter_cost < INT64_MAX) {
658     xd->mi[0]->mv[0].as_int = best_mv[0].as_int;
659     xd->mi[0]->mv[1].as_int = best_mv[1].as_int;
660     const YV12_BUFFER_CONFIG *ref_frame_ptr[2] = {
661       best_cmp_rf_idx >= 0
662           ? tpl_data->src_ref_frame[comp_ref_frames[best_cmp_rf_idx][0]]
663           : tpl_data->src_ref_frame[best_rf_idx],
664       best_cmp_rf_idx >= 0
665           ? tpl_data->src_ref_frame[comp_ref_frames[best_cmp_rf_idx][1]]
666           : NULL,
667     };
668     int rate_cost = 1;
669     get_rate_distortion(&rate_cost, &recon_error, src_diff, coeff, qcoeff,
670                         dqcoeff, cm, x, ref_frame_ptr, rec_buffer_pool,
671                         rec_stride_pool, tx_size, best_mode, mi_row, mi_col);
672     tpl_stats->srcrf_rate = rate_cost << TPL_DEP_COST_SCALE_LOG2;
673   }
674 
675   best_intra_cost = AOMMAX(best_intra_cost, 1);
676   best_inter_cost = AOMMIN(best_intra_cost, best_inter_cost);
677   tpl_stats->inter_cost = best_inter_cost << TPL_DEP_COST_SCALE_LOG2;
678   tpl_stats->intra_cost = best_intra_cost << TPL_DEP_COST_SCALE_LOG2;
679 
680   tpl_stats->srcrf_dist = recon_error << (TPL_DEP_COST_SCALE_LOG2);
681 
682   // Final encode
683   int rate_cost = 0;
684   const YV12_BUFFER_CONFIG *ref_frame_ptr[2];
685 
686   ref_frame_ptr[0] =
687       best_mode == NEW_NEWMV
688           ? tpl_data->ref_frame[comp_ref_frames[best_cmp_rf_idx][0]]
689           : best_rf_idx >= 0 ? tpl_data->ref_frame[best_rf_idx] : NULL;
690   ref_frame_ptr[1] =
691       best_mode == NEW_NEWMV
692           ? tpl_data->ref_frame[comp_ref_frames[best_cmp_rf_idx][1]]
693           : NULL;
694   get_rate_distortion(&rate_cost, &recon_error, src_diff, coeff, qcoeff,
695                       dqcoeff, cm, x, ref_frame_ptr, rec_buffer_pool,
696                       rec_stride_pool, tx_size, best_mode, mi_row, mi_col);
697 
698   tpl_stats->recrf_dist = recon_error << (TPL_DEP_COST_SCALE_LOG2);
699   tpl_stats->recrf_rate = rate_cost << TPL_DEP_COST_SCALE_LOG2;
700   if (!is_inter_mode(best_mode)) {
701     tpl_stats->srcrf_dist = recon_error << (TPL_DEP_COST_SCALE_LOG2);
702     tpl_stats->srcrf_rate = rate_cost << TPL_DEP_COST_SCALE_LOG2;
703   }
704 
705   tpl_stats->recrf_dist = AOMMAX(tpl_stats->srcrf_dist, tpl_stats->recrf_dist);
706   tpl_stats->recrf_rate = AOMMAX(tpl_stats->srcrf_rate, tpl_stats->recrf_rate);
707 
708   if (best_mode == NEW_NEWMV) {
709     ref_frame_ptr[0] = tpl_data->ref_frame[comp_ref_frames[best_cmp_rf_idx][0]];
710     ref_frame_ptr[1] =
711         tpl_data->src_ref_frame[comp_ref_frames[best_cmp_rf_idx][1]];
712     get_rate_distortion(&rate_cost, &recon_error, src_diff, coeff, qcoeff,
713                         dqcoeff, cm, x, ref_frame_ptr, rec_buffer_pool,
714                         rec_stride_pool, tx_size, best_mode, mi_row, mi_col);
715     tpl_stats->cmp_recrf_dist[0] = recon_error << TPL_DEP_COST_SCALE_LOG2;
716     tpl_stats->cmp_recrf_rate[0] = rate_cost << TPL_DEP_COST_SCALE_LOG2;
717 
718     tpl_stats->cmp_recrf_dist[0] =
719         AOMMAX(tpl_stats->srcrf_dist, tpl_stats->cmp_recrf_dist[0]);
720     tpl_stats->cmp_recrf_rate[0] =
721         AOMMAX(tpl_stats->srcrf_rate, tpl_stats->cmp_recrf_rate[0]);
722 
723     tpl_stats->cmp_recrf_dist[0] =
724         AOMMIN(tpl_stats->recrf_dist, tpl_stats->cmp_recrf_dist[0]);
725     tpl_stats->cmp_recrf_rate[0] =
726         AOMMIN(tpl_stats->recrf_rate, tpl_stats->cmp_recrf_rate[0]);
727 
728     rate_cost = 0;
729     ref_frame_ptr[0] =
730         tpl_data->src_ref_frame[comp_ref_frames[best_cmp_rf_idx][0]];
731     ref_frame_ptr[1] = tpl_data->ref_frame[comp_ref_frames[best_cmp_rf_idx][1]];
732     get_rate_distortion(&rate_cost, &recon_error, src_diff, coeff, qcoeff,
733                         dqcoeff, cm, x, ref_frame_ptr, rec_buffer_pool,
734                         rec_stride_pool, tx_size, best_mode, mi_row, mi_col);
735     tpl_stats->cmp_recrf_dist[1] = recon_error << TPL_DEP_COST_SCALE_LOG2;
736     tpl_stats->cmp_recrf_rate[1] = rate_cost << TPL_DEP_COST_SCALE_LOG2;
737 
738     tpl_stats->cmp_recrf_dist[1] =
739         AOMMAX(tpl_stats->srcrf_dist, tpl_stats->cmp_recrf_dist[1]);
740     tpl_stats->cmp_recrf_rate[1] =
741         AOMMAX(tpl_stats->srcrf_rate, tpl_stats->cmp_recrf_rate[1]);
742 
743     tpl_stats->cmp_recrf_dist[1] =
744         AOMMIN(tpl_stats->recrf_dist, tpl_stats->cmp_recrf_dist[1]);
745     tpl_stats->cmp_recrf_rate[1] =
746         AOMMIN(tpl_stats->recrf_rate, tpl_stats->cmp_recrf_rate[1]);
747   }
748 
749   if (best_mode == NEWMV) {
750     tpl_stats->mv[best_rf_idx] = best_mv[0];
751     tpl_stats->ref_frame_index[0] = best_rf_idx;
752     tpl_stats->ref_frame_index[1] = NONE_FRAME;
753   } else if (best_mode == NEW_NEWMV) {
754     tpl_stats->ref_frame_index[0] = comp_ref_frames[best_cmp_rf_idx][0];
755     tpl_stats->ref_frame_index[1] = comp_ref_frames[best_cmp_rf_idx][1];
756     tpl_stats->mv[tpl_stats->ref_frame_index[0]] = best_mv[0];
757     tpl_stats->mv[tpl_stats->ref_frame_index[1]] = best_mv[1];
758   }
759 
760   for (int idy = 0; idy < mi_height; ++idy) {
761     for (int idx = 0; idx < mi_width; ++idx) {
762       if ((xd->mb_to_right_edge >> (3 + MI_SIZE_LOG2)) + mi_width > idx &&
763           (xd->mb_to_bottom_edge >> (3 + MI_SIZE_LOG2)) + mi_height > idy) {
764         xd->mi[idx + idy * cm->mi_params.mi_stride] = xd->mi[0];
765       }
766     }
767   }
768 
769   // Free temporary buffers.
770   aom_free(predictor8);
771   aom_free(src_diff);
772   aom_free(coeff);
773   aom_free(qcoeff);
774   aom_free(dqcoeff);
775 }
776 
round_floor(int ref_pos,int bsize_pix)777 static int round_floor(int ref_pos, int bsize_pix) {
778   int round;
779   if (ref_pos < 0)
780     round = -(1 + (-ref_pos - 1) / bsize_pix);
781   else
782     round = ref_pos / bsize_pix;
783 
784   return round;
785 }
786 
get_overlap_area(int grid_pos_row,int grid_pos_col,int ref_pos_row,int ref_pos_col,int block,BLOCK_SIZE bsize)787 static int get_overlap_area(int grid_pos_row, int grid_pos_col, int ref_pos_row,
788                             int ref_pos_col, int block, BLOCK_SIZE bsize) {
789   int width = 0, height = 0;
790   int bw = 4 << mi_size_wide_log2[bsize];
791   int bh = 4 << mi_size_high_log2[bsize];
792 
793   switch (block) {
794     case 0:
795       width = grid_pos_col + bw - ref_pos_col;
796       height = grid_pos_row + bh - ref_pos_row;
797       break;
798     case 1:
799       width = ref_pos_col + bw - grid_pos_col;
800       height = grid_pos_row + bh - ref_pos_row;
801       break;
802     case 2:
803       width = grid_pos_col + bw - ref_pos_col;
804       height = ref_pos_row + bh - grid_pos_row;
805       break;
806     case 3:
807       width = ref_pos_col + bw - grid_pos_col;
808       height = ref_pos_row + bh - grid_pos_row;
809       break;
810     default: assert(0);
811   }
812 
813   return width * height;
814 }
815 
av1_tpl_ptr_pos(int mi_row,int mi_col,int stride,uint8_t right_shift)816 int av1_tpl_ptr_pos(int mi_row, int mi_col, int stride, uint8_t right_shift) {
817   return (mi_row >> right_shift) * stride + (mi_col >> right_shift);
818 }
819 
delta_rate_cost(int64_t delta_rate,int64_t recrf_dist,int64_t srcrf_dist,int pix_num)820 static int64_t delta_rate_cost(int64_t delta_rate, int64_t recrf_dist,
821                                int64_t srcrf_dist, int pix_num) {
822   double beta = (double)srcrf_dist / recrf_dist;
823   int64_t rate_cost = delta_rate;
824 
825   if (srcrf_dist <= 128) return rate_cost;
826 
827   double dr =
828       (double)(delta_rate >> (TPL_DEP_COST_SCALE_LOG2 + AV1_PROB_COST_SHIFT)) /
829       pix_num;
830 
831   double log_den = log(beta) / log(2.0) + 2.0 * dr;
832 
833   if (log_den > log(10.0) / log(2.0)) {
834     rate_cost = (int64_t)((log(1.0 / beta) * pix_num) / log(2.0) / 2.0);
835     rate_cost <<= (TPL_DEP_COST_SCALE_LOG2 + AV1_PROB_COST_SHIFT);
836     return rate_cost;
837   }
838 
839   double num = pow(2.0, log_den);
840   double den = num * beta + (1 - beta) * beta;
841 
842   rate_cost = (int64_t)((pix_num * log(num / den)) / log(2.0) / 2.0);
843 
844   rate_cost <<= (TPL_DEP_COST_SCALE_LOG2 + AV1_PROB_COST_SHIFT);
845 
846   return rate_cost;
847 }
848 
tpl_model_update_b(TplParams * const tpl_data,int mi_row,int mi_col,const BLOCK_SIZE bsize,int frame_idx,int ref)849 static AOM_INLINE void tpl_model_update_b(TplParams *const tpl_data, int mi_row,
850                                           int mi_col, const BLOCK_SIZE bsize,
851                                           int frame_idx, int ref) {
852   TplDepFrame *tpl_frame_ptr = &tpl_data->tpl_frame[frame_idx];
853   TplDepStats *tpl_ptr = tpl_frame_ptr->tpl_stats_ptr;
854   TplDepFrame *tpl_frame = tpl_data->tpl_frame;
855   const uint8_t block_mis_log2 = tpl_data->tpl_stats_block_mis_log2;
856   TplDepStats *tpl_stats_ptr = &tpl_ptr[av1_tpl_ptr_pos(
857       mi_row, mi_col, tpl_frame->stride, block_mis_log2)];
858 
859   int is_compound = tpl_stats_ptr->ref_frame_index[1] >= 0;
860 
861   if (tpl_stats_ptr->ref_frame_index[ref] < 0) return;
862   const int ref_frame_index = tpl_stats_ptr->ref_frame_index[ref];
863   TplDepFrame *ref_tpl_frame =
864       &tpl_frame[tpl_frame[frame_idx].ref_map_index[ref_frame_index]];
865   TplDepStats *ref_stats_ptr = ref_tpl_frame->tpl_stats_ptr;
866 
867   if (tpl_frame[frame_idx].ref_map_index[ref_frame_index] < 0) return;
868 
869   const FULLPEL_MV full_mv =
870       get_fullmv_from_mv(&tpl_stats_ptr->mv[ref_frame_index].as_mv);
871   const int ref_pos_row = mi_row * MI_SIZE + full_mv.row;
872   const int ref_pos_col = mi_col * MI_SIZE + full_mv.col;
873 
874   const int bw = 4 << mi_size_wide_log2[bsize];
875   const int bh = 4 << mi_size_high_log2[bsize];
876   const int mi_height = mi_size_high[bsize];
877   const int mi_width = mi_size_wide[bsize];
878   const int pix_num = bw * bh;
879 
880   // top-left on grid block location in pixel
881   int grid_pos_row_base = round_floor(ref_pos_row, bh) * bh;
882   int grid_pos_col_base = round_floor(ref_pos_col, bw) * bw;
883   int block;
884 
885   int64_t srcrf_dist = is_compound ? tpl_stats_ptr->cmp_recrf_dist[!ref]
886                                    : tpl_stats_ptr->srcrf_dist;
887   int64_t srcrf_rate = is_compound ? tpl_stats_ptr->cmp_recrf_rate[!ref]
888                                    : tpl_stats_ptr->srcrf_rate;
889 
890   int64_t cur_dep_dist = tpl_stats_ptr->recrf_dist - srcrf_dist;
891   int64_t mc_dep_dist =
892       (int64_t)(tpl_stats_ptr->mc_dep_dist *
893                 ((double)(tpl_stats_ptr->recrf_dist - srcrf_dist) /
894                  tpl_stats_ptr->recrf_dist));
895   int64_t delta_rate = tpl_stats_ptr->recrf_rate - srcrf_rate;
896   int64_t mc_dep_rate =
897       delta_rate_cost(tpl_stats_ptr->mc_dep_rate, tpl_stats_ptr->recrf_dist,
898                       srcrf_dist, pix_num);
899 
900   for (block = 0; block < 4; ++block) {
901     int grid_pos_row = grid_pos_row_base + bh * (block >> 1);
902     int grid_pos_col = grid_pos_col_base + bw * (block & 0x01);
903 
904     if (grid_pos_row >= 0 && grid_pos_row < ref_tpl_frame->mi_rows * MI_SIZE &&
905         grid_pos_col >= 0 && grid_pos_col < ref_tpl_frame->mi_cols * MI_SIZE) {
906       int overlap_area = get_overlap_area(
907           grid_pos_row, grid_pos_col, ref_pos_row, ref_pos_col, block, bsize);
908       int ref_mi_row = round_floor(grid_pos_row, bh) * mi_height;
909       int ref_mi_col = round_floor(grid_pos_col, bw) * mi_width;
910       const int step = 1 << block_mis_log2;
911 
912       for (int idy = 0; idy < mi_height; idy += step) {
913         for (int idx = 0; idx < mi_width; idx += step) {
914           TplDepStats *des_stats = &ref_stats_ptr[av1_tpl_ptr_pos(
915               ref_mi_row + idy, ref_mi_col + idx, ref_tpl_frame->stride,
916               block_mis_log2)];
917           des_stats->mc_dep_dist +=
918               ((cur_dep_dist + mc_dep_dist) * overlap_area) / pix_num;
919           des_stats->mc_dep_rate +=
920               ((delta_rate + mc_dep_rate) * overlap_area) / pix_num;
921 
922           assert(overlap_area >= 0);
923         }
924       }
925     }
926   }
927 }
928 
tpl_model_update(TplParams * const tpl_data,int mi_row,int mi_col,const BLOCK_SIZE bsize,int frame_idx)929 static AOM_INLINE void tpl_model_update(TplParams *const tpl_data, int mi_row,
930                                         int mi_col, const BLOCK_SIZE bsize,
931                                         int frame_idx) {
932   const int mi_height = mi_size_high[bsize];
933   const int mi_width = mi_size_wide[bsize];
934   const int step = 1 << tpl_data->tpl_stats_block_mis_log2;
935   const BLOCK_SIZE tpl_stats_block_size =
936       convert_length_to_bsize(MI_SIZE << tpl_data->tpl_stats_block_mis_log2);
937 
938   for (int idy = 0; idy < mi_height; idy += step) {
939     for (int idx = 0; idx < mi_width; idx += step) {
940       tpl_model_update_b(tpl_data, mi_row + idy, mi_col + idx,
941                          tpl_stats_block_size, frame_idx, 0);
942       tpl_model_update_b(tpl_data, mi_row + idy, mi_col + idx,
943                          tpl_stats_block_size, frame_idx, 1);
944     }
945   }
946 }
947 
tpl_model_store(TplDepStats * tpl_stats_ptr,int mi_row,int mi_col,BLOCK_SIZE bsize,int stride,const TplDepStats * src_stats,uint8_t block_mis_log2)948 static AOM_INLINE void tpl_model_store(TplDepStats *tpl_stats_ptr, int mi_row,
949                                        int mi_col, BLOCK_SIZE bsize, int stride,
950                                        const TplDepStats *src_stats,
951                                        uint8_t block_mis_log2) {
952   const int mi_height = mi_size_high[bsize];
953   const int mi_width = mi_size_wide[bsize];
954   const int step = 1 << block_mis_log2;
955   const int div = (mi_height >> block_mis_log2) * (mi_width >> block_mis_log2);
956 
957   int64_t intra_cost = src_stats->intra_cost / div;
958   int64_t inter_cost = src_stats->inter_cost / div;
959   int64_t srcrf_dist = src_stats->srcrf_dist / div;
960   int64_t recrf_dist = src_stats->recrf_dist / div;
961   int64_t srcrf_rate = src_stats->srcrf_rate / div;
962   int64_t recrf_rate = src_stats->recrf_rate / div;
963   int64_t cmp_recrf_dist[2] = {
964     src_stats->cmp_recrf_dist[0] / div,
965     src_stats->cmp_recrf_dist[1] / div,
966   };
967   int64_t cmp_recrf_rate[2] = {
968     src_stats->cmp_recrf_rate[0] / div,
969     src_stats->cmp_recrf_rate[1] / div,
970   };
971 
972   intra_cost = AOMMAX(1, intra_cost);
973   inter_cost = AOMMAX(1, inter_cost);
974   srcrf_dist = AOMMAX(1, srcrf_dist);
975   recrf_dist = AOMMAX(1, recrf_dist);
976   srcrf_rate = AOMMAX(1, srcrf_rate);
977   recrf_rate = AOMMAX(1, recrf_rate);
978   cmp_recrf_dist[0] = AOMMAX(1, cmp_recrf_dist[0]);
979   cmp_recrf_dist[1] = AOMMAX(1, cmp_recrf_dist[1]);
980   cmp_recrf_rate[0] = AOMMAX(1, cmp_recrf_rate[0]);
981   cmp_recrf_rate[1] = AOMMAX(1, cmp_recrf_rate[1]);
982 
983   for (int idy = 0; idy < mi_height; idy += step) {
984     TplDepStats *tpl_ptr = &tpl_stats_ptr[av1_tpl_ptr_pos(
985         mi_row + idy, mi_col, stride, block_mis_log2)];
986     for (int idx = 0; idx < mi_width; idx += step) {
987       tpl_ptr->intra_cost = intra_cost;
988       tpl_ptr->inter_cost = inter_cost;
989       tpl_ptr->srcrf_dist = srcrf_dist;
990       tpl_ptr->recrf_dist = recrf_dist;
991       tpl_ptr->srcrf_rate = srcrf_rate;
992       tpl_ptr->recrf_rate = recrf_rate;
993       tpl_ptr->cmp_recrf_dist[0] = cmp_recrf_dist[0];
994       tpl_ptr->cmp_recrf_dist[1] = cmp_recrf_dist[1];
995       tpl_ptr->cmp_recrf_rate[0] = cmp_recrf_rate[0];
996       tpl_ptr->cmp_recrf_rate[1] = cmp_recrf_rate[1];
997       memcpy(tpl_ptr->mv, src_stats->mv, sizeof(tpl_ptr->mv));
998       memcpy(tpl_ptr->pred_error, src_stats->pred_error,
999              sizeof(tpl_ptr->pred_error));
1000       tpl_ptr->ref_frame_index[0] = src_stats->ref_frame_index[0];
1001       tpl_ptr->ref_frame_index[1] = src_stats->ref_frame_index[1];
1002       ++tpl_ptr;
1003     }
1004   }
1005 }
1006 
1007 // Reset the ref and source frame pointers of tpl_data.
tpl_reset_src_ref_frames(TplParams * tpl_data)1008 static AOM_INLINE void tpl_reset_src_ref_frames(TplParams *tpl_data) {
1009   for (int i = 0; i < INTER_REFS_PER_FRAME; ++i) {
1010     tpl_data->ref_frame[i] = NULL;
1011     tpl_data->src_ref_frame[i] = NULL;
1012   }
1013 }
1014 
get_gop_length(const GF_GROUP * gf_group)1015 static AOM_INLINE int get_gop_length(const GF_GROUP *gf_group) {
1016   int gop_length = AOMMIN(gf_group->size, MAX_TPL_FRAME_IDX - 1);
1017   return gop_length;
1018 }
1019 
1020 // Initialize the mc_flow parameters used in computing tpl data.
init_mc_flow_dispenser(AV1_COMP * cpi,int frame_idx,int pframe_qindex)1021 static AOM_INLINE void init_mc_flow_dispenser(AV1_COMP *cpi, int frame_idx,
1022                                               int pframe_qindex) {
1023   TplParams *const tpl_data = &cpi->tpl_data;
1024   TplDepFrame *tpl_frame = &tpl_data->tpl_frame[frame_idx];
1025   const YV12_BUFFER_CONFIG *this_frame = tpl_frame->gf_picture;
1026   const YV12_BUFFER_CONFIG *ref_frames_ordered[INTER_REFS_PER_FRAME];
1027   uint32_t ref_frame_display_indices[INTER_REFS_PER_FRAME];
1028   GF_GROUP *gf_group = &cpi->gf_group;
1029   int ref_pruning_enabled = is_frame_eligible_for_ref_pruning(
1030       gf_group, cpi->sf.inter_sf.selective_ref_frame,
1031       cpi->sf.tpl_sf.prune_ref_frames_in_tpl, frame_idx);
1032   int gop_length = get_gop_length(gf_group);
1033   int ref_frame_flags;
1034   AV1_COMMON *cm = &cpi->common;
1035   int rdmult, idx;
1036   ThreadData *td = &cpi->td;
1037   MACROBLOCK *x = &td->mb;
1038   MACROBLOCKD *xd = &x->e_mbd;
1039   tpl_data->frame_idx = frame_idx;
1040   tpl_reset_src_ref_frames(tpl_data);
1041   av1_tile_init(&xd->tile, cm, 0, 0);
1042 
1043   // Setup scaling factor
1044   av1_setup_scale_factors_for_frame(
1045       &tpl_data->sf, this_frame->y_crop_width, this_frame->y_crop_height,
1046       this_frame->y_crop_width, this_frame->y_crop_height);
1047 
1048   xd->cur_buf = this_frame;
1049 
1050   for (idx = 0; idx < INTER_REFS_PER_FRAME; ++idx) {
1051     TplDepFrame *tpl_ref_frame =
1052         &tpl_data->tpl_frame[tpl_frame->ref_map_index[idx]];
1053     tpl_data->ref_frame[idx] = tpl_ref_frame->rec_picture;
1054     tpl_data->src_ref_frame[idx] = tpl_ref_frame->gf_picture;
1055     ref_frame_display_indices[idx] = tpl_ref_frame->frame_display_index;
1056   }
1057 
1058   // Store the reference frames based on priority order
1059   for (int i = 0; i < INTER_REFS_PER_FRAME; ++i) {
1060     ref_frames_ordered[i] =
1061         tpl_data->ref_frame[ref_frame_priority_order[i] - 1];
1062   }
1063 
1064   // Work out which reference frame slots may be used.
1065   ref_frame_flags = get_ref_frame_flags(&cpi->sf, ref_frames_ordered,
1066                                         cpi->ext_flags.ref_frame_flags);
1067 
1068   enforce_max_ref_frames(cpi, &ref_frame_flags, ref_frame_display_indices,
1069                          tpl_frame->frame_display_index);
1070 
1071   // Prune reference frames
1072   for (idx = 0; idx < INTER_REFS_PER_FRAME; ++idx) {
1073     if ((ref_frame_flags & (1 << idx)) == 0) {
1074       tpl_data->ref_frame[idx] = NULL;
1075     }
1076   }
1077 
1078   // Skip motion estimation w.r.t. reference frames which are not
1079   // considered in RD search, using "selective_ref_frame" speed feature.
1080   // The reference frame pruning is not enabled for frames beyond the gop
1081   // length, as there are fewer reference frames and the reference frames
1082   // differ from the frames considered during RD search.
1083   if (ref_pruning_enabled && (frame_idx < gop_length)) {
1084     for (idx = 0; idx < INTER_REFS_PER_FRAME; ++idx) {
1085       const MV_REFERENCE_FRAME refs[2] = { idx + 1, NONE_FRAME };
1086       if (prune_ref_by_selective_ref_frame(cpi, NULL, refs,
1087                                            ref_frame_display_indices)) {
1088         tpl_data->ref_frame[idx] = NULL;
1089       }
1090     }
1091   }
1092 
1093   // Make a temporary mbmi for tpl model
1094   MB_MODE_INFO mbmi;
1095   memset(&mbmi, 0, sizeof(mbmi));
1096   MB_MODE_INFO *mbmi_ptr = &mbmi;
1097   xd->mi = &mbmi_ptr;
1098 
1099   xd->block_ref_scale_factors[0] = &tpl_data->sf;
1100   xd->block_ref_scale_factors[1] = &tpl_data->sf;
1101 
1102   const int base_qindex = pframe_qindex;
1103   // Get rd multiplier set up.
1104   rdmult = (int)av1_compute_rd_mult(cpi, base_qindex);
1105   if (rdmult < 1) rdmult = 1;
1106   av1_set_error_per_bit(&x->errorperbit, rdmult);
1107   av1_set_sad_per_bit(cpi, &x->sadperbit, base_qindex);
1108 
1109   tpl_frame->is_valid = 1;
1110 
1111   cm->quant_params.base_qindex = base_qindex;
1112   av1_frame_init_quantizer(cpi);
1113 
1114   tpl_frame->base_rdmult =
1115       av1_compute_rd_mult_based_on_qindex(cpi, pframe_qindex) / 6;
1116 }
1117 
1118 // This function stores the motion estimation dependencies of all the blocks in
1119 // a row
av1_mc_flow_dispenser_row(AV1_COMP * cpi,MACROBLOCK * x,int mi_row,BLOCK_SIZE bsize,TX_SIZE tx_size)1120 void av1_mc_flow_dispenser_row(AV1_COMP *cpi, MACROBLOCK *x, int mi_row,
1121                                BLOCK_SIZE bsize, TX_SIZE tx_size) {
1122   AV1_COMMON *const cm = &cpi->common;
1123   MultiThreadInfo *const mt_info = &cpi->mt_info;
1124   AV1TplRowMultiThreadInfo *const tpl_row_mt = &mt_info->tpl_row_mt;
1125   const CommonModeInfoParams *const mi_params = &cm->mi_params;
1126   const int mi_width = mi_size_wide[bsize];
1127   TplParams *const tpl_data = &cpi->tpl_data;
1128   TplDepFrame *tpl_frame = &tpl_data->tpl_frame[tpl_data->frame_idx];
1129   MACROBLOCKD *xd = &x->e_mbd;
1130 
1131   const int tplb_cols_in_tile =
1132       ROUND_POWER_OF_TWO(mi_params->mi_cols, mi_size_wide_log2[bsize]);
1133   const int tplb_row = ROUND_POWER_OF_TWO(mi_row, mi_size_high_log2[bsize]);
1134 
1135   for (int mi_col = 0, tplb_col_in_tile = 0; mi_col < mi_params->mi_cols;
1136        mi_col += mi_width, tplb_col_in_tile++) {
1137     (*tpl_row_mt->sync_read_ptr)(&tpl_data->tpl_mt_sync, tplb_row,
1138                                  tplb_col_in_tile);
1139     TplDepStats tpl_stats;
1140 
1141     // Motion estimation column boundary
1142     av1_set_mv_col_limits(mi_params, &x->mv_limits, mi_col, mi_width,
1143                           tpl_data->border_in_pixels);
1144     xd->mb_to_left_edge = -GET_MV_SUBPEL(mi_col * MI_SIZE);
1145     xd->mb_to_right_edge =
1146         GET_MV_SUBPEL(mi_params->mi_cols - mi_width - mi_col);
1147     mode_estimation(cpi, x, mi_row, mi_col, bsize, tx_size, &tpl_stats);
1148 
1149     // Motion flow dependency dispenser.
1150     tpl_model_store(tpl_frame->tpl_stats_ptr, mi_row, mi_col, bsize,
1151                     tpl_frame->stride, &tpl_stats,
1152                     tpl_data->tpl_stats_block_mis_log2);
1153     (*tpl_row_mt->sync_write_ptr)(&tpl_data->tpl_mt_sync, tplb_row,
1154                                   tplb_col_in_tile, tplb_cols_in_tile);
1155   }
1156 }
1157 
mc_flow_dispenser(AV1_COMP * cpi)1158 static AOM_INLINE void mc_flow_dispenser(AV1_COMP *cpi) {
1159   AV1_COMMON *cm = &cpi->common;
1160   const CommonModeInfoParams *const mi_params = &cm->mi_params;
1161   ThreadData *td = &cpi->td;
1162   MACROBLOCK *x = &td->mb;
1163   MACROBLOCKD *xd = &x->e_mbd;
1164   const BLOCK_SIZE bsize = convert_length_to_bsize(cpi->tpl_data.tpl_bsize_1d);
1165   const TX_SIZE tx_size = max_txsize_lookup[bsize];
1166   const int mi_height = mi_size_high[bsize];
1167   for (int mi_row = 0; mi_row < mi_params->mi_rows; mi_row += mi_height) {
1168     // Motion estimation row boundary
1169     av1_set_mv_row_limits(mi_params, &x->mv_limits, mi_row, mi_height,
1170                           cpi->tpl_data.border_in_pixels);
1171     xd->mb_to_top_edge = -GET_MV_SUBPEL(mi_row * MI_SIZE);
1172     xd->mb_to_bottom_edge =
1173         GET_MV_SUBPEL((mi_params->mi_rows - mi_height - mi_row) * MI_SIZE);
1174     av1_mc_flow_dispenser_row(cpi, x, mi_row, bsize, tx_size);
1175   }
1176 }
1177 
mc_flow_synthesizer(AV1_COMP * cpi,int frame_idx)1178 static void mc_flow_synthesizer(AV1_COMP *cpi, int frame_idx) {
1179   AV1_COMMON *cm = &cpi->common;
1180   TplParams *const tpl_data = &cpi->tpl_data;
1181 
1182   const BLOCK_SIZE bsize = convert_length_to_bsize(tpl_data->tpl_bsize_1d);
1183   const int mi_height = mi_size_high[bsize];
1184   const int mi_width = mi_size_wide[bsize];
1185 
1186   for (int mi_row = 0; mi_row < cm->mi_params.mi_rows; mi_row += mi_height) {
1187     for (int mi_col = 0; mi_col < cm->mi_params.mi_cols; mi_col += mi_width) {
1188       if (frame_idx) {
1189         tpl_model_update(tpl_data, mi_row, mi_col, bsize, frame_idx);
1190       }
1191     }
1192   }
1193 }
1194 
init_gop_frames_for_tpl(AV1_COMP * cpi,const EncodeFrameParams * const init_frame_params,GF_GROUP * gf_group,int gop_eval,int * tpl_group_frames,const EncodeFrameInput * const frame_input,int * pframe_qindex)1195 static AOM_INLINE void init_gop_frames_for_tpl(
1196     AV1_COMP *cpi, const EncodeFrameParams *const init_frame_params,
1197     GF_GROUP *gf_group, int gop_eval, int *tpl_group_frames,
1198     const EncodeFrameInput *const frame_input, int *pframe_qindex) {
1199   AV1_COMMON *cm = &cpi->common;
1200   int cur_frame_idx = gf_group->index;
1201   *pframe_qindex = 0;
1202 
1203   RefBufferStack ref_buffer_stack = cpi->ref_buffer_stack;
1204   EncodeFrameParams frame_params = *init_frame_params;
1205   TplParams *const tpl_data = &cpi->tpl_data;
1206 
1207   int ref_picture_map[REF_FRAMES];
1208 
1209   for (int i = 0; i < REF_FRAMES; ++i) {
1210     if (frame_params.frame_type == KEY_FRAME) {
1211       tpl_data->tpl_frame[-i - 1].gf_picture = NULL;
1212       tpl_data->tpl_frame[-1 - 1].rec_picture = NULL;
1213       tpl_data->tpl_frame[-i - 1].frame_display_index = 0;
1214     } else {
1215       tpl_data->tpl_frame[-i - 1].gf_picture = &cm->ref_frame_map[i]->buf;
1216       tpl_data->tpl_frame[-i - 1].rec_picture = &cm->ref_frame_map[i]->buf;
1217       tpl_data->tpl_frame[-i - 1].frame_display_index =
1218           cm->ref_frame_map[i]->display_order_hint;
1219     }
1220 
1221     ref_picture_map[i] = -i - 1;
1222   }
1223 
1224   *tpl_group_frames = cur_frame_idx;
1225 
1226   int gf_index;
1227   int anc_frame_offset = gop_eval ? 0 : gf_group->cur_frame_idx[cur_frame_idx];
1228   int process_frame_count = 0;
1229   const int gop_length = get_gop_length(gf_group);
1230 
1231   for (gf_index = cur_frame_idx; gf_index < gop_length; ++gf_index) {
1232     TplDepFrame *tpl_frame = &tpl_data->tpl_frame[gf_index];
1233     FRAME_UPDATE_TYPE frame_update_type = gf_group->update_type[gf_index];
1234     int frame_display_index = gf_index == gf_group->size
1235                                   ? cpi->rc.baseline_gf_interval
1236                                   : gf_group->cur_frame_idx[gf_index] +
1237                                         gf_group->arf_src_offset[gf_index];
1238 
1239     int lookahead_index = frame_display_index - anc_frame_offset;
1240 
1241     frame_params.show_frame = frame_update_type != ARF_UPDATE &&
1242                               frame_update_type != INTNL_ARF_UPDATE;
1243     frame_params.show_existing_frame =
1244         frame_update_type == INTNL_OVERLAY_UPDATE ||
1245         frame_update_type == OVERLAY_UPDATE;
1246     frame_params.frame_type = gf_group->frame_type[gf_index];
1247 
1248     if (frame_update_type == LF_UPDATE)
1249       *pframe_qindex = gf_group->q_val[gf_index];
1250 
1251     struct lookahead_entry *buf;
1252     if (gf_index == cur_frame_idx) {
1253       buf = av1_lookahead_peek(cpi->lookahead, lookahead_index,
1254                                cpi->compressor_stage);
1255       tpl_frame->gf_picture = gop_eval ? &buf->img : frame_input->source;
1256     } else {
1257       buf = av1_lookahead_peek(cpi->lookahead, lookahead_index,
1258                                cpi->compressor_stage);
1259       if (buf == NULL) break;
1260       tpl_frame->gf_picture = &buf->img;
1261     }
1262     if (gop_eval && cpi->rc.frames_since_key > 0 &&
1263         gf_group->arf_index == gf_index)
1264       tpl_frame->gf_picture = &cpi->alt_ref_buffer;
1265 
1266     // 'cm->current_frame.frame_number' is the display number
1267     // of the current frame.
1268     // 'anc_frame_offset' is the number of frames displayed so
1269     // far within the gf group. 'cm->current_frame.frame_number -
1270     // anc_frame_offset' is the offset of the first frame in the gf group.
1271     // 'frame display index' is frame offset within the gf group.
1272     // 'frame_display_index + cm->current_frame.frame_number - anc_frame_offset'
1273     // is the display index of the frame.
1274     tpl_frame->frame_display_index =
1275         frame_display_index + cm->current_frame.frame_number - anc_frame_offset;
1276     assert(buf->display_idx == cpi->frame_index_set.show_frame_count -
1277                                    anc_frame_offset + frame_display_index);
1278 
1279     if (frame_update_type != OVERLAY_UPDATE &&
1280         frame_update_type != INTNL_OVERLAY_UPDATE) {
1281       tpl_frame->rec_picture = &tpl_data->tpl_rec_pool[process_frame_count];
1282       tpl_frame->tpl_stats_ptr = tpl_data->tpl_stats_pool[process_frame_count];
1283       ++process_frame_count;
1284     }
1285 
1286     av1_get_ref_frames(cpi, &ref_buffer_stack);
1287     int refresh_mask = av1_get_refresh_frame_flags(
1288         cpi, &frame_params, frame_update_type, &ref_buffer_stack);
1289 
1290     int refresh_frame_map_index = av1_get_refresh_ref_frame_map(refresh_mask);
1291     av1_update_ref_frame_map(cpi, frame_update_type, frame_params.frame_type,
1292                              frame_params.show_existing_frame,
1293                              refresh_frame_map_index, &ref_buffer_stack);
1294 
1295     for (int i = LAST_FRAME; i <= ALTREF_FRAME; ++i)
1296       tpl_frame->ref_map_index[i - LAST_FRAME] =
1297           ref_picture_map[cm->remapped_ref_idx[i - LAST_FRAME]];
1298 
1299     if (refresh_mask) ref_picture_map[refresh_frame_map_index] = gf_index;
1300 
1301     ++*tpl_group_frames;
1302   }
1303 
1304   if (cpi->rc.frames_since_key == 0) return;
1305 
1306   int extend_frame_count = 0;
1307   int extend_frame_length = AOMMIN(
1308       MAX_TPL_EXTEND, cpi->rc.frames_to_key - cpi->rc.baseline_gf_interval);
1309   int frame_display_index = gf_group->cur_frame_idx[gop_length - 1] +
1310                             gf_group->arf_src_offset[gop_length - 1] + 1;
1311 
1312   for (;
1313        gf_index < MAX_TPL_FRAME_IDX && extend_frame_count < extend_frame_length;
1314        ++gf_index) {
1315     TplDepFrame *tpl_frame = &tpl_data->tpl_frame[gf_index];
1316     FRAME_UPDATE_TYPE frame_update_type = LF_UPDATE;
1317     frame_params.show_frame = frame_update_type != ARF_UPDATE &&
1318                               frame_update_type != INTNL_ARF_UPDATE;
1319     frame_params.show_existing_frame =
1320         frame_update_type == INTNL_OVERLAY_UPDATE;
1321     frame_params.frame_type = INTER_FRAME;
1322 
1323     int lookahead_index = frame_display_index - anc_frame_offset;
1324     struct lookahead_entry *buf = av1_lookahead_peek(
1325         cpi->lookahead, lookahead_index, cpi->compressor_stage);
1326 
1327     if (buf == NULL) break;
1328 
1329     tpl_frame->gf_picture = &buf->img;
1330     tpl_frame->rec_picture = &tpl_data->tpl_rec_pool[process_frame_count];
1331     tpl_frame->tpl_stats_ptr = tpl_data->tpl_stats_pool[process_frame_count];
1332     // 'cm->current_frame.frame_number' is the display number
1333     // of the current frame.
1334     // 'anc_frame_offset' is the number of frames displayed so
1335     // far within the gf group. 'cm->current_frame.frame_number -
1336     // anc_frame_offset' is the offset of the first frame in the gf group.
1337     // 'frame display index' is frame offset within the gf group.
1338     // 'frame_display_index + cm->current_frame.frame_number - anc_frame_offset'
1339     // is the display index of the frame.
1340     tpl_frame->frame_display_index =
1341         frame_display_index + cm->current_frame.frame_number - anc_frame_offset;
1342 
1343     ++process_frame_count;
1344 
1345     gf_group->update_type[gf_index] = LF_UPDATE;
1346     gf_group->q_val[gf_index] = *pframe_qindex;
1347 
1348     av1_get_ref_frames(cpi, &ref_buffer_stack);
1349     int refresh_mask = av1_get_refresh_frame_flags(
1350         cpi, &frame_params, frame_update_type, &ref_buffer_stack);
1351     int refresh_frame_map_index = av1_get_refresh_ref_frame_map(refresh_mask);
1352     av1_update_ref_frame_map(cpi, frame_update_type, frame_params.frame_type,
1353                              frame_params.show_existing_frame,
1354                              refresh_frame_map_index, &ref_buffer_stack);
1355 
1356     for (int i = LAST_FRAME; i <= ALTREF_FRAME; ++i)
1357       tpl_frame->ref_map_index[i - LAST_FRAME] =
1358           ref_picture_map[cm->remapped_ref_idx[i - LAST_FRAME]];
1359 
1360     tpl_frame->ref_map_index[ALTREF_FRAME - LAST_FRAME] = -1;
1361     tpl_frame->ref_map_index[LAST3_FRAME - LAST_FRAME] = -1;
1362     tpl_frame->ref_map_index[BWDREF_FRAME - LAST_FRAME] = -1;
1363     tpl_frame->ref_map_index[ALTREF2_FRAME - LAST_FRAME] = -1;
1364 
1365     if (refresh_mask) ref_picture_map[refresh_frame_map_index] = gf_index;
1366 
1367     ++*tpl_group_frames;
1368     ++extend_frame_count;
1369     ++frame_display_index;
1370   }
1371 
1372   av1_get_ref_frames(cpi, &cpi->ref_buffer_stack);
1373 }
1374 
av1_init_tpl_stats(TplParams * const tpl_data)1375 void av1_init_tpl_stats(TplParams *const tpl_data) {
1376   for (int frame_idx = 0; frame_idx < MAX_LAG_BUFFERS; ++frame_idx) {
1377     TplDepFrame *tpl_frame = &tpl_data->tpl_stats_buffer[frame_idx];
1378     if (tpl_data->tpl_stats_pool[frame_idx] == NULL) continue;
1379     memset(tpl_data->tpl_stats_pool[frame_idx], 0,
1380            tpl_frame->height * tpl_frame->width *
1381                sizeof(*tpl_frame->tpl_stats_ptr));
1382     tpl_frame->is_valid = 0;
1383   }
1384 }
1385 
av1_tpl_setup_stats(AV1_COMP * cpi,int gop_eval,const EncodeFrameParams * const frame_params,const EncodeFrameInput * const frame_input)1386 int av1_tpl_setup_stats(AV1_COMP *cpi, int gop_eval,
1387                         const EncodeFrameParams *const frame_params,
1388                         const EncodeFrameInput *const frame_input) {
1389 #if CONFIG_COLLECT_COMPONENT_TIMING
1390   start_timing(cpi, av1_tpl_setup_stats_time);
1391 #endif
1392   AV1_COMMON *cm = &cpi->common;
1393   MultiThreadInfo *const mt_info = &cpi->mt_info;
1394   AV1TplRowMultiThreadInfo *const tpl_row_mt = &mt_info->tpl_row_mt;
1395   GF_GROUP *gf_group = &cpi->gf_group;
1396   int bottom_index, top_index;
1397   EncodeFrameParams this_frame_params = *frame_params;
1398   TplParams *const tpl_data = &cpi->tpl_data;
1399 
1400   if (cpi->superres_mode != AOM_SUPERRES_NONE) {
1401     assert(cpi->superres_mode != AOM_SUPERRES_AUTO);
1402     av1_init_tpl_stats(tpl_data);
1403     return 0;
1404   }
1405 
1406   cm->current_frame.frame_type = frame_params->frame_type;
1407   for (int gf_index = gf_group->index; gf_index < gf_group->size; ++gf_index) {
1408     cm->current_frame.frame_type = gf_group->frame_type[gf_index];
1409     av1_configure_buffer_updates(cpi, &this_frame_params.refresh_frame,
1410                                  gf_group->update_type[gf_index],
1411                                  cm->current_frame.frame_type, 0);
1412 
1413     memcpy(&cpi->refresh_frame, &this_frame_params.refresh_frame,
1414            sizeof(cpi->refresh_frame));
1415 
1416     cm->show_frame = gf_group->update_type[gf_index] != ARF_UPDATE &&
1417                      gf_group->update_type[gf_index] != INTNL_ARF_UPDATE;
1418 
1419     gf_group->q_val[gf_index] =
1420         av1_rc_pick_q_and_bounds(cpi, &cpi->rc, cm->width, cm->height, gf_index,
1421                                  &bottom_index, &top_index);
1422   }
1423 
1424   int pframe_qindex;
1425   int tpl_gf_group_frames;
1426   init_gop_frames_for_tpl(cpi, frame_params, gf_group, gop_eval,
1427                           &tpl_gf_group_frames, frame_input, &pframe_qindex);
1428 
1429   cpi->rc.base_layer_qp = pframe_qindex;
1430 
1431   av1_init_tpl_stats(tpl_data);
1432 
1433   tpl_row_mt->sync_read_ptr = av1_tpl_row_mt_sync_read_dummy;
1434   tpl_row_mt->sync_write_ptr = av1_tpl_row_mt_sync_write_dummy;
1435 
1436   av1_setup_scale_factors_for_frame(&cm->sf_identity, cm->width, cm->height,
1437                                     cm->width, cm->height);
1438 
1439   if (frame_params->frame_type == KEY_FRAME) {
1440     av1_init_mv_probs(cm);
1441   }
1442   av1_fill_mv_costs(cm->fc, cm->features.cur_frame_force_integer_mv,
1443                     cm->features.allow_high_precision_mv, cpi->td.mb.mv_costs);
1444 
1445   // Backward propagation from tpl_group_frames to 1.
1446   for (int frame_idx = gf_group->index; frame_idx < tpl_gf_group_frames;
1447        ++frame_idx) {
1448     if (gf_group->update_type[frame_idx] == INTNL_OVERLAY_UPDATE ||
1449         gf_group->update_type[frame_idx] == OVERLAY_UPDATE)
1450       continue;
1451 
1452     init_mc_flow_dispenser(cpi, frame_idx, pframe_qindex);
1453     if (mt_info->num_workers > 1 && !cpi->sf.tpl_sf.allow_compound_pred) {
1454       tpl_row_mt->sync_read_ptr = av1_tpl_row_mt_sync_read;
1455       tpl_row_mt->sync_write_ptr = av1_tpl_row_mt_sync_write;
1456       av1_mc_flow_dispenser_mt(cpi);
1457     } else {
1458       mc_flow_dispenser(cpi);
1459     }
1460 
1461     aom_extend_frame_borders(tpl_data->tpl_frame[frame_idx].rec_picture,
1462                              av1_num_planes(cm));
1463   }
1464 
1465   for (int frame_idx = tpl_gf_group_frames - 1; frame_idx >= gf_group->index;
1466        --frame_idx) {
1467     if (gf_group->update_type[frame_idx] == INTNL_OVERLAY_UPDATE ||
1468         gf_group->update_type[frame_idx] == OVERLAY_UPDATE)
1469       continue;
1470 
1471     mc_flow_synthesizer(cpi, frame_idx);
1472   }
1473 
1474   av1_configure_buffer_updates(cpi, &this_frame_params.refresh_frame,
1475                                gf_group->update_type[gf_group->index],
1476                                frame_params->frame_type, 0);
1477   cm->current_frame.frame_type = frame_params->frame_type;
1478   cm->show_frame = frame_params->show_frame;
1479 
1480   if (cpi->common.tiles.large_scale) return 0;
1481   if (gf_group->max_layer_depth_allowed == 0) return 1;
1482   if (!gop_eval) return 0;
1483   assert(gf_group->arf_index >= 0);
1484 
1485   double beta[2] = { 0.0 };
1486   for (int frame_idx = gf_group->arf_index;
1487        frame_idx <= AOMMIN(tpl_gf_group_frames - 1, gf_group->arf_index + 1);
1488        ++frame_idx) {
1489     TplDepFrame *tpl_frame = &tpl_data->tpl_frame[frame_idx];
1490     TplDepStats *tpl_stats = tpl_frame->tpl_stats_ptr;
1491     int tpl_stride = tpl_frame->stride;
1492     int64_t intra_cost_base = 0;
1493     int64_t mc_dep_cost_base = 0;
1494     const int step = 1 << tpl_data->tpl_stats_block_mis_log2;
1495     const int row_step = step;
1496     const int col_step_sr =
1497         coded_to_superres_mi(step, cm->superres_scale_denominator);
1498     const int mi_cols_sr = av1_pixels_to_mi(cm->superres_upscaled_width);
1499 
1500     for (int row = 0; row < cm->mi_params.mi_rows; row += row_step) {
1501       for (int col = 0; col < mi_cols_sr; col += col_step_sr) {
1502         TplDepStats *this_stats = &tpl_stats[av1_tpl_ptr_pos(
1503             row, col, tpl_stride, tpl_data->tpl_stats_block_mis_log2)];
1504         int64_t mc_dep_delta =
1505             RDCOST(tpl_frame->base_rdmult, this_stats->mc_dep_rate,
1506                    this_stats->mc_dep_dist);
1507         intra_cost_base += (this_stats->recrf_dist << RDDIV_BITS);
1508         mc_dep_cost_base +=
1509             (this_stats->recrf_dist << RDDIV_BITS) + mc_dep_delta;
1510       }
1511     }
1512     if (intra_cost_base == 0) {
1513       // This should happen very rarely and if it happens, assign a dummy value
1514       // to it since it probably wouldn't influence things much
1515       beta[frame_idx - gf_group->arf_index] = 0;
1516     } else {
1517       beta[frame_idx - gf_group->arf_index] =
1518           (double)mc_dep_cost_base / intra_cost_base;
1519     }
1520   }
1521 
1522 #if CONFIG_COLLECT_COMPONENT_TIMING
1523   end_timing(cpi, av1_tpl_setup_stats_time);
1524 #endif
1525 
1526   // Allow larger GOP size if the base layer ARF has higher dependency factor
1527   // than the intermediate ARF and both ARFs have reasonably high dependency
1528   // factors.
1529   return (beta[0] >= beta[1] + 0.7) && beta[0] > 8.0;
1530 }
1531 
av1_tpl_rdmult_setup(AV1_COMP * cpi)1532 void av1_tpl_rdmult_setup(AV1_COMP *cpi) {
1533   const AV1_COMMON *const cm = &cpi->common;
1534   const GF_GROUP *const gf_group = &cpi->gf_group;
1535   const int tpl_idx = gf_group->index;
1536 
1537   assert(IMPLIES(gf_group->size > 0, tpl_idx < gf_group->size));
1538 
1539   TplParams *const tpl_data = &cpi->tpl_data;
1540   const TplDepFrame *const tpl_frame = &tpl_data->tpl_frame[tpl_idx];
1541 
1542   if (!tpl_frame->is_valid) return;
1543 
1544   const TplDepStats *const tpl_stats = tpl_frame->tpl_stats_ptr;
1545   const int tpl_stride = tpl_frame->stride;
1546   const int mi_cols_sr = av1_pixels_to_mi(cm->superres_upscaled_width);
1547 
1548   const int block_size = BLOCK_16X16;
1549   const int num_mi_w = mi_size_wide[block_size];
1550   const int num_mi_h = mi_size_high[block_size];
1551   const int num_cols = (mi_cols_sr + num_mi_w - 1) / num_mi_w;
1552   const int num_rows = (cm->mi_params.mi_rows + num_mi_h - 1) / num_mi_h;
1553   const double c = 1.2;
1554   const int step = 1 << tpl_data->tpl_stats_block_mis_log2;
1555 
1556   aom_clear_system_state();
1557 
1558   // Loop through each 'block_size' X 'block_size' block.
1559   for (int row = 0; row < num_rows; row++) {
1560     for (int col = 0; col < num_cols; col++) {
1561       double intra_cost = 0.0, mc_dep_cost = 0.0;
1562       // Loop through each mi block.
1563       for (int mi_row = row * num_mi_h; mi_row < (row + 1) * num_mi_h;
1564            mi_row += step) {
1565         for (int mi_col = col * num_mi_w; mi_col < (col + 1) * num_mi_w;
1566              mi_col += step) {
1567           if (mi_row >= cm->mi_params.mi_rows || mi_col >= mi_cols_sr) continue;
1568           const TplDepStats *this_stats = &tpl_stats[av1_tpl_ptr_pos(
1569               mi_row, mi_col, tpl_stride, tpl_data->tpl_stats_block_mis_log2)];
1570           int64_t mc_dep_delta =
1571               RDCOST(tpl_frame->base_rdmult, this_stats->mc_dep_rate,
1572                      this_stats->mc_dep_dist);
1573           intra_cost += (double)(this_stats->recrf_dist << RDDIV_BITS);
1574           mc_dep_cost +=
1575               (double)(this_stats->recrf_dist << RDDIV_BITS) + mc_dep_delta;
1576         }
1577       }
1578       const double rk = intra_cost / mc_dep_cost;
1579       const int index = row * num_cols + col;
1580       cpi->tpl_rdmult_scaling_factors[index] = rk / cpi->rd.r0 + c;
1581     }
1582   }
1583   aom_clear_system_state();
1584 }
1585 
av1_tpl_rdmult_setup_sb(AV1_COMP * cpi,MACROBLOCK * const x,BLOCK_SIZE sb_size,int mi_row,int mi_col)1586 void av1_tpl_rdmult_setup_sb(AV1_COMP *cpi, MACROBLOCK *const x,
1587                              BLOCK_SIZE sb_size, int mi_row, int mi_col) {
1588   AV1_COMMON *const cm = &cpi->common;
1589   GF_GROUP *gf_group = &cpi->gf_group;
1590   assert(IMPLIES(cpi->gf_group.size > 0,
1591                  cpi->gf_group.index < cpi->gf_group.size));
1592   const int tpl_idx = cpi->gf_group.index;
1593   TplDepFrame *tpl_frame = &cpi->tpl_data.tpl_frame[tpl_idx];
1594 
1595   if (tpl_frame->is_valid == 0) return;
1596   if (!is_frame_tpl_eligible(gf_group, gf_group->index)) return;
1597   if (tpl_idx >= MAX_TPL_FRAME_IDX) return;
1598   if (cpi->oxcf.q_cfg.aq_mode != NO_AQ) return;
1599 
1600   const int mi_col_sr =
1601       coded_to_superres_mi(mi_col, cm->superres_scale_denominator);
1602   const int mi_cols_sr = av1_pixels_to_mi(cm->superres_upscaled_width);
1603   const int sb_mi_width_sr = coded_to_superres_mi(
1604       mi_size_wide[sb_size], cm->superres_scale_denominator);
1605 
1606   const int bsize_base = BLOCK_16X16;
1607   const int num_mi_w = mi_size_wide[bsize_base];
1608   const int num_mi_h = mi_size_high[bsize_base];
1609   const int num_cols = (mi_cols_sr + num_mi_w - 1) / num_mi_w;
1610   const int num_rows = (cm->mi_params.mi_rows + num_mi_h - 1) / num_mi_h;
1611   const int num_bcols = (sb_mi_width_sr + num_mi_w - 1) / num_mi_w;
1612   const int num_brows = (mi_size_high[sb_size] + num_mi_h - 1) / num_mi_h;
1613   int row, col;
1614 
1615   double base_block_count = 0.0;
1616   double log_sum = 0.0;
1617 
1618   aom_clear_system_state();
1619   for (row = mi_row / num_mi_w;
1620        row < num_rows && row < mi_row / num_mi_w + num_brows; ++row) {
1621     for (col = mi_col_sr / num_mi_h;
1622          col < num_cols && col < mi_col_sr / num_mi_h + num_bcols; ++col) {
1623       const int index = row * num_cols + col;
1624       log_sum += log(cpi->tpl_rdmult_scaling_factors[index]);
1625       base_block_count += 1.0;
1626     }
1627   }
1628 
1629   const CommonQuantParams *quant_params = &cm->quant_params;
1630   const int orig_rdmult = av1_compute_rd_mult(
1631       cpi, quant_params->base_qindex + quant_params->y_dc_delta_q);
1632   const int new_rdmult =
1633       av1_compute_rd_mult(cpi, quant_params->base_qindex + x->delta_qindex +
1634                                    quant_params->y_dc_delta_q);
1635   const double scaling_factor = (double)new_rdmult / (double)orig_rdmult;
1636 
1637   double scale_adj = log(scaling_factor) - log_sum / base_block_count;
1638   scale_adj = exp(scale_adj);
1639 
1640   for (row = mi_row / num_mi_w;
1641        row < num_rows && row < mi_row / num_mi_w + num_brows; ++row) {
1642     for (col = mi_col_sr / num_mi_h;
1643          col < num_cols && col < mi_col_sr / num_mi_h + num_bcols; ++col) {
1644       const int index = row * num_cols + col;
1645       cpi->tpl_sb_rdmult_scaling_factors[index] =
1646           scale_adj * cpi->tpl_rdmult_scaling_factors[index];
1647     }
1648   }
1649   aom_clear_system_state();
1650 }
1651 
1652 #define EPSILON (0.0000001)
1653 
av1_exponential_entropy(double q_step,double b)1654 double av1_exponential_entropy(double q_step, double b) {
1655   aom_clear_system_state();
1656   double z = fmax(exp(-q_step / b), EPSILON);
1657   return -log2(1 - z) - z * log2(z) / (1 - z);
1658 }
1659 
av1_laplace_entropy(double q_step,double b,double zero_bin_ratio)1660 double av1_laplace_entropy(double q_step, double b, double zero_bin_ratio) {
1661   aom_clear_system_state();
1662   // zero bin's size is zero_bin_ratio * q_step
1663   // non-zero bin's size is q_step
1664   double z = fmax(exp(-zero_bin_ratio / 2 * q_step / b), EPSILON);
1665   double h = av1_exponential_entropy(q_step, b);
1666   double r = -(1 - z) * log2(1 - z) - z * log2(z) + z * (h + 1);
1667   return r;
1668 }
1669 
av1_laplace_estimate_frame_rate(int q_index,int block_count,const double * abs_coeff_mean,int coeff_num)1670 double av1_laplace_estimate_frame_rate(int q_index, int block_count,
1671                                        const double *abs_coeff_mean,
1672                                        int coeff_num) {
1673   aom_clear_system_state();
1674   double zero_bin_ratio = 2;
1675   double dc_q_step = av1_dc_quant_QTX(q_index, 0, AOM_BITS_8) / 4.;
1676   double ac_q_step = av1_ac_quant_QTX(q_index, 0, AOM_BITS_8) / 4.;
1677   double est_rate = 0;
1678   // dc coeff
1679   est_rate += av1_laplace_entropy(dc_q_step, abs_coeff_mean[0], zero_bin_ratio);
1680   // ac coeff
1681   for (int i = 1; i < coeff_num; ++i) {
1682     est_rate +=
1683         av1_laplace_entropy(ac_q_step, abs_coeff_mean[i], zero_bin_ratio);
1684   }
1685   est_rate *= block_count;
1686   return est_rate;
1687 }
1688