1 /*
2  * Copyright (c) 2016, Alliance for Open Media. All rights reserved
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <assert.h>
13 #include <math.h>
14 
15 #include "config/aom_dsp_rtcd.h"
16 #include "config/av1_rtcd.h"
17 
18 #include "aom_dsp/aom_dsp_common.h"
19 #include "aom_dsp/blend.h"
20 #include "aom_mem/aom_mem.h"
21 #include "aom_ports/aom_timer.h"
22 #include "aom_ports/mem.h"
23 #include "aom_ports/system_state.h"
24 
25 #include "av1/common/cfl.h"
26 #include "av1/common/common.h"
27 #include "av1/common/common_data.h"
28 #include "av1/common/entropy.h"
29 #include "av1/common/entropymode.h"
30 #include "av1/common/idct.h"
31 #include "av1/common/mvref_common.h"
32 #include "av1/common/obmc.h"
33 #include "av1/common/pred_common.h"
34 #include "av1/common/quant_common.h"
35 #include "av1/common/reconinter.h"
36 #include "av1/common/reconintra.h"
37 #include "av1/common/scan.h"
38 #include "av1/common/seg_common.h"
39 #include "av1/common/txb_common.h"
40 #include "av1/common/warped_motion.h"
41 
42 #include "av1/encoder/aq_variance.h"
43 #include "av1/encoder/av1_quantize.h"
44 #include "av1/encoder/cost.h"
45 #include "av1/encoder/encodemb.h"
46 #include "av1/encoder/encodemv.h"
47 #include "av1/encoder/encoder.h"
48 #include "av1/encoder/encodetxb.h"
49 #include "av1/encoder/hybrid_fwd_txfm.h"
50 #include "av1/encoder/mcomp.h"
51 #include "av1/encoder/ml.h"
52 #include "av1/encoder/palette.h"
53 #include "av1/encoder/pustats.h"
54 #include "av1/encoder/random.h"
55 #include "av1/encoder/ratectrl.h"
56 #include "av1/encoder/rd.h"
57 #include "av1/encoder/rdopt.h"
58 #include "av1/encoder/reconinter_enc.h"
59 #include "av1/encoder/tokenize.h"
60 #include "av1/encoder/tx_prune_model_weights.h"
61 
62 typedef void (*model_rd_for_sb_type)(
63     const AV1_COMP *const cpi, BLOCK_SIZE bsize, MACROBLOCK *x, MACROBLOCKD *xd,
64     int plane_from, int plane_to, int mi_row, int mi_col, int *out_rate_sum,
65     int64_t *out_dist_sum, int *skip_txfm_sb, int64_t *skip_sse_sb,
66     int *plane_rate, int64_t *plane_sse, int64_t *plane_dist);
67 typedef void (*model_rd_from_sse_type)(const AV1_COMP *const cpi,
68                                        const MACROBLOCK *const x,
69                                        BLOCK_SIZE plane_bsize, int plane,
70                                        int64_t sse, int num_samples, int *rate,
71                                        int64_t *dist);
72 
73 static void model_rd_for_sb(const AV1_COMP *const cpi, BLOCK_SIZE bsize,
74                             MACROBLOCK *x, MACROBLOCKD *xd, int plane_from,
75                             int plane_to, int mi_row, int mi_col,
76                             int *out_rate_sum, int64_t *out_dist_sum,
77                             int *skip_txfm_sb, int64_t *skip_sse_sb,
78                             int *plane_rate, int64_t *plane_sse,
79                             int64_t *plane_dist);
80 static void model_rd_for_sb_with_curvfit(
81     const AV1_COMP *const cpi, BLOCK_SIZE bsize, MACROBLOCK *x, MACROBLOCKD *xd,
82     int plane_from, int plane_to, int mi_row, int mi_col, int *out_rate_sum,
83     int64_t *out_dist_sum, int *skip_txfm_sb, int64_t *skip_sse_sb,
84     int *plane_rate, int64_t *plane_sse, int64_t *plane_dist);
85 static void model_rd_for_sb_with_surffit(
86     const AV1_COMP *const cpi, BLOCK_SIZE bsize, MACROBLOCK *x, MACROBLOCKD *xd,
87     int plane_from, int plane_to, int mi_row, int mi_col, int *out_rate_sum,
88     int64_t *out_dist_sum, int *skip_txfm_sb, int64_t *skip_sse_sb,
89     int *plane_rate, int64_t *plane_sse, int64_t *plane_dist);
90 static void model_rd_for_sb_with_dnn(
91     const AV1_COMP *const cpi, BLOCK_SIZE bsize, MACROBLOCK *x, MACROBLOCKD *xd,
92     int plane_from, int plane_to, int mi_row, int mi_col, int *out_rate_sum,
93     int64_t *out_dist_sum, int *skip_txfm_sb, int64_t *skip_sse_sb,
94     int *plane_rate, int64_t *plane_sse, int64_t *plane_dist);
95 static void model_rd_for_sb_with_fullrdy(
96     const AV1_COMP *const cpi, BLOCK_SIZE bsize, MACROBLOCK *x, MACROBLOCKD *xd,
97     int plane_from, int plane_to, int mi_row, int mi_col, int *out_rate_sum,
98     int64_t *out_dist_sum, int *skip_txfm_sb, int64_t *skip_sse_sb,
99     int *plane_rate, int64_t *plane_sse, int64_t *plane_dist);
100 static void model_rd_from_sse(const AV1_COMP *const cpi,
101                               const MACROBLOCK *const x, BLOCK_SIZE plane_bsize,
102                               int plane, int64_t sse, int num_samples,
103                               int *rate, int64_t *dist);
104 static void model_rd_with_dnn(const AV1_COMP *const cpi,
105                               const MACROBLOCK *const x, BLOCK_SIZE plane_bsize,
106                               int plane, int64_t sse, int num_samples,
107                               int *rate, int64_t *dist);
108 static void model_rd_with_curvfit(const AV1_COMP *const cpi,
109                                   const MACROBLOCK *const x,
110                                   BLOCK_SIZE plane_bsize, int plane,
111                                   int64_t sse, int num_samples, int *rate,
112                                   int64_t *dist);
113 static void model_rd_with_surffit(const AV1_COMP *const cpi,
114                                   const MACROBLOCK *const x,
115                                   BLOCK_SIZE plane_bsize, int plane,
116                                   int64_t sse, int num_samples, int *rate,
117                                   int64_t *dist);
118 
119 typedef enum {
120   MODELRD_LEGACY,
121   MODELRD_CURVFIT,
122   MODELRD_SUFFIT,
123   MODELRD_DNN,
124   MODELRD_FULLRDY,
125   MODELRD_TYPES
126 } ModelRdType;
127 
128 static model_rd_for_sb_type model_rd_sb_fn[MODELRD_TYPES] = {
129   model_rd_for_sb, model_rd_for_sb_with_curvfit, model_rd_for_sb_with_surffit,
130   model_rd_for_sb_with_dnn, model_rd_for_sb_with_fullrdy
131 };
132 
133 static model_rd_from_sse_type model_rd_sse_fn[MODELRD_TYPES] = {
134   model_rd_from_sse, model_rd_with_curvfit, model_rd_with_surffit,
135   model_rd_with_dnn, NULL
136 };
137 
138 // 0: Legacy model
139 // 1: Curve fit model
140 // 2: Surface fit model
141 // 3: DNN regression model
142 // 4: Full rd model
143 #define MODELRD_TYPE_INTERP_FILTER 1
144 #define MODELRD_TYPE_TX_SEARCH_PRUNE 2
145 #define MODELRD_TYPE_MASKED_COMPOUND 1
146 #define MODELRD_TYPE_INTERINTRA 1
147 #define MODELRD_TYPE_INTRA 1
148 #define MODELRD_TYPE_JNT_COMPOUND 1
149 
150 #define DUAL_FILTER_SET_SIZE (SWITCHABLE_FILTERS * SWITCHABLE_FILTERS)
151 static const InterpFilters filter_sets[DUAL_FILTER_SET_SIZE] = {
152   0x00000000, 0x00010000, 0x00020000,  // y = 0
153   0x00000001, 0x00010001, 0x00020001,  // y = 1
154   0x00000002, 0x00010002, 0x00020002,  // y = 2
155 };
156 
157 #define SECOND_REF_FRAME_MASK                                         \
158   ((1 << ALTREF_FRAME) | (1 << ALTREF2_FRAME) | (1 << BWDREF_FRAME) | \
159    (1 << GOLDEN_FRAME) | (1 << LAST2_FRAME) | 0x01)
160 
161 #define ANGLE_SKIP_THRESH 10
162 
163 static const double ADST_FLIP_SVM[8] = {
164   /* vertical */
165   -6.6623, -2.8062, -3.2531, 3.1671,
166   /* horizontal */
167   -7.7051, -3.2234, -3.6193, 3.4533
168 };
169 
170 typedef struct {
171   PREDICTION_MODE mode;
172   MV_REFERENCE_FRAME ref_frame[2];
173 } MODE_DEFINITION;
174 
175 typedef struct {
176   MV_REFERENCE_FRAME ref_frame[2];
177 } REF_DEFINITION;
178 
179 typedef enum {
180   FTXS_NONE = 0,
181   FTXS_DCT_AND_1D_DCT_ONLY = 1 << 0,
182   FTXS_DISABLE_TRELLIS_OPT = 1 << 1,
183   FTXS_USE_TRANSFORM_DOMAIN = 1 << 2
184 } FAST_TX_SEARCH_MODE;
185 
186 static void select_tx_type_yrd(const AV1_COMP *cpi, MACROBLOCK *x,
187                                RD_STATS *rd_stats, BLOCK_SIZE bsize, int mi_row,
188                                int mi_col, int64_t ref_best_rd);
189 
190 static int inter_block_uvrd(const AV1_COMP *cpi, MACROBLOCK *x,
191                             RD_STATS *rd_stats, BLOCK_SIZE bsize,
192                             int64_t non_skip_ref_best_rd,
193                             int64_t skip_ref_best_rd,
194                             FAST_TX_SEARCH_MODE ftxs_mode);
195 
196 struct rdcost_block_args {
197   const AV1_COMP *cpi;
198   MACROBLOCK *x;
199   ENTROPY_CONTEXT t_above[MAX_MIB_SIZE];
200   ENTROPY_CONTEXT t_left[MAX_MIB_SIZE];
201   RD_STATS rd_stats;
202   int64_t this_rd;
203   int64_t best_rd;
204   int exit_early;
205   int incomplete_exit;
206   int use_fast_coef_costing;
207   FAST_TX_SEARCH_MODE ftxs_mode;
208 };
209 
210 #define LAST_NEW_MV_INDEX 6
211 static const MODE_DEFINITION av1_mode_order[MAX_MODES] = {
212   { NEARESTMV, { LAST_FRAME, NONE_FRAME } },
213   { NEARESTMV, { LAST2_FRAME, NONE_FRAME } },
214   { NEARESTMV, { LAST3_FRAME, NONE_FRAME } },
215   { NEARESTMV, { BWDREF_FRAME, NONE_FRAME } },
216   { NEARESTMV, { ALTREF2_FRAME, NONE_FRAME } },
217   { NEARESTMV, { ALTREF_FRAME, NONE_FRAME } },
218   { NEARESTMV, { GOLDEN_FRAME, NONE_FRAME } },
219 
220   { NEWMV, { LAST_FRAME, NONE_FRAME } },
221   { NEWMV, { LAST2_FRAME, NONE_FRAME } },
222   { NEWMV, { LAST3_FRAME, NONE_FRAME } },
223   { NEWMV, { BWDREF_FRAME, NONE_FRAME } },
224   { NEWMV, { ALTREF2_FRAME, NONE_FRAME } },
225   { NEWMV, { ALTREF_FRAME, NONE_FRAME } },
226   { NEWMV, { GOLDEN_FRAME, NONE_FRAME } },
227 
228   { NEARMV, { LAST_FRAME, NONE_FRAME } },
229   { NEARMV, { LAST2_FRAME, NONE_FRAME } },
230   { NEARMV, { LAST3_FRAME, NONE_FRAME } },
231   { NEARMV, { BWDREF_FRAME, NONE_FRAME } },
232   { NEARMV, { ALTREF2_FRAME, NONE_FRAME } },
233   { NEARMV, { ALTREF_FRAME, NONE_FRAME } },
234   { NEARMV, { GOLDEN_FRAME, NONE_FRAME } },
235 
236   { GLOBALMV, { LAST_FRAME, NONE_FRAME } },
237   { GLOBALMV, { LAST2_FRAME, NONE_FRAME } },
238   { GLOBALMV, { LAST3_FRAME, NONE_FRAME } },
239   { GLOBALMV, { BWDREF_FRAME, NONE_FRAME } },
240   { GLOBALMV, { ALTREF2_FRAME, NONE_FRAME } },
241   { GLOBALMV, { GOLDEN_FRAME, NONE_FRAME } },
242   { GLOBALMV, { ALTREF_FRAME, NONE_FRAME } },
243 
244   // TODO(zoeliu): May need to reconsider the order on the modes to check
245 
246   { NEAREST_NEARESTMV, { LAST_FRAME, ALTREF_FRAME } },
247   { NEAREST_NEARESTMV, { LAST2_FRAME, ALTREF_FRAME } },
248   { NEAREST_NEARESTMV, { LAST3_FRAME, ALTREF_FRAME } },
249   { NEAREST_NEARESTMV, { GOLDEN_FRAME, ALTREF_FRAME } },
250   { NEAREST_NEARESTMV, { LAST_FRAME, BWDREF_FRAME } },
251   { NEAREST_NEARESTMV, { LAST2_FRAME, BWDREF_FRAME } },
252   { NEAREST_NEARESTMV, { LAST3_FRAME, BWDREF_FRAME } },
253   { NEAREST_NEARESTMV, { GOLDEN_FRAME, BWDREF_FRAME } },
254   { NEAREST_NEARESTMV, { LAST_FRAME, ALTREF2_FRAME } },
255   { NEAREST_NEARESTMV, { LAST2_FRAME, ALTREF2_FRAME } },
256   { NEAREST_NEARESTMV, { LAST3_FRAME, ALTREF2_FRAME } },
257   { NEAREST_NEARESTMV, { GOLDEN_FRAME, ALTREF2_FRAME } },
258 
259   { NEAREST_NEARESTMV, { LAST_FRAME, LAST2_FRAME } },
260   { NEAREST_NEARESTMV, { LAST_FRAME, LAST3_FRAME } },
261   { NEAREST_NEARESTMV, { LAST_FRAME, GOLDEN_FRAME } },
262   { NEAREST_NEARESTMV, { BWDREF_FRAME, ALTREF_FRAME } },
263 
264   { NEAR_NEARMV, { LAST_FRAME, ALTREF_FRAME } },
265   { NEW_NEARESTMV, { LAST_FRAME, ALTREF_FRAME } },
266   { NEAREST_NEWMV, { LAST_FRAME, ALTREF_FRAME } },
267   { NEW_NEARMV, { LAST_FRAME, ALTREF_FRAME } },
268   { NEAR_NEWMV, { LAST_FRAME, ALTREF_FRAME } },
269   { NEW_NEWMV, { LAST_FRAME, ALTREF_FRAME } },
270   { GLOBAL_GLOBALMV, { LAST_FRAME, ALTREF_FRAME } },
271 
272   { NEAR_NEARMV, { LAST2_FRAME, ALTREF_FRAME } },
273   { NEW_NEARESTMV, { LAST2_FRAME, ALTREF_FRAME } },
274   { NEAREST_NEWMV, { LAST2_FRAME, ALTREF_FRAME } },
275   { NEW_NEARMV, { LAST2_FRAME, ALTREF_FRAME } },
276   { NEAR_NEWMV, { LAST2_FRAME, ALTREF_FRAME } },
277   { NEW_NEWMV, { LAST2_FRAME, ALTREF_FRAME } },
278   { GLOBAL_GLOBALMV, { LAST2_FRAME, ALTREF_FRAME } },
279 
280   { NEAR_NEARMV, { LAST3_FRAME, ALTREF_FRAME } },
281   { NEW_NEARESTMV, { LAST3_FRAME, ALTREF_FRAME } },
282   { NEAREST_NEWMV, { LAST3_FRAME, ALTREF_FRAME } },
283   { NEW_NEARMV, { LAST3_FRAME, ALTREF_FRAME } },
284   { NEAR_NEWMV, { LAST3_FRAME, ALTREF_FRAME } },
285   { NEW_NEWMV, { LAST3_FRAME, ALTREF_FRAME } },
286   { GLOBAL_GLOBALMV, { LAST3_FRAME, ALTREF_FRAME } },
287 
288   { NEAR_NEARMV, { GOLDEN_FRAME, ALTREF_FRAME } },
289   { NEW_NEARESTMV, { GOLDEN_FRAME, ALTREF_FRAME } },
290   { NEAREST_NEWMV, { GOLDEN_FRAME, ALTREF_FRAME } },
291   { NEW_NEARMV, { GOLDEN_FRAME, ALTREF_FRAME } },
292   { NEAR_NEWMV, { GOLDEN_FRAME, ALTREF_FRAME } },
293   { NEW_NEWMV, { GOLDEN_FRAME, ALTREF_FRAME } },
294   { GLOBAL_GLOBALMV, { GOLDEN_FRAME, ALTREF_FRAME } },
295 
296   { NEAR_NEARMV, { LAST_FRAME, BWDREF_FRAME } },
297   { NEW_NEARESTMV, { LAST_FRAME, BWDREF_FRAME } },
298   { NEAREST_NEWMV, { LAST_FRAME, BWDREF_FRAME } },
299   { NEW_NEARMV, { LAST_FRAME, BWDREF_FRAME } },
300   { NEAR_NEWMV, { LAST_FRAME, BWDREF_FRAME } },
301   { NEW_NEWMV, { LAST_FRAME, BWDREF_FRAME } },
302   { GLOBAL_GLOBALMV, { LAST_FRAME, BWDREF_FRAME } },
303 
304   { NEAR_NEARMV, { LAST2_FRAME, BWDREF_FRAME } },
305   { NEW_NEARESTMV, { LAST2_FRAME, BWDREF_FRAME } },
306   { NEAREST_NEWMV, { LAST2_FRAME, BWDREF_FRAME } },
307   { NEW_NEARMV, { LAST2_FRAME, BWDREF_FRAME } },
308   { NEAR_NEWMV, { LAST2_FRAME, BWDREF_FRAME } },
309   { NEW_NEWMV, { LAST2_FRAME, BWDREF_FRAME } },
310   { GLOBAL_GLOBALMV, { LAST2_FRAME, BWDREF_FRAME } },
311 
312   { NEAR_NEARMV, { LAST3_FRAME, BWDREF_FRAME } },
313   { NEW_NEARESTMV, { LAST3_FRAME, BWDREF_FRAME } },
314   { NEAREST_NEWMV, { LAST3_FRAME, BWDREF_FRAME } },
315   { NEW_NEARMV, { LAST3_FRAME, BWDREF_FRAME } },
316   { NEAR_NEWMV, { LAST3_FRAME, BWDREF_FRAME } },
317   { NEW_NEWMV, { LAST3_FRAME, BWDREF_FRAME } },
318   { GLOBAL_GLOBALMV, { LAST3_FRAME, BWDREF_FRAME } },
319 
320   { NEAR_NEARMV, { GOLDEN_FRAME, BWDREF_FRAME } },
321   { NEW_NEARESTMV, { GOLDEN_FRAME, BWDREF_FRAME } },
322   { NEAREST_NEWMV, { GOLDEN_FRAME, BWDREF_FRAME } },
323   { NEW_NEARMV, { GOLDEN_FRAME, BWDREF_FRAME } },
324   { NEAR_NEWMV, { GOLDEN_FRAME, BWDREF_FRAME } },
325   { NEW_NEWMV, { GOLDEN_FRAME, BWDREF_FRAME } },
326   { GLOBAL_GLOBALMV, { GOLDEN_FRAME, BWDREF_FRAME } },
327 
328   { NEAR_NEARMV, { LAST_FRAME, ALTREF2_FRAME } },
329   { NEW_NEARESTMV, { LAST_FRAME, ALTREF2_FRAME } },
330   { NEAREST_NEWMV, { LAST_FRAME, ALTREF2_FRAME } },
331   { NEW_NEARMV, { LAST_FRAME, ALTREF2_FRAME } },
332   { NEAR_NEWMV, { LAST_FRAME, ALTREF2_FRAME } },
333   { NEW_NEWMV, { LAST_FRAME, ALTREF2_FRAME } },
334   { GLOBAL_GLOBALMV, { LAST_FRAME, ALTREF2_FRAME } },
335 
336   { NEAR_NEARMV, { LAST2_FRAME, ALTREF2_FRAME } },
337   { NEW_NEARESTMV, { LAST2_FRAME, ALTREF2_FRAME } },
338   { NEAREST_NEWMV, { LAST2_FRAME, ALTREF2_FRAME } },
339   { NEW_NEARMV, { LAST2_FRAME, ALTREF2_FRAME } },
340   { NEAR_NEWMV, { LAST2_FRAME, ALTREF2_FRAME } },
341   { NEW_NEWMV, { LAST2_FRAME, ALTREF2_FRAME } },
342   { GLOBAL_GLOBALMV, { LAST2_FRAME, ALTREF2_FRAME } },
343 
344   { NEAR_NEARMV, { LAST3_FRAME, ALTREF2_FRAME } },
345   { NEW_NEARESTMV, { LAST3_FRAME, ALTREF2_FRAME } },
346   { NEAREST_NEWMV, { LAST3_FRAME, ALTREF2_FRAME } },
347   { NEW_NEARMV, { LAST3_FRAME, ALTREF2_FRAME } },
348   { NEAR_NEWMV, { LAST3_FRAME, ALTREF2_FRAME } },
349   { NEW_NEWMV, { LAST3_FRAME, ALTREF2_FRAME } },
350   { GLOBAL_GLOBALMV, { LAST3_FRAME, ALTREF2_FRAME } },
351 
352   { NEAR_NEARMV, { GOLDEN_FRAME, ALTREF2_FRAME } },
353   { NEW_NEARESTMV, { GOLDEN_FRAME, ALTREF2_FRAME } },
354   { NEAREST_NEWMV, { GOLDEN_FRAME, ALTREF2_FRAME } },
355   { NEW_NEARMV, { GOLDEN_FRAME, ALTREF2_FRAME } },
356   { NEAR_NEWMV, { GOLDEN_FRAME, ALTREF2_FRAME } },
357   { NEW_NEWMV, { GOLDEN_FRAME, ALTREF2_FRAME } },
358   { GLOBAL_GLOBALMV, { GOLDEN_FRAME, ALTREF2_FRAME } },
359 
360   { NEAR_NEARMV, { LAST_FRAME, LAST2_FRAME } },
361   { NEW_NEARESTMV, { LAST_FRAME, LAST2_FRAME } },
362   { NEAREST_NEWMV, { LAST_FRAME, LAST2_FRAME } },
363   { NEW_NEARMV, { LAST_FRAME, LAST2_FRAME } },
364   { NEAR_NEWMV, { LAST_FRAME, LAST2_FRAME } },
365   { NEW_NEWMV, { LAST_FRAME, LAST2_FRAME } },
366   { GLOBAL_GLOBALMV, { LAST_FRAME, LAST2_FRAME } },
367 
368   { NEAR_NEARMV, { LAST_FRAME, LAST3_FRAME } },
369   { NEW_NEARESTMV, { LAST_FRAME, LAST3_FRAME } },
370   { NEAREST_NEWMV, { LAST_FRAME, LAST3_FRAME } },
371   { NEW_NEARMV, { LAST_FRAME, LAST3_FRAME } },
372   { NEAR_NEWMV, { LAST_FRAME, LAST3_FRAME } },
373   { NEW_NEWMV, { LAST_FRAME, LAST3_FRAME } },
374   { GLOBAL_GLOBALMV, { LAST_FRAME, LAST3_FRAME } },
375 
376   { NEAR_NEARMV, { LAST_FRAME, GOLDEN_FRAME } },
377   { NEW_NEARESTMV, { LAST_FRAME, GOLDEN_FRAME } },
378   { NEAREST_NEWMV, { LAST_FRAME, GOLDEN_FRAME } },
379   { NEW_NEARMV, { LAST_FRAME, GOLDEN_FRAME } },
380   { NEAR_NEWMV, { LAST_FRAME, GOLDEN_FRAME } },
381   { NEW_NEWMV, { LAST_FRAME, GOLDEN_FRAME } },
382   { GLOBAL_GLOBALMV, { LAST_FRAME, GOLDEN_FRAME } },
383 
384   { NEAR_NEARMV, { BWDREF_FRAME, ALTREF_FRAME } },
385   { NEW_NEARESTMV, { BWDREF_FRAME, ALTREF_FRAME } },
386   { NEAREST_NEWMV, { BWDREF_FRAME, ALTREF_FRAME } },
387   { NEW_NEARMV, { BWDREF_FRAME, ALTREF_FRAME } },
388   { NEAR_NEWMV, { BWDREF_FRAME, ALTREF_FRAME } },
389   { NEW_NEWMV, { BWDREF_FRAME, ALTREF_FRAME } },
390   { GLOBAL_GLOBALMV, { BWDREF_FRAME, ALTREF_FRAME } },
391 
392   // intra modes
393   { DC_PRED, { INTRA_FRAME, NONE_FRAME } },
394   { PAETH_PRED, { INTRA_FRAME, NONE_FRAME } },
395   { SMOOTH_PRED, { INTRA_FRAME, NONE_FRAME } },
396   { SMOOTH_V_PRED, { INTRA_FRAME, NONE_FRAME } },
397   { SMOOTH_H_PRED, { INTRA_FRAME, NONE_FRAME } },
398   { H_PRED, { INTRA_FRAME, NONE_FRAME } },
399   { V_PRED, { INTRA_FRAME, NONE_FRAME } },
400   { D135_PRED, { INTRA_FRAME, NONE_FRAME } },
401   { D203_PRED, { INTRA_FRAME, NONE_FRAME } },
402   { D157_PRED, { INTRA_FRAME, NONE_FRAME } },
403   { D67_PRED, { INTRA_FRAME, NONE_FRAME } },
404   { D113_PRED, { INTRA_FRAME, NONE_FRAME } },
405   { D45_PRED, { INTRA_FRAME, NONE_FRAME } },
406 };
407 
408 static const int16_t intra_to_mode_idx[INTRA_MODE_NUM] = {
409   7,    // DC_PRED,
410   134,  // V_PRED,
411   133,  // H_PRED,
412   140,  // D45_PRED,
413   135,  // D135_PRED,
414   139,  // D113_PRED,
415   137,  // D157_PRED,
416   136,  // D203_PRED,
417   138,  // D67_PRED,
418   46,   // SMOOTH_PRED,
419   47,   // SMOOTH_V_PRED,
420   48,   // SMOOTH_H_PRED,
421   45,   // PAETH_PRED,
422 };
423 
424 /* clang-format off */
425 static const int16_t single_inter_to_mode_idx[SINGLE_INTER_MODE_NUM]
426                                              [REF_FRAMES] = {
427   // NEARESTMV,
428   { -1, 0, 1, 2, 6, 3, 4, 5, },
429   // NEARMV,
430   { -1, 15, 16, 17, 21, 18, 19, 20, },
431   // GLOBALMV,
432   { -1, 22, 23, 24, 27, 25, 26, 28, },
433   // NEWMV,
434   { -1, 8, 9, 10, 14, 11, 12, 13, },
435 };
436 /* clang-format on */
437 
438 /* clang-format off */
439 static const int16_t comp_inter_to_mode_idx[COMP_INTER_MODE_NUM][REF_FRAMES]
440                                      [REF_FRAMES] = {
441   // NEAREST_NEARESTMV,
442   {
443       { -1, -1, -1, -1, -1, -1, -1, -1, },
444       { -1, -1, 41, 42, 43, 33, 37, 29, },
445       { -1, -1, -1, -1, -1, 34, 38, 30, },
446       { -1, -1, -1, -1, -1, 35, 39, 31, },
447       { -1, -1, -1, -1, -1, 36, 40, 32, },
448       { -1, -1, -1, -1, -1, -1, -1, 44, },
449       { -1, -1, -1, -1, -1, -1, -1, -1, },
450       { -1, -1, -1, -1, -1, -1, -1, -1, },
451   },
452   // NEAR_NEARMV,
453   {
454       { -1, -1, -1, -1, -1, -1, -1, -1, },
455       { -1, -1, 141, 148, 155, 77, 105, 49, },
456       { -1, -1, -1, -1, -1, 84, 112, 56, },
457       { -1, -1, -1, -1, -1, 91, 119, 63, },
458       { -1, -1, -1, -1, -1, 98, 126, 70, },
459       { -1, -1, -1, -1, -1, -1, -1, 162, },
460       { -1, -1, -1, -1, -1, -1, -1, -1, },
461       { -1, -1, -1, -1, -1, -1, -1, -1, },
462   },
463   // NEAREST_NEWMV,
464   {
465       { -1, -1, -1, -1, -1, -1, -1, -1, },
466       { -1, -1, 143, 150, 157, 79, 107, 51, },
467       { -1, -1, -1, -1, -1, 86, 114, 58, },
468       { -1, -1, -1, -1, -1, 93, 121, 65, },
469       { -1, -1, -1, -1, -1, 100, 128, 72, },
470       { -1, -1, -1, -1, -1, -1, -1, 164, },
471       { -1, -1, -1, -1, -1, -1, -1, -1, },
472       { -1, -1, -1, -1, -1, -1, -1, -1, },
473   },
474   // NEW_NEARESTMV,
475   {
476       { -1, -1, -1, -1, -1, -1, -1, -1, },
477       { -1, -1, 142, 149, 156, 78, 106, 50, },
478       { -1, -1, -1, -1, -1, 85, 113, 57, },
479       { -1, -1, -1, -1, -1, 92, 120, 64, },
480       { -1, -1, -1, -1, -1, 99, 127, 71, },
481       { -1, -1, -1, -1, -1, -1, -1, 163, },
482       { -1, -1, -1, -1, -1, -1, -1, -1, },
483       { -1, -1, -1, -1, -1, -1, -1, -1, },
484   },
485   // NEAR_NEWMV,
486   {
487       { -1, -1, -1, -1, -1, -1, -1, -1, },
488       { -1, -1, 145, 152, 159, 81, 109, 53, },
489       { -1, -1, -1, -1, -1, 88, 116, 60, },
490       { -1, -1, -1, -1, -1, 95, 123, 67, },
491       { -1, -1, -1, -1, -1, 102, 130, 74, },
492       { -1, -1, -1, -1, -1, -1, -1, 166, },
493       { -1, -1, -1, -1, -1, -1, -1, -1, },
494       { -1, -1, -1, -1, -1, -1, -1, -1, },
495   },
496   // NEW_NEARMV,
497   {
498       { -1, -1, -1, -1, -1, -1, -1, -1, },
499       { -1, -1, 144, 151, 158, 80, 108, 52, },
500       { -1, -1, -1, -1, -1, 87, 115, 59, },
501       { -1, -1, -1, -1, -1, 94, 122, 66, },
502       { -1, -1, -1, -1, -1, 101, 129, 73, },
503       { -1, -1, -1, -1, -1, -1, -1, 165, },
504       { -1, -1, -1, -1, -1, -1, -1, -1, },
505       { -1, -1, -1, -1, -1, -1, -1, -1, },
506   },
507   // GLOBAL_GLOBALMV,
508   {
509       { -1, -1, -1, -1, -1, -1, -1, -1, },
510       { -1, -1, 147, 154, 161, 83, 111, 55, },
511       { -1, -1, -1, -1, -1, 90, 118, 62, },
512       { -1, -1, -1, -1, -1, 97, 125, 69, },
513       { -1, -1, -1, -1, -1, 104, 132, 76, },
514       { -1, -1, -1, -1, -1, -1, -1, 168, },
515       { -1, -1, -1, -1, -1, -1, -1, -1, },
516       { -1, -1, -1, -1, -1, -1, -1, -1, },
517   },
518   // NEW_NEWMV,
519   {
520       { -1, -1, -1, -1, -1, -1, -1, -1, },
521       { -1, -1, 146, 153, 160, 82, 110, 54, },
522       { -1, -1, -1, -1, -1, 89, 117, 61, },
523       { -1, -1, -1, -1, -1, 96, 124, 68, },
524       { -1, -1, -1, -1, -1, 103, 131, 75, },
525       { -1, -1, -1, -1, -1, -1, -1, 167, },
526       { -1, -1, -1, -1, -1, -1, -1, -1, },
527       { -1, -1, -1, -1, -1, -1, -1, -1, },
528   },
529 };
530 /* clang-format on */
531 
get_prediction_mode_idx(PREDICTION_MODE this_mode,MV_REFERENCE_FRAME ref_frame,MV_REFERENCE_FRAME second_ref_frame)532 static int get_prediction_mode_idx(PREDICTION_MODE this_mode,
533                                    MV_REFERENCE_FRAME ref_frame,
534                                    MV_REFERENCE_FRAME second_ref_frame) {
535   if (this_mode < INTRA_MODE_END) {
536     assert(ref_frame == INTRA_FRAME);
537     assert(second_ref_frame == NONE_FRAME);
538     return intra_to_mode_idx[this_mode - INTRA_MODE_START];
539   }
540   if (this_mode >= SINGLE_INTER_MODE_START &&
541       this_mode < SINGLE_INTER_MODE_END) {
542     assert((ref_frame > INTRA_FRAME) && (ref_frame <= ALTREF_FRAME));
543     return single_inter_to_mode_idx[this_mode - SINGLE_INTER_MODE_START]
544                                    [ref_frame];
545   }
546   if (this_mode >= COMP_INTER_MODE_START && this_mode < COMP_INTER_MODE_END) {
547     assert((ref_frame > INTRA_FRAME) && (ref_frame <= ALTREF_FRAME));
548     assert((second_ref_frame > INTRA_FRAME) &&
549            (second_ref_frame <= ALTREF_FRAME));
550     return comp_inter_to_mode_idx[this_mode - COMP_INTER_MODE_START][ref_frame]
551                                  [second_ref_frame];
552   }
553   assert(0);
554   return -1;
555 }
556 
557 static const PREDICTION_MODE intra_rd_search_mode_order[INTRA_MODES] = {
558   DC_PRED,       H_PRED,        V_PRED,    SMOOTH_PRED, PAETH_PRED,
559   SMOOTH_V_PRED, SMOOTH_H_PRED, D135_PRED, D203_PRED,   D157_PRED,
560   D67_PRED,      D113_PRED,     D45_PRED,
561 };
562 
563 static const UV_PREDICTION_MODE uv_rd_search_mode_order[UV_INTRA_MODES] = {
564   UV_DC_PRED,     UV_CFL_PRED,   UV_H_PRED,        UV_V_PRED,
565   UV_SMOOTH_PRED, UV_PAETH_PRED, UV_SMOOTH_V_PRED, UV_SMOOTH_H_PRED,
566   UV_D135_PRED,   UV_D203_PRED,  UV_D157_PRED,     UV_D67_PRED,
567   UV_D113_PRED,   UV_D45_PRED,
568 };
569 
570 typedef struct SingleInterModeState {
571   int64_t rd;
572   MV_REFERENCE_FRAME ref_frame;
573   int valid;
574 } SingleInterModeState;
575 
576 typedef struct InterModeSearchState {
577   int64_t best_rd;
578   MB_MODE_INFO best_mbmode;
579   int best_rate_y;
580   int best_rate_uv;
581   int best_mode_skippable;
582   int best_skip2;
583   int best_mode_index;
584   int skip_intra_modes;
585   int num_available_refs;
586   int64_t dist_refs[REF_FRAMES];
587   int dist_order_refs[REF_FRAMES];
588   int64_t mode_threshold[MAX_MODES];
589   PREDICTION_MODE best_intra_mode;
590   int64_t best_intra_rd;
591   int angle_stats_ready;
592   uint8_t directional_mode_skip_mask[INTRA_MODES];
593   unsigned int best_pred_sse;
594   int rate_uv_intra[TX_SIZES_ALL];
595   int rate_uv_tokenonly[TX_SIZES_ALL];
596   int64_t dist_uvs[TX_SIZES_ALL];
597   int skip_uvs[TX_SIZES_ALL];
598   UV_PREDICTION_MODE mode_uv[TX_SIZES_ALL];
599   PALETTE_MODE_INFO pmi_uv[TX_SIZES_ALL];
600   int8_t uv_angle_delta[TX_SIZES_ALL];
601   int64_t best_pred_rd[REFERENCE_MODES];
602   int64_t best_pred_diff[REFERENCE_MODES];
603   // Save a set of single_newmv for each checked ref_mv.
604   int_mv single_newmv[MAX_REF_MV_SERCH][REF_FRAMES];
605   int single_newmv_rate[MAX_REF_MV_SERCH][REF_FRAMES];
606   int single_newmv_valid[MAX_REF_MV_SERCH][REF_FRAMES];
607   int64_t modelled_rd[MB_MODE_COUNT][MAX_REF_MV_SERCH][REF_FRAMES];
608   // The rd of simple translation in single inter modes
609   int64_t simple_rd[MB_MODE_COUNT][MAX_REF_MV_SERCH][REF_FRAMES];
610 
611   // Single search results by [directions][modes][reference frames]
612   SingleInterModeState single_state[2][SINGLE_INTER_MODE_NUM][FWD_REFS];
613   int single_state_cnt[2][SINGLE_INTER_MODE_NUM];
614   SingleInterModeState single_state_modelled[2][SINGLE_INTER_MODE_NUM]
615                                             [FWD_REFS];
616   int single_state_modelled_cnt[2][SINGLE_INTER_MODE_NUM];
617 
618   MV_REFERENCE_FRAME single_rd_order[2][SINGLE_INTER_MODE_NUM][FWD_REFS];
619 } InterModeSearchState;
620 
621 #if CONFIG_COLLECT_INTER_MODE_RD_STATS
inter_mode_data_block_idx(BLOCK_SIZE bsize)622 int inter_mode_data_block_idx(BLOCK_SIZE bsize) {
623   if (bsize == BLOCK_8X8) return 1;
624   if (bsize == BLOCK_16X16) return 2;
625   if (bsize == BLOCK_32X32) return 3;
626   return -1;
627 }
628 
av1_inter_mode_data_init(TileDataEnc * tile_data)629 void av1_inter_mode_data_init(TileDataEnc *tile_data) {
630   for (int i = 0; i < BLOCK_SIZES_ALL; ++i) {
631     InterModeRdModel *md = &tile_data->inter_mode_rd_models[i];
632     md->ready = 0;
633     md->num = 0;
634     md->dist_sum = 0;
635     md->ld_sum = 0;
636     md->sse_sum = 0;
637     md->sse_sse_sum = 0;
638     md->sse_ld_sum = 0;
639   }
640 }
641 
get_est_rate_dist(TileDataEnc * tile_data,BLOCK_SIZE bsize,int64_t sse,int * est_residue_cost,int64_t * est_dist)642 static int get_est_rate_dist(TileDataEnc *tile_data, BLOCK_SIZE bsize,
643                              int64_t sse, int *est_residue_cost,
644                              int64_t *est_dist) {
645   aom_clear_system_state();
646   const InterModeRdModel *md = &tile_data->inter_mode_rd_models[bsize];
647   if (md->ready) {
648     const double est_ld = md->a * sse + md->b;
649     if (sse < md->dist_mean) {
650       *est_residue_cost = 0;
651       *est_dist = sse;
652     } else {
653       *est_residue_cost = (int)round((sse - md->dist_mean) / est_ld);
654       *est_dist = (int64_t)round(md->dist_mean);
655     }
656     return 1;
657   }
658   return 0;
659 }
660 
get_est_rd(TileDataEnc * tile_data,BLOCK_SIZE bsize,int rdmult,int64_t sse,int curr_cost)661 static int64_t get_est_rd(TileDataEnc *tile_data, BLOCK_SIZE bsize, int rdmult,
662                           int64_t sse, int curr_cost) {
663   int est_residue_cost;
664   int64_t est_dist;
665   if (get_est_rate_dist(tile_data, bsize, sse, &est_residue_cost, &est_dist)) {
666     int rate = est_residue_cost + curr_cost;
667     int64_t est_rd = RDCOST(rdmult, rate, est_dist);
668     return est_rd;
669   }
670   return 0;
671 }
672 
av1_inter_mode_data_fit(TileDataEnc * tile_data,int rdmult)673 void av1_inter_mode_data_fit(TileDataEnc *tile_data, int rdmult) {
674   aom_clear_system_state();
675   for (int bsize = 0; bsize < BLOCK_SIZES_ALL; ++bsize) {
676     const int block_idx = inter_mode_data_block_idx(bsize);
677     InterModeRdModel *md = &tile_data->inter_mode_rd_models[bsize];
678     if (block_idx == -1) continue;
679     if ((md->ready == 0 && md->num < 200) || (md->ready == 1 && md->num < 64)) {
680       continue;
681     } else {
682       if (md->ready == 0) {
683         md->dist_mean = md->dist_sum / md->num;
684         md->ld_mean = md->ld_sum / md->num;
685         md->sse_mean = md->sse_sum / md->num;
686         md->sse_sse_mean = md->sse_sse_sum / md->num;
687         md->sse_ld_mean = md->sse_ld_sum / md->num;
688       } else {
689         const double factor = 3;
690         md->dist_mean =
691             (md->dist_mean * factor + (md->dist_sum / md->num)) / (factor + 1);
692         md->ld_mean =
693             (md->ld_mean * factor + (md->ld_sum / md->num)) / (factor + 1);
694         md->sse_mean =
695             (md->sse_mean * factor + (md->sse_sum / md->num)) / (factor + 1);
696         md->sse_sse_mean =
697             (md->sse_sse_mean * factor + (md->sse_sse_sum / md->num)) /
698             (factor + 1);
699         md->sse_ld_mean =
700             (md->sse_ld_mean * factor + (md->sse_ld_sum / md->num)) /
701             (factor + 1);
702       }
703 
704       const double my = md->ld_mean;
705       const double mx = md->sse_mean;
706       const double dx = sqrt(md->sse_sse_mean);
707       const double dxy = md->sse_ld_mean;
708 
709       md->a = (dxy - mx * my) / (dx * dx - mx * mx);
710       md->b = my - md->a * mx;
711       md->ready = 1;
712 
713       md->num = 0;
714       md->dist_sum = 0;
715       md->ld_sum = 0;
716       md->sse_sum = 0;
717       md->sse_sse_sum = 0;
718       md->sse_ld_sum = 0;
719     }
720     (void)rdmult;
721   }
722 }
723 
inter_mode_data_push(TileDataEnc * tile_data,BLOCK_SIZE bsize,int64_t sse,int64_t dist,int residue_cost)724 static void inter_mode_data_push(TileDataEnc *tile_data, BLOCK_SIZE bsize,
725                                  int64_t sse, int64_t dist, int residue_cost) {
726   if (residue_cost == 0 || sse == dist) return;
727   const int block_idx = inter_mode_data_block_idx(bsize);
728   if (block_idx == -1) return;
729   InterModeRdModel *rd_model = &tile_data->inter_mode_rd_models[bsize];
730   if (rd_model->num < INTER_MODE_RD_DATA_OVERALL_SIZE) {
731     aom_clear_system_state();
732     const double ld = (sse - dist) * 1. / residue_cost;
733     ++rd_model->num;
734     rd_model->dist_sum += dist;
735     rd_model->ld_sum += ld;
736     rd_model->sse_sum += sse;
737     rd_model->sse_sse_sum += sse * sse;
738     rd_model->sse_ld_sum += sse * ld;
739   }
740 }
741 
inter_modes_info_push(InterModesInfo * inter_modes_info,int mode_rate,int64_t sse,int64_t est_rd,const MB_MODE_INFO * mbmi)742 static void inter_modes_info_push(InterModesInfo *inter_modes_info,
743                                   int mode_rate, int64_t sse, int64_t est_rd,
744                                   const MB_MODE_INFO *mbmi) {
745   const int num = inter_modes_info->num;
746   assert(num < MAX_INTER_MODES);
747   inter_modes_info->mbmi_arr[num] = *mbmi;
748   inter_modes_info->mode_rate_arr[num] = mode_rate;
749   inter_modes_info->sse_arr[num] = sse;
750   inter_modes_info->est_rd_arr[num] = est_rd;
751   ++inter_modes_info->num;
752 }
753 
compare_rd_idx_pair(const void * a,const void * b)754 static int compare_rd_idx_pair(const void *a, const void *b) {
755   if (((RdIdxPair *)a)->rd == ((RdIdxPair *)b)->rd) {
756     return 0;
757   } else if (((const RdIdxPair *)a)->rd > ((const RdIdxPair *)b)->rd) {
758     return 1;
759   } else {
760     return -1;
761   }
762 }
763 
inter_modes_info_sort(const InterModesInfo * inter_modes_info,RdIdxPair * rd_idx_pair_arr)764 static void inter_modes_info_sort(const InterModesInfo *inter_modes_info,
765                                   RdIdxPair *rd_idx_pair_arr) {
766   if (inter_modes_info->num == 0) {
767     return;
768   }
769   for (int i = 0; i < inter_modes_info->num; ++i) {
770     rd_idx_pair_arr[i].idx = i;
771     rd_idx_pair_arr[i].rd = inter_modes_info->est_rd_arr[i];
772   }
773   qsort(rd_idx_pair_arr, inter_modes_info->num, sizeof(rd_idx_pair_arr[0]),
774         compare_rd_idx_pair);
775 }
776 #endif  // CONFIG_COLLECT_INTER_MODE_RD_STATS
777 
write_uniform_cost(int n,int v)778 static INLINE int write_uniform_cost(int n, int v) {
779   const int l = get_unsigned_bits(n);
780   const int m = (1 << l) - n;
781   if (l == 0) return 0;
782   if (v < m)
783     return av1_cost_literal(l - 1);
784   else
785     return av1_cost_literal(l);
786 }
787 
788 // Similar to store_cfl_required(), but for use during the RDO process,
789 // where we haven't yet determined whether this block uses CfL.
store_cfl_required_rdo(const AV1_COMMON * cm,const MACROBLOCK * x)790 static INLINE CFL_ALLOWED_TYPE store_cfl_required_rdo(const AV1_COMMON *cm,
791                                                       const MACROBLOCK *x) {
792   const MACROBLOCKD *xd = &x->e_mbd;
793 
794   if (cm->seq_params.monochrome || x->skip_chroma_rd) return CFL_DISALLOWED;
795 
796   if (!xd->cfl.is_chroma_reference) {
797     // For non-chroma-reference blocks, we should always store the luma pixels,
798     // in case the corresponding chroma-reference block uses CfL.
799     // Note that this can only happen for block sizes which are <8 on
800     // their shortest side, as otherwise they would be chroma reference
801     // blocks.
802     return CFL_ALLOWED;
803   }
804 
805   // For chroma reference blocks, we should store data in the encoder iff we're
806   // allowed to try out CfL.
807   return is_cfl_allowed(xd);
808 }
809 
810 // constants for prune 1 and prune 2 decision boundaries
811 #define FAST_EXT_TX_CORR_MID 0.0
812 #define FAST_EXT_TX_EDST_MID 0.1
813 #define FAST_EXT_TX_CORR_MARGIN 0.5
814 #define FAST_EXT_TX_EDST_MARGIN 0.3
815 
816 static int inter_block_yrd(const AV1_COMP *cpi, MACROBLOCK *x,
817                            RD_STATS *rd_stats, BLOCK_SIZE bsize,
818                            int64_t ref_best_rd, FAST_TX_SEARCH_MODE ftxs_mode);
819 
pixel_dist_visible_only(const AV1_COMP * const cpi,const MACROBLOCK * x,const uint8_t * src,const int src_stride,const uint8_t * dst,const int dst_stride,const BLOCK_SIZE tx_bsize,int txb_rows,int txb_cols,int visible_rows,int visible_cols)820 static unsigned pixel_dist_visible_only(
821     const AV1_COMP *const cpi, const MACROBLOCK *x, const uint8_t *src,
822     const int src_stride, const uint8_t *dst, const int dst_stride,
823     const BLOCK_SIZE tx_bsize, int txb_rows, int txb_cols, int visible_rows,
824     int visible_cols) {
825   unsigned sse;
826 
827   if (txb_rows == visible_rows && txb_cols == visible_cols) {
828     cpi->fn_ptr[tx_bsize].vf(src, src_stride, dst, dst_stride, &sse);
829     return sse;
830   }
831   const MACROBLOCKD *xd = &x->e_mbd;
832 
833   if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) {
834     uint64_t sse64 = aom_highbd_sse_odd_size(src, src_stride, dst, dst_stride,
835                                              visible_cols, visible_rows);
836     return (unsigned int)ROUND_POWER_OF_TWO(sse64, (xd->bd - 8) * 2);
837   }
838   sse = aom_sse_odd_size(src, src_stride, dst, dst_stride, visible_cols,
839                          visible_rows);
840   return sse;
841 }
842 
843 #if CONFIG_DIST_8X8
cdef_dist_8x8_16bit(uint16_t * dst,int dstride,uint16_t * src,int sstride,int coeff_shift)844 static uint64_t cdef_dist_8x8_16bit(uint16_t *dst, int dstride, uint16_t *src,
845                                     int sstride, int coeff_shift) {
846   uint64_t svar = 0;
847   uint64_t dvar = 0;
848   uint64_t sum_s = 0;
849   uint64_t sum_d = 0;
850   uint64_t sum_s2 = 0;
851   uint64_t sum_d2 = 0;
852   uint64_t sum_sd = 0;
853   uint64_t dist = 0;
854 
855   int i, j;
856   for (i = 0; i < 8; i++) {
857     for (j = 0; j < 8; j++) {
858       sum_s += src[i * sstride + j];
859       sum_d += dst[i * dstride + j];
860       sum_s2 += src[i * sstride + j] * src[i * sstride + j];
861       sum_d2 += dst[i * dstride + j] * dst[i * dstride + j];
862       sum_sd += src[i * sstride + j] * dst[i * dstride + j];
863     }
864   }
865   /* Compute the variance -- the calculation cannot go negative. */
866   svar = sum_s2 - ((sum_s * sum_s + 32) >> 6);
867   dvar = sum_d2 - ((sum_d * sum_d + 32) >> 6);
868 
869   // Tuning of jm's original dering distortion metric used in CDEF tool,
870   // suggested by jm
871   const uint64_t a = 4;
872   const uint64_t b = 2;
873   const uint64_t c1 = (400 * a << 2 * coeff_shift);
874   const uint64_t c2 = (b * 20000 * a * a << 4 * coeff_shift);
875 
876   dist = (uint64_t)floor(.5 + (sum_d2 + sum_s2 - 2 * sum_sd) * .5 *
877                                   (svar + dvar + c1) /
878                                   (sqrt(svar * (double)dvar + c2)));
879 
880   // Calibrate dist to have similar rate for the same QP with MSE only
881   // distortion (as in master branch)
882   dist = (uint64_t)((float)dist * 0.75);
883 
884   return dist;
885 }
886 
od_compute_var_4x4(uint16_t * x,int stride)887 static int od_compute_var_4x4(uint16_t *x, int stride) {
888   int sum;
889   int s2;
890   int i;
891   sum = 0;
892   s2 = 0;
893   for (i = 0; i < 4; i++) {
894     int j;
895     for (j = 0; j < 4; j++) {
896       int t;
897 
898       t = x[i * stride + j];
899       sum += t;
900       s2 += t * t;
901     }
902   }
903 
904   return (s2 - (sum * sum >> 4)) >> 4;
905 }
906 
907 /* OD_DIST_LP_MID controls the frequency weighting filter used for computing
908    the distortion. For a value X, the filter is [1 X 1]/(X + 2) and
909    is applied both horizontally and vertically. For X=5, the filter is
910    a good approximation for the OD_QM8_Q4_HVS quantization matrix. */
911 #define OD_DIST_LP_MID (5)
912 #define OD_DIST_LP_NORM (OD_DIST_LP_MID + 2)
913 
od_compute_dist_8x8(int use_activity_masking,uint16_t * x,uint16_t * y,od_coeff * e_lp,int stride)914 static double od_compute_dist_8x8(int use_activity_masking, uint16_t *x,
915                                   uint16_t *y, od_coeff *e_lp, int stride) {
916   double sum;
917   int min_var;
918   double mean_var;
919   double var_stat;
920   double activity;
921   double calibration;
922   int i;
923   int j;
924   double vardist;
925 
926   vardist = 0;
927 
928 #if 1
929   min_var = INT_MAX;
930   mean_var = 0;
931   for (i = 0; i < 3; i++) {
932     for (j = 0; j < 3; j++) {
933       int varx;
934       int vary;
935       varx = od_compute_var_4x4(x + 2 * i * stride + 2 * j, stride);
936       vary = od_compute_var_4x4(y + 2 * i * stride + 2 * j, stride);
937       min_var = OD_MINI(min_var, varx);
938       mean_var += 1. / (1 + varx);
939       /* The cast to (double) is to avoid an overflow before the sqrt.*/
940       vardist += varx - 2 * sqrt(varx * (double)vary) + vary;
941     }
942   }
943   /* We use a different variance statistic depending on whether activity
944      masking is used, since the harmonic mean appeared slightly worse with
945      masking off. The calibration constant just ensures that we preserve the
946      rate compared to activity=1. */
947   if (use_activity_masking) {
948     calibration = 1.95;
949     var_stat = 9. / mean_var;
950   } else {
951     calibration = 1.62;
952     var_stat = min_var;
953   }
954   /* 1.62 is a calibration constant, 0.25 is a noise floor and 1/6 is the
955      activity masking constant. */
956   activity = calibration * pow(.25 + var_stat, -1. / 6);
957 #else
958   activity = 1;
959 #endif  // 1
960   sum = 0;
961   for (i = 0; i < 8; i++) {
962     for (j = 0; j < 8; j++)
963       sum += e_lp[i * stride + j] * (double)e_lp[i * stride + j];
964   }
965   /* Normalize the filter to unit DC response. */
966   sum *= 1. / (OD_DIST_LP_NORM * OD_DIST_LP_NORM * OD_DIST_LP_NORM *
967                OD_DIST_LP_NORM);
968   return activity * activity * (sum + vardist);
969 }
970 
971 // Note : Inputs x and y are in a pixel domain
od_compute_dist_common(int activity_masking,uint16_t * x,uint16_t * y,int bsize_w,int bsize_h,int qindex,od_coeff * tmp,od_coeff * e_lp)972 static double od_compute_dist_common(int activity_masking, uint16_t *x,
973                                      uint16_t *y, int bsize_w, int bsize_h,
974                                      int qindex, od_coeff *tmp,
975                                      od_coeff *e_lp) {
976   int i, j;
977   double sum = 0;
978   const int mid = OD_DIST_LP_MID;
979 
980   for (j = 0; j < bsize_w; j++) {
981     e_lp[j] = mid * tmp[j] + 2 * tmp[bsize_w + j];
982     e_lp[(bsize_h - 1) * bsize_w + j] = mid * tmp[(bsize_h - 1) * bsize_w + j] +
983                                         2 * tmp[(bsize_h - 2) * bsize_w + j];
984   }
985   for (i = 1; i < bsize_h - 1; i++) {
986     for (j = 0; j < bsize_w; j++) {
987       e_lp[i * bsize_w + j] = mid * tmp[i * bsize_w + j] +
988                               tmp[(i - 1) * bsize_w + j] +
989                               tmp[(i + 1) * bsize_w + j];
990     }
991   }
992   for (i = 0; i < bsize_h; i += 8) {
993     for (j = 0; j < bsize_w; j += 8) {
994       sum += od_compute_dist_8x8(activity_masking, &x[i * bsize_w + j],
995                                  &y[i * bsize_w + j], &e_lp[i * bsize_w + j],
996                                  bsize_w);
997     }
998   }
999   /* Scale according to linear regression against SSE, for 8x8 blocks. */
1000   if (activity_masking) {
1001     sum *= 2.2 + (1.7 - 2.2) * (qindex - 99) / (210 - 99) +
1002            (qindex < 99 ? 2.5 * (qindex - 99) / 99 * (qindex - 99) / 99 : 0);
1003   } else {
1004     sum *= qindex >= 128
1005                ? 1.4 + (0.9 - 1.4) * (qindex - 128) / (209 - 128)
1006                : qindex <= 43 ? 1.5 + (2.0 - 1.5) * (qindex - 43) / (16 - 43)
1007                               : 1.5 + (1.4 - 1.5) * (qindex - 43) / (128 - 43);
1008   }
1009 
1010   return sum;
1011 }
1012 
od_compute_dist(uint16_t * x,uint16_t * y,int bsize_w,int bsize_h,int qindex)1013 static double od_compute_dist(uint16_t *x, uint16_t *y, int bsize_w,
1014                               int bsize_h, int qindex) {
1015   assert(bsize_w >= 8 && bsize_h >= 8);
1016 
1017   int activity_masking = 0;
1018 
1019   int i, j;
1020   DECLARE_ALIGNED(16, od_coeff, e[MAX_SB_SQUARE]);
1021   DECLARE_ALIGNED(16, od_coeff, tmp[MAX_SB_SQUARE]);
1022   DECLARE_ALIGNED(16, od_coeff, e_lp[MAX_SB_SQUARE]);
1023   for (i = 0; i < bsize_h; i++) {
1024     for (j = 0; j < bsize_w; j++) {
1025       e[i * bsize_w + j] = x[i * bsize_w + j] - y[i * bsize_w + j];
1026     }
1027   }
1028   int mid = OD_DIST_LP_MID;
1029   for (i = 0; i < bsize_h; i++) {
1030     tmp[i * bsize_w] = mid * e[i * bsize_w] + 2 * e[i * bsize_w + 1];
1031     tmp[i * bsize_w + bsize_w - 1] =
1032         mid * e[i * bsize_w + bsize_w - 1] + 2 * e[i * bsize_w + bsize_w - 2];
1033     for (j = 1; j < bsize_w - 1; j++) {
1034       tmp[i * bsize_w + j] = mid * e[i * bsize_w + j] + e[i * bsize_w + j - 1] +
1035                              e[i * bsize_w + j + 1];
1036     }
1037   }
1038   return od_compute_dist_common(activity_masking, x, y, bsize_w, bsize_h,
1039                                 qindex, tmp, e_lp);
1040 }
1041 
od_compute_dist_diff(uint16_t * x,int16_t * e,int bsize_w,int bsize_h,int qindex)1042 static double od_compute_dist_diff(uint16_t *x, int16_t *e, int bsize_w,
1043                                    int bsize_h, int qindex) {
1044   assert(bsize_w >= 8 && bsize_h >= 8);
1045 
1046   int activity_masking = 0;
1047 
1048   DECLARE_ALIGNED(16, uint16_t, y[MAX_SB_SQUARE]);
1049   DECLARE_ALIGNED(16, od_coeff, tmp[MAX_SB_SQUARE]);
1050   DECLARE_ALIGNED(16, od_coeff, e_lp[MAX_SB_SQUARE]);
1051   int i, j;
1052   for (i = 0; i < bsize_h; i++) {
1053     for (j = 0; j < bsize_w; j++) {
1054       y[i * bsize_w + j] = x[i * bsize_w + j] - e[i * bsize_w + j];
1055     }
1056   }
1057   int mid = OD_DIST_LP_MID;
1058   for (i = 0; i < bsize_h; i++) {
1059     tmp[i * bsize_w] = mid * e[i * bsize_w] + 2 * e[i * bsize_w + 1];
1060     tmp[i * bsize_w + bsize_w - 1] =
1061         mid * e[i * bsize_w + bsize_w - 1] + 2 * e[i * bsize_w + bsize_w - 2];
1062     for (j = 1; j < bsize_w - 1; j++) {
1063       tmp[i * bsize_w + j] = mid * e[i * bsize_w + j] + e[i * bsize_w + j - 1] +
1064                              e[i * bsize_w + j + 1];
1065     }
1066   }
1067   return od_compute_dist_common(activity_masking, x, y, bsize_w, bsize_h,
1068                                 qindex, tmp, e_lp);
1069 }
1070 
av1_dist_8x8(const AV1_COMP * const cpi,const MACROBLOCK * x,const uint8_t * src,int src_stride,const uint8_t * dst,int dst_stride,const BLOCK_SIZE tx_bsize,int bsw,int bsh,int visible_w,int visible_h,int qindex)1071 int64_t av1_dist_8x8(const AV1_COMP *const cpi, const MACROBLOCK *x,
1072                      const uint8_t *src, int src_stride, const uint8_t *dst,
1073                      int dst_stride, const BLOCK_SIZE tx_bsize, int bsw,
1074                      int bsh, int visible_w, int visible_h, int qindex) {
1075   int64_t d = 0;
1076   int i, j;
1077   const MACROBLOCKD *xd = &x->e_mbd;
1078 
1079   DECLARE_ALIGNED(16, uint16_t, orig[MAX_SB_SQUARE]);
1080   DECLARE_ALIGNED(16, uint16_t, rec[MAX_SB_SQUARE]);
1081 
1082   assert(bsw >= 8);
1083   assert(bsh >= 8);
1084   assert((bsw & 0x07) == 0);
1085   assert((bsh & 0x07) == 0);
1086 
1087   if (x->tune_metric == AOM_TUNE_CDEF_DIST ||
1088       x->tune_metric == AOM_TUNE_DAALA_DIST) {
1089     if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) {
1090       for (j = 0; j < bsh; j++)
1091         for (i = 0; i < bsw; i++)
1092           orig[j * bsw + i] = CONVERT_TO_SHORTPTR(src)[j * src_stride + i];
1093 
1094       if ((bsw == visible_w) && (bsh == visible_h)) {
1095         for (j = 0; j < bsh; j++)
1096           for (i = 0; i < bsw; i++)
1097             rec[j * bsw + i] = CONVERT_TO_SHORTPTR(dst)[j * dst_stride + i];
1098       } else {
1099         for (j = 0; j < visible_h; j++)
1100           for (i = 0; i < visible_w; i++)
1101             rec[j * bsw + i] = CONVERT_TO_SHORTPTR(dst)[j * dst_stride + i];
1102 
1103         if (visible_w < bsw) {
1104           for (j = 0; j < bsh; j++)
1105             for (i = visible_w; i < bsw; i++)
1106               rec[j * bsw + i] = CONVERT_TO_SHORTPTR(src)[j * src_stride + i];
1107         }
1108 
1109         if (visible_h < bsh) {
1110           for (j = visible_h; j < bsh; j++)
1111             for (i = 0; i < bsw; i++)
1112               rec[j * bsw + i] = CONVERT_TO_SHORTPTR(src)[j * src_stride + i];
1113         }
1114       }
1115     } else {
1116       for (j = 0; j < bsh; j++)
1117         for (i = 0; i < bsw; i++) orig[j * bsw + i] = src[j * src_stride + i];
1118 
1119       if ((bsw == visible_w) && (bsh == visible_h)) {
1120         for (j = 0; j < bsh; j++)
1121           for (i = 0; i < bsw; i++) rec[j * bsw + i] = dst[j * dst_stride + i];
1122       } else {
1123         for (j = 0; j < visible_h; j++)
1124           for (i = 0; i < visible_w; i++)
1125             rec[j * bsw + i] = dst[j * dst_stride + i];
1126 
1127         if (visible_w < bsw) {
1128           for (j = 0; j < bsh; j++)
1129             for (i = visible_w; i < bsw; i++)
1130               rec[j * bsw + i] = src[j * src_stride + i];
1131         }
1132 
1133         if (visible_h < bsh) {
1134           for (j = visible_h; j < bsh; j++)
1135             for (i = 0; i < bsw; i++)
1136               rec[j * bsw + i] = src[j * src_stride + i];
1137         }
1138       }
1139     }
1140   }
1141 
1142   if (x->tune_metric == AOM_TUNE_DAALA_DIST) {
1143     d = (int64_t)od_compute_dist(orig, rec, bsw, bsh, qindex);
1144   } else if (x->tune_metric == AOM_TUNE_CDEF_DIST) {
1145     int coeff_shift = AOMMAX(xd->bd - 8, 0);
1146 
1147     for (i = 0; i < bsh; i += 8) {
1148       for (j = 0; j < bsw; j += 8) {
1149         d += cdef_dist_8x8_16bit(&rec[i * bsw + j], bsw, &orig[i * bsw + j],
1150                                  bsw, coeff_shift);
1151       }
1152     }
1153     if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH)
1154       d = ((uint64_t)d) >> 2 * coeff_shift;
1155   } else {
1156     // Otherwise, MSE by default
1157     d = pixel_dist_visible_only(cpi, x, src, src_stride, dst, dst_stride,
1158                                 tx_bsize, bsh, bsw, visible_h, visible_w);
1159   }
1160 
1161   return d;
1162 }
1163 
dist_8x8_diff(const MACROBLOCK * x,const uint8_t * src,int src_stride,const int16_t * diff,int diff_stride,int bsw,int bsh,int visible_w,int visible_h,int qindex)1164 static int64_t dist_8x8_diff(const MACROBLOCK *x, const uint8_t *src,
1165                              int src_stride, const int16_t *diff,
1166                              int diff_stride, int bsw, int bsh, int visible_w,
1167                              int visible_h, int qindex) {
1168   int64_t d = 0;
1169   int i, j;
1170   const MACROBLOCKD *xd = &x->e_mbd;
1171 
1172   DECLARE_ALIGNED(16, uint16_t, orig[MAX_SB_SQUARE]);
1173   DECLARE_ALIGNED(16, int16_t, diff16[MAX_SB_SQUARE]);
1174 
1175   assert(bsw >= 8);
1176   assert(bsh >= 8);
1177   assert((bsw & 0x07) == 0);
1178   assert((bsh & 0x07) == 0);
1179 
1180   if (x->tune_metric == AOM_TUNE_CDEF_DIST ||
1181       x->tune_metric == AOM_TUNE_DAALA_DIST) {
1182     if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) {
1183       for (j = 0; j < bsh; j++)
1184         for (i = 0; i < bsw; i++)
1185           orig[j * bsw + i] = CONVERT_TO_SHORTPTR(src)[j * src_stride + i];
1186     } else {
1187       for (j = 0; j < bsh; j++)
1188         for (i = 0; i < bsw; i++) orig[j * bsw + i] = src[j * src_stride + i];
1189     }
1190 
1191     if ((bsw == visible_w) && (bsh == visible_h)) {
1192       for (j = 0; j < bsh; j++)
1193         for (i = 0; i < bsw; i++)
1194           diff16[j * bsw + i] = diff[j * diff_stride + i];
1195     } else {
1196       for (j = 0; j < visible_h; j++)
1197         for (i = 0; i < visible_w; i++)
1198           diff16[j * bsw + i] = diff[j * diff_stride + i];
1199 
1200       if (visible_w < bsw) {
1201         for (j = 0; j < bsh; j++)
1202           for (i = visible_w; i < bsw; i++) diff16[j * bsw + i] = 0;
1203       }
1204 
1205       if (visible_h < bsh) {
1206         for (j = visible_h; j < bsh; j++)
1207           for (i = 0; i < bsw; i++) diff16[j * bsw + i] = 0;
1208       }
1209     }
1210   }
1211 
1212   if (x->tune_metric == AOM_TUNE_DAALA_DIST) {
1213     d = (int64_t)od_compute_dist_diff(orig, diff16, bsw, bsh, qindex);
1214   } else if (x->tune_metric == AOM_TUNE_CDEF_DIST) {
1215     int coeff_shift = AOMMAX(xd->bd - 8, 0);
1216     DECLARE_ALIGNED(16, uint16_t, dst16[MAX_SB_SQUARE]);
1217 
1218     for (i = 0; i < bsh; i++) {
1219       for (j = 0; j < bsw; j++) {
1220         dst16[i * bsw + j] = orig[i * bsw + j] - diff16[i * bsw + j];
1221       }
1222     }
1223 
1224     for (i = 0; i < bsh; i += 8) {
1225       for (j = 0; j < bsw; j += 8) {
1226         d += cdef_dist_8x8_16bit(&dst16[i * bsw + j], bsw, &orig[i * bsw + j],
1227                                  bsw, coeff_shift);
1228       }
1229     }
1230     // Don't scale 'd' for HBD since it will be done by caller side for diff
1231     // input
1232   } else {
1233     // Otherwise, MSE by default
1234     d = aom_sum_squares_2d_i16(diff, diff_stride, visible_w, visible_h);
1235   }
1236 
1237   return d;
1238 }
1239 #endif  // CONFIG_DIST_8X8
1240 
get_energy_distribution_fine(const AV1_COMP * cpi,BLOCK_SIZE bsize,const uint8_t * src,int src_stride,const uint8_t * dst,int dst_stride,int need_4th,double * hordist,double * verdist)1241 static void get_energy_distribution_fine(const AV1_COMP *cpi, BLOCK_SIZE bsize,
1242                                          const uint8_t *src, int src_stride,
1243                                          const uint8_t *dst, int dst_stride,
1244                                          int need_4th, double *hordist,
1245                                          double *verdist) {
1246   const int bw = block_size_wide[bsize];
1247   const int bh = block_size_high[bsize];
1248   unsigned int esq[16] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 };
1249 
1250   if (bsize < BLOCK_16X16 || (bsize >= BLOCK_4X16 && bsize <= BLOCK_32X8)) {
1251     // Special cases: calculate 'esq' values manually, as we don't have 'vf'
1252     // functions for the 16 (very small) sub-blocks of this block.
1253     const int w_shift = (bw == 4) ? 0 : (bw == 8) ? 1 : (bw == 16) ? 2 : 3;
1254     const int h_shift = (bh == 4) ? 0 : (bh == 8) ? 1 : (bh == 16) ? 2 : 3;
1255     assert(bw <= 32);
1256     assert(bh <= 32);
1257     assert(((bw - 1) >> w_shift) + (((bh - 1) >> h_shift) << 2) == 15);
1258     if (cpi->common.seq_params.use_highbitdepth) {
1259       const uint16_t *src16 = CONVERT_TO_SHORTPTR(src);
1260       const uint16_t *dst16 = CONVERT_TO_SHORTPTR(dst);
1261       for (int i = 0; i < bh; ++i)
1262         for (int j = 0; j < bw; ++j) {
1263           const int index = (j >> w_shift) + ((i >> h_shift) << 2);
1264           esq[index] +=
1265               (src16[j + i * src_stride] - dst16[j + i * dst_stride]) *
1266               (src16[j + i * src_stride] - dst16[j + i * dst_stride]);
1267         }
1268     } else {
1269       for (int i = 0; i < bh; ++i)
1270         for (int j = 0; j < bw; ++j) {
1271           const int index = (j >> w_shift) + ((i >> h_shift) << 2);
1272           esq[index] += (src[j + i * src_stride] - dst[j + i * dst_stride]) *
1273                         (src[j + i * src_stride] - dst[j + i * dst_stride]);
1274         }
1275     }
1276   } else {  // Calculate 'esq' values using 'vf' functions on the 16 sub-blocks.
1277     const int f_index =
1278         (bsize < BLOCK_SIZES) ? bsize - BLOCK_16X16 : bsize - BLOCK_8X16;
1279     assert(f_index >= 0 && f_index < BLOCK_SIZES_ALL);
1280     const BLOCK_SIZE subsize = (BLOCK_SIZE)f_index;
1281     assert(block_size_wide[bsize] == 4 * block_size_wide[subsize]);
1282     assert(block_size_high[bsize] == 4 * block_size_high[subsize]);
1283     cpi->fn_ptr[subsize].vf(src, src_stride, dst, dst_stride, &esq[0]);
1284     cpi->fn_ptr[subsize].vf(src + bw / 4, src_stride, dst + bw / 4, dst_stride,
1285                             &esq[1]);
1286     cpi->fn_ptr[subsize].vf(src + bw / 2, src_stride, dst + bw / 2, dst_stride,
1287                             &esq[2]);
1288     cpi->fn_ptr[subsize].vf(src + 3 * bw / 4, src_stride, dst + 3 * bw / 4,
1289                             dst_stride, &esq[3]);
1290     src += bh / 4 * src_stride;
1291     dst += bh / 4 * dst_stride;
1292 
1293     cpi->fn_ptr[subsize].vf(src, src_stride, dst, dst_stride, &esq[4]);
1294     cpi->fn_ptr[subsize].vf(src + bw / 4, src_stride, dst + bw / 4, dst_stride,
1295                             &esq[5]);
1296     cpi->fn_ptr[subsize].vf(src + bw / 2, src_stride, dst + bw / 2, dst_stride,
1297                             &esq[6]);
1298     cpi->fn_ptr[subsize].vf(src + 3 * bw / 4, src_stride, dst + 3 * bw / 4,
1299                             dst_stride, &esq[7]);
1300     src += bh / 4 * src_stride;
1301     dst += bh / 4 * dst_stride;
1302 
1303     cpi->fn_ptr[subsize].vf(src, src_stride, dst, dst_stride, &esq[8]);
1304     cpi->fn_ptr[subsize].vf(src + bw / 4, src_stride, dst + bw / 4, dst_stride,
1305                             &esq[9]);
1306     cpi->fn_ptr[subsize].vf(src + bw / 2, src_stride, dst + bw / 2, dst_stride,
1307                             &esq[10]);
1308     cpi->fn_ptr[subsize].vf(src + 3 * bw / 4, src_stride, dst + 3 * bw / 4,
1309                             dst_stride, &esq[11]);
1310     src += bh / 4 * src_stride;
1311     dst += bh / 4 * dst_stride;
1312 
1313     cpi->fn_ptr[subsize].vf(src, src_stride, dst, dst_stride, &esq[12]);
1314     cpi->fn_ptr[subsize].vf(src + bw / 4, src_stride, dst + bw / 4, dst_stride,
1315                             &esq[13]);
1316     cpi->fn_ptr[subsize].vf(src + bw / 2, src_stride, dst + bw / 2, dst_stride,
1317                             &esq[14]);
1318     cpi->fn_ptr[subsize].vf(src + 3 * bw / 4, src_stride, dst + 3 * bw / 4,
1319                             dst_stride, &esq[15]);
1320   }
1321 
1322   double total = (double)esq[0] + esq[1] + esq[2] + esq[3] + esq[4] + esq[5] +
1323                  esq[6] + esq[7] + esq[8] + esq[9] + esq[10] + esq[11] +
1324                  esq[12] + esq[13] + esq[14] + esq[15];
1325   if (total > 0) {
1326     const double e_recip = 1.0 / total;
1327     hordist[0] = ((double)esq[0] + esq[4] + esq[8] + esq[12]) * e_recip;
1328     hordist[1] = ((double)esq[1] + esq[5] + esq[9] + esq[13]) * e_recip;
1329     hordist[2] = ((double)esq[2] + esq[6] + esq[10] + esq[14]) * e_recip;
1330     if (need_4th) {
1331       hordist[3] = ((double)esq[3] + esq[7] + esq[11] + esq[15]) * e_recip;
1332     }
1333     verdist[0] = ((double)esq[0] + esq[1] + esq[2] + esq[3]) * e_recip;
1334     verdist[1] = ((double)esq[4] + esq[5] + esq[6] + esq[7]) * e_recip;
1335     verdist[2] = ((double)esq[8] + esq[9] + esq[10] + esq[11]) * e_recip;
1336     if (need_4th) {
1337       verdist[3] = ((double)esq[12] + esq[13] + esq[14] + esq[15]) * e_recip;
1338     }
1339   } else {
1340     hordist[0] = verdist[0] = 0.25;
1341     hordist[1] = verdist[1] = 0.25;
1342     hordist[2] = verdist[2] = 0.25;
1343     if (need_4th) {
1344       hordist[3] = verdist[3] = 0.25;
1345     }
1346   }
1347 }
1348 
adst_vs_flipadst(const AV1_COMP * cpi,BLOCK_SIZE bsize,const uint8_t * src,int src_stride,const uint8_t * dst,int dst_stride)1349 static int adst_vs_flipadst(const AV1_COMP *cpi, BLOCK_SIZE bsize,
1350                             const uint8_t *src, int src_stride,
1351                             const uint8_t *dst, int dst_stride) {
1352   int prune_bitmask = 0;
1353   double svm_proj_h = 0, svm_proj_v = 0;
1354   double hdist[3] = { 0, 0, 0 }, vdist[3] = { 0, 0, 0 };
1355   get_energy_distribution_fine(cpi, bsize, src, src_stride, dst, dst_stride, 0,
1356                                hdist, vdist);
1357 
1358   svm_proj_v = vdist[0] * ADST_FLIP_SVM[0] + vdist[1] * ADST_FLIP_SVM[1] +
1359                vdist[2] * ADST_FLIP_SVM[2] + ADST_FLIP_SVM[3];
1360   svm_proj_h = hdist[0] * ADST_FLIP_SVM[4] + hdist[1] * ADST_FLIP_SVM[5] +
1361                hdist[2] * ADST_FLIP_SVM[6] + ADST_FLIP_SVM[7];
1362   if (svm_proj_v > FAST_EXT_TX_EDST_MID + FAST_EXT_TX_EDST_MARGIN)
1363     prune_bitmask |= 1 << FLIPADST_1D;
1364   else if (svm_proj_v < FAST_EXT_TX_EDST_MID - FAST_EXT_TX_EDST_MARGIN)
1365     prune_bitmask |= 1 << ADST_1D;
1366 
1367   if (svm_proj_h > FAST_EXT_TX_EDST_MID + FAST_EXT_TX_EDST_MARGIN)
1368     prune_bitmask |= 1 << (FLIPADST_1D + 8);
1369   else if (svm_proj_h < FAST_EXT_TX_EDST_MID - FAST_EXT_TX_EDST_MARGIN)
1370     prune_bitmask |= 1 << (ADST_1D + 8);
1371 
1372   return prune_bitmask;
1373 }
1374 
get_horver_correlation(const int16_t * diff,int stride,int w,int h,double * hcorr,double * vcorr)1375 static void get_horver_correlation(const int16_t *diff, int stride, int w,
1376                                    int h, double *hcorr, double *vcorr) {
1377   // Returns hor/ver correlation coefficient
1378   const int num = (h - 1) * (w - 1);
1379   double num_r;
1380   int i, j;
1381   int64_t xy_sum = 0, xz_sum = 0;
1382   int64_t x_sum = 0, y_sum = 0, z_sum = 0;
1383   int64_t x2_sum = 0, y2_sum = 0, z2_sum = 0;
1384   double x_var_n, y_var_n, z_var_n, xy_var_n, xz_var_n;
1385   *hcorr = *vcorr = 1;
1386 
1387   assert(num > 0);
1388   num_r = 1.0 / num;
1389   for (i = 1; i < h; ++i) {
1390     for (j = 1; j < w; ++j) {
1391       const int16_t x = diff[i * stride + j];
1392       const int16_t y = diff[i * stride + j - 1];
1393       const int16_t z = diff[(i - 1) * stride + j];
1394       xy_sum += x * y;
1395       xz_sum += x * z;
1396       x_sum += x;
1397       y_sum += y;
1398       z_sum += z;
1399       x2_sum += x * x;
1400       y2_sum += y * y;
1401       z2_sum += z * z;
1402     }
1403   }
1404   x_var_n = x2_sum - (x_sum * x_sum) * num_r;
1405   y_var_n = y2_sum - (y_sum * y_sum) * num_r;
1406   z_var_n = z2_sum - (z_sum * z_sum) * num_r;
1407   xy_var_n = xy_sum - (x_sum * y_sum) * num_r;
1408   xz_var_n = xz_sum - (x_sum * z_sum) * num_r;
1409   if (x_var_n > 0 && y_var_n > 0) {
1410     *hcorr = xy_var_n / sqrt(x_var_n * y_var_n);
1411     *hcorr = *hcorr < 0 ? 0 : *hcorr;
1412   }
1413   if (x_var_n > 0 && z_var_n > 0) {
1414     *vcorr = xz_var_n / sqrt(x_var_n * z_var_n);
1415     *vcorr = *vcorr < 0 ? 0 : *vcorr;
1416   }
1417 }
1418 
dct_vs_idtx(const int16_t * diff,int stride,int w,int h)1419 static int dct_vs_idtx(const int16_t *diff, int stride, int w, int h) {
1420   double hcorr, vcorr;
1421   int prune_bitmask = 0;
1422   get_horver_correlation(diff, stride, w, h, &hcorr, &vcorr);
1423 
1424   if (vcorr > FAST_EXT_TX_CORR_MID + FAST_EXT_TX_CORR_MARGIN)
1425     prune_bitmask |= 1 << IDTX_1D;
1426   else if (vcorr < FAST_EXT_TX_CORR_MID - FAST_EXT_TX_CORR_MARGIN)
1427     prune_bitmask |= 1 << DCT_1D;
1428 
1429   if (hcorr > FAST_EXT_TX_CORR_MID + FAST_EXT_TX_CORR_MARGIN)
1430     prune_bitmask |= 1 << (IDTX_1D + 8);
1431   else if (hcorr < FAST_EXT_TX_CORR_MID - FAST_EXT_TX_CORR_MARGIN)
1432     prune_bitmask |= 1 << (DCT_1D + 8);
1433   return prune_bitmask;
1434 }
1435 
1436 // Performance drop: 0.5%, Speed improvement: 24%
prune_two_for_sby(const AV1_COMP * cpi,BLOCK_SIZE bsize,MACROBLOCK * x,const MACROBLOCKD * xd,int adst_flipadst,int dct_idtx)1437 static int prune_two_for_sby(const AV1_COMP *cpi, BLOCK_SIZE bsize,
1438                              MACROBLOCK *x, const MACROBLOCKD *xd,
1439                              int adst_flipadst, int dct_idtx) {
1440   int prune = 0;
1441 
1442   if (adst_flipadst) {
1443     const struct macroblock_plane *const p = &x->plane[0];
1444     const struct macroblockd_plane *const pd = &xd->plane[0];
1445     prune |= adst_vs_flipadst(cpi, bsize, p->src.buf, p->src.stride,
1446                               pd->dst.buf, pd->dst.stride);
1447   }
1448   if (dct_idtx) {
1449     av1_subtract_plane(x, bsize, 0);
1450     const struct macroblock_plane *const p = &x->plane[0];
1451     const int bw = block_size_wide[bsize];
1452     const int bh = block_size_high[bsize];
1453     prune |= dct_vs_idtx(p->src_diff, bw, bw, bh);
1454   }
1455 
1456   return prune;
1457 }
1458 
1459 // Performance drop: 0.3%, Speed improvement: 5%
prune_one_for_sby(const AV1_COMP * cpi,BLOCK_SIZE bsize,const MACROBLOCK * x,const MACROBLOCKD * xd)1460 static int prune_one_for_sby(const AV1_COMP *cpi, BLOCK_SIZE bsize,
1461                              const MACROBLOCK *x, const MACROBLOCKD *xd) {
1462   const struct macroblock_plane *const p = &x->plane[0];
1463   const struct macroblockd_plane *const pd = &xd->plane[0];
1464   return adst_vs_flipadst(cpi, bsize, p->src.buf, p->src.stride, pd->dst.buf,
1465                           pd->dst.stride);
1466 }
1467 
1468 // 1D Transforms used in inter set, this needs to be changed if
1469 // ext_tx_used_inter is changed
1470 static const int ext_tx_used_inter_1D[EXT_TX_SETS_INTER][TX_TYPES_1D] = {
1471   { 1, 0, 0, 0 },
1472   { 1, 1, 1, 1 },
1473   { 1, 1, 1, 1 },
1474   { 1, 0, 0, 1 },
1475 };
1476 
get_energy_distribution_finer(const int16_t * diff,int stride,int bw,int bh,float * hordist,float * verdist)1477 static void get_energy_distribution_finer(const int16_t *diff, int stride,
1478                                           int bw, int bh, float *hordist,
1479                                           float *verdist) {
1480   // First compute downscaled block energy values (esq); downscale factors
1481   // are defined by w_shift and h_shift.
1482   unsigned int esq[256];
1483   const int w_shift = bw <= 8 ? 0 : 1;
1484   const int h_shift = bh <= 8 ? 0 : 1;
1485   const int esq_w = bw >> w_shift;
1486   const int esq_h = bh >> h_shift;
1487   const int esq_sz = esq_w * esq_h;
1488   int i, j;
1489   memset(esq, 0, esq_sz * sizeof(esq[0]));
1490   if (w_shift) {
1491     for (i = 0; i < bh; i++) {
1492       unsigned int *cur_esq_row = esq + (i >> h_shift) * esq_w;
1493       const int16_t *cur_diff_row = diff + i * stride;
1494       for (j = 0; j < bw; j += 2) {
1495         cur_esq_row[j >> 1] += (cur_diff_row[j] * cur_diff_row[j] +
1496                                 cur_diff_row[j + 1] * cur_diff_row[j + 1]);
1497       }
1498     }
1499   } else {
1500     for (i = 0; i < bh; i++) {
1501       unsigned int *cur_esq_row = esq + (i >> h_shift) * esq_w;
1502       const int16_t *cur_diff_row = diff + i * stride;
1503       for (j = 0; j < bw; j++) {
1504         cur_esq_row[j] += cur_diff_row[j] * cur_diff_row[j];
1505       }
1506     }
1507   }
1508 
1509   uint64_t total = 0;
1510   for (i = 0; i < esq_sz; i++) total += esq[i];
1511 
1512   // Output hordist and verdist arrays are normalized 1D projections of esq
1513   if (total == 0) {
1514     float hor_val = 1.0f / esq_w;
1515     for (j = 0; j < esq_w - 1; j++) hordist[j] = hor_val;
1516     float ver_val = 1.0f / esq_h;
1517     for (i = 0; i < esq_h - 1; i++) verdist[i] = ver_val;
1518     return;
1519   }
1520 
1521   const float e_recip = 1.0f / (float)total;
1522   memset(hordist, 0, (esq_w - 1) * sizeof(hordist[0]));
1523   memset(verdist, 0, (esq_h - 1) * sizeof(verdist[0]));
1524   const unsigned int *cur_esq_row;
1525   for (i = 0; i < esq_h - 1; i++) {
1526     cur_esq_row = esq + i * esq_w;
1527     for (j = 0; j < esq_w - 1; j++) {
1528       hordist[j] += (float)cur_esq_row[j];
1529       verdist[i] += (float)cur_esq_row[j];
1530     }
1531     verdist[i] += (float)cur_esq_row[j];
1532   }
1533   cur_esq_row = esq + i * esq_w;
1534   for (j = 0; j < esq_w - 1; j++) hordist[j] += (float)cur_esq_row[j];
1535 
1536   for (j = 0; j < esq_w - 1; j++) hordist[j] *= e_recip;
1537   for (i = 0; i < esq_h - 1; i++) verdist[i] *= e_recip;
1538 }
1539 
1540 // Similar to get_horver_correlation, but also takes into account first
1541 // row/column, when computing horizontal/vertical correlation.
get_horver_correlation_full(const int16_t * diff,int stride,int w,int h,float * hcorr,float * vcorr)1542 static void get_horver_correlation_full(const int16_t *diff, int stride, int w,
1543                                         int h, float *hcorr, float *vcorr) {
1544   const float num_hor = (float)(h * (w - 1));
1545   const float num_ver = (float)((h - 1) * w);
1546   int i, j;
1547 
1548   // The following notation is used:
1549   // x - current pixel
1550   // y - left neighbor pixel
1551   // z - top neighbor pixel
1552   int64_t xy_sum = 0, xz_sum = 0;
1553   int64_t xhor_sum = 0, xver_sum = 0, y_sum = 0, z_sum = 0;
1554   int64_t x2hor_sum = 0, x2ver_sum = 0, y2_sum = 0, z2_sum = 0;
1555 
1556   int16_t x, y, z;
1557   for (j = 1; j < w; ++j) {
1558     x = diff[j];
1559     y = diff[j - 1];
1560     xy_sum += x * y;
1561     xhor_sum += x;
1562     y_sum += y;
1563     x2hor_sum += x * x;
1564     y2_sum += y * y;
1565   }
1566   for (i = 1; i < h; ++i) {
1567     x = diff[i * stride];
1568     z = diff[(i - 1) * stride];
1569     xz_sum += x * z;
1570     xver_sum += x;
1571     z_sum += z;
1572     x2ver_sum += x * x;
1573     z2_sum += z * z;
1574     for (j = 1; j < w; ++j) {
1575       x = diff[i * stride + j];
1576       y = diff[i * stride + j - 1];
1577       z = diff[(i - 1) * stride + j];
1578       xy_sum += x * y;
1579       xz_sum += x * z;
1580       xhor_sum += x;
1581       xver_sum += x;
1582       y_sum += y;
1583       z_sum += z;
1584       x2hor_sum += x * x;
1585       x2ver_sum += x * x;
1586       y2_sum += y * y;
1587       z2_sum += z * z;
1588     }
1589   }
1590   const float xhor_var_n = x2hor_sum - (xhor_sum * xhor_sum) / num_hor;
1591   const float y_var_n = y2_sum - (y_sum * y_sum) / num_hor;
1592   const float xy_var_n = xy_sum - (xhor_sum * y_sum) / num_hor;
1593   const float xver_var_n = x2ver_sum - (xver_sum * xver_sum) / num_ver;
1594   const float z_var_n = z2_sum - (z_sum * z_sum) / num_ver;
1595   const float xz_var_n = xz_sum - (xver_sum * z_sum) / num_ver;
1596 
1597   *hcorr = *vcorr = 1;
1598   if (xhor_var_n > 0 && y_var_n > 0) {
1599     *hcorr = xy_var_n / sqrtf(xhor_var_n * y_var_n);
1600     *hcorr = *hcorr < 0 ? 0 : *hcorr;
1601   }
1602   if (xver_var_n > 0 && z_var_n > 0) {
1603     *vcorr = xz_var_n / sqrtf(xver_var_n * z_var_n);
1604     *vcorr = *vcorr < 0 ? 0 : *vcorr;
1605   }
1606 }
1607 
1608 // Transforms raw scores into a probability distribution across 16 TX types
score_2D_transform_pow8(float * scores_2D,float shift)1609 static void score_2D_transform_pow8(float *scores_2D, float shift) {
1610   float sum = 0.0f;
1611   int i;
1612 
1613   for (i = 0; i < 16; i++) {
1614     float v, v2, v4;
1615     v = AOMMAX(scores_2D[i] + shift, 0.0f);
1616     v2 = v * v;
1617     v4 = v2 * v2;
1618     scores_2D[i] = v4 * v4;
1619     sum += scores_2D[i];
1620   }
1621   for (i = 0; i < 16; i++) scores_2D[i] /= sum;
1622 }
1623 
1624 // These thresholds were calibrated to provide a certain number of TX types
1625 // pruned by the model on average, i.e. selecting a threshold with index i
1626 // will lead to pruning i+1 TX types on average
1627 static const float *prune_2D_adaptive_thresholds[] = {
1628   // TX_4X4
1629   (float[]){ 0.00549f, 0.01306f, 0.02039f, 0.02747f, 0.03406f, 0.04065f,
1630              0.04724f, 0.05383f, 0.06067f, 0.06799f, 0.07605f, 0.08533f,
1631              0.09778f, 0.11780f },
1632   // TX_8X8
1633   (float[]){ 0.00037f, 0.00183f, 0.00525f, 0.01038f, 0.01697f, 0.02502f,
1634              0.03381f, 0.04333f, 0.05286f, 0.06287f, 0.07434f, 0.08850f,
1635              0.10803f, 0.14124f },
1636   // TX_16X16
1637   (float[]){ 0.01404f, 0.02820f, 0.04211f, 0.05164f, 0.05798f, 0.06335f,
1638              0.06897f, 0.07629f, 0.08875f, 0.11169f },
1639   // TX_32X32
1640   NULL,
1641   // TX_64X64
1642   NULL,
1643   // TX_4X8
1644   (float[]){ 0.00183f, 0.00745f, 0.01428f, 0.02185f, 0.02966f, 0.03723f,
1645              0.04456f, 0.05188f, 0.05920f, 0.06702f, 0.07605f, 0.08704f,
1646              0.10168f, 0.12585f },
1647   // TX_8X4
1648   (float[]){ 0.00085f, 0.00476f, 0.01135f, 0.01892f, 0.02698f, 0.03528f,
1649              0.04358f, 0.05164f, 0.05994f, 0.06848f, 0.07849f, 0.09021f,
1650              0.10583f, 0.13123f },
1651   // TX_8X16
1652   (float[]){ 0.00037f, 0.00232f, 0.00671f, 0.01257f, 0.01965f, 0.02722f,
1653              0.03552f, 0.04382f, 0.05237f, 0.06189f, 0.07336f, 0.08728f,
1654              0.10730f, 0.14221f },
1655   // TX_16X8
1656   (float[]){ 0.00061f, 0.00330f, 0.00818f, 0.01453f, 0.02185f, 0.02966f,
1657              0.03772f, 0.04578f, 0.05383f, 0.06262f, 0.07288f, 0.08582f,
1658              0.10339f, 0.13464f },
1659   // TX_16X32
1660   NULL,
1661   // TX_32X16
1662   NULL,
1663   // TX_32X64
1664   NULL,
1665   // TX_64X32
1666   NULL,
1667   // TX_4X16
1668   (float[]){ 0.00232f, 0.00671f, 0.01257f, 0.01941f, 0.02673f, 0.03430f,
1669              0.04211f, 0.04968f, 0.05750f, 0.06580f, 0.07507f, 0.08655f,
1670              0.10242f, 0.12878f },
1671   // TX_16X4
1672   (float[]){ 0.00110f, 0.00525f, 0.01208f, 0.01990f, 0.02795f, 0.03601f,
1673              0.04358f, 0.05115f, 0.05896f, 0.06702f, 0.07629f, 0.08752f,
1674              0.10217f, 0.12610f },
1675   // TX_8X32
1676   NULL,
1677   // TX_32X8
1678   NULL,
1679   // TX_16X64
1680   NULL,
1681   // TX_64X16
1682   NULL,
1683 };
1684 
prune_tx_2D(MACROBLOCK * x,BLOCK_SIZE bsize,TX_SIZE tx_size,int blk_row,int blk_col,TxSetType tx_set_type,TX_TYPE_PRUNE_MODE prune_mode)1685 static uint16_t prune_tx_2D(MACROBLOCK *x, BLOCK_SIZE bsize, TX_SIZE tx_size,
1686                             int blk_row, int blk_col, TxSetType tx_set_type,
1687                             TX_TYPE_PRUNE_MODE prune_mode) {
1688   static const int tx_type_table_2D[16] = {
1689     DCT_DCT,      DCT_ADST,      DCT_FLIPADST,      V_DCT,
1690     ADST_DCT,     ADST_ADST,     ADST_FLIPADST,     V_ADST,
1691     FLIPADST_DCT, FLIPADST_ADST, FLIPADST_FLIPADST, V_FLIPADST,
1692     H_DCT,        H_ADST,        H_FLIPADST,        IDTX
1693   };
1694   if (tx_set_type != EXT_TX_SET_ALL16 &&
1695       tx_set_type != EXT_TX_SET_DTT9_IDTX_1DDCT)
1696     return 0;
1697   const NN_CONFIG *nn_config_hor = av1_tx_type_nnconfig_map_hor[tx_size];
1698   const NN_CONFIG *nn_config_ver = av1_tx_type_nnconfig_map_ver[tx_size];
1699   if (!nn_config_hor || !nn_config_ver) return 0;  // Model not established yet.
1700 
1701   aom_clear_system_state();
1702   float hfeatures[16], vfeatures[16];
1703   float hscores[4], vscores[4];
1704   float scores_2D[16];
1705   const int bw = tx_size_wide[tx_size];
1706   const int bh = tx_size_high[tx_size];
1707   const int hfeatures_num = bw <= 8 ? bw : bw / 2;
1708   const int vfeatures_num = bh <= 8 ? bh : bh / 2;
1709   assert(hfeatures_num <= 16);
1710   assert(vfeatures_num <= 16);
1711 
1712   const struct macroblock_plane *const p = &x->plane[0];
1713   const int diff_stride = block_size_wide[bsize];
1714   const int16_t *diff = p->src_diff + 4 * blk_row * diff_stride + 4 * blk_col;
1715   get_energy_distribution_finer(diff, diff_stride, bw, bh, hfeatures,
1716                                 vfeatures);
1717   get_horver_correlation_full(diff, diff_stride, bw, bh,
1718                               &hfeatures[hfeatures_num - 1],
1719                               &vfeatures[vfeatures_num - 1]);
1720   av1_nn_predict(hfeatures, nn_config_hor, hscores);
1721   av1_nn_predict(vfeatures, nn_config_ver, vscores);
1722 
1723   float score_2D_average = 0.0f;
1724   for (int i = 0; i < 4; i++) {
1725     float *cur_scores_2D = scores_2D + i * 4;
1726     cur_scores_2D[0] = vscores[i] * hscores[0];
1727     cur_scores_2D[1] = vscores[i] * hscores[1];
1728     cur_scores_2D[2] = vscores[i] * hscores[2];
1729     cur_scores_2D[3] = vscores[i] * hscores[3];
1730     score_2D_average += cur_scores_2D[0] + cur_scores_2D[1] + cur_scores_2D[2] +
1731                         cur_scores_2D[3];
1732   }
1733   score_2D_average /= 16;
1734 
1735   const int prune_aggr_table[2][2] = { { 6, 4 }, { 10, 7 } };
1736   int pruning_aggressiveness = 1;
1737   if (tx_set_type == EXT_TX_SET_ALL16) {
1738     score_2D_transform_pow8(scores_2D, (10 - score_2D_average));
1739     pruning_aggressiveness =
1740         prune_aggr_table[prune_mode - PRUNE_2D_ACCURATE][0];
1741   } else if (tx_set_type == EXT_TX_SET_DTT9_IDTX_1DDCT) {
1742     score_2D_transform_pow8(scores_2D, (20 - score_2D_average));
1743     pruning_aggressiveness =
1744         prune_aggr_table[prune_mode - PRUNE_2D_ACCURATE][1];
1745   }
1746 
1747   // Always keep the TX type with the highest score, prune all others with
1748   // score below score_thresh.
1749   int max_score_i = 0;
1750   float max_score = 0.0f;
1751   for (int i = 0; i < 16; i++) {
1752     if (scores_2D[i] > max_score &&
1753         av1_ext_tx_used[tx_set_type][tx_type_table_2D[i]]) {
1754       max_score = scores_2D[i];
1755       max_score_i = i;
1756     }
1757   }
1758 
1759   const float score_thresh =
1760       prune_2D_adaptive_thresholds[tx_size][pruning_aggressiveness - 1];
1761 
1762   uint16_t prune_bitmask = 0;
1763   for (int i = 0; i < 16; i++) {
1764     if (scores_2D[i] < score_thresh && i != max_score_i)
1765       prune_bitmask |= (1 << tx_type_table_2D[i]);
1766   }
1767   return prune_bitmask;
1768 }
1769 
1770 // ((prune >> vtx_tab[tx_type]) & 1)
1771 static const uint16_t prune_v_mask[] = {
1772   0x0000, 0x0425, 0x108a, 0x14af, 0x4150, 0x4575, 0x51da, 0x55ff,
1773   0xaa00, 0xae25, 0xba8a, 0xbeaf, 0xeb50, 0xef75, 0xfbda, 0xffff,
1774 };
1775 
1776 // ((prune >> (htx_tab[tx_type] + 8)) & 1)
1777 static const uint16_t prune_h_mask[] = {
1778   0x0000, 0x0813, 0x210c, 0x291f, 0x80e0, 0x88f3, 0xa1ec, 0xa9ff,
1779   0x5600, 0x5e13, 0x770c, 0x7f1f, 0xd6e0, 0xdef3, 0xf7ec, 0xffff,
1780 };
1781 
gen_tx_search_prune_mask(int tx_search_prune)1782 static INLINE uint16_t gen_tx_search_prune_mask(int tx_search_prune) {
1783   uint8_t prune_v = tx_search_prune & 0x0F;
1784   uint8_t prune_h = (tx_search_prune >> 8) & 0x0F;
1785   return (prune_v_mask[prune_v] & prune_h_mask[prune_h]);
1786 }
1787 
prune_tx(const AV1_COMP * cpi,BLOCK_SIZE bsize,MACROBLOCK * x,const MACROBLOCKD * const xd,int tx_set_type)1788 static void prune_tx(const AV1_COMP *cpi, BLOCK_SIZE bsize, MACROBLOCK *x,
1789                      const MACROBLOCKD *const xd, int tx_set_type) {
1790   x->tx_search_prune[tx_set_type] = 0;
1791   x->tx_split_prune_flag = 0;
1792   const MB_MODE_INFO *mbmi = xd->mi[0];
1793   if (!is_inter_block(mbmi) || cpi->sf.tx_type_search.prune_mode == NO_PRUNE ||
1794       x->use_default_inter_tx_type || xd->lossless[mbmi->segment_id] ||
1795       x->cb_partition_scan)
1796     return;
1797   int tx_set = ext_tx_set_index[1][tx_set_type];
1798   assert(tx_set >= 0);
1799   const int *tx_set_1D = ext_tx_used_inter_1D[tx_set];
1800   int prune = 0;
1801   switch (cpi->sf.tx_type_search.prune_mode) {
1802     case NO_PRUNE: return;
1803     case PRUNE_ONE:
1804       if (!(tx_set_1D[FLIPADST_1D] & tx_set_1D[ADST_1D])) return;
1805       prune = prune_one_for_sby(cpi, bsize, x, xd);
1806       x->tx_search_prune[tx_set_type] = gen_tx_search_prune_mask(prune);
1807       break;
1808     case PRUNE_TWO:
1809       if (!(tx_set_1D[FLIPADST_1D] & tx_set_1D[ADST_1D])) {
1810         if (!(tx_set_1D[DCT_1D] & tx_set_1D[IDTX_1D])) return;
1811         prune = prune_two_for_sby(cpi, bsize, x, xd, 0, 1);
1812       } else if (!(tx_set_1D[DCT_1D] & tx_set_1D[IDTX_1D])) {
1813         prune = prune_two_for_sby(cpi, bsize, x, xd, 1, 0);
1814       } else {
1815         prune = prune_two_for_sby(cpi, bsize, x, xd, 1, 1);
1816       }
1817       x->tx_search_prune[tx_set_type] = gen_tx_search_prune_mask(prune);
1818       break;
1819     case PRUNE_2D_ACCURATE:
1820     case PRUNE_2D_FAST: break;
1821     default: assert(0);
1822   }
1823 }
1824 
model_rd_from_sse(const AV1_COMP * const cpi,const MACROBLOCK * const x,BLOCK_SIZE plane_bsize,int plane,int64_t sse,int num_samples,int * rate,int64_t * dist)1825 static void model_rd_from_sse(const AV1_COMP *const cpi,
1826                               const MACROBLOCK *const x, BLOCK_SIZE plane_bsize,
1827                               int plane, int64_t sse, int num_samples,
1828                               int *rate, int64_t *dist) {
1829   (void)num_samples;
1830   const MACROBLOCKD *const xd = &x->e_mbd;
1831   const struct macroblockd_plane *const pd = &xd->plane[plane];
1832   const int dequant_shift =
1833       (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) ? xd->bd - 5 : 3;
1834 
1835   // Fast approximate the modelling function.
1836   if (cpi->sf.simple_model_rd_from_var) {
1837     const int64_t square_error = sse;
1838     int quantizer = pd->dequant_Q3[1] >> dequant_shift;
1839     if (quantizer < 120)
1840       *rate = (int)AOMMIN(
1841           (square_error * (280 - quantizer)) >> (16 - AV1_PROB_COST_SHIFT),
1842           INT_MAX);
1843     else
1844       *rate = 0;
1845     assert(*rate >= 0);
1846     *dist = (square_error * quantizer) >> 8;
1847   } else {
1848     av1_model_rd_from_var_lapndz(sse, num_pels_log2_lookup[plane_bsize],
1849                                  pd->dequant_Q3[1] >> dequant_shift, rate,
1850                                  dist);
1851   }
1852   *dist <<= 4;
1853 }
1854 
1855 #if CONFIG_COLLECT_INTER_MODE_RD_STATS
get_sse(const AV1_COMP * cpi,const MACROBLOCK * x)1856 static int64_t get_sse(const AV1_COMP *cpi, const MACROBLOCK *x) {
1857   const AV1_COMMON *cm = &cpi->common;
1858   const int num_planes = av1_num_planes(cm);
1859   const MACROBLOCKD *xd = &x->e_mbd;
1860   const MB_MODE_INFO *mbmi = xd->mi[0];
1861   int64_t total_sse = 0;
1862   for (int plane = 0; plane < num_planes; ++plane) {
1863     const struct macroblock_plane *const p = &x->plane[plane];
1864     const struct macroblockd_plane *const pd = &xd->plane[plane];
1865     const BLOCK_SIZE bs = get_plane_block_size(mbmi->sb_type, pd->subsampling_x,
1866                                                pd->subsampling_y);
1867     unsigned int sse;
1868 
1869     if (x->skip_chroma_rd && plane) continue;
1870 
1871     cpi->fn_ptr[bs].vf(p->src.buf, p->src.stride, pd->dst.buf, pd->dst.stride,
1872                        &sse);
1873     total_sse += sse;
1874   }
1875   total_sse <<= 4;
1876   return total_sse;
1877 }
1878 #endif
1879 
model_rd_for_sb(const AV1_COMP * const cpi,BLOCK_SIZE bsize,MACROBLOCK * x,MACROBLOCKD * xd,int plane_from,int plane_to,int mi_row,int mi_col,int * out_rate_sum,int64_t * out_dist_sum,int * skip_txfm_sb,int64_t * skip_sse_sb,int * plane_rate,int64_t * plane_sse,int64_t * plane_dist)1880 static void model_rd_for_sb(const AV1_COMP *const cpi, BLOCK_SIZE bsize,
1881                             MACROBLOCK *x, MACROBLOCKD *xd, int plane_from,
1882                             int plane_to, int mi_row, int mi_col,
1883                             int *out_rate_sum, int64_t *out_dist_sum,
1884                             int *skip_txfm_sb, int64_t *skip_sse_sb,
1885                             int *plane_rate, int64_t *plane_sse,
1886                             int64_t *plane_dist) {
1887   // Note our transform coeffs are 8 times an orthogonal transform.
1888   // Hence quantizer step is also 8 times. To get effective quantizer
1889   // we need to divide by 8 before sending to modeling function.
1890   int plane;
1891   (void)mi_row;
1892   (void)mi_col;
1893   const int ref = xd->mi[0]->ref_frame[0];
1894 
1895   int64_t rate_sum = 0;
1896   int64_t dist_sum = 0;
1897   int64_t total_sse = 0;
1898 
1899   for (plane = plane_from; plane <= plane_to; ++plane) {
1900     struct macroblock_plane *const p = &x->plane[plane];
1901     struct macroblockd_plane *const pd = &xd->plane[plane];
1902     const BLOCK_SIZE plane_bsize =
1903         get_plane_block_size(bsize, pd->subsampling_x, pd->subsampling_y);
1904     const int bw = block_size_wide[plane_bsize];
1905     const int bh = block_size_high[plane_bsize];
1906     int64_t sse;
1907     int rate;
1908     int64_t dist;
1909 
1910     if (x->skip_chroma_rd && plane) continue;
1911 
1912     if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) {
1913       sse = aom_highbd_sse(p->src.buf, p->src.stride, pd->dst.buf,
1914                            pd->dst.stride, bw, bh);
1915     } else {
1916       sse = aom_sse(p->src.buf, p->src.stride, pd->dst.buf, pd->dst.stride, bw,
1917                     bh);
1918     }
1919     sse = ROUND_POWER_OF_TWO(sse, (xd->bd - 8) * 2);
1920 
1921     model_rd_from_sse(cpi, x, plane_bsize, plane, sse, bw * bh, &rate, &dist);
1922 
1923     if (plane == 0) x->pred_sse[ref] = (unsigned int)AOMMIN(sse, UINT_MAX);
1924 
1925     total_sse += sse;
1926     rate_sum += rate;
1927     dist_sum += dist;
1928     if (plane_rate) plane_rate[plane] = rate;
1929     if (plane_sse) plane_sse[plane] = sse;
1930     if (plane_dist) plane_dist[plane] = dist;
1931     assert(rate_sum >= 0);
1932   }
1933 
1934   if (skip_txfm_sb) *skip_txfm_sb = total_sse == 0;
1935   if (skip_sse_sb) *skip_sse_sb = total_sse << 4;
1936   rate_sum = AOMMIN(rate_sum, INT_MAX);
1937   *out_rate_sum = (int)rate_sum;
1938   *out_dist_sum = dist_sum;
1939 }
1940 
check_block_skip(const AV1_COMP * const cpi,BLOCK_SIZE bsize,MACROBLOCK * x,MACROBLOCKD * xd,int plane_from,int plane_to,int * skip_txfm_sb)1941 static void check_block_skip(const AV1_COMP *const cpi, BLOCK_SIZE bsize,
1942                              MACROBLOCK *x, MACROBLOCKD *xd, int plane_from,
1943                              int plane_to, int *skip_txfm_sb) {
1944   *skip_txfm_sb = 1;
1945   for (int plane = plane_from; plane <= plane_to; ++plane) {
1946     struct macroblock_plane *const p = &x->plane[plane];
1947     struct macroblockd_plane *const pd = &xd->plane[plane];
1948     const BLOCK_SIZE bs =
1949         get_plane_block_size(bsize, pd->subsampling_x, pd->subsampling_y);
1950     unsigned int sse;
1951 
1952     if (x->skip_chroma_rd && plane) continue;
1953 
1954     // Since fast HBD variance functions scale down sse by 4 bit, we first use
1955     // fast vf implementation to rule out blocks with non-zero scaled sse. Then,
1956     // only if the source is HBD and the scaled sse is 0, accurate sse
1957     // computation is applied to determine if the sse is really 0. This step is
1958     // necessary for HBD lossless coding.
1959     cpi->fn_ptr[bs].vf(p->src.buf, p->src.stride, pd->dst.buf, pd->dst.stride,
1960                        &sse);
1961     if (sse) {
1962       *skip_txfm_sb = 0;
1963       return;
1964     } else if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) {
1965       uint64_t sse64 = aom_highbd_sse_odd_size(
1966           p->src.buf, p->src.stride, pd->dst.buf, pd->dst.stride,
1967           block_size_wide[bs], block_size_high[bs]);
1968 
1969       if (sse64) {
1970         *skip_txfm_sb = 0;
1971         return;
1972       }
1973     }
1974   }
1975   return;
1976 }
1977 
av1_block_error_c(const tran_low_t * coeff,const tran_low_t * dqcoeff,intptr_t block_size,int64_t * ssz)1978 int64_t av1_block_error_c(const tran_low_t *coeff, const tran_low_t *dqcoeff,
1979                           intptr_t block_size, int64_t *ssz) {
1980   int i;
1981   int64_t error = 0, sqcoeff = 0;
1982 
1983   for (i = 0; i < block_size; i++) {
1984     const int diff = coeff[i] - dqcoeff[i];
1985     error += diff * diff;
1986     sqcoeff += coeff[i] * coeff[i];
1987   }
1988 
1989   *ssz = sqcoeff;
1990   return error;
1991 }
1992 
av1_highbd_block_error_c(const tran_low_t * coeff,const tran_low_t * dqcoeff,intptr_t block_size,int64_t * ssz,int bd)1993 int64_t av1_highbd_block_error_c(const tran_low_t *coeff,
1994                                  const tran_low_t *dqcoeff, intptr_t block_size,
1995                                  int64_t *ssz, int bd) {
1996   int i;
1997   int64_t error = 0, sqcoeff = 0;
1998   int shift = 2 * (bd - 8);
1999   int rounding = shift > 0 ? 1 << (shift - 1) : 0;
2000 
2001   for (i = 0; i < block_size; i++) {
2002     const int64_t diff = coeff[i] - dqcoeff[i];
2003     error += diff * diff;
2004     sqcoeff += (int64_t)coeff[i] * (int64_t)coeff[i];
2005   }
2006   assert(error >= 0 && sqcoeff >= 0);
2007   error = (error + rounding) >> shift;
2008   sqcoeff = (sqcoeff + rounding) >> shift;
2009 
2010   *ssz = sqcoeff;
2011   return error;
2012 }
2013 
2014 // Get transform block visible dimensions cropped to the MI units.
get_txb_dimensions(const MACROBLOCKD * xd,int plane,BLOCK_SIZE plane_bsize,int blk_row,int blk_col,BLOCK_SIZE tx_bsize,int * width,int * height,int * visible_width,int * visible_height)2015 static void get_txb_dimensions(const MACROBLOCKD *xd, int plane,
2016                                BLOCK_SIZE plane_bsize, int blk_row, int blk_col,
2017                                BLOCK_SIZE tx_bsize, int *width, int *height,
2018                                int *visible_width, int *visible_height) {
2019   assert(tx_bsize <= plane_bsize);
2020   int txb_height = block_size_high[tx_bsize];
2021   int txb_width = block_size_wide[tx_bsize];
2022   const int block_height = block_size_high[plane_bsize];
2023   const int block_width = block_size_wide[plane_bsize];
2024   const struct macroblockd_plane *const pd = &xd->plane[plane];
2025   // TODO(aconverse@google.com): Investigate using crop_width/height here rather
2026   // than the MI size
2027   const int block_rows =
2028       (xd->mb_to_bottom_edge >= 0)
2029           ? block_height
2030           : (xd->mb_to_bottom_edge >> (3 + pd->subsampling_y)) + block_height;
2031   const int block_cols =
2032       (xd->mb_to_right_edge >= 0)
2033           ? block_width
2034           : (xd->mb_to_right_edge >> (3 + pd->subsampling_x)) + block_width;
2035   const int tx_unit_size = tx_size_wide_log2[0];
2036   if (width) *width = txb_width;
2037   if (height) *height = txb_height;
2038   *visible_width = clamp(block_cols - (blk_col << tx_unit_size), 0, txb_width);
2039   *visible_height =
2040       clamp(block_rows - (blk_row << tx_unit_size), 0, txb_height);
2041 }
2042 
2043 // Compute the pixel domain distortion from src and dst on all visible 4x4s in
2044 // the
2045 // transform block.
pixel_dist(const AV1_COMP * const cpi,const MACROBLOCK * x,int plane,const uint8_t * src,const int src_stride,const uint8_t * dst,const int dst_stride,int blk_row,int blk_col,const BLOCK_SIZE plane_bsize,const BLOCK_SIZE tx_bsize)2046 static unsigned pixel_dist(const AV1_COMP *const cpi, const MACROBLOCK *x,
2047                            int plane, const uint8_t *src, const int src_stride,
2048                            const uint8_t *dst, const int dst_stride,
2049                            int blk_row, int blk_col,
2050                            const BLOCK_SIZE plane_bsize,
2051                            const BLOCK_SIZE tx_bsize) {
2052   int txb_rows, txb_cols, visible_rows, visible_cols;
2053   const MACROBLOCKD *xd = &x->e_mbd;
2054 
2055   get_txb_dimensions(xd, plane, plane_bsize, blk_row, blk_col, tx_bsize,
2056                      &txb_cols, &txb_rows, &visible_cols, &visible_rows);
2057   assert(visible_rows > 0);
2058   assert(visible_cols > 0);
2059 
2060 #if CONFIG_DIST_8X8
2061   if (x->using_dist_8x8 && plane == 0)
2062     return (unsigned)av1_dist_8x8(cpi, x, src, src_stride, dst, dst_stride,
2063                                   tx_bsize, txb_cols, txb_rows, visible_cols,
2064                                   visible_rows, x->qindex);
2065 #endif  // CONFIG_DIST_8X8
2066 
2067   unsigned sse = pixel_dist_visible_only(cpi, x, src, src_stride, dst,
2068                                          dst_stride, tx_bsize, txb_rows,
2069                                          txb_cols, visible_rows, visible_cols);
2070 
2071   return sse;
2072 }
2073 
2074 // Compute the pixel domain distortion from diff on all visible 4x4s in the
2075 // transform block.
pixel_diff_dist(const MACROBLOCK * x,int plane,int blk_row,int blk_col,const BLOCK_SIZE plane_bsize,const BLOCK_SIZE tx_bsize)2076 static INLINE int64_t pixel_diff_dist(const MACROBLOCK *x, int plane,
2077                                       int blk_row, int blk_col,
2078                                       const BLOCK_SIZE plane_bsize,
2079                                       const BLOCK_SIZE tx_bsize) {
2080   int visible_rows, visible_cols;
2081   const MACROBLOCKD *xd = &x->e_mbd;
2082   get_txb_dimensions(xd, plane, plane_bsize, blk_row, blk_col, tx_bsize, NULL,
2083                      NULL, &visible_cols, &visible_rows);
2084   const int diff_stride = block_size_wide[plane_bsize];
2085   const int16_t *diff = x->plane[plane].src_diff;
2086 #if CONFIG_DIST_8X8
2087   int txb_height = block_size_high[tx_bsize];
2088   int txb_width = block_size_wide[tx_bsize];
2089   if (x->using_dist_8x8 && plane == 0) {
2090     const int src_stride = x->plane[plane].src.stride;
2091     const int src_idx = (blk_row * src_stride + blk_col)
2092                         << tx_size_wide_log2[0];
2093     const int diff_idx = (blk_row * diff_stride + blk_col)
2094                          << tx_size_wide_log2[0];
2095     const uint8_t *src = &x->plane[plane].src.buf[src_idx];
2096     return dist_8x8_diff(x, src, src_stride, diff + diff_idx, diff_stride,
2097                          txb_width, txb_height, visible_cols, visible_rows,
2098                          x->qindex);
2099   }
2100 #endif
2101   diff += ((blk_row * diff_stride + blk_col) << tx_size_wide_log2[0]);
2102   return aom_sum_squares_2d_i16(diff, diff_stride, visible_cols, visible_rows);
2103 }
2104 
av1_count_colors(const uint8_t * src,int stride,int rows,int cols,int * val_count)2105 int av1_count_colors(const uint8_t *src, int stride, int rows, int cols,
2106                      int *val_count) {
2107   const int max_pix_val = 1 << 8;
2108   memset(val_count, 0, max_pix_val * sizeof(val_count[0]));
2109   for (int r = 0; r < rows; ++r) {
2110     for (int c = 0; c < cols; ++c) {
2111       const int this_val = src[r * stride + c];
2112       assert(this_val < max_pix_val);
2113       ++val_count[this_val];
2114     }
2115   }
2116   int n = 0;
2117   for (int i = 0; i < max_pix_val; ++i) {
2118     if (val_count[i]) ++n;
2119   }
2120   return n;
2121 }
2122 
av1_count_colors_highbd(const uint8_t * src8,int stride,int rows,int cols,int bit_depth,int * val_count)2123 int av1_count_colors_highbd(const uint8_t *src8, int stride, int rows, int cols,
2124                             int bit_depth, int *val_count) {
2125   assert(bit_depth <= 12);
2126   const int max_pix_val = 1 << bit_depth;
2127   const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
2128   memset(val_count, 0, max_pix_val * sizeof(val_count[0]));
2129   for (int r = 0; r < rows; ++r) {
2130     for (int c = 0; c < cols; ++c) {
2131       const int this_val = src[r * stride + c];
2132       assert(this_val < max_pix_val);
2133       if (this_val >= max_pix_val) return 0;
2134       ++val_count[this_val];
2135     }
2136   }
2137   int n = 0;
2138   for (int i = 0; i < max_pix_val; ++i) {
2139     if (val_count[i]) ++n;
2140   }
2141   return n;
2142 }
2143 
inverse_transform_block_facade(MACROBLOCKD * xd,int plane,int block,int blk_row,int blk_col,int eob,int reduced_tx_set)2144 static void inverse_transform_block_facade(MACROBLOCKD *xd, int plane,
2145                                            int block, int blk_row, int blk_col,
2146                                            int eob, int reduced_tx_set) {
2147   struct macroblockd_plane *const pd = &xd->plane[plane];
2148   tran_low_t *dqcoeff = BLOCK_OFFSET(pd->dqcoeff, block);
2149   const PLANE_TYPE plane_type = get_plane_type(plane);
2150   const TX_SIZE tx_size = av1_get_tx_size(plane, xd);
2151   const TX_TYPE tx_type = av1_get_tx_type(plane_type, xd, blk_row, blk_col,
2152                                           tx_size, reduced_tx_set);
2153   const int dst_stride = pd->dst.stride;
2154   uint8_t *dst =
2155       &pd->dst.buf[(blk_row * dst_stride + blk_col) << tx_size_wide_log2[0]];
2156   av1_inverse_transform_block(xd, dqcoeff, plane, tx_type, tx_size, dst,
2157                               dst_stride, eob, reduced_tx_set);
2158 }
2159 
2160 static int find_tx_size_rd_info(TXB_RD_RECORD *cur_record, const uint32_t hash);
2161 
get_intra_txb_hash(MACROBLOCK * x,int plane,int blk_row,int blk_col,BLOCK_SIZE plane_bsize,TX_SIZE tx_size)2162 static uint32_t get_intra_txb_hash(MACROBLOCK *x, int plane, int blk_row,
2163                                    int blk_col, BLOCK_SIZE plane_bsize,
2164                                    TX_SIZE tx_size) {
2165   int16_t tmp_data[64 * 64];
2166   const int diff_stride = block_size_wide[plane_bsize];
2167   const int16_t *diff = x->plane[plane].src_diff;
2168   const int16_t *cur_diff_row = diff + 4 * blk_row * diff_stride + 4 * blk_col;
2169   const int txb_w = tx_size_wide[tx_size];
2170   const int txb_h = tx_size_high[tx_size];
2171   uint8_t *hash_data = (uint8_t *)cur_diff_row;
2172   if (txb_w != diff_stride) {
2173     int16_t *cur_hash_row = tmp_data;
2174     for (int i = 0; i < txb_h; i++) {
2175       memcpy(cur_hash_row, cur_diff_row, sizeof(*diff) * txb_w);
2176       cur_hash_row += txb_w;
2177       cur_diff_row += diff_stride;
2178     }
2179     hash_data = (uint8_t *)tmp_data;
2180   }
2181   CRC32C *crc = &x->mb_rd_record.crc_calculator;
2182   const uint32_t hash = av1_get_crc32c_value(crc, hash_data, 2 * txb_w * txb_h);
2183   return (hash << 5) + tx_size;
2184 }
2185 
dist_block_tx_domain(MACROBLOCK * x,int plane,int block,TX_SIZE tx_size,int64_t * out_dist,int64_t * out_sse)2186 static INLINE void dist_block_tx_domain(MACROBLOCK *x, int plane, int block,
2187                                         TX_SIZE tx_size, int64_t *out_dist,
2188                                         int64_t *out_sse) {
2189   MACROBLOCKD *const xd = &x->e_mbd;
2190   const struct macroblock_plane *const p = &x->plane[plane];
2191   const struct macroblockd_plane *const pd = &xd->plane[plane];
2192   // Transform domain distortion computation is more efficient as it does
2193   // not involve an inverse transform, but it is less accurate.
2194   const int buffer_length = av1_get_max_eob(tx_size);
2195   int64_t this_sse;
2196   // TX-domain results need to shift down to Q2/D10 to match pixel
2197   // domain distortion values which are in Q2^2
2198   int shift = (MAX_TX_SCALE - av1_get_tx_scale(tx_size)) * 2;
2199   tran_low_t *const coeff = BLOCK_OFFSET(p->coeff, block);
2200   tran_low_t *const dqcoeff = BLOCK_OFFSET(pd->dqcoeff, block);
2201 
2202   if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH)
2203     *out_dist = av1_highbd_block_error(coeff, dqcoeff, buffer_length, &this_sse,
2204                                        xd->bd);
2205   else
2206     *out_dist = av1_block_error(coeff, dqcoeff, buffer_length, &this_sse);
2207 
2208   *out_dist = RIGHT_SIGNED_SHIFT(*out_dist, shift);
2209   *out_sse = RIGHT_SIGNED_SHIFT(this_sse, shift);
2210 }
2211 
dist_block_px_domain(const AV1_COMP * cpi,MACROBLOCK * x,int plane,BLOCK_SIZE plane_bsize,int block,int blk_row,int blk_col,TX_SIZE tx_size)2212 static INLINE int64_t dist_block_px_domain(const AV1_COMP *cpi, MACROBLOCK *x,
2213                                            int plane, BLOCK_SIZE plane_bsize,
2214                                            int block, int blk_row, int blk_col,
2215                                            TX_SIZE tx_size) {
2216   MACROBLOCKD *const xd = &x->e_mbd;
2217   const struct macroblock_plane *const p = &x->plane[plane];
2218   const struct macroblockd_plane *const pd = &xd->plane[plane];
2219   const uint16_t eob = p->eobs[block];
2220   const BLOCK_SIZE tx_bsize = txsize_to_bsize[tx_size];
2221   const int bsw = block_size_wide[tx_bsize];
2222   const int bsh = block_size_high[tx_bsize];
2223   const int src_stride = x->plane[plane].src.stride;
2224   const int dst_stride = xd->plane[plane].dst.stride;
2225   // Scale the transform block index to pixel unit.
2226   const int src_idx = (blk_row * src_stride + blk_col) << tx_size_wide_log2[0];
2227   const int dst_idx = (blk_row * dst_stride + blk_col) << tx_size_wide_log2[0];
2228   const uint8_t *src = &x->plane[plane].src.buf[src_idx];
2229   const uint8_t *dst = &xd->plane[plane].dst.buf[dst_idx];
2230   const tran_low_t *dqcoeff = BLOCK_OFFSET(pd->dqcoeff, block);
2231 
2232   assert(cpi != NULL);
2233   assert(tx_size_wide_log2[0] == tx_size_high_log2[0]);
2234 
2235   uint8_t *recon;
2236   DECLARE_ALIGNED(16, uint16_t, recon16[MAX_TX_SQUARE]);
2237 
2238   if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) {
2239     recon = CONVERT_TO_BYTEPTR(recon16);
2240     av1_highbd_convolve_2d_copy_sr(CONVERT_TO_SHORTPTR(dst), dst_stride,
2241                                    CONVERT_TO_SHORTPTR(recon), MAX_TX_SIZE, bsw,
2242                                    bsh, NULL, NULL, 0, 0, NULL, xd->bd);
2243   } else {
2244     recon = (uint8_t *)recon16;
2245     av1_convolve_2d_copy_sr(dst, dst_stride, recon, MAX_TX_SIZE, bsw, bsh, NULL,
2246                             NULL, 0, 0, NULL);
2247   }
2248 
2249   const PLANE_TYPE plane_type = get_plane_type(plane);
2250   TX_TYPE tx_type = av1_get_tx_type(plane_type, xd, blk_row, blk_col, tx_size,
2251                                     cpi->common.reduced_tx_set_used);
2252   av1_inverse_transform_block(xd, dqcoeff, plane, tx_type, tx_size, recon,
2253                               MAX_TX_SIZE, eob,
2254                               cpi->common.reduced_tx_set_used);
2255 
2256   return 16 * pixel_dist(cpi, x, plane, src, src_stride, recon, MAX_TX_SIZE,
2257                          blk_row, blk_col, plane_bsize, tx_bsize);
2258 }
2259 
get_mean(const int16_t * diff,int stride,int w,int h)2260 static double get_mean(const int16_t *diff, int stride, int w, int h) {
2261   double sum = 0.0;
2262   for (int j = 0; j < h; ++j) {
2263     for (int i = 0; i < w; ++i) {
2264       sum += diff[j * stride + i];
2265     }
2266   }
2267   assert(w > 0 && h > 0);
2268   return sum / (w * h);
2269 }
2270 
get_sse_norm(const int16_t * diff,int stride,int w,int h)2271 static double get_sse_norm(const int16_t *diff, int stride, int w, int h) {
2272   double sum = 0.0;
2273   for (int j = 0; j < h; ++j) {
2274     for (int i = 0; i < w; ++i) {
2275       const int err = diff[j * stride + i];
2276       sum += err * err;
2277     }
2278   }
2279   assert(w > 0 && h > 0);
2280   return sum / (w * h);
2281 }
2282 
get_sad_norm(const int16_t * diff,int stride,int w,int h)2283 static double get_sad_norm(const int16_t *diff, int stride, int w, int h) {
2284   double sum = 0.0;
2285   for (int j = 0; j < h; ++j) {
2286     for (int i = 0; i < w; ++i) {
2287       sum += abs(diff[j * stride + i]);
2288     }
2289   }
2290   assert(w > 0 && h > 0);
2291   return sum / (w * h);
2292 }
2293 
get_2x2_normalized_sses_and_sads(const AV1_COMP * const cpi,BLOCK_SIZE tx_bsize,const uint8_t * const src,int src_stride,const uint8_t * const dst,int dst_stride,const int16_t * const src_diff,int diff_stride,double * const sse_norm_arr,double * const sad_norm_arr)2294 static void get_2x2_normalized_sses_and_sads(
2295     const AV1_COMP *const cpi, BLOCK_SIZE tx_bsize, const uint8_t *const src,
2296     int src_stride, const uint8_t *const dst, int dst_stride,
2297     const int16_t *const src_diff, int diff_stride, double *const sse_norm_arr,
2298     double *const sad_norm_arr) {
2299   const BLOCK_SIZE tx_bsize_half =
2300       get_partition_subsize(tx_bsize, PARTITION_SPLIT);
2301   if (tx_bsize_half == BLOCK_INVALID) {  // manually calculate stats
2302     const int half_width = block_size_wide[tx_bsize] / 2;
2303     const int half_height = block_size_high[tx_bsize] / 2;
2304     for (int row = 0; row < 2; ++row) {
2305       for (int col = 0; col < 2; ++col) {
2306         const int16_t *const this_src_diff =
2307             src_diff + row * half_height * diff_stride + col * half_width;
2308         if (sse_norm_arr) {
2309           sse_norm_arr[row * 2 + col] =
2310               get_sse_norm(this_src_diff, diff_stride, half_width, half_height);
2311         }
2312         if (sad_norm_arr) {
2313           sad_norm_arr[row * 2 + col] =
2314               get_sad_norm(this_src_diff, diff_stride, half_width, half_height);
2315         }
2316       }
2317     }
2318   } else {  // use function pointers to calculate stats
2319     const int half_width = block_size_wide[tx_bsize_half];
2320     const int half_height = block_size_high[tx_bsize_half];
2321     const int num_samples_half = half_width * half_height;
2322     for (int row = 0; row < 2; ++row) {
2323       for (int col = 0; col < 2; ++col) {
2324         const uint8_t *const this_src =
2325             src + row * half_height * src_stride + col * half_width;
2326         const uint8_t *const this_dst =
2327             dst + row * half_height * dst_stride + col * half_width;
2328 
2329         if (sse_norm_arr) {
2330           unsigned int this_sse;
2331           cpi->fn_ptr[tx_bsize_half].vf(this_src, src_stride, this_dst,
2332                                         dst_stride, &this_sse);
2333           sse_norm_arr[row * 2 + col] = (double)this_sse / num_samples_half;
2334         }
2335 
2336         if (sad_norm_arr) {
2337           const unsigned int this_sad = cpi->fn_ptr[tx_bsize_half].sdf(
2338               this_src, src_stride, this_dst, dst_stride);
2339           sad_norm_arr[row * 2 + col] = (double)this_sad / num_samples_half;
2340         }
2341       }
2342     }
2343   }
2344 }
2345 
2346 // NOTE: CONFIG_COLLECT_RD_STATS has 3 possible values
2347 // 0: Do not collect any RD stats
2348 // 1: Collect RD stats for transform units
2349 // 2: Collect RD stats for partition units
2350 #if CONFIG_COLLECT_RD_STATS
2351 
2352 #if CONFIG_COLLECT_RD_STATS == 1
PrintTransformUnitStats(const AV1_COMP * const cpi,MACROBLOCK * x,const RD_STATS * const rd_stats,int blk_row,int blk_col,BLOCK_SIZE plane_bsize,TX_SIZE tx_size,TX_TYPE tx_type,int64_t rd)2353 static void PrintTransformUnitStats(const AV1_COMP *const cpi, MACROBLOCK *x,
2354                                     const RD_STATS *const rd_stats, int blk_row,
2355                                     int blk_col, BLOCK_SIZE plane_bsize,
2356                                     TX_SIZE tx_size, TX_TYPE tx_type,
2357                                     int64_t rd) {
2358   if (rd_stats->rate == INT_MAX || rd_stats->dist == INT64_MAX) return;
2359 
2360   // Generate small sample to restrict output size.
2361   static unsigned int seed = 21743;
2362   if (lcg_rand16(&seed) % 256 > 0) return;
2363 
2364   const char output_file[] = "tu_stats.txt";
2365   FILE *fout = fopen(output_file, "a");
2366   if (!fout) return;
2367 
2368   const BLOCK_SIZE tx_bsize = txsize_to_bsize[tx_size];
2369   const MACROBLOCKD *const xd = &x->e_mbd;
2370   const int plane = 0;
2371   struct macroblock_plane *const p = &x->plane[plane];
2372   const struct macroblockd_plane *const pd = &xd->plane[plane];
2373   const int txw = tx_size_wide[tx_size];
2374   const int txh = tx_size_high[tx_size];
2375   const int dequant_shift =
2376       (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) ? xd->bd - 5 : 3;
2377   const int q_step = pd->dequant_Q3[1] >> dequant_shift;
2378   const double num_samples = txw * txh;
2379 
2380   const double rate_norm = (double)rd_stats->rate / num_samples;
2381   const double dist_norm = (double)rd_stats->dist / num_samples;
2382 
2383   fprintf(fout, "%g %g", rate_norm, dist_norm);
2384 
2385   const int src_stride = p->src.stride;
2386   const uint8_t *const src =
2387       &p->src.buf[(blk_row * src_stride + blk_col) << tx_size_wide_log2[0]];
2388   const int dst_stride = pd->dst.stride;
2389   const uint8_t *const dst =
2390       &pd->dst.buf[(blk_row * dst_stride + blk_col) << tx_size_wide_log2[0]];
2391   unsigned int sse;
2392   cpi->fn_ptr[tx_bsize].vf(src, src_stride, dst, dst_stride, &sse);
2393   const double sse_norm = (double)sse / num_samples;
2394 
2395   const unsigned int sad =
2396       cpi->fn_ptr[tx_bsize].sdf(src, src_stride, dst, dst_stride);
2397   const double sad_norm = (double)sad / num_samples;
2398 
2399   fprintf(fout, " %g %g", sse_norm, sad_norm);
2400 
2401   const int diff_stride = block_size_wide[plane_bsize];
2402   const int16_t *const src_diff =
2403       &p->src_diff[(blk_row * diff_stride + blk_col) << tx_size_wide_log2[0]];
2404 
2405   double sse_norm_arr[4], sad_norm_arr[4];
2406   get_2x2_normalized_sses_and_sads(cpi, tx_bsize, src, src_stride, dst,
2407                                    dst_stride, src_diff, diff_stride,
2408                                    sse_norm_arr, sad_norm_arr);
2409   for (int i = 0; i < 4; ++i) {
2410     fprintf(fout, " %g", sse_norm_arr[i]);
2411   }
2412   for (int i = 0; i < 4; ++i) {
2413     fprintf(fout, " %g", sad_norm_arr[i]);
2414   }
2415 
2416   const TX_TYPE_1D tx_type_1d_row = htx_tab[tx_type];
2417   const TX_TYPE_1D tx_type_1d_col = vtx_tab[tx_type];
2418 
2419   fprintf(fout, " %d %d %d %d %d", q_step, tx_size_wide[tx_size],
2420           tx_size_high[tx_size], tx_type_1d_row, tx_type_1d_col);
2421 
2422   int model_rate;
2423   int64_t model_dist;
2424   model_rd_sse_fn[MODELRD_CURVFIT](cpi, x, tx_bsize, plane, sse, num_samples,
2425                                    &model_rate, &model_dist);
2426   const double model_rate_norm = (double)model_rate / num_samples;
2427   const double model_dist_norm = (double)model_dist / num_samples;
2428   fprintf(fout, " %g %g", model_rate_norm, model_dist_norm);
2429 
2430   const double mean = get_mean(src_diff, diff_stride, txw, txh);
2431   double hor_corr, vert_corr;
2432   get_horver_correlation(src_diff, diff_stride, txw, txh, &hor_corr,
2433                          &vert_corr);
2434   fprintf(fout, " %g %g %g", mean, hor_corr, vert_corr);
2435 
2436   double hdist[4] = { 0 }, vdist[4] = { 0 };
2437   get_energy_distribution_fine(cpi, tx_bsize, src, src_stride, dst, dst_stride,
2438                                1, hdist, vdist);
2439   fprintf(fout, " %g %g %g %g %g %g %g %g", hdist[0], hdist[1], hdist[2],
2440           hdist[3], vdist[0], vdist[1], vdist[2], vdist[3]);
2441 
2442   fprintf(fout, " %d %" PRId64, x->rdmult, rd);
2443 
2444   fprintf(fout, "\n");
2445   fclose(fout);
2446 }
2447 #endif  // CONFIG_COLLECT_RD_STATS == 1
2448 
2449 #if CONFIG_COLLECT_RD_STATS >= 2
PrintPredictionUnitStats(const AV1_COMP * const cpi,MACROBLOCK * x,const RD_STATS * const rd_stats,BLOCK_SIZE plane_bsize)2450 static void PrintPredictionUnitStats(const AV1_COMP *const cpi, MACROBLOCK *x,
2451                                      const RD_STATS *const rd_stats,
2452                                      BLOCK_SIZE plane_bsize) {
2453   if (rd_stats->invalid_rate) return;
2454   if (rd_stats->rate == INT_MAX || rd_stats->dist == INT64_MAX) return;
2455 
2456   // Generate small sample to restrict output size.
2457   static unsigned int seed = 95014;
2458   if (lcg_rand16(&seed) % 256 > 0) return;
2459 
2460   const char output_file[] = "pu_stats.txt";
2461   FILE *fout = fopen(output_file, "a");
2462   if (!fout) return;
2463 
2464   const MACROBLOCKD *const xd = &x->e_mbd;
2465   const int plane = 0;
2466   struct macroblock_plane *const p = &x->plane[plane];
2467   const struct macroblockd_plane *const pd = &xd->plane[plane];
2468   const int diff_stride = block_size_wide[plane_bsize];
2469   int bw, bh;
2470   get_txb_dimensions(xd, plane, plane_bsize, 0, 0, plane_bsize, NULL, NULL, &bw,
2471                      &bh);
2472   const int num_samples = bw * bh;
2473   const int dequant_shift =
2474       (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) ? xd->bd - 5 : 3;
2475   const int q_step = pd->dequant_Q3[1] >> dequant_shift;
2476 
2477   const double rate_norm = (double)rd_stats->rate / num_samples;
2478   const double dist_norm = (double)rd_stats->dist / num_samples;
2479   const double rdcost_norm =
2480       (double)RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist) / num_samples;
2481 
2482   fprintf(fout, "%g %g %g", rate_norm, dist_norm, rdcost_norm);
2483 
2484   const int src_stride = p->src.stride;
2485   const uint8_t *const src = p->src.buf;
2486   const int dst_stride = pd->dst.stride;
2487   const uint8_t *const dst = pd->dst.buf;
2488   const int16_t *const src_diff = p->src_diff;
2489   const int shift = (xd->bd - 8);
2490 
2491   int64_t sse = aom_sum_squares_2d_i16(src_diff, diff_stride, bw, bh);
2492   sse = ROUND_POWER_OF_TWO(sse, shift * 2);
2493   const double sse_norm = (double)sse / num_samples;
2494 
2495   const unsigned int sad =
2496       cpi->fn_ptr[plane_bsize].sdf(src, src_stride, dst, dst_stride);
2497   const double sad_norm =
2498       (double)sad / (1 << num_pels_log2_lookup[plane_bsize]);
2499 
2500   fprintf(fout, " %g %g", sse_norm, sad_norm);
2501 
2502   double sse_norm_arr[4], sad_norm_arr[4];
2503   get_2x2_normalized_sses_and_sads(cpi, plane_bsize, src, src_stride, dst,
2504                                    dst_stride, src_diff, diff_stride,
2505                                    sse_norm_arr, sad_norm_arr);
2506   if (shift) {
2507     for (int k = 0; k < 4; ++k) sse_norm_arr[k] /= (1 << (2 * shift));
2508     for (int k = 0; k < 4; ++k) sad_norm_arr[k] /= (1 << shift);
2509   }
2510   for (int i = 0; i < 4; ++i) {
2511     fprintf(fout, " %g", sse_norm_arr[i]);
2512   }
2513   for (int i = 0; i < 4; ++i) {
2514     fprintf(fout, " %g", sad_norm_arr[i]);
2515   }
2516 
2517   fprintf(fout, " %d %d %d %d", q_step, x->rdmult, bw, bh);
2518 
2519   int model_rate;
2520   int64_t model_dist;
2521   model_rd_sse_fn[MODELRD_CURVFIT](cpi, x, plane_bsize, plane, sse, num_samples,
2522                                    &model_rate, &model_dist);
2523   const double model_rdcost_norm =
2524       (double)RDCOST(x->rdmult, model_rate, model_dist) / num_samples;
2525   const double model_rate_norm = (double)model_rate / num_samples;
2526   const double model_dist_norm = (double)model_dist / num_samples;
2527   fprintf(fout, " %g %g %g", model_rate_norm, model_dist_norm,
2528           model_rdcost_norm);
2529 
2530   double mean = get_mean(src_diff, diff_stride, bw, bh);
2531   mean /= (1 << shift);
2532   double hor_corr, vert_corr;
2533   get_horver_correlation(src_diff, diff_stride, bw, bh, &hor_corr, &vert_corr);
2534   fprintf(fout, " %g %g %g", mean, hor_corr, vert_corr);
2535 
2536   double hdist[4] = { 0 }, vdist[4] = { 0 };
2537   get_energy_distribution_fine(cpi, plane_bsize, src, src_stride, dst,
2538                                dst_stride, 1, hdist, vdist);
2539   fprintf(fout, " %g %g %g %g %g %g %g %g", hdist[0], hdist[1], hdist[2],
2540           hdist[3], vdist[0], vdist[1], vdist[2], vdist[3]);
2541 
2542   fprintf(fout, "\n");
2543   fclose(fout);
2544 }
2545 #endif  // CONFIG_COLLECT_RD_STATS >= 2
2546 #endif  // CONFIG_COLLECT_RD_STATS
2547 
model_rd_with_dnn(const AV1_COMP * const cpi,const MACROBLOCK * const x,BLOCK_SIZE plane_bsize,int plane,int64_t sse,int num_samples,int * rate,int64_t * dist)2548 static void model_rd_with_dnn(const AV1_COMP *const cpi,
2549                               const MACROBLOCK *const x, BLOCK_SIZE plane_bsize,
2550                               int plane, int64_t sse, int num_samples,
2551                               int *rate, int64_t *dist) {
2552   const MACROBLOCKD *const xd = &x->e_mbd;
2553   const struct macroblockd_plane *const pd = &xd->plane[plane];
2554   const int log_numpels = num_pels_log2_lookup[plane_bsize];
2555 
2556   const int dequant_shift =
2557       (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) ? xd->bd - 5 : 3;
2558   const int q_step = AOMMAX(pd->dequant_Q3[1] >> dequant_shift, 1);
2559 
2560   const struct macroblock_plane *const p = &x->plane[plane];
2561   int bw, bh;
2562   get_txb_dimensions(xd, plane, plane_bsize, 0, 0, plane_bsize, NULL, NULL, &bw,
2563                      &bh);
2564   const int src_stride = p->src.stride;
2565   const uint8_t *const src = p->src.buf;
2566   const int dst_stride = pd->dst.stride;
2567   const uint8_t *const dst = pd->dst.buf;
2568   const int16_t *const src_diff = p->src_diff;
2569   const int diff_stride = block_size_wide[plane_bsize];
2570   const int shift = (xd->bd - 8);
2571 
2572   if (sse == 0) {
2573     if (rate) *rate = 0;
2574     if (dist) *dist = 0;
2575     return;
2576   }
2577   if (plane) {
2578     int model_rate;
2579     int64_t model_dist;
2580     model_rd_with_curvfit(cpi, x, plane_bsize, plane, sse, num_samples,
2581                           &model_rate, &model_dist);
2582     if (rate) *rate = model_rate;
2583     if (dist) *dist = model_dist;
2584     return;
2585   }
2586 
2587   aom_clear_system_state();
2588   const double sse_norm = (double)sse / num_samples;
2589 
2590   double sse_norm_arr[4];
2591   get_2x2_normalized_sses_and_sads(cpi, plane_bsize, src, src_stride, dst,
2592                                    dst_stride, src_diff, diff_stride,
2593                                    sse_norm_arr, NULL);
2594   double mean = get_mean(src_diff, bw, bw, bh);
2595   if (shift) {
2596     for (int k = 0; k < 4; ++k) sse_norm_arr[k] /= (1 << (2 * shift));
2597     mean /= (1 << shift);
2598   }
2599   double sse_norm_sum = 0.0, sse_frac_arr[3];
2600   for (int k = 0; k < 4; ++k) sse_norm_sum += sse_norm_arr[k];
2601   for (int k = 0; k < 3; ++k)
2602     sse_frac_arr[k] =
2603         sse_norm_sum > 0.0 ? sse_norm_arr[k] / sse_norm_sum : 0.25;
2604   const double q_sqr = (double)(q_step * q_step);
2605   const double q_sqr_by_sse_norm = q_sqr / (sse_norm + 1.0);
2606   const double mean_sqr_by_sse_norm = mean * mean / (sse_norm + 1.0);
2607   double hor_corr, vert_corr;
2608   get_horver_correlation(src_diff, diff_stride, bw, bh, &hor_corr, &vert_corr);
2609 
2610   float features[NUM_FEATURES_PUSTATS];
2611   features[0] = (float)hor_corr;
2612   features[1] = (float)log_numpels;
2613   features[2] = (float)mean_sqr_by_sse_norm;
2614   features[3] = (float)q_sqr_by_sse_norm;
2615   features[4] = (float)sse_frac_arr[0];
2616   features[5] = (float)sse_frac_arr[1];
2617   features[6] = (float)sse_frac_arr[2];
2618   features[7] = (float)vert_corr;
2619 
2620   float rate_f, dist_by_sse_norm_f;
2621   av1_nn_predict(features, &av1_pustats_dist_nnconfig, &dist_by_sse_norm_f);
2622   av1_nn_predict(features, &av1_pustats_rate_nnconfig, &rate_f);
2623   const float dist_f = (float)((double)dist_by_sse_norm_f * (1.0 + sse_norm));
2624   int rate_i = (int)(AOMMAX(0.0, rate_f * num_samples) + 0.5);
2625   int64_t dist_i = (int64_t)(AOMMAX(0.0, dist_f * num_samples) + 0.5);
2626   aom_clear_system_state();
2627 
2628   // Check if skip is better
2629   if (rate_i == 0) {
2630     dist_i = sse << 4;
2631   } else if (RDCOST(x->rdmult, rate_i, dist_i) >=
2632              RDCOST(x->rdmult, 0, sse << 4)) {
2633     rate_i = 0;
2634     dist_i = sse << 4;
2635   }
2636 
2637   if (rate) *rate = rate_i;
2638   if (dist) *dist = dist_i;
2639   return;
2640 }
2641 
model_rd_for_sb_with_dnn(const AV1_COMP * const cpi,BLOCK_SIZE bsize,MACROBLOCK * x,MACROBLOCKD * xd,int plane_from,int plane_to,int mi_row,int mi_col,int * out_rate_sum,int64_t * out_dist_sum,int * skip_txfm_sb,int64_t * skip_sse_sb,int * plane_rate,int64_t * plane_sse,int64_t * plane_dist)2642 static void model_rd_for_sb_with_dnn(
2643     const AV1_COMP *const cpi, BLOCK_SIZE bsize, MACROBLOCK *x, MACROBLOCKD *xd,
2644     int plane_from, int plane_to, int mi_row, int mi_col, int *out_rate_sum,
2645     int64_t *out_dist_sum, int *skip_txfm_sb, int64_t *skip_sse_sb,
2646     int *plane_rate, int64_t *plane_sse, int64_t *plane_dist) {
2647   (void)mi_row;
2648   (void)mi_col;
2649   // Note our transform coeffs are 8 times an orthogonal transform.
2650   // Hence quantizer step is also 8 times. To get effective quantizer
2651   // we need to divide by 8 before sending to modeling function.
2652   const int ref = xd->mi[0]->ref_frame[0];
2653 
2654   int64_t rate_sum = 0;
2655   int64_t dist_sum = 0;
2656   int64_t total_sse = 0;
2657 
2658   for (int plane = plane_from; plane <= plane_to; ++plane) {
2659     struct macroblockd_plane *const pd = &xd->plane[plane];
2660     const BLOCK_SIZE plane_bsize =
2661         get_plane_block_size(bsize, pd->subsampling_x, pd->subsampling_y);
2662     int64_t dist, sse;
2663     int rate;
2664 
2665     if (x->skip_chroma_rd && plane) continue;
2666 
2667     const struct macroblock_plane *const p = &x->plane[plane];
2668     const int shift = (xd->bd - 8);
2669     int bw, bh;
2670     get_txb_dimensions(xd, plane, plane_bsize, 0, 0, plane_bsize, NULL, NULL,
2671                        &bw, &bh);
2672     if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) {
2673       sse = aom_highbd_sse(p->src.buf, p->src.stride, pd->dst.buf,
2674                            pd->dst.stride, bw, bh);
2675     } else {
2676       sse = aom_sse(p->src.buf, p->src.stride, pd->dst.buf, pd->dst.stride, bw,
2677                     bh);
2678     }
2679     sse = ROUND_POWER_OF_TWO(sse, shift * 2);
2680 
2681     model_rd_with_dnn(cpi, x, plane_bsize, plane, sse, bw * bh, &rate, &dist);
2682 
2683     if (plane == 0) x->pred_sse[ref] = (unsigned int)AOMMIN(sse, UINT_MAX);
2684 
2685     total_sse += sse;
2686     rate_sum += rate;
2687     dist_sum += dist;
2688 
2689     if (plane_rate) plane_rate[plane] = rate;
2690     if (plane_sse) plane_sse[plane] = sse;
2691     if (plane_dist) plane_dist[plane] = dist;
2692   }
2693 
2694   if (skip_txfm_sb) *skip_txfm_sb = total_sse == 0;
2695   if (skip_sse_sb) *skip_sse_sb = total_sse << 4;
2696   *out_rate_sum = (int)rate_sum;
2697   *out_dist_sum = dist_sum;
2698 }
2699 
2700 // Fits a surface for rate and distortion using as features:
2701 // log2(sse_norm + 1) and log2(sse_norm/qstep^2)
model_rd_with_surffit(const AV1_COMP * const cpi,const MACROBLOCK * const x,BLOCK_SIZE plane_bsize,int plane,int64_t sse,int num_samples,int * rate,int64_t * dist)2702 static void model_rd_with_surffit(const AV1_COMP *const cpi,
2703                                   const MACROBLOCK *const x,
2704                                   BLOCK_SIZE plane_bsize, int plane,
2705                                   int64_t sse, int num_samples, int *rate,
2706                                   int64_t *dist) {
2707   (void)cpi;
2708   (void)plane_bsize;
2709   const MACROBLOCKD *const xd = &x->e_mbd;
2710   const struct macroblockd_plane *const pd = &xd->plane[plane];
2711   const int dequant_shift =
2712       (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) ? xd->bd - 5 : 3;
2713   const int qstep = AOMMAX(pd->dequant_Q3[1] >> dequant_shift, 1);
2714   if (sse == 0) {
2715     if (rate) *rate = 0;
2716     if (dist) *dist = 0;
2717     return;
2718   }
2719   aom_clear_system_state();
2720   const double sse_norm = (double)sse / num_samples;
2721   const double qstepsqr = (double)qstep * qstep;
2722   const double xm = log(sse_norm + 1.0) / log(2.0);
2723   const double yl = log(sse_norm / qstepsqr) / log(2.0);
2724   double rate_f, dist_by_sse_norm_f;
2725 
2726   av1_model_rd_surffit(xm, yl, &rate_f, &dist_by_sse_norm_f);
2727 
2728   const double dist_f = dist_by_sse_norm_f * sse_norm;
2729   int rate_i = (int)(AOMMAX(0.0, rate_f * num_samples) + 0.5);
2730   int64_t dist_i = (int64_t)(AOMMAX(0.0, dist_f * num_samples) + 0.5);
2731   aom_clear_system_state();
2732 
2733   // Check if skip is better
2734   if (rate_i == 0) {
2735     dist_i = sse << 4;
2736   } else if (RDCOST(x->rdmult, rate_i, dist_i) >=
2737              RDCOST(x->rdmult, 0, sse << 4)) {
2738     rate_i = 0;
2739     dist_i = sse << 4;
2740   }
2741 
2742   if (rate) *rate = rate_i;
2743   if (dist) *dist = dist_i;
2744 }
2745 
model_rd_for_sb_with_surffit(const AV1_COMP * const cpi,BLOCK_SIZE bsize,MACROBLOCK * x,MACROBLOCKD * xd,int plane_from,int plane_to,int mi_row,int mi_col,int * out_rate_sum,int64_t * out_dist_sum,int * skip_txfm_sb,int64_t * skip_sse_sb,int * plane_rate,int64_t * plane_sse,int64_t * plane_dist)2746 static void model_rd_for_sb_with_surffit(
2747     const AV1_COMP *const cpi, BLOCK_SIZE bsize, MACROBLOCK *x, MACROBLOCKD *xd,
2748     int plane_from, int plane_to, int mi_row, int mi_col, int *out_rate_sum,
2749     int64_t *out_dist_sum, int *skip_txfm_sb, int64_t *skip_sse_sb,
2750     int *plane_rate, int64_t *plane_sse, int64_t *plane_dist) {
2751   (void)mi_row;
2752   (void)mi_col;
2753   // Note our transform coeffs are 8 times an orthogonal transform.
2754   // Hence quantizer step is also 8 times. To get effective quantizer
2755   // we need to divide by 8 before sending to modeling function.
2756   const int ref = xd->mi[0]->ref_frame[0];
2757 
2758   int64_t rate_sum = 0;
2759   int64_t dist_sum = 0;
2760   int64_t total_sse = 0;
2761 
2762   for (int plane = plane_from; plane <= plane_to; ++plane) {
2763     struct macroblockd_plane *const pd = &xd->plane[plane];
2764     const BLOCK_SIZE plane_bsize =
2765         get_plane_block_size(bsize, pd->subsampling_x, pd->subsampling_y);
2766     int64_t dist, sse;
2767     int rate;
2768 
2769     if (x->skip_chroma_rd && plane) continue;
2770 
2771     int bw, bh;
2772     const struct macroblock_plane *const p = &x->plane[plane];
2773     const int shift = (xd->bd - 8);
2774     get_txb_dimensions(xd, plane, plane_bsize, 0, 0, plane_bsize, NULL, NULL,
2775                        &bw, &bh);
2776     if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) {
2777       sse = aom_highbd_sse(p->src.buf, p->src.stride, pd->dst.buf,
2778                            pd->dst.stride, bw, bh);
2779     } else {
2780       sse = aom_sse(p->src.buf, p->src.stride, pd->dst.buf, pd->dst.stride, bw,
2781                     bh);
2782     }
2783     sse = ROUND_POWER_OF_TWO(sse, shift * 2);
2784 
2785     model_rd_with_surffit(cpi, x, plane_bsize, plane, sse, bw * bh, &rate,
2786                           &dist);
2787 
2788     if (plane == 0) x->pred_sse[ref] = (unsigned int)AOMMIN(sse, UINT_MAX);
2789 
2790     total_sse += sse;
2791     rate_sum += rate;
2792     dist_sum += dist;
2793 
2794     if (plane_rate) plane_rate[plane] = rate;
2795     if (plane_sse) plane_sse[plane] = sse;
2796     if (plane_dist) plane_dist[plane] = dist;
2797   }
2798 
2799   if (skip_txfm_sb) *skip_txfm_sb = total_sse == 0;
2800   if (skip_sse_sb) *skip_sse_sb = total_sse << 4;
2801   *out_rate_sum = (int)rate_sum;
2802   *out_dist_sum = dist_sum;
2803 }
2804 
2805 // Fits a curve for rate and distortion using as feature:
2806 // log2(sse_norm/qstep^2)
model_rd_with_curvfit(const AV1_COMP * const cpi,const MACROBLOCK * const x,BLOCK_SIZE plane_bsize,int plane,int64_t sse,int num_samples,int * rate,int64_t * dist)2807 static void model_rd_with_curvfit(const AV1_COMP *const cpi,
2808                                   const MACROBLOCK *const x,
2809                                   BLOCK_SIZE plane_bsize, int plane,
2810                                   int64_t sse, int num_samples, int *rate,
2811                                   int64_t *dist) {
2812   (void)cpi;
2813   (void)plane_bsize;
2814   const MACROBLOCKD *const xd = &x->e_mbd;
2815   const struct macroblockd_plane *const pd = &xd->plane[plane];
2816   const int dequant_shift =
2817       (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) ? xd->bd - 5 : 3;
2818   const int qstep = AOMMAX(pd->dequant_Q3[1] >> dequant_shift, 1);
2819 
2820   if (sse == 0) {
2821     if (rate) *rate = 0;
2822     if (dist) *dist = 0;
2823     return;
2824   }
2825   aom_clear_system_state();
2826   const double sse_norm = (double)sse / num_samples;
2827   const double qstepsqr = (double)qstep * qstep;
2828   const double xqr = log(sse_norm / qstepsqr) / log(2.0);
2829 
2830   double rate_f, dist_by_sse_norm_f;
2831   av1_model_rd_curvfit(xqr, &rate_f, &dist_by_sse_norm_f);
2832 
2833   const double dist_f = dist_by_sse_norm_f * sse_norm;
2834   int rate_i = (int)(AOMMAX(0.0, rate_f * num_samples) + 0.5);
2835   int64_t dist_i = (int64_t)(AOMMAX(0.0, dist_f * num_samples) + 0.5);
2836   aom_clear_system_state();
2837 
2838   // Check if skip is better
2839   if (rate_i == 0) {
2840     dist_i = sse << 4;
2841   } else if (RDCOST(x->rdmult, rate_i, dist_i) >=
2842              RDCOST(x->rdmult, 0, sse << 4)) {
2843     rate_i = 0;
2844     dist_i = sse << 4;
2845   }
2846 
2847   if (rate) *rate = rate_i;
2848   if (dist) *dist = dist_i;
2849 }
2850 
model_rd_for_sb_with_curvfit(const AV1_COMP * const cpi,BLOCK_SIZE bsize,MACROBLOCK * x,MACROBLOCKD * xd,int plane_from,int plane_to,int mi_row,int mi_col,int * out_rate_sum,int64_t * out_dist_sum,int * skip_txfm_sb,int64_t * skip_sse_sb,int * plane_rate,int64_t * plane_sse,int64_t * plane_dist)2851 static void model_rd_for_sb_with_curvfit(
2852     const AV1_COMP *const cpi, BLOCK_SIZE bsize, MACROBLOCK *x, MACROBLOCKD *xd,
2853     int plane_from, int plane_to, int mi_row, int mi_col, int *out_rate_sum,
2854     int64_t *out_dist_sum, int *skip_txfm_sb, int64_t *skip_sse_sb,
2855     int *plane_rate, int64_t *plane_sse, int64_t *plane_dist) {
2856   (void)mi_row;
2857   (void)mi_col;
2858   // Note our transform coeffs are 8 times an orthogonal transform.
2859   // Hence quantizer step is also 8 times. To get effective quantizer
2860   // we need to divide by 8 before sending to modeling function.
2861   const int ref = xd->mi[0]->ref_frame[0];
2862 
2863   int64_t rate_sum = 0;
2864   int64_t dist_sum = 0;
2865   int64_t total_sse = 0;
2866 
2867   for (int plane = plane_from; plane <= plane_to; ++plane) {
2868     struct macroblockd_plane *const pd = &xd->plane[plane];
2869     const BLOCK_SIZE plane_bsize =
2870         get_plane_block_size(bsize, pd->subsampling_x, pd->subsampling_y);
2871     int64_t dist, sse;
2872     int rate;
2873 
2874     if (x->skip_chroma_rd && plane) continue;
2875 
2876     int bw, bh;
2877     const struct macroblock_plane *const p = &x->plane[plane];
2878     const int shift = (xd->bd - 8);
2879     get_txb_dimensions(xd, plane, plane_bsize, 0, 0, plane_bsize, NULL, NULL,
2880                        &bw, &bh);
2881 
2882     if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) {
2883       sse = aom_highbd_sse(p->src.buf, p->src.stride, pd->dst.buf,
2884                            pd->dst.stride, bw, bh);
2885     } else {
2886       sse = aom_sse(p->src.buf, p->src.stride, pd->dst.buf, pd->dst.stride, bw,
2887                     bh);
2888     }
2889 
2890     sse = ROUND_POWER_OF_TWO(sse, shift * 2);
2891     model_rd_with_curvfit(cpi, x, plane_bsize, plane, sse, bw * bh, &rate,
2892                           &dist);
2893 
2894     if (plane == 0) x->pred_sse[ref] = (unsigned int)AOMMIN(sse, UINT_MAX);
2895 
2896     total_sse += sse;
2897     rate_sum += rate;
2898     dist_sum += dist;
2899 
2900     if (plane_rate) plane_rate[plane] = rate;
2901     if (plane_sse) plane_sse[plane] = sse;
2902     if (plane_dist) plane_dist[plane] = dist;
2903   }
2904 
2905   if (skip_txfm_sb) *skip_txfm_sb = total_sse == 0;
2906   if (skip_sse_sb) *skip_sse_sb = total_sse << 4;
2907   *out_rate_sum = (int)rate_sum;
2908   *out_dist_sum = dist_sum;
2909 }
2910 
model_rd_for_sb_with_fullrdy(const AV1_COMP * const cpi,BLOCK_SIZE bsize,MACROBLOCK * x,MACROBLOCKD * xd,int plane_from,int plane_to,int mi_row,int mi_col,int * out_rate_sum,int64_t * out_dist_sum,int * skip_txfm_sb,int64_t * skip_sse_sb,int * plane_rate,int64_t * plane_sse,int64_t * plane_dist)2911 static void model_rd_for_sb_with_fullrdy(
2912     const AV1_COMP *const cpi, BLOCK_SIZE bsize, MACROBLOCK *x, MACROBLOCKD *xd,
2913     int plane_from, int plane_to, int mi_row, int mi_col, int *out_rate_sum,
2914     int64_t *out_dist_sum, int *skip_txfm_sb, int64_t *skip_sse_sb,
2915     int *plane_rate, int64_t *plane_sse, int64_t *plane_dist) {
2916   const int ref = xd->mi[0]->ref_frame[0];
2917 
2918   int64_t rate_sum = 0;
2919   int64_t dist_sum = 0;
2920   int64_t total_sse = 0;
2921 
2922   for (int plane = plane_from; plane <= plane_to; ++plane) {
2923     struct macroblock_plane *const p = &x->plane[plane];
2924     struct macroblockd_plane *const pd = &xd->plane[plane];
2925     const BLOCK_SIZE plane_bsize =
2926         get_plane_block_size(bsize, pd->subsampling_x, pd->subsampling_y);
2927     const int bw = block_size_wide[plane_bsize];
2928     const int bh = block_size_high[plane_bsize];
2929     int64_t sse;
2930     int rate;
2931     int64_t dist;
2932 
2933     if (x->skip_chroma_rd && plane) continue;
2934 
2935     if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) {
2936       sse = aom_highbd_sse(p->src.buf, p->src.stride, pd->dst.buf,
2937                            pd->dst.stride, bw, bh);
2938     } else {
2939       sse = aom_sse(p->src.buf, p->src.stride, pd->dst.buf, pd->dst.stride, bw,
2940                     bh);
2941     }
2942     sse = ROUND_POWER_OF_TWO(sse, (xd->bd - 8) * 2);
2943 
2944     RD_STATS rd_stats;
2945     if (plane == 0) {
2946       select_tx_type_yrd(cpi, x, &rd_stats, bsize, mi_row, mi_col, INT64_MAX);
2947       if (rd_stats.invalid_rate) {
2948         rate = 0;
2949         dist = sse << 4;
2950       } else {
2951         rate = rd_stats.rate;
2952         dist = rd_stats.dist;
2953       }
2954     } else {
2955       model_rd_with_curvfit(cpi, x, plane_bsize, plane, sse, bw * bh, &rate,
2956                             &dist);
2957     }
2958 
2959     if (plane == 0) x->pred_sse[ref] = (unsigned int)AOMMIN(sse, UINT_MAX);
2960 
2961     total_sse += sse;
2962     rate_sum += rate;
2963     dist_sum += dist;
2964 
2965     if (plane_rate) plane_rate[plane] = rate;
2966     if (plane_sse) plane_sse[plane] = sse;
2967     if (plane_dist) plane_dist[plane] = dist;
2968   }
2969 
2970   if (skip_txfm_sb) *skip_txfm_sb = total_sse == 0;
2971   if (skip_sse_sb) *skip_sse_sb = total_sse << 4;
2972   *out_rate_sum = (int)rate_sum;
2973   *out_dist_sum = dist_sum;
2974 }
2975 
search_txk_type(const AV1_COMP * cpi,MACROBLOCK * x,int plane,int block,int blk_row,int blk_col,BLOCK_SIZE plane_bsize,TX_SIZE tx_size,const TXB_CTX * const txb_ctx,FAST_TX_SEARCH_MODE ftxs_mode,int use_fast_coef_costing,int64_t ref_best_rd,RD_STATS * best_rd_stats)2976 static int64_t search_txk_type(const AV1_COMP *cpi, MACROBLOCK *x, int plane,
2977                                int block, int blk_row, int blk_col,
2978                                BLOCK_SIZE plane_bsize, TX_SIZE tx_size,
2979                                const TXB_CTX *const txb_ctx,
2980                                FAST_TX_SEARCH_MODE ftxs_mode,
2981                                int use_fast_coef_costing, int64_t ref_best_rd,
2982                                RD_STATS *best_rd_stats) {
2983   const AV1_COMMON *cm = &cpi->common;
2984   MACROBLOCKD *xd = &x->e_mbd;
2985   struct macroblockd_plane *const pd = &xd->plane[plane];
2986   MB_MODE_INFO *mbmi = xd->mi[0];
2987   const int is_inter = is_inter_block(mbmi);
2988   int64_t best_rd = INT64_MAX;
2989   uint16_t best_eob = 0;
2990   TX_TYPE best_tx_type = DCT_DCT;
2991   TX_TYPE last_tx_type = TX_TYPES;
2992   const int fast_tx_search = ftxs_mode & FTXS_DCT_AND_1D_DCT_ONLY;
2993   // The buffer used to swap dqcoeff in macroblockd_plane so we can keep dqcoeff
2994   // of the best tx_type
2995   DECLARE_ALIGNED(32, tran_low_t, this_dqcoeff[MAX_SB_SQUARE]);
2996   tran_low_t *orig_dqcoeff = pd->dqcoeff;
2997   tran_low_t *best_dqcoeff = this_dqcoeff;
2998   const int txk_type_idx =
2999       av1_get_txk_type_index(plane_bsize, blk_row, blk_col);
3000   av1_invalid_rd_stats(best_rd_stats);
3001 
3002   TXB_RD_INFO *intra_txb_rd_info = NULL;
3003   uint16_t cur_joint_ctx = 0;
3004   const int mi_row = -xd->mb_to_top_edge >> (3 + MI_SIZE_LOG2);
3005   const int mi_col = -xd->mb_to_left_edge >> (3 + MI_SIZE_LOG2);
3006   const int within_border =
3007       mi_row >= xd->tile.mi_row_start &&
3008       (mi_row + mi_size_high[plane_bsize] < xd->tile.mi_row_end) &&
3009       mi_col >= xd->tile.mi_col_start &&
3010       (mi_col + mi_size_wide[plane_bsize] < xd->tile.mi_col_end);
3011   if (within_border && cpi->sf.use_intra_txb_hash && frame_is_intra_only(cm) &&
3012       !is_inter && plane == 0 &&
3013       tx_size_wide[tx_size] == tx_size_high[tx_size]) {
3014     const uint32_t intra_hash =
3015         get_intra_txb_hash(x, plane, blk_row, blk_col, plane_bsize, tx_size);
3016     const int intra_hash_idx =
3017         find_tx_size_rd_info(&x->txb_rd_record_intra, intra_hash);
3018     intra_txb_rd_info = &x->txb_rd_record_intra.tx_rd_info[intra_hash_idx];
3019 
3020     cur_joint_ctx = (txb_ctx->dc_sign_ctx << 8) + txb_ctx->txb_skip_ctx;
3021     if (intra_txb_rd_info->entropy_context == cur_joint_ctx &&
3022         x->txb_rd_record_intra.tx_rd_info[intra_hash_idx].valid) {
3023       mbmi->txk_type[txk_type_idx] = intra_txb_rd_info->tx_type;
3024       const TX_TYPE ref_tx_type =
3025           av1_get_tx_type(get_plane_type(plane), &x->e_mbd, blk_row, blk_col,
3026                           tx_size, cpi->common.reduced_tx_set_used);
3027       if (ref_tx_type == intra_txb_rd_info->tx_type) {
3028         best_rd_stats->rate = intra_txb_rd_info->rate;
3029         best_rd_stats->dist = intra_txb_rd_info->dist;
3030         best_rd_stats->sse = intra_txb_rd_info->sse;
3031         best_rd_stats->skip = intra_txb_rd_info->eob == 0;
3032         x->plane[plane].eobs[block] = intra_txb_rd_info->eob;
3033         x->plane[plane].txb_entropy_ctx[block] =
3034             intra_txb_rd_info->txb_entropy_ctx;
3035         best_rd = RDCOST(x->rdmult, best_rd_stats->rate, best_rd_stats->dist);
3036         best_eob = intra_txb_rd_info->eob;
3037         best_tx_type = intra_txb_rd_info->tx_type;
3038         update_txk_array(mbmi->txk_type, plane_bsize, blk_row, blk_col, tx_size,
3039                          best_tx_type);
3040         goto RECON_INTRA;
3041       }
3042     }
3043   }
3044 
3045   int rate_cost = 0;
3046   TX_TYPE txk_start = DCT_DCT;
3047   TX_TYPE txk_end = TX_TYPES - 1;
3048   if ((!is_inter && x->use_default_intra_tx_type) ||
3049       (is_inter && x->use_default_inter_tx_type)) {
3050     txk_start = txk_end = get_default_tx_type(0, xd, tx_size);
3051   } else if (x->rd_model == LOW_TXFM_RD || x->cb_partition_scan) {
3052     if (plane == 0) txk_end = DCT_DCT;
3053   }
3054 
3055   uint8_t best_txb_ctx = 0;
3056   const TxSetType tx_set_type =
3057       av1_get_ext_tx_set_type(tx_size, is_inter, cm->reduced_tx_set_used);
3058 
3059   TX_TYPE uv_tx_type = DCT_DCT;
3060   if (plane) {
3061     // tx_type of PLANE_TYPE_UV should be the same as PLANE_TYPE_Y
3062     uv_tx_type = txk_start = txk_end =
3063         av1_get_tx_type(get_plane_type(plane), xd, blk_row, blk_col, tx_size,
3064                         cm->reduced_tx_set_used);
3065   }
3066   const uint16_t ext_tx_used_flag = av1_ext_tx_used_flag[tx_set_type];
3067   if (xd->lossless[mbmi->segment_id] || txsize_sqr_up_map[tx_size] > TX_32X32 ||
3068       ext_tx_used_flag == 0x0001) {
3069     txk_start = txk_end = DCT_DCT;
3070   }
3071   uint16_t allowed_tx_mask = 0;  // 1: allow; 0: skip.
3072   if (txk_start == txk_end) {
3073     allowed_tx_mask = 1 << txk_start;
3074     allowed_tx_mask &= ext_tx_used_flag;
3075   } else if (fast_tx_search) {
3076     allowed_tx_mask = 0x0c01;  // V_DCT, H_DCT, DCT_DCT
3077     allowed_tx_mask &= ext_tx_used_flag;
3078   } else {
3079     assert(plane == 0);
3080     allowed_tx_mask = ext_tx_used_flag;
3081     // !fast_tx_search && txk_end != txk_start && plane == 0
3082     const int do_prune = cpi->sf.tx_type_search.prune_mode > NO_PRUNE;
3083     if (do_prune && is_inter) {
3084       if (cpi->sf.tx_type_search.prune_mode >= PRUNE_2D_ACCURATE) {
3085         const uint16_t prune =
3086             prune_tx_2D(x, plane_bsize, tx_size, blk_row, blk_col, tx_set_type,
3087                         cpi->sf.tx_type_search.prune_mode);
3088         allowed_tx_mask &= (~prune);
3089       } else {
3090         allowed_tx_mask &= (~x->tx_search_prune[tx_set_type]);
3091       }
3092     }
3093   }
3094   // Need to have at least one transform type allowed.
3095   if (allowed_tx_mask == 0) {
3096     txk_start = txk_end = (plane ? uv_tx_type : DCT_DCT);
3097     allowed_tx_mask = (1 << txk_start);
3098   }
3099 
3100   int use_transform_domain_distortion =
3101       (cpi->sf.use_transform_domain_distortion > 0) &&
3102       // Any 64-pt transforms only preserves half the coefficients.
3103       // Therefore transform domain distortion is not valid for these
3104       // transform sizes.
3105       txsize_sqr_up_map[tx_size] != TX_64X64;
3106 #if CONFIG_DIST_8X8
3107   if (x->using_dist_8x8) use_transform_domain_distortion = 0;
3108 #endif
3109   int calc_pixel_domain_distortion_final =
3110       cpi->sf.use_transform_domain_distortion == 1 &&
3111       use_transform_domain_distortion && x->rd_model != LOW_TXFM_RD &&
3112       !x->cb_partition_scan;
3113   if (calc_pixel_domain_distortion_final &&
3114       (txk_start == txk_end || allowed_tx_mask == 0x0001))
3115     calc_pixel_domain_distortion_final = use_transform_domain_distortion = 0;
3116 
3117   const uint16_t *eobs_ptr = x->plane[plane].eobs;
3118 
3119   const BLOCK_SIZE tx_bsize = txsize_to_bsize[tx_size];
3120   int64_t block_sse =
3121       pixel_diff_dist(x, plane, blk_row, blk_col, plane_bsize, tx_bsize);
3122   if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH)
3123     block_sse = ROUND_POWER_OF_TWO(block_sse, (xd->bd - 8) * 2);
3124   block_sse *= 16;
3125 
3126   for (TX_TYPE tx_type = txk_start; tx_type <= txk_end; ++tx_type) {
3127     if (!(allowed_tx_mask & (1 << tx_type))) continue;
3128     if (plane == 0) mbmi->txk_type[txk_type_idx] = tx_type;
3129     RD_STATS this_rd_stats;
3130     av1_invalid_rd_stats(&this_rd_stats);
3131 
3132     if (!cpi->optimize_seg_arr[mbmi->segment_id]) {
3133       av1_xform_quant(
3134           cm, x, plane, block, blk_row, blk_col, plane_bsize, tx_size, tx_type,
3135           USE_B_QUANT_NO_TRELLIS ? AV1_XFORM_QUANT_B : AV1_XFORM_QUANT_FP);
3136       rate_cost = av1_cost_coeffs(cm, x, plane, block, tx_size, tx_type,
3137                                   txb_ctx, use_fast_coef_costing);
3138     } else {
3139       av1_xform_quant(cm, x, plane, block, blk_row, blk_col, plane_bsize,
3140                       tx_size, tx_type, AV1_XFORM_QUANT_FP);
3141       if (cpi->sf.optimize_b_precheck && best_rd < INT64_MAX &&
3142           eobs_ptr[block] >= 4) {
3143         // Calculate distortion quickly in transform domain.
3144         dist_block_tx_domain(x, plane, block, tx_size, &this_rd_stats.dist,
3145                              &this_rd_stats.sse);
3146 
3147         const int64_t best_rd_ = AOMMIN(best_rd, ref_best_rd);
3148         const int64_t dist_cost_estimate =
3149             RDCOST(x->rdmult, 0, AOMMIN(this_rd_stats.dist, this_rd_stats.sse));
3150         if (dist_cost_estimate - (dist_cost_estimate >> 3) > best_rd_) continue;
3151 
3152         rate_cost = av1_cost_coeffs(cm, x, plane, block, tx_size, tx_type,
3153                                     txb_ctx, use_fast_coef_costing);
3154         const int64_t rd_estimate =
3155             AOMMIN(RDCOST(x->rdmult, rate_cost, this_rd_stats.dist),
3156                    RDCOST(x->rdmult, 0, this_rd_stats.sse));
3157         if (rd_estimate - (rd_estimate >> 3) > best_rd_) continue;
3158       }
3159       av1_optimize_b(cpi, x, plane, block, tx_size, tx_type, txb_ctx, 1,
3160                      &rate_cost);
3161     }
3162     if (eobs_ptr[block] == 0) {
3163       // When eob is 0, pixel domain distortion is more efficient and accurate.
3164       this_rd_stats.dist = this_rd_stats.sse = block_sse;
3165     } else if (use_transform_domain_distortion) {
3166       dist_block_tx_domain(x, plane, block, tx_size, &this_rd_stats.dist,
3167                            &this_rd_stats.sse);
3168     } else {
3169       this_rd_stats.dist = dist_block_px_domain(
3170           cpi, x, plane, plane_bsize, block, blk_row, blk_col, tx_size);
3171       this_rd_stats.sse = block_sse;
3172     }
3173 
3174     this_rd_stats.rate = rate_cost;
3175 
3176     const int64_t rd =
3177         RDCOST(x->rdmult, this_rd_stats.rate, this_rd_stats.dist);
3178 
3179     if (rd < best_rd) {
3180       best_rd = rd;
3181       *best_rd_stats = this_rd_stats;
3182       best_tx_type = tx_type;
3183       best_txb_ctx = x->plane[plane].txb_entropy_ctx[block];
3184       best_eob = x->plane[plane].eobs[block];
3185       last_tx_type = best_tx_type;
3186 
3187       // Swap qcoeff and dqcoeff buffers
3188       tran_low_t *const tmp_dqcoeff = best_dqcoeff;
3189       best_dqcoeff = pd->dqcoeff;
3190       pd->dqcoeff = tmp_dqcoeff;
3191     }
3192 
3193 #if CONFIG_COLLECT_RD_STATS == 1
3194     if (plane == 0) {
3195       PrintTransformUnitStats(cpi, x, &this_rd_stats, blk_row, blk_col,
3196                               plane_bsize, tx_size, tx_type, rd);
3197     }
3198 #endif  // CONFIG_COLLECT_RD_STATS == 1
3199 
3200     if (cpi->sf.adaptive_txb_search_level) {
3201       if ((best_rd - (best_rd >> cpi->sf.adaptive_txb_search_level)) >
3202           ref_best_rd) {
3203         break;
3204       }
3205     }
3206 
3207     // Skip transform type search when we found the block has been quantized to
3208     // all zero and at the same time, it has better rdcost than doing transform.
3209     if (cpi->sf.tx_type_search.skip_tx_search && !best_eob) break;
3210   }
3211 
3212   assert(best_rd != INT64_MAX);
3213 
3214   best_rd_stats->skip = best_eob == 0;
3215   if (plane == 0) {
3216     update_txk_array(mbmi->txk_type, plane_bsize, blk_row, blk_col, tx_size,
3217                      best_tx_type);
3218   }
3219   x->plane[plane].txb_entropy_ctx[block] = best_txb_ctx;
3220   x->plane[plane].eobs[block] = best_eob;
3221 
3222   pd->dqcoeff = best_dqcoeff;
3223 
3224   if (calc_pixel_domain_distortion_final && best_eob) {
3225     best_rd_stats->dist = dist_block_px_domain(
3226         cpi, x, plane, plane_bsize, block, blk_row, blk_col, tx_size);
3227     best_rd_stats->sse = block_sse;
3228   }
3229 
3230   if (intra_txb_rd_info != NULL) {
3231     intra_txb_rd_info->valid = 1;
3232     intra_txb_rd_info->entropy_context = cur_joint_ctx;
3233     intra_txb_rd_info->rate = best_rd_stats->rate;
3234     intra_txb_rd_info->dist = best_rd_stats->dist;
3235     intra_txb_rd_info->sse = best_rd_stats->sse;
3236     intra_txb_rd_info->eob = best_eob;
3237     intra_txb_rd_info->txb_entropy_ctx = best_txb_ctx;
3238     if (plane == 0) intra_txb_rd_info->tx_type = best_tx_type;
3239   }
3240 
3241 RECON_INTRA:
3242   if (!is_inter && best_eob &&
3243       (blk_row + tx_size_high_unit[tx_size] < mi_size_high[plane_bsize] ||
3244        blk_col + tx_size_wide_unit[tx_size] < mi_size_wide[plane_bsize])) {
3245     // intra mode needs decoded result such that the next transform block
3246     // can use it for prediction.
3247     // if the last search tx_type is the best tx_type, we don't need to
3248     // do this again
3249     if (best_tx_type != last_tx_type) {
3250       if (!cpi->optimize_seg_arr[mbmi->segment_id]) {
3251         av1_xform_quant(
3252             cm, x, plane, block, blk_row, blk_col, plane_bsize, tx_size,
3253             best_tx_type,
3254             USE_B_QUANT_NO_TRELLIS ? AV1_XFORM_QUANT_B : AV1_XFORM_QUANT_FP);
3255       } else {
3256         av1_xform_quant(cm, x, plane, block, blk_row, blk_col, plane_bsize,
3257                         tx_size, best_tx_type, AV1_XFORM_QUANT_FP);
3258         av1_optimize_b(cpi, x, plane, block, tx_size, best_tx_type, txb_ctx, 1,
3259                        &rate_cost);
3260       }
3261     }
3262 
3263     inverse_transform_block_facade(xd, plane, block, blk_row, blk_col,
3264                                    x->plane[plane].eobs[block],
3265                                    cm->reduced_tx_set_used);
3266 
3267     // This may happen because of hash collision. The eob stored in the hash
3268     // table is non-zero, but the real eob is zero. We need to make sure tx_type
3269     // is DCT_DCT in this case.
3270     if (plane == 0 && x->plane[plane].eobs[block] == 0 &&
3271         best_tx_type != DCT_DCT) {
3272       update_txk_array(mbmi->txk_type, plane_bsize, blk_row, blk_col, tx_size,
3273                        DCT_DCT);
3274     }
3275   }
3276   pd->dqcoeff = orig_dqcoeff;
3277 
3278   return best_rd;
3279 }
3280 
block_rd_txfm(int plane,int block,int blk_row,int blk_col,BLOCK_SIZE plane_bsize,TX_SIZE tx_size,void * arg)3281 static void block_rd_txfm(int plane, int block, int blk_row, int blk_col,
3282                           BLOCK_SIZE plane_bsize, TX_SIZE tx_size, void *arg) {
3283   struct rdcost_block_args *args = arg;
3284   MACROBLOCK *const x = args->x;
3285   MACROBLOCKD *const xd = &x->e_mbd;
3286   const MB_MODE_INFO *const mbmi = xd->mi[0];
3287   const AV1_COMP *cpi = args->cpi;
3288   ENTROPY_CONTEXT *a = args->t_above + blk_col;
3289   ENTROPY_CONTEXT *l = args->t_left + blk_row;
3290   const AV1_COMMON *cm = &cpi->common;
3291   int64_t rd1, rd2, rd;
3292   RD_STATS this_rd_stats;
3293 
3294   av1_init_rd_stats(&this_rd_stats);
3295 
3296   if (args->exit_early) {
3297     args->incomplete_exit = 1;
3298     return;
3299   }
3300 
3301   if (!is_inter_block(mbmi)) {
3302     av1_predict_intra_block_facade(cm, xd, plane, blk_col, blk_row, tx_size);
3303     av1_subtract_txb(x, plane, plane_bsize, blk_col, blk_row, tx_size);
3304   }
3305   TXB_CTX txb_ctx;
3306   get_txb_ctx(plane_bsize, tx_size, plane, a, l, &txb_ctx);
3307   search_txk_type(cpi, x, plane, block, blk_row, blk_col, plane_bsize, tx_size,
3308                   &txb_ctx, args->ftxs_mode, args->use_fast_coef_costing,
3309                   args->best_rd - args->this_rd, &this_rd_stats);
3310 
3311   if (plane == AOM_PLANE_Y && xd->cfl.store_y) {
3312     assert(!is_inter_block(mbmi) || plane_bsize < BLOCK_8X8);
3313     cfl_store_tx(xd, blk_row, blk_col, tx_size, plane_bsize);
3314   }
3315 
3316 #if CONFIG_RD_DEBUG
3317   av1_update_txb_coeff_cost(&this_rd_stats, plane, tx_size, blk_row, blk_col,
3318                             this_rd_stats.rate);
3319 #endif  // CONFIG_RD_DEBUG
3320   av1_set_txb_context(x, plane, block, tx_size, a, l);
3321 
3322   const int blk_idx =
3323       blk_row * (block_size_wide[plane_bsize] >> tx_size_wide_log2[0]) +
3324       blk_col;
3325 
3326   if (plane == 0)
3327     set_blk_skip(x, plane, blk_idx, x->plane[plane].eobs[block] == 0);
3328   else
3329     set_blk_skip(x, plane, blk_idx, 0);
3330 
3331   rd1 = RDCOST(x->rdmult, this_rd_stats.rate, this_rd_stats.dist);
3332   rd2 = RDCOST(x->rdmult, 0, this_rd_stats.sse);
3333 
3334   // TODO(jingning): temporarily enabled only for luma component
3335   rd = AOMMIN(rd1, rd2);
3336 
3337   this_rd_stats.skip &= !x->plane[plane].eobs[block];
3338 
3339   av1_merge_rd_stats(&args->rd_stats, &this_rd_stats);
3340 
3341   args->this_rd += rd;
3342 
3343   if (args->this_rd > args->best_rd) {
3344     args->exit_early = 1;
3345     return;
3346   }
3347 }
3348 
txfm_rd_in_plane(MACROBLOCK * x,const AV1_COMP * cpi,RD_STATS * rd_stats,int64_t ref_best_rd,int plane,BLOCK_SIZE bsize,TX_SIZE tx_size,int use_fast_coef_casting,FAST_TX_SEARCH_MODE ftxs_mode)3349 static void txfm_rd_in_plane(MACROBLOCK *x, const AV1_COMP *cpi,
3350                              RD_STATS *rd_stats, int64_t ref_best_rd, int plane,
3351                              BLOCK_SIZE bsize, TX_SIZE tx_size,
3352                              int use_fast_coef_casting,
3353                              FAST_TX_SEARCH_MODE ftxs_mode) {
3354   MACROBLOCKD *const xd = &x->e_mbd;
3355   const struct macroblockd_plane *const pd = &xd->plane[plane];
3356   struct rdcost_block_args args;
3357   av1_zero(args);
3358   args.x = x;
3359   args.cpi = cpi;
3360   args.best_rd = ref_best_rd;
3361   args.use_fast_coef_costing = use_fast_coef_casting;
3362   args.ftxs_mode = ftxs_mode;
3363   av1_init_rd_stats(&args.rd_stats);
3364 
3365   if (plane == 0) xd->mi[0]->tx_size = tx_size;
3366 
3367   av1_get_entropy_contexts(bsize, pd, args.t_above, args.t_left);
3368 
3369   av1_foreach_transformed_block_in_plane(xd, bsize, plane, block_rd_txfm,
3370                                          &args);
3371 
3372   MB_MODE_INFO *const mbmi = xd->mi[0];
3373   const int is_inter = is_inter_block(mbmi);
3374   const int invalid_rd = is_inter ? args.incomplete_exit : args.exit_early;
3375 
3376   if (invalid_rd) {
3377     av1_invalid_rd_stats(rd_stats);
3378   } else {
3379     *rd_stats = args.rd_stats;
3380   }
3381 }
3382 
tx_size_cost(const AV1_COMMON * const cm,const MACROBLOCK * const x,BLOCK_SIZE bsize,TX_SIZE tx_size)3383 static int tx_size_cost(const AV1_COMMON *const cm, const MACROBLOCK *const x,
3384                         BLOCK_SIZE bsize, TX_SIZE tx_size) {
3385   const MACROBLOCKD *const xd = &x->e_mbd;
3386   const MB_MODE_INFO *const mbmi = xd->mi[0];
3387 
3388   if (cm->tx_mode == TX_MODE_SELECT && block_signals_txsize(mbmi->sb_type)) {
3389     const int32_t tx_size_cat = bsize_to_tx_size_cat(bsize);
3390     const int depth = tx_size_to_depth(tx_size, bsize);
3391     const int tx_size_ctx = get_tx_size_context(xd);
3392     int r_tx_size = x->tx_size_cost[tx_size_cat][tx_size_ctx][depth];
3393     return r_tx_size;
3394   } else {
3395     return 0;
3396   }
3397 }
3398 
txfm_yrd(const AV1_COMP * const cpi,MACROBLOCK * x,RD_STATS * rd_stats,int64_t ref_best_rd,BLOCK_SIZE bs,TX_SIZE tx_size,FAST_TX_SEARCH_MODE ftxs_mode)3399 static int64_t txfm_yrd(const AV1_COMP *const cpi, MACROBLOCK *x,
3400                         RD_STATS *rd_stats, int64_t ref_best_rd, BLOCK_SIZE bs,
3401                         TX_SIZE tx_size, FAST_TX_SEARCH_MODE ftxs_mode) {
3402   const AV1_COMMON *const cm = &cpi->common;
3403   MACROBLOCKD *const xd = &x->e_mbd;
3404   MB_MODE_INFO *const mbmi = xd->mi[0];
3405   int64_t rd = INT64_MAX;
3406   const int skip_ctx = av1_get_skip_context(xd);
3407   int s0, s1;
3408   const int is_inter = is_inter_block(mbmi);
3409   const int tx_select =
3410       cm->tx_mode == TX_MODE_SELECT && block_signals_txsize(mbmi->sb_type);
3411   int ctx = txfm_partition_context(
3412       xd->above_txfm_context, xd->left_txfm_context, mbmi->sb_type, tx_size);
3413   const int r_tx_size = is_inter ? x->txfm_partition_cost[ctx][0]
3414                                  : tx_size_cost(cm, x, bs, tx_size);
3415 
3416   assert(IMPLIES(is_rect_tx(tx_size), is_rect_tx_allowed_bsize(bs)));
3417 
3418   s0 = x->skip_cost[skip_ctx][0];
3419   s1 = x->skip_cost[skip_ctx][1];
3420 
3421   mbmi->tx_size = tx_size;
3422   txfm_rd_in_plane(x, cpi, rd_stats, ref_best_rd, AOM_PLANE_Y, bs, tx_size,
3423                    cpi->sf.use_fast_coef_costing, ftxs_mode);
3424   if (rd_stats->rate == INT_MAX) return INT64_MAX;
3425 
3426   if (rd_stats->skip) {
3427     if (is_inter) {
3428       rd = RDCOST(x->rdmult, s1, rd_stats->sse);
3429     } else {
3430       rd = RDCOST(x->rdmult, s1 + r_tx_size * tx_select, rd_stats->sse);
3431     }
3432   } else {
3433     rd = RDCOST(x->rdmult, rd_stats->rate + s0 + r_tx_size * tx_select,
3434                 rd_stats->dist);
3435   }
3436 
3437   if (tx_select) rd_stats->rate += r_tx_size;
3438 
3439   if (is_inter && !xd->lossless[xd->mi[0]->segment_id] && !(rd_stats->skip))
3440     rd = AOMMIN(rd, RDCOST(x->rdmult, s1, rd_stats->sse));
3441 
3442   return rd;
3443 }
3444 
estimate_yrd_for_sb(const AV1_COMP * const cpi,BLOCK_SIZE bs,MACROBLOCK * x,int * r,int64_t * d,int * s,int64_t * sse,int64_t ref_best_rd)3445 static int64_t estimate_yrd_for_sb(const AV1_COMP *const cpi, BLOCK_SIZE bs,
3446                                    MACROBLOCK *x, int *r, int64_t *d, int *s,
3447                                    int64_t *sse, int64_t ref_best_rd) {
3448   RD_STATS rd_stats;
3449   av1_subtract_plane(x, bs, 0);
3450   x->rd_model = LOW_TXFM_RD;
3451   int64_t rd = txfm_yrd(cpi, x, &rd_stats, ref_best_rd, bs,
3452                         max_txsize_rect_lookup[bs], FTXS_NONE);
3453   x->rd_model = FULL_TXFM_RD;
3454   *r = rd_stats.rate;
3455   *d = rd_stats.dist;
3456   *s = rd_stats.skip;
3457   *sse = rd_stats.sse;
3458   return rd;
3459 }
3460 
choose_largest_tx_size(const AV1_COMP * const cpi,MACROBLOCK * x,RD_STATS * rd_stats,int64_t ref_best_rd,BLOCK_SIZE bs)3461 static void choose_largest_tx_size(const AV1_COMP *const cpi, MACROBLOCK *x,
3462                                    RD_STATS *rd_stats, int64_t ref_best_rd,
3463                                    BLOCK_SIZE bs) {
3464   const AV1_COMMON *const cm = &cpi->common;
3465   MACROBLOCKD *const xd = &x->e_mbd;
3466   MB_MODE_INFO *const mbmi = xd->mi[0];
3467   const int is_inter = is_inter_block(mbmi);
3468   mbmi->tx_size = tx_size_from_tx_mode(bs, cm->tx_mode);
3469   const TxSetType tx_set_type =
3470       av1_get_ext_tx_set_type(mbmi->tx_size, is_inter, cm->reduced_tx_set_used);
3471   prune_tx(cpi, bs, x, xd, tx_set_type);
3472   txfm_rd_in_plane(x, cpi, rd_stats, ref_best_rd, AOM_PLANE_Y, bs,
3473                    mbmi->tx_size, cpi->sf.use_fast_coef_costing, FTXS_NONE);
3474   // Reset the pruning flags.
3475   av1_zero(x->tx_search_prune);
3476   x->tx_split_prune_flag = 0;
3477 }
3478 
choose_smallest_tx_size(const AV1_COMP * const cpi,MACROBLOCK * x,RD_STATS * rd_stats,int64_t ref_best_rd,BLOCK_SIZE bs)3479 static void choose_smallest_tx_size(const AV1_COMP *const cpi, MACROBLOCK *x,
3480                                     RD_STATS *rd_stats, int64_t ref_best_rd,
3481                                     BLOCK_SIZE bs) {
3482   MACROBLOCKD *const xd = &x->e_mbd;
3483   MB_MODE_INFO *const mbmi = xd->mi[0];
3484 
3485   mbmi->tx_size = TX_4X4;
3486   txfm_rd_in_plane(x, cpi, rd_stats, ref_best_rd, 0, bs, mbmi->tx_size,
3487                    cpi->sf.use_fast_coef_costing, FTXS_NONE);
3488 }
3489 
bsize_to_num_blk(BLOCK_SIZE bsize)3490 static INLINE int bsize_to_num_blk(BLOCK_SIZE bsize) {
3491   int num_blk = 1 << (num_pels_log2_lookup[bsize] - 2 * tx_size_wide_log2[0]);
3492   return num_blk;
3493 }
3494 
get_search_init_depth(int mi_width,int mi_height,int is_inter,const SPEED_FEATURES * sf)3495 static int get_search_init_depth(int mi_width, int mi_height, int is_inter,
3496                                  const SPEED_FEATURES *sf) {
3497   if (sf->tx_size_search_method == USE_LARGESTALL) return MAX_VARTX_DEPTH;
3498 
3499   if (sf->tx_size_search_lgr_block) {
3500     if (mi_width > mi_size_wide[BLOCK_64X64] ||
3501         mi_height > mi_size_high[BLOCK_64X64])
3502       return MAX_VARTX_DEPTH;
3503   }
3504 
3505   if (is_inter) {
3506     return (mi_height != mi_width) ? sf->inter_tx_size_search_init_depth_rect
3507                                    : sf->inter_tx_size_search_init_depth_sqr;
3508   } else {
3509     return (mi_height != mi_width) ? sf->intra_tx_size_search_init_depth_rect
3510                                    : sf->intra_tx_size_search_init_depth_sqr;
3511   }
3512 }
3513 
choose_tx_size_type_from_rd(const AV1_COMP * const cpi,MACROBLOCK * x,RD_STATS * rd_stats,int64_t ref_best_rd,BLOCK_SIZE bs)3514 static void choose_tx_size_type_from_rd(const AV1_COMP *const cpi,
3515                                         MACROBLOCK *x, RD_STATS *rd_stats,
3516                                         int64_t ref_best_rd, BLOCK_SIZE bs) {
3517   const AV1_COMMON *const cm = &cpi->common;
3518   MACROBLOCKD *const xd = &x->e_mbd;
3519   MB_MODE_INFO *const mbmi = xd->mi[0];
3520   int64_t rd = INT64_MAX;
3521   int n;
3522   int start_tx;
3523   int depth;
3524   int64_t best_rd = INT64_MAX;
3525   const TX_SIZE max_rect_tx_size = max_txsize_rect_lookup[bs];
3526   TX_SIZE best_tx_size = max_rect_tx_size;
3527   TX_TYPE best_txk_type[TXK_TYPE_BUF_LEN];
3528   uint8_t best_blk_skip[MAX_MIB_SIZE * MAX_MIB_SIZE];
3529   const int n4 = bsize_to_num_blk(bs);
3530   const int tx_select = cm->tx_mode == TX_MODE_SELECT;
3531 
3532   av1_invalid_rd_stats(rd_stats);
3533 
3534   if (tx_select) {
3535     start_tx = max_rect_tx_size;
3536     depth = get_search_init_depth(mi_size_wide[bs], mi_size_high[bs],
3537                                   is_inter_block(mbmi), &cpi->sf);
3538   } else {
3539     const TX_SIZE chosen_tx_size = tx_size_from_tx_mode(bs, cm->tx_mode);
3540     start_tx = chosen_tx_size;
3541     depth = MAX_TX_DEPTH;
3542   }
3543 
3544   prune_tx(cpi, bs, x, xd, EXT_TX_SET_ALL16);
3545 
3546   for (n = start_tx; depth <= MAX_TX_DEPTH; depth++, n = sub_tx_size_map[n]) {
3547 #if CONFIG_DIST_8X8
3548     if (x->using_dist_8x8) {
3549       if (tx_size_wide[n] < 8 || tx_size_high[n] < 8) continue;
3550     }
3551 #endif
3552     RD_STATS this_rd_stats;
3553     if (mbmi->ref_mv_idx > 0) x->rd_model = LOW_TXFM_RD;
3554     rd = txfm_yrd(cpi, x, &this_rd_stats, ref_best_rd, bs, n, FTXS_NONE);
3555     x->rd_model = FULL_TXFM_RD;
3556 
3557     if (rd < best_rd) {
3558       memcpy(best_txk_type, mbmi->txk_type,
3559              sizeof(best_txk_type[0]) * TXK_TYPE_BUF_LEN);
3560       memcpy(best_blk_skip, x->blk_skip, sizeof(best_blk_skip[0]) * n4);
3561       best_tx_size = n;
3562       best_rd = rd;
3563       *rd_stats = this_rd_stats;
3564     }
3565     if (n == TX_4X4) break;
3566   }
3567 
3568   if (rd_stats->rate != INT_MAX) {
3569     mbmi->tx_size = best_tx_size;
3570     memcpy(mbmi->txk_type, best_txk_type,
3571            sizeof(best_txk_type[0]) * TXK_TYPE_BUF_LEN);
3572     memcpy(x->blk_skip, best_blk_skip, sizeof(best_blk_skip[0]) * n4);
3573   }
3574 
3575   // Reset the pruning flags.
3576   av1_zero(x->tx_search_prune);
3577   x->tx_split_prune_flag = 0;
3578 }
3579 
super_block_yrd(const AV1_COMP * const cpi,MACROBLOCK * x,RD_STATS * rd_stats,BLOCK_SIZE bs,int64_t ref_best_rd)3580 static void super_block_yrd(const AV1_COMP *const cpi, MACROBLOCK *x,
3581                             RD_STATS *rd_stats, BLOCK_SIZE bs,
3582                             int64_t ref_best_rd) {
3583   MACROBLOCKD *xd = &x->e_mbd;
3584   av1_init_rd_stats(rd_stats);
3585 
3586   assert(bs == xd->mi[0]->sb_type);
3587 
3588   if (xd->lossless[xd->mi[0]->segment_id]) {
3589     choose_smallest_tx_size(cpi, x, rd_stats, ref_best_rd, bs);
3590   } else if (cpi->sf.tx_size_search_method == USE_LARGESTALL) {
3591     choose_largest_tx_size(cpi, x, rd_stats, ref_best_rd, bs);
3592   } else {
3593     choose_tx_size_type_from_rd(cpi, x, rd_stats, ref_best_rd, bs);
3594   }
3595 }
3596 
3597 // Return the rate cost for luma prediction mode info. of intra blocks.
intra_mode_info_cost_y(const AV1_COMP * cpi,const MACROBLOCK * x,const MB_MODE_INFO * mbmi,BLOCK_SIZE bsize,int mode_cost)3598 static int intra_mode_info_cost_y(const AV1_COMP *cpi, const MACROBLOCK *x,
3599                                   const MB_MODE_INFO *mbmi, BLOCK_SIZE bsize,
3600                                   int mode_cost) {
3601   int total_rate = mode_cost;
3602   const int use_palette = mbmi->palette_mode_info.palette_size[0] > 0;
3603   const int use_filter_intra = mbmi->filter_intra_mode_info.use_filter_intra;
3604   const int use_intrabc = mbmi->use_intrabc;
3605   // Can only activate one mode.
3606   assert(((mbmi->mode != DC_PRED) + use_palette + use_intrabc +
3607           use_filter_intra) <= 1);
3608   const int try_palette =
3609       av1_allow_palette(cpi->common.allow_screen_content_tools, mbmi->sb_type);
3610   if (try_palette && mbmi->mode == DC_PRED) {
3611     const MACROBLOCKD *xd = &x->e_mbd;
3612     const int bsize_ctx = av1_get_palette_bsize_ctx(bsize);
3613     const int mode_ctx = av1_get_palette_mode_ctx(xd);
3614     total_rate += x->palette_y_mode_cost[bsize_ctx][mode_ctx][use_palette];
3615     if (use_palette) {
3616       const uint8_t *const color_map = xd->plane[0].color_index_map;
3617       int block_width, block_height, rows, cols;
3618       av1_get_block_dimensions(bsize, 0, xd, &block_width, &block_height, &rows,
3619                                &cols);
3620       const int plt_size = mbmi->palette_mode_info.palette_size[0];
3621       int palette_mode_cost =
3622           x->palette_y_size_cost[bsize_ctx][plt_size - PALETTE_MIN_SIZE] +
3623           write_uniform_cost(plt_size, color_map[0]);
3624       uint16_t color_cache[2 * PALETTE_MAX_SIZE];
3625       const int n_cache = av1_get_palette_cache(xd, 0, color_cache);
3626       palette_mode_cost +=
3627           av1_palette_color_cost_y(&mbmi->palette_mode_info, color_cache,
3628                                    n_cache, cpi->common.seq_params.bit_depth);
3629       palette_mode_cost +=
3630           av1_cost_color_map(x, 0, bsize, mbmi->tx_size, PALETTE_MAP);
3631       total_rate += palette_mode_cost;
3632     }
3633   }
3634   if (av1_filter_intra_allowed(&cpi->common, mbmi)) {
3635     total_rate += x->filter_intra_cost[mbmi->sb_type][use_filter_intra];
3636     if (use_filter_intra) {
3637       total_rate += x->filter_intra_mode_cost[mbmi->filter_intra_mode_info
3638                                                   .filter_intra_mode];
3639     }
3640   }
3641   if (av1_is_directional_mode(mbmi->mode)) {
3642     if (av1_use_angle_delta(bsize)) {
3643       total_rate += x->angle_delta_cost[mbmi->mode - V_PRED]
3644                                        [MAX_ANGLE_DELTA +
3645                                         mbmi->angle_delta[PLANE_TYPE_Y]];
3646     }
3647   }
3648   if (av1_allow_intrabc(&cpi->common))
3649     total_rate += x->intrabc_cost[use_intrabc];
3650   return total_rate;
3651 }
3652 
3653 // Return the rate cost for chroma prediction mode info. of intra blocks.
intra_mode_info_cost_uv(const AV1_COMP * cpi,const MACROBLOCK * x,const MB_MODE_INFO * mbmi,BLOCK_SIZE bsize,int mode_cost)3654 static int intra_mode_info_cost_uv(const AV1_COMP *cpi, const MACROBLOCK *x,
3655                                    const MB_MODE_INFO *mbmi, BLOCK_SIZE bsize,
3656                                    int mode_cost) {
3657   int total_rate = mode_cost;
3658   const int use_palette = mbmi->palette_mode_info.palette_size[1] > 0;
3659   const UV_PREDICTION_MODE mode = mbmi->uv_mode;
3660   // Can only activate one mode.
3661   assert(((mode != UV_DC_PRED) + use_palette + mbmi->use_intrabc) <= 1);
3662 
3663   const int try_palette =
3664       av1_allow_palette(cpi->common.allow_screen_content_tools, mbmi->sb_type);
3665   if (try_palette && mode == UV_DC_PRED) {
3666     const PALETTE_MODE_INFO *pmi = &mbmi->palette_mode_info;
3667     total_rate +=
3668         x->palette_uv_mode_cost[pmi->palette_size[0] > 0][use_palette];
3669     if (use_palette) {
3670       const int bsize_ctx = av1_get_palette_bsize_ctx(bsize);
3671       const int plt_size = pmi->palette_size[1];
3672       const MACROBLOCKD *xd = &x->e_mbd;
3673       const uint8_t *const color_map = xd->plane[1].color_index_map;
3674       int palette_mode_cost =
3675           x->palette_uv_size_cost[bsize_ctx][plt_size - PALETTE_MIN_SIZE] +
3676           write_uniform_cost(plt_size, color_map[0]);
3677       uint16_t color_cache[2 * PALETTE_MAX_SIZE];
3678       const int n_cache = av1_get_palette_cache(xd, 1, color_cache);
3679       palette_mode_cost += av1_palette_color_cost_uv(
3680           pmi, color_cache, n_cache, cpi->common.seq_params.bit_depth);
3681       palette_mode_cost +=
3682           av1_cost_color_map(x, 1, bsize, mbmi->tx_size, PALETTE_MAP);
3683       total_rate += palette_mode_cost;
3684     }
3685   }
3686   if (av1_is_directional_mode(get_uv_mode(mode))) {
3687     if (av1_use_angle_delta(bsize)) {
3688       total_rate +=
3689           x->angle_delta_cost[mode - V_PRED][mbmi->angle_delta[PLANE_TYPE_UV] +
3690                                              MAX_ANGLE_DELTA];
3691     }
3692   }
3693   return total_rate;
3694 }
3695 
conditional_skipintra(PREDICTION_MODE mode,PREDICTION_MODE best_intra_mode)3696 static int conditional_skipintra(PREDICTION_MODE mode,
3697                                  PREDICTION_MODE best_intra_mode) {
3698   if (mode == D113_PRED && best_intra_mode != V_PRED &&
3699       best_intra_mode != D135_PRED)
3700     return 1;
3701   if (mode == D67_PRED && best_intra_mode != V_PRED &&
3702       best_intra_mode != D45_PRED)
3703     return 1;
3704   if (mode == D203_PRED && best_intra_mode != H_PRED &&
3705       best_intra_mode != D45_PRED)
3706     return 1;
3707   if (mode == D157_PRED && best_intra_mode != H_PRED &&
3708       best_intra_mode != D135_PRED)
3709     return 1;
3710   return 0;
3711 }
3712 
3713 // Model based RD estimation for luma intra blocks.
intra_model_yrd(const AV1_COMP * const cpi,MACROBLOCK * const x,BLOCK_SIZE bsize,int mode_cost,int mi_row,int mi_col)3714 static int64_t intra_model_yrd(const AV1_COMP *const cpi, MACROBLOCK *const x,
3715                                BLOCK_SIZE bsize, int mode_cost, int mi_row,
3716                                int mi_col) {
3717   const AV1_COMMON *cm = &cpi->common;
3718   MACROBLOCKD *const xd = &x->e_mbd;
3719   MB_MODE_INFO *const mbmi = xd->mi[0];
3720   assert(!is_inter_block(mbmi));
3721   RD_STATS this_rd_stats;
3722   int row, col;
3723   int64_t temp_sse, this_rd;
3724   TX_SIZE tx_size = tx_size_from_tx_mode(bsize, cm->tx_mode);
3725   const int stepr = tx_size_high_unit[tx_size];
3726   const int stepc = tx_size_wide_unit[tx_size];
3727   const int max_blocks_wide = max_block_wide(xd, bsize, 0);
3728   const int max_blocks_high = max_block_high(xd, bsize, 0);
3729   mbmi->tx_size = tx_size;
3730   // Prediction.
3731   for (row = 0; row < max_blocks_high; row += stepr) {
3732     for (col = 0; col < max_blocks_wide; col += stepc) {
3733       av1_predict_intra_block_facade(cm, xd, 0, col, row, tx_size);
3734     }
3735   }
3736   // RD estimation.
3737   model_rd_sb_fn[MODELRD_TYPE_INTRA](
3738       cpi, bsize, x, xd, 0, 0, mi_row, mi_col, &this_rd_stats.rate,
3739       &this_rd_stats.dist, &this_rd_stats.skip, &temp_sse, NULL, NULL, NULL);
3740   if (av1_is_directional_mode(mbmi->mode) && av1_use_angle_delta(bsize)) {
3741     mode_cost +=
3742         x->angle_delta_cost[mbmi->mode - V_PRED]
3743                            [MAX_ANGLE_DELTA + mbmi->angle_delta[PLANE_TYPE_Y]];
3744   }
3745   if (mbmi->mode == DC_PRED &&
3746       av1_filter_intra_allowed_bsize(cm, mbmi->sb_type)) {
3747     if (mbmi->filter_intra_mode_info.use_filter_intra) {
3748       const int mode = mbmi->filter_intra_mode_info.filter_intra_mode;
3749       mode_cost += x->filter_intra_cost[mbmi->sb_type][1] +
3750                    x->filter_intra_mode_cost[mode];
3751     } else {
3752       mode_cost += x->filter_intra_cost[mbmi->sb_type][0];
3753     }
3754   }
3755   this_rd =
3756       RDCOST(x->rdmult, this_rd_stats.rate + mode_cost, this_rd_stats.dist);
3757   return this_rd;
3758 }
3759 
3760 // Extends 'color_map' array from 'orig_width x orig_height' to 'new_width x
3761 // new_height'. Extra rows and columns are filled in by copying last valid
3762 // row/column.
extend_palette_color_map(uint8_t * const color_map,int orig_width,int orig_height,int new_width,int new_height)3763 static void extend_palette_color_map(uint8_t *const color_map, int orig_width,
3764                                      int orig_height, int new_width,
3765                                      int new_height) {
3766   int j;
3767   assert(new_width >= orig_width);
3768   assert(new_height >= orig_height);
3769   if (new_width == orig_width && new_height == orig_height) return;
3770 
3771   for (j = orig_height - 1; j >= 0; --j) {
3772     memmove(color_map + j * new_width, color_map + j * orig_width, orig_width);
3773     // Copy last column to extra columns.
3774     memset(color_map + j * new_width + orig_width,
3775            color_map[j * new_width + orig_width - 1], new_width - orig_width);
3776   }
3777   // Copy last row to extra rows.
3778   for (j = orig_height; j < new_height; ++j) {
3779     memcpy(color_map + j * new_width, color_map + (orig_height - 1) * new_width,
3780            new_width);
3781   }
3782 }
3783 
3784 // Bias toward using colors in the cache.
3785 // TODO(huisu): Try other schemes to improve compression.
optimize_palette_colors(uint16_t * color_cache,int n_cache,int n_colors,int stride,int * centroids)3786 static void optimize_palette_colors(uint16_t *color_cache, int n_cache,
3787                                     int n_colors, int stride, int *centroids) {
3788   if (n_cache <= 0) return;
3789   for (int i = 0; i < n_colors * stride; i += stride) {
3790     int min_diff = abs(centroids[i] - (int)color_cache[0]);
3791     int idx = 0;
3792     for (int j = 1; j < n_cache; ++j) {
3793       const int this_diff = abs(centroids[i] - color_cache[j]);
3794       if (this_diff < min_diff) {
3795         min_diff = this_diff;
3796         idx = j;
3797       }
3798     }
3799     if (min_diff <= 1) centroids[i] = color_cache[idx];
3800   }
3801 }
3802 
3803 // Given the base colors as specified in centroids[], calculate the RD cost
3804 // of palette mode.
palette_rd_y(const AV1_COMP * const cpi,MACROBLOCK * x,MB_MODE_INFO * mbmi,BLOCK_SIZE bsize,int mi_row,int mi_col,int dc_mode_cost,const int * data,int * centroids,int n,uint16_t * color_cache,int n_cache,MB_MODE_INFO * best_mbmi,uint8_t * best_palette_color_map,int64_t * best_rd,int64_t * best_model_rd,int * rate,int * rate_tokenonly,int * rate_overhead,int64_t * distortion,int * skippable,PICK_MODE_CONTEXT * ctx,uint8_t * blk_skip)3805 static void palette_rd_y(const AV1_COMP *const cpi, MACROBLOCK *x,
3806                          MB_MODE_INFO *mbmi, BLOCK_SIZE bsize, int mi_row,
3807                          int mi_col, int dc_mode_cost, const int *data,
3808                          int *centroids, int n, uint16_t *color_cache,
3809                          int n_cache, MB_MODE_INFO *best_mbmi,
3810                          uint8_t *best_palette_color_map, int64_t *best_rd,
3811                          int64_t *best_model_rd, int *rate, int *rate_tokenonly,
3812                          int *rate_overhead, int64_t *distortion,
3813                          int *skippable, PICK_MODE_CONTEXT *ctx,
3814                          uint8_t *blk_skip) {
3815   optimize_palette_colors(color_cache, n_cache, n, 1, centroids);
3816   int k = av1_remove_duplicates(centroids, n);
3817   if (k < PALETTE_MIN_SIZE) {
3818     // Too few unique colors to create a palette. And DC_PRED will work
3819     // well for that case anyway. So skip.
3820     return;
3821   }
3822   PALETTE_MODE_INFO *const pmi = &mbmi->palette_mode_info;
3823   if (cpi->common.seq_params.use_highbitdepth)
3824     for (int i = 0; i < k; ++i)
3825       pmi->palette_colors[i] = clip_pixel_highbd(
3826           (int)centroids[i], cpi->common.seq_params.bit_depth);
3827   else
3828     for (int i = 0; i < k; ++i)
3829       pmi->palette_colors[i] = clip_pixel(centroids[i]);
3830   pmi->palette_size[0] = k;
3831   MACROBLOCKD *const xd = &x->e_mbd;
3832   uint8_t *const color_map = xd->plane[0].color_index_map;
3833   int block_width, block_height, rows, cols;
3834   av1_get_block_dimensions(bsize, 0, xd, &block_width, &block_height, &rows,
3835                            &cols);
3836   av1_calc_indices(data, centroids, color_map, rows * cols, k, 1);
3837   extend_palette_color_map(color_map, cols, rows, block_width, block_height);
3838   const int palette_mode_cost =
3839       intra_mode_info_cost_y(cpi, x, mbmi, bsize, dc_mode_cost);
3840   int64_t this_model_rd =
3841       intra_model_yrd(cpi, x, bsize, palette_mode_cost, mi_row, mi_col);
3842   if (*best_model_rd != INT64_MAX &&
3843       this_model_rd > *best_model_rd + (*best_model_rd >> 1))
3844     return;
3845   if (this_model_rd < *best_model_rd) *best_model_rd = this_model_rd;
3846   RD_STATS tokenonly_rd_stats;
3847   super_block_yrd(cpi, x, &tokenonly_rd_stats, bsize, *best_rd);
3848   if (tokenonly_rd_stats.rate == INT_MAX) return;
3849   int this_rate = tokenonly_rd_stats.rate + palette_mode_cost;
3850   int64_t this_rd = RDCOST(x->rdmult, this_rate, tokenonly_rd_stats.dist);
3851   if (!xd->lossless[mbmi->segment_id] && block_signals_txsize(mbmi->sb_type)) {
3852     tokenonly_rd_stats.rate -=
3853         tx_size_cost(&cpi->common, x, bsize, mbmi->tx_size);
3854   }
3855   if (this_rd < *best_rd) {
3856     *best_rd = this_rd;
3857     memcpy(best_palette_color_map, color_map,
3858            block_width * block_height * sizeof(color_map[0]));
3859     *best_mbmi = *mbmi;
3860     memcpy(blk_skip, x->blk_skip, sizeof(x->blk_skip[0]) * ctx->num_4x4_blk);
3861     *rate_overhead = this_rate - tokenonly_rd_stats.rate;
3862     if (rate) *rate = this_rate;
3863     if (rate_tokenonly) *rate_tokenonly = tokenonly_rd_stats.rate;
3864     if (distortion) *distortion = tokenonly_rd_stats.dist;
3865     if (skippable) *skippable = tokenonly_rd_stats.skip;
3866   }
3867 }
3868 
rd_pick_palette_intra_sby(const AV1_COMP * const cpi,MACROBLOCK * x,BLOCK_SIZE bsize,int mi_row,int mi_col,int dc_mode_cost,MB_MODE_INFO * best_mbmi,uint8_t * best_palette_color_map,int64_t * best_rd,int64_t * best_model_rd,int * rate,int * rate_tokenonly,int64_t * distortion,int * skippable,PICK_MODE_CONTEXT * ctx,uint8_t * best_blk_skip)3869 static int rd_pick_palette_intra_sby(
3870     const AV1_COMP *const cpi, MACROBLOCK *x, BLOCK_SIZE bsize, int mi_row,
3871     int mi_col, int dc_mode_cost, MB_MODE_INFO *best_mbmi,
3872     uint8_t *best_palette_color_map, int64_t *best_rd, int64_t *best_model_rd,
3873     int *rate, int *rate_tokenonly, int64_t *distortion, int *skippable,
3874     PICK_MODE_CONTEXT *ctx, uint8_t *best_blk_skip) {
3875   int rate_overhead = 0;
3876   MACROBLOCKD *const xd = &x->e_mbd;
3877   MB_MODE_INFO *const mbmi = xd->mi[0];
3878   assert(!is_inter_block(mbmi));
3879   assert(av1_allow_palette(cpi->common.allow_screen_content_tools, bsize));
3880   const SequenceHeader *const seq_params = &cpi->common.seq_params;
3881   int colors, n;
3882   const int src_stride = x->plane[0].src.stride;
3883   const uint8_t *const src = x->plane[0].src.buf;
3884   uint8_t *const color_map = xd->plane[0].color_index_map;
3885   int block_width, block_height, rows, cols;
3886   av1_get_block_dimensions(bsize, 0, xd, &block_width, &block_height, &rows,
3887                            &cols);
3888 
3889   int count_buf[1 << 12];  // Maximum (1 << 12) color levels.
3890   if (seq_params->use_highbitdepth)
3891     colors = av1_count_colors_highbd(src, src_stride, rows, cols,
3892                                      seq_params->bit_depth, count_buf);
3893   else
3894     colors = av1_count_colors(src, src_stride, rows, cols, count_buf);
3895   mbmi->filter_intra_mode_info.use_filter_intra = 0;
3896 
3897   if (colors > 1 && colors <= 64) {
3898     int r, c, i;
3899     const int max_itr = 50;
3900     int *const data = x->palette_buffer->kmeans_data_buf;
3901     int centroids[PALETTE_MAX_SIZE];
3902     int lb, ub, val;
3903     uint16_t *src16 = CONVERT_TO_SHORTPTR(src);
3904     if (seq_params->use_highbitdepth)
3905       lb = ub = src16[0];
3906     else
3907       lb = ub = src[0];
3908 
3909     if (seq_params->use_highbitdepth) {
3910       for (r = 0; r < rows; ++r) {
3911         for (c = 0; c < cols; ++c) {
3912           val = src16[r * src_stride + c];
3913           data[r * cols + c] = val;
3914           if (val < lb)
3915             lb = val;
3916           else if (val > ub)
3917             ub = val;
3918         }
3919       }
3920     } else {
3921       for (r = 0; r < rows; ++r) {
3922         for (c = 0; c < cols; ++c) {
3923           val = src[r * src_stride + c];
3924           data[r * cols + c] = val;
3925           if (val < lb)
3926             lb = val;
3927           else if (val > ub)
3928             ub = val;
3929         }
3930       }
3931     }
3932 
3933     mbmi->mode = DC_PRED;
3934     mbmi->filter_intra_mode_info.use_filter_intra = 0;
3935 
3936     uint16_t color_cache[2 * PALETTE_MAX_SIZE];
3937     const int n_cache = av1_get_palette_cache(xd, 0, color_cache);
3938 
3939     // Find the dominant colors, stored in top_colors[].
3940     int top_colors[PALETTE_MAX_SIZE] = { 0 };
3941     for (i = 0; i < AOMMIN(colors, PALETTE_MAX_SIZE); ++i) {
3942       int max_count = 0;
3943       for (int j = 0; j < (1 << seq_params->bit_depth); ++j) {
3944         if (count_buf[j] > max_count) {
3945           max_count = count_buf[j];
3946           top_colors[i] = j;
3947         }
3948       }
3949       assert(max_count > 0);
3950       count_buf[top_colors[i]] = 0;
3951     }
3952 
3953     // Try the dominant colors directly.
3954     // TODO(huisu@google.com): Try to avoid duplicate computation in cases
3955     // where the dominant colors and the k-means results are similar.
3956     for (n = AOMMIN(colors, PALETTE_MAX_SIZE); n >= 2; --n) {
3957       for (i = 0; i < n; ++i) centroids[i] = top_colors[i];
3958       palette_rd_y(cpi, x, mbmi, bsize, mi_row, mi_col, dc_mode_cost, data,
3959                    centroids, n, color_cache, n_cache, best_mbmi,
3960                    best_palette_color_map, best_rd, best_model_rd, rate,
3961                    rate_tokenonly, &rate_overhead, distortion, skippable, ctx,
3962                    best_blk_skip);
3963     }
3964 
3965     // K-means clustering.
3966     for (n = AOMMIN(colors, PALETTE_MAX_SIZE); n >= 2; --n) {
3967       if (colors == PALETTE_MIN_SIZE) {
3968         // Special case: These colors automatically become the centroids.
3969         assert(colors == n);
3970         assert(colors == 2);
3971         centroids[0] = lb;
3972         centroids[1] = ub;
3973       } else {
3974         for (i = 0; i < n; ++i) {
3975           centroids[i] = lb + (2 * i + 1) * (ub - lb) / n / 2;
3976         }
3977         av1_k_means(data, centroids, color_map, rows * cols, n, 1, max_itr);
3978       }
3979       palette_rd_y(cpi, x, mbmi, bsize, mi_row, mi_col, dc_mode_cost, data,
3980                    centroids, n, color_cache, n_cache, best_mbmi,
3981                    best_palette_color_map, best_rd, best_model_rd, rate,
3982                    rate_tokenonly, &rate_overhead, distortion, skippable, ctx,
3983                    best_blk_skip);
3984     }
3985   }
3986 
3987   if (best_mbmi->palette_mode_info.palette_size[0] > 0) {
3988     memcpy(color_map, best_palette_color_map,
3989            block_width * block_height * sizeof(best_palette_color_map[0]));
3990   }
3991   *mbmi = *best_mbmi;
3992   return rate_overhead;
3993 }
3994 
3995 // Return 1 if an filter intra mode is selected; return 0 otherwise.
rd_pick_filter_intra_sby(const AV1_COMP * const cpi,MACROBLOCK * x,int mi_row,int mi_col,int * rate,int * rate_tokenonly,int64_t * distortion,int * skippable,BLOCK_SIZE bsize,int mode_cost,int64_t * best_rd,int64_t * best_model_rd,PICK_MODE_CONTEXT * ctx)3996 static int rd_pick_filter_intra_sby(const AV1_COMP *const cpi, MACROBLOCK *x,
3997                                     int mi_row, int mi_col, int *rate,
3998                                     int *rate_tokenonly, int64_t *distortion,
3999                                     int *skippable, BLOCK_SIZE bsize,
4000                                     int mode_cost, int64_t *best_rd,
4001                                     int64_t *best_model_rd,
4002                                     PICK_MODE_CONTEXT *ctx) {
4003   MACROBLOCKD *const xd = &x->e_mbd;
4004   MB_MODE_INFO *mbmi = xd->mi[0];
4005   int filter_intra_selected_flag = 0;
4006   FILTER_INTRA_MODE mode;
4007   TX_SIZE best_tx_size = TX_8X8;
4008   FILTER_INTRA_MODE_INFO filter_intra_mode_info;
4009   TX_TYPE best_txk_type[TXK_TYPE_BUF_LEN];
4010   (void)ctx;
4011   av1_zero(filter_intra_mode_info);
4012   mbmi->filter_intra_mode_info.use_filter_intra = 1;
4013   mbmi->mode = DC_PRED;
4014   mbmi->palette_mode_info.palette_size[0] = 0;
4015 
4016   for (mode = 0; mode < FILTER_INTRA_MODES; ++mode) {
4017     int64_t this_rd, this_model_rd;
4018     RD_STATS tokenonly_rd_stats;
4019     mbmi->filter_intra_mode_info.filter_intra_mode = mode;
4020     this_model_rd = intra_model_yrd(cpi, x, bsize, mode_cost, mi_row, mi_col);
4021     if (*best_model_rd != INT64_MAX &&
4022         this_model_rd > *best_model_rd + (*best_model_rd >> 1))
4023       continue;
4024     if (this_model_rd < *best_model_rd) *best_model_rd = this_model_rd;
4025     super_block_yrd(cpi, x, &tokenonly_rd_stats, bsize, *best_rd);
4026     if (tokenonly_rd_stats.rate == INT_MAX) continue;
4027     const int this_rate =
4028         tokenonly_rd_stats.rate +
4029         intra_mode_info_cost_y(cpi, x, mbmi, bsize, mode_cost);
4030     this_rd = RDCOST(x->rdmult, this_rate, tokenonly_rd_stats.dist);
4031 
4032     if (this_rd < *best_rd) {
4033       *best_rd = this_rd;
4034       best_tx_size = mbmi->tx_size;
4035       filter_intra_mode_info = mbmi->filter_intra_mode_info;
4036       memcpy(best_txk_type, mbmi->txk_type,
4037              sizeof(best_txk_type[0]) * TXK_TYPE_BUF_LEN);
4038       memcpy(ctx->blk_skip, x->blk_skip,
4039              sizeof(x->blk_skip[0]) * ctx->num_4x4_blk);
4040       *rate = this_rate;
4041       *rate_tokenonly = tokenonly_rd_stats.rate;
4042       *distortion = tokenonly_rd_stats.dist;
4043       *skippable = tokenonly_rd_stats.skip;
4044       filter_intra_selected_flag = 1;
4045     }
4046   }
4047 
4048   if (filter_intra_selected_flag) {
4049     mbmi->mode = DC_PRED;
4050     mbmi->tx_size = best_tx_size;
4051     mbmi->filter_intra_mode_info = filter_intra_mode_info;
4052     memcpy(mbmi->txk_type, best_txk_type,
4053            sizeof(best_txk_type[0]) * TXK_TYPE_BUF_LEN);
4054     return 1;
4055   } else {
4056     return 0;
4057   }
4058 }
4059 
4060 // Run RD calculation with given luma intra prediction angle., and return
4061 // the RD cost. Update the best mode info. if the RD cost is the best so far.
calc_rd_given_intra_angle(const AV1_COMP * const cpi,MACROBLOCK * x,BLOCK_SIZE bsize,int mi_row,int mi_col,int mode_cost,int64_t best_rd_in,int8_t angle_delta,int max_angle_delta,int * rate,RD_STATS * rd_stats,int * best_angle_delta,TX_SIZE * best_tx_size,int64_t * best_rd,int64_t * best_model_rd,TX_TYPE * best_txk_type,uint8_t * best_blk_skip)4062 static int64_t calc_rd_given_intra_angle(
4063     const AV1_COMP *const cpi, MACROBLOCK *x, BLOCK_SIZE bsize, int mi_row,
4064     int mi_col, int mode_cost, int64_t best_rd_in, int8_t angle_delta,
4065     int max_angle_delta, int *rate, RD_STATS *rd_stats, int *best_angle_delta,
4066     TX_SIZE *best_tx_size, int64_t *best_rd, int64_t *best_model_rd,
4067     TX_TYPE *best_txk_type, uint8_t *best_blk_skip) {
4068   RD_STATS tokenonly_rd_stats;
4069   int64_t this_rd, this_model_rd;
4070   MB_MODE_INFO *mbmi = x->e_mbd.mi[0];
4071   const int n4 = bsize_to_num_blk(bsize);
4072   assert(!is_inter_block(mbmi));
4073   mbmi->angle_delta[PLANE_TYPE_Y] = angle_delta;
4074   this_model_rd = intra_model_yrd(cpi, x, bsize, mode_cost, mi_row, mi_col);
4075   if (*best_model_rd != INT64_MAX &&
4076       this_model_rd > *best_model_rd + (*best_model_rd >> 1))
4077     return INT64_MAX;
4078   if (this_model_rd < *best_model_rd) *best_model_rd = this_model_rd;
4079   super_block_yrd(cpi, x, &tokenonly_rd_stats, bsize, best_rd_in);
4080   if (tokenonly_rd_stats.rate == INT_MAX) return INT64_MAX;
4081 
4082   int this_rate =
4083       mode_cost + tokenonly_rd_stats.rate +
4084       x->angle_delta_cost[mbmi->mode - V_PRED][max_angle_delta + angle_delta];
4085   this_rd = RDCOST(x->rdmult, this_rate, tokenonly_rd_stats.dist);
4086 
4087   if (this_rd < *best_rd) {
4088     memcpy(best_txk_type, mbmi->txk_type,
4089            sizeof(*best_txk_type) * TXK_TYPE_BUF_LEN);
4090     memcpy(best_blk_skip, x->blk_skip, sizeof(best_blk_skip[0]) * n4);
4091     *best_rd = this_rd;
4092     *best_angle_delta = mbmi->angle_delta[PLANE_TYPE_Y];
4093     *best_tx_size = mbmi->tx_size;
4094     *rate = this_rate;
4095     rd_stats->rate = tokenonly_rd_stats.rate;
4096     rd_stats->dist = tokenonly_rd_stats.dist;
4097     rd_stats->skip = tokenonly_rd_stats.skip;
4098   }
4099   return this_rd;
4100 }
4101 
4102 // With given luma directional intra prediction mode, pick the best angle delta
4103 // Return the RD cost corresponding to the best angle delta.
rd_pick_intra_angle_sby(const AV1_COMP * const cpi,MACROBLOCK * x,int mi_row,int mi_col,int * rate,RD_STATS * rd_stats,BLOCK_SIZE bsize,int mode_cost,int64_t best_rd,int64_t * best_model_rd)4104 static int64_t rd_pick_intra_angle_sby(const AV1_COMP *const cpi, MACROBLOCK *x,
4105                                        int mi_row, int mi_col, int *rate,
4106                                        RD_STATS *rd_stats, BLOCK_SIZE bsize,
4107                                        int mode_cost, int64_t best_rd,
4108                                        int64_t *best_model_rd) {
4109   MB_MODE_INFO *mbmi = x->e_mbd.mi[0];
4110   assert(!is_inter_block(mbmi));
4111 
4112   int best_angle_delta = 0;
4113   int64_t rd_cost[2 * (MAX_ANGLE_DELTA + 2)];
4114   TX_SIZE best_tx_size = mbmi->tx_size;
4115   TX_TYPE best_txk_type[TXK_TYPE_BUF_LEN];
4116   uint8_t best_blk_skip[MAX_MIB_SIZE * MAX_MIB_SIZE];
4117 
4118   for (int i = 0; i < 2 * (MAX_ANGLE_DELTA + 2); ++i) rd_cost[i] = INT64_MAX;
4119 
4120   int first_try = 1;
4121   for (int angle_delta = 0; angle_delta <= MAX_ANGLE_DELTA; angle_delta += 2) {
4122     for (int i = 0; i < 2; ++i) {
4123       const int64_t best_rd_in =
4124           (best_rd == INT64_MAX) ? INT64_MAX
4125                                  : (best_rd + (best_rd >> (first_try ? 3 : 5)));
4126       const int64_t this_rd = calc_rd_given_intra_angle(
4127           cpi, x, bsize, mi_row, mi_col, mode_cost, best_rd_in,
4128           (1 - 2 * i) * angle_delta, MAX_ANGLE_DELTA, rate, rd_stats,
4129           &best_angle_delta, &best_tx_size, &best_rd, best_model_rd,
4130           best_txk_type, best_blk_skip);
4131       rd_cost[2 * angle_delta + i] = this_rd;
4132       if (first_try && this_rd == INT64_MAX) return best_rd;
4133       first_try = 0;
4134       if (angle_delta == 0) {
4135         rd_cost[1] = this_rd;
4136         break;
4137       }
4138     }
4139   }
4140 
4141   assert(best_rd != INT64_MAX);
4142   for (int angle_delta = 1; angle_delta <= MAX_ANGLE_DELTA; angle_delta += 2) {
4143     for (int i = 0; i < 2; ++i) {
4144       int skip_search = 0;
4145       const int64_t rd_thresh = best_rd + (best_rd >> 5);
4146       if (rd_cost[2 * (angle_delta + 1) + i] > rd_thresh &&
4147           rd_cost[2 * (angle_delta - 1) + i] > rd_thresh)
4148         skip_search = 1;
4149       if (!skip_search) {
4150         calc_rd_given_intra_angle(cpi, x, bsize, mi_row, mi_col, mode_cost,
4151                                   best_rd, (1 - 2 * i) * angle_delta,
4152                                   MAX_ANGLE_DELTA, rate, rd_stats,
4153                                   &best_angle_delta, &best_tx_size, &best_rd,
4154                                   best_model_rd, best_txk_type, best_blk_skip);
4155       }
4156     }
4157   }
4158 
4159   if (rd_stats->rate != INT_MAX) {
4160     mbmi->tx_size = best_tx_size;
4161     mbmi->angle_delta[PLANE_TYPE_Y] = best_angle_delta;
4162     memcpy(mbmi->txk_type, best_txk_type,
4163            sizeof(*best_txk_type) * TXK_TYPE_BUF_LEN);
4164     memcpy(x->blk_skip, best_blk_skip,
4165            sizeof(best_blk_skip[0]) * bsize_to_num_blk(bsize));
4166   }
4167   return best_rd;
4168 }
4169 
4170 // Indices are sign, integer, and fractional part of the gradient value
4171 static const uint8_t gradient_to_angle_bin[2][7][16] = {
4172   {
4173       { 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 0, 0, 0, 0 },
4174       { 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1 },
4175       { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 },
4176       { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 },
4177       { 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 },
4178       { 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2 },
4179       { 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2 },
4180   },
4181   {
4182       { 6, 6, 6, 6, 5, 5, 5, 5, 5, 5, 5, 5, 4, 4, 4, 4 },
4183       { 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 3, 3, 3, 3, 3, 3 },
4184       { 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3 },
4185       { 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3 },
4186       { 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3 },
4187       { 3, 3, 3, 3, 3, 3, 3, 2, 2, 2, 2, 2, 2, 2, 2, 2 },
4188       { 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2 },
4189   },
4190 };
4191 
4192 /* clang-format off */
4193 static const uint8_t mode_to_angle_bin[INTRA_MODES] = {
4194   0, 2, 6, 0, 4, 3, 5, 7, 1, 0,
4195   0,
4196 };
4197 /* clang-format on */
4198 
angle_estimation(const uint8_t * src,int src_stride,int rows,int cols,BLOCK_SIZE bsize,uint8_t * directional_mode_skip_mask)4199 static void angle_estimation(const uint8_t *src, int src_stride, int rows,
4200                              int cols, BLOCK_SIZE bsize,
4201                              uint8_t *directional_mode_skip_mask) {
4202   memset(directional_mode_skip_mask, 0,
4203          INTRA_MODES * sizeof(*directional_mode_skip_mask));
4204   // Check if angle_delta is used
4205   if (!av1_use_angle_delta(bsize)) return;
4206   uint64_t hist[DIRECTIONAL_MODES];
4207   memset(hist, 0, DIRECTIONAL_MODES * sizeof(hist[0]));
4208   src += src_stride;
4209   int r, c, dx, dy;
4210   for (r = 1; r < rows; ++r) {
4211     for (c = 1; c < cols; ++c) {
4212       dx = src[c] - src[c - 1];
4213       dy = src[c] - src[c - src_stride];
4214       int index;
4215       const int temp = dx * dx + dy * dy;
4216       if (dy == 0) {
4217         index = 2;
4218       } else {
4219         const int sn = (dx > 0) ^ (dy > 0);
4220         dx = abs(dx);
4221         dy = abs(dy);
4222         const int remd = (dx % dy) * 16 / dy;
4223         const int quot = dx / dy;
4224         index = gradient_to_angle_bin[sn][AOMMIN(quot, 6)][AOMMIN(remd, 15)];
4225       }
4226       hist[index] += temp;
4227     }
4228     src += src_stride;
4229   }
4230 
4231   int i;
4232   uint64_t hist_sum = 0;
4233   for (i = 0; i < DIRECTIONAL_MODES; ++i) hist_sum += hist[i];
4234   for (i = 0; i < INTRA_MODES; ++i) {
4235     if (av1_is_directional_mode(i)) {
4236       const uint8_t angle_bin = mode_to_angle_bin[i];
4237       uint64_t score = 2 * hist[angle_bin];
4238       int weight = 2;
4239       if (angle_bin > 0) {
4240         score += hist[angle_bin - 1];
4241         ++weight;
4242       }
4243       if (angle_bin < DIRECTIONAL_MODES - 1) {
4244         score += hist[angle_bin + 1];
4245         ++weight;
4246       }
4247       if (score * ANGLE_SKIP_THRESH < hist_sum * weight)
4248         directional_mode_skip_mask[i] = 1;
4249     }
4250   }
4251 }
4252 
highbd_angle_estimation(const uint8_t * src8,int src_stride,int rows,int cols,BLOCK_SIZE bsize,uint8_t * directional_mode_skip_mask)4253 static void highbd_angle_estimation(const uint8_t *src8, int src_stride,
4254                                     int rows, int cols, BLOCK_SIZE bsize,
4255                                     uint8_t *directional_mode_skip_mask) {
4256   memset(directional_mode_skip_mask, 0,
4257          INTRA_MODES * sizeof(*directional_mode_skip_mask));
4258   // Check if angle_delta is used
4259   if (!av1_use_angle_delta(bsize)) return;
4260   uint16_t *src = CONVERT_TO_SHORTPTR(src8);
4261   uint64_t hist[DIRECTIONAL_MODES];
4262   memset(hist, 0, DIRECTIONAL_MODES * sizeof(hist[0]));
4263   src += src_stride;
4264   int r, c, dx, dy;
4265   for (r = 1; r < rows; ++r) {
4266     for (c = 1; c < cols; ++c) {
4267       dx = src[c] - src[c - 1];
4268       dy = src[c] - src[c - src_stride];
4269       int index;
4270       const int temp = dx * dx + dy * dy;
4271       if (dy == 0) {
4272         index = 2;
4273       } else {
4274         const int sn = (dx > 0) ^ (dy > 0);
4275         dx = abs(dx);
4276         dy = abs(dy);
4277         const int remd = (dx % dy) * 16 / dy;
4278         const int quot = dx / dy;
4279         index = gradient_to_angle_bin[sn][AOMMIN(quot, 6)][AOMMIN(remd, 15)];
4280       }
4281       hist[index] += temp;
4282     }
4283     src += src_stride;
4284   }
4285 
4286   int i;
4287   uint64_t hist_sum = 0;
4288   for (i = 0; i < DIRECTIONAL_MODES; ++i) hist_sum += hist[i];
4289   for (i = 0; i < INTRA_MODES; ++i) {
4290     if (av1_is_directional_mode(i)) {
4291       const uint8_t angle_bin = mode_to_angle_bin[i];
4292       uint64_t score = 2 * hist[angle_bin];
4293       int weight = 2;
4294       if (angle_bin > 0) {
4295         score += hist[angle_bin - 1];
4296         ++weight;
4297       }
4298       if (angle_bin < DIRECTIONAL_MODES - 1) {
4299         score += hist[angle_bin + 1];
4300         ++weight;
4301       }
4302       if (score * ANGLE_SKIP_THRESH < hist_sum * weight)
4303         directional_mode_skip_mask[i] = 1;
4304     }
4305   }
4306 }
4307 
4308 // Given selected prediction mode, search for the best tx type and size.
intra_block_yrd(const AV1_COMP * const cpi,MACROBLOCK * x,BLOCK_SIZE bsize,const int * bmode_costs,int64_t * best_rd,int * rate,int * rate_tokenonly,int64_t * distortion,int * skippable,MB_MODE_INFO * best_mbmi,PICK_MODE_CONTEXT * ctx)4309 static void intra_block_yrd(const AV1_COMP *const cpi, MACROBLOCK *x,
4310                             BLOCK_SIZE bsize, const int *bmode_costs,
4311                             int64_t *best_rd, int *rate, int *rate_tokenonly,
4312                             int64_t *distortion, int *skippable,
4313                             MB_MODE_INFO *best_mbmi, PICK_MODE_CONTEXT *ctx) {
4314   MACROBLOCKD *const xd = &x->e_mbd;
4315   MB_MODE_INFO *const mbmi = xd->mi[0];
4316   RD_STATS rd_stats;
4317   super_block_yrd(cpi, x, &rd_stats, bsize, *best_rd);
4318   if (rd_stats.rate == INT_MAX) return;
4319   int this_rate_tokenonly = rd_stats.rate;
4320   if (!xd->lossless[mbmi->segment_id] && block_signals_txsize(mbmi->sb_type)) {
4321     // super_block_yrd above includes the cost of the tx_size in the
4322     // tokenonly rate, but for intra blocks, tx_size is always coded
4323     // (prediction granularity), so we account for it in the full rate,
4324     // not the tokenonly rate.
4325     this_rate_tokenonly -= tx_size_cost(&cpi->common, x, bsize, mbmi->tx_size);
4326   }
4327   const int this_rate =
4328       rd_stats.rate +
4329       intra_mode_info_cost_y(cpi, x, mbmi, bsize, bmode_costs[mbmi->mode]);
4330   const int64_t this_rd = RDCOST(x->rdmult, this_rate, rd_stats.dist);
4331   if (this_rd < *best_rd) {
4332     *best_mbmi = *mbmi;
4333     *best_rd = this_rd;
4334     *rate = this_rate;
4335     *rate_tokenonly = this_rate_tokenonly;
4336     *distortion = rd_stats.dist;
4337     *skippable = rd_stats.skip;
4338     memcpy(ctx->blk_skip, x->blk_skip,
4339            sizeof(x->blk_skip[0]) * ctx->num_4x4_blk);
4340   }
4341 }
4342 
4343 // This function is used only for intra_only frames
rd_pick_intra_sby_mode(const AV1_COMP * const cpi,MACROBLOCK * x,int mi_row,int mi_col,int * rate,int * rate_tokenonly,int64_t * distortion,int * skippable,BLOCK_SIZE bsize,int64_t best_rd,PICK_MODE_CONTEXT * ctx)4344 static int64_t rd_pick_intra_sby_mode(const AV1_COMP *const cpi, MACROBLOCK *x,
4345                                       int mi_row, int mi_col, int *rate,
4346                                       int *rate_tokenonly, int64_t *distortion,
4347                                       int *skippable, BLOCK_SIZE bsize,
4348                                       int64_t best_rd, PICK_MODE_CONTEXT *ctx) {
4349   MACROBLOCKD *const xd = &x->e_mbd;
4350   MB_MODE_INFO *const mbmi = xd->mi[0];
4351   assert(!is_inter_block(mbmi));
4352   int64_t best_model_rd = INT64_MAX;
4353   const int rows = block_size_high[bsize];
4354   const int cols = block_size_wide[bsize];
4355   int is_directional_mode;
4356   uint8_t directional_mode_skip_mask[INTRA_MODES];
4357   const int src_stride = x->plane[0].src.stride;
4358   const uint8_t *src = x->plane[0].src.buf;
4359   int beat_best_rd = 0;
4360   const int *bmode_costs;
4361   PALETTE_MODE_INFO *const pmi = &mbmi->palette_mode_info;
4362   const int try_palette =
4363       av1_allow_palette(cpi->common.allow_screen_content_tools, mbmi->sb_type);
4364   uint8_t *best_palette_color_map =
4365       try_palette ? x->palette_buffer->best_palette_color_map : NULL;
4366   const MB_MODE_INFO *above_mi = xd->above_mbmi;
4367   const MB_MODE_INFO *left_mi = xd->left_mbmi;
4368   const PREDICTION_MODE A = av1_above_block_mode(above_mi);
4369   const PREDICTION_MODE L = av1_left_block_mode(left_mi);
4370   const int above_ctx = intra_mode_context[A];
4371   const int left_ctx = intra_mode_context[L];
4372   bmode_costs = x->y_mode_costs[above_ctx][left_ctx];
4373 
4374   mbmi->angle_delta[PLANE_TYPE_Y] = 0;
4375   if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH)
4376     highbd_angle_estimation(src, src_stride, rows, cols, bsize,
4377                             directional_mode_skip_mask);
4378   else
4379     angle_estimation(src, src_stride, rows, cols, bsize,
4380                      directional_mode_skip_mask);
4381   mbmi->filter_intra_mode_info.use_filter_intra = 0;
4382   pmi->palette_size[0] = 0;
4383 
4384   if (cpi->sf.tx_type_search.fast_intra_tx_type_search)
4385     x->use_default_intra_tx_type = 1;
4386   else
4387     x->use_default_intra_tx_type = 0;
4388 
4389   MB_MODE_INFO best_mbmi = *mbmi;
4390   /* Y Search for intra prediction mode */
4391   for (int mode_idx = INTRA_MODE_START; mode_idx < INTRA_MODE_END; ++mode_idx) {
4392     RD_STATS this_rd_stats;
4393     int this_rate, this_rate_tokenonly, s;
4394     int64_t this_distortion, this_rd, this_model_rd;
4395     mbmi->mode = intra_rd_search_mode_order[mode_idx];
4396     mbmi->angle_delta[PLANE_TYPE_Y] = 0;
4397     this_model_rd =
4398         intra_model_yrd(cpi, x, bsize, bmode_costs[mbmi->mode], mi_row, mi_col);
4399     if (best_model_rd != INT64_MAX &&
4400         this_model_rd > best_model_rd + (best_model_rd >> 1))
4401       continue;
4402     if (this_model_rd < best_model_rd) best_model_rd = this_model_rd;
4403     is_directional_mode = av1_is_directional_mode(mbmi->mode);
4404     if (is_directional_mode && directional_mode_skip_mask[mbmi->mode]) continue;
4405     if (is_directional_mode && av1_use_angle_delta(bsize)) {
4406       this_rd_stats.rate = INT_MAX;
4407       rd_pick_intra_angle_sby(cpi, x, mi_row, mi_col, &this_rate,
4408                               &this_rd_stats, bsize, bmode_costs[mbmi->mode],
4409                               best_rd, &best_model_rd);
4410     } else {
4411       super_block_yrd(cpi, x, &this_rd_stats, bsize, best_rd);
4412     }
4413     this_rate_tokenonly = this_rd_stats.rate;
4414     this_distortion = this_rd_stats.dist;
4415     s = this_rd_stats.skip;
4416 
4417     if (this_rate_tokenonly == INT_MAX) continue;
4418 
4419     if (!xd->lossless[mbmi->segment_id] &&
4420         block_signals_txsize(mbmi->sb_type)) {
4421       // super_block_yrd above includes the cost of the tx_size in the
4422       // tokenonly rate, but for intra blocks, tx_size is always coded
4423       // (prediction granularity), so we account for it in the full rate,
4424       // not the tokenonly rate.
4425       this_rate_tokenonly -=
4426           tx_size_cost(&cpi->common, x, bsize, mbmi->tx_size);
4427     }
4428     this_rate =
4429         this_rd_stats.rate +
4430         intra_mode_info_cost_y(cpi, x, mbmi, bsize, bmode_costs[mbmi->mode]);
4431     this_rd = RDCOST(x->rdmult, this_rate, this_distortion);
4432     if (this_rd < best_rd) {
4433       best_mbmi = *mbmi;
4434       best_rd = this_rd;
4435       beat_best_rd = 1;
4436       *rate = this_rate;
4437       *rate_tokenonly = this_rate_tokenonly;
4438       *distortion = this_distortion;
4439       *skippable = s;
4440       memcpy(ctx->blk_skip, x->blk_skip,
4441              sizeof(x->blk_skip[0]) * ctx->num_4x4_blk);
4442     }
4443   }
4444 
4445   if (try_palette) {
4446     rd_pick_palette_intra_sby(
4447         cpi, x, bsize, mi_row, mi_col, bmode_costs[DC_PRED], &best_mbmi,
4448         best_palette_color_map, &best_rd, &best_model_rd, rate, rate_tokenonly,
4449         distortion, skippable, ctx, ctx->blk_skip);
4450   }
4451 
4452   if (beat_best_rd && av1_filter_intra_allowed_bsize(&cpi->common, bsize)) {
4453     if (rd_pick_filter_intra_sby(
4454             cpi, x, mi_row, mi_col, rate, rate_tokenonly, distortion, skippable,
4455             bsize, bmode_costs[DC_PRED], &best_rd, &best_model_rd, ctx)) {
4456       best_mbmi = *mbmi;
4457     }
4458   }
4459 
4460   // If previous searches use only the default tx type, do an extra search for
4461   // the best tx type.
4462   if (x->use_default_intra_tx_type) {
4463     *mbmi = best_mbmi;
4464     x->use_default_intra_tx_type = 0;
4465     intra_block_yrd(cpi, x, bsize, bmode_costs, &best_rd, rate, rate_tokenonly,
4466                     distortion, skippable, &best_mbmi, ctx);
4467   }
4468 
4469   *mbmi = best_mbmi;
4470   return best_rd;
4471 }
4472 
4473 // Return value 0: early termination triggered, no valid rd cost available;
4474 //              1: rd cost values are valid.
super_block_uvrd(const AV1_COMP * const cpi,MACROBLOCK * x,RD_STATS * rd_stats,BLOCK_SIZE bsize,int64_t ref_best_rd)4475 static int super_block_uvrd(const AV1_COMP *const cpi, MACROBLOCK *x,
4476                             RD_STATS *rd_stats, BLOCK_SIZE bsize,
4477                             int64_t ref_best_rd) {
4478   MACROBLOCKD *const xd = &x->e_mbd;
4479   MB_MODE_INFO *const mbmi = xd->mi[0];
4480   struct macroblockd_plane *const pd = &xd->plane[AOM_PLANE_U];
4481   const TX_SIZE uv_tx_size = av1_get_tx_size(AOM_PLANE_U, xd);
4482   int plane;
4483   int is_cost_valid = 1;
4484   av1_init_rd_stats(rd_stats);
4485 
4486   if (ref_best_rd < 0) is_cost_valid = 0;
4487 
4488   if (x->skip_chroma_rd) return is_cost_valid;
4489 
4490   bsize = scale_chroma_bsize(bsize, pd->subsampling_x, pd->subsampling_y);
4491 
4492   if (is_inter_block(mbmi) && is_cost_valid) {
4493     for (plane = 1; plane < MAX_MB_PLANE; ++plane)
4494       av1_subtract_plane(x, bsize, plane);
4495   }
4496 
4497   if (is_cost_valid) {
4498     for (plane = 1; plane < MAX_MB_PLANE; ++plane) {
4499       RD_STATS pn_rd_stats;
4500       txfm_rd_in_plane(x, cpi, &pn_rd_stats, ref_best_rd, plane, bsize,
4501                        uv_tx_size, cpi->sf.use_fast_coef_costing, FTXS_NONE);
4502       if (pn_rd_stats.rate == INT_MAX) {
4503         is_cost_valid = 0;
4504         break;
4505       }
4506       av1_merge_rd_stats(rd_stats, &pn_rd_stats);
4507       if (RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist) > ref_best_rd &&
4508           RDCOST(x->rdmult, 0, rd_stats->sse) > ref_best_rd) {
4509         is_cost_valid = 0;
4510         break;
4511       }
4512     }
4513   }
4514 
4515   if (!is_cost_valid) {
4516     // reset cost value
4517     av1_invalid_rd_stats(rd_stats);
4518   }
4519 
4520   return is_cost_valid;
4521 }
4522 
tx_block_rd_b(const AV1_COMP * cpi,MACROBLOCK * x,TX_SIZE tx_size,int blk_row,int blk_col,int plane,int block,int plane_bsize,TXB_CTX * txb_ctx,RD_STATS * rd_stats,FAST_TX_SEARCH_MODE ftxs_mode,int64_t ref_rdcost,TXB_RD_INFO * rd_info_array)4523 static void tx_block_rd_b(const AV1_COMP *cpi, MACROBLOCK *x, TX_SIZE tx_size,
4524                           int blk_row, int blk_col, int plane, int block,
4525                           int plane_bsize, TXB_CTX *txb_ctx, RD_STATS *rd_stats,
4526                           FAST_TX_SEARCH_MODE ftxs_mode, int64_t ref_rdcost,
4527                           TXB_RD_INFO *rd_info_array) {
4528   const struct macroblock_plane *const p = &x->plane[plane];
4529   const uint16_t cur_joint_ctx =
4530       (txb_ctx->dc_sign_ctx << 8) + txb_ctx->txb_skip_ctx;
4531   const int txk_type_idx =
4532       av1_get_txk_type_index(plane_bsize, blk_row, blk_col);
4533   // Look up RD and terminate early in case when we've already processed exactly
4534   // the same residual with exactly the same entropy context.
4535   if (rd_info_array != NULL && rd_info_array->valid &&
4536       rd_info_array->entropy_context == cur_joint_ctx) {
4537     if (plane == 0)
4538       x->e_mbd.mi[0]->txk_type[txk_type_idx] = rd_info_array->tx_type;
4539     const TX_TYPE ref_tx_type =
4540         av1_get_tx_type(get_plane_type(plane), &x->e_mbd, blk_row, blk_col,
4541                         tx_size, cpi->common.reduced_tx_set_used);
4542     if (ref_tx_type == rd_info_array->tx_type) {
4543       rd_stats->rate += rd_info_array->rate;
4544       rd_stats->dist += rd_info_array->dist;
4545       rd_stats->sse += rd_info_array->sse;
4546       rd_stats->skip &= rd_info_array->eob == 0;
4547       p->eobs[block] = rd_info_array->eob;
4548       p->txb_entropy_ctx[block] = rd_info_array->txb_entropy_ctx;
4549       return;
4550     }
4551   }
4552 
4553   RD_STATS this_rd_stats;
4554   search_txk_type(cpi, x, plane, block, blk_row, blk_col, plane_bsize, tx_size,
4555                   txb_ctx, ftxs_mode, 0, ref_rdcost, &this_rd_stats);
4556 
4557   av1_merge_rd_stats(rd_stats, &this_rd_stats);
4558 
4559   // Save RD results for possible reuse in future.
4560   if (rd_info_array != NULL) {
4561     rd_info_array->valid = 1;
4562     rd_info_array->entropy_context = cur_joint_ctx;
4563     rd_info_array->rate = this_rd_stats.rate;
4564     rd_info_array->dist = this_rd_stats.dist;
4565     rd_info_array->sse = this_rd_stats.sse;
4566     rd_info_array->eob = p->eobs[block];
4567     rd_info_array->txb_entropy_ctx = p->txb_entropy_ctx[block];
4568     if (plane == 0) {
4569       rd_info_array->tx_type = x->e_mbd.mi[0]->txk_type[txk_type_idx];
4570     }
4571   }
4572 }
4573 
get_mean_and_dev(const int16_t * data,int stride,int bw,int bh,float * mean,float * dev)4574 static void get_mean_and_dev(const int16_t *data, int stride, int bw, int bh,
4575                              float *mean, float *dev) {
4576   int x_sum = 0;
4577   uint64_t x2_sum = 0;
4578   for (int i = 0; i < bh; ++i) {
4579     for (int j = 0; j < bw; ++j) {
4580       const int val = data[j];
4581       x_sum += val;
4582       x2_sum += val * val;
4583     }
4584     data += stride;
4585   }
4586 
4587   const int num = bw * bh;
4588   const float e_x = (float)x_sum / num;
4589   const float e_x2 = (float)((double)x2_sum / num);
4590   const float diff = e_x2 - e_x * e_x;
4591   *dev = (diff > 0) ? sqrtf(diff) : 0;
4592   *mean = e_x;
4593 }
4594 
get_mean_and_dev_float(const float * data,int stride,int bw,int bh,float * mean,float * dev)4595 static void get_mean_and_dev_float(const float *data, int stride, int bw,
4596                                    int bh, float *mean, float *dev) {
4597   float x_sum = 0;
4598   float x2_sum = 0;
4599   for (int i = 0; i < bh; ++i) {
4600     for (int j = 0; j < bw; ++j) {
4601       const float val = data[j];
4602       x_sum += val;
4603       x2_sum += val * val;
4604     }
4605     data += stride;
4606   }
4607 
4608   const int num = bw * bh;
4609   const float e_x = x_sum / num;
4610   const float e_x2 = x2_sum / num;
4611   const float diff = e_x2 - e_x * e_x;
4612   *dev = (diff > 0) ? sqrtf(diff) : 0;
4613   *mean = e_x;
4614 }
4615 
4616 // Feature used by the model to predict tx split: the mean and standard
4617 // deviation values of the block and sub-blocks.
get_mean_dev_features(const int16_t * data,int stride,int bw,int bh,int levels,float * feature)4618 static void get_mean_dev_features(const int16_t *data, int stride, int bw,
4619                                   int bh, int levels, float *feature) {
4620   int feature_idx = 0;
4621   int width = bw;
4622   int height = bh;
4623   const int16_t *const data_ptr = &data[0];
4624   for (int lv = 0; lv < levels; ++lv) {
4625     if (width < 2 || height < 2) break;
4626     float mean_buf[16];
4627     float dev_buf[16];
4628     int blk_idx = 0;
4629     for (int row = 0; row < bh; row += height) {
4630       for (int col = 0; col < bw; col += width) {
4631         float mean, dev;
4632         get_mean_and_dev(data_ptr + row * stride + col, stride, width, height,
4633                          &mean, &dev);
4634         feature[feature_idx++] = mean;
4635         feature[feature_idx++] = dev;
4636         mean_buf[blk_idx] = mean;
4637         dev_buf[blk_idx++] = dev;
4638       }
4639     }
4640     if (blk_idx > 1) {
4641       float mean, dev;
4642       // Deviation of means.
4643       get_mean_and_dev_float(mean_buf, 1, 1, blk_idx, &mean, &dev);
4644       feature[feature_idx++] = dev;
4645       // Mean of deviations.
4646       get_mean_and_dev_float(dev_buf, 1, 1, blk_idx, &mean, &dev);
4647       feature[feature_idx++] = mean;
4648     }
4649     // Reduce the block size when proceeding to the next level.
4650     if (height == width) {
4651       height = height >> 1;
4652       width = width >> 1;
4653     } else if (height > width) {
4654       height = height >> 1;
4655     } else {
4656       width = width >> 1;
4657     }
4658   }
4659 }
4660 
ml_predict_tx_split(MACROBLOCK * x,BLOCK_SIZE bsize,int blk_row,int blk_col,TX_SIZE tx_size)4661 static int ml_predict_tx_split(MACROBLOCK *x, BLOCK_SIZE bsize, int blk_row,
4662                                int blk_col, TX_SIZE tx_size) {
4663   const NN_CONFIG *nn_config = av1_tx_split_nnconfig_map[tx_size];
4664   if (!nn_config) return -1;
4665 
4666   const int diff_stride = block_size_wide[bsize];
4667   const int16_t *diff =
4668       x->plane[0].src_diff + 4 * blk_row * diff_stride + 4 * blk_col;
4669   const int bw = tx_size_wide[tx_size];
4670   const int bh = tx_size_high[tx_size];
4671   aom_clear_system_state();
4672 
4673   float features[64] = { 0.0f };
4674   get_mean_dev_features(diff, diff_stride, bw, bh, 2, features);
4675 
4676   float score = 0.0f;
4677   av1_nn_predict(features, nn_config, &score);
4678   if (score > 8.0f) return 100;
4679   if (score < -8.0f) return 0;
4680   score = 1.0f / (1.0f + (float)exp(-score));
4681   return (int)(score * 100);
4682 }
4683 
4684 typedef struct {
4685   int64_t rd;
4686   int txb_entropy_ctx;
4687   TX_TYPE tx_type;
4688 } TxCandidateInfo;
4689 
try_tx_block_no_split(const AV1_COMP * cpi,MACROBLOCK * x,int blk_row,int blk_col,int block,TX_SIZE tx_size,int depth,BLOCK_SIZE plane_bsize,const ENTROPY_CONTEXT * ta,const ENTROPY_CONTEXT * tl,int txfm_partition_ctx,RD_STATS * rd_stats,int64_t ref_best_rd,FAST_TX_SEARCH_MODE ftxs_mode,TXB_RD_INFO_NODE * rd_info_node,TxCandidateInfo * no_split)4690 static void try_tx_block_no_split(
4691     const AV1_COMP *cpi, MACROBLOCK *x, int blk_row, int blk_col, int block,
4692     TX_SIZE tx_size, int depth, BLOCK_SIZE plane_bsize,
4693     const ENTROPY_CONTEXT *ta, const ENTROPY_CONTEXT *tl,
4694     int txfm_partition_ctx, RD_STATS *rd_stats, int64_t ref_best_rd,
4695     FAST_TX_SEARCH_MODE ftxs_mode, TXB_RD_INFO_NODE *rd_info_node,
4696     TxCandidateInfo *no_split) {
4697   MACROBLOCKD *const xd = &x->e_mbd;
4698   MB_MODE_INFO *const mbmi = xd->mi[0];
4699   struct macroblock_plane *const p = &x->plane[0];
4700   const int bw = block_size_wide[plane_bsize] >> tx_size_wide_log2[0];
4701 
4702   no_split->rd = INT64_MAX;
4703   no_split->txb_entropy_ctx = 0;
4704   no_split->tx_type = TX_TYPES;
4705 
4706   const ENTROPY_CONTEXT *const pta = ta + blk_col;
4707   const ENTROPY_CONTEXT *const ptl = tl + blk_row;
4708 
4709   const TX_SIZE txs_ctx = get_txsize_entropy_ctx(tx_size);
4710   TXB_CTX txb_ctx;
4711   get_txb_ctx(plane_bsize, tx_size, 0, pta, ptl, &txb_ctx);
4712   const int zero_blk_rate = x->coeff_costs[txs_ctx][PLANE_TYPE_Y]
4713                                 .txb_skip_cost[txb_ctx.txb_skip_ctx][1];
4714 
4715   rd_stats->ref_rdcost = ref_best_rd;
4716   rd_stats->zero_rate = zero_blk_rate;
4717   const int index = av1_get_txb_size_index(plane_bsize, blk_row, blk_col);
4718   mbmi->inter_tx_size[index] = tx_size;
4719   tx_block_rd_b(cpi, x, tx_size, blk_row, blk_col, 0, block, plane_bsize,
4720                 &txb_ctx, rd_stats, ftxs_mode, ref_best_rd,
4721                 rd_info_node != NULL ? rd_info_node->rd_info_array : NULL);
4722   assert(rd_stats->rate < INT_MAX);
4723 
4724   if ((RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist) >=
4725            RDCOST(x->rdmult, zero_blk_rate, rd_stats->sse) ||
4726        rd_stats->skip == 1) &&
4727       !xd->lossless[mbmi->segment_id]) {
4728 #if CONFIG_RD_DEBUG
4729     av1_update_txb_coeff_cost(rd_stats, plane, tx_size, blk_row, blk_col,
4730                               zero_blk_rate - rd_stats->rate);
4731 #endif  // CONFIG_RD_DEBUG
4732     rd_stats->rate = zero_blk_rate;
4733     rd_stats->dist = rd_stats->sse;
4734     rd_stats->skip = 1;
4735     set_blk_skip(x, 0, blk_row * bw + blk_col, 1);
4736     p->eobs[block] = 0;
4737     update_txk_array(mbmi->txk_type, plane_bsize, blk_row, blk_col, tx_size,
4738                      DCT_DCT);
4739   } else {
4740     set_blk_skip(x, 0, blk_row * bw + blk_col, 0);
4741     rd_stats->skip = 0;
4742   }
4743 
4744   if (tx_size > TX_4X4 && depth < MAX_VARTX_DEPTH)
4745     rd_stats->rate += x->txfm_partition_cost[txfm_partition_ctx][0];
4746 
4747   no_split->rd = RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist);
4748   no_split->txb_entropy_ctx = p->txb_entropy_ctx[block];
4749   const int txk_type_idx =
4750       av1_get_txk_type_index(plane_bsize, blk_row, blk_col);
4751   no_split->tx_type = mbmi->txk_type[txk_type_idx];
4752 }
4753 
4754 static void select_tx_block(const AV1_COMP *cpi, MACROBLOCK *x, int blk_row,
4755                             int blk_col, int block, TX_SIZE tx_size, int depth,
4756                             BLOCK_SIZE plane_bsize, ENTROPY_CONTEXT *ta,
4757                             ENTROPY_CONTEXT *tl, TXFM_CONTEXT *tx_above,
4758                             TXFM_CONTEXT *tx_left, RD_STATS *rd_stats,
4759                             int64_t ref_best_rd, int *is_cost_valid,
4760                             FAST_TX_SEARCH_MODE ftxs_mode,
4761                             TXB_RD_INFO_NODE *rd_info_node);
4762 
try_tx_block_split(const AV1_COMP * cpi,MACROBLOCK * x,int blk_row,int blk_col,int block,TX_SIZE tx_size,int depth,BLOCK_SIZE plane_bsize,ENTROPY_CONTEXT * ta,ENTROPY_CONTEXT * tl,TXFM_CONTEXT * tx_above,TXFM_CONTEXT * tx_left,int txfm_partition_ctx,int64_t no_split_rd,int64_t ref_best_rd,FAST_TX_SEARCH_MODE ftxs_mode,TXB_RD_INFO_NODE * rd_info_node,RD_STATS * split_rd_stats,int64_t * split_rd)4763 static void try_tx_block_split(
4764     const AV1_COMP *cpi, MACROBLOCK *x, int blk_row, int blk_col, int block,
4765     TX_SIZE tx_size, int depth, BLOCK_SIZE plane_bsize, ENTROPY_CONTEXT *ta,
4766     ENTROPY_CONTEXT *tl, TXFM_CONTEXT *tx_above, TXFM_CONTEXT *tx_left,
4767     int txfm_partition_ctx, int64_t no_split_rd, int64_t ref_best_rd,
4768     FAST_TX_SEARCH_MODE ftxs_mode, TXB_RD_INFO_NODE *rd_info_node,
4769     RD_STATS *split_rd_stats, int64_t *split_rd) {
4770   MACROBLOCKD *const xd = &x->e_mbd;
4771   const int max_blocks_high = max_block_high(xd, plane_bsize, 0);
4772   const int max_blocks_wide = max_block_wide(xd, plane_bsize, 0);
4773   const TX_SIZE sub_txs = sub_tx_size_map[tx_size];
4774   const int bsw = tx_size_wide_unit[sub_txs];
4775   const int bsh = tx_size_high_unit[sub_txs];
4776   const int sub_step = bsw * bsh;
4777   RD_STATS this_rd_stats;
4778   int this_cost_valid = 1;
4779   int64_t tmp_rd = 0;
4780 
4781   split_rd_stats->rate = x->txfm_partition_cost[txfm_partition_ctx][1];
4782 
4783   assert(tx_size < TX_SIZES_ALL);
4784 
4785   int blk_idx = 0;
4786   for (int r = 0; r < tx_size_high_unit[tx_size]; r += bsh) {
4787     for (int c = 0; c < tx_size_wide_unit[tx_size]; c += bsw, ++blk_idx) {
4788       const int offsetr = blk_row + r;
4789       const int offsetc = blk_col + c;
4790       if (offsetr >= max_blocks_high || offsetc >= max_blocks_wide) continue;
4791       assert(blk_idx < 4);
4792       select_tx_block(
4793           cpi, x, offsetr, offsetc, block, sub_txs, depth + 1, plane_bsize, ta,
4794           tl, tx_above, tx_left, &this_rd_stats, ref_best_rd - tmp_rd,
4795           &this_cost_valid, ftxs_mode,
4796           (rd_info_node != NULL) ? rd_info_node->children[blk_idx] : NULL);
4797 
4798       if (!this_cost_valid) goto LOOP_EXIT;
4799 
4800       av1_merge_rd_stats(split_rd_stats, &this_rd_stats);
4801 
4802       tmp_rd = RDCOST(x->rdmult, split_rd_stats->rate, split_rd_stats->dist);
4803 
4804       if (no_split_rd < tmp_rd) {
4805         this_cost_valid = 0;
4806         goto LOOP_EXIT;
4807       }
4808       block += sub_step;
4809     }
4810   }
4811 
4812 LOOP_EXIT : {}
4813 
4814   if (this_cost_valid) *split_rd = tmp_rd;
4815 }
4816 
4817 // Search for the best tx partition/type for a given luma block.
select_tx_block(const AV1_COMP * cpi,MACROBLOCK * x,int blk_row,int blk_col,int block,TX_SIZE tx_size,int depth,BLOCK_SIZE plane_bsize,ENTROPY_CONTEXT * ta,ENTROPY_CONTEXT * tl,TXFM_CONTEXT * tx_above,TXFM_CONTEXT * tx_left,RD_STATS * rd_stats,int64_t ref_best_rd,int * is_cost_valid,FAST_TX_SEARCH_MODE ftxs_mode,TXB_RD_INFO_NODE * rd_info_node)4818 static void select_tx_block(const AV1_COMP *cpi, MACROBLOCK *x, int blk_row,
4819                             int blk_col, int block, TX_SIZE tx_size, int depth,
4820                             BLOCK_SIZE plane_bsize, ENTROPY_CONTEXT *ta,
4821                             ENTROPY_CONTEXT *tl, TXFM_CONTEXT *tx_above,
4822                             TXFM_CONTEXT *tx_left, RD_STATS *rd_stats,
4823                             int64_t ref_best_rd, int *is_cost_valid,
4824                             FAST_TX_SEARCH_MODE ftxs_mode,
4825                             TXB_RD_INFO_NODE *rd_info_node) {
4826   assert(tx_size < TX_SIZES_ALL);
4827   av1_init_rd_stats(rd_stats);
4828   if (ref_best_rd < 0) {
4829     *is_cost_valid = 0;
4830     return;
4831   }
4832 
4833   MACROBLOCKD *const xd = &x->e_mbd;
4834   const int max_blocks_high = max_block_high(xd, plane_bsize, 0);
4835   const int max_blocks_wide = max_block_wide(xd, plane_bsize, 0);
4836   if (blk_row >= max_blocks_high || blk_col >= max_blocks_wide) return;
4837 
4838   const int bw = block_size_wide[plane_bsize] >> tx_size_wide_log2[0];
4839   MB_MODE_INFO *const mbmi = xd->mi[0];
4840   const int ctx = txfm_partition_context(tx_above + blk_col, tx_left + blk_row,
4841                                          mbmi->sb_type, tx_size);
4842   struct macroblock_plane *const p = &x->plane[0];
4843 
4844   const int try_no_split = 1;
4845   int try_split = tx_size > TX_4X4 && depth < MAX_VARTX_DEPTH;
4846 #if CONFIG_DIST_8X8
4847   if (x->using_dist_8x8)
4848     try_split &= tx_size_wide[tx_size] >= 16 && tx_size_high[tx_size] >= 16;
4849 #endif
4850   TxCandidateInfo no_split = { INT64_MAX, 0, TX_TYPES };
4851 
4852   // TX no split
4853   if (try_no_split) {
4854     try_tx_block_no_split(cpi, x, blk_row, blk_col, block, tx_size, depth,
4855                           plane_bsize, ta, tl, ctx, rd_stats, ref_best_rd,
4856                           ftxs_mode, rd_info_node, &no_split);
4857 
4858     if (cpi->sf.adaptive_txb_search_level &&
4859         (no_split.rd -
4860          (no_split.rd >> (1 + cpi->sf.adaptive_txb_search_level))) >
4861             ref_best_rd) {
4862       *is_cost_valid = 0;
4863       return;
4864     }
4865 
4866     if (cpi->sf.txb_split_cap) {
4867       if (p->eobs[block] == 0) try_split = 0;
4868     }
4869   }
4870 
4871   if (x->e_mbd.bd == 8 && !x->cb_partition_scan && try_split) {
4872     const int threshold = cpi->sf.tx_type_search.ml_tx_split_thresh;
4873     if (threshold >= 0) {
4874       const int split_score =
4875           ml_predict_tx_split(x, plane_bsize, blk_row, blk_col, tx_size);
4876       if (split_score >= 0 && split_score < threshold) try_split = 0;
4877     }
4878   }
4879 
4880   // TX split
4881   int64_t split_rd = INT64_MAX;
4882   RD_STATS split_rd_stats;
4883   av1_init_rd_stats(&split_rd_stats);
4884   if (try_split) {
4885     try_tx_block_split(cpi, x, blk_row, blk_col, block, tx_size, depth,
4886                        plane_bsize, ta, tl, tx_above, tx_left, ctx, no_split.rd,
4887                        AOMMIN(no_split.rd, ref_best_rd), ftxs_mode,
4888                        rd_info_node, &split_rd_stats, &split_rd);
4889   }
4890 
4891   if (no_split.rd < split_rd) {
4892     ENTROPY_CONTEXT *pta = ta + blk_col;
4893     ENTROPY_CONTEXT *ptl = tl + blk_row;
4894     const TX_SIZE tx_size_selected = tx_size;
4895     p->txb_entropy_ctx[block] = no_split.txb_entropy_ctx;
4896     av1_set_txb_context(x, 0, block, tx_size_selected, pta, ptl);
4897     txfm_partition_update(tx_above + blk_col, tx_left + blk_row, tx_size,
4898                           tx_size);
4899     for (int idy = 0; idy < tx_size_high_unit[tx_size]; ++idy) {
4900       for (int idx = 0; idx < tx_size_wide_unit[tx_size]; ++idx) {
4901         const int index =
4902             av1_get_txb_size_index(plane_bsize, blk_row + idy, blk_col + idx);
4903         mbmi->inter_tx_size[index] = tx_size_selected;
4904       }
4905     }
4906     mbmi->tx_size = tx_size_selected;
4907     update_txk_array(mbmi->txk_type, plane_bsize, blk_row, blk_col, tx_size,
4908                      no_split.tx_type);
4909     set_blk_skip(x, 0, blk_row * bw + blk_col, rd_stats->skip);
4910   } else {
4911     *rd_stats = split_rd_stats;
4912     if (split_rd == INT64_MAX) *is_cost_valid = 0;
4913   }
4914 }
4915 
select_inter_block_yrd(const AV1_COMP * cpi,MACROBLOCK * x,RD_STATS * rd_stats,BLOCK_SIZE bsize,int64_t ref_best_rd,FAST_TX_SEARCH_MODE ftxs_mode,TXB_RD_INFO_NODE * rd_info_tree)4916 static void select_inter_block_yrd(const AV1_COMP *cpi, MACROBLOCK *x,
4917                                    RD_STATS *rd_stats, BLOCK_SIZE bsize,
4918                                    int64_t ref_best_rd,
4919                                    FAST_TX_SEARCH_MODE ftxs_mode,
4920                                    TXB_RD_INFO_NODE *rd_info_tree) {
4921   MACROBLOCKD *const xd = &x->e_mbd;
4922   int is_cost_valid = 1;
4923   int64_t this_rd = 0, skip_rd = 0;
4924 
4925   if (ref_best_rd < 0) is_cost_valid = 0;
4926 
4927   av1_init_rd_stats(rd_stats);
4928 
4929   if (is_cost_valid) {
4930     const struct macroblockd_plane *const pd = &xd->plane[0];
4931     const BLOCK_SIZE plane_bsize =
4932         get_plane_block_size(bsize, pd->subsampling_x, pd->subsampling_y);
4933     const int mi_width = mi_size_wide[plane_bsize];
4934     const int mi_height = mi_size_high[plane_bsize];
4935     const TX_SIZE max_tx_size = max_txsize_rect_lookup[plane_bsize];
4936     const int bh = tx_size_high_unit[max_tx_size];
4937     const int bw = tx_size_wide_unit[max_tx_size];
4938     int idx, idy;
4939     int block = 0;
4940     int step = tx_size_wide_unit[max_tx_size] * tx_size_high_unit[max_tx_size];
4941     ENTROPY_CONTEXT ctxa[MAX_MIB_SIZE];
4942     ENTROPY_CONTEXT ctxl[MAX_MIB_SIZE];
4943     TXFM_CONTEXT tx_above[MAX_MIB_SIZE];
4944     TXFM_CONTEXT tx_left[MAX_MIB_SIZE];
4945 
4946     RD_STATS pn_rd_stats;
4947     const int init_depth =
4948         get_search_init_depth(mi_width, mi_height, 1, &cpi->sf);
4949     av1_init_rd_stats(&pn_rd_stats);
4950 
4951     av1_get_entropy_contexts(bsize, pd, ctxa, ctxl);
4952     memcpy(tx_above, xd->above_txfm_context, sizeof(TXFM_CONTEXT) * mi_width);
4953     memcpy(tx_left, xd->left_txfm_context, sizeof(TXFM_CONTEXT) * mi_height);
4954     const int skip_ctx = av1_get_skip_context(xd);
4955     const int s0 = x->skip_cost[skip_ctx][0];
4956     const int s1 = x->skip_cost[skip_ctx][1];
4957 
4958     skip_rd = RDCOST(x->rdmult, s1, 0);
4959     this_rd = RDCOST(x->rdmult, s0, 0);
4960     for (idy = 0; idy < mi_height; idy += bh) {
4961       for (idx = 0; idx < mi_width; idx += bw) {
4962         int64_t best_rd_sofar = (ref_best_rd - (AOMMIN(skip_rd, this_rd)));
4963         select_tx_block(cpi, x, idy, idx, block, max_tx_size, init_depth,
4964                         plane_bsize, ctxa, ctxl, tx_above, tx_left,
4965                         &pn_rd_stats, best_rd_sofar, &is_cost_valid, ftxs_mode,
4966                         rd_info_tree);
4967         if (!is_cost_valid || pn_rd_stats.rate == INT_MAX) {
4968           av1_invalid_rd_stats(rd_stats);
4969           return;
4970         }
4971         av1_merge_rd_stats(rd_stats, &pn_rd_stats);
4972         skip_rd = RDCOST(x->rdmult, s1, rd_stats->sse);
4973         this_rd = RDCOST(x->rdmult, rd_stats->rate + s0, rd_stats->dist);
4974         block += step;
4975         if (rd_info_tree != NULL) rd_info_tree += 1;
4976       }
4977     }
4978     if (skip_rd <= this_rd) {
4979       rd_stats->rate = 0;
4980       rd_stats->dist = rd_stats->sse;
4981       rd_stats->skip = 1;
4982     } else {
4983       rd_stats->skip = 0;
4984     }
4985   }
4986 
4987   if (!is_cost_valid) {
4988     // reset cost value
4989     av1_invalid_rd_stats(rd_stats);
4990   }
4991 }
4992 
select_tx_size_fix_type(const AV1_COMP * cpi,MACROBLOCK * x,RD_STATS * rd_stats,BLOCK_SIZE bsize,int64_t ref_best_rd,TXB_RD_INFO_NODE * rd_info_tree)4993 static int64_t select_tx_size_fix_type(const AV1_COMP *cpi, MACROBLOCK *x,
4994                                        RD_STATS *rd_stats, BLOCK_SIZE bsize,
4995                                        int64_t ref_best_rd,
4996                                        TXB_RD_INFO_NODE *rd_info_tree) {
4997   const int fast_tx_search = cpi->sf.tx_size_search_method > USE_FULL_RD;
4998   MACROBLOCKD *const xd = &x->e_mbd;
4999   MB_MODE_INFO *const mbmi = xd->mi[0];
5000   const int is_inter = is_inter_block(mbmi);
5001   const int skip_ctx = av1_get_skip_context(xd);
5002   int s0 = x->skip_cost[skip_ctx][0];
5003   int s1 = x->skip_cost[skip_ctx][1];
5004   int64_t rd;
5005 
5006   // TODO(debargha): enable this as a speed feature where the
5007   // select_inter_block_yrd() function above will use a simplified search
5008   // such as not using full optimize, but the inter_block_yrd() function
5009   // will use more complex search given that the transform partitions have
5010   // already been decided.
5011 
5012   int64_t rd_thresh = ref_best_rd;
5013   if (fast_tx_search && rd_thresh < INT64_MAX) {
5014     if (INT64_MAX - rd_thresh > (rd_thresh >> 3)) rd_thresh += (rd_thresh >> 3);
5015   }
5016   assert(rd_thresh > 0);
5017 
5018   FAST_TX_SEARCH_MODE ftxs_mode =
5019       fast_tx_search ? FTXS_DCT_AND_1D_DCT_ONLY : FTXS_NONE;
5020   select_inter_block_yrd(cpi, x, rd_stats, bsize, rd_thresh, ftxs_mode,
5021                          rd_info_tree);
5022   if (rd_stats->rate == INT_MAX) return INT64_MAX;
5023 
5024   // If fast_tx_search is true, only DCT and 1D DCT were tested in
5025   // select_inter_block_yrd() above. Do a better search for tx type with
5026   // tx sizes already decided.
5027   if (fast_tx_search) {
5028     if (!inter_block_yrd(cpi, x, rd_stats, bsize, ref_best_rd, FTXS_NONE))
5029       return INT64_MAX;
5030   }
5031 
5032   if (rd_stats->skip)
5033     rd = RDCOST(x->rdmult, s1, rd_stats->sse);
5034   else
5035     rd = RDCOST(x->rdmult, rd_stats->rate + s0, rd_stats->dist);
5036 
5037   if (is_inter && !xd->lossless[xd->mi[0]->segment_id] && !(rd_stats->skip))
5038     rd = AOMMIN(rd, RDCOST(x->rdmult, s1, rd_stats->sse));
5039 
5040   return rd;
5041 }
5042 
5043 // Finds rd cost for a y block, given the transform size partitions
tx_block_yrd(const AV1_COMP * cpi,MACROBLOCK * x,int blk_row,int blk_col,int block,TX_SIZE tx_size,BLOCK_SIZE plane_bsize,int depth,ENTROPY_CONTEXT * above_ctx,ENTROPY_CONTEXT * left_ctx,TXFM_CONTEXT * tx_above,TXFM_CONTEXT * tx_left,int64_t ref_best_rd,RD_STATS * rd_stats,FAST_TX_SEARCH_MODE ftxs_mode)5044 static void tx_block_yrd(const AV1_COMP *cpi, MACROBLOCK *x, int blk_row,
5045                          int blk_col, int block, TX_SIZE tx_size,
5046                          BLOCK_SIZE plane_bsize, int depth,
5047                          ENTROPY_CONTEXT *above_ctx, ENTROPY_CONTEXT *left_ctx,
5048                          TXFM_CONTEXT *tx_above, TXFM_CONTEXT *tx_left,
5049                          int64_t ref_best_rd, RD_STATS *rd_stats,
5050                          FAST_TX_SEARCH_MODE ftxs_mode) {
5051   MACROBLOCKD *const xd = &x->e_mbd;
5052   MB_MODE_INFO *const mbmi = xd->mi[0];
5053   const int max_blocks_high = max_block_high(xd, plane_bsize, 0);
5054   const int max_blocks_wide = max_block_wide(xd, plane_bsize, 0);
5055 
5056   assert(tx_size < TX_SIZES_ALL);
5057 
5058   if (blk_row >= max_blocks_high || blk_col >= max_blocks_wide) return;
5059 
5060   const TX_SIZE plane_tx_size = mbmi->inter_tx_size[av1_get_txb_size_index(
5061       plane_bsize, blk_row, blk_col)];
5062 
5063   int ctx = txfm_partition_context(tx_above + blk_col, tx_left + blk_row,
5064                                    mbmi->sb_type, tx_size);
5065 
5066   av1_init_rd_stats(rd_stats);
5067   if (tx_size == plane_tx_size) {
5068     ENTROPY_CONTEXT *ta = above_ctx + blk_col;
5069     ENTROPY_CONTEXT *tl = left_ctx + blk_row;
5070     const TX_SIZE txs_ctx = get_txsize_entropy_ctx(tx_size);
5071     TXB_CTX txb_ctx;
5072     get_txb_ctx(plane_bsize, tx_size, 0, ta, tl, &txb_ctx);
5073 
5074     const int zero_blk_rate = x->coeff_costs[txs_ctx][get_plane_type(0)]
5075                                   .txb_skip_cost[txb_ctx.txb_skip_ctx][1];
5076     rd_stats->zero_rate = zero_blk_rate;
5077     rd_stats->ref_rdcost = ref_best_rd;
5078     tx_block_rd_b(cpi, x, tx_size, blk_row, blk_col, 0, block, plane_bsize,
5079                   &txb_ctx, rd_stats, ftxs_mode, ref_best_rd, NULL);
5080     const int mi_width = block_size_wide[plane_bsize] >> tx_size_wide_log2[0];
5081     if (RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist) >=
5082             RDCOST(x->rdmult, zero_blk_rate, rd_stats->sse) ||
5083         rd_stats->skip == 1) {
5084       rd_stats->rate = zero_blk_rate;
5085       rd_stats->dist = rd_stats->sse;
5086       rd_stats->skip = 1;
5087       set_blk_skip(x, 0, blk_row * mi_width + blk_col, 1);
5088       x->plane[0].eobs[block] = 0;
5089       x->plane[0].txb_entropy_ctx[block] = 0;
5090       update_txk_array(mbmi->txk_type, plane_bsize, blk_row, blk_col, tx_size,
5091                        DCT_DCT);
5092     } else {
5093       rd_stats->skip = 0;
5094       set_blk_skip(x, 0, blk_row * mi_width + blk_col, 0);
5095     }
5096     if (tx_size > TX_4X4 && depth < MAX_VARTX_DEPTH)
5097       rd_stats->rate += x->txfm_partition_cost[ctx][0];
5098     av1_set_txb_context(x, 0, block, tx_size, ta, tl);
5099     txfm_partition_update(tx_above + blk_col, tx_left + blk_row, tx_size,
5100                           tx_size);
5101   } else {
5102     const TX_SIZE sub_txs = sub_tx_size_map[tx_size];
5103     const int bsw = tx_size_wide_unit[sub_txs];
5104     const int bsh = tx_size_high_unit[sub_txs];
5105     const int step = bsh * bsw;
5106     RD_STATS pn_rd_stats;
5107     int64_t this_rd = 0;
5108     assert(bsw > 0 && bsh > 0);
5109 
5110     for (int row = 0; row < tx_size_high_unit[tx_size]; row += bsh) {
5111       for (int col = 0; col < tx_size_wide_unit[tx_size]; col += bsw) {
5112         const int offsetr = blk_row + row;
5113         const int offsetc = blk_col + col;
5114 
5115         if (offsetr >= max_blocks_high || offsetc >= max_blocks_wide) continue;
5116 
5117         av1_init_rd_stats(&pn_rd_stats);
5118         tx_block_yrd(cpi, x, offsetr, offsetc, block, sub_txs, plane_bsize,
5119                      depth + 1, above_ctx, left_ctx, tx_above, tx_left,
5120                      ref_best_rd - this_rd, &pn_rd_stats, ftxs_mode);
5121         if (pn_rd_stats.rate == INT_MAX) {
5122           av1_invalid_rd_stats(rd_stats);
5123           return;
5124         }
5125         av1_merge_rd_stats(rd_stats, &pn_rd_stats);
5126         this_rd += RDCOST(x->rdmult, pn_rd_stats.rate, pn_rd_stats.dist);
5127         block += step;
5128       }
5129     }
5130 
5131     if (tx_size > TX_4X4 && depth < MAX_VARTX_DEPTH)
5132       rd_stats->rate += x->txfm_partition_cost[ctx][1];
5133   }
5134 }
5135 
5136 // Return value 0: early termination triggered, no valid rd cost available;
5137 //              1: rd cost values are valid.
inter_block_yrd(const AV1_COMP * cpi,MACROBLOCK * x,RD_STATS * rd_stats,BLOCK_SIZE bsize,int64_t ref_best_rd,FAST_TX_SEARCH_MODE ftxs_mode)5138 static int inter_block_yrd(const AV1_COMP *cpi, MACROBLOCK *x,
5139                            RD_STATS *rd_stats, BLOCK_SIZE bsize,
5140                            int64_t ref_best_rd, FAST_TX_SEARCH_MODE ftxs_mode) {
5141   MACROBLOCKD *const xd = &x->e_mbd;
5142   int is_cost_valid = 1;
5143   int64_t this_rd = 0;
5144 
5145   if (ref_best_rd < 0) is_cost_valid = 0;
5146 
5147   av1_init_rd_stats(rd_stats);
5148 
5149   if (is_cost_valid) {
5150     const struct macroblockd_plane *const pd = &xd->plane[0];
5151     const BLOCK_SIZE plane_bsize =
5152         get_plane_block_size(bsize, pd->subsampling_x, pd->subsampling_y);
5153     const int mi_width = mi_size_wide[plane_bsize];
5154     const int mi_height = mi_size_high[plane_bsize];
5155     const TX_SIZE max_tx_size = get_vartx_max_txsize(xd, plane_bsize, 0);
5156     const int bh = tx_size_high_unit[max_tx_size];
5157     const int bw = tx_size_wide_unit[max_tx_size];
5158     const int init_depth =
5159         get_search_init_depth(mi_width, mi_height, 1, &cpi->sf);
5160     int idx, idy;
5161     int block = 0;
5162     int step = tx_size_wide_unit[max_tx_size] * tx_size_high_unit[max_tx_size];
5163     ENTROPY_CONTEXT ctxa[MAX_MIB_SIZE];
5164     ENTROPY_CONTEXT ctxl[MAX_MIB_SIZE];
5165     TXFM_CONTEXT tx_above[MAX_MIB_SIZE];
5166     TXFM_CONTEXT tx_left[MAX_MIB_SIZE];
5167     RD_STATS pn_rd_stats;
5168 
5169     av1_get_entropy_contexts(bsize, pd, ctxa, ctxl);
5170     memcpy(tx_above, xd->above_txfm_context, sizeof(TXFM_CONTEXT) * mi_width);
5171     memcpy(tx_left, xd->left_txfm_context, sizeof(TXFM_CONTEXT) * mi_height);
5172 
5173     for (idy = 0; idy < mi_height; idy += bh) {
5174       for (idx = 0; idx < mi_width; idx += bw) {
5175         av1_init_rd_stats(&pn_rd_stats);
5176         tx_block_yrd(cpi, x, idy, idx, block, max_tx_size, plane_bsize,
5177                      init_depth, ctxa, ctxl, tx_above, tx_left,
5178                      ref_best_rd - this_rd, &pn_rd_stats, ftxs_mode);
5179         if (pn_rd_stats.rate == INT_MAX) {
5180           av1_invalid_rd_stats(rd_stats);
5181           return 0;
5182         }
5183         av1_merge_rd_stats(rd_stats, &pn_rd_stats);
5184         this_rd +=
5185             AOMMIN(RDCOST(x->rdmult, pn_rd_stats.rate, pn_rd_stats.dist),
5186                    RDCOST(x->rdmult, pn_rd_stats.zero_rate, pn_rd_stats.sse));
5187         block += step;
5188       }
5189     }
5190   }
5191 
5192   const int skip_ctx = av1_get_skip_context(xd);
5193   const int s0 = x->skip_cost[skip_ctx][0];
5194   const int s1 = x->skip_cost[skip_ctx][1];
5195   int64_t skip_rd = RDCOST(x->rdmult, s1, rd_stats->sse);
5196   this_rd = RDCOST(x->rdmult, rd_stats->rate + s0, rd_stats->dist);
5197   if (skip_rd < this_rd) {
5198     this_rd = skip_rd;
5199     rd_stats->rate = 0;
5200     rd_stats->dist = rd_stats->sse;
5201     rd_stats->skip = 1;
5202   }
5203   if (this_rd > ref_best_rd) is_cost_valid = 0;
5204 
5205   if (!is_cost_valid) {
5206     // reset cost value
5207     av1_invalid_rd_stats(rd_stats);
5208   }
5209   return is_cost_valid;
5210 }
5211 
get_block_residue_hash(MACROBLOCK * x,BLOCK_SIZE bsize)5212 static INLINE uint32_t get_block_residue_hash(MACROBLOCK *x, BLOCK_SIZE bsize) {
5213   const int rows = block_size_high[bsize];
5214   const int cols = block_size_wide[bsize];
5215   const int16_t *diff = x->plane[0].src_diff;
5216   const uint32_t hash = av1_get_crc32c_value(&x->mb_rd_record.crc_calculator,
5217                                              (uint8_t *)diff, 2 * rows * cols);
5218   return (hash << 5) + bsize;
5219 }
5220 
save_tx_rd_info(int n4,uint32_t hash,const MACROBLOCK * const x,const RD_STATS * const rd_stats,MB_RD_RECORD * tx_rd_record)5221 static void save_tx_rd_info(int n4, uint32_t hash, const MACROBLOCK *const x,
5222                             const RD_STATS *const rd_stats,
5223                             MB_RD_RECORD *tx_rd_record) {
5224   int index;
5225   if (tx_rd_record->num < RD_RECORD_BUFFER_LEN) {
5226     index =
5227         (tx_rd_record->index_start + tx_rd_record->num) % RD_RECORD_BUFFER_LEN;
5228     ++tx_rd_record->num;
5229   } else {
5230     index = tx_rd_record->index_start;
5231     tx_rd_record->index_start =
5232         (tx_rd_record->index_start + 1) % RD_RECORD_BUFFER_LEN;
5233   }
5234   MB_RD_INFO *const tx_rd_info = &tx_rd_record->tx_rd_info[index];
5235   const MACROBLOCKD *const xd = &x->e_mbd;
5236   const MB_MODE_INFO *const mbmi = xd->mi[0];
5237   tx_rd_info->hash_value = hash;
5238   tx_rd_info->tx_size = mbmi->tx_size;
5239   memcpy(tx_rd_info->blk_skip, x->blk_skip,
5240          sizeof(tx_rd_info->blk_skip[0]) * n4);
5241   av1_copy(tx_rd_info->inter_tx_size, mbmi->inter_tx_size);
5242   av1_copy(tx_rd_info->txk_type, mbmi->txk_type);
5243   tx_rd_info->rd_stats = *rd_stats;
5244 }
5245 
fetch_tx_rd_info(int n4,const MB_RD_INFO * const tx_rd_info,RD_STATS * const rd_stats,MACROBLOCK * const x)5246 static void fetch_tx_rd_info(int n4, const MB_RD_INFO *const tx_rd_info,
5247                              RD_STATS *const rd_stats, MACROBLOCK *const x) {
5248   MACROBLOCKD *const xd = &x->e_mbd;
5249   MB_MODE_INFO *const mbmi = xd->mi[0];
5250   mbmi->tx_size = tx_rd_info->tx_size;
5251   memcpy(x->blk_skip, tx_rd_info->blk_skip,
5252          sizeof(tx_rd_info->blk_skip[0]) * n4);
5253   av1_copy(mbmi->inter_tx_size, tx_rd_info->inter_tx_size);
5254   av1_copy(mbmi->txk_type, tx_rd_info->txk_type);
5255   *rd_stats = tx_rd_info->rd_stats;
5256 }
5257 
find_tx_size_rd_info(TXB_RD_RECORD * cur_record,const uint32_t hash)5258 static int find_tx_size_rd_info(TXB_RD_RECORD *cur_record,
5259                                 const uint32_t hash) {
5260   // Linear search through the circular buffer to find matching hash.
5261   for (int i = cur_record->index_start - 1; i >= 0; i--) {
5262     if (cur_record->hash_vals[i] == hash) return i;
5263   }
5264   for (int i = cur_record->num - 1; i >= cur_record->index_start; i--) {
5265     if (cur_record->hash_vals[i] == hash) return i;
5266   }
5267   int index;
5268   // If not found - add new RD info into the buffer and return its index
5269   if (cur_record->num < TX_SIZE_RD_RECORD_BUFFER_LEN) {
5270     index = (cur_record->index_start + cur_record->num) %
5271             TX_SIZE_RD_RECORD_BUFFER_LEN;
5272     cur_record->num++;
5273   } else {
5274     index = cur_record->index_start;
5275     cur_record->index_start =
5276         (cur_record->index_start + 1) % TX_SIZE_RD_RECORD_BUFFER_LEN;
5277   }
5278 
5279   cur_record->hash_vals[index] = hash;
5280   av1_zero(cur_record->tx_rd_info[index]);
5281   return index;
5282 }
5283 
5284 typedef struct {
5285   int leaf;
5286   int8_t children[4];
5287 } RD_RECORD_IDX_NODE;
5288 
5289 static const RD_RECORD_IDX_NODE rd_record_tree_8x8[] = {
5290   { 1, { 0 } },
5291 };
5292 
5293 static const RD_RECORD_IDX_NODE rd_record_tree_8x16[] = {
5294   { 0, { 1, 2, -1, -1 } },
5295   { 1, { 0, 0, 0, 0 } },
5296   { 1, { 0, 0, 0, 0 } },
5297 };
5298 
5299 static const RD_RECORD_IDX_NODE rd_record_tree_16x8[] = {
5300   { 0, { 1, 2, -1, -1 } },
5301   { 1, { 0 } },
5302   { 1, { 0 } },
5303 };
5304 
5305 static const RD_RECORD_IDX_NODE rd_record_tree_16x16[] = {
5306   { 0, { 1, 2, 3, 4 } }, { 1, { 0 } }, { 1, { 0 } }, { 1, { 0 } }, { 1, { 0 } },
5307 };
5308 
5309 static const RD_RECORD_IDX_NODE rd_record_tree_1_2[] = {
5310   { 0, { 1, 2, -1, -1 } },
5311   { 0, { 3, 4, 5, 6 } },
5312   { 0, { 7, 8, 9, 10 } },
5313 };
5314 
5315 static const RD_RECORD_IDX_NODE rd_record_tree_2_1[] = {
5316   { 0, { 1, 2, -1, -1 } },
5317   { 0, { 3, 4, 7, 8 } },
5318   { 0, { 5, 6, 9, 10 } },
5319 };
5320 
5321 static const RD_RECORD_IDX_NODE rd_record_tree_sqr[] = {
5322   { 0, { 1, 2, 3, 4 } },     { 0, { 5, 6, 9, 10 } },    { 0, { 7, 8, 11, 12 } },
5323   { 0, { 13, 14, 17, 18 } }, { 0, { 15, 16, 19, 20 } },
5324 };
5325 
5326 static const RD_RECORD_IDX_NODE rd_record_tree_64x128[] = {
5327   { 0, { 2, 3, 4, 5 } },     { 0, { 6, 7, 8, 9 } },
5328   { 0, { 10, 11, 14, 15 } }, { 0, { 12, 13, 16, 17 } },
5329   { 0, { 18, 19, 22, 23 } }, { 0, { 20, 21, 24, 25 } },
5330   { 0, { 26, 27, 30, 31 } }, { 0, { 28, 29, 32, 33 } },
5331   { 0, { 34, 35, 38, 39 } }, { 0, { 36, 37, 40, 41 } },
5332 };
5333 
5334 static const RD_RECORD_IDX_NODE rd_record_tree_128x64[] = {
5335   { 0, { 2, 3, 6, 7 } },     { 0, { 4, 5, 8, 9 } },
5336   { 0, { 10, 11, 18, 19 } }, { 0, { 12, 13, 20, 21 } },
5337   { 0, { 14, 15, 22, 23 } }, { 0, { 16, 17, 24, 25 } },
5338   { 0, { 26, 27, 34, 35 } }, { 0, { 28, 29, 36, 37 } },
5339   { 0, { 30, 31, 38, 39 } }, { 0, { 32, 33, 40, 41 } },
5340 };
5341 
5342 static const RD_RECORD_IDX_NODE rd_record_tree_128x128[] = {
5343   { 0, { 4, 5, 8, 9 } },     { 0, { 6, 7, 10, 11 } },
5344   { 0, { 12, 13, 16, 17 } }, { 0, { 14, 15, 18, 19 } },
5345   { 0, { 20, 21, 28, 29 } }, { 0, { 22, 23, 30, 31 } },
5346   { 0, { 24, 25, 32, 33 } }, { 0, { 26, 27, 34, 35 } },
5347   { 0, { 36, 37, 44, 45 } }, { 0, { 38, 39, 46, 47 } },
5348   { 0, { 40, 41, 48, 49 } }, { 0, { 42, 43, 50, 51 } },
5349   { 0, { 52, 53, 60, 61 } }, { 0, { 54, 55, 62, 63 } },
5350   { 0, { 56, 57, 64, 65 } }, { 0, { 58, 59, 66, 67 } },
5351   { 0, { 68, 69, 76, 77 } }, { 0, { 70, 71, 78, 79 } },
5352   { 0, { 72, 73, 80, 81 } }, { 0, { 74, 75, 82, 83 } },
5353 };
5354 
5355 static const RD_RECORD_IDX_NODE rd_record_tree_1_4[] = {
5356   { 0, { 1, -1, 2, -1 } },
5357   { 0, { 3, 4, -1, -1 } },
5358   { 0, { 5, 6, -1, -1 } },
5359 };
5360 
5361 static const RD_RECORD_IDX_NODE rd_record_tree_4_1[] = {
5362   { 0, { 1, 2, -1, -1 } },
5363   { 0, { 3, 4, -1, -1 } },
5364   { 0, { 5, 6, -1, -1 } },
5365 };
5366 
5367 static const RD_RECORD_IDX_NODE *rd_record_tree[BLOCK_SIZES_ALL] = {
5368   NULL,                    // BLOCK_4X4
5369   NULL,                    // BLOCK_4X8
5370   NULL,                    // BLOCK_8X4
5371   rd_record_tree_8x8,      // BLOCK_8X8
5372   rd_record_tree_8x16,     // BLOCK_8X16
5373   rd_record_tree_16x8,     // BLOCK_16X8
5374   rd_record_tree_16x16,    // BLOCK_16X16
5375   rd_record_tree_1_2,      // BLOCK_16X32
5376   rd_record_tree_2_1,      // BLOCK_32X16
5377   rd_record_tree_sqr,      // BLOCK_32X32
5378   rd_record_tree_1_2,      // BLOCK_32X64
5379   rd_record_tree_2_1,      // BLOCK_64X32
5380   rd_record_tree_sqr,      // BLOCK_64X64
5381   rd_record_tree_64x128,   // BLOCK_64X128
5382   rd_record_tree_128x64,   // BLOCK_128X64
5383   rd_record_tree_128x128,  // BLOCK_128X128
5384   NULL,                    // BLOCK_4X16
5385   NULL,                    // BLOCK_16X4
5386   rd_record_tree_1_4,      // BLOCK_8X32
5387   rd_record_tree_4_1,      // BLOCK_32X8
5388   rd_record_tree_1_4,      // BLOCK_16X64
5389   rd_record_tree_4_1,      // BLOCK_64X16
5390 };
5391 
5392 static const int rd_record_tree_size[BLOCK_SIZES_ALL] = {
5393   0,                                                            // BLOCK_4X4
5394   0,                                                            // BLOCK_4X8
5395   0,                                                            // BLOCK_8X4
5396   sizeof(rd_record_tree_8x8) / sizeof(RD_RECORD_IDX_NODE),      // BLOCK_8X8
5397   sizeof(rd_record_tree_8x16) / sizeof(RD_RECORD_IDX_NODE),     // BLOCK_8X16
5398   sizeof(rd_record_tree_16x8) / sizeof(RD_RECORD_IDX_NODE),     // BLOCK_16X8
5399   sizeof(rd_record_tree_16x16) / sizeof(RD_RECORD_IDX_NODE),    // BLOCK_16X16
5400   sizeof(rd_record_tree_1_2) / sizeof(RD_RECORD_IDX_NODE),      // BLOCK_16X32
5401   sizeof(rd_record_tree_2_1) / sizeof(RD_RECORD_IDX_NODE),      // BLOCK_32X16
5402   sizeof(rd_record_tree_sqr) / sizeof(RD_RECORD_IDX_NODE),      // BLOCK_32X32
5403   sizeof(rd_record_tree_1_2) / sizeof(RD_RECORD_IDX_NODE),      // BLOCK_32X64
5404   sizeof(rd_record_tree_2_1) / sizeof(RD_RECORD_IDX_NODE),      // BLOCK_64X32
5405   sizeof(rd_record_tree_sqr) / sizeof(RD_RECORD_IDX_NODE),      // BLOCK_64X64
5406   sizeof(rd_record_tree_64x128) / sizeof(RD_RECORD_IDX_NODE),   // BLOCK_64X128
5407   sizeof(rd_record_tree_128x64) / sizeof(RD_RECORD_IDX_NODE),   // BLOCK_128X64
5408   sizeof(rd_record_tree_128x128) / sizeof(RD_RECORD_IDX_NODE),  // BLOCK_128X128
5409   0,                                                            // BLOCK_4X16
5410   0,                                                            // BLOCK_16X4
5411   sizeof(rd_record_tree_1_4) / sizeof(RD_RECORD_IDX_NODE),      // BLOCK_8X32
5412   sizeof(rd_record_tree_4_1) / sizeof(RD_RECORD_IDX_NODE),      // BLOCK_32X8
5413   sizeof(rd_record_tree_1_4) / sizeof(RD_RECORD_IDX_NODE),      // BLOCK_16X64
5414   sizeof(rd_record_tree_4_1) / sizeof(RD_RECORD_IDX_NODE),      // BLOCK_64X16
5415 };
5416 
init_rd_record_tree(TXB_RD_INFO_NODE * tree,BLOCK_SIZE bsize)5417 static INLINE void init_rd_record_tree(TXB_RD_INFO_NODE *tree,
5418                                        BLOCK_SIZE bsize) {
5419   const RD_RECORD_IDX_NODE *rd_record = rd_record_tree[bsize];
5420   const int size = rd_record_tree_size[bsize];
5421   for (int i = 0; i < size; ++i) {
5422     if (rd_record[i].leaf) {
5423       av1_zero(tree[i].children);
5424     } else {
5425       for (int j = 0; j < 4; ++j) {
5426         const int8_t idx = rd_record[i].children[j];
5427         tree[i].children[j] = idx > 0 ? &tree[idx] : NULL;
5428       }
5429     }
5430   }
5431 }
5432 
5433 // Go through all TX blocks that could be used in TX size search, compute
5434 // residual hash values for them and find matching RD info that stores previous
5435 // RD search results for these TX blocks. The idea is to prevent repeated
5436 // rate/distortion computations that happen because of the combination of
5437 // partition and TX size search. The resulting RD info records are returned in
5438 // the form of a quadtree for easier access in actual TX size search.
find_tx_size_rd_records(MACROBLOCK * x,BLOCK_SIZE bsize,int mi_row,int mi_col,TXB_RD_INFO_NODE * dst_rd_info)5439 static int find_tx_size_rd_records(MACROBLOCK *x, BLOCK_SIZE bsize, int mi_row,
5440                                    int mi_col, TXB_RD_INFO_NODE *dst_rd_info) {
5441   TXB_RD_RECORD *rd_records_table[4] = { x->txb_rd_record_8X8,
5442                                          x->txb_rd_record_16X16,
5443                                          x->txb_rd_record_32X32,
5444                                          x->txb_rd_record_64X64 };
5445   const TX_SIZE max_square_tx_size = max_txsize_lookup[bsize];
5446   const int bw = block_size_wide[bsize];
5447   const int bh = block_size_high[bsize];
5448 
5449   // Hashing is performed only for square TX sizes larger than TX_4X4
5450   if (max_square_tx_size < TX_8X8) return 0;
5451   const int diff_stride = bw;
5452   const struct macroblock_plane *const p = &x->plane[0];
5453   const int16_t *diff = &p->src_diff[0];
5454   init_rd_record_tree(dst_rd_info, bsize);
5455   // Coordinates of the top-left corner of current block within the superblock
5456   // measured in pixels:
5457   const int mi_row_in_sb = (mi_row % MAX_MIB_SIZE) << MI_SIZE_LOG2;
5458   const int mi_col_in_sb = (mi_col % MAX_MIB_SIZE) << MI_SIZE_LOG2;
5459   int cur_rd_info_idx = 0;
5460   int cur_tx_depth = 0;
5461   TX_SIZE cur_tx_size = max_txsize_rect_lookup[bsize];
5462   while (cur_tx_depth <= MAX_VARTX_DEPTH) {
5463     const int cur_tx_bw = tx_size_wide[cur_tx_size];
5464     const int cur_tx_bh = tx_size_high[cur_tx_size];
5465     if (cur_tx_bw < 8 || cur_tx_bh < 8) break;
5466     const TX_SIZE next_tx_size = sub_tx_size_map[cur_tx_size];
5467     const int tx_size_idx = cur_tx_size - TX_8X8;
5468     for (int row = 0; row < bh; row += cur_tx_bh) {
5469       for (int col = 0; col < bw; col += cur_tx_bw) {
5470         if (cur_tx_bw != cur_tx_bh) {
5471           // Use dummy nodes for all rectangular transforms within the
5472           // TX size search tree.
5473           dst_rd_info[cur_rd_info_idx].rd_info_array = NULL;
5474         } else {
5475           // Get spatial location of this TX block within the superblock
5476           // (measured in cur_tx_bsize units).
5477           const int row_in_sb = (mi_row_in_sb + row) / cur_tx_bh;
5478           const int col_in_sb = (mi_col_in_sb + col) / cur_tx_bw;
5479 
5480           int16_t hash_data[MAX_SB_SQUARE];
5481           int16_t *cur_hash_row = hash_data;
5482           const int16_t *cur_diff_row = diff + row * diff_stride + col;
5483           for (int i = 0; i < cur_tx_bh; i++) {
5484             memcpy(cur_hash_row, cur_diff_row, sizeof(*hash_data) * cur_tx_bw);
5485             cur_hash_row += cur_tx_bw;
5486             cur_diff_row += diff_stride;
5487           }
5488           const int hash = av1_get_crc32c_value(&x->mb_rd_record.crc_calculator,
5489                                                 (uint8_t *)hash_data,
5490                                                 2 * cur_tx_bw * cur_tx_bh);
5491           // Find corresponding RD info based on the hash value.
5492           const int record_idx =
5493               row_in_sb * (MAX_MIB_SIZE >> (tx_size_idx + 1)) + col_in_sb;
5494           TXB_RD_RECORD *records = &rd_records_table[tx_size_idx][record_idx];
5495           int idx = find_tx_size_rd_info(records, hash);
5496           dst_rd_info[cur_rd_info_idx].rd_info_array =
5497               &records->tx_rd_info[idx];
5498         }
5499         ++cur_rd_info_idx;
5500       }
5501     }
5502     cur_tx_size = next_tx_size;
5503     ++cur_tx_depth;
5504   }
5505   return 1;
5506 }
5507 
5508 // origin_threshold * 128 / 100
5509 static const uint32_t skip_pred_threshold[3][BLOCK_SIZES_ALL] = {
5510   {
5511       64, 64, 64, 70, 60, 60, 68, 68, 68, 68, 68,
5512       68, 68, 68, 68, 68, 64, 64, 70, 70, 68, 68,
5513   },
5514   {
5515       88, 88, 88, 86, 87, 87, 68, 68, 68, 68, 68,
5516       68, 68, 68, 68, 68, 88, 88, 86, 86, 68, 68,
5517   },
5518   {
5519       90, 93, 93, 90, 93, 93, 74, 74, 74, 74, 74,
5520       74, 74, 74, 74, 74, 90, 90, 90, 90, 74, 74,
5521   },
5522 };
5523 
5524 // lookup table for predict_skip_flag
5525 // int max_tx_size = max_txsize_rect_lookup[bsize];
5526 // if (tx_size_high[max_tx_size] > 16 || tx_size_wide[max_tx_size] > 16)
5527 //   max_tx_size = AOMMIN(max_txsize_lookup[bsize], TX_16X16);
5528 static const TX_SIZE max_predict_sf_tx_size[BLOCK_SIZES_ALL] = {
5529   TX_4X4,   TX_4X8,   TX_8X4,   TX_8X8,   TX_8X16,  TX_16X8,
5530   TX_16X16, TX_16X16, TX_16X16, TX_16X16, TX_16X16, TX_16X16,
5531   TX_16X16, TX_16X16, TX_16X16, TX_16X16, TX_4X16,  TX_16X4,
5532   TX_8X8,   TX_8X8,   TX_16X16, TX_16X16,
5533 };
5534 
5535 // Uses simple features on top of DCT coefficients to quickly predict
5536 // whether optimal RD decision is to skip encoding the residual.
5537 // The sse value is stored in dist.
predict_skip_flag(MACROBLOCK * x,BLOCK_SIZE bsize,int64_t * dist,int reduced_tx_set)5538 static int predict_skip_flag(MACROBLOCK *x, BLOCK_SIZE bsize, int64_t *dist,
5539                              int reduced_tx_set) {
5540   const int bw = block_size_wide[bsize];
5541   const int bh = block_size_high[bsize];
5542   const MACROBLOCKD *xd = &x->e_mbd;
5543   const int16_t dc_q = av1_dc_quant_QTX(x->qindex, 0, xd->bd);
5544 
5545   *dist = pixel_diff_dist(x, 0, 0, 0, bsize, bsize);
5546   const int64_t mse = *dist / bw / bh;
5547   // Normalized quantizer takes the transform upscaling factor (8 for tx size
5548   // smaller than 32) into account.
5549   const int16_t normalized_dc_q = dc_q >> 3;
5550   const int64_t mse_thresh = (int64_t)normalized_dc_q * normalized_dc_q / 8;
5551   // Predict not to skip when mse is larger than threshold.
5552   if (mse > mse_thresh) return 0;
5553 
5554   const int max_tx_size = max_predict_sf_tx_size[bsize];
5555   const int tx_h = tx_size_high[max_tx_size];
5556   const int tx_w = tx_size_wide[max_tx_size];
5557   DECLARE_ALIGNED(32, tran_low_t, coefs[32 * 32]);
5558   TxfmParam param;
5559   param.tx_type = DCT_DCT;
5560   param.tx_size = max_tx_size;
5561   param.bd = xd->bd;
5562   param.is_hbd = get_bitdepth_data_path_index(xd);
5563   param.lossless = 0;
5564   param.tx_set_type = av1_get_ext_tx_set_type(
5565       param.tx_size, is_inter_block(xd->mi[0]), reduced_tx_set);
5566   const int bd_idx = (xd->bd == 8) ? 0 : ((xd->bd == 10) ? 1 : 2);
5567   const uint32_t max_qcoef_thresh = skip_pred_threshold[bd_idx][bsize];
5568   const int16_t *src_diff = x->plane[0].src_diff;
5569   const int n_coeff = tx_w * tx_h;
5570   const int16_t ac_q = av1_ac_quant_QTX(x->qindex, 0, xd->bd);
5571   const uint32_t dc_thresh = max_qcoef_thresh * dc_q;
5572   const uint32_t ac_thresh = max_qcoef_thresh * ac_q;
5573   for (int row = 0; row < bh; row += tx_h) {
5574     for (int col = 0; col < bw; col += tx_w) {
5575       av1_fwd_txfm(src_diff + col, coefs, bw, &param);
5576       // Operating on TX domain, not pixels; we want the QTX quantizers
5577       const uint32_t dc_coef = (((uint32_t)abs(coefs[0])) << 7);
5578       if (dc_coef >= dc_thresh) return 0;
5579       for (int i = 1; i < n_coeff; ++i) {
5580         const uint32_t ac_coef = (((uint32_t)abs(coefs[i])) << 7);
5581         if (ac_coef >= ac_thresh) return 0;
5582       }
5583     }
5584     src_diff += tx_h * bw;
5585   }
5586   return 1;
5587 }
5588 
5589 // Used to set proper context for early termination with skip = 1.
set_skip_flag(MACROBLOCK * x,RD_STATS * rd_stats,int bsize,int64_t dist)5590 static void set_skip_flag(MACROBLOCK *x, RD_STATS *rd_stats, int bsize,
5591                           int64_t dist) {
5592   MACROBLOCKD *const xd = &x->e_mbd;
5593   MB_MODE_INFO *const mbmi = xd->mi[0];
5594   const int n4 = bsize_to_num_blk(bsize);
5595   const TX_SIZE tx_size = max_txsize_rect_lookup[bsize];
5596   memset(mbmi->txk_type, DCT_DCT, sizeof(mbmi->txk_type[0]) * TXK_TYPE_BUF_LEN);
5597   memset(mbmi->inter_tx_size, tx_size, sizeof(mbmi->inter_tx_size));
5598   mbmi->tx_size = tx_size;
5599   for (int i = 0; i < n4; ++i) set_blk_skip(x, 0, i, 1);
5600   rd_stats->skip = 1;
5601   rd_stats->rate = 0;
5602   if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH)
5603     dist = ROUND_POWER_OF_TWO(dist, (xd->bd - 8) * 2);
5604   rd_stats->dist = rd_stats->sse = (dist << 4);
5605 }
5606 
select_tx_type_yrd(const AV1_COMP * cpi,MACROBLOCK * x,RD_STATS * rd_stats,BLOCK_SIZE bsize,int mi_row,int mi_col,int64_t ref_best_rd)5607 static void select_tx_type_yrd(const AV1_COMP *cpi, MACROBLOCK *x,
5608                                RD_STATS *rd_stats, BLOCK_SIZE bsize, int mi_row,
5609                                int mi_col, int64_t ref_best_rd) {
5610   const AV1_COMMON *cm = &cpi->common;
5611   MACROBLOCKD *const xd = &x->e_mbd;
5612   MB_MODE_INFO *const mbmi = xd->mi[0];
5613   int64_t rd = INT64_MAX;
5614   int64_t best_rd = INT64_MAX;
5615   const int is_inter = is_inter_block(mbmi);
5616   const int n4 = bsize_to_num_blk(bsize);
5617   // Get the tx_size 1 level down
5618   const TX_SIZE min_tx_size = sub_tx_size_map[max_txsize_rect_lookup[bsize]];
5619   const TxSetType tx_set_type =
5620       av1_get_ext_tx_set_type(min_tx_size, is_inter, cm->reduced_tx_set_used);
5621   const int within_border =
5622       mi_row >= xd->tile.mi_row_start &&
5623       (mi_row + mi_size_high[bsize] < xd->tile.mi_row_end) &&
5624       mi_col >= xd->tile.mi_col_start &&
5625       (mi_col + mi_size_wide[bsize] < xd->tile.mi_col_end);
5626 
5627   av1_invalid_rd_stats(rd_stats);
5628 
5629   if (cpi->sf.model_based_prune_tx_search_level && ref_best_rd != INT64_MAX) {
5630     int model_rate;
5631     int64_t model_dist;
5632     int model_skip;
5633     model_rd_sb_fn[MODELRD_TYPE_TX_SEARCH_PRUNE](
5634         cpi, bsize, x, xd, 0, 0, mi_row, mi_col, &model_rate, &model_dist,
5635         &model_skip, NULL, NULL, NULL, NULL);
5636     const int64_t model_rd = RDCOST(x->rdmult, model_rate, model_dist);
5637     // If the modeled rd is a lot worse than the best so far, breakout.
5638     // TODO(debargha, urvang): Improve the model and make the check below
5639     // tighter.
5640     assert(cpi->sf.model_based_prune_tx_search_level >= 0 &&
5641            cpi->sf.model_based_prune_tx_search_level <= 2);
5642     static const int prune_factor_by8[] = { 2 + MODELRD_TYPE_TX_SEARCH_PRUNE,
5643                                             4 + MODELRD_TYPE_TX_SEARCH_PRUNE };
5644     if (!model_skip &&
5645         ((model_rd *
5646           prune_factor_by8[cpi->sf.model_based_prune_tx_search_level - 1]) >>
5647          3) > ref_best_rd)
5648       return;
5649   }
5650 
5651   const uint32_t hash = get_block_residue_hash(x, bsize);
5652   MB_RD_RECORD *mb_rd_record = &x->mb_rd_record;
5653 
5654   if (ref_best_rd != INT64_MAX && within_border && cpi->sf.use_mb_rd_hash) {
5655     for (int i = 0; i < mb_rd_record->num; ++i) {
5656       const int index = (mb_rd_record->index_start + i) % RD_RECORD_BUFFER_LEN;
5657       // If there is a match in the tx_rd_record, fetch the RD decision and
5658       // terminate early.
5659       if (mb_rd_record->tx_rd_info[index].hash_value == hash) {
5660         MB_RD_INFO *tx_rd_info = &mb_rd_record->tx_rd_info[index];
5661         fetch_tx_rd_info(n4, tx_rd_info, rd_stats, x);
5662         return;
5663       }
5664     }
5665   }
5666 
5667   // If we predict that skip is the optimal RD decision - set the respective
5668   // context and terminate early.
5669   int64_t dist;
5670   if (is_inter && cpi->sf.tx_type_search.use_skip_flag_prediction &&
5671       predict_skip_flag(x, bsize, &dist, cm->reduced_tx_set_used)) {
5672     set_skip_flag(x, rd_stats, bsize, dist);
5673     // Save the RD search results into tx_rd_record.
5674     if (within_border) save_tx_rd_info(n4, hash, x, rd_stats, mb_rd_record);
5675     return;
5676   }
5677 
5678   // Precompute residual hashes and find existing or add new RD records to
5679   // store and reuse rate and distortion values to speed up TX size search.
5680   TXB_RD_INFO_NODE matched_rd_info[4 + 16 + 64];
5681   int found_rd_info = 0;
5682   if (ref_best_rd != INT64_MAX && within_border && cpi->sf.use_inter_txb_hash) {
5683     found_rd_info =
5684         find_tx_size_rd_records(x, bsize, mi_row, mi_col, matched_rd_info);
5685   }
5686 
5687   prune_tx(cpi, bsize, x, xd, tx_set_type);
5688 
5689   int found = 0;
5690 
5691   RD_STATS this_rd_stats;
5692   av1_init_rd_stats(&this_rd_stats);
5693 
5694   rd = select_tx_size_fix_type(cpi, x, &this_rd_stats, bsize, ref_best_rd,
5695                                found_rd_info ? matched_rd_info : NULL);
5696   assert(IMPLIES(this_rd_stats.skip && !this_rd_stats.invalid_rate,
5697                  this_rd_stats.rate == 0));
5698 
5699   ref_best_rd = AOMMIN(rd, ref_best_rd);
5700   if (rd < best_rd) {
5701     *rd_stats = this_rd_stats;
5702     found = 1;
5703   }
5704 
5705   // Reset the pruning flags.
5706   av1_zero(x->tx_search_prune);
5707   x->tx_split_prune_flag = 0;
5708 
5709   // We should always find at least one candidate unless ref_best_rd is less
5710   // than INT64_MAX (in which case, all the calls to select_tx_size_fix_type
5711   // might have failed to find something better)
5712   assert(IMPLIES(!found, ref_best_rd != INT64_MAX));
5713   if (!found) return;
5714 
5715   // Save the RD search results into tx_rd_record.
5716   if (within_border && cpi->sf.use_mb_rd_hash)
5717     save_tx_rd_info(n4, hash, x, rd_stats, mb_rd_record);
5718 }
5719 
tx_block_uvrd(const AV1_COMP * cpi,MACROBLOCK * x,int blk_row,int blk_col,int plane,int block,TX_SIZE tx_size,BLOCK_SIZE plane_bsize,ENTROPY_CONTEXT * above_ctx,ENTROPY_CONTEXT * left_ctx,RD_STATS * rd_stats,FAST_TX_SEARCH_MODE ftxs_mode)5720 static void tx_block_uvrd(const AV1_COMP *cpi, MACROBLOCK *x, int blk_row,
5721                           int blk_col, int plane, int block, TX_SIZE tx_size,
5722                           BLOCK_SIZE plane_bsize, ENTROPY_CONTEXT *above_ctx,
5723                           ENTROPY_CONTEXT *left_ctx, RD_STATS *rd_stats,
5724                           FAST_TX_SEARCH_MODE ftxs_mode) {
5725   assert(plane > 0);
5726   assert(tx_size < TX_SIZES_ALL);
5727   MACROBLOCKD *const xd = &x->e_mbd;
5728   MB_MODE_INFO *const mbmi = xd->mi[0];
5729   const int max_blocks_high = max_block_high(xd, plane_bsize, plane);
5730   const int max_blocks_wide = max_block_wide(xd, plane_bsize, plane);
5731   if (blk_row >= max_blocks_high || blk_col >= max_blocks_wide) return;
5732 
5733   ENTROPY_CONTEXT *ta = above_ctx + blk_col;
5734   ENTROPY_CONTEXT *tl = left_ctx + blk_row;
5735   TXB_CTX txb_ctx;
5736   get_txb_ctx(plane_bsize, tx_size, plane, ta, tl, &txb_ctx);
5737   const TX_SIZE txs_ctx = get_txsize_entropy_ctx(tx_size);
5738   const int zero_blk_rate = x->coeff_costs[txs_ctx][PLANE_TYPE_UV]
5739                                 .txb_skip_cost[txb_ctx.txb_skip_ctx][1];
5740   tx_block_rd_b(cpi, x, tx_size, blk_row, blk_col, plane, block, plane_bsize,
5741                 &txb_ctx, rd_stats, ftxs_mode, INT64_MAX, NULL);
5742 
5743   const int mi_width = block_size_wide[plane_bsize] >> tx_size_wide_log2[0];
5744   const int blk_idx = blk_row * mi_width + blk_col;
5745 
5746   av1_set_txb_context(x, plane, block, tx_size, ta, tl);
5747   if ((RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist) >=
5748            RDCOST(x->rdmult, zero_blk_rate, rd_stats->sse) ||
5749        rd_stats->skip == 1) &&
5750       !xd->lossless[mbmi->segment_id]) {
5751     rd_stats->rate = zero_blk_rate;
5752     rd_stats->dist = rd_stats->sse;
5753   }
5754 
5755   // Set chroma blk_skip to 0
5756   set_blk_skip(x, plane, blk_idx, 0);
5757 }
5758 
5759 // Return value 0: early termination triggered, no valid rd cost available;
5760 //              1: rd cost values are valid.
inter_block_uvrd(const AV1_COMP * cpi,MACROBLOCK * x,RD_STATS * rd_stats,BLOCK_SIZE bsize,int64_t non_skip_ref_best_rd,int64_t skip_ref_best_rd,FAST_TX_SEARCH_MODE ftxs_mode)5761 static int inter_block_uvrd(const AV1_COMP *cpi, MACROBLOCK *x,
5762                             RD_STATS *rd_stats, BLOCK_SIZE bsize,
5763                             int64_t non_skip_ref_best_rd,
5764                             int64_t skip_ref_best_rd,
5765                             FAST_TX_SEARCH_MODE ftxs_mode) {
5766   MACROBLOCKD *const xd = &x->e_mbd;
5767   MB_MODE_INFO *const mbmi = xd->mi[0];
5768   int plane;
5769   int is_cost_valid = 1;
5770   int64_t this_rd = 0;
5771   int64_t skip_rd = 0;
5772 
5773   if ((non_skip_ref_best_rd < 0) && (skip_ref_best_rd < 0)) is_cost_valid = 0;
5774 
5775   av1_init_rd_stats(rd_stats);
5776 
5777   if (x->skip_chroma_rd) {
5778     if (!is_cost_valid) av1_invalid_rd_stats(rd_stats);
5779 
5780     return is_cost_valid;
5781   }
5782 
5783   const BLOCK_SIZE bsizec = scale_chroma_bsize(
5784       bsize, xd->plane[1].subsampling_x, xd->plane[1].subsampling_y);
5785 
5786   if (is_inter_block(mbmi) && is_cost_valid) {
5787     for (plane = 1; plane < MAX_MB_PLANE; ++plane)
5788       av1_subtract_plane(x, bsizec, plane);
5789   }
5790 
5791   if (is_cost_valid) {
5792     for (plane = 1; plane < MAX_MB_PLANE; ++plane) {
5793       const struct macroblockd_plane *const pd = &xd->plane[plane];
5794       const BLOCK_SIZE plane_bsize =
5795           get_plane_block_size(bsizec, pd->subsampling_x, pd->subsampling_y);
5796       const int mi_width = block_size_wide[plane_bsize] >> tx_size_wide_log2[0];
5797       const int mi_height =
5798           block_size_high[plane_bsize] >> tx_size_high_log2[0];
5799       const TX_SIZE max_tx_size = get_vartx_max_txsize(xd, plane_bsize, plane);
5800       const int bh = tx_size_high_unit[max_tx_size];
5801       const int bw = tx_size_wide_unit[max_tx_size];
5802       int idx, idy;
5803       int block = 0;
5804       const int step = bh * bw;
5805       ENTROPY_CONTEXT ta[MAX_MIB_SIZE];
5806       ENTROPY_CONTEXT tl[MAX_MIB_SIZE];
5807       av1_get_entropy_contexts(bsizec, pd, ta, tl);
5808 
5809       for (idy = 0; idy < mi_height; idy += bh) {
5810         for (idx = 0; idx < mi_width; idx += bw) {
5811           RD_STATS pn_rd_stats;
5812           av1_init_rd_stats(&pn_rd_stats);
5813           tx_block_uvrd(cpi, x, idy, idx, plane, block, max_tx_size,
5814                         plane_bsize, ta, tl, &pn_rd_stats, ftxs_mode);
5815           if (pn_rd_stats.rate == INT_MAX) {
5816             av1_invalid_rd_stats(rd_stats);
5817             return 0;
5818           }
5819           av1_merge_rd_stats(rd_stats, &pn_rd_stats);
5820           this_rd = RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist);
5821           skip_rd = RDCOST(x->rdmult, 0, rd_stats->sse);
5822           if ((this_rd > non_skip_ref_best_rd) &&
5823               (skip_rd > skip_ref_best_rd)) {
5824             av1_invalid_rd_stats(rd_stats);
5825             return 0;
5826           }
5827           block += step;
5828         }
5829       }
5830     }
5831   } else {
5832     // reset cost value
5833     av1_invalid_rd_stats(rd_stats);
5834   }
5835 
5836   return is_cost_valid;
5837 }
5838 
rd_pick_palette_intra_sbuv(const AV1_COMP * const cpi,MACROBLOCK * x,int dc_mode_cost,uint8_t * best_palette_color_map,MB_MODE_INFO * const best_mbmi,int64_t * best_rd,int * rate,int * rate_tokenonly,int64_t * distortion,int * skippable)5839 static void rd_pick_palette_intra_sbuv(const AV1_COMP *const cpi, MACROBLOCK *x,
5840                                        int dc_mode_cost,
5841                                        uint8_t *best_palette_color_map,
5842                                        MB_MODE_INFO *const best_mbmi,
5843                                        int64_t *best_rd, int *rate,
5844                                        int *rate_tokenonly, int64_t *distortion,
5845                                        int *skippable) {
5846   MACROBLOCKD *const xd = &x->e_mbd;
5847   MB_MODE_INFO *const mbmi = xd->mi[0];
5848   assert(!is_inter_block(mbmi));
5849   assert(
5850       av1_allow_palette(cpi->common.allow_screen_content_tools, mbmi->sb_type));
5851   PALETTE_MODE_INFO *const pmi = &mbmi->palette_mode_info;
5852   const BLOCK_SIZE bsize = mbmi->sb_type;
5853   const SequenceHeader *const seq_params = &cpi->common.seq_params;
5854   int this_rate;
5855   int64_t this_rd;
5856   int colors_u, colors_v, colors;
5857   const int src_stride = x->plane[1].src.stride;
5858   const uint8_t *const src_u = x->plane[1].src.buf;
5859   const uint8_t *const src_v = x->plane[2].src.buf;
5860   uint8_t *const color_map = xd->plane[1].color_index_map;
5861   RD_STATS tokenonly_rd_stats;
5862   int plane_block_width, plane_block_height, rows, cols;
5863   av1_get_block_dimensions(bsize, 1, xd, &plane_block_width,
5864                            &plane_block_height, &rows, &cols);
5865 
5866   mbmi->uv_mode = UV_DC_PRED;
5867 
5868   int count_buf[1 << 12];  // Maximum (1 << 12) color levels.
5869   if (seq_params->use_highbitdepth) {
5870     colors_u = av1_count_colors_highbd(src_u, src_stride, rows, cols,
5871                                        seq_params->bit_depth, count_buf);
5872     colors_v = av1_count_colors_highbd(src_v, src_stride, rows, cols,
5873                                        seq_params->bit_depth, count_buf);
5874   } else {
5875     colors_u = av1_count_colors(src_u, src_stride, rows, cols, count_buf);
5876     colors_v = av1_count_colors(src_v, src_stride, rows, cols, count_buf);
5877   }
5878 
5879   uint16_t color_cache[2 * PALETTE_MAX_SIZE];
5880   const int n_cache = av1_get_palette_cache(xd, 1, color_cache);
5881 
5882   colors = colors_u > colors_v ? colors_u : colors_v;
5883   if (colors > 1 && colors <= 64) {
5884     int r, c, n, i, j;
5885     const int max_itr = 50;
5886     int lb_u, ub_u, val_u;
5887     int lb_v, ub_v, val_v;
5888     int *const data = x->palette_buffer->kmeans_data_buf;
5889     int centroids[2 * PALETTE_MAX_SIZE];
5890 
5891     uint16_t *src_u16 = CONVERT_TO_SHORTPTR(src_u);
5892     uint16_t *src_v16 = CONVERT_TO_SHORTPTR(src_v);
5893     if (seq_params->use_highbitdepth) {
5894       lb_u = src_u16[0];
5895       ub_u = src_u16[0];
5896       lb_v = src_v16[0];
5897       ub_v = src_v16[0];
5898     } else {
5899       lb_u = src_u[0];
5900       ub_u = src_u[0];
5901       lb_v = src_v[0];
5902       ub_v = src_v[0];
5903     }
5904 
5905     for (r = 0; r < rows; ++r) {
5906       for (c = 0; c < cols; ++c) {
5907         if (seq_params->use_highbitdepth) {
5908           val_u = src_u16[r * src_stride + c];
5909           val_v = src_v16[r * src_stride + c];
5910           data[(r * cols + c) * 2] = val_u;
5911           data[(r * cols + c) * 2 + 1] = val_v;
5912         } else {
5913           val_u = src_u[r * src_stride + c];
5914           val_v = src_v[r * src_stride + c];
5915           data[(r * cols + c) * 2] = val_u;
5916           data[(r * cols + c) * 2 + 1] = val_v;
5917         }
5918         if (val_u < lb_u)
5919           lb_u = val_u;
5920         else if (val_u > ub_u)
5921           ub_u = val_u;
5922         if (val_v < lb_v)
5923           lb_v = val_v;
5924         else if (val_v > ub_v)
5925           ub_v = val_v;
5926       }
5927     }
5928 
5929     for (n = colors > PALETTE_MAX_SIZE ? PALETTE_MAX_SIZE : colors; n >= 2;
5930          --n) {
5931       for (i = 0; i < n; ++i) {
5932         centroids[i * 2] = lb_u + (2 * i + 1) * (ub_u - lb_u) / n / 2;
5933         centroids[i * 2 + 1] = lb_v + (2 * i + 1) * (ub_v - lb_v) / n / 2;
5934       }
5935       av1_k_means(data, centroids, color_map, rows * cols, n, 2, max_itr);
5936       optimize_palette_colors(color_cache, n_cache, n, 2, centroids);
5937       // Sort the U channel colors in ascending order.
5938       for (i = 0; i < 2 * (n - 1); i += 2) {
5939         int min_idx = i;
5940         int min_val = centroids[i];
5941         for (j = i + 2; j < 2 * n; j += 2)
5942           if (centroids[j] < min_val) min_val = centroids[j], min_idx = j;
5943         if (min_idx != i) {
5944           int temp_u = centroids[i], temp_v = centroids[i + 1];
5945           centroids[i] = centroids[min_idx];
5946           centroids[i + 1] = centroids[min_idx + 1];
5947           centroids[min_idx] = temp_u, centroids[min_idx + 1] = temp_v;
5948         }
5949       }
5950       av1_calc_indices(data, centroids, color_map, rows * cols, n, 2);
5951       extend_palette_color_map(color_map, cols, rows, plane_block_width,
5952                                plane_block_height);
5953       pmi->palette_size[1] = n;
5954       for (i = 1; i < 3; ++i) {
5955         for (j = 0; j < n; ++j) {
5956           if (seq_params->use_highbitdepth)
5957             pmi->palette_colors[i * PALETTE_MAX_SIZE + j] = clip_pixel_highbd(
5958                 (int)centroids[j * 2 + i - 1], seq_params->bit_depth);
5959           else
5960             pmi->palette_colors[i * PALETTE_MAX_SIZE + j] =
5961                 clip_pixel((int)centroids[j * 2 + i - 1]);
5962         }
5963       }
5964 
5965       super_block_uvrd(cpi, x, &tokenonly_rd_stats, bsize, *best_rd);
5966       if (tokenonly_rd_stats.rate == INT_MAX) continue;
5967       this_rate = tokenonly_rd_stats.rate +
5968                   intra_mode_info_cost_uv(cpi, x, mbmi, bsize, dc_mode_cost);
5969       this_rd = RDCOST(x->rdmult, this_rate, tokenonly_rd_stats.dist);
5970       if (this_rd < *best_rd) {
5971         *best_rd = this_rd;
5972         *best_mbmi = *mbmi;
5973         memcpy(best_palette_color_map, color_map,
5974                plane_block_width * plane_block_height *
5975                    sizeof(best_palette_color_map[0]));
5976         *rate = this_rate;
5977         *distortion = tokenonly_rd_stats.dist;
5978         *rate_tokenonly = tokenonly_rd_stats.rate;
5979         *skippable = tokenonly_rd_stats.skip;
5980       }
5981     }
5982   }
5983   if (best_mbmi->palette_mode_info.palette_size[1] > 0) {
5984     memcpy(color_map, best_palette_color_map,
5985            plane_block_width * plane_block_height *
5986                sizeof(best_palette_color_map[0]));
5987   }
5988 }
5989 
5990 // Run RD calculation with given chroma intra prediction angle., and return
5991 // the RD cost. Update the best mode info. if the RD cost is the best so far.
pick_intra_angle_routine_sbuv(const AV1_COMP * const cpi,MACROBLOCK * x,BLOCK_SIZE bsize,int rate_overhead,int64_t best_rd_in,int * rate,RD_STATS * rd_stats,int * best_angle_delta,int64_t * best_rd)5992 static int64_t pick_intra_angle_routine_sbuv(
5993     const AV1_COMP *const cpi, MACROBLOCK *x, BLOCK_SIZE bsize,
5994     int rate_overhead, int64_t best_rd_in, int *rate, RD_STATS *rd_stats,
5995     int *best_angle_delta, int64_t *best_rd) {
5996   MB_MODE_INFO *mbmi = x->e_mbd.mi[0];
5997   assert(!is_inter_block(mbmi));
5998   int this_rate;
5999   int64_t this_rd;
6000   RD_STATS tokenonly_rd_stats;
6001 
6002   if (!super_block_uvrd(cpi, x, &tokenonly_rd_stats, bsize, best_rd_in))
6003     return INT64_MAX;
6004   this_rate = tokenonly_rd_stats.rate +
6005               intra_mode_info_cost_uv(cpi, x, mbmi, bsize, rate_overhead);
6006   this_rd = RDCOST(x->rdmult, this_rate, tokenonly_rd_stats.dist);
6007   if (this_rd < *best_rd) {
6008     *best_rd = this_rd;
6009     *best_angle_delta = mbmi->angle_delta[PLANE_TYPE_UV];
6010     *rate = this_rate;
6011     rd_stats->rate = tokenonly_rd_stats.rate;
6012     rd_stats->dist = tokenonly_rd_stats.dist;
6013     rd_stats->skip = tokenonly_rd_stats.skip;
6014   }
6015   return this_rd;
6016 }
6017 
6018 // With given chroma directional intra prediction mode, pick the best angle
6019 // delta. Return true if a RD cost that is smaller than the input one is found.
rd_pick_intra_angle_sbuv(const AV1_COMP * const cpi,MACROBLOCK * x,BLOCK_SIZE bsize,int rate_overhead,int64_t best_rd,int * rate,RD_STATS * rd_stats)6020 static int rd_pick_intra_angle_sbuv(const AV1_COMP *const cpi, MACROBLOCK *x,
6021                                     BLOCK_SIZE bsize, int rate_overhead,
6022                                     int64_t best_rd, int *rate,
6023                                     RD_STATS *rd_stats) {
6024   MACROBLOCKD *const xd = &x->e_mbd;
6025   MB_MODE_INFO *mbmi = xd->mi[0];
6026   assert(!is_inter_block(mbmi));
6027   int i, angle_delta, best_angle_delta = 0;
6028   int64_t this_rd, best_rd_in, rd_cost[2 * (MAX_ANGLE_DELTA + 2)];
6029 
6030   rd_stats->rate = INT_MAX;
6031   rd_stats->skip = 0;
6032   rd_stats->dist = INT64_MAX;
6033   for (i = 0; i < 2 * (MAX_ANGLE_DELTA + 2); ++i) rd_cost[i] = INT64_MAX;
6034 
6035   for (angle_delta = 0; angle_delta <= MAX_ANGLE_DELTA; angle_delta += 2) {
6036     for (i = 0; i < 2; ++i) {
6037       best_rd_in = (best_rd == INT64_MAX)
6038                        ? INT64_MAX
6039                        : (best_rd + (best_rd >> ((angle_delta == 0) ? 3 : 5)));
6040       mbmi->angle_delta[PLANE_TYPE_UV] = (1 - 2 * i) * angle_delta;
6041       this_rd = pick_intra_angle_routine_sbuv(cpi, x, bsize, rate_overhead,
6042                                               best_rd_in, rate, rd_stats,
6043                                               &best_angle_delta, &best_rd);
6044       rd_cost[2 * angle_delta + i] = this_rd;
6045       if (angle_delta == 0) {
6046         if (this_rd == INT64_MAX) return 0;
6047         rd_cost[1] = this_rd;
6048         break;
6049       }
6050     }
6051   }
6052 
6053   assert(best_rd != INT64_MAX);
6054   for (angle_delta = 1; angle_delta <= MAX_ANGLE_DELTA; angle_delta += 2) {
6055     int64_t rd_thresh;
6056     for (i = 0; i < 2; ++i) {
6057       int skip_search = 0;
6058       rd_thresh = best_rd + (best_rd >> 5);
6059       if (rd_cost[2 * (angle_delta + 1) + i] > rd_thresh &&
6060           rd_cost[2 * (angle_delta - 1) + i] > rd_thresh)
6061         skip_search = 1;
6062       if (!skip_search) {
6063         mbmi->angle_delta[PLANE_TYPE_UV] = (1 - 2 * i) * angle_delta;
6064         pick_intra_angle_routine_sbuv(cpi, x, bsize, rate_overhead, best_rd,
6065                                       rate, rd_stats, &best_angle_delta,
6066                                       &best_rd);
6067       }
6068     }
6069   }
6070 
6071   mbmi->angle_delta[PLANE_TYPE_UV] = best_angle_delta;
6072   return rd_stats->rate != INT_MAX;
6073 }
6074 
6075 #define PLANE_SIGN_TO_JOINT_SIGN(plane, a, b) \
6076   (plane == CFL_PRED_U ? a * CFL_SIGNS + b - 1 : b * CFL_SIGNS + a - 1)
cfl_rd_pick_alpha(MACROBLOCK * const x,const AV1_COMP * const cpi,TX_SIZE tx_size,int64_t best_rd)6077 static int cfl_rd_pick_alpha(MACROBLOCK *const x, const AV1_COMP *const cpi,
6078                              TX_SIZE tx_size, int64_t best_rd) {
6079   MACROBLOCKD *const xd = &x->e_mbd;
6080   MB_MODE_INFO *const mbmi = xd->mi[0];
6081 
6082   const BLOCK_SIZE bsize = mbmi->sb_type;
6083 #if CONFIG_DEBUG
6084   assert(is_cfl_allowed(xd));
6085   const int ssx = xd->plane[AOM_PLANE_U].subsampling_x;
6086   const int ssy = xd->plane[AOM_PLANE_U].subsampling_y;
6087   const BLOCK_SIZE plane_bsize = get_plane_block_size(mbmi->sb_type, ssx, ssy);
6088   (void)plane_bsize;
6089   assert(plane_bsize < BLOCK_SIZES_ALL);
6090   if (!xd->lossless[mbmi->segment_id]) {
6091     assert(block_size_wide[plane_bsize] == tx_size_wide[tx_size]);
6092     assert(block_size_high[plane_bsize] == tx_size_high[tx_size]);
6093   }
6094 #endif  // CONFIG_DEBUG
6095 
6096   xd->cfl.use_dc_pred_cache = 1;
6097   const int64_t mode_rd =
6098       RDCOST(x->rdmult,
6099              x->intra_uv_mode_cost[CFL_ALLOWED][mbmi->mode][UV_CFL_PRED], 0);
6100   int64_t best_rd_uv[CFL_JOINT_SIGNS][CFL_PRED_PLANES];
6101   int best_c[CFL_JOINT_SIGNS][CFL_PRED_PLANES];
6102 #if CONFIG_DEBUG
6103   int best_rate_uv[CFL_JOINT_SIGNS][CFL_PRED_PLANES];
6104 #endif  // CONFIG_DEBUG
6105 
6106   for (int plane = 0; plane < CFL_PRED_PLANES; plane++) {
6107     RD_STATS rd_stats;
6108     av1_init_rd_stats(&rd_stats);
6109     for (int joint_sign = 0; joint_sign < CFL_JOINT_SIGNS; joint_sign++) {
6110       best_rd_uv[joint_sign][plane] = INT64_MAX;
6111       best_c[joint_sign][plane] = 0;
6112     }
6113     // Collect RD stats for an alpha value of zero in this plane.
6114     // Skip i == CFL_SIGN_ZERO as (0, 0) is invalid.
6115     for (int i = CFL_SIGN_NEG; i < CFL_SIGNS; i++) {
6116       const int joint_sign = PLANE_SIGN_TO_JOINT_SIGN(plane, CFL_SIGN_ZERO, i);
6117       if (i == CFL_SIGN_NEG) {
6118         mbmi->cfl_alpha_idx = 0;
6119         mbmi->cfl_alpha_signs = joint_sign;
6120         txfm_rd_in_plane(x, cpi, &rd_stats, best_rd, plane + 1, bsize, tx_size,
6121                          cpi->sf.use_fast_coef_costing, FTXS_NONE);
6122         if (rd_stats.rate == INT_MAX) break;
6123       }
6124       const int alpha_rate = x->cfl_cost[joint_sign][plane][0];
6125       best_rd_uv[joint_sign][plane] =
6126           RDCOST(x->rdmult, rd_stats.rate + alpha_rate, rd_stats.dist);
6127 #if CONFIG_DEBUG
6128       best_rate_uv[joint_sign][plane] = rd_stats.rate;
6129 #endif  // CONFIG_DEBUG
6130     }
6131   }
6132 
6133   int best_joint_sign = -1;
6134 
6135   for (int plane = 0; plane < CFL_PRED_PLANES; plane++) {
6136     for (int pn_sign = CFL_SIGN_NEG; pn_sign < CFL_SIGNS; pn_sign++) {
6137       int progress = 0;
6138       for (int c = 0; c < CFL_ALPHABET_SIZE; c++) {
6139         int flag = 0;
6140         RD_STATS rd_stats;
6141         if (c > 2 && progress < c) break;
6142         av1_init_rd_stats(&rd_stats);
6143         for (int i = 0; i < CFL_SIGNS; i++) {
6144           const int joint_sign = PLANE_SIGN_TO_JOINT_SIGN(plane, pn_sign, i);
6145           if (i == 0) {
6146             mbmi->cfl_alpha_idx = (c << CFL_ALPHABET_SIZE_LOG2) + c;
6147             mbmi->cfl_alpha_signs = joint_sign;
6148             txfm_rd_in_plane(x, cpi, &rd_stats, best_rd, plane + 1, bsize,
6149                              tx_size, cpi->sf.use_fast_coef_costing, FTXS_NONE);
6150             if (rd_stats.rate == INT_MAX) break;
6151           }
6152           const int alpha_rate = x->cfl_cost[joint_sign][plane][c];
6153           int64_t this_rd =
6154               RDCOST(x->rdmult, rd_stats.rate + alpha_rate, rd_stats.dist);
6155           if (this_rd >= best_rd_uv[joint_sign][plane]) continue;
6156           best_rd_uv[joint_sign][plane] = this_rd;
6157           best_c[joint_sign][plane] = c;
6158 #if CONFIG_DEBUG
6159           best_rate_uv[joint_sign][plane] = rd_stats.rate;
6160 #endif  // CONFIG_DEBUG
6161           flag = 2;
6162           if (best_rd_uv[joint_sign][!plane] == INT64_MAX) continue;
6163           this_rd += mode_rd + best_rd_uv[joint_sign][!plane];
6164           if (this_rd >= best_rd) continue;
6165           best_rd = this_rd;
6166           best_joint_sign = joint_sign;
6167         }
6168         progress += flag;
6169       }
6170     }
6171   }
6172 
6173   int best_rate_overhead = INT_MAX;
6174   int ind = 0;
6175   if (best_joint_sign >= 0) {
6176     const int u = best_c[best_joint_sign][CFL_PRED_U];
6177     const int v = best_c[best_joint_sign][CFL_PRED_V];
6178     ind = (u << CFL_ALPHABET_SIZE_LOG2) + v;
6179     best_rate_overhead = x->cfl_cost[best_joint_sign][CFL_PRED_U][u] +
6180                          x->cfl_cost[best_joint_sign][CFL_PRED_V][v];
6181 #if CONFIG_DEBUG
6182     xd->cfl.rate = x->intra_uv_mode_cost[CFL_ALLOWED][mbmi->mode][UV_CFL_PRED] +
6183                    best_rate_overhead +
6184                    best_rate_uv[best_joint_sign][CFL_PRED_U] +
6185                    best_rate_uv[best_joint_sign][CFL_PRED_V];
6186 #endif  // CONFIG_DEBUG
6187   } else {
6188     best_joint_sign = 0;
6189   }
6190 
6191   mbmi->cfl_alpha_idx = ind;
6192   mbmi->cfl_alpha_signs = best_joint_sign;
6193   xd->cfl.use_dc_pred_cache = 0;
6194   xd->cfl.dc_pred_is_cached[0] = 0;
6195   xd->cfl.dc_pred_is_cached[1] = 0;
6196   return best_rate_overhead;
6197 }
6198 
init_sbuv_mode(MB_MODE_INFO * const mbmi)6199 static void init_sbuv_mode(MB_MODE_INFO *const mbmi) {
6200   mbmi->uv_mode = UV_DC_PRED;
6201   mbmi->palette_mode_info.palette_size[1] = 0;
6202 }
6203 
rd_pick_intra_sbuv_mode(const AV1_COMP * const cpi,MACROBLOCK * x,int * rate,int * rate_tokenonly,int64_t * distortion,int * skippable,BLOCK_SIZE bsize,TX_SIZE max_tx_size)6204 static int64_t rd_pick_intra_sbuv_mode(const AV1_COMP *const cpi, MACROBLOCK *x,
6205                                        int *rate, int *rate_tokenonly,
6206                                        int64_t *distortion, int *skippable,
6207                                        BLOCK_SIZE bsize, TX_SIZE max_tx_size) {
6208   MACROBLOCKD *xd = &x->e_mbd;
6209   MB_MODE_INFO *mbmi = xd->mi[0];
6210   assert(!is_inter_block(mbmi));
6211   MB_MODE_INFO best_mbmi = *mbmi;
6212   int64_t best_rd = INT64_MAX, this_rd;
6213 
6214   for (int mode_idx = 0; mode_idx < UV_INTRA_MODES; ++mode_idx) {
6215     int this_rate;
6216     RD_STATS tokenonly_rd_stats;
6217     UV_PREDICTION_MODE mode = uv_rd_search_mode_order[mode_idx];
6218     const int is_directional_mode = av1_is_directional_mode(get_uv_mode(mode));
6219     if (!(cpi->sf.intra_uv_mode_mask[txsize_sqr_up_map[max_tx_size]] &
6220           (1 << mode)))
6221       continue;
6222 
6223     mbmi->uv_mode = mode;
6224     int cfl_alpha_rate = 0;
6225     if (mode == UV_CFL_PRED) {
6226       if (!is_cfl_allowed(xd)) continue;
6227       assert(!is_directional_mode);
6228       const TX_SIZE uv_tx_size = av1_get_tx_size(AOM_PLANE_U, xd);
6229       cfl_alpha_rate = cfl_rd_pick_alpha(x, cpi, uv_tx_size, best_rd);
6230       if (cfl_alpha_rate == INT_MAX) continue;
6231     }
6232     mbmi->angle_delta[PLANE_TYPE_UV] = 0;
6233     if (is_directional_mode && av1_use_angle_delta(mbmi->sb_type)) {
6234       const int rate_overhead =
6235           x->intra_uv_mode_cost[is_cfl_allowed(xd)][mbmi->mode][mode];
6236       if (!rd_pick_intra_angle_sbuv(cpi, x, bsize, rate_overhead, best_rd,
6237                                     &this_rate, &tokenonly_rd_stats))
6238         continue;
6239     } else {
6240       if (!super_block_uvrd(cpi, x, &tokenonly_rd_stats, bsize, best_rd)) {
6241         continue;
6242       }
6243     }
6244     const int mode_cost =
6245         x->intra_uv_mode_cost[is_cfl_allowed(xd)][mbmi->mode][mode] +
6246         cfl_alpha_rate;
6247     this_rate = tokenonly_rd_stats.rate +
6248                 intra_mode_info_cost_uv(cpi, x, mbmi, bsize, mode_cost);
6249     if (mode == UV_CFL_PRED) {
6250       assert(is_cfl_allowed(xd));
6251 #if CONFIG_DEBUG
6252       if (!xd->lossless[mbmi->segment_id])
6253         assert(xd->cfl.rate == tokenonly_rd_stats.rate + mode_cost);
6254 #endif  // CONFIG_DEBUG
6255     }
6256     this_rd = RDCOST(x->rdmult, this_rate, tokenonly_rd_stats.dist);
6257 
6258     if (this_rd < best_rd) {
6259       best_mbmi = *mbmi;
6260       best_rd = this_rd;
6261       *rate = this_rate;
6262       *rate_tokenonly = tokenonly_rd_stats.rate;
6263       *distortion = tokenonly_rd_stats.dist;
6264       *skippable = tokenonly_rd_stats.skip;
6265     }
6266   }
6267 
6268   const int try_palette =
6269       av1_allow_palette(cpi->common.allow_screen_content_tools, mbmi->sb_type);
6270   if (try_palette) {
6271     uint8_t *best_palette_color_map = x->palette_buffer->best_palette_color_map;
6272     rd_pick_palette_intra_sbuv(
6273         cpi, x,
6274         x->intra_uv_mode_cost[is_cfl_allowed(xd)][mbmi->mode][UV_DC_PRED],
6275         best_palette_color_map, &best_mbmi, &best_rd, rate, rate_tokenonly,
6276         distortion, skippable);
6277   }
6278 
6279   *mbmi = best_mbmi;
6280   // Make sure we actually chose a mode
6281   assert(best_rd < INT64_MAX);
6282   return best_rd;
6283 }
6284 
choose_intra_uv_mode(const AV1_COMP * const cpi,MACROBLOCK * const x,BLOCK_SIZE bsize,TX_SIZE max_tx_size,int * rate_uv,int * rate_uv_tokenonly,int64_t * dist_uv,int * skip_uv,UV_PREDICTION_MODE * mode_uv)6285 static void choose_intra_uv_mode(const AV1_COMP *const cpi, MACROBLOCK *const x,
6286                                  BLOCK_SIZE bsize, TX_SIZE max_tx_size,
6287                                  int *rate_uv, int *rate_uv_tokenonly,
6288                                  int64_t *dist_uv, int *skip_uv,
6289                                  UV_PREDICTION_MODE *mode_uv) {
6290   const AV1_COMMON *const cm = &cpi->common;
6291   MACROBLOCKD *xd = &x->e_mbd;
6292   MB_MODE_INFO *mbmi = xd->mi[0];
6293   const int mi_row = -xd->mb_to_top_edge >> (3 + MI_SIZE_LOG2);
6294   const int mi_col = -xd->mb_to_left_edge >> (3 + MI_SIZE_LOG2);
6295   // Use an estimated rd for uv_intra based on DC_PRED if the
6296   // appropriate speed flag is set.
6297   init_sbuv_mode(mbmi);
6298   if (x->skip_chroma_rd) {
6299     *rate_uv = 0;
6300     *rate_uv_tokenonly = 0;
6301     *dist_uv = 0;
6302     *skip_uv = 1;
6303     *mode_uv = UV_DC_PRED;
6304     return;
6305   }
6306   xd->cfl.is_chroma_reference =
6307       is_chroma_reference(mi_row, mi_col, bsize, cm->seq_params.subsampling_x,
6308                           cm->seq_params.subsampling_y);
6309   bsize = scale_chroma_bsize(bsize, xd->plane[AOM_PLANE_U].subsampling_x,
6310                              xd->plane[AOM_PLANE_U].subsampling_y);
6311   // Only store reconstructed luma when there's chroma RDO. When there's no
6312   // chroma RDO, the reconstructed luma will be stored in encode_superblock().
6313   xd->cfl.store_y = store_cfl_required_rdo(cm, x);
6314   if (xd->cfl.store_y) {
6315     // Restore reconstructed luma values.
6316     av1_encode_intra_block_plane(cpi, x, mbmi->sb_type, AOM_PLANE_Y,
6317                                  cpi->optimize_seg_arr[mbmi->segment_id],
6318                                  mi_row, mi_col);
6319     xd->cfl.store_y = 0;
6320   }
6321   rd_pick_intra_sbuv_mode(cpi, x, rate_uv, rate_uv_tokenonly, dist_uv, skip_uv,
6322                           bsize, max_tx_size);
6323   *mode_uv = mbmi->uv_mode;
6324 }
6325 
cost_mv_ref(const MACROBLOCK * const x,PREDICTION_MODE mode,int16_t mode_context)6326 static int cost_mv_ref(const MACROBLOCK *const x, PREDICTION_MODE mode,
6327                        int16_t mode_context) {
6328   if (is_inter_compound_mode(mode)) {
6329     return x
6330         ->inter_compound_mode_cost[mode_context][INTER_COMPOUND_OFFSET(mode)];
6331   }
6332 
6333   int mode_cost = 0;
6334   int16_t mode_ctx = mode_context & NEWMV_CTX_MASK;
6335 
6336   assert(is_inter_mode(mode));
6337 
6338   if (mode == NEWMV) {
6339     mode_cost = x->newmv_mode_cost[mode_ctx][0];
6340     return mode_cost;
6341   } else {
6342     mode_cost = x->newmv_mode_cost[mode_ctx][1];
6343     mode_ctx = (mode_context >> GLOBALMV_OFFSET) & GLOBALMV_CTX_MASK;
6344 
6345     if (mode == GLOBALMV) {
6346       mode_cost += x->zeromv_mode_cost[mode_ctx][0];
6347       return mode_cost;
6348     } else {
6349       mode_cost += x->zeromv_mode_cost[mode_ctx][1];
6350       mode_ctx = (mode_context >> REFMV_OFFSET) & REFMV_CTX_MASK;
6351       mode_cost += x->refmv_mode_cost[mode_ctx][mode != NEARESTMV];
6352       return mode_cost;
6353     }
6354   }
6355 }
6356 
get_interinter_compound_mask_rate(const MACROBLOCK * const x,const MB_MODE_INFO * const mbmi)6357 static int get_interinter_compound_mask_rate(const MACROBLOCK *const x,
6358                                              const MB_MODE_INFO *const mbmi) {
6359   switch (mbmi->interinter_comp.type) {
6360     case COMPOUND_AVERAGE: return 0;
6361     case COMPOUND_WEDGE:
6362       return get_interinter_wedge_bits(mbmi->sb_type) > 0
6363                  ? av1_cost_literal(1) +
6364                        x->wedge_idx_cost[mbmi->sb_type]
6365                                         [mbmi->interinter_comp.wedge_index]
6366                  : 0;
6367     case COMPOUND_DIFFWTD: return av1_cost_literal(1);
6368     default: assert(0); return 0;
6369   }
6370 }
6371 
6372 typedef struct {
6373   int eobs;
6374   int brate;
6375   int byrate;
6376   int64_t bdist;
6377   int64_t bsse;
6378   int64_t brdcost;
6379   int_mv mvs[2];
6380   int_mv pred_mv[2];
6381   int_mv ref_mv[2];
6382 
6383   ENTROPY_CONTEXT ta[2];
6384   ENTROPY_CONTEXT tl[2];
6385 } SEG_RDSTAT;
6386 
6387 typedef struct {
6388   int_mv *ref_mv[2];
6389   int_mv mvp;
6390 
6391   int64_t segment_rd;
6392   int r;
6393   int64_t d;
6394   int64_t sse;
6395   int segment_yrate;
6396   PREDICTION_MODE modes[4];
6397   SEG_RDSTAT rdstat[4][INTER_MODES + INTER_COMPOUND_MODES];
6398   int mvthresh;
6399 } BEST_SEG_INFO;
6400 
mv_check_bounds(const MvLimits * mv_limits,const MV * mv)6401 static INLINE int mv_check_bounds(const MvLimits *mv_limits, const MV *mv) {
6402   return (mv->row >> 3) < mv_limits->row_min ||
6403          (mv->row >> 3) > mv_limits->row_max ||
6404          (mv->col >> 3) < mv_limits->col_min ||
6405          (mv->col >> 3) > mv_limits->col_max;
6406 }
6407 
get_single_mode(PREDICTION_MODE this_mode,int ref_idx,int is_comp_pred)6408 static INLINE PREDICTION_MODE get_single_mode(PREDICTION_MODE this_mode,
6409                                               int ref_idx, int is_comp_pred) {
6410   PREDICTION_MODE single_mode;
6411   if (is_comp_pred) {
6412     single_mode =
6413         ref_idx ? compound_ref1_mode(this_mode) : compound_ref0_mode(this_mode);
6414   } else {
6415     single_mode = this_mode;
6416   }
6417   return single_mode;
6418 }
6419 
joint_motion_search(const AV1_COMP * cpi,MACROBLOCK * x,BLOCK_SIZE bsize,int_mv * cur_mv,int mi_row,int mi_col,int_mv * ref_mv_sub8x8[2],const uint8_t * mask,int mask_stride,int * rate_mv,const int block)6420 static void joint_motion_search(const AV1_COMP *cpi, MACROBLOCK *x,
6421                                 BLOCK_SIZE bsize, int_mv *cur_mv, int mi_row,
6422                                 int mi_col, int_mv *ref_mv_sub8x8[2],
6423                                 const uint8_t *mask, int mask_stride,
6424                                 int *rate_mv, const int block) {
6425   const AV1_COMMON *const cm = &cpi->common;
6426   const int num_planes = av1_num_planes(cm);
6427   const int pw = block_size_wide[bsize];
6428   const int ph = block_size_high[bsize];
6429   const int plane = 0;
6430   MACROBLOCKD *xd = &x->e_mbd;
6431   MB_MODE_INFO *mbmi = xd->mi[0];
6432   // This function should only ever be called for compound modes
6433   assert(has_second_ref(mbmi));
6434   const int_mv init_mv[2] = { cur_mv[0], cur_mv[1] };
6435   const int refs[2] = { mbmi->ref_frame[0], mbmi->ref_frame[1] };
6436   int_mv ref_mv[2];
6437   int ite, ref;
6438   // ic and ir are the 4x4 coordinates of the sub8x8 at index "block"
6439   const int ic = block & 1;
6440   const int ir = (block - ic) >> 1;
6441   struct macroblockd_plane *const pd = &xd->plane[0];
6442   const int p_col = ((mi_col * MI_SIZE) >> pd->subsampling_x) + 4 * ic;
6443   const int p_row = ((mi_row * MI_SIZE) >> pd->subsampling_y) + 4 * ir;
6444 
6445   ConvolveParams conv_params = get_conv_params(0, plane, xd->bd);
6446   conv_params.use_jnt_comp_avg = 0;
6447   WarpTypesAllowed warp_types[2];
6448   for (ref = 0; ref < 2; ++ref) {
6449     const WarpedMotionParams *const wm =
6450         &xd->global_motion[xd->mi[0]->ref_frame[ref]];
6451     const int is_global = is_global_mv_block(xd->mi[0], wm->wmtype);
6452     warp_types[ref].global_warp_allowed = is_global;
6453     warp_types[ref].local_warp_allowed = mbmi->motion_mode == WARPED_CAUSAL;
6454   }
6455 
6456   // Do joint motion search in compound mode to get more accurate mv.
6457   struct buf_2d backup_yv12[2][MAX_MB_PLANE];
6458   int last_besterr[2] = { INT_MAX, INT_MAX };
6459   const YV12_BUFFER_CONFIG *const scaled_ref_frame[2] = {
6460     av1_get_scaled_ref_frame(cpi, refs[0]),
6461     av1_get_scaled_ref_frame(cpi, refs[1])
6462   };
6463 
6464   // Prediction buffer from second frame.
6465   DECLARE_ALIGNED(16, uint8_t, second_pred16[MAX_SB_SQUARE * sizeof(uint16_t)]);
6466   uint8_t *second_pred = get_buf_by_bd(xd, second_pred16);
6467   (void)ref_mv_sub8x8;
6468 
6469   const int have_newmv = have_nearmv_in_inter_mode(mbmi->mode);
6470   const int ref_mv_idx = mbmi->ref_mv_idx + (have_newmv ? 1 : 0);
6471   MV *const best_mv = &x->best_mv.as_mv;
6472   const int search_range = SEARCH_RANGE_8P;
6473   const int sadpb = x->sadperbit16;
6474   // Allow joint search multiple times iteratively for each reference frame
6475   // and break out of the search loop if it couldn't find a better mv.
6476   for (ite = 0; ite < 4; ite++) {
6477     struct buf_2d ref_yv12[2];
6478     int bestsme = INT_MAX;
6479     MvLimits tmp_mv_limits = x->mv_limits;
6480     int id = ite % 2;  // Even iterations search in the first reference frame,
6481                        // odd iterations search in the second. The predictor
6482                        // found for the 'other' reference frame is factored in.
6483     if (ite >= 2 && cur_mv[!id].as_int == init_mv[!id].as_int) {
6484       if (cur_mv[id].as_int == init_mv[id].as_int) {
6485         break;
6486       } else {
6487         int_mv cur_int_mv, init_int_mv;
6488         cur_int_mv.as_mv.col = cur_mv[id].as_mv.col >> 3;
6489         cur_int_mv.as_mv.row = cur_mv[id].as_mv.col >> 3;
6490         init_int_mv.as_mv.row = init_mv[id].as_mv.row >> 3;
6491         init_int_mv.as_mv.col = init_mv[id].as_mv.col >> 3;
6492         if (cur_int_mv.as_int == init_int_mv.as_int) {
6493           break;
6494         }
6495       }
6496     }
6497     for (ref = 0; ref < 2; ++ref) {
6498       ref_mv[ref] = av1_get_ref_mv(x, ref);
6499       // Swap out the reference frame for a version that's been scaled to
6500       // match the resolution of the current frame, allowing the existing
6501       // motion search code to be used without additional modifications.
6502       if (scaled_ref_frame[ref]) {
6503         int i;
6504         for (i = 0; i < num_planes; i++)
6505           backup_yv12[ref][i] = xd->plane[i].pre[ref];
6506         av1_setup_pre_planes(xd, ref, scaled_ref_frame[ref], mi_row, mi_col,
6507                              NULL, num_planes);
6508       }
6509     }
6510 
6511     assert(IMPLIES(scaled_ref_frame[0] != NULL,
6512                    cm->width == scaled_ref_frame[0]->y_crop_width &&
6513                        cm->height == scaled_ref_frame[0]->y_crop_height));
6514     assert(IMPLIES(scaled_ref_frame[1] != NULL,
6515                    cm->width == scaled_ref_frame[1]->y_crop_width &&
6516                        cm->height == scaled_ref_frame[1]->y_crop_height));
6517 
6518     // Initialize based on (possibly scaled) prediction buffers.
6519     ref_yv12[0] = xd->plane[plane].pre[0];
6520     ref_yv12[1] = xd->plane[plane].pre[1];
6521 
6522     // Get the prediction block from the 'other' reference frame.
6523     const InterpFilters interp_filters = EIGHTTAP_REGULAR;
6524 
6525     // Since we have scaled the reference frames to match the size of the
6526     // current frame we must use a unit scaling factor during mode selection.
6527     av1_build_inter_predictor(ref_yv12[!id].buf, ref_yv12[!id].stride,
6528                               second_pred, pw, &cur_mv[!id].as_mv,
6529                               &cm->sf_identity, pw, ph, &conv_params,
6530                               interp_filters, &warp_types[!id], p_col, p_row,
6531                               plane, !id, MV_PRECISION_Q3, mi_col * MI_SIZE,
6532                               mi_row * MI_SIZE, xd, cm->allow_warped_motion);
6533 
6534     const int order_idx = id != 0;
6535     av1_jnt_comp_weight_assign(cm, mbmi, order_idx, &xd->jcp_param.fwd_offset,
6536                                &xd->jcp_param.bck_offset,
6537                                &xd->jcp_param.use_jnt_comp_avg, 1);
6538 
6539     // Do full-pixel compound motion search on the current reference frame.
6540     if (id) xd->plane[plane].pre[0] = ref_yv12[id];
6541     av1_set_mv_search_range(&x->mv_limits, &ref_mv[id].as_mv);
6542 
6543     // Use the mv result from the single mode as mv predictor.
6544     *best_mv = cur_mv[id].as_mv;
6545 
6546     best_mv->col >>= 3;
6547     best_mv->row >>= 3;
6548 
6549     av1_set_mvcost(x, id, ref_mv_idx);
6550 
6551     // Small-range full-pixel motion search.
6552     bestsme = av1_refining_search_8p_c(x, sadpb, search_range,
6553                                        &cpi->fn_ptr[bsize], mask, mask_stride,
6554                                        id, &ref_mv[id].as_mv, second_pred);
6555     if (bestsme < INT_MAX) {
6556       if (mask)
6557         bestsme = av1_get_mvpred_mask_var(x, best_mv, &ref_mv[id].as_mv,
6558                                           second_pred, mask, mask_stride, id,
6559                                           &cpi->fn_ptr[bsize], 1);
6560       else
6561         bestsme = av1_get_mvpred_av_var(x, best_mv, &ref_mv[id].as_mv,
6562                                         second_pred, &cpi->fn_ptr[bsize], 1);
6563     }
6564 
6565     x->mv_limits = tmp_mv_limits;
6566 
6567     // Restore the pointer to the first (possibly scaled) prediction buffer.
6568     if (id) xd->plane[plane].pre[0] = ref_yv12[0];
6569 
6570     for (ref = 0; ref < 2; ++ref) {
6571       if (scaled_ref_frame[ref]) {
6572         // Swap back the original buffers for subpel motion search.
6573         for (int i = 0; i < num_planes; i++) {
6574           xd->plane[i].pre[ref] = backup_yv12[ref][i];
6575         }
6576         // Re-initialize based on unscaled prediction buffers.
6577         ref_yv12[ref] = xd->plane[plane].pre[ref];
6578       }
6579     }
6580 
6581     // Do sub-pixel compound motion search on the current reference frame.
6582     if (id) xd->plane[plane].pre[0] = ref_yv12[id];
6583 
6584     if (cpi->common.cur_frame_force_integer_mv) {
6585       x->best_mv.as_mv.row *= 8;
6586       x->best_mv.as_mv.col *= 8;
6587     }
6588     if (bestsme < INT_MAX && cpi->common.cur_frame_force_integer_mv == 0) {
6589       int dis; /* TODO: use dis in distortion calculation later. */
6590       unsigned int sse;
6591       bestsme = cpi->find_fractional_mv_step(
6592           x, cm, mi_row, mi_col, &ref_mv[id].as_mv,
6593           cpi->common.allow_high_precision_mv, x->errorperbit,
6594           &cpi->fn_ptr[bsize], 0, cpi->sf.mv.subpel_iters_per_step, NULL,
6595           x->nmvjointcost, x->mvcost, &dis, &sse, second_pred, mask,
6596           mask_stride, id, pw, ph, cpi->sf.use_accurate_subpel_search);
6597     }
6598 
6599     // Restore the pointer to the first prediction buffer.
6600     if (id) xd->plane[plane].pre[0] = ref_yv12[0];
6601     if (bestsme < last_besterr[id]) {
6602       cur_mv[id].as_mv = *best_mv;
6603       last_besterr[id] = bestsme;
6604     } else {
6605       break;
6606     }
6607   }
6608 
6609   *rate_mv = 0;
6610 
6611   for (ref = 0; ref < 2; ++ref) {
6612     av1_set_mvcost(x, ref, ref_mv_idx);
6613     const int_mv curr_ref_mv = av1_get_ref_mv(x, ref);
6614     *rate_mv += av1_mv_bit_cost(&cur_mv[ref].as_mv, &curr_ref_mv.as_mv,
6615                                 x->nmvjointcost, x->mvcost, MV_COST_WEIGHT);
6616   }
6617 }
6618 
estimate_ref_frame_costs(const AV1_COMMON * cm,const MACROBLOCKD * xd,const MACROBLOCK * x,int segment_id,unsigned int * ref_costs_single,unsigned int (* ref_costs_comp)[REF_FRAMES])6619 static void estimate_ref_frame_costs(
6620     const AV1_COMMON *cm, const MACROBLOCKD *xd, const MACROBLOCK *x,
6621     int segment_id, unsigned int *ref_costs_single,
6622     unsigned int (*ref_costs_comp)[REF_FRAMES]) {
6623   int seg_ref_active =
6624       segfeature_active(&cm->seg, segment_id, SEG_LVL_REF_FRAME);
6625   if (seg_ref_active) {
6626     memset(ref_costs_single, 0, REF_FRAMES * sizeof(*ref_costs_single));
6627     int ref_frame;
6628     for (ref_frame = 0; ref_frame < REF_FRAMES; ++ref_frame)
6629       memset(ref_costs_comp[ref_frame], 0,
6630              REF_FRAMES * sizeof((*ref_costs_comp)[0]));
6631   } else {
6632     int intra_inter_ctx = av1_get_intra_inter_context(xd);
6633     ref_costs_single[INTRA_FRAME] = x->intra_inter_cost[intra_inter_ctx][0];
6634     unsigned int base_cost = x->intra_inter_cost[intra_inter_ctx][1];
6635 
6636     for (int i = LAST_FRAME; i <= ALTREF_FRAME; ++i)
6637       ref_costs_single[i] = base_cost;
6638 
6639     const int ctx_p1 = av1_get_pred_context_single_ref_p1(xd);
6640     const int ctx_p2 = av1_get_pred_context_single_ref_p2(xd);
6641     const int ctx_p3 = av1_get_pred_context_single_ref_p3(xd);
6642     const int ctx_p4 = av1_get_pred_context_single_ref_p4(xd);
6643     const int ctx_p5 = av1_get_pred_context_single_ref_p5(xd);
6644     const int ctx_p6 = av1_get_pred_context_single_ref_p6(xd);
6645 
6646     // Determine cost of a single ref frame, where frame types are represented
6647     // by a tree:
6648     // Level 0: add cost whether this ref is a forward or backward ref
6649     ref_costs_single[LAST_FRAME] += x->single_ref_cost[ctx_p1][0][0];
6650     ref_costs_single[LAST2_FRAME] += x->single_ref_cost[ctx_p1][0][0];
6651     ref_costs_single[LAST3_FRAME] += x->single_ref_cost[ctx_p1][0][0];
6652     ref_costs_single[GOLDEN_FRAME] += x->single_ref_cost[ctx_p1][0][0];
6653     ref_costs_single[BWDREF_FRAME] += x->single_ref_cost[ctx_p1][0][1];
6654     ref_costs_single[ALTREF2_FRAME] += x->single_ref_cost[ctx_p1][0][1];
6655     ref_costs_single[ALTREF_FRAME] += x->single_ref_cost[ctx_p1][0][1];
6656 
6657     // Level 1: if this ref is forward ref,
6658     // add cost whether it is last/last2 or last3/golden
6659     ref_costs_single[LAST_FRAME] += x->single_ref_cost[ctx_p3][2][0];
6660     ref_costs_single[LAST2_FRAME] += x->single_ref_cost[ctx_p3][2][0];
6661     ref_costs_single[LAST3_FRAME] += x->single_ref_cost[ctx_p3][2][1];
6662     ref_costs_single[GOLDEN_FRAME] += x->single_ref_cost[ctx_p3][2][1];
6663 
6664     // Level 1: if this ref is backward ref
6665     // then add cost whether this ref is altref or backward ref
6666     ref_costs_single[BWDREF_FRAME] += x->single_ref_cost[ctx_p2][1][0];
6667     ref_costs_single[ALTREF2_FRAME] += x->single_ref_cost[ctx_p2][1][0];
6668     ref_costs_single[ALTREF_FRAME] += x->single_ref_cost[ctx_p2][1][1];
6669 
6670     // Level 2: further add cost whether this ref is last or last2
6671     ref_costs_single[LAST_FRAME] += x->single_ref_cost[ctx_p4][3][0];
6672     ref_costs_single[LAST2_FRAME] += x->single_ref_cost[ctx_p4][3][1];
6673 
6674     // Level 2: last3 or golden
6675     ref_costs_single[LAST3_FRAME] += x->single_ref_cost[ctx_p5][4][0];
6676     ref_costs_single[GOLDEN_FRAME] += x->single_ref_cost[ctx_p5][4][1];
6677 
6678     // Level 2: bwdref or altref2
6679     ref_costs_single[BWDREF_FRAME] += x->single_ref_cost[ctx_p6][5][0];
6680     ref_costs_single[ALTREF2_FRAME] += x->single_ref_cost[ctx_p6][5][1];
6681 
6682     if (cm->reference_mode != SINGLE_REFERENCE) {
6683       // Similar to single ref, determine cost of compound ref frames.
6684       // cost_compound_refs = cost_first_ref + cost_second_ref
6685       const int bwdref_comp_ctx_p = av1_get_pred_context_comp_bwdref_p(xd);
6686       const int bwdref_comp_ctx_p1 = av1_get_pred_context_comp_bwdref_p1(xd);
6687       const int ref_comp_ctx_p = av1_get_pred_context_comp_ref_p(xd);
6688       const int ref_comp_ctx_p1 = av1_get_pred_context_comp_ref_p1(xd);
6689       const int ref_comp_ctx_p2 = av1_get_pred_context_comp_ref_p2(xd);
6690 
6691       const int comp_ref_type_ctx = av1_get_comp_reference_type_context(xd);
6692       unsigned int ref_bicomp_costs[REF_FRAMES] = { 0 };
6693 
6694       ref_bicomp_costs[LAST_FRAME] = ref_bicomp_costs[LAST2_FRAME] =
6695           ref_bicomp_costs[LAST3_FRAME] = ref_bicomp_costs[GOLDEN_FRAME] =
6696               base_cost + x->comp_ref_type_cost[comp_ref_type_ctx][1];
6697       ref_bicomp_costs[BWDREF_FRAME] = ref_bicomp_costs[ALTREF2_FRAME] = 0;
6698       ref_bicomp_costs[ALTREF_FRAME] = 0;
6699 
6700       // cost of first ref frame
6701       ref_bicomp_costs[LAST_FRAME] += x->comp_ref_cost[ref_comp_ctx_p][0][0];
6702       ref_bicomp_costs[LAST2_FRAME] += x->comp_ref_cost[ref_comp_ctx_p][0][0];
6703       ref_bicomp_costs[LAST3_FRAME] += x->comp_ref_cost[ref_comp_ctx_p][0][1];
6704       ref_bicomp_costs[GOLDEN_FRAME] += x->comp_ref_cost[ref_comp_ctx_p][0][1];
6705 
6706       ref_bicomp_costs[LAST_FRAME] += x->comp_ref_cost[ref_comp_ctx_p1][1][0];
6707       ref_bicomp_costs[LAST2_FRAME] += x->comp_ref_cost[ref_comp_ctx_p1][1][1];
6708 
6709       ref_bicomp_costs[LAST3_FRAME] += x->comp_ref_cost[ref_comp_ctx_p2][2][0];
6710       ref_bicomp_costs[GOLDEN_FRAME] += x->comp_ref_cost[ref_comp_ctx_p2][2][1];
6711 
6712       // cost of second ref frame
6713       ref_bicomp_costs[BWDREF_FRAME] +=
6714           x->comp_bwdref_cost[bwdref_comp_ctx_p][0][0];
6715       ref_bicomp_costs[ALTREF2_FRAME] +=
6716           x->comp_bwdref_cost[bwdref_comp_ctx_p][0][0];
6717       ref_bicomp_costs[ALTREF_FRAME] +=
6718           x->comp_bwdref_cost[bwdref_comp_ctx_p][0][1];
6719 
6720       ref_bicomp_costs[BWDREF_FRAME] +=
6721           x->comp_bwdref_cost[bwdref_comp_ctx_p1][1][0];
6722       ref_bicomp_costs[ALTREF2_FRAME] +=
6723           x->comp_bwdref_cost[bwdref_comp_ctx_p1][1][1];
6724 
6725       // cost: if one ref frame is forward ref, the other ref is backward ref
6726       int ref0, ref1;
6727       for (ref0 = LAST_FRAME; ref0 <= GOLDEN_FRAME; ++ref0) {
6728         for (ref1 = BWDREF_FRAME; ref1 <= ALTREF_FRAME; ++ref1) {
6729           ref_costs_comp[ref0][ref1] =
6730               ref_bicomp_costs[ref0] + ref_bicomp_costs[ref1];
6731         }
6732       }
6733 
6734       // cost: if both ref frames are the same side.
6735       const int uni_comp_ref_ctx_p = av1_get_pred_context_uni_comp_ref_p(xd);
6736       const int uni_comp_ref_ctx_p1 = av1_get_pred_context_uni_comp_ref_p1(xd);
6737       const int uni_comp_ref_ctx_p2 = av1_get_pred_context_uni_comp_ref_p2(xd);
6738       ref_costs_comp[LAST_FRAME][LAST2_FRAME] =
6739           base_cost + x->comp_ref_type_cost[comp_ref_type_ctx][0] +
6740           x->uni_comp_ref_cost[uni_comp_ref_ctx_p][0][0] +
6741           x->uni_comp_ref_cost[uni_comp_ref_ctx_p1][1][0];
6742       ref_costs_comp[LAST_FRAME][LAST3_FRAME] =
6743           base_cost + x->comp_ref_type_cost[comp_ref_type_ctx][0] +
6744           x->uni_comp_ref_cost[uni_comp_ref_ctx_p][0][0] +
6745           x->uni_comp_ref_cost[uni_comp_ref_ctx_p1][1][1] +
6746           x->uni_comp_ref_cost[uni_comp_ref_ctx_p2][2][0];
6747       ref_costs_comp[LAST_FRAME][GOLDEN_FRAME] =
6748           base_cost + x->comp_ref_type_cost[comp_ref_type_ctx][0] +
6749           x->uni_comp_ref_cost[uni_comp_ref_ctx_p][0][0] +
6750           x->uni_comp_ref_cost[uni_comp_ref_ctx_p1][1][1] +
6751           x->uni_comp_ref_cost[uni_comp_ref_ctx_p2][2][1];
6752       ref_costs_comp[BWDREF_FRAME][ALTREF_FRAME] =
6753           base_cost + x->comp_ref_type_cost[comp_ref_type_ctx][0] +
6754           x->uni_comp_ref_cost[uni_comp_ref_ctx_p][0][1];
6755     } else {
6756       int ref0, ref1;
6757       for (ref0 = LAST_FRAME; ref0 <= GOLDEN_FRAME; ++ref0) {
6758         for (ref1 = BWDREF_FRAME; ref1 <= ALTREF_FRAME; ++ref1)
6759           ref_costs_comp[ref0][ref1] = 512;
6760       }
6761       ref_costs_comp[LAST_FRAME][LAST2_FRAME] = 512;
6762       ref_costs_comp[LAST_FRAME][LAST3_FRAME] = 512;
6763       ref_costs_comp[LAST_FRAME][GOLDEN_FRAME] = 512;
6764       ref_costs_comp[BWDREF_FRAME][ALTREF_FRAME] = 512;
6765     }
6766   }
6767 }
6768 
store_coding_context(MACROBLOCK * x,PICK_MODE_CONTEXT * ctx,int mode_index,int64_t comp_pred_diff[REFERENCE_MODES],int skippable)6769 static void store_coding_context(MACROBLOCK *x, PICK_MODE_CONTEXT *ctx,
6770                                  int mode_index,
6771                                  int64_t comp_pred_diff[REFERENCE_MODES],
6772                                  int skippable) {
6773   MACROBLOCKD *const xd = &x->e_mbd;
6774 
6775   // Take a snapshot of the coding context so it can be
6776   // restored if we decide to encode this way
6777   ctx->skip = x->skip;
6778   ctx->skippable = skippable;
6779   ctx->best_mode_index = mode_index;
6780   ctx->mic = *xd->mi[0];
6781   ctx->mbmi_ext = *x->mbmi_ext;
6782   ctx->single_pred_diff = (int)comp_pred_diff[SINGLE_REFERENCE];
6783   ctx->comp_pred_diff = (int)comp_pred_diff[COMPOUND_REFERENCE];
6784   ctx->hybrid_pred_diff = (int)comp_pred_diff[REFERENCE_MODE_SELECT];
6785 }
6786 
setup_buffer_ref_mvs_inter(const AV1_COMP * const cpi,MACROBLOCK * x,MV_REFERENCE_FRAME ref_frame,BLOCK_SIZE block_size,int mi_row,int mi_col,struct buf_2d yv12_mb[REF_FRAMES][MAX_MB_PLANE])6787 static void setup_buffer_ref_mvs_inter(
6788     const AV1_COMP *const cpi, MACROBLOCK *x, MV_REFERENCE_FRAME ref_frame,
6789     BLOCK_SIZE block_size, int mi_row, int mi_col,
6790     struct buf_2d yv12_mb[REF_FRAMES][MAX_MB_PLANE]) {
6791   const AV1_COMMON *cm = &cpi->common;
6792   const int num_planes = av1_num_planes(cm);
6793   const YV12_BUFFER_CONFIG *yv12 = get_ref_frame_buffer(cpi, ref_frame);
6794   MACROBLOCKD *const xd = &x->e_mbd;
6795   MB_MODE_INFO *const mbmi = xd->mi[0];
6796   const struct scale_factors *const sf = &cm->frame_refs[ref_frame - 1].sf;
6797   MB_MODE_INFO_EXT *const mbmi_ext = x->mbmi_ext;
6798 
6799   assert(yv12 != NULL);
6800 
6801   // TODO(jkoleszar): Is the UV buffer ever used here? If so, need to make this
6802   // use the UV scaling factors.
6803   av1_setup_pred_block(xd, yv12_mb[ref_frame], yv12, mi_row, mi_col, sf, sf,
6804                        num_planes);
6805 
6806   // Gets an initial list of candidate vectors from neighbours and orders them
6807   av1_find_mv_refs(cm, xd, mbmi, ref_frame, mbmi_ext->ref_mv_count,
6808                    mbmi_ext->ref_mv_stack, NULL, mbmi_ext->global_mvs, mi_row,
6809                    mi_col, mbmi_ext->mode_context);
6810 
6811   // Further refinement that is encode side only to test the top few candidates
6812   // in full and choose the best as the centre point for subsequent searches.
6813   // The current implementation doesn't support scaling.
6814   (void)block_size;
6815   av1_mv_pred(cpi, x, yv12_mb[ref_frame][0].buf, yv12->y_stride, ref_frame,
6816               block_size);
6817 }
6818 
single_motion_search(const AV1_COMP * const cpi,MACROBLOCK * x,BLOCK_SIZE bsize,int mi_row,int mi_col,int ref_idx,int * rate_mv)6819 static void single_motion_search(const AV1_COMP *const cpi, MACROBLOCK *x,
6820                                  BLOCK_SIZE bsize, int mi_row, int mi_col,
6821                                  int ref_idx, int *rate_mv) {
6822   MACROBLOCKD *xd = &x->e_mbd;
6823   const AV1_COMMON *cm = &cpi->common;
6824   const int num_planes = av1_num_planes(cm);
6825   MB_MODE_INFO *mbmi = xd->mi[0];
6826   struct buf_2d backup_yv12[MAX_MB_PLANE] = { { 0, 0, 0, 0, 0 } };
6827   int bestsme = INT_MAX;
6828   int step_param;
6829   int sadpb = x->sadperbit16;
6830   MV mvp_full;
6831   int ref = mbmi->ref_frame[ref_idx];
6832   MV ref_mv = av1_get_ref_mv(x, ref_idx).as_mv;
6833 
6834   MvLimits tmp_mv_limits = x->mv_limits;
6835   int cost_list[5];
6836 
6837   const YV12_BUFFER_CONFIG *scaled_ref_frame =
6838       av1_get_scaled_ref_frame(cpi, ref);
6839 
6840   if (scaled_ref_frame) {
6841     // Swap out the reference frame for a version that's been scaled to
6842     // match the resolution of the current frame, allowing the existing
6843     // full-pixel motion search code to be used without additional
6844     // modifications.
6845     for (int i = 0; i < num_planes; i++) {
6846       backup_yv12[i] = xd->plane[i].pre[ref_idx];
6847     }
6848     av1_setup_pre_planes(xd, ref_idx, scaled_ref_frame, mi_row, mi_col, NULL,
6849                          num_planes);
6850   }
6851 
6852   av1_set_mvcost(
6853       x, ref_idx,
6854       mbmi->ref_mv_idx + (have_nearmv_in_inter_mode(mbmi->mode) ? 1 : 0));
6855 
6856   // Work out the size of the first step in the mv step search.
6857   // 0 here is maximum length first step. 1 is AOMMAX >> 1 etc.
6858   if (cpi->sf.mv.auto_mv_step_size && cm->show_frame) {
6859     // Take the weighted average of the step_params based on the last frame's
6860     // max mv magnitude and that based on the best ref mvs of the current
6861     // block for the given reference.
6862     step_param =
6863         (av1_init_search_range(x->max_mv_context[ref]) + cpi->mv_step_param) /
6864         2;
6865   } else {
6866     step_param = cpi->mv_step_param;
6867   }
6868 
6869   if (cpi->sf.adaptive_motion_search && bsize < cm->seq_params.sb_size) {
6870     int boffset =
6871         2 * (mi_size_wide_log2[cm->seq_params.sb_size] -
6872              AOMMIN(mi_size_high_log2[bsize], mi_size_wide_log2[bsize]));
6873     step_param = AOMMAX(step_param, boffset);
6874   }
6875 
6876   if (cpi->sf.adaptive_motion_search) {
6877     int bwl = mi_size_wide_log2[bsize];
6878     int bhl = mi_size_high_log2[bsize];
6879     int tlevel = x->pred_mv_sad[ref] >> (bwl + bhl + 4);
6880 
6881     if (tlevel < 5) {
6882       step_param += 2;
6883       step_param = AOMMIN(step_param, MAX_MVSEARCH_STEPS - 1);
6884     }
6885 
6886     // prev_mv_sad is not setup for dynamically scaled frames.
6887     if (cpi->oxcf.resize_mode != RESIZE_RANDOM) {
6888       int i;
6889       for (i = LAST_FRAME; i <= ALTREF_FRAME && cm->show_frame; ++i) {
6890         if ((x->pred_mv_sad[ref] >> 3) > x->pred_mv_sad[i]) {
6891           x->pred_mv[ref].row = 0;
6892           x->pred_mv[ref].col = 0;
6893           x->best_mv.as_int = INVALID_MV;
6894 
6895           if (scaled_ref_frame) {
6896             // Swap back the original buffers before returning.
6897             for (int j = 0; j < num_planes; ++j)
6898               xd->plane[j].pre[ref_idx] = backup_yv12[j];
6899           }
6900           return;
6901         }
6902       }
6903     }
6904   }
6905 
6906   // Note: MV limits are modified here. Always restore the original values
6907   // after full-pixel motion search.
6908   av1_set_mv_search_range(&x->mv_limits, &ref_mv);
6909 
6910   if (mbmi->motion_mode != SIMPLE_TRANSLATION)
6911     mvp_full = mbmi->mv[0].as_mv;
6912   else
6913     mvp_full = ref_mv;
6914 
6915   mvp_full.col >>= 3;
6916   mvp_full.row >>= 3;
6917 
6918   x->best_mv.as_int = x->second_best_mv.as_int = INVALID_MV;
6919 
6920   switch (mbmi->motion_mode) {
6921     case SIMPLE_TRANSLATION:
6922       bestsme = av1_full_pixel_search(
6923           cpi, x, bsize, &mvp_full, step_param, cpi->sf.mv.search_method, 0,
6924           sadpb, cond_cost_list(cpi, cost_list), &ref_mv, INT_MAX, 1,
6925           (MI_SIZE * mi_col), (MI_SIZE * mi_row), 0);
6926       break;
6927     case OBMC_CAUSAL:
6928       bestsme = av1_obmc_full_pixel_search(cpi, x, &mvp_full, step_param, sadpb,
6929                                            MAX_MVSEARCH_STEPS - 1 - step_param,
6930                                            1, &cpi->fn_ptr[bsize], &ref_mv,
6931                                            &(x->best_mv.as_mv), 0);
6932       break;
6933     default: assert(0 && "Invalid motion mode!\n");
6934   }
6935 
6936   if (scaled_ref_frame) {
6937     // Swap back the original buffers for subpel motion search.
6938     for (int i = 0; i < num_planes; i++) {
6939       xd->plane[i].pre[ref_idx] = backup_yv12[i];
6940     }
6941   }
6942 
6943   x->mv_limits = tmp_mv_limits;
6944 
6945   if (cpi->common.cur_frame_force_integer_mv) {
6946     x->best_mv.as_mv.row *= 8;
6947     x->best_mv.as_mv.col *= 8;
6948   }
6949   const int use_fractional_mv =
6950       bestsme < INT_MAX && cpi->common.cur_frame_force_integer_mv == 0;
6951   if (use_fractional_mv) {
6952     int dis; /* TODO: use dis in distortion calculation later. */
6953     switch (mbmi->motion_mode) {
6954       case SIMPLE_TRANSLATION:
6955         if (cpi->sf.use_accurate_subpel_search) {
6956           int best_mv_var;
6957           const int try_second = x->second_best_mv.as_int != INVALID_MV &&
6958                                  x->second_best_mv.as_int != x->best_mv.as_int;
6959           const int pw = block_size_wide[bsize];
6960           const int ph = block_size_high[bsize];
6961 
6962           best_mv_var = cpi->find_fractional_mv_step(
6963               x, cm, mi_row, mi_col, &ref_mv, cm->allow_high_precision_mv,
6964               x->errorperbit, &cpi->fn_ptr[bsize], cpi->sf.mv.subpel_force_stop,
6965               cpi->sf.mv.subpel_iters_per_step, cond_cost_list(cpi, cost_list),
6966               x->nmvjointcost, x->mvcost, &dis, &x->pred_sse[ref], NULL, NULL,
6967               0, 0, pw, ph, 1);
6968 
6969           if (try_second) {
6970             const int minc =
6971                 AOMMAX(x->mv_limits.col_min * 8, ref_mv.col - MV_MAX);
6972             const int maxc =
6973                 AOMMIN(x->mv_limits.col_max * 8, ref_mv.col + MV_MAX);
6974             const int minr =
6975                 AOMMAX(x->mv_limits.row_min * 8, ref_mv.row - MV_MAX);
6976             const int maxr =
6977                 AOMMIN(x->mv_limits.row_max * 8, ref_mv.row + MV_MAX);
6978             int this_var;
6979             MV best_mv = x->best_mv.as_mv;
6980 
6981             x->best_mv = x->second_best_mv;
6982             if (x->best_mv.as_mv.row * 8 <= maxr &&
6983                 x->best_mv.as_mv.row * 8 >= minr &&
6984                 x->best_mv.as_mv.col * 8 <= maxc &&
6985                 x->best_mv.as_mv.col * 8 >= minc) {
6986               this_var = cpi->find_fractional_mv_step(
6987                   x, cm, mi_row, mi_col, &ref_mv, cm->allow_high_precision_mv,
6988                   x->errorperbit, &cpi->fn_ptr[bsize],
6989                   cpi->sf.mv.subpel_force_stop,
6990                   cpi->sf.mv.subpel_iters_per_step,
6991                   cond_cost_list(cpi, cost_list), x->nmvjointcost, x->mvcost,
6992                   &dis, &x->pred_sse[ref], NULL, NULL, 0, 0, pw, ph, 1);
6993               if (this_var < best_mv_var) best_mv = x->best_mv.as_mv;
6994               x->best_mv.as_mv = best_mv;
6995             }
6996           }
6997         } else {
6998           cpi->find_fractional_mv_step(
6999               x, cm, mi_row, mi_col, &ref_mv, cm->allow_high_precision_mv,
7000               x->errorperbit, &cpi->fn_ptr[bsize], cpi->sf.mv.subpel_force_stop,
7001               cpi->sf.mv.subpel_iters_per_step, cond_cost_list(cpi, cost_list),
7002               x->nmvjointcost, x->mvcost, &dis, &x->pred_sse[ref], NULL, NULL,
7003               0, 0, 0, 0, 0);
7004         }
7005         break;
7006       case OBMC_CAUSAL:
7007         av1_find_best_obmc_sub_pixel_tree_up(
7008             x, cm, mi_row, mi_col, &x->best_mv.as_mv, &ref_mv,
7009             cm->allow_high_precision_mv, x->errorperbit, &cpi->fn_ptr[bsize],
7010             cpi->sf.mv.subpel_force_stop, cpi->sf.mv.subpel_iters_per_step,
7011             x->nmvjointcost, x->mvcost, &dis, &x->pred_sse[ref], 0,
7012             cpi->sf.use_accurate_subpel_search);
7013         break;
7014       default: assert(0 && "Invalid motion mode!\n");
7015     }
7016   }
7017   *rate_mv = av1_mv_bit_cost(&x->best_mv.as_mv, &ref_mv, x->nmvjointcost,
7018                              x->mvcost, MV_COST_WEIGHT);
7019 
7020   if (cpi->sf.adaptive_motion_search && mbmi->motion_mode == SIMPLE_TRANSLATION)
7021     x->pred_mv[ref] = x->best_mv.as_mv;
7022 }
7023 
restore_dst_buf(MACROBLOCKD * xd,BUFFER_SET dst,const int num_planes)7024 static INLINE void restore_dst_buf(MACROBLOCKD *xd, BUFFER_SET dst,
7025                                    const int num_planes) {
7026   int i;
7027   for (i = 0; i < num_planes; i++) {
7028     xd->plane[i].dst.buf = dst.plane[i];
7029     xd->plane[i].dst.stride = dst.stride[i];
7030   }
7031 }
7032 
build_second_inter_pred(const AV1_COMP * cpi,MACROBLOCK * x,BLOCK_SIZE bsize,const MV * other_mv,int mi_row,int mi_col,const int block,int ref_idx,uint8_t * second_pred)7033 static void build_second_inter_pred(const AV1_COMP *cpi, MACROBLOCK *x,
7034                                     BLOCK_SIZE bsize, const MV *other_mv,
7035                                     int mi_row, int mi_col, const int block,
7036                                     int ref_idx, uint8_t *second_pred) {
7037   const AV1_COMMON *const cm = &cpi->common;
7038   const int pw = block_size_wide[bsize];
7039   const int ph = block_size_high[bsize];
7040   MACROBLOCKD *xd = &x->e_mbd;
7041   MB_MODE_INFO *mbmi = xd->mi[0];
7042   const int other_ref = mbmi->ref_frame[!ref_idx];
7043   struct macroblockd_plane *const pd = &xd->plane[0];
7044   // ic and ir are the 4x4 coordinates of the sub8x8 at index "block"
7045   const int ic = block & 1;
7046   const int ir = (block - ic) >> 1;
7047   const int p_col = ((mi_col * MI_SIZE) >> pd->subsampling_x) + 4 * ic;
7048   const int p_row = ((mi_row * MI_SIZE) >> pd->subsampling_y) + 4 * ir;
7049   const WarpedMotionParams *const wm = &xd->global_motion[other_ref];
7050   int is_global = is_global_mv_block(xd->mi[0], wm->wmtype);
7051 
7052   // This function should only ever be called for compound modes
7053   assert(has_second_ref(mbmi));
7054 
7055   const int plane = 0;
7056   struct buf_2d ref_yv12 = xd->plane[plane].pre[!ref_idx];
7057 
7058   struct scale_factors sf;
7059   av1_setup_scale_factors_for_frame(&sf, ref_yv12.width, ref_yv12.height,
7060                                     cm->width, cm->height);
7061 
7062   ConvolveParams conv_params = get_conv_params(0, plane, xd->bd);
7063   WarpTypesAllowed warp_types;
7064   warp_types.global_warp_allowed = is_global;
7065   warp_types.local_warp_allowed = mbmi->motion_mode == WARPED_CAUSAL;
7066 
7067   // Get the prediction block from the 'other' reference frame.
7068   av1_build_inter_predictor(ref_yv12.buf, ref_yv12.stride, second_pred, pw,
7069                             other_mv, &sf, pw, ph, &conv_params,
7070                             mbmi->interp_filters, &warp_types, p_col, p_row,
7071                             plane, !ref_idx, MV_PRECISION_Q3, mi_col * MI_SIZE,
7072                             mi_row * MI_SIZE, xd, cm->allow_warped_motion);
7073 
7074   av1_jnt_comp_weight_assign(cm, mbmi, 0, &xd->jcp_param.fwd_offset,
7075                              &xd->jcp_param.bck_offset,
7076                              &xd->jcp_param.use_jnt_comp_avg, 1);
7077 }
7078 
7079 // Search for the best mv for one component of a compound,
7080 // given that the other component is fixed.
compound_single_motion_search(const AV1_COMP * cpi,MACROBLOCK * x,BLOCK_SIZE bsize,MV * this_mv,int mi_row,int mi_col,const uint8_t * second_pred,const uint8_t * mask,int mask_stride,int * rate_mv,int ref_idx)7081 static void compound_single_motion_search(const AV1_COMP *cpi, MACROBLOCK *x,
7082                                           BLOCK_SIZE bsize, MV *this_mv,
7083                                           int mi_row, int mi_col,
7084                                           const uint8_t *second_pred,
7085                                           const uint8_t *mask, int mask_stride,
7086                                           int *rate_mv, int ref_idx) {
7087   const AV1_COMMON *const cm = &cpi->common;
7088   const int num_planes = av1_num_planes(cm);
7089   const int pw = block_size_wide[bsize];
7090   const int ph = block_size_high[bsize];
7091   MACROBLOCKD *xd = &x->e_mbd;
7092   MB_MODE_INFO *mbmi = xd->mi[0];
7093   const int ref = mbmi->ref_frame[ref_idx];
7094   const int_mv ref_mv = av1_get_ref_mv(x, ref_idx);
7095   struct macroblockd_plane *const pd = &xd->plane[0];
7096 
7097   struct buf_2d backup_yv12[MAX_MB_PLANE];
7098   const YV12_BUFFER_CONFIG *const scaled_ref_frame =
7099       av1_get_scaled_ref_frame(cpi, ref);
7100 
7101   // Check that this is either an interinter or an interintra block
7102   assert(has_second_ref(mbmi) || (ref_idx == 0 && is_interintra_mode(mbmi)));
7103 
7104   // Store the first prediction buffer.
7105   struct buf_2d orig_yv12;
7106   if (ref_idx) {
7107     orig_yv12 = pd->pre[0];
7108     pd->pre[0] = pd->pre[ref_idx];
7109   }
7110 
7111   if (scaled_ref_frame) {
7112     int i;
7113     // Swap out the reference frame for a version that's been scaled to
7114     // match the resolution of the current frame, allowing the existing
7115     // full-pixel motion search code to be used without additional
7116     // modifications.
7117     for (i = 0; i < num_planes; i++) backup_yv12[i] = xd->plane[i].pre[ref_idx];
7118     av1_setup_pre_planes(xd, ref_idx, scaled_ref_frame, mi_row, mi_col, NULL,
7119                          num_planes);
7120   }
7121 
7122   int bestsme = INT_MAX;
7123   int sadpb = x->sadperbit16;
7124   MV *const best_mv = &x->best_mv.as_mv;
7125   int search_range = SEARCH_RANGE_8P;
7126 
7127   MvLimits tmp_mv_limits = x->mv_limits;
7128 
7129   // Do compound motion search on the current reference frame.
7130   av1_set_mv_search_range(&x->mv_limits, &ref_mv.as_mv);
7131 
7132   // Use the mv result from the single mode as mv predictor.
7133   *best_mv = *this_mv;
7134 
7135   best_mv->col >>= 3;
7136   best_mv->row >>= 3;
7137 
7138   av1_set_mvcost(
7139       x, ref_idx,
7140       mbmi->ref_mv_idx + (have_nearmv_in_inter_mode(mbmi->mode) ? 1 : 0));
7141 
7142   // Small-range full-pixel motion search.
7143   bestsme = av1_refining_search_8p_c(x, sadpb, search_range,
7144                                      &cpi->fn_ptr[bsize], mask, mask_stride,
7145                                      ref_idx, &ref_mv.as_mv, second_pred);
7146   if (bestsme < INT_MAX) {
7147     if (mask)
7148       bestsme =
7149           av1_get_mvpred_mask_var(x, best_mv, &ref_mv.as_mv, second_pred, mask,
7150                                   mask_stride, ref_idx, &cpi->fn_ptr[bsize], 1);
7151     else
7152       bestsme = av1_get_mvpred_av_var(x, best_mv, &ref_mv.as_mv, second_pred,
7153                                       &cpi->fn_ptr[bsize], 1);
7154   }
7155 
7156   x->mv_limits = tmp_mv_limits;
7157 
7158   if (scaled_ref_frame) {
7159     // Swap back the original buffers for subpel motion search.
7160     for (int i = 0; i < num_planes; i++) {
7161       xd->plane[i].pre[ref_idx] = backup_yv12[i];
7162     }
7163   }
7164 
7165   if (cpi->common.cur_frame_force_integer_mv) {
7166     x->best_mv.as_mv.row *= 8;
7167     x->best_mv.as_mv.col *= 8;
7168   }
7169   const int use_fractional_mv =
7170       bestsme < INT_MAX && cpi->common.cur_frame_force_integer_mv == 0;
7171   if (use_fractional_mv) {
7172     int dis; /* TODO: use dis in distortion calculation later. */
7173     unsigned int sse;
7174     bestsme = cpi->find_fractional_mv_step(
7175         x, cm, mi_row, mi_col, &ref_mv.as_mv,
7176         cpi->common.allow_high_precision_mv, x->errorperbit,
7177         &cpi->fn_ptr[bsize], 0, cpi->sf.mv.subpel_iters_per_step, NULL,
7178         x->nmvjointcost, x->mvcost, &dis, &sse, second_pred, mask, mask_stride,
7179         ref_idx, pw, ph, cpi->sf.use_accurate_subpel_search);
7180   }
7181 
7182   // Restore the pointer to the first unscaled prediction buffer.
7183   if (ref_idx) pd->pre[0] = orig_yv12;
7184 
7185   if (bestsme < INT_MAX) *this_mv = *best_mv;
7186 
7187   *rate_mv = 0;
7188 
7189   av1_set_mvcost(
7190       x, ref_idx,
7191       mbmi->ref_mv_idx + (have_nearmv_in_inter_mode(mbmi->mode) ? 1 : 0));
7192   *rate_mv += av1_mv_bit_cost(this_mv, &ref_mv.as_mv, x->nmvjointcost,
7193                               x->mvcost, MV_COST_WEIGHT);
7194 }
7195 
7196 // Wrapper for compound_single_motion_search, for the common case
7197 // where the second prediction is also an inter mode.
compound_single_motion_search_interinter(const AV1_COMP * cpi,MACROBLOCK * x,BLOCK_SIZE bsize,int_mv * cur_mv,int mi_row,int mi_col,const uint8_t * mask,int mask_stride,int * rate_mv,const int block,int ref_idx)7198 static void compound_single_motion_search_interinter(
7199     const AV1_COMP *cpi, MACROBLOCK *x, BLOCK_SIZE bsize, int_mv *cur_mv,
7200     int mi_row, int mi_col, const uint8_t *mask, int mask_stride, int *rate_mv,
7201     const int block, int ref_idx) {
7202   MACROBLOCKD *xd = &x->e_mbd;
7203   // This function should only ever be called for compound modes
7204   assert(has_second_ref(xd->mi[0]));
7205 
7206   // Prediction buffer from second frame.
7207   DECLARE_ALIGNED(16, uint16_t, second_pred_alloc_16[MAX_SB_SQUARE]);
7208   uint8_t *second_pred;
7209   if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH)
7210     second_pred = CONVERT_TO_BYTEPTR(second_pred_alloc_16);
7211   else
7212     second_pred = (uint8_t *)second_pred_alloc_16;
7213 
7214   MV *this_mv = &cur_mv[ref_idx].as_mv;
7215   const MV *other_mv = &cur_mv[!ref_idx].as_mv;
7216 
7217   build_second_inter_pred(cpi, x, bsize, other_mv, mi_row, mi_col, block,
7218                           ref_idx, second_pred);
7219 
7220   compound_single_motion_search(cpi, x, bsize, this_mv, mi_row, mi_col,
7221                                 second_pred, mask, mask_stride, rate_mv,
7222                                 ref_idx);
7223 }
7224 
do_masked_motion_search_indexed(const AV1_COMP * const cpi,MACROBLOCK * x,const int_mv * const cur_mv,const INTERINTER_COMPOUND_DATA * const comp_data,BLOCK_SIZE bsize,int mi_row,int mi_col,int_mv * tmp_mv,int * rate_mv,int which)7225 static void do_masked_motion_search_indexed(
7226     const AV1_COMP *const cpi, MACROBLOCK *x, const int_mv *const cur_mv,
7227     const INTERINTER_COMPOUND_DATA *const comp_data, BLOCK_SIZE bsize,
7228     int mi_row, int mi_col, int_mv *tmp_mv, int *rate_mv, int which) {
7229   // NOTE: which values: 0 - 0 only, 1 - 1 only, 2 - both
7230   MACROBLOCKD *xd = &x->e_mbd;
7231   MB_MODE_INFO *mbmi = xd->mi[0];
7232   BLOCK_SIZE sb_type = mbmi->sb_type;
7233   const uint8_t *mask;
7234   const int mask_stride = block_size_wide[bsize];
7235 
7236   mask = av1_get_compound_type_mask(comp_data, sb_type);
7237 
7238   tmp_mv[0].as_int = cur_mv[0].as_int;
7239   tmp_mv[1].as_int = cur_mv[1].as_int;
7240   if (which == 0 || which == 1) {
7241     compound_single_motion_search_interinter(cpi, x, bsize, tmp_mv, mi_row,
7242                                              mi_col, mask, mask_stride, rate_mv,
7243                                              0, which);
7244   } else if (which == 2) {
7245     joint_motion_search(cpi, x, bsize, tmp_mv, mi_row, mi_col, NULL, mask,
7246                         mask_stride, rate_mv, 0);
7247   }
7248 }
7249 
7250 #define USE_DISCOUNT_NEWMV_TEST 0
7251 #if USE_DISCOUNT_NEWMV_TEST
7252 // In some situations we want to discount the apparent cost of a new motion
7253 // vector. Where there is a subtle motion field and especially where there is
7254 // low spatial complexity then it can be hard to cover the cost of a new motion
7255 // vector in a single block, even if that motion vector reduces distortion.
7256 // However, once established that vector may be usable through the nearest and
7257 // near mv modes to reduce distortion in subsequent blocks and also improve
7258 // visual quality.
7259 #define NEW_MV_DISCOUNT_FACTOR 8
7260 static INLINE void get_this_mv(int_mv *this_mv, PREDICTION_MODE this_mode,
7261                                int ref_idx, int ref_mv_idx,
7262                                const MV_REFERENCE_FRAME *ref_frame,
7263                                const MB_MODE_INFO_EXT *mbmi_ext);
discount_newmv_test(const AV1_COMP * const cpi,const MACROBLOCK * x,PREDICTION_MODE this_mode,int_mv this_mv)7264 static int discount_newmv_test(const AV1_COMP *const cpi, const MACROBLOCK *x,
7265                                PREDICTION_MODE this_mode, int_mv this_mv) {
7266   if (this_mode == NEWMV && this_mv.as_int != 0 &&
7267       !cpi->rc.is_src_frame_alt_ref) {
7268     // Only discount new_mv when nearst_mv and all near_mv are zero, and the
7269     // new_mv is not equal to global_mv
7270     const AV1_COMMON *const cm = &cpi->common;
7271     const MACROBLOCKD *const xd = &x->e_mbd;
7272     const MB_MODE_INFO *const mbmi = xd->mi[0];
7273     const MV_REFERENCE_FRAME tmp_ref_frames[2] = { mbmi->ref_frame[0],
7274                                                    NONE_FRAME };
7275     const uint8_t ref_frame_type = av1_ref_frame_type(tmp_ref_frames);
7276     int_mv nearest_mv;
7277     get_this_mv(&nearest_mv, NEARESTMV, 0, 0, tmp_ref_frames, x->mbmi_ext);
7278     int ret = nearest_mv.as_int == 0;
7279     for (int ref_mv_idx = 0;
7280          ref_mv_idx < x->mbmi_ext->ref_mv_count[ref_frame_type]; ++ref_mv_idx) {
7281       int_mv near_mv;
7282       get_this_mv(&near_mv, NEARMV, 0, ref_mv_idx, tmp_ref_frames, x->mbmi_ext);
7283       ret &= near_mv.as_int == 0;
7284     }
7285     if (cm->global_motion[tmp_ref_frames[0]].wmtype <= TRANSLATION) {
7286       int_mv global_mv;
7287       get_this_mv(&global_mv, GLOBALMV, 0, 0, tmp_ref_frames, x->mbmi_ext);
7288       ret &= global_mv.as_int != this_mv.as_int;
7289     }
7290     return ret;
7291   }
7292   return 0;
7293 }
7294 #endif
7295 
7296 #define LEFT_TOP_MARGIN ((AOM_BORDER_IN_PIXELS - AOM_INTERP_EXTEND) << 3)
7297 #define RIGHT_BOTTOM_MARGIN ((AOM_BORDER_IN_PIXELS - AOM_INTERP_EXTEND) << 3)
7298 
7299 // TODO(jingning): this mv clamping function should be block size dependent.
clamp_mv2(MV * mv,const MACROBLOCKD * xd)7300 static INLINE void clamp_mv2(MV *mv, const MACROBLOCKD *xd) {
7301   clamp_mv(mv, xd->mb_to_left_edge - LEFT_TOP_MARGIN,
7302            xd->mb_to_right_edge + RIGHT_BOTTOM_MARGIN,
7303            xd->mb_to_top_edge - LEFT_TOP_MARGIN,
7304            xd->mb_to_bottom_edge + RIGHT_BOTTOM_MARGIN);
7305 }
7306 
estimate_wedge_sign(const AV1_COMP * cpi,const MACROBLOCK * x,const BLOCK_SIZE bsize,const uint8_t * pred0,int stride0,const uint8_t * pred1,int stride1)7307 static int estimate_wedge_sign(const AV1_COMP *cpi, const MACROBLOCK *x,
7308                                const BLOCK_SIZE bsize, const uint8_t *pred0,
7309                                int stride0, const uint8_t *pred1, int stride1) {
7310   static const BLOCK_SIZE split_qtr[BLOCK_SIZES_ALL] = {
7311     //                            4X4
7312     BLOCK_INVALID,
7313     // 4X8,        8X4,           8X8
7314     BLOCK_INVALID, BLOCK_INVALID, BLOCK_4X4,
7315     // 8X16,       16X8,          16X16
7316     BLOCK_4X8, BLOCK_8X4, BLOCK_8X8,
7317     // 16X32,      32X16,         32X32
7318     BLOCK_8X16, BLOCK_16X8, BLOCK_16X16,
7319     // 32X64,      64X32,         64X64
7320     BLOCK_16X32, BLOCK_32X16, BLOCK_32X32,
7321     // 64x128,     128x64,        128x128
7322     BLOCK_32X64, BLOCK_64X32, BLOCK_64X64,
7323     // 4X16,       16X4,          8X32
7324     BLOCK_INVALID, BLOCK_INVALID, BLOCK_4X16,
7325     // 32X8,       16X64,         64X16
7326     BLOCK_16X4, BLOCK_8X32, BLOCK_32X8
7327   };
7328   const struct macroblock_plane *const p = &x->plane[0];
7329   const uint8_t *src = p->src.buf;
7330   int src_stride = p->src.stride;
7331   const int bw = block_size_wide[bsize];
7332   const int bh = block_size_high[bsize];
7333   uint32_t esq[2][4];
7334   int64_t tl, br;
7335 
7336   const BLOCK_SIZE f_index = split_qtr[bsize];
7337   assert(f_index != BLOCK_INVALID);
7338 
7339   if (x->e_mbd.cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) {
7340     pred0 = CONVERT_TO_BYTEPTR(pred0);
7341     pred1 = CONVERT_TO_BYTEPTR(pred1);
7342   }
7343 
7344   cpi->fn_ptr[f_index].vf(src, src_stride, pred0, stride0, &esq[0][0]);
7345   cpi->fn_ptr[f_index].vf(src + bw / 2, src_stride, pred0 + bw / 2, stride0,
7346                           &esq[0][1]);
7347   cpi->fn_ptr[f_index].vf(src + bh / 2 * src_stride, src_stride,
7348                           pred0 + bh / 2 * stride0, stride0, &esq[0][2]);
7349   cpi->fn_ptr[f_index].vf(src + bh / 2 * src_stride + bw / 2, src_stride,
7350                           pred0 + bh / 2 * stride0 + bw / 2, stride0,
7351                           &esq[0][3]);
7352   cpi->fn_ptr[f_index].vf(src, src_stride, pred1, stride1, &esq[1][0]);
7353   cpi->fn_ptr[f_index].vf(src + bw / 2, src_stride, pred1 + bw / 2, stride1,
7354                           &esq[1][1]);
7355   cpi->fn_ptr[f_index].vf(src + bh / 2 * src_stride, src_stride,
7356                           pred1 + bh / 2 * stride1, stride0, &esq[1][2]);
7357   cpi->fn_ptr[f_index].vf(src + bh / 2 * src_stride + bw / 2, src_stride,
7358                           pred1 + bh / 2 * stride1 + bw / 2, stride0,
7359                           &esq[1][3]);
7360 
7361   tl = ((int64_t)esq[0][0] + esq[0][1] + esq[0][2]) -
7362        ((int64_t)esq[1][0] + esq[1][1] + esq[1][2]);
7363   br = ((int64_t)esq[1][3] + esq[1][1] + esq[1][2]) -
7364        ((int64_t)esq[0][3] + esq[0][1] + esq[0][2]);
7365   return (tl + br > 0);
7366 }
7367 
7368 // Choose the best wedge index and sign
pick_wedge(const AV1_COMP * const cpi,const MACROBLOCK * const x,const BLOCK_SIZE bsize,const uint8_t * const p0,const int16_t * const residual1,const int16_t * const diff10,int * const best_wedge_sign,int * const best_wedge_index)7369 static int64_t pick_wedge(const AV1_COMP *const cpi, const MACROBLOCK *const x,
7370                           const BLOCK_SIZE bsize, const uint8_t *const p0,
7371                           const int16_t *const residual1,
7372                           const int16_t *const diff10,
7373                           int *const best_wedge_sign,
7374                           int *const best_wedge_index) {
7375   const MACROBLOCKD *const xd = &x->e_mbd;
7376   const struct buf_2d *const src = &x->plane[0].src;
7377   const int bw = block_size_wide[bsize];
7378   const int bh = block_size_high[bsize];
7379   const int N = bw * bh;
7380   assert(N >= 64);
7381   int rate;
7382   int64_t dist;
7383   int64_t rd, best_rd = INT64_MAX;
7384   int wedge_index;
7385   int wedge_sign;
7386   int wedge_types = (1 << get_wedge_bits_lookup(bsize));
7387   const uint8_t *mask;
7388   uint64_t sse;
7389   const int hbd = xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH;
7390   const int bd_round = hbd ? (xd->bd - 8) * 2 : 0;
7391 
7392   DECLARE_ALIGNED(32, int16_t, residual0[MAX_SB_SQUARE]);  // src - pred0
7393   if (hbd) {
7394     aom_highbd_subtract_block(bh, bw, residual0, bw, src->buf, src->stride,
7395                               CONVERT_TO_BYTEPTR(p0), bw, xd->bd);
7396   } else {
7397     aom_subtract_block(bh, bw, residual0, bw, src->buf, src->stride, p0, bw);
7398   }
7399 
7400   int64_t sign_limit = ((int64_t)aom_sum_squares_i16(residual0, N) -
7401                         (int64_t)aom_sum_squares_i16(residual1, N)) *
7402                        (1 << WEDGE_WEIGHT_BITS) / 2;
7403   int16_t *ds = residual0;
7404 
7405   av1_wedge_compute_delta_squares(ds, residual0, residual1, N);
7406 
7407   for (wedge_index = 0; wedge_index < wedge_types; ++wedge_index) {
7408     mask = av1_get_contiguous_soft_mask(wedge_index, 0, bsize);
7409 
7410     wedge_sign = av1_wedge_sign_from_residuals(ds, mask, N, sign_limit);
7411 
7412     mask = av1_get_contiguous_soft_mask(wedge_index, wedge_sign, bsize);
7413     sse = av1_wedge_sse_from_residuals(residual1, diff10, mask, N);
7414     sse = ROUND_POWER_OF_TWO(sse, bd_round);
7415 
7416     model_rd_sse_fn[MODELRD_TYPE_MASKED_COMPOUND](cpi, x, bsize, 0, sse, N,
7417                                                   &rate, &dist);
7418     // int rate2;
7419     // int64_t dist2;
7420     // model_rd_with_curvfit(cpi, x, bsize, 0, sse, N, &rate2, &dist2);
7421     // printf("sse %"PRId64": leagacy: %d %"PRId64", curvfit %d %"PRId64"\n",
7422     // sse, rate, dist, rate2, dist2); dist = dist2;
7423     // rate = rate2;
7424 
7425     rate += x->wedge_idx_cost[bsize][wedge_index];
7426     rd = RDCOST(x->rdmult, rate, dist);
7427 
7428     if (rd < best_rd) {
7429       *best_wedge_index = wedge_index;
7430       *best_wedge_sign = wedge_sign;
7431       best_rd = rd;
7432     }
7433   }
7434 
7435   return best_rd -
7436          RDCOST(x->rdmult, x->wedge_idx_cost[bsize][*best_wedge_index], 0);
7437 }
7438 
7439 // Choose the best wedge index the specified sign
pick_wedge_fixed_sign(const AV1_COMP * const cpi,const MACROBLOCK * const x,const BLOCK_SIZE bsize,const int16_t * const residual1,const int16_t * const diff10,const int wedge_sign,int * const best_wedge_index)7440 static int64_t pick_wedge_fixed_sign(const AV1_COMP *const cpi,
7441                                      const MACROBLOCK *const x,
7442                                      const BLOCK_SIZE bsize,
7443                                      const int16_t *const residual1,
7444                                      const int16_t *const diff10,
7445                                      const int wedge_sign,
7446                                      int *const best_wedge_index) {
7447   const MACROBLOCKD *const xd = &x->e_mbd;
7448 
7449   const int bw = block_size_wide[bsize];
7450   const int bh = block_size_high[bsize];
7451   const int N = bw * bh;
7452   assert(N >= 64);
7453   int rate;
7454   int64_t dist;
7455   int64_t rd, best_rd = INT64_MAX;
7456   int wedge_index;
7457   int wedge_types = (1 << get_wedge_bits_lookup(bsize));
7458   const uint8_t *mask;
7459   uint64_t sse;
7460   const int hbd = xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH;
7461   const int bd_round = hbd ? (xd->bd - 8) * 2 : 0;
7462   for (wedge_index = 0; wedge_index < wedge_types; ++wedge_index) {
7463     mask = av1_get_contiguous_soft_mask(wedge_index, wedge_sign, bsize);
7464     sse = av1_wedge_sse_from_residuals(residual1, diff10, mask, N);
7465     sse = ROUND_POWER_OF_TWO(sse, bd_round);
7466 
7467     model_rd_sse_fn[MODELRD_TYPE_MASKED_COMPOUND](cpi, x, bsize, 0, sse, N,
7468                                                   &rate, &dist);
7469     rate += x->wedge_idx_cost[bsize][wedge_index];
7470     rd = RDCOST(x->rdmult, rate, dist);
7471 
7472     if (rd < best_rd) {
7473       *best_wedge_index = wedge_index;
7474       best_rd = rd;
7475     }
7476   }
7477   return best_rd -
7478          RDCOST(x->rdmult, x->wedge_idx_cost[bsize][*best_wedge_index], 0);
7479 }
7480 
pick_interinter_wedge(const AV1_COMP * const cpi,MACROBLOCK * const x,const BLOCK_SIZE bsize,const uint8_t * const p0,const uint8_t * const p1,const int16_t * const residual1,const int16_t * const diff10)7481 static int64_t pick_interinter_wedge(
7482     const AV1_COMP *const cpi, MACROBLOCK *const x, const BLOCK_SIZE bsize,
7483     const uint8_t *const p0, const uint8_t *const p1,
7484     const int16_t *const residual1, const int16_t *const diff10) {
7485   MACROBLOCKD *const xd = &x->e_mbd;
7486   MB_MODE_INFO *const mbmi = xd->mi[0];
7487   const int bw = block_size_wide[bsize];
7488 
7489   int64_t rd;
7490   int wedge_index = -1;
7491   int wedge_sign = 0;
7492 
7493   assert(is_interinter_compound_used(COMPOUND_WEDGE, bsize));
7494   assert(cpi->common.seq_params.enable_masked_compound);
7495 
7496   if (cpi->sf.fast_wedge_sign_estimate) {
7497     wedge_sign = estimate_wedge_sign(cpi, x, bsize, p0, bw, p1, bw);
7498     rd = pick_wedge_fixed_sign(cpi, x, bsize, residual1, diff10, wedge_sign,
7499                                &wedge_index);
7500   } else {
7501     rd = pick_wedge(cpi, x, bsize, p0, residual1, diff10, &wedge_sign,
7502                     &wedge_index);
7503   }
7504 
7505   mbmi->interinter_comp.wedge_sign = wedge_sign;
7506   mbmi->interinter_comp.wedge_index = wedge_index;
7507   return rd;
7508 }
7509 
pick_interinter_seg(const AV1_COMP * const cpi,MACROBLOCK * const x,const BLOCK_SIZE bsize,const uint8_t * const p0,const uint8_t * const p1,const int16_t * const residual1,const int16_t * const diff10)7510 static int64_t pick_interinter_seg(const AV1_COMP *const cpi,
7511                                    MACROBLOCK *const x, const BLOCK_SIZE bsize,
7512                                    const uint8_t *const p0,
7513                                    const uint8_t *const p1,
7514                                    const int16_t *const residual1,
7515                                    const int16_t *const diff10) {
7516   MACROBLOCKD *const xd = &x->e_mbd;
7517   MB_MODE_INFO *const mbmi = xd->mi[0];
7518   const int bw = block_size_wide[bsize];
7519   const int bh = block_size_high[bsize];
7520   const int N = 1 << num_pels_log2_lookup[bsize];
7521   int rate;
7522   int64_t dist;
7523   DIFFWTD_MASK_TYPE cur_mask_type;
7524   int64_t best_rd = INT64_MAX;
7525   DIFFWTD_MASK_TYPE best_mask_type = 0;
7526   const int hbd = xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH;
7527   const int bd_round = hbd ? (xd->bd - 8) * 2 : 0;
7528   DECLARE_ALIGNED(16, uint8_t, seg_mask[2 * MAX_SB_SQUARE]);
7529   uint8_t *tmp_mask[2] = { xd->seg_mask, seg_mask };
7530   // try each mask type and its inverse
7531   for (cur_mask_type = 0; cur_mask_type < DIFFWTD_MASK_TYPES; cur_mask_type++) {
7532     // build mask and inverse
7533     if (hbd)
7534       av1_build_compound_diffwtd_mask_highbd(
7535           tmp_mask[cur_mask_type], cur_mask_type, CONVERT_TO_BYTEPTR(p0), bw,
7536           CONVERT_TO_BYTEPTR(p1), bw, bh, bw, xd->bd);
7537     else
7538       av1_build_compound_diffwtd_mask(tmp_mask[cur_mask_type], cur_mask_type,
7539                                       p0, bw, p1, bw, bh, bw);
7540 
7541     // compute rd for mask
7542     uint64_t sse = av1_wedge_sse_from_residuals(residual1, diff10,
7543                                                 tmp_mask[cur_mask_type], N);
7544     sse = ROUND_POWER_OF_TWO(sse, bd_round);
7545 
7546     model_rd_sse_fn[MODELRD_TYPE_MASKED_COMPOUND](cpi, x, bsize, 0, sse, N,
7547                                                   &rate, &dist);
7548     const int64_t rd0 = RDCOST(x->rdmult, rate, dist);
7549 
7550     if (rd0 < best_rd) {
7551       best_mask_type = cur_mask_type;
7552       best_rd = rd0;
7553     }
7554   }
7555   mbmi->interinter_comp.mask_type = best_mask_type;
7556   if (best_mask_type == DIFFWTD_38_INV) {
7557     memcpy(xd->seg_mask, seg_mask, N * 2);
7558   }
7559   return best_rd;
7560 }
7561 
pick_interintra_wedge(const AV1_COMP * const cpi,const MACROBLOCK * const x,const BLOCK_SIZE bsize,const uint8_t * const p0,const uint8_t * const p1)7562 static int64_t pick_interintra_wedge(const AV1_COMP *const cpi,
7563                                      const MACROBLOCK *const x,
7564                                      const BLOCK_SIZE bsize,
7565                                      const uint8_t *const p0,
7566                                      const uint8_t *const p1) {
7567   const MACROBLOCKD *const xd = &x->e_mbd;
7568   MB_MODE_INFO *const mbmi = xd->mi[0];
7569   assert(is_interintra_wedge_used(bsize));
7570   assert(cpi->common.seq_params.enable_interintra_compound);
7571 
7572   const struct buf_2d *const src = &x->plane[0].src;
7573   const int bw = block_size_wide[bsize];
7574   const int bh = block_size_high[bsize];
7575   DECLARE_ALIGNED(32, int16_t, residual1[MAX_SB_SQUARE]);  // src - pred1
7576   DECLARE_ALIGNED(32, int16_t, diff10[MAX_SB_SQUARE]);     // pred1 - pred0
7577   if (get_bitdepth_data_path_index(xd)) {
7578     aom_highbd_subtract_block(bh, bw, residual1, bw, src->buf, src->stride,
7579                               CONVERT_TO_BYTEPTR(p1), bw, xd->bd);
7580     aom_highbd_subtract_block(bh, bw, diff10, bw, CONVERT_TO_BYTEPTR(p1), bw,
7581                               CONVERT_TO_BYTEPTR(p0), bw, xd->bd);
7582   } else {
7583     aom_subtract_block(bh, bw, residual1, bw, src->buf, src->stride, p1, bw);
7584     aom_subtract_block(bh, bw, diff10, bw, p1, bw, p0, bw);
7585   }
7586   int wedge_index = -1;
7587   int64_t rd =
7588       pick_wedge_fixed_sign(cpi, x, bsize, residual1, diff10, 0, &wedge_index);
7589 
7590   mbmi->interintra_wedge_sign = 0;
7591   mbmi->interintra_wedge_index = wedge_index;
7592   return rd;
7593 }
7594 
pick_interinter_mask(const AV1_COMP * const cpi,MACROBLOCK * x,const BLOCK_SIZE bsize,const uint8_t * const p0,const uint8_t * const p1,const int16_t * const residual1,const int16_t * const diff10)7595 static int64_t pick_interinter_mask(const AV1_COMP *const cpi, MACROBLOCK *x,
7596                                     const BLOCK_SIZE bsize,
7597                                     const uint8_t *const p0,
7598                                     const uint8_t *const p1,
7599                                     const int16_t *const residual1,
7600                                     const int16_t *const diff10) {
7601   const COMPOUND_TYPE compound_type = x->e_mbd.mi[0]->interinter_comp.type;
7602   switch (compound_type) {
7603     case COMPOUND_WEDGE:
7604       return pick_interinter_wedge(cpi, x, bsize, p0, p1, residual1, diff10);
7605     case COMPOUND_DIFFWTD:
7606       return pick_interinter_seg(cpi, x, bsize, p0, p1, residual1, diff10);
7607     default: assert(0); return 0;
7608   }
7609 }
7610 
interinter_compound_motion_search(const AV1_COMP * const cpi,MACROBLOCK * x,const int_mv * const cur_mv,const BLOCK_SIZE bsize,const PREDICTION_MODE this_mode,int mi_row,int mi_col)7611 static int interinter_compound_motion_search(const AV1_COMP *const cpi,
7612                                              MACROBLOCK *x,
7613                                              const int_mv *const cur_mv,
7614                                              const BLOCK_SIZE bsize,
7615                                              const PREDICTION_MODE this_mode,
7616                                              int mi_row, int mi_col) {
7617   MACROBLOCKD *const xd = &x->e_mbd;
7618   MB_MODE_INFO *const mbmi = xd->mi[0];
7619   int_mv tmp_mv[2];
7620   int tmp_rate_mv = 0;
7621   mbmi->interinter_comp.seg_mask = xd->seg_mask;
7622   const INTERINTER_COMPOUND_DATA *compound_data = &mbmi->interinter_comp;
7623 
7624   if (this_mode == NEW_NEWMV) {
7625     do_masked_motion_search_indexed(cpi, x, cur_mv, compound_data, bsize,
7626                                     mi_row, mi_col, tmp_mv, &tmp_rate_mv, 2);
7627     mbmi->mv[0].as_int = tmp_mv[0].as_int;
7628     mbmi->mv[1].as_int = tmp_mv[1].as_int;
7629   } else if (this_mode == NEW_NEARESTMV || this_mode == NEW_NEARMV) {
7630     do_masked_motion_search_indexed(cpi, x, cur_mv, compound_data, bsize,
7631                                     mi_row, mi_col, tmp_mv, &tmp_rate_mv, 0);
7632     mbmi->mv[0].as_int = tmp_mv[0].as_int;
7633   } else if (this_mode == NEAREST_NEWMV || this_mode == NEAR_NEWMV) {
7634     do_masked_motion_search_indexed(cpi, x, cur_mv, compound_data, bsize,
7635                                     mi_row, mi_col, tmp_mv, &tmp_rate_mv, 1);
7636     mbmi->mv[1].as_int = tmp_mv[1].as_int;
7637   }
7638   return tmp_rate_mv;
7639 }
7640 
get_inter_predictors_masked_compound(const AV1_COMP * const cpi,MACROBLOCK * x,const BLOCK_SIZE bsize,int mi_row,int mi_col,uint8_t ** preds0,uint8_t ** preds1,int16_t * residual1,int16_t * diff10,int * strides)7641 static void get_inter_predictors_masked_compound(
7642     const AV1_COMP *const cpi, MACROBLOCK *x, const BLOCK_SIZE bsize,
7643     int mi_row, int mi_col, uint8_t **preds0, uint8_t **preds1,
7644     int16_t *residual1, int16_t *diff10, int *strides) {
7645   const AV1_COMMON *cm = &cpi->common;
7646   MACROBLOCKD *xd = &x->e_mbd;
7647   const int bw = block_size_wide[bsize];
7648   const int bh = block_size_high[bsize];
7649   int can_use_previous = cm->allow_warped_motion;
7650   // get inter predictors to use for masked compound modes
7651   av1_build_inter_predictors_for_planes_single_buf(
7652       xd, bsize, 0, 0, mi_row, mi_col, 0, preds0, strides, can_use_previous);
7653   av1_build_inter_predictors_for_planes_single_buf(
7654       xd, bsize, 0, 0, mi_row, mi_col, 1, preds1, strides, can_use_previous);
7655   const struct buf_2d *const src = &x->plane[0].src;
7656   if (get_bitdepth_data_path_index(xd)) {
7657     aom_highbd_subtract_block(bh, bw, residual1, bw, src->buf, src->stride,
7658                               CONVERT_TO_BYTEPTR(*preds1), bw, xd->bd);
7659     aom_highbd_subtract_block(bh, bw, diff10, bw, CONVERT_TO_BYTEPTR(*preds1),
7660                               bw, CONVERT_TO_BYTEPTR(*preds0), bw, xd->bd);
7661   } else {
7662     aom_subtract_block(bh, bw, residual1, bw, src->buf, src->stride, *preds1,
7663                        bw);
7664     aom_subtract_block(bh, bw, diff10, bw, *preds1, bw, *preds0, bw);
7665   }
7666 }
7667 
build_and_cost_compound_type(const AV1_COMP * const cpi,MACROBLOCK * x,const int_mv * const cur_mv,const BLOCK_SIZE bsize,const PREDICTION_MODE this_mode,int * rs2,int rate_mv,BUFFER_SET * ctx,int * out_rate_mv,uint8_t ** preds0,uint8_t ** preds1,int16_t * residual1,int16_t * diff10,int * strides,int mi_row,int mi_col,int mode_rate,int64_t ref_best_rd,int * calc_pred_masked_compound)7668 static int64_t build_and_cost_compound_type(
7669     const AV1_COMP *const cpi, MACROBLOCK *x, const int_mv *const cur_mv,
7670     const BLOCK_SIZE bsize, const PREDICTION_MODE this_mode, int *rs2,
7671     int rate_mv, BUFFER_SET *ctx, int *out_rate_mv, uint8_t **preds0,
7672     uint8_t **preds1, int16_t *residual1, int16_t *diff10, int *strides,
7673     int mi_row, int mi_col, int mode_rate, int64_t ref_best_rd,
7674     int *calc_pred_masked_compound) {
7675   const AV1_COMMON *const cm = &cpi->common;
7676   MACROBLOCKD *xd = &x->e_mbd;
7677   MB_MODE_INFO *const mbmi = xd->mi[0];
7678   int rate_sum;
7679   int64_t dist_sum;
7680   int64_t best_rd_cur = INT64_MAX;
7681   int64_t rd = INT64_MAX;
7682   int tmp_skip_txfm_sb;
7683   int64_t tmp_skip_sse_sb;
7684   const COMPOUND_TYPE compound_type = mbmi->interinter_comp.type;
7685 
7686   if (*calc_pred_masked_compound) {
7687     get_inter_predictors_masked_compound(cpi, x, bsize, mi_row, mi_col, preds0,
7688                                          preds1, residual1, diff10, strides);
7689     *calc_pred_masked_compound = 0;
7690   }
7691 
7692   best_rd_cur =
7693       pick_interinter_mask(cpi, x, bsize, *preds0, *preds1, residual1, diff10);
7694   *rs2 += get_interinter_compound_mask_rate(x, mbmi);
7695   best_rd_cur += RDCOST(x->rdmult, *rs2 + rate_mv, 0);
7696 
7697   // Although the true rate_mv might be different after motion search, but it
7698   // is unlikely to be the best mode considering the transform rd cost and other
7699   // mode overhead cost
7700   int64_t mode_rd = RDCOST(x->rdmult, *rs2 + mode_rate, 0);
7701   if (mode_rd > ref_best_rd) return INT64_MAX;
7702 
7703   if (have_newmv_in_inter_mode(this_mode) && compound_type == COMPOUND_WEDGE) {
7704     *out_rate_mv = interinter_compound_motion_search(cpi, x, cur_mv, bsize,
7705                                                      this_mode, mi_row, mi_col);
7706     av1_build_inter_predictors_sby(cm, xd, mi_row, mi_col, ctx, bsize);
7707     model_rd_sb_fn[MODELRD_TYPE_MASKED_COMPOUND](
7708         cpi, bsize, x, xd, 0, 0, mi_row, mi_col, &rate_sum, &dist_sum,
7709         &tmp_skip_txfm_sb, &tmp_skip_sse_sb, NULL, NULL, NULL);
7710     rd = RDCOST(x->rdmult, *rs2 + *out_rate_mv + rate_sum, dist_sum);
7711     if (rd >= best_rd_cur) {
7712       mbmi->mv[0].as_int = cur_mv[0].as_int;
7713       mbmi->mv[1].as_int = cur_mv[1].as_int;
7714       *out_rate_mv = rate_mv;
7715       av1_build_wedge_inter_predictor_from_buf(xd, bsize, 0, 0, preds0, strides,
7716                                                preds1, strides);
7717     }
7718     rd = estimate_yrd_for_sb(cpi, bsize, x, &rate_sum, &dist_sum,
7719                              &tmp_skip_txfm_sb, &tmp_skip_sse_sb, INT64_MAX);
7720     if (rd != INT64_MAX)
7721       rd = RDCOST(x->rdmult, *rs2 + *out_rate_mv + rate_sum, dist_sum);
7722     best_rd_cur = rd;
7723 
7724   } else {
7725     av1_build_wedge_inter_predictor_from_buf(xd, bsize, 0, 0, preds0, strides,
7726                                              preds1, strides);
7727     rd = estimate_yrd_for_sb(cpi, bsize, x, &rate_sum, &dist_sum,
7728                              &tmp_skip_txfm_sb, &tmp_skip_sse_sb, INT64_MAX);
7729     if (rd != INT64_MAX)
7730       rd = RDCOST(x->rdmult, *rs2 + rate_mv + rate_sum, dist_sum);
7731     best_rd_cur = rd;
7732   }
7733   return best_rd_cur;
7734 }
7735 
7736 typedef struct {
7737   // OBMC secondary prediction buffers and respective strides
7738   uint8_t *above_pred_buf[MAX_MB_PLANE];
7739   int above_pred_stride[MAX_MB_PLANE];
7740   uint8_t *left_pred_buf[MAX_MB_PLANE];
7741   int left_pred_stride[MAX_MB_PLANE];
7742   int_mv (*single_newmv)[REF_FRAMES];
7743   // Pointer to array of motion vectors to use for each ref and their rates
7744   // Should point to first of 2 arrays in 2D array
7745   int (*single_newmv_rate)[REF_FRAMES];
7746   int (*single_newmv_valid)[REF_FRAMES];
7747   // Pointer to array of predicted rate-distortion
7748   // Should point to first of 2 arrays in 2D array
7749   int64_t (*modelled_rd)[MAX_REF_MV_SERCH][REF_FRAMES];
7750   InterpFilter single_filter[MB_MODE_COUNT][REF_FRAMES];
7751   int ref_frame_cost;
7752   int single_comp_cost;
7753   int64_t (*simple_rd)[MAX_REF_MV_SERCH][REF_FRAMES];
7754   int skip_motion_mode;
7755   INTERINTRA_MODE *inter_intra_mode;
7756 } HandleInterModeArgs;
7757 
7758 /* If the current mode shares the same mv with other modes with higher cost,
7759  * skip this mode. */
skip_repeated_mv(const AV1_COMMON * const cm,const MACROBLOCK * const x,PREDICTION_MODE this_mode,const MV_REFERENCE_FRAME ref_frames[2],InterModeSearchState * search_state)7760 static int skip_repeated_mv(const AV1_COMMON *const cm,
7761                             const MACROBLOCK *const x,
7762                             PREDICTION_MODE this_mode,
7763                             const MV_REFERENCE_FRAME ref_frames[2],
7764                             InterModeSearchState *search_state) {
7765   const int is_comp_pred = ref_frames[1] > INTRA_FRAME;
7766   const uint8_t ref_frame_type = av1_ref_frame_type(ref_frames);
7767   const MB_MODE_INFO_EXT *const mbmi_ext = x->mbmi_ext;
7768   const int ref_mv_count = mbmi_ext->ref_mv_count[ref_frame_type];
7769   PREDICTION_MODE compare_mode = MB_MODE_COUNT;
7770   if (!is_comp_pred) {
7771     if (this_mode == NEARMV) {
7772       if (ref_mv_count == 0) {
7773         // NEARMV has the same motion vector as NEARESTMV
7774         compare_mode = NEARESTMV;
7775       }
7776       if (ref_mv_count == 1 &&
7777           cm->global_motion[ref_frames[0]].wmtype <= TRANSLATION) {
7778         // NEARMV has the same motion vector as GLOBALMV
7779         compare_mode = GLOBALMV;
7780       }
7781     }
7782     if (this_mode == GLOBALMV) {
7783       if (ref_mv_count == 0 &&
7784           cm->global_motion[ref_frames[0]].wmtype <= TRANSLATION) {
7785         // GLOBALMV has the same motion vector as NEARESTMV
7786         compare_mode = NEARESTMV;
7787       }
7788       if (ref_mv_count == 1) {
7789         // GLOBALMV has the same motion vector as NEARMV
7790         compare_mode = NEARMV;
7791       }
7792     }
7793 
7794     if (compare_mode != MB_MODE_COUNT) {
7795       // Use modelled_rd to check whether compare mode was searched
7796       if (search_state->modelled_rd[compare_mode][0][ref_frames[0]] !=
7797           INT64_MAX) {
7798         const int16_t mode_ctx =
7799             av1_mode_context_analyzer(mbmi_ext->mode_context, ref_frames);
7800         const int compare_cost = cost_mv_ref(x, compare_mode, mode_ctx);
7801         const int this_cost = cost_mv_ref(x, this_mode, mode_ctx);
7802 
7803         // Only skip if the mode cost is larger than compare mode cost
7804         if (this_cost > compare_cost) {
7805           search_state->modelled_rd[this_mode][0][ref_frames[0]] =
7806               search_state->modelled_rd[compare_mode][0][ref_frames[0]];
7807           return 1;
7808         }
7809       }
7810     }
7811   }
7812   return 0;
7813 }
7814 
clamp_and_check_mv(int_mv * out_mv,int_mv in_mv,const AV1_COMMON * cm,const MACROBLOCK * x)7815 static INLINE int clamp_and_check_mv(int_mv *out_mv, int_mv in_mv,
7816                                      const AV1_COMMON *cm,
7817                                      const MACROBLOCK *x) {
7818   const MACROBLOCKD *const xd = &x->e_mbd;
7819   *out_mv = in_mv;
7820   lower_mv_precision(&out_mv->as_mv, cm->allow_high_precision_mv,
7821                      cm->cur_frame_force_integer_mv);
7822   clamp_mv2(&out_mv->as_mv, xd);
7823   return !mv_check_bounds(&x->mv_limits, &out_mv->as_mv);
7824 }
7825 
handle_newmv(const AV1_COMP * const cpi,MACROBLOCK * const x,const BLOCK_SIZE bsize,int_mv * cur_mv,const int mi_row,const int mi_col,int * const rate_mv,HandleInterModeArgs * const args)7826 static int64_t handle_newmv(const AV1_COMP *const cpi, MACROBLOCK *const x,
7827                             const BLOCK_SIZE bsize, int_mv *cur_mv,
7828                             const int mi_row, const int mi_col,
7829                             int *const rate_mv,
7830                             HandleInterModeArgs *const args) {
7831   const MACROBLOCKD *const xd = &x->e_mbd;
7832   const MB_MODE_INFO *const mbmi = xd->mi[0];
7833   const int is_comp_pred = has_second_ref(mbmi);
7834   const PREDICTION_MODE this_mode = mbmi->mode;
7835   const int refs[2] = { mbmi->ref_frame[0],
7836                         mbmi->ref_frame[1] < 0 ? 0 : mbmi->ref_frame[1] };
7837   const int ref_mv_idx = mbmi->ref_mv_idx;
7838   int i;
7839 
7840   (void)args;
7841 
7842   if (is_comp_pred) {
7843     if (this_mode == NEW_NEWMV) {
7844       cur_mv[0].as_int = args->single_newmv[ref_mv_idx][refs[0]].as_int;
7845       cur_mv[1].as_int = args->single_newmv[ref_mv_idx][refs[1]].as_int;
7846 
7847       if (cpi->sf.comp_inter_joint_search_thresh <= bsize) {
7848         joint_motion_search(cpi, x, bsize, cur_mv, mi_row, mi_col, NULL, NULL,
7849                             0, rate_mv, 0);
7850       } else {
7851         *rate_mv = 0;
7852         for (i = 0; i < 2; ++i) {
7853           const int_mv ref_mv = av1_get_ref_mv(x, i);
7854           av1_set_mvcost(x, i, mbmi->ref_mv_idx);
7855           *rate_mv +=
7856               av1_mv_bit_cost(&cur_mv[i].as_mv, &ref_mv.as_mv, x->nmvjointcost,
7857                               x->mvcost, MV_COST_WEIGHT);
7858         }
7859       }
7860     } else if (this_mode == NEAREST_NEWMV || this_mode == NEAR_NEWMV) {
7861       cur_mv[1].as_int = args->single_newmv[ref_mv_idx][refs[1]].as_int;
7862       if (cpi->sf.comp_inter_joint_search_thresh <= bsize) {
7863         compound_single_motion_search_interinter(
7864             cpi, x, bsize, cur_mv, mi_row, mi_col, NULL, 0, rate_mv, 0, 1);
7865       } else {
7866         av1_set_mvcost(x, 1,
7867                        mbmi->ref_mv_idx + (this_mode == NEAR_NEWMV ? 1 : 0));
7868         const int_mv ref_mv = av1_get_ref_mv(x, 1);
7869         *rate_mv = av1_mv_bit_cost(&cur_mv[1].as_mv, &ref_mv.as_mv,
7870                                    x->nmvjointcost, x->mvcost, MV_COST_WEIGHT);
7871       }
7872     } else {
7873       assert(this_mode == NEW_NEARESTMV || this_mode == NEW_NEARMV);
7874       cur_mv[0].as_int = args->single_newmv[ref_mv_idx][refs[0]].as_int;
7875       if (cpi->sf.comp_inter_joint_search_thresh <= bsize) {
7876         compound_single_motion_search_interinter(
7877             cpi, x, bsize, cur_mv, mi_row, mi_col, NULL, 0, rate_mv, 0, 0);
7878       } else {
7879         const int_mv ref_mv = av1_get_ref_mv(x, 0);
7880         av1_set_mvcost(x, 0,
7881                        mbmi->ref_mv_idx + (this_mode == NEW_NEARMV ? 1 : 0));
7882         *rate_mv = av1_mv_bit_cost(&cur_mv[0].as_mv, &ref_mv.as_mv,
7883                                    x->nmvjointcost, x->mvcost, MV_COST_WEIGHT);
7884       }
7885     }
7886   } else {
7887     single_motion_search(cpi, x, bsize, mi_row, mi_col, 0, rate_mv);
7888     if (x->best_mv.as_int == INVALID_MV) return INT64_MAX;
7889 
7890     args->single_newmv[ref_mv_idx][refs[0]] = x->best_mv;
7891     args->single_newmv_rate[ref_mv_idx][refs[0]] = *rate_mv;
7892     args->single_newmv_valid[ref_mv_idx][refs[0]] = 1;
7893 
7894     cur_mv[0].as_int = x->best_mv.as_int;
7895 
7896 #if USE_DISCOUNT_NEWMV_TEST
7897     // Estimate the rate implications of a new mv but discount this
7898     // under certain circumstances where we want to help initiate a weak
7899     // motion field, where the distortion gain for a single block may not
7900     // be enough to overcome the cost of a new mv.
7901     if (discount_newmv_test(cpi, x, this_mode, x->best_mv)) {
7902       *rate_mv = AOMMAX(*rate_mv / NEW_MV_DISCOUNT_FACTOR, 1);
7903     }
7904 #endif
7905   }
7906 
7907   return 0;
7908 }
7909 
swap_dst_buf(MACROBLOCKD * xd,const BUFFER_SET * dst_bufs[2],int num_planes)7910 static INLINE void swap_dst_buf(MACROBLOCKD *xd, const BUFFER_SET *dst_bufs[2],
7911                                 int num_planes) {
7912   const BUFFER_SET *buf0 = dst_bufs[0];
7913   dst_bufs[0] = dst_bufs[1];
7914   dst_bufs[1] = buf0;
7915   restore_dst_buf(xd, *dst_bufs[0], num_planes);
7916 }
7917 
get_switchable_rate(MACROBLOCK * const x,const InterpFilters filters,const int ctx[2])7918 static INLINE int get_switchable_rate(MACROBLOCK *const x,
7919                                       const InterpFilters filters,
7920                                       const int ctx[2]) {
7921   int inter_filter_cost;
7922   const InterpFilter filter0 = av1_extract_interp_filter(filters, 0);
7923   const InterpFilter filter1 = av1_extract_interp_filter(filters, 1);
7924   inter_filter_cost = x->switchable_interp_costs[ctx[0]][filter0];
7925   inter_filter_cost += x->switchable_interp_costs[ctx[1]][filter1];
7926   return SWITCHABLE_INTERP_RATE_FACTOR * inter_filter_cost;
7927 }
7928 
7929 // calculate the rdcost of given interpolation_filter
interpolation_filter_rd(MACROBLOCK * const x,const AV1_COMP * const cpi,BLOCK_SIZE bsize,int mi_row,int mi_col,BUFFER_SET * const orig_dst,int64_t * const rd,int * const switchable_rate,int * const skip_txfm_sb,int64_t * const skip_sse_sb,const BUFFER_SET * dst_bufs[2],int filter_idx,const int switchable_ctx[2],const int skip_pred,int * rate,int64_t * dist)7930 static INLINE int64_t interpolation_filter_rd(
7931     MACROBLOCK *const x, const AV1_COMP *const cpi, BLOCK_SIZE bsize,
7932     int mi_row, int mi_col, BUFFER_SET *const orig_dst, int64_t *const rd,
7933     int *const switchable_rate, int *const skip_txfm_sb,
7934     int64_t *const skip_sse_sb, const BUFFER_SET *dst_bufs[2], int filter_idx,
7935     const int switchable_ctx[2], const int skip_pred, int *rate,
7936     int64_t *dist) {
7937   const AV1_COMMON *cm = &cpi->common;
7938   const int num_planes = av1_num_planes(cm);
7939   MACROBLOCKD *const xd = &x->e_mbd;
7940   MB_MODE_INFO *const mbmi = xd->mi[0];
7941   int tmp_rate[2], tmp_skip_sb[2] = { 1, 1 };
7942   int64_t tmp_dist[2], tmp_skip_sse[2] = { 0, 0 };
7943 
7944   const InterpFilters last_best = mbmi->interp_filters;
7945   mbmi->interp_filters = filter_sets[filter_idx];
7946   const int tmp_rs =
7947       get_switchable_rate(x, mbmi->interp_filters, switchable_ctx);
7948 
7949   assert(skip_pred != 2);
7950   assert((skip_pred >= 0) && (skip_pred <= cpi->default_interp_skip_flags));
7951   assert(rate[0] >= 0);
7952   assert(dist[0] >= 0);
7953   assert((skip_txfm_sb[0] == 0) || (skip_txfm_sb[0] == 1));
7954   assert(skip_sse_sb[0] >= 0);
7955   assert(rate[1] >= 0);
7956   assert(dist[1] >= 0);
7957   assert((skip_txfm_sb[1] == 0) || (skip_txfm_sb[1] == 1));
7958   assert(skip_sse_sb[1] >= 0);
7959 
7960   if (skip_pred != cpi->default_interp_skip_flags) {
7961     if (skip_pred != DEFAULT_LUMA_INTERP_SKIP_FLAG) {
7962       av1_build_inter_predictors_sby(cm, xd, mi_row, mi_col, orig_dst, bsize);
7963 #if CONFIG_COLLECT_RD_STATS == 3
7964       RD_STATS rd_stats_y;
7965       select_tx_type_yrd(cpi, x, &rd_stats_y, bsize, mi_row, mi_col, INT64_MAX);
7966       PrintPredictionUnitStats(cpi, x, &rd_stats_y, bsize);
7967 #endif  // CONFIG_COLLECT_RD_STATS == 3
7968       model_rd_sb_fn[MODELRD_TYPE_INTERP_FILTER](
7969           cpi, bsize, x, xd, 0, 0, mi_row, mi_col, &tmp_rate[0], &tmp_dist[0],
7970           &tmp_skip_sb[0], &tmp_skip_sse[0], NULL, NULL, NULL);
7971       tmp_rate[1] = tmp_rate[0];
7972       tmp_dist[1] = tmp_dist[0];
7973     } else {
7974       // only luma MC is skipped
7975       tmp_rate[1] = rate[0];
7976       tmp_dist[1] = dist[0];
7977     }
7978     if (num_planes > 1) {
7979       for (int plane = 1; plane < num_planes; ++plane) {
7980         int tmp_rate_uv, tmp_skip_sb_uv;
7981         int64_t tmp_dist_uv, tmp_skip_sse_uv;
7982         int64_t tmp_rd = RDCOST(x->rdmult, tmp_rs + tmp_rate[1], tmp_dist[1]);
7983         if (tmp_rd >= *rd) {
7984           mbmi->interp_filters = last_best;
7985           return 0;
7986         }
7987         av1_build_inter_predictors_sbp(cm, xd, mi_row, mi_col, orig_dst, bsize,
7988                                        plane);
7989         model_rd_sb_fn[MODELRD_TYPE_INTERP_FILTER](
7990             cpi, bsize, x, xd, plane, plane, mi_row, mi_col, &tmp_rate_uv,
7991             &tmp_dist_uv, &tmp_skip_sb_uv, &tmp_skip_sse_uv, NULL, NULL, NULL);
7992         tmp_rate[1] =
7993             (int)AOMMIN(((int64_t)tmp_rate[1] + (int64_t)tmp_rate_uv), INT_MAX);
7994         tmp_dist[1] += tmp_dist_uv;
7995         tmp_skip_sb[1] &= tmp_skip_sb_uv;
7996         tmp_skip_sse[1] += tmp_skip_sse_uv;
7997       }
7998     }
7999   } else {
8000     // both luma and chroma MC is skipped
8001     tmp_rate[1] = rate[1];
8002     tmp_dist[1] = dist[1];
8003   }
8004   int64_t tmp_rd = RDCOST(x->rdmult, tmp_rs + tmp_rate[1], tmp_dist[1]);
8005 
8006   if (tmp_rd < *rd) {
8007     *rd = tmp_rd;
8008     *switchable_rate = tmp_rs;
8009     if (skip_pred != cpi->default_interp_skip_flags) {
8010       if (skip_pred == 0) {
8011         // Overwrite the data as current filter is the best one
8012         tmp_skip_sb[1] = tmp_skip_sb[0] & tmp_skip_sb[1];
8013         tmp_skip_sse[1] = tmp_skip_sse[0] + tmp_skip_sse[1];
8014         memcpy(rate, tmp_rate, sizeof(*rate) * 2);
8015         memcpy(dist, tmp_dist, sizeof(*dist) * 2);
8016         memcpy(skip_txfm_sb, tmp_skip_sb, sizeof(*skip_txfm_sb) * 2);
8017         memcpy(skip_sse_sb, tmp_skip_sse, sizeof(*skip_sse_sb) * 2);
8018         // As luma MC data is computed, no need to recompute after the search
8019         x->recalc_luma_mc_data = 0;
8020       } else if (skip_pred == DEFAULT_LUMA_INTERP_SKIP_FLAG) {
8021         // As luma MC data is not computed, update of luma data can be skipped
8022         rate[1] = tmp_rate[1];
8023         dist[1] = tmp_dist[1];
8024         skip_txfm_sb[1] = skip_txfm_sb[0] & tmp_skip_sb[1];
8025         skip_sse_sb[1] = skip_sse_sb[0] + tmp_skip_sse[1];
8026         // As luma MC data is not recomputed and current filter is the best,
8027         // indicate the possibility of recomputing MC data
8028         // If current buffer contains valid MC data, toggle to indicate that
8029         // luma MC data needs to be recomputed
8030         x->recalc_luma_mc_data ^= 1;
8031       }
8032       swap_dst_buf(xd, dst_bufs, num_planes);
8033     }
8034     return 1;
8035   }
8036   mbmi->interp_filters = last_best;
8037   return 0;
8038 }
8039 
8040 // Find the best rd filter in horizontal direction
find_best_horiz_interp_filter_rd(MACROBLOCK * const x,const AV1_COMP * const cpi,BLOCK_SIZE bsize,int mi_row,int mi_col,BUFFER_SET * const orig_dst,int64_t * const rd,int * const switchable_rate,int * const skip_txfm_sb,int64_t * const skip_sse_sb,const BUFFER_SET * dst_bufs[2],const int switchable_ctx[2],const int skip_hor,int * rate,int64_t * dist,int best_dual_mode)8041 static INLINE int find_best_horiz_interp_filter_rd(
8042     MACROBLOCK *const x, const AV1_COMP *const cpi, BLOCK_SIZE bsize,
8043     int mi_row, int mi_col, BUFFER_SET *const orig_dst, int64_t *const rd,
8044     int *const switchable_rate, int *const skip_txfm_sb,
8045     int64_t *const skip_sse_sb, const BUFFER_SET *dst_bufs[2],
8046     const int switchable_ctx[2], const int skip_hor, int *rate, int64_t *dist,
8047     int best_dual_mode) {
8048   int i;
8049   const int bw = block_size_wide[bsize];
8050   assert(best_dual_mode == 0);
8051   if ((bw <= 4) && (skip_hor != cpi->default_interp_skip_flags)) {
8052     int skip_pred = cpi->default_interp_skip_flags;
8053     // Process the filters in reverse order to enable reusing rate and
8054     // distortion (calcuated during EIGHTTAP_REGULAR) for MULTITAP_SHARP
8055     for (i = (SWITCHABLE_FILTERS - 1); i >= 1; --i) {
8056       if (interpolation_filter_rd(x, cpi, bsize, mi_row, mi_col, orig_dst, rd,
8057                                   switchable_rate, skip_txfm_sb, skip_sse_sb,
8058                                   dst_bufs, i, switchable_ctx, skip_pred, rate,
8059                                   dist)) {
8060         best_dual_mode = i;
8061       }
8062       skip_pred = skip_hor;
8063     }
8064   } else {
8065     for (i = 1; i < SWITCHABLE_FILTERS; ++i) {
8066       if (interpolation_filter_rd(x, cpi, bsize, mi_row, mi_col, orig_dst, rd,
8067                                   switchable_rate, skip_txfm_sb, skip_sse_sb,
8068                                   dst_bufs, i, switchable_ctx, skip_hor, rate,
8069                                   dist)) {
8070         best_dual_mode = i;
8071       }
8072     }
8073   }
8074   return best_dual_mode;
8075 }
8076 
8077 // Find the best rd filter in vertical direction
find_best_vert_interp_filter_rd(MACROBLOCK * const x,const AV1_COMP * const cpi,BLOCK_SIZE bsize,int mi_row,int mi_col,BUFFER_SET * const orig_dst,int64_t * const rd,int * const switchable_rate,int * const skip_txfm_sb,int64_t * const skip_sse_sb,const BUFFER_SET * dst_bufs[2],const int switchable_ctx[2],const int skip_ver,int * rate,int64_t * dist,int best_dual_mode,int filter_set_size)8078 static INLINE void find_best_vert_interp_filter_rd(
8079     MACROBLOCK *const x, const AV1_COMP *const cpi, BLOCK_SIZE bsize,
8080     int mi_row, int mi_col, BUFFER_SET *const orig_dst, int64_t *const rd,
8081     int *const switchable_rate, int *const skip_txfm_sb,
8082     int64_t *const skip_sse_sb, const BUFFER_SET *dst_bufs[2],
8083     const int switchable_ctx[2], const int skip_ver, int *rate, int64_t *dist,
8084     int best_dual_mode, int filter_set_size) {
8085   int i;
8086   const int bh = block_size_high[bsize];
8087   if ((bh <= 4) && (skip_ver != cpi->default_interp_skip_flags)) {
8088     int skip_pred = cpi->default_interp_skip_flags;
8089     // Process the filters in reverse order to enable reusing rate and
8090     // distortion (calcuated during EIGHTTAP_REGULAR) for MULTITAP_SHARP
8091     assert(filter_set_size == DUAL_FILTER_SET_SIZE);
8092     for (i = (filter_set_size - SWITCHABLE_FILTERS + best_dual_mode);
8093          i >= (best_dual_mode + SWITCHABLE_FILTERS); i -= SWITCHABLE_FILTERS) {
8094       interpolation_filter_rd(x, cpi, bsize, mi_row, mi_col, orig_dst, rd,
8095                               switchable_rate, skip_txfm_sb, skip_sse_sb,
8096                               dst_bufs, i, switchable_ctx, skip_pred, rate,
8097                               dist);
8098       skip_pred = skip_ver;
8099     }
8100   } else {
8101     for (i = best_dual_mode + SWITCHABLE_FILTERS; i < filter_set_size;
8102          i += SWITCHABLE_FILTERS) {
8103       interpolation_filter_rd(x, cpi, bsize, mi_row, mi_col, orig_dst, rd,
8104                               switchable_rate, skip_txfm_sb, skip_sse_sb,
8105                               dst_bufs, i, switchable_ctx, skip_ver, rate,
8106                               dist);
8107     }
8108   }
8109 }
8110 
8111 // check if there is saved result match with this search
is_interp_filter_match(const INTERPOLATION_FILTER_STATS * st,MB_MODE_INFO * const mi)8112 static INLINE int is_interp_filter_match(const INTERPOLATION_FILTER_STATS *st,
8113                                          MB_MODE_INFO *const mi) {
8114   for (int i = 0; i < 2; ++i) {
8115     if ((st->ref_frames[i] != mi->ref_frame[i]) ||
8116         (st->mv[i].as_int != mi->mv[i].as_int)) {
8117       return 0;
8118     }
8119   }
8120   if (has_second_ref(mi) && st->comp_type != mi->interinter_comp.type) return 0;
8121   return 1;
8122 }
8123 
find_interp_filter_in_stats(MACROBLOCK * x,MB_MODE_INFO * const mbmi)8124 static INLINE int find_interp_filter_in_stats(MACROBLOCK *x,
8125                                               MB_MODE_INFO *const mbmi) {
8126   const int comp_idx = mbmi->compound_idx;
8127   const int offset = x->interp_filter_stats_idx[comp_idx];
8128   for (int j = 0; j < offset; ++j) {
8129     const INTERPOLATION_FILTER_STATS *st = &x->interp_filter_stats[comp_idx][j];
8130     if (is_interp_filter_match(st, mbmi)) {
8131       mbmi->interp_filters = st->filters;
8132       return j;
8133     }
8134   }
8135   return -1;  // no match result found
8136 }
8137 
save_interp_filter_search_stat(MACROBLOCK * x,MB_MODE_INFO * const mbmi)8138 static INLINE void save_interp_filter_search_stat(MACROBLOCK *x,
8139                                                   MB_MODE_INFO *const mbmi) {
8140   const int comp_idx = mbmi->compound_idx;
8141   const int offset = x->interp_filter_stats_idx[comp_idx];
8142   if (offset < MAX_INTERP_FILTER_STATS) {
8143     INTERPOLATION_FILTER_STATS stat = { mbmi->interp_filters,
8144                                         { mbmi->mv[0], mbmi->mv[1] },
8145                                         { mbmi->ref_frame[0],
8146                                           mbmi->ref_frame[1] },
8147                                         mbmi->interinter_comp.type };
8148     x->interp_filter_stats[comp_idx][offset] = stat;
8149     x->interp_filter_stats_idx[comp_idx]++;
8150   }
8151 }
8152 
interpolation_filter_search(MACROBLOCK * const x,const AV1_COMP * const cpi,BLOCK_SIZE bsize,int mi_row,int mi_col,const BUFFER_SET * const tmp_dst,BUFFER_SET * const orig_dst,InterpFilter (* const single_filter)[REF_FRAMES],int64_t * const rd,int * const switchable_rate,int * const skip_txfm_sb,int64_t * const skip_sse_sb,const int skip_build_pred,HandleInterModeArgs * args,int64_t ref_best_rd)8153 static int64_t interpolation_filter_search(
8154     MACROBLOCK *const x, const AV1_COMP *const cpi, BLOCK_SIZE bsize,
8155     int mi_row, int mi_col, const BUFFER_SET *const tmp_dst,
8156     BUFFER_SET *const orig_dst, InterpFilter (*const single_filter)[REF_FRAMES],
8157     int64_t *const rd, int *const switchable_rate, int *const skip_txfm_sb,
8158     int64_t *const skip_sse_sb, const int skip_build_pred,
8159     HandleInterModeArgs *args, int64_t ref_best_rd) {
8160   const AV1_COMMON *cm = &cpi->common;
8161   const int num_planes = av1_num_planes(cm);
8162   MACROBLOCKD *const xd = &x->e_mbd;
8163   MB_MODE_INFO *const mbmi = xd->mi[0];
8164   const int need_search =
8165       av1_is_interp_needed(xd) && av1_is_interp_search_needed(xd);
8166   int i;
8167   // Index 0 corresponds to luma rd data and index 1 corresponds to cummulative
8168   // data of all planes
8169   int tmp_rate[2] = { 0, 0 };
8170   int64_t tmp_dist[2] = { 0, 0 };
8171   int best_skip_txfm_sb[2] = { 1, 1 };
8172   int64_t best_skip_sse_sb[2] = { 0, 0 };
8173   const int ref_frame = xd->mi[0]->ref_frame[0];
8174 
8175   (void)single_filter;
8176   int match_found = -1;
8177   const InterpFilter assign_filter = cm->interp_filter;
8178   if (cpi->sf.skip_repeat_interpolation_filter_search && need_search) {
8179     match_found = find_interp_filter_in_stats(x, mbmi);
8180   }
8181   if (!need_search || match_found == -1) {
8182     set_default_interp_filters(mbmi, assign_filter);
8183   }
8184   int switchable_ctx[2];
8185   switchable_ctx[0] = av1_get_pred_context_switchable_interp(xd, 0);
8186   switchable_ctx[1] = av1_get_pred_context_switchable_interp(xd, 1);
8187   *switchable_rate =
8188       get_switchable_rate(x, mbmi->interp_filters, switchable_ctx);
8189   if (!skip_build_pred)
8190     av1_build_inter_predictors_sb(cm, xd, mi_row, mi_col, orig_dst, bsize);
8191 
8192 #if CONFIG_COLLECT_RD_STATS == 3
8193   RD_STATS rd_stats_y;
8194   select_tx_type_yrd(cpi, x, &rd_stats_y, bsize, mi_row, mi_col, INT64_MAX);
8195   PrintPredictionUnitStats(cpi, x, &rd_stats_y, bsize);
8196 #endif  // CONFIG_COLLECT_RD_STATS == 3
8197   model_rd_sb_fn[MODELRD_TYPE_INTERP_FILTER](
8198       cpi, bsize, x, xd, 0, 0, mi_row, mi_col, &tmp_rate[0], &tmp_dist[0],
8199       &best_skip_txfm_sb[0], &best_skip_sse_sb[0], NULL, NULL, NULL);
8200   if (num_planes > 1)
8201     model_rd_sb_fn[MODELRD_TYPE_INTERP_FILTER](
8202         cpi, bsize, x, xd, 1, num_planes - 1, mi_row, mi_col, &tmp_rate[1],
8203         &tmp_dist[1], &best_skip_txfm_sb[1], &best_skip_sse_sb[1], NULL, NULL,
8204         NULL);
8205   tmp_rate[1] =
8206       (int)AOMMIN((int64_t)tmp_rate[0] + (int64_t)tmp_rate[1], INT_MAX);
8207   assert(tmp_rate[1] >= 0);
8208   tmp_dist[1] = tmp_dist[0] + tmp_dist[1];
8209   best_skip_txfm_sb[1] = best_skip_txfm_sb[0] & best_skip_txfm_sb[1];
8210   best_skip_sse_sb[1] = best_skip_sse_sb[0] + best_skip_sse_sb[1];
8211   *rd = RDCOST(x->rdmult, (*switchable_rate + tmp_rate[1]), tmp_dist[1]);
8212   *skip_txfm_sb = best_skip_txfm_sb[1];
8213   *skip_sse_sb = best_skip_sse_sb[1];
8214   x->pred_sse[ref_frame] = (unsigned int)(best_skip_sse_sb[0] >> 4);
8215 
8216   if (assign_filter != SWITCHABLE || match_found != -1) {
8217     return 0;
8218   }
8219   if (!need_search) {
8220     assert(mbmi->interp_filters ==
8221            av1_broadcast_interp_filter(EIGHTTAP_REGULAR));
8222     return 0;
8223   }
8224   if (args->modelled_rd != NULL) {
8225     if (has_second_ref(mbmi)) {
8226       const int ref_mv_idx = mbmi->ref_mv_idx;
8227       int refs[2] = { mbmi->ref_frame[0],
8228                       (mbmi->ref_frame[1] < 0 ? 0 : mbmi->ref_frame[1]) };
8229       const int mode0 = compound_ref0_mode(mbmi->mode);
8230       const int mode1 = compound_ref1_mode(mbmi->mode);
8231       const int64_t mrd = AOMMIN(args->modelled_rd[mode0][ref_mv_idx][refs[0]],
8232                                  args->modelled_rd[mode1][ref_mv_idx][refs[1]]);
8233       if ((*rd >> 1) > mrd && ref_best_rd < INT64_MAX) {
8234         return INT64_MAX;
8235       }
8236     }
8237   }
8238 
8239   x->recalc_luma_mc_data = 0;
8240   // skip_flag=xx (in binary form)
8241   // Setting 0th flag corresonds to skipping luma MC and setting 1st bt
8242   // corresponds to skipping chroma MC  skip_flag=0 corresponds to "Don't skip
8243   // luma and chroma MC"  Skip flag=1 corresponds to "Skip Luma MC only"
8244   // Skip_flag=2 is not a valid case
8245   // skip_flag=3 corresponds to "Skip both luma and chroma MC"
8246   int skip_hor = cpi->default_interp_skip_flags;
8247   int skip_ver = cpi->default_interp_skip_flags;
8248   const int is_compound = has_second_ref(mbmi);
8249   assert(is_intrabc_block(mbmi) == 0);
8250   for (int j = 0; j < 1 + is_compound; ++j) {
8251     const RefBuffer *ref_buf = &cm->frame_refs[mbmi->ref_frame[j] - LAST_FRAME];
8252     const struct scale_factors *const sf = &ref_buf->sf;
8253     // TODO(any): Refine skip flag calculation considering scaling
8254     if (av1_is_scaled(sf)) {
8255       skip_hor = 0;
8256       skip_ver = 0;
8257       break;
8258     }
8259     const MV mv = mbmi->mv[j].as_mv;
8260     int skip_hor_plane = 0;
8261     int skip_ver_plane = 0;
8262     for (int k = 0; k < AOMMAX(1, (num_planes - 1)); ++k) {
8263       struct macroblockd_plane *const pd = &xd->plane[k];
8264       const int bw = pd->width;
8265       const int bh = pd->height;
8266       const MV mv_q4 = clamp_mv_to_umv_border_sb(
8267           xd, &mv, bw, bh, pd->subsampling_x, pd->subsampling_y);
8268       const int sub_x = (mv_q4.col & SUBPEL_MASK) << SCALE_EXTRA_BITS;
8269       const int sub_y = (mv_q4.row & SUBPEL_MASK) << SCALE_EXTRA_BITS;
8270       skip_hor_plane |= ((sub_x == 0) << k);
8271       skip_ver_plane |= ((sub_y == 0) << k);
8272     }
8273     skip_hor = skip_hor & skip_hor_plane;
8274     skip_ver = skip_ver & skip_ver_plane;
8275     // It is not valid that "luma MV is sub-pel, whereas chroma MV is not"
8276     assert(skip_hor != 2);
8277     assert(skip_ver != 2);
8278   }
8279   // When compond prediction type is compound segment wedge, luma MC and chroma
8280   // MC need to go hand in hand as mask generated during luma MC is reuired for
8281   // chroma MC. If skip_hor = 0 and skip_ver = 1, mask used for chroma MC during
8282   // vertical filter decision may be incorrect as temporary MC evaluation
8283   // overwrites the mask. Make skip_ver as 0 for this case so that mask is
8284   // populated during luma MC
8285   if (is_compound && mbmi->compound_idx == 1 &&
8286       mbmi->interinter_comp.type == COMPOUND_DIFFWTD) {
8287     assert(mbmi->comp_group_idx == 1);
8288     if (skip_hor == 0 && skip_ver == 1) skip_ver = 0;
8289   }
8290   // do interp_filter search
8291   const int filter_set_size = DUAL_FILTER_SET_SIZE;
8292   restore_dst_buf(xd, *tmp_dst, num_planes);
8293   const BUFFER_SET *dst_bufs[2] = { tmp_dst, orig_dst };
8294   if (cpi->sf.use_fast_interpolation_filter_search &&
8295       cm->seq_params.enable_dual_filter) {
8296     // default to (R,R): EIGHTTAP_REGULARxEIGHTTAP_REGULAR
8297     int best_dual_mode = 0;
8298     // Find best of {R}x{R,Sm,Sh}
8299     // EIGHTTAP_REGULAR mode is calculated beforehand
8300     best_dual_mode = find_best_horiz_interp_filter_rd(
8301         x, cpi, bsize, mi_row, mi_col, orig_dst, rd, switchable_rate,
8302         best_skip_txfm_sb, best_skip_sse_sb, dst_bufs, switchable_ctx, skip_hor,
8303         tmp_rate, tmp_dist, best_dual_mode);
8304 
8305     // From best of horizontal EIGHTTAP_REGULAR modes, check vertical modes
8306     find_best_vert_interp_filter_rd(
8307         x, cpi, bsize, mi_row, mi_col, orig_dst, rd, switchable_rate,
8308         best_skip_txfm_sb, best_skip_sse_sb, dst_bufs, switchable_ctx, skip_ver,
8309         tmp_rate, tmp_dist, best_dual_mode, filter_set_size);
8310   } else {
8311     // EIGHTTAP_REGULAR mode is calculated beforehand
8312     for (i = 1; i < filter_set_size; ++i) {
8313       if (cm->seq_params.enable_dual_filter == 0) {
8314         const int16_t filter_y = filter_sets[i] & 0xffff;
8315         const int16_t filter_x = filter_sets[i] >> 16;
8316         if (filter_x != filter_y) continue;
8317       }
8318       interpolation_filter_rd(x, cpi, bsize, mi_row, mi_col, orig_dst, rd,
8319                               switchable_rate, best_skip_txfm_sb,
8320                               best_skip_sse_sb, dst_bufs, i, switchable_ctx, 0,
8321                               tmp_rate, tmp_dist);
8322       assert(x->recalc_luma_mc_data == 0);
8323     }
8324   }
8325   swap_dst_buf(xd, dst_bufs, num_planes);
8326   // Recompute final MC data if required
8327   if (x->recalc_luma_mc_data == 1) {
8328     // Recomputing final luma MC data is required only if the same was skipped
8329     // in either of the directions  Condition below is necessary, but not
8330     // sufficient
8331     assert((skip_hor == 1) || (skip_ver == 1));
8332     av1_build_inter_predictors_sby(cm, xd, mi_row, mi_col, orig_dst, bsize);
8333   }
8334   *skip_txfm_sb = best_skip_txfm_sb[1];
8335   *skip_sse_sb = best_skip_sse_sb[1];
8336   x->pred_sse[ref_frame] = (unsigned int)(best_skip_sse_sb[0] >> 4);
8337 
8338   // save search results
8339   if (cpi->sf.skip_repeat_interpolation_filter_search) {
8340     assert(match_found == -1);
8341     save_interp_filter_search_stat(x, mbmi);
8342   }
8343   return 0;
8344 }
8345 
txfm_search(const AV1_COMP * cpi,MACROBLOCK * x,BLOCK_SIZE bsize,int mi_row,int mi_col,RD_STATS * rd_stats,RD_STATS * rd_stats_y,RD_STATS * rd_stats_uv,int mode_rate,int64_t ref_best_rd)8346 static int txfm_search(const AV1_COMP *cpi, MACROBLOCK *x, BLOCK_SIZE bsize,
8347                        int mi_row, int mi_col, RD_STATS *rd_stats,
8348                        RD_STATS *rd_stats_y, RD_STATS *rd_stats_uv,
8349                        int mode_rate, int64_t ref_best_rd) {
8350   /*
8351    * This function combines y and uv planes' transform search processes
8352    * together, when the prediction is generated. It first does subtration to
8353    * obtain the prediction error. Then it calls
8354    * select_tx_type_yrd/super_block_yrd and inter_block_uvrd sequentially and
8355    * handles the early terminations happen in those functions. At the end, it
8356    * computes the rd_stats/_y/_uv accordingly.
8357    */
8358   const AV1_COMMON *cm = &cpi->common;
8359   MACROBLOCKD *const xd = &x->e_mbd;
8360   MB_MODE_INFO *const mbmi = xd->mi[0];
8361   int skip_txfm_sb = 0;
8362   const int num_planes = av1_num_planes(cm);
8363   const int ref_frame_1 = mbmi->ref_frame[1];
8364   const int64_t mode_rd = RDCOST(x->rdmult, mode_rate, 0);
8365   const int64_t rd_thresh =
8366       ref_best_rd == INT64_MAX ? INT64_MAX : ref_best_rd - mode_rd;
8367   const int skip_ctx = av1_get_skip_context(xd);
8368   const int64_t min_header_rate =
8369       mode_rate + AOMMIN(x->skip_cost[skip_ctx][0], x->skip_cost[skip_ctx][1]);
8370   // Account for minimum skip and non_skip rd.
8371   // Eventually either one of them will be added to mode_rate
8372   const int64_t min_header_rd_possible = RDCOST(x->rdmult, min_header_rate, 0);
8373 
8374   if (min_header_rd_possible > ref_best_rd) {
8375     av1_invalid_rd_stats(rd_stats_y);
8376     av1_invalid_rd_stats(rd_stats);
8377     return 0;
8378   }
8379 
8380   av1_init_rd_stats(rd_stats);
8381   av1_init_rd_stats(rd_stats_y);
8382   av1_init_rd_stats(rd_stats_uv);
8383   rd_stats->rate = mode_rate;
8384 
8385   if (!cpi->common.all_lossless)
8386     check_block_skip(cpi, bsize, x, xd, 0, num_planes - 1, &skip_txfm_sb);
8387   if (!skip_txfm_sb) {
8388     int64_t non_skip_rdcosty = INT64_MAX;
8389     int64_t skip_rdcosty = INT64_MAX;
8390     int64_t min_rdcosty = INT64_MAX;
8391     int is_cost_valid_uv = 0;
8392 
8393     // cost and distortion
8394     av1_subtract_plane(x, bsize, 0);
8395     if (cm->tx_mode == TX_MODE_SELECT && !xd->lossless[mbmi->segment_id]) {
8396       // Motion mode
8397       select_tx_type_yrd(cpi, x, rd_stats_y, bsize, mi_row, mi_col, rd_thresh);
8398 #if CONFIG_COLLECT_RD_STATS == 2
8399       PrintPredictionUnitStats(cpi, x, rd_stats_y, bsize);
8400 #endif  // CONFIG_COLLECT_RD_STATS == 2
8401     } else {
8402       super_block_yrd(cpi, x, rd_stats_y, bsize, rd_thresh);
8403       memset(mbmi->inter_tx_size, mbmi->tx_size, sizeof(mbmi->inter_tx_size));
8404       for (int i = 0; i < xd->n4_h * xd->n4_w; ++i)
8405         set_blk_skip(x, 0, i, rd_stats_y->skip);
8406     }
8407 
8408     if (rd_stats_y->rate == INT_MAX) {
8409       av1_invalid_rd_stats(rd_stats);
8410       // TODO(angiebird): check if we need this
8411       // restore_dst_buf(xd, *orig_dst, num_planes);
8412       mbmi->ref_frame[1] = ref_frame_1;
8413       return 0;
8414     }
8415 
8416     av1_merge_rd_stats(rd_stats, rd_stats_y);
8417 
8418     non_skip_rdcosty = RDCOST(
8419         x->rdmult, rd_stats->rate + x->skip_cost[skip_ctx][0], rd_stats->dist);
8420     skip_rdcosty =
8421         RDCOST(x->rdmult, mode_rate + x->skip_cost[skip_ctx][1], rd_stats->sse);
8422     min_rdcosty = AOMMIN(non_skip_rdcosty, skip_rdcosty);
8423 
8424     if (min_rdcosty > ref_best_rd) {
8425       int64_t tokenonly_rdy =
8426           AOMMIN(RDCOST(x->rdmult, rd_stats_y->rate, rd_stats_y->dist),
8427                  RDCOST(x->rdmult, 0, rd_stats_y->sse));
8428       // Invalidate rd_stats_y to skip the rest of the motion modes search
8429       if (tokenonly_rdy - (tokenonly_rdy >> cpi->sf.adaptive_txb_search_level) >
8430           rd_thresh)
8431         av1_invalid_rd_stats(rd_stats_y);
8432       mbmi->ref_frame[1] = ref_frame_1;
8433       return 0;
8434     }
8435 
8436     if (num_planes > 1) {
8437       /* clang-format off */
8438       is_cost_valid_uv =
8439           inter_block_uvrd(cpi, x, rd_stats_uv, bsize,
8440                            ref_best_rd - non_skip_rdcosty,
8441                            ref_best_rd - skip_rdcosty, FTXS_NONE);
8442       if (!is_cost_valid_uv) {
8443         mbmi->ref_frame[1] = ref_frame_1;
8444         return 0;
8445       }
8446       /* clang-format on */
8447       av1_merge_rd_stats(rd_stats, rd_stats_uv);
8448     } else {
8449       av1_init_rd_stats(rd_stats_uv);
8450     }
8451     if (rd_stats->skip) {
8452       rd_stats->rate -= rd_stats_uv->rate + rd_stats_y->rate;
8453       rd_stats_y->rate = 0;
8454       rd_stats_uv->rate = 0;
8455       rd_stats->rate += x->skip_cost[skip_ctx][1];
8456       mbmi->skip = 0;
8457       // here mbmi->skip temporarily plays a role as what this_skip2 does
8458 
8459       int64_t tmprd = RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist);
8460       if (tmprd > ref_best_rd) {
8461         mbmi->ref_frame[1] = ref_frame_1;
8462         return 0;
8463       }
8464     } else if (!xd->lossless[mbmi->segment_id] &&
8465                (RDCOST(x->rdmult,
8466                        rd_stats_y->rate + rd_stats_uv->rate +
8467                            x->skip_cost[skip_ctx][0],
8468                        rd_stats->dist) >=
8469                 RDCOST(x->rdmult, x->skip_cost[skip_ctx][1], rd_stats->sse))) {
8470       rd_stats->rate -= rd_stats_uv->rate + rd_stats_y->rate;
8471       rd_stats->rate += x->skip_cost[skip_ctx][1];
8472       rd_stats->dist = rd_stats->sse;
8473       rd_stats_y->rate = 0;
8474       rd_stats_uv->rate = 0;
8475       mbmi->skip = 1;
8476     } else {
8477       rd_stats->rate += x->skip_cost[skip_ctx][0];
8478       mbmi->skip = 0;
8479     }
8480   } else {
8481     x->skip = 1;
8482     mbmi->tx_size = tx_size_from_tx_mode(bsize, cm->tx_mode);
8483     // The cost of skip bit needs to be added.
8484     mbmi->skip = 0;
8485     rd_stats->rate += x->skip_cost[skip_ctx][1];
8486 
8487     rd_stats->dist = 0;
8488     rd_stats->sse = 0;
8489     rd_stats_y->rate = 0;
8490     rd_stats_uv->rate = 0;
8491     rd_stats->skip = 1;
8492     int64_t tmprd = RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist);
8493     if (tmprd > ref_best_rd) {
8494       mbmi->ref_frame[1] = ref_frame_1;
8495       return 0;
8496     }
8497   }
8498   return 1;
8499 }
8500 
handle_inter_intra_mode(const AV1_COMP * const cpi,MACROBLOCK * const x,BLOCK_SIZE bsize,int mi_row,int mi_col,MB_MODE_INFO * mbmi,HandleInterModeArgs * args,int64_t ref_best_rd,int * rate_mv,int * tmp_rate2,BUFFER_SET * orig_dst)8501 static int handle_inter_intra_mode(const AV1_COMP *const cpi,
8502                                    MACROBLOCK *const x, BLOCK_SIZE bsize,
8503                                    int mi_row, int mi_col, MB_MODE_INFO *mbmi,
8504                                    HandleInterModeArgs *args,
8505                                    int64_t ref_best_rd, int *rate_mv,
8506                                    int *tmp_rate2, BUFFER_SET *orig_dst) {
8507   const AV1_COMMON *const cm = &cpi->common;
8508   const int num_planes = av1_num_planes(cm);
8509   MACROBLOCKD *xd = &x->e_mbd;
8510 
8511   INTERINTRA_MODE best_interintra_mode = II_DC_PRED;
8512   int64_t rd, best_interintra_rd = INT64_MAX;
8513   int rmode, rate_sum;
8514   int64_t dist_sum;
8515   int tmp_rate_mv = 0;
8516   int tmp_skip_txfm_sb;
8517   int bw = block_size_wide[bsize];
8518   int64_t tmp_skip_sse_sb;
8519   DECLARE_ALIGNED(16, uint8_t, tmp_buf_[2 * MAX_INTERINTRA_SB_SQUARE]);
8520   DECLARE_ALIGNED(16, uint8_t, intrapred_[2 * MAX_INTERINTRA_SB_SQUARE]);
8521   uint8_t *tmp_buf = get_buf_by_bd(xd, tmp_buf_);
8522   uint8_t *intrapred = get_buf_by_bd(xd, intrapred_);
8523   const int *const interintra_mode_cost =
8524       x->interintra_mode_cost[size_group_lookup[bsize]];
8525   const int_mv mv0 = mbmi->mv[0];
8526   const int is_wedge_used = is_interintra_wedge_used(bsize);
8527   int rwedge = is_wedge_used ? x->wedge_interintra_cost[bsize][0] : 0;
8528   mbmi->ref_frame[1] = NONE_FRAME;
8529   xd->plane[0].dst.buf = tmp_buf;
8530   xd->plane[0].dst.stride = bw;
8531   av1_build_inter_predictors_sby(cm, xd, mi_row, mi_col, NULL, bsize);
8532 
8533   restore_dst_buf(xd, *orig_dst, num_planes);
8534   mbmi->ref_frame[1] = INTRA_FRAME;
8535   mbmi->use_wedge_interintra = 0;
8536   best_interintra_mode = args->inter_intra_mode[mbmi->ref_frame[0]];
8537   int j = 0;
8538   if (cpi->sf.reuse_inter_intra_mode == 0 ||
8539       best_interintra_mode == INTERINTRA_MODES) {
8540     for (j = 0; j < INTERINTRA_MODES; ++j) {
8541       mbmi->interintra_mode = (INTERINTRA_MODE)j;
8542       rmode = interintra_mode_cost[mbmi->interintra_mode];
8543       av1_build_intra_predictors_for_interintra(cm, xd, bsize, 0, orig_dst,
8544                                                 intrapred, bw);
8545       av1_combine_interintra(xd, bsize, 0, tmp_buf, bw, intrapred, bw);
8546       model_rd_sb_fn[MODELRD_TYPE_INTERINTRA](
8547           cpi, bsize, x, xd, 0, 0, mi_row, mi_col, &rate_sum, &dist_sum,
8548           &tmp_skip_txfm_sb, &tmp_skip_sse_sb, NULL, NULL, NULL);
8549       rd = RDCOST(x->rdmult, tmp_rate_mv + rate_sum + rmode, dist_sum);
8550       if (rd < best_interintra_rd) {
8551         best_interintra_rd = rd;
8552         best_interintra_mode = mbmi->interintra_mode;
8553       }
8554     }
8555     args->inter_intra_mode[mbmi->ref_frame[0]] = best_interintra_mode;
8556   }
8557   if (j == 0 || best_interintra_mode != II_SMOOTH_PRED) {
8558     mbmi->interintra_mode = best_interintra_mode;
8559     rmode = interintra_mode_cost[mbmi->interintra_mode];
8560     av1_build_intra_predictors_for_interintra(cm, xd, bsize, 0, orig_dst,
8561                                               intrapred, bw);
8562     av1_combine_interintra(xd, bsize, 0, tmp_buf, bw, intrapred, bw);
8563   }
8564   rd = estimate_yrd_for_sb(cpi, bsize, x, &rate_sum, &dist_sum,
8565                            &tmp_skip_txfm_sb, &tmp_skip_sse_sb, INT64_MAX);
8566   if (rd != INT64_MAX)
8567     rd = RDCOST(x->rdmult, *rate_mv + rmode + rate_sum + rwedge, dist_sum);
8568   best_interintra_rd = rd;
8569   if (ref_best_rd < INT64_MAX && (best_interintra_rd >> 1) > ref_best_rd) {
8570     return -1;
8571   }
8572   if (is_wedge_used) {
8573     int64_t best_interintra_rd_nowedge = rd;
8574     int64_t best_interintra_rd_wedge = INT64_MAX;
8575     int_mv tmp_mv;
8576     // Disable wedge search if source variance is small
8577     if (x->source_variance > cpi->sf.disable_wedge_search_var_thresh) {
8578       mbmi->use_wedge_interintra = 1;
8579 
8580       rwedge = av1_cost_literal(get_interintra_wedge_bits(bsize)) +
8581                x->wedge_interintra_cost[bsize][1];
8582 
8583       best_interintra_rd_wedge =
8584           pick_interintra_wedge(cpi, x, bsize, intrapred_, tmp_buf_);
8585 
8586       best_interintra_rd_wedge +=
8587           RDCOST(x->rdmult, rmode + *rate_mv + rwedge, 0);
8588       rd = INT64_MAX;
8589       // Refine motion vector.
8590       if (have_newmv_in_inter_mode(mbmi->mode)) {
8591         // get negative of mask
8592         const uint8_t *mask = av1_get_contiguous_soft_mask(
8593             mbmi->interintra_wedge_index, 1, bsize);
8594         tmp_mv = mbmi->mv[0];
8595         compound_single_motion_search(cpi, x, bsize, &tmp_mv.as_mv, mi_row,
8596                                       mi_col, intrapred, mask, bw, &tmp_rate_mv,
8597                                       0);
8598         if (mbmi->mv[0].as_int != tmp_mv.as_int) {
8599           mbmi->mv[0].as_int = tmp_mv.as_int;
8600           av1_build_inter_predictors_sby(cm, xd, mi_row, mi_col, orig_dst,
8601                                          bsize);
8602           model_rd_sb_fn[MODELRD_TYPE_MASKED_COMPOUND](
8603               cpi, bsize, x, xd, 0, 0, mi_row, mi_col, &rate_sum, &dist_sum,
8604               &tmp_skip_txfm_sb, &tmp_skip_sse_sb, NULL, NULL, NULL);
8605           rd = RDCOST(x->rdmult, tmp_rate_mv + rmode + rate_sum + rwedge,
8606                       dist_sum);
8607         }
8608       }
8609       if (rd >= best_interintra_rd_wedge) {
8610         tmp_mv.as_int = mv0.as_int;
8611         tmp_rate_mv = *rate_mv;
8612         av1_combine_interintra(xd, bsize, 0, tmp_buf, bw, intrapred, bw);
8613       }
8614       // Evaluate closer to true rd
8615       rd = estimate_yrd_for_sb(cpi, bsize, x, &rate_sum, &dist_sum,
8616                                &tmp_skip_txfm_sb, &tmp_skip_sse_sb, INT64_MAX);
8617       if (rd != INT64_MAX)
8618         rd = RDCOST(x->rdmult, rmode + tmp_rate_mv + rwedge + rate_sum,
8619                     dist_sum);
8620       best_interintra_rd_wedge = rd;
8621       if (best_interintra_rd_wedge < best_interintra_rd_nowedge) {
8622         mbmi->use_wedge_interintra = 1;
8623         mbmi->mv[0].as_int = tmp_mv.as_int;
8624         *tmp_rate2 += tmp_rate_mv - *rate_mv;
8625         *rate_mv = tmp_rate_mv;
8626       } else {
8627         mbmi->use_wedge_interintra = 0;
8628         mbmi->mv[0].as_int = mv0.as_int;
8629         av1_build_inter_predictors_sby(cm, xd, mi_row, mi_col, orig_dst, bsize);
8630       }
8631     } else {
8632       mbmi->use_wedge_interintra = 0;
8633     }
8634   }  // if (is_interintra_wedge_used(bsize))
8635   if (num_planes > 1) {
8636     av1_build_inter_predictors_sbuv(cm, xd, mi_row, mi_col, orig_dst, bsize);
8637   }
8638   return 0;
8639 }
8640 
8641 // TODO(afergs): Refactor the MBMI references in here - there's four
8642 // TODO(afergs): Refactor optional args - add them to a struct or remove
motion_mode_rd(const AV1_COMP * const cpi,MACROBLOCK * const x,BLOCK_SIZE bsize,RD_STATS * rd_stats,RD_STATS * rd_stats_y,RD_STATS * rd_stats_uv,int * disable_skip,int mi_row,int mi_col,HandleInterModeArgs * const args,int64_t ref_best_rd,const int * refs,int * rate_mv,BUFFER_SET * orig_dst,TileDataEnc * tile_data,int64_t * best_est_rd,int do_tx_search,InterModesInfo * inter_modes_info)8643 static int64_t motion_mode_rd(const AV1_COMP *const cpi, MACROBLOCK *const x,
8644                               BLOCK_SIZE bsize, RD_STATS *rd_stats,
8645                               RD_STATS *rd_stats_y, RD_STATS *rd_stats_uv,
8646                               int *disable_skip, int mi_row, int mi_col,
8647                               HandleInterModeArgs *const args,
8648                               int64_t ref_best_rd, const int *refs,
8649                               int *rate_mv, BUFFER_SET *orig_dst
8650 #if CONFIG_COLLECT_INTER_MODE_RD_STATS
8651                               ,
8652                               TileDataEnc *tile_data, int64_t *best_est_rd,
8653                               int do_tx_search, InterModesInfo *inter_modes_info
8654 #endif
8655 ) {
8656   const AV1_COMMON *const cm = &cpi->common;
8657   const int num_planes = av1_num_planes(cm);
8658   MACROBLOCKD *xd = &x->e_mbd;
8659   MB_MODE_INFO *mbmi = xd->mi[0];
8660   const int is_comp_pred = has_second_ref(mbmi);
8661   const PREDICTION_MODE this_mode = mbmi->mode;
8662   const int rate2_nocoeff = rd_stats->rate;
8663   int best_xskip, best_disable_skip = 0;
8664   RD_STATS best_rd_stats, best_rd_stats_y, best_rd_stats_uv;
8665   MB_MODE_INFO base_mbmi, best_mbmi;
8666   uint8_t best_blk_skip[MAX_MIB_SIZE * MAX_MIB_SIZE];
8667   const int rate_mv0 = *rate_mv;
8668 
8669   int interintra_allowed = cm->seq_params.enable_interintra_compound &&
8670                            is_interintra_allowed(mbmi) && mbmi->compound_idx;
8671   int pts0[SAMPLES_ARRAY_SIZE], pts_inref0[SAMPLES_ARRAY_SIZE];
8672 
8673   assert(mbmi->ref_frame[1] != INTRA_FRAME);
8674   const MV_REFERENCE_FRAME ref_frame_1 = mbmi->ref_frame[1];
8675   av1_invalid_rd_stats(&best_rd_stats);
8676   aom_clear_system_state();
8677   mbmi->num_proj_ref = 1;  // assume num_proj_ref >=1
8678   MOTION_MODE last_motion_mode_allowed = SIMPLE_TRANSLATION;
8679   if (cm->switchable_motion_mode) {
8680     last_motion_mode_allowed = motion_mode_allowed(xd->global_motion, xd, mbmi,
8681                                                    cm->allow_warped_motion);
8682   }
8683   if (last_motion_mode_allowed == WARPED_CAUSAL) {
8684     mbmi->num_proj_ref = findSamples(cm, xd, mi_row, mi_col, pts0, pts_inref0);
8685   }
8686   int total_samples = mbmi->num_proj_ref;
8687   if (total_samples == 0) {
8688     last_motion_mode_allowed = OBMC_CAUSAL;
8689   }
8690   base_mbmi = *mbmi;
8691 
8692   const int switchable_rate =
8693       av1_is_interp_needed(xd) ? av1_get_switchable_rate(cm, x, xd) : 0;
8694   int64_t best_rd = INT64_MAX;
8695   int best_rate_mv = rate_mv0;
8696   for (int mode_index = (int)SIMPLE_TRANSLATION;
8697        mode_index <= (int)last_motion_mode_allowed + interintra_allowed;
8698        mode_index++) {
8699     if (args->skip_motion_mode && mode_index) continue;
8700     int64_t tmp_rd = INT64_MAX;
8701     int tmp_rate2 = rate2_nocoeff;
8702     int is_interintra_mode = mode_index > (int)last_motion_mode_allowed;
8703     int skip_txfm_sb = 0;
8704     int tmp_rate_mv = rate_mv0;
8705 
8706     *mbmi = base_mbmi;
8707     if (is_interintra_mode) {
8708       mbmi->motion_mode = SIMPLE_TRANSLATION;
8709     } else {
8710       mbmi->motion_mode = (MOTION_MODE)mode_index;
8711       assert(mbmi->ref_frame[1] != INTRA_FRAME);
8712     }
8713 
8714     if (mbmi->motion_mode == SIMPLE_TRANSLATION && !is_interintra_mode) {
8715       // SIMPLE_TRANSLATION mode: no need to recalculate.
8716       // The prediction is calculated before motion_mode_rd() is called in
8717       // handle_inter_mode()
8718     } else if (mbmi->motion_mode == OBMC_CAUSAL) {
8719       uint32_t cur_mv = mbmi->mv[0].as_int;
8720       assert(!is_comp_pred);
8721       if (have_newmv_in_inter_mode(this_mode)) {
8722         single_motion_search(cpi, x, bsize, mi_row, mi_col, 0, &tmp_rate_mv);
8723         mbmi->mv[0].as_int = x->best_mv.as_int;
8724 #if USE_DISCOUNT_NEWMV_TEST
8725         if (discount_newmv_test(cpi, x, this_mode, mbmi->mv[0])) {
8726           tmp_rate_mv = AOMMAX((tmp_rate_mv / NEW_MV_DISCOUNT_FACTOR), 1);
8727         }
8728 #endif
8729         tmp_rate2 = rate2_nocoeff - rate_mv0 + tmp_rate_mv;
8730       }
8731       if (mbmi->mv[0].as_int != cur_mv) {
8732         av1_build_inter_predictors_sb(cm, xd, mi_row, mi_col, orig_dst, bsize);
8733       }
8734       av1_build_obmc_inter_prediction(
8735           cm, xd, mi_row, mi_col, args->above_pred_buf, args->above_pred_stride,
8736           args->left_pred_buf, args->left_pred_stride);
8737     } else if (mbmi->motion_mode == WARPED_CAUSAL) {
8738       int pts[SAMPLES_ARRAY_SIZE], pts_inref[SAMPLES_ARRAY_SIZE];
8739       mbmi->motion_mode = WARPED_CAUSAL;
8740       mbmi->wm_params.wmtype = DEFAULT_WMTYPE;
8741       mbmi->interp_filters = av1_broadcast_interp_filter(
8742           av1_unswitchable_filter(cm->interp_filter));
8743 
8744       memcpy(pts, pts0, total_samples * 2 * sizeof(*pts0));
8745       memcpy(pts_inref, pts_inref0, total_samples * 2 * sizeof(*pts_inref0));
8746       // Select the samples according to motion vector difference
8747       if (mbmi->num_proj_ref > 1) {
8748         mbmi->num_proj_ref = selectSamples(&mbmi->mv[0].as_mv, pts, pts_inref,
8749                                            mbmi->num_proj_ref, bsize);
8750       }
8751 
8752       if (!find_projection(mbmi->num_proj_ref, pts, pts_inref, bsize,
8753                            mbmi->mv[0].as_mv.row, mbmi->mv[0].as_mv.col,
8754                            &mbmi->wm_params, mi_row, mi_col)) {
8755         // Refine MV for NEWMV mode
8756         assert(!is_comp_pred);
8757         if (have_newmv_in_inter_mode(this_mode)) {
8758           const int_mv mv0 = mbmi->mv[0];
8759           const WarpedMotionParams wm_params0 = mbmi->wm_params;
8760           int num_proj_ref0 = mbmi->num_proj_ref;
8761 
8762           // Refine MV in a small range.
8763           av1_refine_warped_mv(cpi, x, bsize, mi_row, mi_col, pts0, pts_inref0,
8764                                total_samples);
8765 
8766           // Keep the refined MV and WM parameters.
8767           if (mv0.as_int != mbmi->mv[0].as_int) {
8768             const int ref = refs[0];
8769             const int_mv ref_mv = av1_get_ref_mv(x, 0);
8770             tmp_rate_mv =
8771                 av1_mv_bit_cost(&mbmi->mv[0].as_mv, &ref_mv.as_mv,
8772                                 x->nmvjointcost, x->mvcost, MV_COST_WEIGHT);
8773 
8774             if (cpi->sf.adaptive_motion_search)
8775               x->pred_mv[ref] = mbmi->mv[0].as_mv;
8776 
8777 #if USE_DISCOUNT_NEWMV_TEST
8778             if (discount_newmv_test(cpi, x, this_mode, mbmi->mv[0])) {
8779               tmp_rate_mv = AOMMAX((tmp_rate_mv / NEW_MV_DISCOUNT_FACTOR), 1);
8780             }
8781 #endif
8782             tmp_rate2 = rate2_nocoeff - rate_mv0 + tmp_rate_mv;
8783           } else {
8784             // Restore the old MV and WM parameters.
8785             mbmi->mv[0] = mv0;
8786             mbmi->wm_params = wm_params0;
8787             mbmi->num_proj_ref = num_proj_ref0;
8788           }
8789         }
8790 
8791         av1_build_inter_predictors_sb(cm, xd, mi_row, mi_col, NULL, bsize);
8792       } else {
8793         continue;
8794       }
8795     } else if (is_interintra_mode) {
8796       const int ret = handle_inter_intra_mode(
8797           cpi, x, bsize, mi_row, mi_col, mbmi, args, ref_best_rd, &tmp_rate_mv,
8798           &tmp_rate2, orig_dst);
8799       if (ret < 0) continue;
8800     }
8801 
8802     if (!cpi->common.all_lossless)
8803       check_block_skip(cpi, bsize, x, xd, 0, num_planes - 1, &skip_txfm_sb);
8804 
8805     x->skip = 0;
8806 
8807     rd_stats->dist = 0;
8808     rd_stats->sse = 0;
8809     rd_stats->skip = 1;
8810     rd_stats->rate = tmp_rate2;
8811     if (mbmi->motion_mode != WARPED_CAUSAL) rd_stats->rate += switchable_rate;
8812     if (interintra_allowed) {
8813       rd_stats->rate += x->interintra_cost[size_group_lookup[bsize]]
8814                                           [mbmi->ref_frame[1] == INTRA_FRAME];
8815       if (mbmi->ref_frame[1] == INTRA_FRAME) {
8816         rd_stats->rate += x->interintra_mode_cost[size_group_lookup[bsize]]
8817                                                  [mbmi->interintra_mode];
8818         if (is_interintra_wedge_used(bsize)) {
8819           rd_stats->rate +=
8820               x->wedge_interintra_cost[bsize][mbmi->use_wedge_interintra];
8821           if (mbmi->use_wedge_interintra) {
8822             rd_stats->rate +=
8823                 av1_cost_literal(get_interintra_wedge_bits(bsize));
8824           }
8825         }
8826       }
8827     }
8828     if ((last_motion_mode_allowed > SIMPLE_TRANSLATION) &&
8829         (mbmi->ref_frame[1] != INTRA_FRAME)) {
8830       if (last_motion_mode_allowed == WARPED_CAUSAL) {
8831         rd_stats->rate += x->motion_mode_cost[bsize][mbmi->motion_mode];
8832       } else {
8833         rd_stats->rate += x->motion_mode_cost1[bsize][mbmi->motion_mode];
8834       }
8835     }
8836 
8837     if (!skip_txfm_sb) {
8838 #if CONFIG_COLLECT_INTER_MODE_RD_STATS
8839       int64_t est_rd = 0;
8840       int est_skip = 0;
8841       if (cpi->sf.inter_mode_rd_model_estimation && cm->tile_cols == 1 &&
8842           cm->tile_rows == 1) {
8843         InterModeRdModel *md = &tile_data->inter_mode_rd_models[mbmi->sb_type];
8844         if (md->ready) {
8845           const int64_t curr_sse = get_sse(cpi, x);
8846           est_rd = get_est_rd(tile_data, mbmi->sb_type, x->rdmult, curr_sse,
8847                               rd_stats->rate);
8848           est_skip = est_rd * 0.8 > *best_est_rd;
8849           if (est_skip) {
8850             mbmi->ref_frame[1] = ref_frame_1;
8851             continue;
8852           } else {
8853             if (est_rd < *best_est_rd) {
8854               *best_est_rd = est_rd;
8855             }
8856           }
8857         }
8858       }
8859 #endif  // CONFIG_COLLECT_INTER_MODE_RD_STATS
8860     }
8861 
8862 #if CONFIG_COLLECT_INTER_MODE_RD_STATS
8863     if (!do_tx_search) {
8864       const int64_t curr_sse = get_sse(cpi, x);
8865       int est_residue_cost = 0;
8866       int64_t est_dist = 0;
8867       const int has_est_rd = get_est_rate_dist(tile_data, bsize, curr_sse,
8868                                                &est_residue_cost, &est_dist);
8869       (void)has_est_rd;
8870       assert(has_est_rd);
8871       const int mode_rate = rd_stats->rate;
8872       rd_stats->rate += est_residue_cost;
8873       rd_stats->dist = est_dist;
8874       rd_stats->rdcost = RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist);
8875       if (cm->reference_mode == SINGLE_REFERENCE) {
8876         if (!is_comp_pred) {
8877           inter_modes_info_push(inter_modes_info, mode_rate, curr_sse,
8878                                 rd_stats->rdcost, mbmi);
8879         }
8880       } else {
8881         inter_modes_info_push(inter_modes_info, mode_rate, curr_sse,
8882                               rd_stats->rdcost, mbmi);
8883       }
8884     } else {
8885 #endif
8886       int mode_rate = rd_stats->rate;
8887       if (!txfm_search(cpi, x, bsize, mi_row, mi_col, rd_stats, rd_stats_y,
8888                        rd_stats_uv, mode_rate, ref_best_rd)) {
8889         if (rd_stats_y->rate == INT_MAX && mode_index == 0) {
8890           return INT64_MAX;
8891         }
8892         continue;
8893       }
8894       if (!skip_txfm_sb) {
8895         const int64_t curr_rd =
8896             RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist);
8897         if (curr_rd < ref_best_rd) {
8898           ref_best_rd = curr_rd;
8899         }
8900         *disable_skip = 0;
8901 #if CONFIG_COLLECT_INTER_MODE_RD_STATS
8902         if (cpi->sf.inter_mode_rd_model_estimation) {
8903           const int skip_ctx = av1_get_skip_context(xd);
8904           inter_mode_data_push(tile_data, mbmi->sb_type, rd_stats->sse,
8905                                rd_stats->dist,
8906                                rd_stats_y->rate + rd_stats_uv->rate +
8907                                    x->skip_cost[skip_ctx][mbmi->skip]);
8908         }
8909 #endif  // CONFIG_COLLECT_INTER_MODE_RD_STATS
8910       } else {
8911         *disable_skip = 1;
8912       }
8913 #if CONFIG_COLLECT_INTER_MODE_RD_STATS
8914     }
8915 #endif
8916 
8917     if (this_mode == GLOBALMV || this_mode == GLOBAL_GLOBALMV) {
8918       if (is_nontrans_global_motion(xd, xd->mi[0])) {
8919         mbmi->interp_filters = av1_broadcast_interp_filter(
8920             av1_unswitchable_filter(cm->interp_filter));
8921       }
8922     }
8923 
8924     tmp_rd = RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist);
8925     if (mode_index == 0)
8926       args->simple_rd[this_mode][mbmi->ref_mv_idx][mbmi->ref_frame[0]] = tmp_rd;
8927     if ((mode_index == 0) || (tmp_rd < best_rd)) {
8928       best_mbmi = *mbmi;
8929       best_rd = tmp_rd;
8930       best_rd_stats = *rd_stats;
8931       best_rd_stats_y = *rd_stats_y;
8932       best_rate_mv = tmp_rate_mv;
8933       if (num_planes > 1) best_rd_stats_uv = *rd_stats_uv;
8934       memcpy(best_blk_skip, x->blk_skip,
8935              sizeof(x->blk_skip[0]) * xd->n4_h * xd->n4_w);
8936       best_xskip = x->skip;
8937       best_disable_skip = *disable_skip;
8938       if (best_xskip) break;
8939     }
8940   }
8941   mbmi->ref_frame[1] = ref_frame_1;
8942   *rate_mv = best_rate_mv;
8943   if (best_rd == INT64_MAX) {
8944     av1_invalid_rd_stats(rd_stats);
8945     restore_dst_buf(xd, *orig_dst, num_planes);
8946     return INT64_MAX;
8947   }
8948   *mbmi = best_mbmi;
8949   *rd_stats = best_rd_stats;
8950   *rd_stats_y = best_rd_stats_y;
8951   if (num_planes > 1) *rd_stats_uv = best_rd_stats_uv;
8952   memcpy(x->blk_skip, best_blk_skip,
8953          sizeof(x->blk_skip[0]) * xd->n4_h * xd->n4_w);
8954   x->skip = best_xskip;
8955   *disable_skip = best_disable_skip;
8956 
8957   restore_dst_buf(xd, *orig_dst, num_planes);
8958   return 0;
8959 }
8960 
skip_mode_rd(RD_STATS * rd_stats,const AV1_COMP * const cpi,MACROBLOCK * const x,BLOCK_SIZE bsize,int mi_row,int mi_col,BUFFER_SET * const orig_dst)8961 static int64_t skip_mode_rd(RD_STATS *rd_stats, const AV1_COMP *const cpi,
8962                             MACROBLOCK *const x, BLOCK_SIZE bsize, int mi_row,
8963                             int mi_col, BUFFER_SET *const orig_dst) {
8964   const AV1_COMMON *cm = &cpi->common;
8965   const int num_planes = av1_num_planes(cm);
8966   MACROBLOCKD *const xd = &x->e_mbd;
8967   av1_build_inter_predictors_sb(cm, xd, mi_row, mi_col, orig_dst, bsize);
8968 
8969   int64_t total_sse = 0;
8970   for (int plane = 0; plane < num_planes; ++plane) {
8971     const struct macroblock_plane *const p = &x->plane[plane];
8972     const struct macroblockd_plane *const pd = &xd->plane[plane];
8973     const BLOCK_SIZE plane_bsize =
8974         get_plane_block_size(bsize, pd->subsampling_x, pd->subsampling_y);
8975     const int bw = block_size_wide[plane_bsize];
8976     const int bh = block_size_high[plane_bsize];
8977 
8978     av1_subtract_plane(x, bsize, plane);
8979     int64_t sse = aom_sum_squares_2d_i16(p->src_diff, bw, bw, bh);
8980     sse = sse << 4;
8981     total_sse += sse;
8982   }
8983   const int skip_mode_ctx = av1_get_skip_mode_context(xd);
8984   rd_stats->dist = rd_stats->sse = total_sse;
8985   rd_stats->rate = x->skip_mode_cost[skip_mode_ctx][1];
8986   rd_stats->rdcost = RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist);
8987 
8988   restore_dst_buf(xd, *orig_dst, num_planes);
8989   return 0;
8990 }
8991 
get_ref_mv_offset(PREDICTION_MODE single_mode,uint8_t ref_mv_idx)8992 static INLINE int get_ref_mv_offset(PREDICTION_MODE single_mode,
8993                                     uint8_t ref_mv_idx) {
8994   assert(is_inter_singleref_mode(single_mode));
8995   int ref_mv_offset;
8996   if (single_mode == NEARESTMV) {
8997     ref_mv_offset = 0;
8998   } else if (single_mode == NEARMV) {
8999     ref_mv_offset = ref_mv_idx + 1;
9000   } else {
9001     ref_mv_offset = -1;
9002   }
9003   return ref_mv_offset;
9004 }
9005 
get_this_mv(int_mv * this_mv,PREDICTION_MODE this_mode,int ref_idx,int ref_mv_idx,const MV_REFERENCE_FRAME * ref_frame,const MB_MODE_INFO_EXT * mbmi_ext)9006 static INLINE void get_this_mv(int_mv *this_mv, PREDICTION_MODE this_mode,
9007                                int ref_idx, int ref_mv_idx,
9008                                const MV_REFERENCE_FRAME *ref_frame,
9009                                const MB_MODE_INFO_EXT *mbmi_ext) {
9010   const uint8_t ref_frame_type = av1_ref_frame_type(ref_frame);
9011   const int is_comp_pred = ref_frame[1] > INTRA_FRAME;
9012   const PREDICTION_MODE single_mode =
9013       get_single_mode(this_mode, ref_idx, is_comp_pred);
9014   assert(is_inter_singleref_mode(single_mode));
9015   if (single_mode == NEWMV) {
9016     this_mv->as_int = INVALID_MV;
9017   } else if (single_mode == GLOBALMV) {
9018     *this_mv = mbmi_ext->global_mvs[ref_frame[ref_idx]];
9019   } else {
9020     assert(single_mode == NEARMV || single_mode == NEARESTMV);
9021     const int ref_mv_offset = get_ref_mv_offset(single_mode, ref_mv_idx);
9022     if (ref_mv_offset < mbmi_ext->ref_mv_count[ref_frame_type]) {
9023       assert(ref_mv_offset >= 0);
9024       if (ref_idx == 0) {
9025         *this_mv =
9026             mbmi_ext->ref_mv_stack[ref_frame_type][ref_mv_offset].this_mv;
9027       } else {
9028         *this_mv =
9029             mbmi_ext->ref_mv_stack[ref_frame_type][ref_mv_offset].comp_mv;
9030       }
9031     } else {
9032       *this_mv = mbmi_ext->global_mvs[ref_frame[ref_idx]];
9033     }
9034   }
9035 }
9036 
9037 // This function update the non-new mv for the current prediction mode
build_cur_mv(int_mv * cur_mv,PREDICTION_MODE this_mode,const AV1_COMMON * cm,const MACROBLOCK * x)9038 static INLINE int build_cur_mv(int_mv *cur_mv, PREDICTION_MODE this_mode,
9039                                const AV1_COMMON *cm, const MACROBLOCK *x) {
9040   const MACROBLOCKD *xd = &x->e_mbd;
9041   const MB_MODE_INFO *mbmi = xd->mi[0];
9042   const int is_comp_pred = has_second_ref(mbmi);
9043   int ret = 1;
9044   for (int i = 0; i < is_comp_pred + 1; ++i) {
9045     int_mv this_mv;
9046     get_this_mv(&this_mv, this_mode, i, mbmi->ref_mv_idx, mbmi->ref_frame,
9047                 x->mbmi_ext);
9048     const PREDICTION_MODE single_mode =
9049         get_single_mode(this_mode, i, is_comp_pred);
9050     if (single_mode == NEWMV) {
9051       cur_mv[i] = this_mv;
9052     } else {
9053       ret &= clamp_and_check_mv(cur_mv + i, this_mv, cm, x);
9054     }
9055   }
9056   return ret;
9057 }
9058 
get_drl_cost(const MB_MODE_INFO * mbmi,const MB_MODE_INFO_EXT * mbmi_ext,int (* drl_mode_cost0)[2],int8_t ref_frame_type)9059 static INLINE int get_drl_cost(const MB_MODE_INFO *mbmi,
9060                                const MB_MODE_INFO_EXT *mbmi_ext,
9061                                int (*drl_mode_cost0)[2],
9062                                int8_t ref_frame_type) {
9063   int cost = 0;
9064   if (mbmi->mode == NEWMV || mbmi->mode == NEW_NEWMV) {
9065     for (int idx = 0; idx < 2; ++idx) {
9066       if (mbmi_ext->ref_mv_count[ref_frame_type] > idx + 1) {
9067         uint8_t drl_ctx =
9068             av1_drl_ctx(mbmi_ext->ref_mv_stack[ref_frame_type], idx);
9069         cost += drl_mode_cost0[drl_ctx][mbmi->ref_mv_idx != idx];
9070         if (mbmi->ref_mv_idx == idx) return cost;
9071       }
9072     }
9073     return cost;
9074   }
9075 
9076   if (have_nearmv_in_inter_mode(mbmi->mode)) {
9077     for (int idx = 1; idx < 3; ++idx) {
9078       if (mbmi_ext->ref_mv_count[ref_frame_type] > idx + 1) {
9079         uint8_t drl_ctx =
9080             av1_drl_ctx(mbmi_ext->ref_mv_stack[ref_frame_type], idx);
9081         cost += drl_mode_cost0[drl_ctx][mbmi->ref_mv_idx != (idx - 1)];
9082         if (mbmi->ref_mv_idx == (idx - 1)) return cost;
9083       }
9084     }
9085     return cost;
9086   }
9087   return cost;
9088 }
9089 
9090 // Struct for buffers used by compound_type_rd() function.
9091 // For sizes and alignment of these arrays, refer to
9092 // alloc_compound_type_rd_buffers() function.
9093 typedef struct {
9094   uint8_t *pred0;
9095   uint8_t *pred1;
9096   int16_t *residual1;          // src - pred1
9097   int16_t *diff10;             // pred1 - pred0
9098   uint8_t *tmp_best_mask_buf;  // backup of the best segmentation mask
9099 } CompoundTypeRdBuffers;
9100 
compound_type_rd(const AV1_COMP * const cpi,MACROBLOCK * x,BLOCK_SIZE bsize,int mi_col,int mi_row,int_mv * cur_mv,int masked_compound_used,BUFFER_SET * orig_dst,const BUFFER_SET * tmp_dst,CompoundTypeRdBuffers * buffers,int * rate_mv,int64_t * rd,RD_STATS * rd_stats,int64_t ref_best_rd)9101 static int compound_type_rd(const AV1_COMP *const cpi, MACROBLOCK *x,
9102                             BLOCK_SIZE bsize, int mi_col, int mi_row,
9103                             int_mv *cur_mv, int masked_compound_used,
9104                             BUFFER_SET *orig_dst, const BUFFER_SET *tmp_dst,
9105                             CompoundTypeRdBuffers *buffers, int *rate_mv,
9106                             int64_t *rd, RD_STATS *rd_stats,
9107                             int64_t ref_best_rd) {
9108   const AV1_COMMON *cm = &cpi->common;
9109   MACROBLOCKD *xd = &x->e_mbd;
9110   MB_MODE_INFO *mbmi = xd->mi[0];
9111   const PREDICTION_MODE this_mode = mbmi->mode;
9112   const int bw = block_size_wide[bsize];
9113   int rate_sum, rs2;
9114   int64_t dist_sum;
9115 
9116   int_mv best_mv[2];
9117   int best_tmp_rate_mv = *rate_mv;
9118   int tmp_skip_txfm_sb;
9119   int64_t tmp_skip_sse_sb;
9120   INTERINTER_COMPOUND_DATA best_compound_data;
9121   best_compound_data.type = COMPOUND_AVERAGE;
9122   uint8_t *preds0[1] = { buffers->pred0 };
9123   uint8_t *preds1[1] = { buffers->pred1 };
9124   int strides[1] = { bw };
9125   int tmp_rate_mv;
9126   const int num_pix = 1 << num_pels_log2_lookup[bsize];
9127   const int mask_len = 2 * num_pix * sizeof(uint8_t);
9128   COMPOUND_TYPE cur_type;
9129   int best_compmode_interinter_cost = 0;
9130   int calc_pred_masked_compound = 1;
9131 
9132   best_mv[0].as_int = cur_mv[0].as_int;
9133   best_mv[1].as_int = cur_mv[1].as_int;
9134   *rd = INT64_MAX;
9135   for (cur_type = COMPOUND_AVERAGE; cur_type < COMPOUND_TYPES; cur_type++) {
9136     if (cur_type != COMPOUND_AVERAGE && !masked_compound_used) break;
9137     if (!is_interinter_compound_used(cur_type, bsize)) continue;
9138     tmp_rate_mv = *rate_mv;
9139     int64_t best_rd_cur = INT64_MAX;
9140     mbmi->interinter_comp.type = cur_type;
9141     int masked_type_cost = 0;
9142 
9143     const int comp_group_idx_ctx = get_comp_group_idx_context(xd);
9144     const int comp_index_ctx = get_comp_index_context(cm, xd);
9145     mbmi->compound_idx = 1;
9146     if (cur_type == COMPOUND_AVERAGE) {
9147       mbmi->comp_group_idx = 0;
9148       if (masked_compound_used) {
9149         masked_type_cost += x->comp_group_idx_cost[comp_group_idx_ctx][0];
9150       }
9151       masked_type_cost += x->comp_idx_cost[comp_index_ctx][1];
9152       rs2 = masked_type_cost;
9153       const int64_t mode_rd = RDCOST(x->rdmult, rs2 + rd_stats->rate, 0);
9154       if (mode_rd < ref_best_rd) {
9155         av1_build_inter_predictors_sby(cm, xd, mi_row, mi_col, orig_dst, bsize);
9156         int64_t est_rd =
9157             estimate_yrd_for_sb(cpi, bsize, x, &rate_sum, &dist_sum,
9158                                 &tmp_skip_txfm_sb, &tmp_skip_sse_sb, INT64_MAX);
9159         if (est_rd != INT64_MAX)
9160           best_rd_cur = RDCOST(x->rdmult, rs2 + *rate_mv + rate_sum, dist_sum);
9161       }
9162       // use spare buffer for following compound type try
9163       restore_dst_buf(xd, *tmp_dst, 1);
9164     } else {
9165       mbmi->comp_group_idx = 1;
9166       masked_type_cost += x->comp_group_idx_cost[comp_group_idx_ctx][1];
9167       masked_type_cost += x->compound_type_cost[bsize][cur_type - 1];
9168       rs2 = masked_type_cost;
9169       if (x->source_variance > cpi->sf.disable_wedge_search_var_thresh &&
9170           *rd / 3 < ref_best_rd) {
9171         best_rd_cur = build_and_cost_compound_type(
9172             cpi, x, cur_mv, bsize, this_mode, &rs2, *rate_mv, orig_dst,
9173             &tmp_rate_mv, preds0, preds1, buffers->residual1, buffers->diff10,
9174             strides, mi_row, mi_col, rd_stats->rate, ref_best_rd,
9175             &calc_pred_masked_compound);
9176       }
9177     }
9178     if (best_rd_cur < *rd) {
9179       *rd = best_rd_cur;
9180       best_compound_data = mbmi->interinter_comp;
9181       if (masked_compound_used && cur_type != COMPOUND_TYPES - 1) {
9182         memcpy(buffers->tmp_best_mask_buf, xd->seg_mask, mask_len);
9183       }
9184       best_compmode_interinter_cost = rs2;
9185       if (have_newmv_in_inter_mode(this_mode)) {
9186         if (cur_type == COMPOUND_WEDGE) {
9187           best_tmp_rate_mv = tmp_rate_mv;
9188           best_mv[0].as_int = mbmi->mv[0].as_int;
9189           best_mv[1].as_int = mbmi->mv[1].as_int;
9190         } else {
9191           best_mv[0].as_int = cur_mv[0].as_int;
9192           best_mv[1].as_int = cur_mv[1].as_int;
9193         }
9194       }
9195     }
9196     // reset to original mvs for next iteration
9197     mbmi->mv[0].as_int = cur_mv[0].as_int;
9198     mbmi->mv[1].as_int = cur_mv[1].as_int;
9199   }
9200   if (mbmi->interinter_comp.type != best_compound_data.type) {
9201     mbmi->comp_group_idx =
9202         (best_compound_data.type == COMPOUND_AVERAGE) ? 0 : 1;
9203     mbmi->interinter_comp = best_compound_data;
9204     memcpy(xd->seg_mask, buffers->tmp_best_mask_buf, mask_len);
9205   }
9206   if (have_newmv_in_inter_mode(this_mode)) {
9207     mbmi->mv[0].as_int = best_mv[0].as_int;
9208     mbmi->mv[1].as_int = best_mv[1].as_int;
9209     if (mbmi->interinter_comp.type == COMPOUND_WEDGE) {
9210       rd_stats->rate += best_tmp_rate_mv - *rate_mv;
9211       *rate_mv = best_tmp_rate_mv;
9212     }
9213   }
9214   restore_dst_buf(xd, *orig_dst, 1);
9215   return best_compmode_interinter_cost;
9216 }
9217 
is_single_newmv_valid(HandleInterModeArgs * args,MB_MODE_INFO * mbmi,PREDICTION_MODE this_mode)9218 static INLINE int is_single_newmv_valid(HandleInterModeArgs *args,
9219                                         MB_MODE_INFO *mbmi,
9220                                         PREDICTION_MODE this_mode) {
9221   for (int ref_idx = 0; ref_idx < 2; ++ref_idx) {
9222     const PREDICTION_MODE single_mode = get_single_mode(this_mode, ref_idx, 1);
9223     const MV_REFERENCE_FRAME ref = mbmi->ref_frame[ref_idx];
9224     if (single_mode == NEWMV &&
9225         args->single_newmv_valid[mbmi->ref_mv_idx][ref] == 0) {
9226       return 0;
9227     }
9228   }
9229   return 1;
9230 }
9231 
get_drl_refmv_count(const MACROBLOCK * const x,const MV_REFERENCE_FRAME * ref_frame,PREDICTION_MODE mode)9232 static int get_drl_refmv_count(const MACROBLOCK *const x,
9233                                const MV_REFERENCE_FRAME *ref_frame,
9234                                PREDICTION_MODE mode) {
9235   MB_MODE_INFO_EXT *const mbmi_ext = x->mbmi_ext;
9236   const int8_t ref_frame_type = av1_ref_frame_type(ref_frame);
9237   const int has_nearmv = have_nearmv_in_inter_mode(mode) ? 1 : 0;
9238   const int ref_mv_count = mbmi_ext->ref_mv_count[ref_frame_type];
9239   const int only_newmv = (mode == NEWMV || mode == NEW_NEWMV);
9240   const int has_drl =
9241       (has_nearmv && ref_mv_count > 2) || (only_newmv && ref_mv_count > 1);
9242   const int ref_set =
9243       has_drl ? AOMMIN(MAX_REF_MV_SERCH, ref_mv_count - has_nearmv) : 1;
9244 
9245   return ref_set;
9246 }
9247 
9248 typedef struct {
9249   int64_t rd;
9250   int drl_cost;
9251   int rate_mv;
9252   int_mv mv;
9253 } inter_mode_info;
9254 
handle_inter_mode(const AV1_COMP * const cpi,MACROBLOCK * x,BLOCK_SIZE bsize,RD_STATS * rd_stats,RD_STATS * rd_stats_y,RD_STATS * rd_stats_uv,int * disable_skip,int mi_row,int mi_col,HandleInterModeArgs * args,int64_t ref_best_rd,uint8_t * const tmp_buf,CompoundTypeRdBuffers * rd_buffers,TileDataEnc * tile_data,int64_t * best_est_rd,const int do_tx_search,InterModesInfo * inter_modes_info)9255 static int64_t handle_inter_mode(const AV1_COMP *const cpi, MACROBLOCK *x,
9256                                  BLOCK_SIZE bsize, RD_STATS *rd_stats,
9257                                  RD_STATS *rd_stats_y, RD_STATS *rd_stats_uv,
9258                                  int *disable_skip, int mi_row, int mi_col,
9259                                  HandleInterModeArgs *args, int64_t ref_best_rd,
9260                                  uint8_t *const tmp_buf,
9261                                  CompoundTypeRdBuffers *rd_buffers
9262 #if CONFIG_COLLECT_INTER_MODE_RD_STATS
9263                                  ,
9264                                  TileDataEnc *tile_data, int64_t *best_est_rd,
9265                                  const int do_tx_search,
9266                                  InterModesInfo *inter_modes_info
9267 #endif
9268 ) {
9269   const AV1_COMMON *cm = &cpi->common;
9270   const int num_planes = av1_num_planes(cm);
9271   MACROBLOCKD *xd = &x->e_mbd;
9272   MB_MODE_INFO *mbmi = xd->mi[0];
9273   MB_MODE_INFO_EXT *const mbmi_ext = x->mbmi_ext;
9274   const int is_comp_pred = has_second_ref(mbmi);
9275   const PREDICTION_MODE this_mode = mbmi->mode;
9276   int i;
9277   int refs[2] = { mbmi->ref_frame[0],
9278                   (mbmi->ref_frame[1] < 0 ? 0 : mbmi->ref_frame[1]) };
9279   int rate_mv = 0;
9280   int64_t rd = INT64_MAX;
9281 
9282   // do first prediction into the destination buffer. Do the next
9283   // prediction into a temporary buffer. Then keep track of which one
9284   // of these currently holds the best predictor, and use the other
9285   // one for future predictions. In the end, copy from tmp_buf to
9286   // dst if necessary.
9287   struct macroblockd_plane *p = xd->plane;
9288   BUFFER_SET orig_dst = {
9289     { p[0].dst.buf, p[1].dst.buf, p[2].dst.buf },
9290     { p[0].dst.stride, p[1].dst.stride, p[2].dst.stride },
9291   };
9292   const BUFFER_SET tmp_dst = { { tmp_buf, tmp_buf + 1 * MAX_SB_SQUARE,
9293                                  tmp_buf + 2 * MAX_SB_SQUARE },
9294                                { MAX_SB_SIZE, MAX_SB_SIZE, MAX_SB_SIZE } };
9295 
9296   int skip_txfm_sb = 0;
9297   int64_t skip_sse_sb = INT64_MAX;
9298   int16_t mode_ctx;
9299   const int masked_compound_used = is_any_masked_compound_used(bsize) &&
9300                                    cm->seq_params.enable_masked_compound;
9301   int64_t ret_val = INT64_MAX;
9302   const int8_t ref_frame_type = av1_ref_frame_type(mbmi->ref_frame);
9303   RD_STATS best_rd_stats, best_rd_stats_y, best_rd_stats_uv;
9304   int64_t best_rd = INT64_MAX;
9305   uint8_t best_blk_skip[MAX_MIB_SIZE * MAX_MIB_SIZE];
9306   MB_MODE_INFO best_mbmi = *mbmi;
9307   int best_disable_skip;
9308   int best_xskip;
9309   int64_t newmv_ret_val = INT64_MAX;
9310   int_mv backup_mv[2] = { { 0 } };
9311   int backup_rate_mv = 0;
9312   inter_mode_info mode_info[MAX_REF_MV_SERCH];
9313 
9314   int comp_idx;
9315   const int search_jnt_comp = is_comp_pred & cm->seq_params.enable_jnt_comp &
9316                               (mbmi->mode != GLOBAL_GLOBALMV);
9317 
9318   // TODO(jingning): This should be deprecated shortly.
9319   const int has_nearmv = have_nearmv_in_inter_mode(mbmi->mode) ? 1 : 0;
9320   const int ref_set = get_drl_refmv_count(x, mbmi->ref_frame, this_mode);
9321 
9322   for (int ref_mv_idx = 0; ref_mv_idx < ref_set; ++ref_mv_idx) {
9323     mode_info[ref_mv_idx].mv.as_int = INVALID_MV;
9324     mode_info[ref_mv_idx].rd = INT64_MAX;
9325 
9326     if (cpi->sf.reduce_inter_modes && ref_mv_idx > 0) {
9327       if (mbmi->ref_frame[0] == LAST2_FRAME ||
9328           mbmi->ref_frame[0] == LAST3_FRAME ||
9329           mbmi->ref_frame[1] == LAST2_FRAME ||
9330           mbmi->ref_frame[1] == LAST3_FRAME) {
9331         if (mbmi_ext->ref_mv_stack[ref_frame_type][ref_mv_idx + has_nearmv]
9332                 .weight < REF_CAT_LEVEL) {
9333           continue;
9334         }
9335       }
9336     }
9337 
9338     av1_init_rd_stats(rd_stats);
9339 
9340     mbmi->interinter_comp.type = COMPOUND_AVERAGE;
9341     mbmi->comp_group_idx = 0;
9342     mbmi->compound_idx = 1;
9343     if (mbmi->ref_frame[1] == INTRA_FRAME) mbmi->ref_frame[1] = NONE_FRAME;
9344 
9345     mode_ctx =
9346         av1_mode_context_analyzer(mbmi_ext->mode_context, mbmi->ref_frame);
9347 
9348     mbmi->num_proj_ref = 0;
9349     mbmi->motion_mode = SIMPLE_TRANSLATION;
9350     mbmi->ref_mv_idx = ref_mv_idx;
9351 
9352     if (is_comp_pred && (!is_single_newmv_valid(args, mbmi, this_mode))) {
9353       continue;
9354     }
9355 
9356     rd_stats->rate += args->ref_frame_cost + args->single_comp_cost;
9357     const int drl_cost =
9358         get_drl_cost(mbmi, mbmi_ext, x->drl_mode_cost0, ref_frame_type);
9359     rd_stats->rate += drl_cost;
9360     mode_info[ref_mv_idx].drl_cost = drl_cost;
9361 
9362     if (RDCOST(x->rdmult, rd_stats->rate, 0) > ref_best_rd &&
9363         mbmi->mode != NEARESTMV && mbmi->mode != NEAREST_NEARESTMV) {
9364       continue;
9365     }
9366 
9367     int64_t best_rd2 = INT64_MAX;
9368 
9369     const RD_STATS backup_rd_stats = *rd_stats;
9370     // If !search_jnt_comp, we need to force mbmi->compound_idx = 1.
9371     for (comp_idx = 1; comp_idx >= !search_jnt_comp; --comp_idx) {
9372       int rs = 0;
9373       int compmode_interinter_cost = 0;
9374       mbmi->compound_idx = comp_idx;
9375       if (is_comp_pred && comp_idx == 0) {
9376         *rd_stats = backup_rd_stats;
9377         mbmi->interinter_comp.type = COMPOUND_AVERAGE;
9378         if (mbmi->ref_frame[1] == INTRA_FRAME) mbmi->ref_frame[1] = NONE_FRAME;
9379         mbmi->num_proj_ref = 0;
9380         mbmi->motion_mode = SIMPLE_TRANSLATION;
9381         mbmi->comp_group_idx = 0;
9382 
9383         const int comp_group_idx_ctx = get_comp_group_idx_context(xd);
9384         const int comp_index_ctx = get_comp_index_context(cm, xd);
9385         if (masked_compound_used) {
9386           compmode_interinter_cost +=
9387               x->comp_group_idx_cost[comp_group_idx_ctx][0];
9388         }
9389         compmode_interinter_cost += x->comp_idx_cost[comp_index_ctx][0];
9390       }
9391 
9392       int_mv cur_mv[2];
9393       if (!build_cur_mv(cur_mv, this_mode, cm, x)) {
9394         continue;
9395       }
9396       if (have_newmv_in_inter_mode(this_mode)) {
9397         if (comp_idx == 0) {
9398           cur_mv[0] = backup_mv[0];
9399           cur_mv[1] = backup_mv[1];
9400           rate_mv = backup_rate_mv;
9401         }
9402 
9403         // when jnt_comp_skip_mv_search flag is on, new mv will be searched once
9404         if (!(search_jnt_comp && cpi->sf.jnt_comp_skip_mv_search &&
9405               comp_idx == 0)) {
9406           newmv_ret_val = handle_newmv(cpi, x, bsize, cur_mv, mi_row, mi_col,
9407                                        &rate_mv, args);
9408 
9409           // Store cur_mv and rate_mv so that they can be restored in the next
9410           // iteration of the loop
9411           backup_mv[0] = cur_mv[0];
9412           backup_mv[1] = cur_mv[1];
9413           backup_rate_mv = rate_mv;
9414         }
9415 
9416         if (newmv_ret_val != 0) {
9417           continue;
9418         } else {
9419           rd_stats->rate += rate_mv;
9420         }
9421 
9422         if (cpi->sf.skip_repeated_newmv) {
9423           if (!is_comp_pred && this_mode == NEWMV && ref_mv_idx > 0) {
9424             int skip = 0;
9425             int this_rate_mv = 0;
9426             for (i = 0; i < ref_mv_idx; ++i) {
9427               // Check if the motion search result same as previous results
9428               if (cur_mv[0].as_int == args->single_newmv[i][refs[0]].as_int) {
9429                 // If the compared mode has no valid rd, it is unlikely this
9430                 // mode will be the best mode
9431                 if (mode_info[i].rd == INT64_MAX) {
9432                   skip = 1;
9433                   break;
9434                 }
9435                 // Compare the cost difference including drl cost and mv cost
9436                 if (mode_info[i].mv.as_int != INVALID_MV) {
9437                   const int compare_cost =
9438                       mode_info[i].rate_mv + mode_info[i].drl_cost;
9439                   const int_mv ref_mv = av1_get_ref_mv(x, 0);
9440                   this_rate_mv = av1_mv_bit_cost(&mode_info[i].mv.as_mv,
9441                                                  &ref_mv.as_mv, x->nmvjointcost,
9442                                                  x->mvcost, MV_COST_WEIGHT);
9443                   const int this_cost = this_rate_mv + drl_cost;
9444 
9445                   if (compare_cost < this_cost) {
9446                     skip = 1;
9447                     break;
9448                   } else {
9449                     // If the cost is less than current best result, make this
9450                     // the best and update corresponding variables
9451                     if (best_mbmi.ref_mv_idx == i) {
9452                       assert(best_rd != INT64_MAX);
9453                       best_mbmi.ref_mv_idx = ref_mv_idx;
9454                       best_rd_stats.rate += this_cost - compare_cost;
9455                       best_rd = RDCOST(x->rdmult, best_rd_stats.rate,
9456                                        best_rd_stats.dist);
9457                       if (best_rd < ref_best_rd) ref_best_rd = best_rd;
9458 
9459                       skip = 1;
9460                       break;
9461                     }
9462                   }
9463                 }
9464               }
9465             }
9466             if (skip) {
9467               args->modelled_rd[this_mode][ref_mv_idx][refs[0]] =
9468                   args->modelled_rd[this_mode][i][refs[0]];
9469               args->simple_rd[this_mode][ref_mv_idx][refs[0]] =
9470                   args->simple_rd[this_mode][i][refs[0]];
9471               mode_info[ref_mv_idx].rd = mode_info[i].rd;
9472               mode_info[ref_mv_idx].rate_mv = this_rate_mv;
9473               mode_info[ref_mv_idx].mv.as_int = mode_info[i].mv.as_int;
9474 
9475               restore_dst_buf(xd, orig_dst, num_planes);
9476               continue;
9477             }
9478           }
9479         }
9480       }
9481       for (i = 0; i < is_comp_pred + 1; ++i) {
9482         mbmi->mv[i].as_int = cur_mv[i].as_int;
9483       }
9484       const int ref_mv_cost = cost_mv_ref(x, this_mode, mode_ctx);
9485 #if USE_DISCOUNT_NEWMV_TEST
9486       // We don't include the cost of the second reference here, because there
9487       // are only three options: Last/Golden, ARF/Last or Golden/ARF, or in
9488       // other words if you present them in that order, the second one is always
9489       // known if the first is known.
9490       //
9491       // Under some circumstances we discount the cost of new mv mode to
9492       // encourage initiation of a motion field.
9493       if (discount_newmv_test(cpi, x, this_mode, mbmi->mv[0])) {
9494         // discount_newmv_test only applies discount on NEWMV mode.
9495         assert(this_mode == NEWMV);
9496         rd_stats->rate += AOMMIN(cost_mv_ref(x, this_mode, mode_ctx),
9497                                  cost_mv_ref(x, NEARESTMV, mode_ctx));
9498       } else {
9499         rd_stats->rate += ref_mv_cost;
9500       }
9501 #else
9502       rd_stats->rate += ref_mv_cost;
9503 #endif
9504 
9505       if (RDCOST(x->rdmult, rd_stats->rate, 0) > ref_best_rd &&
9506           mbmi->mode != NEARESTMV && mbmi->mode != NEAREST_NEARESTMV) {
9507         continue;
9508       }
9509 
9510       int skip_build_pred = 0;
9511       if (is_comp_pred && comp_idx) {
9512         // Find matching interp filter or set to default interp filter
9513         const int need_search =
9514             av1_is_interp_needed(xd) && av1_is_interp_search_needed(xd);
9515         int match_found = -1;
9516         const InterpFilter assign_filter = cm->interp_filter;
9517         if (cpi->sf.skip_repeat_interpolation_filter_search && need_search) {
9518           match_found = find_interp_filter_in_stats(x, mbmi);
9519         }
9520         if (!need_search || match_found == -1) {
9521           set_default_interp_filters(mbmi, assign_filter);
9522         }
9523 
9524         int64_t best_rd_compound;
9525         compmode_interinter_cost = compound_type_rd(
9526             cpi, x, bsize, mi_col, mi_row, cur_mv, masked_compound_used,
9527             &orig_dst, &tmp_dst, rd_buffers, &rate_mv, &best_rd_compound,
9528             rd_stats, ref_best_rd);
9529         if (ref_best_rd < INT64_MAX && best_rd_compound / 3 > ref_best_rd) {
9530           restore_dst_buf(xd, orig_dst, num_planes);
9531           continue;
9532         }
9533         // No need to call av1_build_inter_predictors_sby if
9534         // COMPOUND_AVERAGE is selected because it is the first
9535         // candidate in compound_type_rd, and the following
9536         // compound types searching uses tmp_dst buffer
9537         if (mbmi->interinter_comp.type == COMPOUND_AVERAGE) {
9538           if (num_planes > 1)
9539             av1_build_inter_predictors_sbuv(cm, xd, mi_row, mi_col, &orig_dst,
9540                                             bsize);
9541           skip_build_pred = 1;
9542         }
9543       }
9544 
9545       ret_val = interpolation_filter_search(
9546           x, cpi, bsize, mi_row, mi_col, &tmp_dst, &orig_dst,
9547           args->single_filter, &rd, &rs, &skip_txfm_sb, &skip_sse_sb,
9548           skip_build_pred, args, ref_best_rd);
9549       if (args->modelled_rd != NULL && !is_comp_pred) {
9550         args->modelled_rd[this_mode][ref_mv_idx][refs[0]] = rd;
9551       }
9552       if (ret_val != 0) {
9553         restore_dst_buf(xd, orig_dst, num_planes);
9554         continue;
9555       } else if (cpi->sf.model_based_post_interp_filter_breakout &&
9556                  ref_best_rd != INT64_MAX && (rd >> 3) * 3 > ref_best_rd) {
9557         restore_dst_buf(xd, orig_dst, num_planes);
9558         if ((rd >> 3) * 2 > ref_best_rd) break;
9559         continue;
9560       }
9561 
9562       if (search_jnt_comp) {
9563         // if 1/2 model rd is larger than best_rd in jnt_comp mode,
9564         // use jnt_comp mode, save additional search
9565         if ((rd >> 3) * 4 > best_rd) {
9566           restore_dst_buf(xd, orig_dst, num_planes);
9567           continue;
9568         }
9569       }
9570 
9571       if (!is_comp_pred)
9572         args->single_filter[this_mode][refs[0]] =
9573             av1_extract_interp_filter(mbmi->interp_filters, 0);
9574 
9575       if (args->modelled_rd != NULL) {
9576         if (is_comp_pred) {
9577           const int mode0 = compound_ref0_mode(this_mode);
9578           const int mode1 = compound_ref1_mode(this_mode);
9579           const int64_t mrd =
9580               AOMMIN(args->modelled_rd[mode0][ref_mv_idx][refs[0]],
9581                      args->modelled_rd[mode1][ref_mv_idx][refs[1]]);
9582           if ((rd >> 3) * 6 > mrd && ref_best_rd < INT64_MAX) {
9583             restore_dst_buf(xd, orig_dst, num_planes);
9584             continue;
9585           }
9586         }
9587       }
9588       rd_stats->rate += compmode_interinter_cost;
9589 
9590       if (search_jnt_comp && cpi->sf.jnt_comp_fast_tx_search && comp_idx == 0) {
9591         // TODO(chengchen): this speed feature introduces big loss.
9592         // Need better estimation of rate distortion.
9593         int dummy_rate;
9594         int64_t dummy_dist;
9595         int plane_rate[MAX_MB_PLANE] = { 0 };
9596         int64_t plane_sse[MAX_MB_PLANE] = { 0 };
9597         int64_t plane_dist[MAX_MB_PLANE] = { 0 };
9598 
9599         model_rd_sb_fn[MODELRD_TYPE_JNT_COMPOUND](
9600             cpi, bsize, x, xd, 0, num_planes - 1, mi_row, mi_col, &dummy_rate,
9601             &dummy_dist, &skip_txfm_sb, &skip_sse_sb, plane_rate, plane_sse,
9602             plane_dist);
9603 
9604         rd_stats->rate += rs;
9605         rd_stats->rate += plane_rate[0] + plane_rate[1] + plane_rate[2];
9606         rd_stats_y->rate = plane_rate[0];
9607         rd_stats_uv->rate = plane_rate[1] + plane_rate[2];
9608         rd_stats->sse = plane_sse[0] + plane_sse[1] + plane_sse[2];
9609         rd_stats_y->sse = plane_sse[0];
9610         rd_stats_uv->sse = plane_sse[1] + plane_sse[2];
9611         rd_stats->dist = plane_dist[0] + plane_dist[1] + plane_dist[2];
9612         rd_stats_y->dist = plane_dist[0];
9613         rd_stats_uv->dist = plane_dist[1] + plane_dist[2];
9614       } else {
9615 #if CONFIG_COLLECT_INTER_MODE_RD_STATS
9616         ret_val = motion_mode_rd(
9617             cpi, x, bsize, rd_stats, rd_stats_y, rd_stats_uv, disable_skip,
9618             mi_row, mi_col, args, ref_best_rd, refs, &rate_mv, &orig_dst,
9619             tile_data, best_est_rd, do_tx_search, inter_modes_info);
9620 #else
9621         ret_val = motion_mode_rd(cpi, x, bsize, rd_stats, rd_stats_y,
9622                                  rd_stats_uv, disable_skip, mi_row, mi_col,
9623                                  args, ref_best_rd, refs, &rate_mv, &orig_dst);
9624 #endif
9625       }
9626       mode_info[ref_mv_idx].mv.as_int = mbmi->mv[0].as_int;
9627       mode_info[ref_mv_idx].rate_mv = rate_mv;
9628       if (ret_val != INT64_MAX) {
9629         int64_t tmp_rd = RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist);
9630         mode_info[ref_mv_idx].rd = tmp_rd;
9631         if (tmp_rd < best_rd) {
9632           best_rd_stats = *rd_stats;
9633           best_rd_stats_y = *rd_stats_y;
9634           best_rd_stats_uv = *rd_stats_uv;
9635           best_rd = tmp_rd;
9636           best_mbmi = *mbmi;
9637           best_disable_skip = *disable_skip;
9638           best_xskip = x->skip;
9639           memcpy(best_blk_skip, x->blk_skip,
9640                  sizeof(best_blk_skip[0]) * xd->n4_h * xd->n4_w);
9641         }
9642 
9643         if (tmp_rd < best_rd2) {
9644           best_rd2 = tmp_rd;
9645         }
9646 
9647         if (tmp_rd < ref_best_rd) {
9648           ref_best_rd = tmp_rd;
9649         }
9650       }
9651       restore_dst_buf(xd, orig_dst, num_planes);
9652     }
9653   }
9654 
9655   if (best_rd == INT64_MAX) return INT64_MAX;
9656 
9657   // re-instate status of the best choice
9658   *rd_stats = best_rd_stats;
9659   *rd_stats_y = best_rd_stats_y;
9660   *rd_stats_uv = best_rd_stats_uv;
9661   *mbmi = best_mbmi;
9662   *disable_skip = best_disable_skip;
9663   x->skip = best_xskip;
9664   assert(IMPLIES(mbmi->comp_group_idx == 1,
9665                  mbmi->interinter_comp.type != COMPOUND_AVERAGE));
9666   memcpy(x->blk_skip, best_blk_skip,
9667          sizeof(best_blk_skip[0]) * xd->n4_h * xd->n4_w);
9668 
9669   return RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist);
9670 }
9671 
rd_pick_intrabc_mode_sb(const AV1_COMP * cpi,MACROBLOCK * x,RD_STATS * rd_cost,BLOCK_SIZE bsize,int64_t best_rd)9672 static int64_t rd_pick_intrabc_mode_sb(const AV1_COMP *cpi, MACROBLOCK *x,
9673                                        RD_STATS *rd_cost, BLOCK_SIZE bsize,
9674                                        int64_t best_rd) {
9675   const AV1_COMMON *const cm = &cpi->common;
9676   if (!av1_allow_intrabc(cm)) return INT64_MAX;
9677   const int num_planes = av1_num_planes(cm);
9678 
9679   MACROBLOCKD *const xd = &x->e_mbd;
9680   const TileInfo *tile = &xd->tile;
9681   MB_MODE_INFO *mbmi = xd->mi[0];
9682   const int mi_row = -xd->mb_to_top_edge / (8 * MI_SIZE);
9683   const int mi_col = -xd->mb_to_left_edge / (8 * MI_SIZE);
9684   const int w = block_size_wide[bsize];
9685   const int h = block_size_high[bsize];
9686   const int sb_row = mi_row >> cm->seq_params.mib_size_log2;
9687   const int sb_col = mi_col >> cm->seq_params.mib_size_log2;
9688 
9689   MB_MODE_INFO_EXT *const mbmi_ext = x->mbmi_ext;
9690   MV_REFERENCE_FRAME ref_frame = INTRA_FRAME;
9691   av1_find_mv_refs(cm, xd, mbmi, ref_frame, mbmi_ext->ref_mv_count,
9692                    mbmi_ext->ref_mv_stack, NULL, mbmi_ext->global_mvs, mi_row,
9693                    mi_col, mbmi_ext->mode_context);
9694 
9695   int_mv nearestmv, nearmv;
9696   av1_find_best_ref_mvs_from_stack(0, mbmi_ext, ref_frame, &nearestmv, &nearmv,
9697                                    0);
9698 
9699   if (nearestmv.as_int == INVALID_MV) {
9700     nearestmv.as_int = 0;
9701   }
9702   if (nearmv.as_int == INVALID_MV) {
9703     nearmv.as_int = 0;
9704   }
9705 
9706   int_mv dv_ref = nearestmv.as_int == 0 ? nearmv : nearestmv;
9707   if (dv_ref.as_int == 0)
9708     av1_find_ref_dv(&dv_ref, tile, cm->seq_params.mib_size, mi_row, mi_col);
9709   // Ref DV should not have sub-pel.
9710   assert((dv_ref.as_mv.col & 7) == 0);
9711   assert((dv_ref.as_mv.row & 7) == 0);
9712   mbmi_ext->ref_mv_stack[INTRA_FRAME][0].this_mv = dv_ref;
9713 
9714   struct buf_2d yv12_mb[MAX_MB_PLANE];
9715   av1_setup_pred_block(xd, yv12_mb, xd->cur_buf, mi_row, mi_col, NULL, NULL,
9716                        num_planes);
9717   for (int i = 0; i < num_planes; ++i) {
9718     xd->plane[i].pre[0] = yv12_mb[i];
9719   }
9720 
9721   enum IntrabcMotionDirection {
9722     IBC_MOTION_ABOVE,
9723     IBC_MOTION_LEFT,
9724     IBC_MOTION_DIRECTIONS
9725   };
9726 
9727   MB_MODE_INFO best_mbmi = *mbmi;
9728   RD_STATS best_rdcost = *rd_cost;
9729   int best_skip = x->skip;
9730 
9731   uint8_t best_blk_skip[MAX_MIB_SIZE * MAX_MIB_SIZE] = { 0 };
9732   for (enum IntrabcMotionDirection dir = IBC_MOTION_ABOVE;
9733        dir < IBC_MOTION_DIRECTIONS; ++dir) {
9734     const MvLimits tmp_mv_limits = x->mv_limits;
9735     switch (dir) {
9736       case IBC_MOTION_ABOVE:
9737         x->mv_limits.col_min = (tile->mi_col_start - mi_col) * MI_SIZE;
9738         x->mv_limits.col_max = (tile->mi_col_end - mi_col) * MI_SIZE - w;
9739         x->mv_limits.row_min = (tile->mi_row_start - mi_row) * MI_SIZE;
9740         x->mv_limits.row_max =
9741             (sb_row * cm->seq_params.mib_size - mi_row) * MI_SIZE - h;
9742         break;
9743       case IBC_MOTION_LEFT:
9744         x->mv_limits.col_min = (tile->mi_col_start - mi_col) * MI_SIZE;
9745         x->mv_limits.col_max =
9746             (sb_col * cm->seq_params.mib_size - mi_col) * MI_SIZE - w;
9747         // TODO(aconverse@google.com): Minimize the overlap between above and
9748         // left areas.
9749         x->mv_limits.row_min = (tile->mi_row_start - mi_row) * MI_SIZE;
9750         int bottom_coded_mi_edge =
9751             AOMMIN((sb_row + 1) * cm->seq_params.mib_size, tile->mi_row_end);
9752         x->mv_limits.row_max = (bottom_coded_mi_edge - mi_row) * MI_SIZE - h;
9753         break;
9754       default: assert(0);
9755     }
9756     assert(x->mv_limits.col_min >= tmp_mv_limits.col_min);
9757     assert(x->mv_limits.col_max <= tmp_mv_limits.col_max);
9758     assert(x->mv_limits.row_min >= tmp_mv_limits.row_min);
9759     assert(x->mv_limits.row_max <= tmp_mv_limits.row_max);
9760     av1_set_mv_search_range(&x->mv_limits, &dv_ref.as_mv);
9761 
9762     if (x->mv_limits.col_max < x->mv_limits.col_min ||
9763         x->mv_limits.row_max < x->mv_limits.row_min) {
9764       x->mv_limits = tmp_mv_limits;
9765       continue;
9766     }
9767 
9768     int step_param = cpi->mv_step_param;
9769     MV mvp_full = dv_ref.as_mv;
9770     mvp_full.col >>= 3;
9771     mvp_full.row >>= 3;
9772     int sadpb = x->sadperbit16;
9773     int cost_list[5];
9774     int bestsme = av1_full_pixel_search(
9775         cpi, x, bsize, &mvp_full, step_param, cpi->sf.mv.search_method, 0,
9776         sadpb, cond_cost_list(cpi, cost_list), &dv_ref.as_mv, INT_MAX, 1,
9777         (MI_SIZE * mi_col), (MI_SIZE * mi_row), 1);
9778 
9779     x->mv_limits = tmp_mv_limits;
9780     if (bestsme == INT_MAX) continue;
9781     mvp_full = x->best_mv.as_mv;
9782     MV dv = { .row = mvp_full.row * 8, .col = mvp_full.col * 8 };
9783     if (mv_check_bounds(&x->mv_limits, &dv)) continue;
9784     if (!av1_is_dv_valid(dv, cm, xd, mi_row, mi_col, bsize,
9785                          cm->seq_params.mib_size_log2))
9786       continue;
9787 
9788     // DV should not have sub-pel.
9789     assert((dv.col & 7) == 0);
9790     assert((dv.row & 7) == 0);
9791     memset(&mbmi->palette_mode_info, 0, sizeof(mbmi->palette_mode_info));
9792     mbmi->filter_intra_mode_info.use_filter_intra = 0;
9793     mbmi->use_intrabc = 1;
9794     mbmi->mode = DC_PRED;
9795     mbmi->uv_mode = UV_DC_PRED;
9796     mbmi->motion_mode = SIMPLE_TRANSLATION;
9797     mbmi->mv[0].as_mv = dv;
9798     mbmi->interp_filters = av1_broadcast_interp_filter(BILINEAR);
9799     mbmi->skip = 0;
9800     x->skip = 0;
9801     av1_build_inter_predictors_sb(cm, xd, mi_row, mi_col, NULL, bsize);
9802 
9803     int *dvcost[2] = { (int *)&cpi->dv_cost[0][MV_MAX],
9804                        (int *)&cpi->dv_cost[1][MV_MAX] };
9805     // TODO(aconverse@google.com): The full motion field defining discount
9806     // in MV_COST_WEIGHT is too large. Explore other values.
9807     int rate_mv = av1_mv_bit_cost(&dv, &dv_ref.as_mv, cpi->dv_joint_cost,
9808                                   dvcost, MV_COST_WEIGHT_SUB);
9809     const int rate_mode = x->intrabc_cost[1];
9810     RD_STATS rd_stats, rd_stats_uv;
9811     av1_subtract_plane(x, bsize, 0);
9812     if (cm->tx_mode == TX_MODE_SELECT && !xd->lossless[mbmi->segment_id]) {
9813       // Intrabc
9814       select_tx_type_yrd(cpi, x, &rd_stats, bsize, mi_row, mi_col, INT64_MAX);
9815     } else {
9816       super_block_yrd(cpi, x, &rd_stats, bsize, INT64_MAX);
9817       memset(mbmi->inter_tx_size, mbmi->tx_size, sizeof(mbmi->inter_tx_size));
9818       for (int i = 0; i < xd->n4_h * xd->n4_w; ++i)
9819         set_blk_skip(x, 0, i, rd_stats.skip);
9820     }
9821     if (num_planes > 1) {
9822       super_block_uvrd(cpi, x, &rd_stats_uv, bsize, INT64_MAX);
9823       av1_merge_rd_stats(&rd_stats, &rd_stats_uv);
9824     }
9825 #if CONFIG_RD_DEBUG
9826     mbmi->rd_stats = rd_stats;
9827 #endif
9828 
9829     const int skip_ctx = av1_get_skip_context(xd);
9830 
9831     RD_STATS rdc_noskip;
9832     av1_init_rd_stats(&rdc_noskip);
9833     rdc_noskip.rate =
9834         rate_mode + rate_mv + rd_stats.rate + x->skip_cost[skip_ctx][0];
9835     rdc_noskip.dist = rd_stats.dist;
9836     rdc_noskip.rdcost = RDCOST(x->rdmult, rdc_noskip.rate, rdc_noskip.dist);
9837     if (rdc_noskip.rdcost < best_rd) {
9838       best_rd = rdc_noskip.rdcost;
9839       best_mbmi = *mbmi;
9840       best_skip = x->skip;
9841       best_rdcost = rdc_noskip;
9842       memcpy(best_blk_skip, x->blk_skip,
9843              sizeof(x->blk_skip[0]) * xd->n4_h * xd->n4_w);
9844     }
9845 
9846     if (!xd->lossless[mbmi->segment_id]) {
9847       x->skip = 1;
9848       mbmi->skip = 1;
9849       RD_STATS rdc_skip;
9850       av1_init_rd_stats(&rdc_skip);
9851       rdc_skip.rate = rate_mode + rate_mv + x->skip_cost[skip_ctx][1];
9852       rdc_skip.dist = rd_stats.sse;
9853       rdc_skip.rdcost = RDCOST(x->rdmult, rdc_skip.rate, rdc_skip.dist);
9854       if (rdc_skip.rdcost < best_rd) {
9855         best_rd = rdc_skip.rdcost;
9856         best_mbmi = *mbmi;
9857         best_skip = x->skip;
9858         best_rdcost = rdc_skip;
9859         memcpy(best_blk_skip, x->blk_skip,
9860                sizeof(x->blk_skip[0]) * xd->n4_h * xd->n4_w);
9861       }
9862     }
9863   }
9864   *mbmi = best_mbmi;
9865   *rd_cost = best_rdcost;
9866   x->skip = best_skip;
9867   memcpy(x->blk_skip, best_blk_skip,
9868          sizeof(x->blk_skip[0]) * xd->n4_h * xd->n4_w);
9869   return best_rd;
9870 }
9871 
av1_rd_pick_intra_mode_sb(const AV1_COMP * cpi,MACROBLOCK * x,int mi_row,int mi_col,RD_STATS * rd_cost,BLOCK_SIZE bsize,PICK_MODE_CONTEXT * ctx,int64_t best_rd)9872 void av1_rd_pick_intra_mode_sb(const AV1_COMP *cpi, MACROBLOCK *x, int mi_row,
9873                                int mi_col, RD_STATS *rd_cost, BLOCK_SIZE bsize,
9874                                PICK_MODE_CONTEXT *ctx, int64_t best_rd) {
9875   const AV1_COMMON *const cm = &cpi->common;
9876   MACROBLOCKD *const xd = &x->e_mbd;
9877   MB_MODE_INFO *const mbmi = xd->mi[0];
9878   const int num_planes = av1_num_planes(cm);
9879   int rate_y = 0, rate_uv = 0, rate_y_tokenonly = 0, rate_uv_tokenonly = 0;
9880   int y_skip = 0, uv_skip = 0;
9881   int64_t dist_y = 0, dist_uv = 0;
9882   TX_SIZE max_uv_tx_size;
9883 
9884   ctx->skip = 0;
9885   mbmi->ref_frame[0] = INTRA_FRAME;
9886   mbmi->ref_frame[1] = NONE_FRAME;
9887   mbmi->use_intrabc = 0;
9888   mbmi->mv[0].as_int = 0;
9889 
9890   const int64_t intra_yrd =
9891       rd_pick_intra_sby_mode(cpi, x, mi_row, mi_col, &rate_y, &rate_y_tokenonly,
9892                              &dist_y, &y_skip, bsize, best_rd, ctx);
9893 
9894   if (intra_yrd < best_rd) {
9895     // Only store reconstructed luma when there's chroma RDO. When there's no
9896     // chroma RDO, the reconstructed luma will be stored in encode_superblock().
9897     xd->cfl.is_chroma_reference =
9898         is_chroma_reference(mi_row, mi_col, bsize, cm->seq_params.subsampling_x,
9899                             cm->seq_params.subsampling_y);
9900     xd->cfl.store_y = store_cfl_required_rdo(cm, x);
9901     if (xd->cfl.store_y) {
9902       // Restore reconstructed luma values.
9903       memcpy(x->blk_skip, ctx->blk_skip,
9904              sizeof(x->blk_skip[0]) * ctx->num_4x4_blk);
9905       av1_encode_intra_block_plane(cpi, x, bsize, AOM_PLANE_Y,
9906                                    cpi->optimize_seg_arr[mbmi->segment_id],
9907                                    mi_row, mi_col);
9908       xd->cfl.store_y = 0;
9909     }
9910     if (num_planes > 1) {
9911       max_uv_tx_size = av1_get_tx_size(AOM_PLANE_U, xd);
9912       init_sbuv_mode(mbmi);
9913       if (!x->skip_chroma_rd)
9914         rd_pick_intra_sbuv_mode(cpi, x, &rate_uv, &rate_uv_tokenonly, &dist_uv,
9915                                 &uv_skip, bsize, max_uv_tx_size);
9916     }
9917 
9918     if (y_skip && (uv_skip || x->skip_chroma_rd)) {
9919       rd_cost->rate = rate_y + rate_uv - rate_y_tokenonly - rate_uv_tokenonly +
9920                       x->skip_cost[av1_get_skip_context(xd)][1];
9921       rd_cost->dist = dist_y + dist_uv;
9922     } else {
9923       rd_cost->rate =
9924           rate_y + rate_uv + x->skip_cost[av1_get_skip_context(xd)][0];
9925       rd_cost->dist = dist_y + dist_uv;
9926     }
9927     rd_cost->rdcost = RDCOST(x->rdmult, rd_cost->rate, rd_cost->dist);
9928   } else {
9929     rd_cost->rate = INT_MAX;
9930   }
9931 
9932   if (rd_cost->rate != INT_MAX && rd_cost->rdcost < best_rd)
9933     best_rd = rd_cost->rdcost;
9934   if (rd_pick_intrabc_mode_sb(cpi, x, rd_cost, bsize, best_rd) < best_rd) {
9935     ctx->skip = x->skip;
9936     memcpy(ctx->blk_skip, x->blk_skip,
9937            sizeof(x->blk_skip[0]) * ctx->num_4x4_blk);
9938     assert(rd_cost->rate != INT_MAX);
9939   }
9940   if (rd_cost->rate == INT_MAX) return;
9941 
9942   ctx->mic = *xd->mi[0];
9943   ctx->mbmi_ext = *x->mbmi_ext;
9944 }
9945 
restore_uv_color_map(const AV1_COMP * const cpi,MACROBLOCK * x)9946 static void restore_uv_color_map(const AV1_COMP *const cpi, MACROBLOCK *x) {
9947   MACROBLOCKD *const xd = &x->e_mbd;
9948   MB_MODE_INFO *const mbmi = xd->mi[0];
9949   PALETTE_MODE_INFO *const pmi = &mbmi->palette_mode_info;
9950   const BLOCK_SIZE bsize = mbmi->sb_type;
9951   int src_stride = x->plane[1].src.stride;
9952   const uint8_t *const src_u = x->plane[1].src.buf;
9953   const uint8_t *const src_v = x->plane[2].src.buf;
9954   int *const data = x->palette_buffer->kmeans_data_buf;
9955   int centroids[2 * PALETTE_MAX_SIZE];
9956   uint8_t *const color_map = xd->plane[1].color_index_map;
9957   int r, c;
9958   const uint16_t *const src_u16 = CONVERT_TO_SHORTPTR(src_u);
9959   const uint16_t *const src_v16 = CONVERT_TO_SHORTPTR(src_v);
9960   int plane_block_width, plane_block_height, rows, cols;
9961   av1_get_block_dimensions(bsize, 1, xd, &plane_block_width,
9962                            &plane_block_height, &rows, &cols);
9963 
9964   for (r = 0; r < rows; ++r) {
9965     for (c = 0; c < cols; ++c) {
9966       if (cpi->common.seq_params.use_highbitdepth) {
9967         data[(r * cols + c) * 2] = src_u16[r * src_stride + c];
9968         data[(r * cols + c) * 2 + 1] = src_v16[r * src_stride + c];
9969       } else {
9970         data[(r * cols + c) * 2] = src_u[r * src_stride + c];
9971         data[(r * cols + c) * 2 + 1] = src_v[r * src_stride + c];
9972       }
9973     }
9974   }
9975 
9976   for (r = 1; r < 3; ++r) {
9977     for (c = 0; c < pmi->palette_size[1]; ++c) {
9978       centroids[c * 2 + r - 1] = pmi->palette_colors[r * PALETTE_MAX_SIZE + c];
9979     }
9980   }
9981 
9982   av1_calc_indices(data, centroids, color_map, rows * cols,
9983                    pmi->palette_size[1], 2);
9984   extend_palette_color_map(color_map, cols, rows, plane_block_width,
9985                            plane_block_height);
9986 }
9987 
9988 static void calc_target_weighted_pred(const AV1_COMMON *cm, const MACROBLOCK *x,
9989                                       const MACROBLOCKD *xd, int mi_row,
9990                                       int mi_col, const uint8_t *above,
9991                                       int above_stride, const uint8_t *left,
9992                                       int left_stride);
9993 
9994 static const int ref_frame_flag_list[REF_FRAMES] = { 0,
9995                                                      AOM_LAST_FLAG,
9996                                                      AOM_LAST2_FLAG,
9997                                                      AOM_LAST3_FLAG,
9998                                                      AOM_GOLD_FLAG,
9999                                                      AOM_BWD_FLAG,
10000                                                      AOM_ALT2_FLAG,
10001                                                      AOM_ALT_FLAG };
10002 
rd_pick_skip_mode(RD_STATS * rd_cost,InterModeSearchState * search_state,const AV1_COMP * const cpi,MACROBLOCK * const x,BLOCK_SIZE bsize,int mi_row,int mi_col,struct buf_2d yv12_mb[REF_FRAMES][MAX_MB_PLANE])10003 static void rd_pick_skip_mode(RD_STATS *rd_cost,
10004                               InterModeSearchState *search_state,
10005                               const AV1_COMP *const cpi, MACROBLOCK *const x,
10006                               BLOCK_SIZE bsize, int mi_row, int mi_col,
10007                               struct buf_2d yv12_mb[REF_FRAMES][MAX_MB_PLANE]) {
10008   const AV1_COMMON *const cm = &cpi->common;
10009   const int num_planes = av1_num_planes(cm);
10010   MACROBLOCKD *const xd = &x->e_mbd;
10011   MB_MODE_INFO *const mbmi = xd->mi[0];
10012 
10013   x->compound_idx = 1;  // COMPOUND_AVERAGE
10014   RD_STATS skip_mode_rd_stats;
10015   av1_invalid_rd_stats(&skip_mode_rd_stats);
10016 
10017   if (cm->ref_frame_idx_0 == INVALID_IDX ||
10018       cm->ref_frame_idx_1 == INVALID_IDX) {
10019     return;
10020   }
10021 
10022   const MV_REFERENCE_FRAME ref_frame = LAST_FRAME + cm->ref_frame_idx_0;
10023   const MV_REFERENCE_FRAME second_ref_frame = LAST_FRAME + cm->ref_frame_idx_1;
10024   const PREDICTION_MODE this_mode = NEAREST_NEARESTMV;
10025   const int mode_index =
10026       get_prediction_mode_idx(this_mode, ref_frame, second_ref_frame);
10027 
10028   if (mode_index == -1) {
10029     return;
10030   }
10031 
10032   mbmi->mode = this_mode;
10033   mbmi->uv_mode = UV_DC_PRED;
10034   mbmi->ref_frame[0] = ref_frame;
10035   mbmi->ref_frame[1] = second_ref_frame;
10036   const uint8_t ref_frame_type = av1_ref_frame_type(mbmi->ref_frame);
10037   if (x->mbmi_ext->ref_mv_count[ref_frame_type] == UINT8_MAX) {
10038     if (x->mbmi_ext->ref_mv_count[ref_frame] == UINT8_MAX ||
10039         x->mbmi_ext->ref_mv_count[second_ref_frame] == UINT8_MAX) {
10040       return;
10041     }
10042     MB_MODE_INFO_EXT *mbmi_ext = x->mbmi_ext;
10043     av1_find_mv_refs(cm, xd, mbmi, ref_frame_type, mbmi_ext->ref_mv_count,
10044                      mbmi_ext->ref_mv_stack, NULL, mbmi_ext->global_mvs, mi_row,
10045                      mi_col, mbmi_ext->mode_context);
10046   }
10047 
10048   assert(this_mode == NEAREST_NEARESTMV);
10049   if (!build_cur_mv(mbmi->mv, this_mode, cm, x)) {
10050     return;
10051   }
10052 
10053   mbmi->filter_intra_mode_info.use_filter_intra = 0;
10054   mbmi->interintra_mode = (INTERINTRA_MODE)(II_DC_PRED - 1);
10055   mbmi->comp_group_idx = 0;
10056   mbmi->compound_idx = x->compound_idx;
10057   mbmi->interinter_comp.type = COMPOUND_AVERAGE;
10058   mbmi->motion_mode = SIMPLE_TRANSLATION;
10059   mbmi->ref_mv_idx = 0;
10060   mbmi->skip_mode = mbmi->skip = 1;
10061 
10062   set_default_interp_filters(mbmi, cm->interp_filter);
10063 
10064   set_ref_ptrs(cm, xd, mbmi->ref_frame[0], mbmi->ref_frame[1]);
10065   for (int i = 0; i < num_planes; i++) {
10066     xd->plane[i].pre[0] = yv12_mb[mbmi->ref_frame[0]][i];
10067     xd->plane[i].pre[1] = yv12_mb[mbmi->ref_frame[1]][i];
10068   }
10069 
10070   BUFFER_SET orig_dst;
10071   for (int i = 0; i < num_planes; i++) {
10072     orig_dst.plane[i] = xd->plane[i].dst.buf;
10073     orig_dst.stride[i] = xd->plane[i].dst.stride;
10074   }
10075 
10076   // Obtain the rdcost for skip_mode.
10077   skip_mode_rd(&skip_mode_rd_stats, cpi, x, bsize, mi_row, mi_col, &orig_dst);
10078 
10079   // Compare the use of skip_mode with the best intra/inter mode obtained.
10080   const int skip_mode_ctx = av1_get_skip_mode_context(xd);
10081   const int64_t best_intra_inter_mode_cost =
10082       (rd_cost->dist < INT64_MAX && rd_cost->rate < INT32_MAX)
10083           ? RDCOST(x->rdmult,
10084                    rd_cost->rate + x->skip_mode_cost[skip_mode_ctx][0],
10085                    rd_cost->dist)
10086           : INT64_MAX;
10087 
10088   if (skip_mode_rd_stats.rdcost <= best_intra_inter_mode_cost) {
10089     assert(mode_index != -1);
10090     search_state->best_mbmode.skip_mode = 1;
10091     search_state->best_mbmode = *mbmi;
10092 
10093     search_state->best_mbmode.skip_mode = search_state->best_mbmode.skip = 1;
10094     search_state->best_mbmode.mode = NEAREST_NEARESTMV;
10095     search_state->best_mbmode.ref_frame[0] = mbmi->ref_frame[0];
10096     search_state->best_mbmode.ref_frame[1] = mbmi->ref_frame[1];
10097     search_state->best_mbmode.mv[0].as_int = mbmi->mv[0].as_int;
10098     search_state->best_mbmode.mv[1].as_int = mbmi->mv[1].as_int;
10099     search_state->best_mbmode.ref_mv_idx = 0;
10100 
10101     // Set up tx_size related variables for skip-specific loop filtering.
10102     search_state->best_mbmode.tx_size =
10103         block_signals_txsize(bsize) ? tx_size_from_tx_mode(bsize, cm->tx_mode)
10104                                     : max_txsize_rect_lookup[bsize];
10105     memset(search_state->best_mbmode.inter_tx_size,
10106            search_state->best_mbmode.tx_size,
10107            sizeof(search_state->best_mbmode.inter_tx_size));
10108     set_txfm_ctxs(search_state->best_mbmode.tx_size, xd->n4_w, xd->n4_h,
10109                   search_state->best_mbmode.skip && is_inter_block(mbmi), xd);
10110 
10111     // Set up color-related variables for skip mode.
10112     search_state->best_mbmode.uv_mode = UV_DC_PRED;
10113     search_state->best_mbmode.palette_mode_info.palette_size[0] = 0;
10114     search_state->best_mbmode.palette_mode_info.palette_size[1] = 0;
10115 
10116     search_state->best_mbmode.comp_group_idx = 0;
10117     search_state->best_mbmode.compound_idx = x->compound_idx;
10118     search_state->best_mbmode.interinter_comp.type = COMPOUND_AVERAGE;
10119     search_state->best_mbmode.motion_mode = SIMPLE_TRANSLATION;
10120 
10121     search_state->best_mbmode.interintra_mode =
10122         (INTERINTRA_MODE)(II_DC_PRED - 1);
10123     search_state->best_mbmode.filter_intra_mode_info.use_filter_intra = 0;
10124 
10125     set_default_interp_filters(&search_state->best_mbmode, cm->interp_filter);
10126 
10127     search_state->best_mode_index = mode_index;
10128 
10129     // Update rd_cost
10130     rd_cost->rate = skip_mode_rd_stats.rate;
10131     rd_cost->dist = rd_cost->sse = skip_mode_rd_stats.dist;
10132     rd_cost->rdcost = skip_mode_rd_stats.rdcost;
10133 
10134     search_state->best_rd = rd_cost->rdcost;
10135     search_state->best_skip2 = 1;
10136     search_state->best_mode_skippable = (skip_mode_rd_stats.sse == 0);
10137 
10138     x->skip = 1;
10139   }
10140 }
10141 
10142 // speed feature: fast intra/inter transform type search
10143 // Used for speed >= 2
10144 // When this speed feature is on, in rd mode search, only DCT is used.
10145 // After the mode is determined, this function is called, to select
10146 // transform types and get accurate rdcost.
sf_refine_fast_tx_type_search(const AV1_COMP * cpi,MACROBLOCK * x,int mi_row,int mi_col,RD_STATS * rd_cost,BLOCK_SIZE bsize,PICK_MODE_CONTEXT * ctx,int best_mode_index,MB_MODE_INFO * best_mbmode,struct buf_2d yv12_mb[REF_FRAMES][MAX_MB_PLANE],int best_rate_y,int best_rate_uv,int * best_skip2)10147 static void sf_refine_fast_tx_type_search(
10148     const AV1_COMP *cpi, MACROBLOCK *x, int mi_row, int mi_col,
10149     RD_STATS *rd_cost, BLOCK_SIZE bsize, PICK_MODE_CONTEXT *ctx,
10150     int best_mode_index, MB_MODE_INFO *best_mbmode,
10151     struct buf_2d yv12_mb[REF_FRAMES][MAX_MB_PLANE], int best_rate_y,
10152     int best_rate_uv, int *best_skip2) {
10153   const AV1_COMMON *const cm = &cpi->common;
10154   const SPEED_FEATURES *const sf = &cpi->sf;
10155   MACROBLOCKD *const xd = &x->e_mbd;
10156   MB_MODE_INFO *const mbmi = xd->mi[0];
10157   const int num_planes = av1_num_planes(cm);
10158 
10159   if (xd->lossless[mbmi->segment_id] == 0 && best_mode_index >= 0 &&
10160       ((sf->tx_type_search.fast_inter_tx_type_search == 1 &&
10161         is_inter_mode(best_mbmode->mode)) ||
10162        (sf->tx_type_search.fast_intra_tx_type_search == 1 &&
10163         !is_inter_mode(best_mbmode->mode)))) {
10164     int skip_blk = 0;
10165     RD_STATS rd_stats_y, rd_stats_uv;
10166 
10167     x->use_default_inter_tx_type = 0;
10168     x->use_default_intra_tx_type = 0;
10169 
10170     *mbmi = *best_mbmode;
10171 
10172     set_ref_ptrs(cm, xd, mbmi->ref_frame[0], mbmi->ref_frame[1]);
10173 
10174     // Select prediction reference frames.
10175     for (int i = 0; i < num_planes; i++) {
10176       xd->plane[i].pre[0] = yv12_mb[mbmi->ref_frame[0]][i];
10177       if (has_second_ref(mbmi))
10178         xd->plane[i].pre[1] = yv12_mb[mbmi->ref_frame[1]][i];
10179     }
10180 
10181     if (is_inter_mode(mbmi->mode)) {
10182       av1_build_inter_predictors_sb(cm, xd, mi_row, mi_col, NULL, bsize);
10183       if (mbmi->motion_mode == OBMC_CAUSAL)
10184         av1_build_obmc_inter_predictors_sb(cm, xd, mi_row, mi_col);
10185 
10186       av1_subtract_plane(x, bsize, 0);
10187       if (cm->tx_mode == TX_MODE_SELECT && !xd->lossless[mbmi->segment_id]) {
10188         // av1_rd_pick_inter_mode_sb
10189         select_tx_type_yrd(cpi, x, &rd_stats_y, bsize, mi_row, mi_col,
10190                            INT64_MAX);
10191         assert(rd_stats_y.rate != INT_MAX);
10192       } else {
10193         super_block_yrd(cpi, x, &rd_stats_y, bsize, INT64_MAX);
10194         memset(mbmi->inter_tx_size, mbmi->tx_size, sizeof(mbmi->inter_tx_size));
10195         for (int i = 0; i < xd->n4_h * xd->n4_w; ++i)
10196           set_blk_skip(x, 0, i, rd_stats_y.skip);
10197       }
10198       if (num_planes > 1) {
10199         inter_block_uvrd(cpi, x, &rd_stats_uv, bsize, INT64_MAX, INT64_MAX,
10200                          FTXS_NONE);
10201       } else {
10202         av1_init_rd_stats(&rd_stats_uv);
10203       }
10204     } else {
10205       super_block_yrd(cpi, x, &rd_stats_y, bsize, INT64_MAX);
10206       if (num_planes > 1) {
10207         super_block_uvrd(cpi, x, &rd_stats_uv, bsize, INT64_MAX);
10208       } else {
10209         av1_init_rd_stats(&rd_stats_uv);
10210       }
10211     }
10212 
10213     if (RDCOST(x->rdmult, rd_stats_y.rate + rd_stats_uv.rate,
10214                (rd_stats_y.dist + rd_stats_uv.dist)) >
10215         RDCOST(x->rdmult, 0, (rd_stats_y.sse + rd_stats_uv.sse))) {
10216       skip_blk = 1;
10217       rd_stats_y.rate = x->skip_cost[av1_get_skip_context(xd)][1];
10218       rd_stats_uv.rate = 0;
10219       rd_stats_y.dist = rd_stats_y.sse;
10220       rd_stats_uv.dist = rd_stats_uv.sse;
10221     } else {
10222       skip_blk = 0;
10223       rd_stats_y.rate += x->skip_cost[av1_get_skip_context(xd)][0];
10224     }
10225 
10226     if (RDCOST(x->rdmult, best_rate_y + best_rate_uv, rd_cost->dist) >
10227         RDCOST(x->rdmult, rd_stats_y.rate + rd_stats_uv.rate,
10228                (rd_stats_y.dist + rd_stats_uv.dist))) {
10229       best_mbmode->tx_size = mbmi->tx_size;
10230       av1_copy(best_mbmode->inter_tx_size, mbmi->inter_tx_size);
10231       memcpy(ctx->blk_skip, x->blk_skip,
10232              sizeof(x->blk_skip[0]) * ctx->num_4x4_blk);
10233       av1_copy(best_mbmode->txk_type, mbmi->txk_type);
10234       rd_cost->rate +=
10235           (rd_stats_y.rate + rd_stats_uv.rate - best_rate_y - best_rate_uv);
10236       rd_cost->dist = rd_stats_y.dist + rd_stats_uv.dist;
10237       rd_cost->rdcost = RDCOST(x->rdmult, rd_cost->rate, rd_cost->dist);
10238       *best_skip2 = skip_blk;
10239     }
10240   }
10241 }
10242 
10243 // Please add/modify parameter setting in this function, making it consistent
10244 // and easy to read and maintain.
set_params_rd_pick_inter_mode(const AV1_COMP * cpi,MACROBLOCK * x,HandleInterModeArgs * args,BLOCK_SIZE bsize,int mi_row,int mi_col,uint16_t ref_frame_skip_mask[2],uint32_t mode_skip_mask[REF_FRAMES],int skip_ref_frame_mask,unsigned int ref_costs_single[REF_FRAMES],unsigned int ref_costs_comp[REF_FRAMES][REF_FRAMES],struct buf_2d yv12_mb[REF_FRAMES][MAX_MB_PLANE])10245 static void set_params_rd_pick_inter_mode(
10246     const AV1_COMP *cpi, MACROBLOCK *x, HandleInterModeArgs *args,
10247     BLOCK_SIZE bsize, int mi_row, int mi_col, uint16_t ref_frame_skip_mask[2],
10248     uint32_t mode_skip_mask[REF_FRAMES], int skip_ref_frame_mask,
10249     unsigned int ref_costs_single[REF_FRAMES],
10250     unsigned int ref_costs_comp[REF_FRAMES][REF_FRAMES],
10251     struct buf_2d yv12_mb[REF_FRAMES][MAX_MB_PLANE]) {
10252   const AV1_COMMON *const cm = &cpi->common;
10253   const int num_planes = av1_num_planes(cm);
10254   MACROBLOCKD *const xd = &x->e_mbd;
10255   MB_MODE_INFO *const mbmi = xd->mi[0];
10256   MB_MODE_INFO_EXT *const mbmi_ext = x->mbmi_ext;
10257   const struct segmentation *const seg = &cm->seg;
10258   const SPEED_FEATURES *const sf = &cpi->sf;
10259   unsigned char segment_id = mbmi->segment_id;
10260   int dst_width1[MAX_MB_PLANE] = { MAX_SB_SIZE, MAX_SB_SIZE, MAX_SB_SIZE };
10261   int dst_width2[MAX_MB_PLANE] = { MAX_SB_SIZE >> 1, MAX_SB_SIZE >> 1,
10262                                    MAX_SB_SIZE >> 1 };
10263   int dst_height1[MAX_MB_PLANE] = { MAX_SB_SIZE >> 1, MAX_SB_SIZE >> 1,
10264                                     MAX_SB_SIZE >> 1 };
10265   int dst_height2[MAX_MB_PLANE] = { MAX_SB_SIZE, MAX_SB_SIZE, MAX_SB_SIZE };
10266 
10267   for (int i = 0; i < MB_MODE_COUNT; ++i)
10268     for (int k = 0; k < REF_FRAMES; ++k) args->single_filter[i][k] = SWITCHABLE;
10269 
10270   if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) {
10271     int len = sizeof(uint16_t);
10272     args->above_pred_buf[0] = CONVERT_TO_BYTEPTR(x->above_pred_buf);
10273     args->above_pred_buf[1] =
10274         CONVERT_TO_BYTEPTR(x->above_pred_buf + (MAX_SB_SQUARE >> 1) * len);
10275     args->above_pred_buf[2] =
10276         CONVERT_TO_BYTEPTR(x->above_pred_buf + MAX_SB_SQUARE * len);
10277     args->left_pred_buf[0] = CONVERT_TO_BYTEPTR(x->left_pred_buf);
10278     args->left_pred_buf[1] =
10279         CONVERT_TO_BYTEPTR(x->left_pred_buf + (MAX_SB_SQUARE >> 1) * len);
10280     args->left_pred_buf[2] =
10281         CONVERT_TO_BYTEPTR(x->left_pred_buf + MAX_SB_SQUARE * len);
10282   } else {
10283     args->above_pred_buf[0] = x->above_pred_buf;
10284     args->above_pred_buf[1] = x->above_pred_buf + (MAX_SB_SQUARE >> 1);
10285     args->above_pred_buf[2] = x->above_pred_buf + MAX_SB_SQUARE;
10286     args->left_pred_buf[0] = x->left_pred_buf;
10287     args->left_pred_buf[1] = x->left_pred_buf + (MAX_SB_SQUARE >> 1);
10288     args->left_pred_buf[2] = x->left_pred_buf + MAX_SB_SQUARE;
10289   }
10290 
10291   av1_collect_neighbors_ref_counts(xd);
10292 
10293   estimate_ref_frame_costs(cm, xd, x, segment_id, ref_costs_single,
10294                            ref_costs_comp);
10295 
10296   MV_REFERENCE_FRAME ref_frame;
10297   for (ref_frame = LAST_FRAME; ref_frame <= ALTREF_FRAME; ++ref_frame) {
10298     x->pred_mv_sad[ref_frame] = INT_MAX;
10299     x->mbmi_ext->mode_context[ref_frame] = 0;
10300     x->mbmi_ext->compound_mode_context[ref_frame] = 0;
10301     mbmi_ext->ref_mv_count[ref_frame] = UINT8_MAX;
10302     if (cpi->ref_frame_flags & ref_frame_flag_list[ref_frame]) {
10303       if (mbmi->partition != PARTITION_NONE &&
10304           mbmi->partition != PARTITION_SPLIT) {
10305         if (skip_ref_frame_mask & (1 << ref_frame)) {
10306           int skip = 1;
10307           for (int r = ALTREF_FRAME + 1; r < MODE_CTX_REF_FRAMES; ++r) {
10308             if (!(skip_ref_frame_mask & (1 << r))) {
10309               const MV_REFERENCE_FRAME *rf = ref_frame_map[r - REF_FRAMES];
10310               if (rf[0] == ref_frame || rf[1] == ref_frame) {
10311                 skip = 0;
10312                 break;
10313               }
10314             }
10315           }
10316           if (skip) continue;
10317         }
10318       }
10319       assert(get_ref_frame_buffer(cpi, ref_frame) != NULL);
10320       setup_buffer_ref_mvs_inter(cpi, x, ref_frame, bsize, mi_row, mi_col,
10321                                  yv12_mb);
10322     }
10323   }
10324   // ref_frame = ALTREF_FRAME
10325   for (; ref_frame < MODE_CTX_REF_FRAMES; ++ref_frame) {
10326     x->mbmi_ext->mode_context[ref_frame] = 0;
10327     mbmi_ext->ref_mv_count[ref_frame] = UINT8_MAX;
10328     const MV_REFERENCE_FRAME *rf = ref_frame_map[ref_frame - REF_FRAMES];
10329     if (!((cpi->ref_frame_flags & ref_frame_flag_list[rf[0]]) &&
10330           (cpi->ref_frame_flags & ref_frame_flag_list[rf[1]]))) {
10331       continue;
10332     }
10333 
10334     if (mbmi->partition != PARTITION_NONE &&
10335         mbmi->partition != PARTITION_SPLIT) {
10336       if (skip_ref_frame_mask & (1 << ref_frame)) {
10337         continue;
10338       }
10339     }
10340     av1_find_mv_refs(cm, xd, mbmi, ref_frame, mbmi_ext->ref_mv_count,
10341                      mbmi_ext->ref_mv_stack, NULL, mbmi_ext->global_mvs, mi_row,
10342                      mi_col, mbmi_ext->mode_context);
10343   }
10344 
10345   av1_count_overlappable_neighbors(cm, xd, mi_row, mi_col);
10346 
10347   if (check_num_overlappable_neighbors(mbmi) &&
10348       is_motion_variation_allowed_bsize(bsize)) {
10349     av1_build_prediction_by_above_preds(cm, xd, mi_row, mi_col,
10350                                         args->above_pred_buf, dst_width1,
10351                                         dst_height1, args->above_pred_stride);
10352     av1_build_prediction_by_left_preds(cm, xd, mi_row, mi_col,
10353                                        args->left_pred_buf, dst_width2,
10354                                        dst_height2, args->left_pred_stride);
10355     av1_setup_dst_planes(xd->plane, bsize, get_frame_new_buffer(cm), mi_row,
10356                          mi_col, 0, num_planes);
10357     calc_target_weighted_pred(
10358         cm, x, xd, mi_row, mi_col, args->above_pred_buf[0],
10359         args->above_pred_stride[0], args->left_pred_buf[0],
10360         args->left_pred_stride[0]);
10361   }
10362 
10363   int min_pred_mv_sad = INT_MAX;
10364   for (ref_frame = LAST_FRAME; ref_frame <= ALTREF_FRAME; ++ref_frame)
10365     min_pred_mv_sad = AOMMIN(min_pred_mv_sad, x->pred_mv_sad[ref_frame]);
10366 
10367   for (int i = 0; i < 2; ++i) {
10368     ref_frame_skip_mask[i] = 0;
10369   }
10370   memset(mode_skip_mask, 0, REF_FRAMES * sizeof(*mode_skip_mask));
10371   for (ref_frame = LAST_FRAME; ref_frame <= ALTREF_FRAME; ++ref_frame) {
10372     if (!(cpi->ref_frame_flags & ref_frame_flag_list[ref_frame])) {
10373       // Skip checking missing references in both single and compound reference
10374       // modes. Note that a mode will be skipped iff both reference frames
10375       // are masked out.
10376       ref_frame_skip_mask[0] |= (1 << ref_frame);
10377       ref_frame_skip_mask[1] |= SECOND_REF_FRAME_MASK;
10378     } else {
10379       // Skip fixed mv modes for poor references
10380       if ((x->pred_mv_sad[ref_frame] >> 2) > min_pred_mv_sad) {
10381         mode_skip_mask[ref_frame] |= INTER_NEAREST_NEAR_ZERO;
10382       }
10383     }
10384     // If the segment reference frame feature is enabled....
10385     // then do nothing if the current ref frame is not allowed..
10386     if (segfeature_active(seg, segment_id, SEG_LVL_REF_FRAME) &&
10387         get_segdata(seg, segment_id, SEG_LVL_REF_FRAME) != (int)ref_frame) {
10388       ref_frame_skip_mask[0] |= (1 << ref_frame);
10389       ref_frame_skip_mask[1] |= SECOND_REF_FRAME_MASK;
10390     }
10391   }
10392 
10393   // Disable this drop out case if the ref frame
10394   // segment level feature is enabled for this segment. This is to
10395   // prevent the possibility that we end up unable to pick any mode.
10396   if (!segfeature_active(seg, segment_id, SEG_LVL_REF_FRAME)) {
10397     // Only consider GLOBALMV/ALTREF_FRAME for alt ref frame,
10398     // unless ARNR filtering is enabled in which case we want
10399     // an unfiltered alternative. We allow near/nearest as well
10400     // because they may result in zero-zero MVs but be cheaper.
10401     if (cpi->rc.is_src_frame_alt_ref && (cpi->oxcf.arnr_max_frames == 0)) {
10402       ref_frame_skip_mask[0] = (1 << LAST_FRAME) | (1 << LAST2_FRAME) |
10403                                (1 << LAST3_FRAME) | (1 << BWDREF_FRAME) |
10404                                (1 << ALTREF2_FRAME) | (1 << GOLDEN_FRAME);
10405       ref_frame_skip_mask[1] = SECOND_REF_FRAME_MASK;
10406       // TODO(zoeliu): To further explore whether following needs to be done for
10407       //               BWDREF_FRAME as well.
10408       mode_skip_mask[ALTREF_FRAME] = ~INTER_NEAREST_NEAR_ZERO;
10409       const MV_REFERENCE_FRAME tmp_ref_frames[2] = { ALTREF_FRAME, NONE_FRAME };
10410       int_mv near_mv, nearest_mv, global_mv;
10411       get_this_mv(&nearest_mv, NEARESTMV, 0, 0, tmp_ref_frames, x->mbmi_ext);
10412       get_this_mv(&near_mv, NEARMV, 0, 0, tmp_ref_frames, x->mbmi_ext);
10413       get_this_mv(&global_mv, GLOBALMV, 0, 0, tmp_ref_frames, x->mbmi_ext);
10414 
10415       if (near_mv.as_int != global_mv.as_int)
10416         mode_skip_mask[ALTREF_FRAME] |= (1 << NEARMV);
10417       if (nearest_mv.as_int != global_mv.as_int)
10418         mode_skip_mask[ALTREF_FRAME] |= (1 << NEARESTMV);
10419     }
10420   }
10421 
10422   if (cpi->rc.is_src_frame_alt_ref) {
10423     if (sf->alt_ref_search_fp) {
10424       assert(cpi->ref_frame_flags & ref_frame_flag_list[ALTREF_FRAME]);
10425       mode_skip_mask[ALTREF_FRAME] = 0;
10426       ref_frame_skip_mask[0] = ~(1 << ALTREF_FRAME);
10427       ref_frame_skip_mask[1] = SECOND_REF_FRAME_MASK;
10428     }
10429   }
10430 
10431   if (sf->alt_ref_search_fp)
10432     if (!cm->show_frame && x->pred_mv_sad[GOLDEN_FRAME] < INT_MAX)
10433       if (x->pred_mv_sad[ALTREF_FRAME] > (x->pred_mv_sad[GOLDEN_FRAME] << 1))
10434         mode_skip_mask[ALTREF_FRAME] |= INTER_ALL;
10435 
10436   if (sf->adaptive_mode_search) {
10437     if (cm->show_frame && !cpi->rc.is_src_frame_alt_ref &&
10438         cpi->rc.frames_since_golden >= 3)
10439       if ((x->pred_mv_sad[GOLDEN_FRAME] >> 1) > x->pred_mv_sad[LAST_FRAME])
10440         mode_skip_mask[GOLDEN_FRAME] |= INTER_ALL;
10441   }
10442 
10443   if (bsize > sf->max_intra_bsize) {
10444     ref_frame_skip_mask[0] |= (1 << INTRA_FRAME);
10445     ref_frame_skip_mask[1] |= (1 << INTRA_FRAME);
10446   }
10447 
10448   mode_skip_mask[INTRA_FRAME] |=
10449       ~(sf->intra_y_mode_mask[max_txsize_lookup[bsize]]);
10450 
10451   if (cpi->sf.tx_type_search.fast_intra_tx_type_search)
10452     x->use_default_intra_tx_type = 1;
10453   else
10454     x->use_default_intra_tx_type = 0;
10455 
10456   if (cpi->sf.tx_type_search.fast_inter_tx_type_search)
10457     x->use_default_inter_tx_type = 1;
10458   else
10459     x->use_default_inter_tx_type = 0;
10460   if (cpi->sf.skip_repeat_interpolation_filter_search) {
10461     x->interp_filter_stats_idx[0] = 0;
10462     x->interp_filter_stats_idx[1] = 0;
10463   }
10464 }
10465 
search_palette_mode(const AV1_COMP * cpi,MACROBLOCK * x,int mi_row,int mi_col,RD_STATS * rd_cost,PICK_MODE_CONTEXT * ctx,BLOCK_SIZE bsize,MB_MODE_INFO * const mbmi,PALETTE_MODE_INFO * const pmi,unsigned int * ref_costs_single,InterModeSearchState * search_state)10466 static void search_palette_mode(const AV1_COMP *cpi, MACROBLOCK *x, int mi_row,
10467                                 int mi_col, RD_STATS *rd_cost,
10468                                 PICK_MODE_CONTEXT *ctx, BLOCK_SIZE bsize,
10469                                 MB_MODE_INFO *const mbmi,
10470                                 PALETTE_MODE_INFO *const pmi,
10471                                 unsigned int *ref_costs_single,
10472                                 InterModeSearchState *search_state) {
10473   const AV1_COMMON *const cm = &cpi->common;
10474   const int num_planes = av1_num_planes(cm);
10475   MACROBLOCKD *const xd = &x->e_mbd;
10476   int rate2 = 0;
10477   int64_t distortion2 = 0, best_rd_palette = search_state->best_rd, this_rd,
10478           best_model_rd_palette = INT64_MAX;
10479   int skippable = 0, rate_overhead_palette = 0;
10480   RD_STATS rd_stats_y;
10481   TX_SIZE uv_tx = TX_4X4;
10482   uint8_t *const best_palette_color_map =
10483       x->palette_buffer->best_palette_color_map;
10484   uint8_t *const color_map = xd->plane[0].color_index_map;
10485   MB_MODE_INFO best_mbmi_palette = *mbmi;
10486   uint8_t best_blk_skip[MAX_MIB_SIZE * MAX_MIB_SIZE];
10487   const int *const intra_mode_cost = x->mbmode_cost[size_group_lookup[bsize]];
10488   const int rows = block_size_high[bsize];
10489   const int cols = block_size_wide[bsize];
10490 
10491   mbmi->mode = DC_PRED;
10492   mbmi->uv_mode = UV_DC_PRED;
10493   mbmi->ref_frame[0] = INTRA_FRAME;
10494   mbmi->ref_frame[1] = NONE_FRAME;
10495   rate_overhead_palette = rd_pick_palette_intra_sby(
10496       cpi, x, bsize, mi_row, mi_col, intra_mode_cost[DC_PRED],
10497       &best_mbmi_palette, best_palette_color_map, &best_rd_palette,
10498       &best_model_rd_palette, NULL, NULL, NULL, NULL, ctx, best_blk_skip);
10499   if (pmi->palette_size[0] == 0) return;
10500 
10501   memcpy(x->blk_skip, best_blk_skip,
10502          sizeof(best_blk_skip[0]) * bsize_to_num_blk(bsize));
10503 
10504   memcpy(color_map, best_palette_color_map,
10505          rows * cols * sizeof(best_palette_color_map[0]));
10506   super_block_yrd(cpi, x, &rd_stats_y, bsize, search_state->best_rd);
10507   if (rd_stats_y.rate == INT_MAX) return;
10508 
10509   skippable = rd_stats_y.skip;
10510   distortion2 = rd_stats_y.dist;
10511   rate2 = rd_stats_y.rate + rate_overhead_palette;
10512   rate2 += ref_costs_single[INTRA_FRAME];
10513   if (num_planes > 1) {
10514     uv_tx = av1_get_tx_size(AOM_PLANE_U, xd);
10515     if (search_state->rate_uv_intra[uv_tx] == INT_MAX) {
10516       choose_intra_uv_mode(
10517           cpi, x, bsize, uv_tx, &search_state->rate_uv_intra[uv_tx],
10518           &search_state->rate_uv_tokenonly[uv_tx],
10519           &search_state->dist_uvs[uv_tx], &search_state->skip_uvs[uv_tx],
10520           &search_state->mode_uv[uv_tx]);
10521       search_state->pmi_uv[uv_tx] = *pmi;
10522       search_state->uv_angle_delta[uv_tx] = mbmi->angle_delta[PLANE_TYPE_UV];
10523     }
10524     mbmi->uv_mode = search_state->mode_uv[uv_tx];
10525     pmi->palette_size[1] = search_state->pmi_uv[uv_tx].palette_size[1];
10526     if (pmi->palette_size[1] > 0) {
10527       memcpy(pmi->palette_colors + PALETTE_MAX_SIZE,
10528              search_state->pmi_uv[uv_tx].palette_colors + PALETTE_MAX_SIZE,
10529              2 * PALETTE_MAX_SIZE * sizeof(pmi->palette_colors[0]));
10530     }
10531     mbmi->angle_delta[PLANE_TYPE_UV] = search_state->uv_angle_delta[uv_tx];
10532     skippable = skippable && search_state->skip_uvs[uv_tx];
10533     distortion2 += search_state->dist_uvs[uv_tx];
10534     rate2 += search_state->rate_uv_intra[uv_tx];
10535   }
10536 
10537   if (skippable) {
10538     rate2 -= rd_stats_y.rate;
10539     if (num_planes > 1) rate2 -= search_state->rate_uv_tokenonly[uv_tx];
10540     rate2 += x->skip_cost[av1_get_skip_context(xd)][1];
10541   } else {
10542     rate2 += x->skip_cost[av1_get_skip_context(xd)][0];
10543   }
10544   this_rd = RDCOST(x->rdmult, rate2, distortion2);
10545   if (this_rd < search_state->best_rd) {
10546     search_state->best_mode_index = 3;
10547     mbmi->mv[0].as_int = 0;
10548     rd_cost->rate = rate2;
10549     rd_cost->dist = distortion2;
10550     rd_cost->rdcost = this_rd;
10551     search_state->best_rd = this_rd;
10552     search_state->best_mbmode = *mbmi;
10553     search_state->best_skip2 = 0;
10554     search_state->best_mode_skippable = skippable;
10555     memcpy(ctx->blk_skip, x->blk_skip,
10556            sizeof(x->blk_skip[0]) * ctx->num_4x4_blk);
10557   }
10558 }
10559 
init_inter_mode_search_state(InterModeSearchState * search_state,const AV1_COMP * cpi,const TileDataEnc * tile_data,const MACROBLOCK * x,BLOCK_SIZE bsize,int64_t best_rd_so_far)10560 static void init_inter_mode_search_state(InterModeSearchState *search_state,
10561                                          const AV1_COMP *cpi,
10562                                          const TileDataEnc *tile_data,
10563                                          const MACROBLOCK *x, BLOCK_SIZE bsize,
10564                                          int64_t best_rd_so_far) {
10565   search_state->best_rd = best_rd_so_far;
10566 
10567   av1_zero(search_state->best_mbmode);
10568 
10569   search_state->best_rate_y = INT_MAX;
10570 
10571   search_state->best_rate_uv = INT_MAX;
10572 
10573   search_state->best_mode_skippable = 0;
10574 
10575   search_state->best_skip2 = 0;
10576 
10577   search_state->best_mode_index = -1;
10578 
10579   const MACROBLOCKD *const xd = &x->e_mbd;
10580   const MB_MODE_INFO *const mbmi = xd->mi[0];
10581   const unsigned char segment_id = mbmi->segment_id;
10582 
10583   search_state->skip_intra_modes = 0;
10584 
10585   search_state->num_available_refs = 0;
10586   memset(search_state->dist_refs, -1, sizeof(search_state->dist_refs));
10587   memset(search_state->dist_order_refs, -1,
10588          sizeof(search_state->dist_order_refs));
10589 
10590   for (int i = 0; i <= LAST_NEW_MV_INDEX; ++i)
10591     search_state->mode_threshold[i] = 0;
10592   const int *const rd_threshes = cpi->rd.threshes[segment_id][bsize];
10593   for (int i = LAST_NEW_MV_INDEX + 1; i < MAX_MODES; ++i)
10594     search_state->mode_threshold[i] =
10595         ((int64_t)rd_threshes[i] * tile_data->thresh_freq_fact[bsize][i]) >> 5;
10596 
10597   search_state->best_intra_mode = DC_PRED;
10598   search_state->best_intra_rd = INT64_MAX;
10599 
10600   search_state->angle_stats_ready = 0;
10601 
10602   search_state->best_pred_sse = UINT_MAX;
10603 
10604   for (int i = 0; i < TX_SIZES_ALL; i++)
10605     search_state->rate_uv_intra[i] = INT_MAX;
10606 
10607   av1_zero(search_state->pmi_uv);
10608 
10609   for (int i = 0; i < REFERENCE_MODES; ++i)
10610     search_state->best_pred_rd[i] = INT64_MAX;
10611 
10612   av1_zero(search_state->single_newmv);
10613   av1_zero(search_state->single_newmv_rate);
10614   av1_zero(search_state->single_newmv_valid);
10615   for (int i = 0; i < MB_MODE_COUNT; ++i) {
10616     for (int j = 0; j < MAX_REF_MV_SERCH; ++j) {
10617       for (int ref_frame = 0; ref_frame < REF_FRAMES; ++ref_frame) {
10618         search_state->modelled_rd[i][j][ref_frame] = INT64_MAX;
10619         search_state->simple_rd[i][j][ref_frame] = INT64_MAX;
10620       }
10621     }
10622   }
10623 
10624   for (int dir = 0; dir < 2; ++dir) {
10625     for (int mode = 0; mode < SINGLE_INTER_MODE_NUM; ++mode) {
10626       for (int ref_frame = 0; ref_frame < FWD_REFS; ++ref_frame) {
10627         SingleInterModeState *state;
10628 
10629         state = &search_state->single_state[dir][mode][ref_frame];
10630         state->ref_frame = NONE_FRAME;
10631         state->rd = INT64_MAX;
10632 
10633         state = &search_state->single_state_modelled[dir][mode][ref_frame];
10634         state->ref_frame = NONE_FRAME;
10635         state->rd = INT64_MAX;
10636       }
10637     }
10638   }
10639   for (int dir = 0; dir < 2; ++dir) {
10640     for (int mode = 0; mode < SINGLE_INTER_MODE_NUM; ++mode) {
10641       for (int ref_frame = 0; ref_frame < FWD_REFS; ++ref_frame) {
10642         search_state->single_rd_order[dir][mode][ref_frame] = NONE_FRAME;
10643       }
10644     }
10645   }
10646   av1_zero(search_state->single_state_cnt);
10647   av1_zero(search_state->single_state_modelled_cnt);
10648 }
10649 
10650 // Case 1: return 0, means don't skip this mode
10651 // Case 2: return 1, means skip this mode completely
10652 // Case 3: return 2, means skip compound only, but still try single motion modes
inter_mode_search_order_independent_skip(const AV1_COMP * cpi,const PICK_MODE_CONTEXT * ctx,const MACROBLOCK * x,BLOCK_SIZE bsize,int mode_index,int mi_row,int mi_col,uint32_t * mode_skip_mask,uint16_t * ref_frame_skip_mask,InterModeSearchState * search_state)10653 static int inter_mode_search_order_independent_skip(
10654     const AV1_COMP *cpi, const PICK_MODE_CONTEXT *ctx, const MACROBLOCK *x,
10655     BLOCK_SIZE bsize, int mode_index, int mi_row, int mi_col,
10656     uint32_t *mode_skip_mask, uint16_t *ref_frame_skip_mask,
10657     InterModeSearchState *search_state) {
10658   const SPEED_FEATURES *const sf = &cpi->sf;
10659   const AV1_COMMON *const cm = &cpi->common;
10660   const struct segmentation *const seg = &cm->seg;
10661   const MACROBLOCKD *const xd = &x->e_mbd;
10662   const MB_MODE_INFO *const mbmi = xd->mi[0];
10663   const unsigned char segment_id = mbmi->segment_id;
10664   const MV_REFERENCE_FRAME *ref_frame = av1_mode_order[mode_index].ref_frame;
10665   const PREDICTION_MODE this_mode = av1_mode_order[mode_index].mode;
10666   int skip_motion_mode = 0;
10667   if (mbmi->partition != PARTITION_NONE && mbmi->partition != PARTITION_SPLIT) {
10668     const int ref_type = av1_ref_frame_type(ref_frame);
10669     int skip_ref = ctx->skip_ref_frame_mask & (1 << ref_type);
10670     if (ref_type <= ALTREF_FRAME && skip_ref) {
10671       // Since the compound ref modes depends on the motion estimation result of
10672       // two single ref modes( best mv of single ref modes as the start point )
10673       // If current single ref mode is marked skip, we need to check if it will
10674       // be used in compound ref modes.
10675       for (int r = ALTREF_FRAME + 1; r < MODE_CTX_REF_FRAMES; ++r) {
10676         if (!(ctx->skip_ref_frame_mask & (1 << r))) {
10677           const MV_REFERENCE_FRAME *rf = ref_frame_map[r - REF_FRAMES];
10678           if (rf[0] == ref_type || rf[1] == ref_type) {
10679             // Found a not skipped compound ref mode which contains current
10680             // single ref. So this single ref can't be skipped completly
10681             // Just skip it's motion mode search, still try it's simple
10682             // transition mode.
10683             skip_motion_mode = 1;
10684             skip_ref = 0;
10685             break;
10686           }
10687         }
10688       }
10689     }
10690     if (skip_ref) return 1;
10691   }
10692 
10693   if (cpi->sf.mode_pruning_based_on_two_pass_partition_search &&
10694       !x->cb_partition_scan) {
10695     const int mi_width = mi_size_wide[bsize];
10696     const int mi_height = mi_size_high[bsize];
10697     int found = 0;
10698     // Search in the stats table to see if the ref frames have been used in the
10699     // first pass of partition search.
10700     for (int row = mi_row; row < mi_row + mi_width && !found;
10701          row += FIRST_PARTITION_PASS_SAMPLE_REGION) {
10702       for (int col = mi_col; col < mi_col + mi_height && !found;
10703            col += FIRST_PARTITION_PASS_SAMPLE_REGION) {
10704         const int index = av1_first_partition_pass_stats_index(row, col);
10705         const FIRST_PARTITION_PASS_STATS *const stats =
10706             &x->first_partition_pass_stats[index];
10707         if (stats->ref0_counts[ref_frame[0]] &&
10708             (ref_frame[1] < 0 || stats->ref1_counts[ref_frame[1]])) {
10709           found = 1;
10710           break;
10711         }
10712       }
10713     }
10714     if (!found) return 1;
10715   }
10716 
10717   if (ref_frame[0] > INTRA_FRAME && ref_frame[1] == INTRA_FRAME) {
10718     // Mode must by compatible
10719     if (!is_interintra_allowed_mode(this_mode)) return 1;
10720     if (!is_interintra_allowed_bsize(bsize)) return 1;
10721   }
10722 
10723   // This is only used in motion vector unit test.
10724   if (cpi->oxcf.motion_vector_unit_test && ref_frame[0] == INTRA_FRAME)
10725     return 1;
10726 
10727   if (ref_frame[0] == INTRA_FRAME) {
10728     if (this_mode != DC_PRED) {
10729       // Disable intra modes other than DC_PRED for blocks with low variance
10730       // Threshold for intra skipping based on source variance
10731       // TODO(debargha): Specialize the threshold for super block sizes
10732       const unsigned int skip_intra_var_thresh = 64;
10733       if ((sf->mode_search_skip_flags & FLAG_SKIP_INTRA_LOWVAR) &&
10734           x->source_variance < skip_intra_var_thresh)
10735         return 1;
10736     }
10737   } else {
10738     if (!is_comp_ref_allowed(bsize) && ref_frame[1] > INTRA_FRAME) return 1;
10739   }
10740 
10741   const int comp_pred = ref_frame[1] > INTRA_FRAME;
10742   if (comp_pred) {
10743     if (!cpi->allow_comp_inter_inter) return 1;
10744 
10745     if (cm->reference_mode == SINGLE_REFERENCE) return 1;
10746 
10747     // Skip compound inter modes if ARF is not available.
10748     if (!(cpi->ref_frame_flags & ref_frame_flag_list[ref_frame[1]])) return 1;
10749 
10750     // Do not allow compound prediction if the segment level reference frame
10751     // feature is in use as in this case there can only be one reference.
10752     if (segfeature_active(seg, segment_id, SEG_LVL_REF_FRAME)) return 1;
10753   }
10754 
10755   if (sf->selective_ref_frame) {
10756     if (sf->selective_ref_frame >= 2 || x->cb_partition_scan) {
10757       if (ref_frame[0] == ALTREF2_FRAME || ref_frame[1] == ALTREF2_FRAME)
10758         if (get_relative_dist(
10759                 cm, cm->cur_frame->ref_frame_offset[ALTREF2_FRAME - LAST_FRAME],
10760                 cm->frame_offset) < 0)
10761           return 1;
10762       if (ref_frame[0] == BWDREF_FRAME || ref_frame[1] == BWDREF_FRAME)
10763         if (get_relative_dist(
10764                 cm, cm->cur_frame->ref_frame_offset[BWDREF_FRAME - LAST_FRAME],
10765                 cm->frame_offset) < 0)
10766           return 1;
10767     }
10768     if (ref_frame[0] == LAST3_FRAME || ref_frame[1] == LAST3_FRAME)
10769       if (get_relative_dist(
10770               cm, cm->cur_frame->ref_frame_offset[LAST3_FRAME - LAST_FRAME],
10771               cm->cur_frame->ref_frame_offset[GOLDEN_FRAME - LAST_FRAME]) <= 0)
10772         return 1;
10773     if (ref_frame[0] == LAST2_FRAME || ref_frame[1] == LAST2_FRAME)
10774       if (get_relative_dist(
10775               cm, cm->cur_frame->ref_frame_offset[LAST2_FRAME - LAST_FRAME],
10776               cm->cur_frame->ref_frame_offset[GOLDEN_FRAME - LAST_FRAME]) <= 0)
10777         return 1;
10778   }
10779 
10780   // One-sided compound is used only when all reference frames are one-sided.
10781   if (sf->selective_ref_frame && comp_pred && !cpi->all_one_sided_refs) {
10782     unsigned int ref_offsets[2];
10783     for (int i = 0; i < 2; ++i) {
10784       const int buf_idx = cm->frame_refs[ref_frame[i] - LAST_FRAME].idx;
10785       assert(buf_idx >= 0);
10786       ref_offsets[i] = cm->buffer_pool->frame_bufs[buf_idx].cur_frame_offset;
10787     }
10788     if ((get_relative_dist(cm, ref_offsets[0], cm->frame_offset) <= 0 &&
10789          get_relative_dist(cm, ref_offsets[1], cm->frame_offset) <= 0) ||
10790         (get_relative_dist(cm, ref_offsets[0], cm->frame_offset) > 0 &&
10791          get_relative_dist(cm, ref_offsets[1], cm->frame_offset) > 0))
10792       return 1;
10793   }
10794 
10795   if (mode_skip_mask[ref_frame[0]] & (1 << this_mode)) {
10796     return 1;
10797   }
10798 
10799   if ((ref_frame_skip_mask[0] & (1 << ref_frame[0])) &&
10800       (ref_frame_skip_mask[1] & (1 << AOMMAX(0, ref_frame[1])))) {
10801     return 1;
10802   }
10803 
10804   if (skip_repeated_mv(cm, x, this_mode, ref_frame, search_state)) {
10805     return 1;
10806   }
10807   if (skip_motion_mode) {
10808     return 2;
10809   }
10810   return 0;
10811 }
10812 
init_mbmi(MB_MODE_INFO * mbmi,int mode_index,const AV1_COMMON * cm)10813 static INLINE void init_mbmi(MB_MODE_INFO *mbmi, int mode_index,
10814                              const AV1_COMMON *cm) {
10815   PALETTE_MODE_INFO *const pmi = &mbmi->palette_mode_info;
10816   PREDICTION_MODE this_mode = av1_mode_order[mode_index].mode;
10817   mbmi->ref_mv_idx = 0;
10818   mbmi->mode = this_mode;
10819   mbmi->uv_mode = UV_DC_PRED;
10820   mbmi->ref_frame[0] = av1_mode_order[mode_index].ref_frame[0];
10821   mbmi->ref_frame[1] = av1_mode_order[mode_index].ref_frame[1];
10822   pmi->palette_size[0] = 0;
10823   pmi->palette_size[1] = 0;
10824   mbmi->filter_intra_mode_info.use_filter_intra = 0;
10825   mbmi->mv[0].as_int = mbmi->mv[1].as_int = 0;
10826   mbmi->motion_mode = SIMPLE_TRANSLATION;
10827   mbmi->interintra_mode = (INTERINTRA_MODE)(II_DC_PRED - 1);
10828   set_default_interp_filters(mbmi, cm->interp_filter);
10829 }
10830 
handle_intra_mode(InterModeSearchState * search_state,const AV1_COMP * cpi,MACROBLOCK * x,BLOCK_SIZE bsize,int mi_row,int mi_col,int ref_frame_cost,const PICK_MODE_CONTEXT * ctx,int disable_skip,RD_STATS * rd_stats,RD_STATS * rd_stats_y,RD_STATS * rd_stats_uv)10831 static int64_t handle_intra_mode(InterModeSearchState *search_state,
10832                                  const AV1_COMP *cpi, MACROBLOCK *x,
10833                                  BLOCK_SIZE bsize, int mi_row, int mi_col,
10834                                  int ref_frame_cost,
10835                                  const PICK_MODE_CONTEXT *ctx, int disable_skip,
10836                                  RD_STATS *rd_stats, RD_STATS *rd_stats_y,
10837                                  RD_STATS *rd_stats_uv) {
10838   const AV1_COMMON *cm = &cpi->common;
10839   const SPEED_FEATURES *const sf = &cpi->sf;
10840   MACROBLOCKD *const xd = &x->e_mbd;
10841   MB_MODE_INFO *const mbmi = xd->mi[0];
10842   assert(mbmi->ref_frame[0] == INTRA_FRAME);
10843   PALETTE_MODE_INFO *const pmi = &mbmi->palette_mode_info;
10844   const int try_palette =
10845       av1_allow_palette(cm->allow_screen_content_tools, mbmi->sb_type);
10846   const int *const intra_mode_cost = x->mbmode_cost[size_group_lookup[bsize]];
10847   const int intra_cost_penalty = av1_get_intra_cost_penalty(
10848       cm->base_qindex, cm->y_dc_delta_q, cm->seq_params.bit_depth);
10849   const int rows = block_size_high[bsize];
10850   const int cols = block_size_wide[bsize];
10851   const int num_planes = av1_num_planes(cm);
10852   const int skip_ctx = av1_get_skip_context(xd);
10853 
10854   int known_rate = intra_mode_cost[mbmi->mode];
10855   known_rate += ref_frame_cost;
10856   if (mbmi->mode != DC_PRED && mbmi->mode != PAETH_PRED)
10857     known_rate += intra_cost_penalty;
10858   known_rate += AOMMIN(x->skip_cost[skip_ctx][0], x->skip_cost[skip_ctx][1]);
10859   const int64_t known_rd = RDCOST(x->rdmult, known_rate, 0);
10860   if (known_rd > search_state->best_rd) {
10861     search_state->skip_intra_modes = 1;
10862     return INT64_MAX;
10863   }
10864 
10865   TX_SIZE uv_tx;
10866   int is_directional_mode = av1_is_directional_mode(mbmi->mode);
10867   if (is_directional_mode && av1_use_angle_delta(bsize)) {
10868     int rate_dummy;
10869     int64_t model_rd = INT64_MAX;
10870     if (!search_state->angle_stats_ready) {
10871       const int src_stride = x->plane[0].src.stride;
10872       const uint8_t *src = x->plane[0].src.buf;
10873       if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH)
10874         highbd_angle_estimation(src, src_stride, rows, cols, bsize,
10875                                 search_state->directional_mode_skip_mask);
10876       else
10877         angle_estimation(src, src_stride, rows, cols, bsize,
10878                          search_state->directional_mode_skip_mask);
10879       search_state->angle_stats_ready = 1;
10880     }
10881     if (search_state->directional_mode_skip_mask[mbmi->mode]) return INT64_MAX;
10882     av1_init_rd_stats(rd_stats_y);
10883     rd_stats_y->rate = INT_MAX;
10884     rd_pick_intra_angle_sby(cpi, x, mi_row, mi_col, &rate_dummy, rd_stats_y,
10885                             bsize, intra_mode_cost[mbmi->mode],
10886                             search_state->best_rd, &model_rd);
10887   } else {
10888     av1_init_rd_stats(rd_stats_y);
10889     mbmi->angle_delta[PLANE_TYPE_Y] = 0;
10890     super_block_yrd(cpi, x, rd_stats_y, bsize, search_state->best_rd);
10891   }
10892   uint8_t best_blk_skip[MAX_MIB_SIZE * MAX_MIB_SIZE];
10893   memcpy(best_blk_skip, x->blk_skip,
10894          sizeof(best_blk_skip[0]) * ctx->num_4x4_blk);
10895   int try_filter_intra = 0;
10896   int64_t best_rd_tmp = INT64_MAX;
10897   if (mbmi->mode == DC_PRED && av1_filter_intra_allowed_bsize(cm, bsize)) {
10898     if (rd_stats_y->rate != INT_MAX) {
10899       const int tmp_rate = rd_stats_y->rate + x->filter_intra_cost[bsize][0] +
10900                            intra_mode_cost[mbmi->mode];
10901       best_rd_tmp = RDCOST(x->rdmult, tmp_rate, rd_stats_y->dist);
10902       try_filter_intra = !((best_rd_tmp / 2) > search_state->best_rd);
10903     } else {
10904       try_filter_intra = !(search_state->best_mbmode.skip);
10905     }
10906   }
10907   if (try_filter_intra) {
10908     RD_STATS rd_stats_y_fi;
10909     int filter_intra_selected_flag = 0;
10910     TX_SIZE best_tx_size = mbmi->tx_size;
10911     TX_TYPE best_txk_type[TXK_TYPE_BUF_LEN];
10912     memcpy(best_txk_type, mbmi->txk_type,
10913            sizeof(*best_txk_type) * TXK_TYPE_BUF_LEN);
10914     FILTER_INTRA_MODE best_fi_mode = FILTER_DC_PRED;
10915 
10916     mbmi->filter_intra_mode_info.use_filter_intra = 1;
10917     for (FILTER_INTRA_MODE fi_mode = FILTER_DC_PRED;
10918          fi_mode < FILTER_INTRA_MODES; ++fi_mode) {
10919       int64_t this_rd_tmp;
10920       mbmi->filter_intra_mode_info.filter_intra_mode = fi_mode;
10921       super_block_yrd(cpi, x, &rd_stats_y_fi, bsize, search_state->best_rd);
10922       if (rd_stats_y_fi.rate == INT_MAX) {
10923         continue;
10924       }
10925       const int this_rate_tmp =
10926           rd_stats_y_fi.rate +
10927           intra_mode_info_cost_y(cpi, x, mbmi, bsize,
10928                                  intra_mode_cost[mbmi->mode]);
10929       this_rd_tmp = RDCOST(x->rdmult, this_rate_tmp, rd_stats_y_fi.dist);
10930 
10931       if (this_rd_tmp != INT64_MAX && this_rd_tmp / 2 > search_state->best_rd) {
10932         break;
10933       }
10934       if (this_rd_tmp < best_rd_tmp) {
10935         best_tx_size = mbmi->tx_size;
10936         memcpy(best_txk_type, mbmi->txk_type,
10937                sizeof(*best_txk_type) * TXK_TYPE_BUF_LEN);
10938         memcpy(best_blk_skip, x->blk_skip,
10939                sizeof(best_blk_skip[0]) * ctx->num_4x4_blk);
10940         best_fi_mode = fi_mode;
10941         *rd_stats_y = rd_stats_y_fi;
10942         filter_intra_selected_flag = 1;
10943         best_rd_tmp = this_rd_tmp;
10944       }
10945     }
10946 
10947     mbmi->tx_size = best_tx_size;
10948     memcpy(mbmi->txk_type, best_txk_type,
10949            sizeof(*best_txk_type) * TXK_TYPE_BUF_LEN);
10950     memcpy(x->blk_skip, best_blk_skip,
10951            sizeof(x->blk_skip[0]) * ctx->num_4x4_blk);
10952 
10953     if (filter_intra_selected_flag) {
10954       mbmi->filter_intra_mode_info.use_filter_intra = 1;
10955       mbmi->filter_intra_mode_info.filter_intra_mode = best_fi_mode;
10956     } else {
10957       mbmi->filter_intra_mode_info.use_filter_intra = 0;
10958     }
10959   }
10960   if (rd_stats_y->rate == INT_MAX) return INT64_MAX;
10961   const int mode_cost_y =
10962       intra_mode_info_cost_y(cpi, x, mbmi, bsize, intra_mode_cost[mbmi->mode]);
10963   av1_init_rd_stats(rd_stats);
10964   av1_init_rd_stats(rd_stats_uv);
10965   if (num_planes > 1) {
10966     uv_tx = av1_get_tx_size(AOM_PLANE_U, xd);
10967     if (search_state->rate_uv_intra[uv_tx] == INT_MAX) {
10968       int rate_y =
10969           rd_stats_y->skip ? x->skip_cost[skip_ctx][1] : rd_stats_y->rate;
10970       const int64_t rdy =
10971           RDCOST(x->rdmult, rate_y + mode_cost_y, rd_stats_y->dist);
10972       if (search_state->best_rd < (INT64_MAX / 2) &&
10973           rdy > (search_state->best_rd + (search_state->best_rd >> 2))) {
10974         search_state->skip_intra_modes = 1;
10975         return INT64_MAX;
10976       }
10977       choose_intra_uv_mode(
10978           cpi, x, bsize, uv_tx, &search_state->rate_uv_intra[uv_tx],
10979           &search_state->rate_uv_tokenonly[uv_tx],
10980           &search_state->dist_uvs[uv_tx], &search_state->skip_uvs[uv_tx],
10981           &search_state->mode_uv[uv_tx]);
10982       if (try_palette) search_state->pmi_uv[uv_tx] = *pmi;
10983       search_state->uv_angle_delta[uv_tx] = mbmi->angle_delta[PLANE_TYPE_UV];
10984 
10985       const int uv_rate = search_state->rate_uv_tokenonly[uv_tx];
10986       const int64_t uv_dist = search_state->dist_uvs[uv_tx];
10987       const int64_t uv_rd = RDCOST(x->rdmult, uv_rate, uv_dist);
10988       if (uv_rd > search_state->best_rd) {
10989         search_state->skip_intra_modes = 1;
10990         return INT64_MAX;
10991       }
10992     }
10993 
10994     rd_stats_uv->rate = search_state->rate_uv_tokenonly[uv_tx];
10995     rd_stats_uv->dist = search_state->dist_uvs[uv_tx];
10996     rd_stats_uv->skip = search_state->skip_uvs[uv_tx];
10997     rd_stats->skip = rd_stats_y->skip && rd_stats_uv->skip;
10998     mbmi->uv_mode = search_state->mode_uv[uv_tx];
10999     if (try_palette) {
11000       pmi->palette_size[1] = search_state->pmi_uv[uv_tx].palette_size[1];
11001       memcpy(pmi->palette_colors + PALETTE_MAX_SIZE,
11002              search_state->pmi_uv[uv_tx].palette_colors + PALETTE_MAX_SIZE,
11003              2 * PALETTE_MAX_SIZE * sizeof(pmi->palette_colors[0]));
11004     }
11005     mbmi->angle_delta[PLANE_TYPE_UV] = search_state->uv_angle_delta[uv_tx];
11006   }
11007   rd_stats->rate = rd_stats_y->rate + mode_cost_y;
11008   if (!xd->lossless[mbmi->segment_id] && block_signals_txsize(bsize)) {
11009     // super_block_yrd above includes the cost of the tx_size in the
11010     // tokenonly rate, but for intra blocks, tx_size is always coded
11011     // (prediction granularity), so we account for it in the full rate,
11012     // not the tokenonly rate.
11013     rd_stats_y->rate -= tx_size_cost(cm, x, bsize, mbmi->tx_size);
11014   }
11015   if (num_planes > 1 && !x->skip_chroma_rd) {
11016     const int uv_mode_cost =
11017         x->intra_uv_mode_cost[is_cfl_allowed(xd)][mbmi->mode][mbmi->uv_mode];
11018     rd_stats->rate +=
11019         rd_stats_uv->rate +
11020         intra_mode_info_cost_uv(cpi, x, mbmi, bsize, uv_mode_cost);
11021   }
11022   if (mbmi->mode != DC_PRED && mbmi->mode != PAETH_PRED)
11023     rd_stats->rate += intra_cost_penalty;
11024   rd_stats->dist = rd_stats_y->dist + rd_stats_uv->dist;
11025 
11026   // Estimate the reference frame signaling cost and add it
11027   // to the rolling cost variable.
11028   rd_stats->rate += ref_frame_cost;
11029   if (rd_stats->skip) {
11030     // Back out the coefficient coding costs
11031     rd_stats->rate -= (rd_stats_y->rate + rd_stats_uv->rate);
11032     rd_stats_y->rate = 0;
11033     rd_stats_uv->rate = 0;
11034     // Cost the skip mb case
11035     rd_stats->rate += x->skip_cost[skip_ctx][1];
11036   } else {
11037     // Add in the cost of the no skip flag.
11038     rd_stats->rate += x->skip_cost[skip_ctx][0];
11039   }
11040   // Calculate the final RD estimate for this mode.
11041   const int64_t this_rd = RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist);
11042   // Keep record of best intra rd
11043   if (this_rd < search_state->best_intra_rd) {
11044     search_state->best_intra_rd = this_rd;
11045     search_state->best_intra_mode = mbmi->mode;
11046   }
11047 
11048   if (sf->skip_intra_in_interframe) {
11049     if (search_state->best_rd < (INT64_MAX / 2) &&
11050         this_rd > (search_state->best_rd + (search_state->best_rd >> 1)))
11051       search_state->skip_intra_modes = 1;
11052   }
11053 
11054   if (!disable_skip) {
11055     for (int i = 0; i < REFERENCE_MODES; ++i)
11056       search_state->best_pred_rd[i] =
11057           AOMMIN(search_state->best_pred_rd[i], this_rd);
11058   }
11059   return this_rd;
11060 }
11061 
collect_single_states(MACROBLOCK * x,InterModeSearchState * search_state,const MB_MODE_INFO * const mbmi)11062 static void collect_single_states(MACROBLOCK *x,
11063                                   InterModeSearchState *search_state,
11064                                   const MB_MODE_INFO *const mbmi) {
11065   int i, j;
11066   const MV_REFERENCE_FRAME ref_frame = mbmi->ref_frame[0];
11067   const PREDICTION_MODE this_mode = mbmi->mode;
11068   const int dir = ref_frame <= GOLDEN_FRAME ? 0 : 1;
11069   const int mode_offset = INTER_OFFSET(this_mode);
11070   const int ref_set = get_drl_refmv_count(x, mbmi->ref_frame, this_mode);
11071 
11072   // Simple rd
11073   int64_t simple_rd = search_state->simple_rd[this_mode][0][ref_frame];
11074   for (int ref_mv_idx = 1; ref_mv_idx < ref_set; ++ref_mv_idx) {
11075     int64_t rd = search_state->simple_rd[this_mode][ref_mv_idx][ref_frame];
11076     if (rd < simple_rd) simple_rd = rd;
11077   }
11078 
11079   // Insertion sort of single_state
11080   SingleInterModeState this_state_s = { simple_rd, ref_frame, 1 };
11081   SingleInterModeState *state_s = search_state->single_state[dir][mode_offset];
11082   i = search_state->single_state_cnt[dir][mode_offset];
11083   for (j = i; j > 0 && state_s[j - 1].rd > this_state_s.rd; --j)
11084     state_s[j] = state_s[j - 1];
11085   state_s[j] = this_state_s;
11086   search_state->single_state_cnt[dir][mode_offset]++;
11087 
11088   // Modelled rd
11089   int64_t modelled_rd = search_state->modelled_rd[this_mode][0][ref_frame];
11090   for (int ref_mv_idx = 1; ref_mv_idx < ref_set; ++ref_mv_idx) {
11091     int64_t rd = search_state->modelled_rd[this_mode][ref_mv_idx][ref_frame];
11092     if (rd < modelled_rd) modelled_rd = rd;
11093   }
11094 
11095   // Insertion sort of single_state_modelled
11096   SingleInterModeState this_state_m = { modelled_rd, ref_frame, 1 };
11097   SingleInterModeState *state_m =
11098       search_state->single_state_modelled[dir][mode_offset];
11099   i = search_state->single_state_modelled_cnt[dir][mode_offset];
11100   for (j = i; j > 0 && state_m[j - 1].rd > this_state_m.rd; --j)
11101     state_m[j] = state_m[j - 1];
11102   state_m[j] = this_state_m;
11103   search_state->single_state_modelled_cnt[dir][mode_offset]++;
11104 }
11105 
analyze_single_states(const AV1_COMP * cpi,InterModeSearchState * search_state)11106 static void analyze_single_states(const AV1_COMP *cpi,
11107                                   InterModeSearchState *search_state) {
11108   int i, j, dir, mode;
11109   if (cpi->sf.prune_comp_search_by_single_result >= 1) {
11110     for (dir = 0; dir < 2; ++dir) {
11111       int64_t best_rd;
11112       SingleInterModeState(*state)[FWD_REFS];
11113 
11114       // Use the best rd of GLOBALMV or NEWMV to prune the unlikely
11115       // reference frames for all the modes (NEARESTMV and NEARMV may not
11116       // have same motion vectors). Always keep the best of each mode
11117       // because it might form the best possible combination with other mode.
11118       state = search_state->single_state[dir];
11119       best_rd = AOMMIN(state[INTER_OFFSET(NEWMV)][0].rd,
11120                        state[INTER_OFFSET(GLOBALMV)][0].rd);
11121       for (mode = 0; mode < SINGLE_INTER_MODE_NUM; ++mode) {
11122         for (i = 1; i < search_state->single_state_cnt[dir][mode]; ++i) {
11123           if (state[mode][i].rd != INT64_MAX &&
11124               (state[mode][i].rd >> 1) > best_rd) {
11125             state[mode][i].valid = 0;
11126           }
11127         }
11128       }
11129 
11130       state = search_state->single_state_modelled[dir];
11131       best_rd = AOMMIN(state[INTER_OFFSET(NEWMV)][0].rd,
11132                        state[INTER_OFFSET(GLOBALMV)][0].rd);
11133       for (mode = 0; mode < SINGLE_INTER_MODE_NUM; ++mode) {
11134         for (i = 1; i < search_state->single_state_modelled_cnt[dir][mode];
11135              ++i) {
11136           if (state[mode][i].rd != INT64_MAX &&
11137               (state[mode][i].rd >> 1) > best_rd) {
11138             state[mode][i].valid = 0;
11139           }
11140         }
11141       }
11142     }
11143   }
11144 
11145   // Ordering by simple rd first, then by modelled rd
11146   for (dir = 0; dir < 2; ++dir) {
11147     for (mode = 0; mode < SINGLE_INTER_MODE_NUM; ++mode) {
11148       const int state_cnt_s = search_state->single_state_cnt[dir][mode];
11149       const int state_cnt_m =
11150           search_state->single_state_modelled_cnt[dir][mode];
11151       SingleInterModeState *state_s = search_state->single_state[dir][mode];
11152       SingleInterModeState *state_m =
11153           search_state->single_state_modelled[dir][mode];
11154       int count = 0;
11155       const int max_candidates = AOMMAX(state_cnt_s, state_cnt_m);
11156       for (i = 0; i < state_cnt_s; ++i) {
11157         if (state_s[i].rd == INT64_MAX) break;
11158         if (state_s[i].valid)
11159           search_state->single_rd_order[dir][mode][count++] =
11160               state_s[i].ref_frame;
11161       }
11162       if (count < max_candidates) {
11163         for (i = 0; i < state_cnt_m; ++i) {
11164           if (state_m[i].rd == INT64_MAX) break;
11165           if (state_m[i].valid) {
11166             int ref_frame = state_m[i].ref_frame;
11167             int match = 0;
11168             // Check if existing already
11169             for (j = 0; j < count; ++j) {
11170               if (search_state->single_rd_order[dir][mode][j] == ref_frame) {
11171                 match = 1;
11172                 break;
11173               }
11174             }
11175             if (!match) {
11176               // Check if this ref_frame is removed in simple rd
11177               int valid = 1;
11178               for (j = 0; j < state_cnt_s; j++) {
11179                 if (ref_frame == state_s[j].ref_frame && !state_s[j].valid) {
11180                   valid = 0;
11181                   break;
11182                 }
11183               }
11184               if (valid)
11185                 search_state->single_rd_order[dir][mode][count++] = ref_frame;
11186             }
11187             if (count >= max_candidates) break;
11188           }
11189         }
11190       }
11191     }
11192   }
11193 }
11194 
compound_skip_get_candidates(const AV1_COMP * cpi,const InterModeSearchState * search_state,const int dir,const PREDICTION_MODE mode)11195 static int compound_skip_get_candidates(
11196     const AV1_COMP *cpi, const InterModeSearchState *search_state,
11197     const int dir, const PREDICTION_MODE mode) {
11198   const int mode_offset = INTER_OFFSET(mode);
11199   const SingleInterModeState *state =
11200       search_state->single_state[dir][mode_offset];
11201   const SingleInterModeState *state_modelled =
11202       search_state->single_state_modelled[dir][mode_offset];
11203   int max_candidates = 0;
11204   int candidates;
11205 
11206   for (int i = 0; i < FWD_REFS; ++i) {
11207     if (search_state->single_rd_order[dir][mode_offset][i] == NONE_FRAME) break;
11208     max_candidates++;
11209   }
11210 
11211   candidates = max_candidates;
11212   if (cpi->sf.prune_comp_search_by_single_result >= 2) {
11213     candidates = AOMMIN(2, max_candidates);
11214   }
11215   if (cpi->sf.prune_comp_search_by_single_result >= 3) {
11216     if (state[0].rd != INT64_MAX && state_modelled[0].rd != INT64_MAX &&
11217         state[0].ref_frame == state_modelled[0].ref_frame)
11218       candidates = 1;
11219     if (mode == NEARMV || mode == GLOBALMV) candidates = 1;
11220   }
11221   return candidates;
11222 }
11223 
compound_skip_by_single_states(const AV1_COMP * cpi,const InterModeSearchState * search_state,const PREDICTION_MODE this_mode,const MV_REFERENCE_FRAME ref_frame,const MV_REFERENCE_FRAME second_ref_frame,const MACROBLOCK * x)11224 static int compound_skip_by_single_states(
11225     const AV1_COMP *cpi, const InterModeSearchState *search_state,
11226     const PREDICTION_MODE this_mode, const MV_REFERENCE_FRAME ref_frame,
11227     const MV_REFERENCE_FRAME second_ref_frame, const MACROBLOCK *x) {
11228   const MV_REFERENCE_FRAME refs[2] = { ref_frame, second_ref_frame };
11229   const int mode[2] = { compound_ref0_mode(this_mode),
11230                         compound_ref1_mode(this_mode) };
11231   const int mode_offset[2] = { INTER_OFFSET(mode[0]), INTER_OFFSET(mode[1]) };
11232   const int mode_dir[2] = { refs[0] <= GOLDEN_FRAME ? 0 : 1,
11233                             refs[1] <= GOLDEN_FRAME ? 0 : 1 };
11234   int ref_searched[2] = { 0, 0 };
11235   int ref_mv_match[2] = { 1, 1 };
11236   int i, j;
11237 
11238   for (i = 0; i < 2; ++i) {
11239     const SingleInterModeState *state =
11240         search_state->single_state[mode_dir[i]][mode_offset[i]];
11241     const int state_cnt =
11242         search_state->single_state_cnt[mode_dir[i]][mode_offset[i]];
11243     for (j = 0; j < state_cnt; ++j) {
11244       if (state[j].ref_frame == refs[i]) {
11245         ref_searched[i] = 1;
11246         break;
11247       }
11248     }
11249   }
11250 
11251   const int ref_set = get_drl_refmv_count(x, refs, this_mode);
11252   for (i = 0; i < 2; ++i) {
11253     if (mode[i] == NEARESTMV || mode[i] == NEARMV) {
11254       const MV_REFERENCE_FRAME single_refs[2] = { refs[i], NONE_FRAME };
11255       int idential = 1;
11256       for (int ref_mv_idx = 0; ref_mv_idx < ref_set; ref_mv_idx++) {
11257         int_mv single_mv;
11258         int_mv comp_mv;
11259         get_this_mv(&single_mv, mode[i], 0, ref_mv_idx, single_refs,
11260                     x->mbmi_ext);
11261         get_this_mv(&comp_mv, this_mode, i, ref_mv_idx, refs, x->mbmi_ext);
11262 
11263         idential &= (single_mv.as_int == comp_mv.as_int);
11264         if (!idential) {
11265           ref_mv_match[i] = 0;
11266           break;
11267         }
11268       }
11269     }
11270   }
11271 
11272   for (i = 0; i < 2; ++i) {
11273     if (ref_searched[i] && ref_mv_match[i]) {
11274       const int candidates =
11275           compound_skip_get_candidates(cpi, search_state, mode_dir[i], mode[i]);
11276       const MV_REFERENCE_FRAME *ref_order =
11277           search_state->single_rd_order[mode_dir[i]][mode_offset[i]];
11278       int match = 0;
11279       for (j = 0; j < candidates; ++j) {
11280         if (refs[i] == ref_order[j]) {
11281           match = 1;
11282           break;
11283         }
11284       }
11285       if (!match) return 1;
11286     }
11287   }
11288 
11289   return 0;
11290 }
11291 
sf_check_is_drop_ref(const MODE_DEFINITION * mode,InterModeSearchState * search_state)11292 static INLINE int sf_check_is_drop_ref(const MODE_DEFINITION *mode,
11293                                        InterModeSearchState *search_state) {
11294   const MV_REFERENCE_FRAME ref_frame = mode->ref_frame[0];
11295   const MV_REFERENCE_FRAME second_ref_frame = mode->ref_frame[1];
11296   if (search_state->num_available_refs > 2) {
11297     if ((ref_frame == search_state->dist_order_refs[0] &&
11298          second_ref_frame == search_state->dist_order_refs[1]) ||
11299         (ref_frame == search_state->dist_order_refs[1] &&
11300          second_ref_frame == search_state->dist_order_refs[0]))
11301       return 1;  // drop this pair of refs
11302   }
11303   return 0;
11304 }
11305 
sf_drop_ref_analyze(InterModeSearchState * search_state,const MODE_DEFINITION * mode,int64_t distortion2)11306 static INLINE void sf_drop_ref_analyze(InterModeSearchState *search_state,
11307                                        const MODE_DEFINITION *mode,
11308                                        int64_t distortion2) {
11309   const PREDICTION_MODE this_mode = mode->mode;
11310   MV_REFERENCE_FRAME ref_frame = mode->ref_frame[0];
11311   const int idx = ref_frame - LAST_FRAME;
11312   if (idx && distortion2 > search_state->dist_refs[idx]) {
11313     search_state->dist_refs[idx] = distortion2;
11314     search_state->dist_order_refs[idx] = ref_frame;
11315   }
11316 
11317   // Reach the last single ref prediction mode
11318   if (ref_frame == ALTREF_FRAME && this_mode == GLOBALMV) {
11319     // bubble sort dist_refs and the order index
11320     for (int i = 0; i < REF_FRAMES; ++i) {
11321       for (int k = i + 1; k < REF_FRAMES; ++k) {
11322         if (search_state->dist_refs[i] < search_state->dist_refs[k]) {
11323           int64_t tmp_dist = search_state->dist_refs[i];
11324           search_state->dist_refs[i] = search_state->dist_refs[k];
11325           search_state->dist_refs[k] = tmp_dist;
11326 
11327           int tmp_idx = search_state->dist_order_refs[i];
11328           search_state->dist_order_refs[i] = search_state->dist_order_refs[k];
11329           search_state->dist_order_refs[k] = tmp_idx;
11330         }
11331       }
11332     }
11333     for (int i = 0; i < REF_FRAMES; ++i) {
11334       if (search_state->dist_refs[i] == -1) break;
11335       search_state->num_available_refs = i;
11336     }
11337     search_state->num_available_refs++;
11338   }
11339 }
11340 
alloc_compound_type_rd_buffers(AV1_COMMON * const cm,CompoundTypeRdBuffers * const bufs)11341 static void alloc_compound_type_rd_buffers(AV1_COMMON *const cm,
11342                                            CompoundTypeRdBuffers *const bufs) {
11343   CHECK_MEM_ERROR(
11344       cm, bufs->pred0,
11345       (uint8_t *)aom_memalign(16, 2 * MAX_SB_SQUARE * sizeof(*bufs->pred0)));
11346   CHECK_MEM_ERROR(
11347       cm, bufs->pred1,
11348       (uint8_t *)aom_memalign(16, 2 * MAX_SB_SQUARE * sizeof(*bufs->pred1)));
11349   CHECK_MEM_ERROR(
11350       cm, bufs->residual1,
11351       (int16_t *)aom_memalign(32, MAX_SB_SQUARE * sizeof(*bufs->residual1)));
11352   CHECK_MEM_ERROR(
11353       cm, bufs->diff10,
11354       (int16_t *)aom_memalign(32, MAX_SB_SQUARE * sizeof(*bufs->diff10)));
11355   CHECK_MEM_ERROR(cm, bufs->tmp_best_mask_buf,
11356                   (uint8_t *)aom_malloc(2 * MAX_SB_SQUARE *
11357                                         sizeof(*bufs->tmp_best_mask_buf)));
11358 }
11359 
release_compound_type_rd_buffers(CompoundTypeRdBuffers * const bufs)11360 static void release_compound_type_rd_buffers(
11361     CompoundTypeRdBuffers *const bufs) {
11362   aom_free(bufs->pred0);
11363   aom_free(bufs->pred1);
11364   aom_free(bufs->residual1);
11365   aom_free(bufs->diff10);
11366   aom_free(bufs->tmp_best_mask_buf);
11367   av1_zero(*bufs);  // Set all pointers to NULL for safety.
11368 }
11369 
av1_rd_pick_inter_mode_sb(AV1_COMP * cpi,TileDataEnc * tile_data,MACROBLOCK * x,int mi_row,int mi_col,RD_STATS * rd_cost,BLOCK_SIZE bsize,PICK_MODE_CONTEXT * ctx,int64_t best_rd_so_far)11370 void av1_rd_pick_inter_mode_sb(AV1_COMP *cpi, TileDataEnc *tile_data,
11371                                MACROBLOCK *x, int mi_row, int mi_col,
11372                                RD_STATS *rd_cost, BLOCK_SIZE bsize,
11373                                PICK_MODE_CONTEXT *ctx, int64_t best_rd_so_far) {
11374   AV1_COMMON *const cm = &cpi->common;
11375   const int num_planes = av1_num_planes(cm);
11376   const SPEED_FEATURES *const sf = &cpi->sf;
11377   MACROBLOCKD *const xd = &x->e_mbd;
11378   MB_MODE_INFO *const mbmi = xd->mi[0];
11379   const int try_palette =
11380       av1_allow_palette(cm->allow_screen_content_tools, mbmi->sb_type);
11381   PALETTE_MODE_INFO *const pmi = &mbmi->palette_mode_info;
11382   const struct segmentation *const seg = &cm->seg;
11383   PREDICTION_MODE this_mode;
11384   unsigned char segment_id = mbmi->segment_id;
11385   int i;
11386   struct buf_2d yv12_mb[REF_FRAMES][MAX_MB_PLANE];
11387   unsigned int ref_costs_single[REF_FRAMES];
11388   unsigned int ref_costs_comp[REF_FRAMES][REF_FRAMES];
11389   int *comp_inter_cost = x->comp_inter_cost[av1_get_reference_mode_context(xd)];
11390   int *mode_map = tile_data->mode_map[bsize];
11391   uint32_t mode_skip_mask[REF_FRAMES];
11392   uint16_t ref_frame_skip_mask[2];
11393 
11394   InterModeSearchState search_state;
11395   init_inter_mode_search_state(&search_state, cpi, tile_data, x, bsize,
11396                                best_rd_so_far);
11397   INTERINTRA_MODE interintra_modes[REF_FRAMES] = {
11398     INTERINTRA_MODES, INTERINTRA_MODES, INTERINTRA_MODES, INTERINTRA_MODES,
11399     INTERINTRA_MODES, INTERINTRA_MODES, INTERINTRA_MODES, INTERINTRA_MODES
11400   };
11401   HandleInterModeArgs args = {
11402     { NULL },  { MAX_SB_SIZE, MAX_SB_SIZE, MAX_SB_SIZE },
11403     { NULL },  { MAX_SB_SIZE >> 1, MAX_SB_SIZE >> 1, MAX_SB_SIZE >> 1 },
11404     NULL,      NULL,
11405     NULL,      search_state.modelled_rd,
11406     { { 0 } }, INT_MAX,
11407     INT_MAX,   search_state.simple_rd,
11408     0,         interintra_modes
11409   };
11410   for (i = 0; i < REF_FRAMES; ++i) x->pred_sse[i] = INT_MAX;
11411 
11412   av1_invalid_rd_stats(rd_cost);
11413 
11414   // init params, set frame modes, speed features
11415   set_params_rd_pick_inter_mode(
11416       cpi, x, &args, bsize, mi_row, mi_col, ref_frame_skip_mask, mode_skip_mask,
11417       ctx->skip_ref_frame_mask, ref_costs_single, ref_costs_comp, yv12_mb);
11418 
11419 #if CONFIG_COLLECT_INTER_MODE_RD_STATS
11420   int64_t best_est_rd = INT64_MAX;
11421   // TODO(angiebird): Turn this on when this speed feature is well tested
11422 #if 1
11423   const InterModeRdModel *md = &tile_data->inter_mode_rd_models[bsize];
11424   const int do_tx_search = !md->ready;
11425 #else
11426   const int do_tx_search = 1;
11427 #endif
11428   InterModesInfo *inter_modes_info = &tile_data->inter_modes_info;
11429   inter_modes_info->num = 0;
11430 #endif
11431 
11432   int intra_mode_num = 0;
11433   int intra_mode_idx_ls[MAX_MODES];
11434   int reach_first_comp_mode = 0;
11435 
11436   // Temporary buffers used by handle_inter_mode().
11437   // We allocate them once and reuse it in every call to that function.
11438   // Note: Must be allocated on the heap due to large size of the arrays.
11439   uint8_t *tmp_buf_orig;
11440   CHECK_MEM_ERROR(
11441       cm, tmp_buf_orig,
11442       (uint8_t *)aom_memalign(32, 2 * MAX_MB_PLANE * MAX_SB_SQUARE));
11443   uint8_t *const tmp_buf = get_buf_by_bd(xd, tmp_buf_orig);
11444 
11445   CompoundTypeRdBuffers rd_buffers;
11446   alloc_compound_type_rd_buffers(cm, &rd_buffers);
11447 
11448   for (int midx = 0; midx < MAX_MODES; ++midx) {
11449     int mode_index = mode_map[midx];
11450     int64_t this_rd = INT64_MAX;
11451     int disable_skip = 0;
11452     int rate2 = 0, rate_y = 0, rate_uv = 0;
11453     int64_t distortion2 = 0;
11454     int skippable = 0;
11455     int this_skip2 = 0;
11456     const MODE_DEFINITION *mode_order = &av1_mode_order[mode_index];
11457     const MV_REFERENCE_FRAME ref_frame = mode_order->ref_frame[0];
11458     const MV_REFERENCE_FRAME second_ref_frame = mode_order->ref_frame[1];
11459     const int comp_pred = second_ref_frame > INTRA_FRAME;
11460     this_mode = mode_order->mode;
11461 
11462     init_mbmi(mbmi, mode_index, cm);
11463 
11464     x->skip = 0;
11465     set_ref_ptrs(cm, xd, ref_frame, second_ref_frame);
11466 
11467     // Reach the first compound prediction mode
11468     if (sf->prune_comp_search_by_single_result > 0 && comp_pred &&
11469         reach_first_comp_mode == 0) {
11470       analyze_single_states(cpi, &search_state);
11471       reach_first_comp_mode = 1;
11472     }
11473     const int ret = inter_mode_search_order_independent_skip(
11474         cpi, ctx, x, bsize, mode_index, mi_row, mi_col, mode_skip_mask,
11475         ref_frame_skip_mask, &search_state);
11476     if (ret == 1) continue;
11477     args.skip_motion_mode = (ret == 2);
11478 
11479     if (sf->drop_ref && comp_pred) {
11480       if (sf_check_is_drop_ref(mode_order, &search_state)) {
11481         continue;
11482       }
11483     }
11484 
11485     if (search_state.best_rd < search_state.mode_threshold[mode_index])
11486       continue;
11487 
11488     if (sf->prune_comp_search_by_single_result > 0 && comp_pred) {
11489       if (compound_skip_by_single_states(cpi, &search_state, this_mode,
11490                                          ref_frame, second_ref_frame, x))
11491         continue;
11492     }
11493 
11494     const int ref_frame_cost = comp_pred
11495                                    ? ref_costs_comp[ref_frame][second_ref_frame]
11496                                    : ref_costs_single[ref_frame];
11497     const int compmode_cost =
11498         is_comp_ref_allowed(mbmi->sb_type) ? comp_inter_cost[comp_pred] : 0;
11499     const int real_compmode_cost =
11500         cm->reference_mode == REFERENCE_MODE_SELECT ? compmode_cost : 0;
11501 
11502     if (comp_pred) {
11503       if ((sf->mode_search_skip_flags & FLAG_SKIP_COMP_BESTINTRA) &&
11504           search_state.best_mode_index >= 0 &&
11505           search_state.best_mbmode.ref_frame[0] == INTRA_FRAME)
11506         continue;
11507     }
11508 
11509     if (ref_frame == INTRA_FRAME) {
11510       if (sf->adaptive_mode_search)
11511         if ((x->source_variance << num_pels_log2_lookup[bsize]) >
11512             search_state.best_pred_sse)
11513           continue;
11514 
11515       if (this_mode != DC_PRED) {
11516         // Only search the oblique modes if the best so far is
11517         // one of the neighboring directional modes
11518         if ((sf->mode_search_skip_flags & FLAG_SKIP_INTRA_BESTINTER) &&
11519             (this_mode >= D45_PRED && this_mode <= PAETH_PRED)) {
11520           if (search_state.best_mode_index >= 0 &&
11521               search_state.best_mbmode.ref_frame[0] > INTRA_FRAME)
11522             continue;
11523         }
11524         if (sf->mode_search_skip_flags & FLAG_SKIP_INTRA_DIRMISMATCH) {
11525           if (conditional_skipintra(this_mode, search_state.best_intra_mode))
11526             continue;
11527         }
11528       }
11529     }
11530 
11531     // Select prediction reference frames.
11532     for (i = 0; i < num_planes; i++) {
11533       xd->plane[i].pre[0] = yv12_mb[ref_frame][i];
11534       if (comp_pred) xd->plane[i].pre[1] = yv12_mb[second_ref_frame][i];
11535     }
11536 
11537     if (ref_frame == INTRA_FRAME) {
11538       intra_mode_idx_ls[intra_mode_num++] = mode_index;
11539       continue;
11540     } else {
11541       mbmi->angle_delta[PLANE_TYPE_Y] = 0;
11542       mbmi->angle_delta[PLANE_TYPE_UV] = 0;
11543       mbmi->filter_intra_mode_info.use_filter_intra = 0;
11544       mbmi->ref_mv_idx = 0;
11545       int64_t ref_best_rd = search_state.best_rd;
11546       {
11547         RD_STATS rd_stats, rd_stats_y, rd_stats_uv;
11548         av1_init_rd_stats(&rd_stats);
11549         rd_stats.rate = rate2;
11550 
11551         // Point to variables that are maintained between loop iterations
11552         args.single_newmv = search_state.single_newmv;
11553         args.single_newmv_rate = search_state.single_newmv_rate;
11554         args.single_newmv_valid = search_state.single_newmv_valid;
11555         args.single_comp_cost = real_compmode_cost;
11556         args.ref_frame_cost = ref_frame_cost;
11557 #if CONFIG_COLLECT_INTER_MODE_RD_STATS
11558         this_rd = handle_inter_mode(
11559             cpi, x, bsize, &rd_stats, &rd_stats_y, &rd_stats_uv, &disable_skip,
11560             mi_row, mi_col, &args, ref_best_rd, tmp_buf, &rd_buffers, tile_data,
11561             &best_est_rd, do_tx_search, inter_modes_info);
11562 #else
11563         this_rd = handle_inter_mode(cpi, x, bsize, &rd_stats, &rd_stats_y,
11564                                     &rd_stats_uv, &disable_skip, mi_row, mi_col,
11565                                     &args, ref_best_rd, tmp_buf, &rd_buffers);
11566 #endif
11567         rate2 = rd_stats.rate;
11568         skippable = rd_stats.skip;
11569         distortion2 = rd_stats.dist;
11570         rate_y = rd_stats_y.rate;
11571         rate_uv = rd_stats_uv.rate;
11572       }
11573 
11574       if (sf->prune_comp_search_by_single_result > 0 &&
11575           is_inter_singleref_mode(this_mode)) {
11576         collect_single_states(x, &search_state, mbmi);
11577       }
11578 
11579       if (this_rd == INT64_MAX) continue;
11580 
11581       this_skip2 = mbmi->skip;
11582       this_rd = RDCOST(x->rdmult, rate2, distortion2);
11583       if (this_skip2) {
11584         rate_y = 0;
11585         rate_uv = 0;
11586       }
11587     }
11588 
11589     // Did this mode help.. i.e. is it the new best mode
11590     if (this_rd < search_state.best_rd || x->skip) {
11591       int mode_excluded = 0;
11592       if (comp_pred) {
11593         mode_excluded = cm->reference_mode == SINGLE_REFERENCE;
11594       }
11595       if (!mode_excluded) {
11596         // Note index of best mode so far
11597         search_state.best_mode_index = mode_index;
11598 
11599         if (ref_frame == INTRA_FRAME) {
11600           /* required for left and above block mv */
11601           mbmi->mv[0].as_int = 0;
11602         } else {
11603           search_state.best_pred_sse = x->pred_sse[ref_frame];
11604         }
11605 
11606         rd_cost->rate = rate2;
11607         rd_cost->dist = distortion2;
11608         rd_cost->rdcost = this_rd;
11609         search_state.best_rd = this_rd;
11610         search_state.best_mbmode = *mbmi;
11611         search_state.best_skip2 = this_skip2;
11612         search_state.best_mode_skippable = skippable;
11613 #if CONFIG_COLLECT_INTER_MODE_RD_STATS
11614         if (do_tx_search) {
11615           // When do_tx_search == 0, handle_inter_mode won't provide correct
11616           // rate_y and rate_uv because txfm_search process is replaced by
11617           // rd estimation.
11618           // Therfore, we should avoid updating best_rate_y and best_rate_uv
11619           // here. These two values will be updated when txfm_search is called
11620           search_state.best_rate_y =
11621               rate_y +
11622               x->skip_cost[av1_get_skip_context(xd)][this_skip2 || skippable];
11623           search_state.best_rate_uv = rate_uv;
11624         }
11625 #else   // CONFIG_COLLECT_INTER_MODE_RD_STATS
11626         search_state.best_rate_y =
11627             rate_y +
11628             x->skip_cost[av1_get_skip_context(xd)][this_skip2 || skippable];
11629         search_state.best_rate_uv = rate_uv;
11630 #endif  // CONFIG_COLLECT_INTER_MODE_RD_STATS
11631         memcpy(ctx->blk_skip, x->blk_skip,
11632                sizeof(x->blk_skip[0]) * ctx->num_4x4_blk);
11633       }
11634     }
11635 
11636     /* keep record of best compound/single-only prediction */
11637     if (!disable_skip && ref_frame != INTRA_FRAME) {
11638       int64_t single_rd, hybrid_rd, single_rate, hybrid_rate;
11639 
11640       if (cm->reference_mode == REFERENCE_MODE_SELECT) {
11641         single_rate = rate2 - compmode_cost;
11642         hybrid_rate = rate2;
11643       } else {
11644         single_rate = rate2;
11645         hybrid_rate = rate2 + compmode_cost;
11646       }
11647 
11648       single_rd = RDCOST(x->rdmult, single_rate, distortion2);
11649       hybrid_rd = RDCOST(x->rdmult, hybrid_rate, distortion2);
11650 
11651       if (!comp_pred) {
11652         if (single_rd < search_state.best_pred_rd[SINGLE_REFERENCE])
11653           search_state.best_pred_rd[SINGLE_REFERENCE] = single_rd;
11654       } else {
11655         if (single_rd < search_state.best_pred_rd[COMPOUND_REFERENCE])
11656           search_state.best_pred_rd[COMPOUND_REFERENCE] = single_rd;
11657       }
11658       if (hybrid_rd < search_state.best_pred_rd[REFERENCE_MODE_SELECT])
11659         search_state.best_pred_rd[REFERENCE_MODE_SELECT] = hybrid_rd;
11660     }
11661     if (sf->drop_ref && second_ref_frame == NONE_FRAME) {
11662       // Collect data from single ref mode, and analyze data.
11663       sf_drop_ref_analyze(&search_state, mode_order, distortion2);
11664     }
11665 
11666     if (x->skip && !comp_pred) break;
11667   }
11668 
11669   aom_free(tmp_buf_orig);
11670   tmp_buf_orig = NULL;
11671   release_compound_type_rd_buffers(&rd_buffers);
11672 
11673 #if CONFIG_COLLECT_INTER_MODE_RD_STATS
11674   if (!do_tx_search) {
11675     inter_modes_info_sort(inter_modes_info, inter_modes_info->rd_idx_pair_arr);
11676     search_state.best_rd = INT64_MAX;
11677 
11678     int64_t top_est_rd =
11679         inter_modes_info->est_rd_arr[inter_modes_info->rd_idx_pair_arr[0].idx];
11680     for (int j = 0; j < inter_modes_info->num; ++j) {
11681       const int data_idx = inter_modes_info->rd_idx_pair_arr[j].idx;
11682       *mbmi = inter_modes_info->mbmi_arr[data_idx];
11683       int64_t curr_est_rd = inter_modes_info->est_rd_arr[data_idx];
11684       if (curr_est_rd * 0.9 > top_est_rd) {
11685         continue;
11686       }
11687       const int mode_rate = inter_modes_info->mode_rate_arr[data_idx];
11688 
11689       x->skip = 0;
11690       set_ref_ptrs(cm, xd, mbmi->ref_frame[0], mbmi->ref_frame[1]);
11691 
11692       // Select prediction reference frames.
11693       const int is_comp_pred = mbmi->ref_frame[1] > INTRA_FRAME;
11694       for (i = 0; i < num_planes; i++) {
11695         xd->plane[i].pre[0] = yv12_mb[mbmi->ref_frame[0]][i];
11696         if (is_comp_pred) xd->plane[i].pre[1] = yv12_mb[mbmi->ref_frame[1]][i];
11697       }
11698 
11699       RD_STATS rd_stats;
11700       RD_STATS rd_stats_y;
11701       RD_STATS rd_stats_uv;
11702 
11703       av1_build_inter_predictors_sb(cm, xd, mi_row, mi_col, NULL, bsize);
11704       if (mbmi->motion_mode == OBMC_CAUSAL)
11705         av1_build_obmc_inter_predictors_sb(cm, xd, mi_row, mi_col);
11706 
11707       if (!txfm_search(cpi, x, bsize, mi_row, mi_col, &rd_stats, &rd_stats_y,
11708                        &rd_stats_uv, mode_rate, search_state.best_rd)) {
11709         continue;
11710       } else {
11711         const int skip_ctx = av1_get_skip_context(xd);
11712         inter_mode_data_push(tile_data, mbmi->sb_type, rd_stats.sse,
11713                              rd_stats.dist,
11714                              rd_stats_y.rate + rd_stats_uv.rate +
11715                                  x->skip_cost[skip_ctx][mbmi->skip]);
11716       }
11717       rd_stats.rdcost = RDCOST(x->rdmult, rd_stats.rate, rd_stats.dist);
11718 
11719       if (rd_stats.rdcost < search_state.best_rd) {
11720         search_state.best_rd = rd_stats.rdcost;
11721         // Note index of best mode so far
11722         const int mode_index = get_prediction_mode_idx(
11723             mbmi->mode, mbmi->ref_frame[0], mbmi->ref_frame[1]);
11724         search_state.best_mode_index = mode_index;
11725         *rd_cost = rd_stats;
11726         search_state.best_rd = rd_stats.rdcost;
11727         search_state.best_mbmode = *mbmi;
11728         search_state.best_skip2 = mbmi->skip;
11729         search_state.best_mode_skippable = rd_stats.skip;
11730         search_state.best_rate_y =
11731             rd_stats_y.rate +
11732             x->skip_cost[av1_get_skip_context(xd)][rd_stats.skip || mbmi->skip];
11733         search_state.best_rate_uv = rd_stats_uv.rate;
11734         memcpy(ctx->blk_skip, x->blk_skip,
11735                sizeof(x->blk_skip[0]) * ctx->num_4x4_blk);
11736       }
11737     }
11738   }
11739 #endif
11740 
11741   for (int j = 0; j < intra_mode_num; ++j) {
11742     const int mode_index = intra_mode_idx_ls[j];
11743     const MV_REFERENCE_FRAME ref_frame =
11744         av1_mode_order[mode_index].ref_frame[0];
11745     assert(av1_mode_order[mode_index].ref_frame[1] == NONE_FRAME);
11746     assert(ref_frame == INTRA_FRAME);
11747     if (sf->skip_intra_in_interframe && search_state.skip_intra_modes) break;
11748     init_mbmi(mbmi, mode_index, cm);
11749     x->skip = 0;
11750     set_ref_ptrs(cm, xd, INTRA_FRAME, NONE_FRAME);
11751 
11752     // Select prediction reference frames.
11753     for (i = 0; i < num_planes; i++) {
11754       xd->plane[i].pre[0] = yv12_mb[ref_frame][i];
11755     }
11756 
11757     RD_STATS intra_rd_stats, intra_rd_stats_y, intra_rd_stats_uv;
11758 
11759     const int ref_frame_cost = ref_costs_single[ref_frame];
11760     intra_rd_stats.rdcost = handle_intra_mode(
11761         &search_state, cpi, x, bsize, mi_row, mi_col, ref_frame_cost, ctx, 0,
11762         &intra_rd_stats, &intra_rd_stats_y, &intra_rd_stats_uv);
11763     if (intra_rd_stats.rdcost < search_state.best_rd) {
11764       search_state.best_rd = intra_rd_stats.rdcost;
11765       // Note index of best mode so far
11766       search_state.best_mode_index = mode_index;
11767       *rd_cost = intra_rd_stats;
11768       search_state.best_rd = intra_rd_stats.rdcost;
11769       search_state.best_mbmode = *mbmi;
11770       search_state.best_skip2 = 0;
11771       search_state.best_mode_skippable = intra_rd_stats.skip;
11772       search_state.best_rate_y =
11773           intra_rd_stats_y.rate +
11774           x->skip_cost[av1_get_skip_context(xd)][intra_rd_stats.skip];
11775       search_state.best_rate_uv = intra_rd_stats_uv.rate;
11776       memcpy(ctx->blk_skip, x->blk_skip,
11777              sizeof(x->blk_skip[0]) * ctx->num_4x4_blk);
11778     }
11779   }
11780 
11781   // In effect only when speed >= 2.
11782   sf_refine_fast_tx_type_search(
11783       cpi, x, mi_row, mi_col, rd_cost, bsize, ctx, search_state.best_mode_index,
11784       &search_state.best_mbmode, yv12_mb, search_state.best_rate_y,
11785       search_state.best_rate_uv, &search_state.best_skip2);
11786 
11787   // Only try palette mode when the best mode so far is an intra mode.
11788   if (try_palette && !is_inter_mode(search_state.best_mbmode.mode)) {
11789     search_palette_mode(cpi, x, mi_row, mi_col, rd_cost, ctx, bsize, mbmi, pmi,
11790                         ref_costs_single, &search_state);
11791   }
11792 
11793   search_state.best_mbmode.skip_mode = 0;
11794   if (cm->skip_mode_flag &&
11795       !segfeature_active(seg, segment_id, SEG_LVL_REF_FRAME) &&
11796       is_comp_ref_allowed(bsize)) {
11797     rd_pick_skip_mode(rd_cost, &search_state, cpi, x, bsize, mi_row, mi_col,
11798                       yv12_mb);
11799   }
11800 
11801   // Make sure that the ref_mv_idx is only nonzero when we're
11802   // using a mode which can support ref_mv_idx
11803   if (search_state.best_mbmode.ref_mv_idx != 0 &&
11804       !(search_state.best_mbmode.mode == NEWMV ||
11805         search_state.best_mbmode.mode == NEW_NEWMV ||
11806         have_nearmv_in_inter_mode(search_state.best_mbmode.mode))) {
11807     search_state.best_mbmode.ref_mv_idx = 0;
11808   }
11809 
11810   if (search_state.best_mode_index < 0 ||
11811       search_state.best_rd >= best_rd_so_far) {
11812     rd_cost->rate = INT_MAX;
11813     rd_cost->rdcost = INT64_MAX;
11814     return;
11815   }
11816 
11817   assert(
11818       (cm->interp_filter == SWITCHABLE) ||
11819       (cm->interp_filter ==
11820        av1_extract_interp_filter(search_state.best_mbmode.interp_filters, 0)) ||
11821       !is_inter_block(&search_state.best_mbmode));
11822   assert(
11823       (cm->interp_filter == SWITCHABLE) ||
11824       (cm->interp_filter ==
11825        av1_extract_interp_filter(search_state.best_mbmode.interp_filters, 1)) ||
11826       !is_inter_block(&search_state.best_mbmode));
11827 
11828   if (!cpi->rc.is_src_frame_alt_ref)
11829     av1_update_rd_thresh_fact(cm, tile_data->thresh_freq_fact,
11830                               sf->adaptive_rd_thresh, bsize,
11831                               search_state.best_mode_index);
11832 
11833   // macroblock modes
11834   *mbmi = search_state.best_mbmode;
11835   x->skip |= search_state.best_skip2;
11836 
11837   // Note: this section is needed since the mode may have been forced to
11838   // GLOBALMV by the all-zero mode handling of ref-mv.
11839   if (mbmi->mode == GLOBALMV || mbmi->mode == GLOBAL_GLOBALMV) {
11840     // Correct the interp filters for GLOBALMV
11841     if (is_nontrans_global_motion(xd, xd->mi[0])) {
11842       assert(mbmi->interp_filters ==
11843              av1_broadcast_interp_filter(
11844                  av1_unswitchable_filter(cm->interp_filter)));
11845     }
11846   }
11847 
11848   for (i = 0; i < REFERENCE_MODES; ++i) {
11849     if (search_state.best_pred_rd[i] == INT64_MAX)
11850       search_state.best_pred_diff[i] = INT_MIN;
11851     else
11852       search_state.best_pred_diff[i] =
11853           search_state.best_rd - search_state.best_pred_rd[i];
11854   }
11855 
11856   x->skip |= search_state.best_mode_skippable;
11857 
11858   assert(search_state.best_mode_index >= 0);
11859 
11860   store_coding_context(x, ctx, search_state.best_mode_index,
11861                        search_state.best_pred_diff,
11862                        search_state.best_mode_skippable);
11863 
11864   if (pmi->palette_size[1] > 0) {
11865     assert(try_palette);
11866     restore_uv_color_map(cpi, x);
11867   }
11868 }
11869 
av1_rd_pick_inter_mode_sb_seg_skip(const AV1_COMP * cpi,TileDataEnc * tile_data,MACROBLOCK * x,int mi_row,int mi_col,RD_STATS * rd_cost,BLOCK_SIZE bsize,PICK_MODE_CONTEXT * ctx,int64_t best_rd_so_far)11870 void av1_rd_pick_inter_mode_sb_seg_skip(const AV1_COMP *cpi,
11871                                         TileDataEnc *tile_data, MACROBLOCK *x,
11872                                         int mi_row, int mi_col,
11873                                         RD_STATS *rd_cost, BLOCK_SIZE bsize,
11874                                         PICK_MODE_CONTEXT *ctx,
11875                                         int64_t best_rd_so_far) {
11876   const AV1_COMMON *const cm = &cpi->common;
11877   MACROBLOCKD *const xd = &x->e_mbd;
11878   MB_MODE_INFO *const mbmi = xd->mi[0];
11879   unsigned char segment_id = mbmi->segment_id;
11880   const int comp_pred = 0;
11881   int i;
11882   int64_t best_pred_diff[REFERENCE_MODES];
11883   unsigned int ref_costs_single[REF_FRAMES];
11884   unsigned int ref_costs_comp[REF_FRAMES][REF_FRAMES];
11885   int *comp_inter_cost = x->comp_inter_cost[av1_get_reference_mode_context(xd)];
11886   InterpFilter best_filter = SWITCHABLE;
11887   int64_t this_rd = INT64_MAX;
11888   int rate2 = 0;
11889   const int64_t distortion2 = 0;
11890   (void)mi_row;
11891   (void)mi_col;
11892 
11893   av1_collect_neighbors_ref_counts(xd);
11894 
11895   estimate_ref_frame_costs(cm, xd, x, segment_id, ref_costs_single,
11896                            ref_costs_comp);
11897 
11898   for (i = 0; i < REF_FRAMES; ++i) x->pred_sse[i] = INT_MAX;
11899   for (i = LAST_FRAME; i < REF_FRAMES; ++i) x->pred_mv_sad[i] = INT_MAX;
11900 
11901   rd_cost->rate = INT_MAX;
11902 
11903   assert(segfeature_active(&cm->seg, segment_id, SEG_LVL_SKIP));
11904 
11905   mbmi->palette_mode_info.palette_size[0] = 0;
11906   mbmi->palette_mode_info.palette_size[1] = 0;
11907   mbmi->filter_intra_mode_info.use_filter_intra = 0;
11908   mbmi->mode = GLOBALMV;
11909   mbmi->motion_mode = SIMPLE_TRANSLATION;
11910   mbmi->uv_mode = UV_DC_PRED;
11911   if (segfeature_active(&cm->seg, segment_id, SEG_LVL_REF_FRAME))
11912     mbmi->ref_frame[0] = get_segdata(&cm->seg, segment_id, SEG_LVL_REF_FRAME);
11913   else
11914     mbmi->ref_frame[0] = LAST_FRAME;
11915   mbmi->ref_frame[1] = NONE_FRAME;
11916   mbmi->mv[0].as_int =
11917       gm_get_motion_vector(&cm->global_motion[mbmi->ref_frame[0]],
11918                            cm->allow_high_precision_mv, bsize, mi_col, mi_row,
11919                            cm->cur_frame_force_integer_mv)
11920           .as_int;
11921   mbmi->tx_size = max_txsize_lookup[bsize];
11922   x->skip = 1;
11923 
11924   mbmi->ref_mv_idx = 0;
11925 
11926   mbmi->motion_mode = SIMPLE_TRANSLATION;
11927   av1_count_overlappable_neighbors(cm, xd, mi_row, mi_col);
11928   if (is_motion_variation_allowed_bsize(bsize) && !has_second_ref(mbmi)) {
11929     int pts[SAMPLES_ARRAY_SIZE], pts_inref[SAMPLES_ARRAY_SIZE];
11930     mbmi->num_proj_ref = findSamples(cm, xd, mi_row, mi_col, pts, pts_inref);
11931     // Select the samples according to motion vector difference
11932     if (mbmi->num_proj_ref > 1)
11933       mbmi->num_proj_ref = selectSamples(&mbmi->mv[0].as_mv, pts, pts_inref,
11934                                          mbmi->num_proj_ref, bsize);
11935   }
11936 
11937   set_default_interp_filters(mbmi, cm->interp_filter);
11938 
11939   if (cm->interp_filter != SWITCHABLE) {
11940     best_filter = cm->interp_filter;
11941   } else {
11942     best_filter = EIGHTTAP_REGULAR;
11943     if (av1_is_interp_needed(xd) && av1_is_interp_search_needed(xd) &&
11944         x->source_variance >= cpi->sf.disable_filter_search_var_thresh) {
11945       int rs;
11946       int best_rs = INT_MAX;
11947       for (i = 0; i < SWITCHABLE_FILTERS; ++i) {
11948         mbmi->interp_filters = av1_broadcast_interp_filter(i);
11949         rs = av1_get_switchable_rate(cm, x, xd);
11950         if (rs < best_rs) {
11951           best_rs = rs;
11952           best_filter = av1_extract_interp_filter(mbmi->interp_filters, 0);
11953         }
11954       }
11955     }
11956   }
11957   // Set the appropriate filter
11958   mbmi->interp_filters = av1_broadcast_interp_filter(best_filter);
11959   rate2 += av1_get_switchable_rate(cm, x, xd);
11960 
11961   if (cm->reference_mode == REFERENCE_MODE_SELECT)
11962     rate2 += comp_inter_cost[comp_pred];
11963 
11964   // Estimate the reference frame signaling cost and add it
11965   // to the rolling cost variable.
11966   rate2 += ref_costs_single[LAST_FRAME];
11967   this_rd = RDCOST(x->rdmult, rate2, distortion2);
11968 
11969   rd_cost->rate = rate2;
11970   rd_cost->dist = distortion2;
11971   rd_cost->rdcost = this_rd;
11972 
11973   if (this_rd >= best_rd_so_far) {
11974     rd_cost->rate = INT_MAX;
11975     rd_cost->rdcost = INT64_MAX;
11976     return;
11977   }
11978 
11979   assert((cm->interp_filter == SWITCHABLE) ||
11980          (cm->interp_filter ==
11981           av1_extract_interp_filter(mbmi->interp_filters, 0)));
11982 
11983   av1_update_rd_thresh_fact(cm, tile_data->thresh_freq_fact,
11984                             cpi->sf.adaptive_rd_thresh, bsize, THR_GLOBALMV);
11985 
11986   av1_zero(best_pred_diff);
11987 
11988   store_coding_context(x, ctx, THR_GLOBALMV, best_pred_diff, 0);
11989 }
11990 
11991 struct calc_target_weighted_pred_ctxt {
11992   const MACROBLOCK *x;
11993   const uint8_t *tmp;
11994   int tmp_stride;
11995   int overlap;
11996 };
11997 
calc_target_weighted_pred_above(MACROBLOCKD * xd,int rel_mi_col,uint8_t nb_mi_width,MB_MODE_INFO * nb_mi,void * fun_ctxt,const int num_planes)11998 static INLINE void calc_target_weighted_pred_above(
11999     MACROBLOCKD *xd, int rel_mi_col, uint8_t nb_mi_width, MB_MODE_INFO *nb_mi,
12000     void *fun_ctxt, const int num_planes) {
12001   (void)nb_mi;
12002   (void)num_planes;
12003 
12004   struct calc_target_weighted_pred_ctxt *ctxt =
12005       (struct calc_target_weighted_pred_ctxt *)fun_ctxt;
12006 
12007   const int bw = xd->n4_w << MI_SIZE_LOG2;
12008   const uint8_t *const mask1d = av1_get_obmc_mask(ctxt->overlap);
12009 
12010   int32_t *wsrc = ctxt->x->wsrc_buf + (rel_mi_col * MI_SIZE);
12011   int32_t *mask = ctxt->x->mask_buf + (rel_mi_col * MI_SIZE);
12012   const uint8_t *tmp = ctxt->tmp + rel_mi_col * MI_SIZE;
12013   const int is_hbd = (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) ? 1 : 0;
12014 
12015   if (!is_hbd) {
12016     for (int row = 0; row < ctxt->overlap; ++row) {
12017       const uint8_t m0 = mask1d[row];
12018       const uint8_t m1 = AOM_BLEND_A64_MAX_ALPHA - m0;
12019       for (int col = 0; col < nb_mi_width * MI_SIZE; ++col) {
12020         wsrc[col] = m1 * tmp[col];
12021         mask[col] = m0;
12022       }
12023       wsrc += bw;
12024       mask += bw;
12025       tmp += ctxt->tmp_stride;
12026     }
12027   } else {
12028     const uint16_t *tmp16 = CONVERT_TO_SHORTPTR(tmp);
12029 
12030     for (int row = 0; row < ctxt->overlap; ++row) {
12031       const uint8_t m0 = mask1d[row];
12032       const uint8_t m1 = AOM_BLEND_A64_MAX_ALPHA - m0;
12033       for (int col = 0; col < nb_mi_width * MI_SIZE; ++col) {
12034         wsrc[col] = m1 * tmp16[col];
12035         mask[col] = m0;
12036       }
12037       wsrc += bw;
12038       mask += bw;
12039       tmp16 += ctxt->tmp_stride;
12040     }
12041   }
12042 }
12043 
calc_target_weighted_pred_left(MACROBLOCKD * xd,int rel_mi_row,uint8_t nb_mi_height,MB_MODE_INFO * nb_mi,void * fun_ctxt,const int num_planes)12044 static INLINE void calc_target_weighted_pred_left(
12045     MACROBLOCKD *xd, int rel_mi_row, uint8_t nb_mi_height, MB_MODE_INFO *nb_mi,
12046     void *fun_ctxt, const int num_planes) {
12047   (void)nb_mi;
12048   (void)num_planes;
12049 
12050   struct calc_target_weighted_pred_ctxt *ctxt =
12051       (struct calc_target_weighted_pred_ctxt *)fun_ctxt;
12052 
12053   const int bw = xd->n4_w << MI_SIZE_LOG2;
12054   const uint8_t *const mask1d = av1_get_obmc_mask(ctxt->overlap);
12055 
12056   int32_t *wsrc = ctxt->x->wsrc_buf + (rel_mi_row * MI_SIZE * bw);
12057   int32_t *mask = ctxt->x->mask_buf + (rel_mi_row * MI_SIZE * bw);
12058   const uint8_t *tmp = ctxt->tmp + (rel_mi_row * MI_SIZE * ctxt->tmp_stride);
12059   const int is_hbd = (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) ? 1 : 0;
12060 
12061   if (!is_hbd) {
12062     for (int row = 0; row < nb_mi_height * MI_SIZE; ++row) {
12063       for (int col = 0; col < ctxt->overlap; ++col) {
12064         const uint8_t m0 = mask1d[col];
12065         const uint8_t m1 = AOM_BLEND_A64_MAX_ALPHA - m0;
12066         wsrc[col] = (wsrc[col] >> AOM_BLEND_A64_ROUND_BITS) * m0 +
12067                     (tmp[col] << AOM_BLEND_A64_ROUND_BITS) * m1;
12068         mask[col] = (mask[col] >> AOM_BLEND_A64_ROUND_BITS) * m0;
12069       }
12070       wsrc += bw;
12071       mask += bw;
12072       tmp += ctxt->tmp_stride;
12073     }
12074   } else {
12075     const uint16_t *tmp16 = CONVERT_TO_SHORTPTR(tmp);
12076 
12077     for (int row = 0; row < nb_mi_height * MI_SIZE; ++row) {
12078       for (int col = 0; col < ctxt->overlap; ++col) {
12079         const uint8_t m0 = mask1d[col];
12080         const uint8_t m1 = AOM_BLEND_A64_MAX_ALPHA - m0;
12081         wsrc[col] = (wsrc[col] >> AOM_BLEND_A64_ROUND_BITS) * m0 +
12082                     (tmp16[col] << AOM_BLEND_A64_ROUND_BITS) * m1;
12083         mask[col] = (mask[col] >> AOM_BLEND_A64_ROUND_BITS) * m0;
12084       }
12085       wsrc += bw;
12086       mask += bw;
12087       tmp16 += ctxt->tmp_stride;
12088     }
12089   }
12090 }
12091 
12092 // This function has a structure similar to av1_build_obmc_inter_prediction
12093 //
12094 // The OBMC predictor is computed as:
12095 //
12096 //  PObmc(x,y) =
12097 //    AOM_BLEND_A64(Mh(x),
12098 //                  AOM_BLEND_A64(Mv(y), P(x,y), PAbove(x,y)),
12099 //                  PLeft(x, y))
12100 //
12101 // Scaling up by AOM_BLEND_A64_MAX_ALPHA ** 2 and omitting the intermediate
12102 // rounding, this can be written as:
12103 //
12104 //  AOM_BLEND_A64_MAX_ALPHA * AOM_BLEND_A64_MAX_ALPHA * Pobmc(x,y) =
12105 //    Mh(x) * Mv(y) * P(x,y) +
12106 //      Mh(x) * Cv(y) * Pabove(x,y) +
12107 //      AOM_BLEND_A64_MAX_ALPHA * Ch(x) * PLeft(x, y)
12108 //
12109 // Where :
12110 //
12111 //  Cv(y) = AOM_BLEND_A64_MAX_ALPHA - Mv(y)
12112 //  Ch(y) = AOM_BLEND_A64_MAX_ALPHA - Mh(y)
12113 //
12114 // This function computes 'wsrc' and 'mask' as:
12115 //
12116 //  wsrc(x, y) =
12117 //    AOM_BLEND_A64_MAX_ALPHA * AOM_BLEND_A64_MAX_ALPHA * src(x, y) -
12118 //      Mh(x) * Cv(y) * Pabove(x,y) +
12119 //      AOM_BLEND_A64_MAX_ALPHA * Ch(x) * PLeft(x, y)
12120 //
12121 //  mask(x, y) = Mh(x) * Mv(y)
12122 //
12123 // These can then be used to efficiently approximate the error for any
12124 // predictor P in the context of the provided neighbouring predictors by
12125 // computing:
12126 //
12127 //  error(x, y) =
12128 //    wsrc(x, y) - mask(x, y) * P(x, y) / (AOM_BLEND_A64_MAX_ALPHA ** 2)
12129 //
calc_target_weighted_pred(const AV1_COMMON * cm,const MACROBLOCK * x,const MACROBLOCKD * xd,int mi_row,int mi_col,const uint8_t * above,int above_stride,const uint8_t * left,int left_stride)12130 static void calc_target_weighted_pred(const AV1_COMMON *cm, const MACROBLOCK *x,
12131                                       const MACROBLOCKD *xd, int mi_row,
12132                                       int mi_col, const uint8_t *above,
12133                                       int above_stride, const uint8_t *left,
12134                                       int left_stride) {
12135   const BLOCK_SIZE bsize = xd->mi[0]->sb_type;
12136   const int bw = xd->n4_w << MI_SIZE_LOG2;
12137   const int bh = xd->n4_h << MI_SIZE_LOG2;
12138   int32_t *mask_buf = x->mask_buf;
12139   int32_t *wsrc_buf = x->wsrc_buf;
12140 
12141   const int is_hbd = (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) ? 1 : 0;
12142   const int src_scale = AOM_BLEND_A64_MAX_ALPHA * AOM_BLEND_A64_MAX_ALPHA;
12143 
12144   // plane 0 should not be subsampled
12145   assert(xd->plane[0].subsampling_x == 0);
12146   assert(xd->plane[0].subsampling_y == 0);
12147 
12148   av1_zero_array(wsrc_buf, bw * bh);
12149   for (int i = 0; i < bw * bh; ++i) mask_buf[i] = AOM_BLEND_A64_MAX_ALPHA;
12150 
12151   // handle above row
12152   if (xd->up_available) {
12153     const int overlap =
12154         AOMMIN(block_size_high[bsize], block_size_high[BLOCK_64X64]) >> 1;
12155     struct calc_target_weighted_pred_ctxt ctxt = { x, above, above_stride,
12156                                                    overlap };
12157     foreach_overlappable_nb_above(cm, (MACROBLOCKD *)xd, mi_col,
12158                                   max_neighbor_obmc[mi_size_wide_log2[bsize]],
12159                                   calc_target_weighted_pred_above, &ctxt);
12160   }
12161 
12162   for (int i = 0; i < bw * bh; ++i) {
12163     wsrc_buf[i] *= AOM_BLEND_A64_MAX_ALPHA;
12164     mask_buf[i] *= AOM_BLEND_A64_MAX_ALPHA;
12165   }
12166 
12167   // handle left column
12168   if (xd->left_available) {
12169     const int overlap =
12170         AOMMIN(block_size_wide[bsize], block_size_wide[BLOCK_64X64]) >> 1;
12171     struct calc_target_weighted_pred_ctxt ctxt = { x, left, left_stride,
12172                                                    overlap };
12173     foreach_overlappable_nb_left(cm, (MACROBLOCKD *)xd, mi_row,
12174                                  max_neighbor_obmc[mi_size_high_log2[bsize]],
12175                                  calc_target_weighted_pred_left, &ctxt);
12176   }
12177 
12178   if (!is_hbd) {
12179     const uint8_t *src = x->plane[0].src.buf;
12180 
12181     for (int row = 0; row < bh; ++row) {
12182       for (int col = 0; col < bw; ++col) {
12183         wsrc_buf[col] = src[col] * src_scale - wsrc_buf[col];
12184       }
12185       wsrc_buf += bw;
12186       src += x->plane[0].src.stride;
12187     }
12188   } else {
12189     const uint16_t *src = CONVERT_TO_SHORTPTR(x->plane[0].src.buf);
12190 
12191     for (int row = 0; row < bh; ++row) {
12192       for (int col = 0; col < bw; ++col) {
12193         wsrc_buf[col] = src[col] * src_scale - wsrc_buf[col];
12194       }
12195       wsrc_buf += bw;
12196       src += x->plane[0].src.stride;
12197     }
12198   }
12199 }
12200