1 /*
2  * Copyright (c) 2020, Alliance for Open Media. All rights reserved
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include "av1/common/cfl.h"
13 #include "av1/common/reconintra.h"
14 #include "av1/encoder/block.h"
15 #include "av1/encoder/hybrid_fwd_txfm.h"
16 #include "av1/common/idct.h"
17 #include "av1/encoder/model_rd.h"
18 #include "av1/encoder/random.h"
19 #include "av1/encoder/rdopt_utils.h"
20 #include "av1/encoder/tx_prune_model_weights.h"
21 #include "av1/encoder/tx_search.h"
22 #include "av1/encoder/txb_rdopt.h"
23 
24 struct rdcost_block_args {
25   const AV1_COMP *cpi;
26   MACROBLOCK *x;
27   ENTROPY_CONTEXT t_above[MAX_MIB_SIZE];
28   ENTROPY_CONTEXT t_left[MAX_MIB_SIZE];
29   RD_STATS rd_stats;
30   int64_t current_rd;
31   int64_t best_rd;
32   int exit_early;
33   int incomplete_exit;
34   FAST_TX_SEARCH_MODE ftxs_mode;
35   int skip_trellis;
36 };
37 
38 typedef struct {
39   int64_t rd;
40   int txb_entropy_ctx;
41   TX_TYPE tx_type;
42 } TxCandidateInfo;
43 
44 typedef struct {
45   int leaf;
46   int8_t children[4];
47 } RD_RECORD_IDX_NODE;
48 
49 typedef struct tx_size_rd_info_node {
50   TXB_RD_INFO *rd_info_array;  // Points to array of size TX_TYPES.
51   struct tx_size_rd_info_node *children[4];
52 } TXB_RD_INFO_NODE;
53 
54 // origin_threshold * 128 / 100
55 static const uint32_t skip_pred_threshold[3][BLOCK_SIZES_ALL] = {
56   {
57       64, 64, 64, 70, 60, 60, 68, 68, 68, 68, 68,
58       68, 68, 68, 68, 68, 64, 64, 70, 70, 68, 68,
59   },
60   {
61       88, 88, 88, 86, 87, 87, 68, 68, 68, 68, 68,
62       68, 68, 68, 68, 68, 88, 88, 86, 86, 68, 68,
63   },
64   {
65       90, 93, 93, 90, 93, 93, 74, 74, 74, 74, 74,
66       74, 74, 74, 74, 74, 90, 90, 90, 90, 74, 74,
67   },
68 };
69 
70 // lookup table for predict_skip_txfm
71 // int max_tx_size = max_txsize_rect_lookup[bsize];
72 // if (tx_size_high[max_tx_size] > 16 || tx_size_wide[max_tx_size] > 16)
73 //   max_tx_size = AOMMIN(max_txsize_lookup[bsize], TX_16X16);
74 static const TX_SIZE max_predict_sf_tx_size[BLOCK_SIZES_ALL] = {
75   TX_4X4,   TX_4X8,   TX_8X4,   TX_8X8,   TX_8X16,  TX_16X8,
76   TX_16X16, TX_16X16, TX_16X16, TX_16X16, TX_16X16, TX_16X16,
77   TX_16X16, TX_16X16, TX_16X16, TX_16X16, TX_4X16,  TX_16X4,
78   TX_8X8,   TX_8X8,   TX_16X16, TX_16X16,
79 };
80 
81 // look-up table for sqrt of number of pixels in a transform block
82 // rounded up to the nearest integer.
83 static const int sqrt_tx_pixels_2d[TX_SIZES_ALL] = { 4,  8,  16, 32, 32, 6,  6,
84                                                      12, 12, 23, 23, 32, 32, 8,
85                                                      8,  16, 16, 23, 23 };
86 
find_tx_size_rd_info(TXB_RD_RECORD * cur_record,const uint32_t hash)87 static int find_tx_size_rd_info(TXB_RD_RECORD *cur_record,
88                                 const uint32_t hash) {
89   // Linear search through the circular buffer to find matching hash.
90   for (int i = cur_record->index_start - 1; i >= 0; i--) {
91     if (cur_record->hash_vals[i] == hash) return i;
92   }
93   for (int i = cur_record->num - 1; i >= cur_record->index_start; i--) {
94     if (cur_record->hash_vals[i] == hash) return i;
95   }
96   int index;
97   // If not found - add new RD info into the buffer and return its index
98   if (cur_record->num < TX_SIZE_RD_RECORD_BUFFER_LEN) {
99     index = (cur_record->index_start + cur_record->num) %
100             TX_SIZE_RD_RECORD_BUFFER_LEN;
101     cur_record->num++;
102   } else {
103     index = cur_record->index_start;
104     cur_record->index_start =
105         (cur_record->index_start + 1) % TX_SIZE_RD_RECORD_BUFFER_LEN;
106   }
107 
108   cur_record->hash_vals[index] = hash;
109   av1_zero(cur_record->tx_rd_info[index]);
110   return index;
111 }
112 
113 static const RD_RECORD_IDX_NODE rd_record_tree_8x8[] = {
114   { 1, { 0 } },
115 };
116 
117 static const RD_RECORD_IDX_NODE rd_record_tree_8x16[] = {
118   { 0, { 1, 2, -1, -1 } },
119   { 1, { 0, 0, 0, 0 } },
120   { 1, { 0, 0, 0, 0 } },
121 };
122 
123 static const RD_RECORD_IDX_NODE rd_record_tree_16x8[] = {
124   { 0, { 1, 2, -1, -1 } },
125   { 1, { 0 } },
126   { 1, { 0 } },
127 };
128 
129 static const RD_RECORD_IDX_NODE rd_record_tree_16x16[] = {
130   { 0, { 1, 2, 3, 4 } }, { 1, { 0 } }, { 1, { 0 } }, { 1, { 0 } }, { 1, { 0 } },
131 };
132 
133 static const RD_RECORD_IDX_NODE rd_record_tree_1_2[] = {
134   { 0, { 1, 2, -1, -1 } },
135   { 0, { 3, 4, 5, 6 } },
136   { 0, { 7, 8, 9, 10 } },
137 };
138 
139 static const RD_RECORD_IDX_NODE rd_record_tree_2_1[] = {
140   { 0, { 1, 2, -1, -1 } },
141   { 0, { 3, 4, 7, 8 } },
142   { 0, { 5, 6, 9, 10 } },
143 };
144 
145 static const RD_RECORD_IDX_NODE rd_record_tree_sqr[] = {
146   { 0, { 1, 2, 3, 4 } },     { 0, { 5, 6, 9, 10 } },    { 0, { 7, 8, 11, 12 } },
147   { 0, { 13, 14, 17, 18 } }, { 0, { 15, 16, 19, 20 } },
148 };
149 
150 static const RD_RECORD_IDX_NODE rd_record_tree_64x128[] = {
151   { 0, { 2, 3, 4, 5 } },     { 0, { 6, 7, 8, 9 } },
152   { 0, { 10, 11, 14, 15 } }, { 0, { 12, 13, 16, 17 } },
153   { 0, { 18, 19, 22, 23 } }, { 0, { 20, 21, 24, 25 } },
154   { 0, { 26, 27, 30, 31 } }, { 0, { 28, 29, 32, 33 } },
155   { 0, { 34, 35, 38, 39 } }, { 0, { 36, 37, 40, 41 } },
156 };
157 
158 static const RD_RECORD_IDX_NODE rd_record_tree_128x64[] = {
159   { 0, { 2, 3, 6, 7 } },     { 0, { 4, 5, 8, 9 } },
160   { 0, { 10, 11, 18, 19 } }, { 0, { 12, 13, 20, 21 } },
161   { 0, { 14, 15, 22, 23 } }, { 0, { 16, 17, 24, 25 } },
162   { 0, { 26, 27, 34, 35 } }, { 0, { 28, 29, 36, 37 } },
163   { 0, { 30, 31, 38, 39 } }, { 0, { 32, 33, 40, 41 } },
164 };
165 
166 static const RD_RECORD_IDX_NODE rd_record_tree_128x128[] = {
167   { 0, { 4, 5, 8, 9 } },     { 0, { 6, 7, 10, 11 } },
168   { 0, { 12, 13, 16, 17 } }, { 0, { 14, 15, 18, 19 } },
169   { 0, { 20, 21, 28, 29 } }, { 0, { 22, 23, 30, 31 } },
170   { 0, { 24, 25, 32, 33 } }, { 0, { 26, 27, 34, 35 } },
171   { 0, { 36, 37, 44, 45 } }, { 0, { 38, 39, 46, 47 } },
172   { 0, { 40, 41, 48, 49 } }, { 0, { 42, 43, 50, 51 } },
173   { 0, { 52, 53, 60, 61 } }, { 0, { 54, 55, 62, 63 } },
174   { 0, { 56, 57, 64, 65 } }, { 0, { 58, 59, 66, 67 } },
175   { 0, { 68, 69, 76, 77 } }, { 0, { 70, 71, 78, 79 } },
176   { 0, { 72, 73, 80, 81 } }, { 0, { 74, 75, 82, 83 } },
177 };
178 
179 static const RD_RECORD_IDX_NODE rd_record_tree_1_4[] = {
180   { 0, { 1, -1, 2, -1 } },
181   { 0, { 3, 4, -1, -1 } },
182   { 0, { 5, 6, -1, -1 } },
183 };
184 
185 static const RD_RECORD_IDX_NODE rd_record_tree_4_1[] = {
186   { 0, { 1, 2, -1, -1 } },
187   { 0, { 3, 4, -1, -1 } },
188   { 0, { 5, 6, -1, -1 } },
189 };
190 
191 static const RD_RECORD_IDX_NODE *rd_record_tree[BLOCK_SIZES_ALL] = {
192   NULL,                    // BLOCK_4X4
193   NULL,                    // BLOCK_4X8
194   NULL,                    // BLOCK_8X4
195   rd_record_tree_8x8,      // BLOCK_8X8
196   rd_record_tree_8x16,     // BLOCK_8X16
197   rd_record_tree_16x8,     // BLOCK_16X8
198   rd_record_tree_16x16,    // BLOCK_16X16
199   rd_record_tree_1_2,      // BLOCK_16X32
200   rd_record_tree_2_1,      // BLOCK_32X16
201   rd_record_tree_sqr,      // BLOCK_32X32
202   rd_record_tree_1_2,      // BLOCK_32X64
203   rd_record_tree_2_1,      // BLOCK_64X32
204   rd_record_tree_sqr,      // BLOCK_64X64
205   rd_record_tree_64x128,   // BLOCK_64X128
206   rd_record_tree_128x64,   // BLOCK_128X64
207   rd_record_tree_128x128,  // BLOCK_128X128
208   NULL,                    // BLOCK_4X16
209   NULL,                    // BLOCK_16X4
210   rd_record_tree_1_4,      // BLOCK_8X32
211   rd_record_tree_4_1,      // BLOCK_32X8
212   rd_record_tree_1_4,      // BLOCK_16X64
213   rd_record_tree_4_1,      // BLOCK_64X16
214 };
215 
216 static const int rd_record_tree_size[BLOCK_SIZES_ALL] = {
217   0,                                                            // BLOCK_4X4
218   0,                                                            // BLOCK_4X8
219   0,                                                            // BLOCK_8X4
220   sizeof(rd_record_tree_8x8) / sizeof(RD_RECORD_IDX_NODE),      // BLOCK_8X8
221   sizeof(rd_record_tree_8x16) / sizeof(RD_RECORD_IDX_NODE),     // BLOCK_8X16
222   sizeof(rd_record_tree_16x8) / sizeof(RD_RECORD_IDX_NODE),     // BLOCK_16X8
223   sizeof(rd_record_tree_16x16) / sizeof(RD_RECORD_IDX_NODE),    // BLOCK_16X16
224   sizeof(rd_record_tree_1_2) / sizeof(RD_RECORD_IDX_NODE),      // BLOCK_16X32
225   sizeof(rd_record_tree_2_1) / sizeof(RD_RECORD_IDX_NODE),      // BLOCK_32X16
226   sizeof(rd_record_tree_sqr) / sizeof(RD_RECORD_IDX_NODE),      // BLOCK_32X32
227   sizeof(rd_record_tree_1_2) / sizeof(RD_RECORD_IDX_NODE),      // BLOCK_32X64
228   sizeof(rd_record_tree_2_1) / sizeof(RD_RECORD_IDX_NODE),      // BLOCK_64X32
229   sizeof(rd_record_tree_sqr) / sizeof(RD_RECORD_IDX_NODE),      // BLOCK_64X64
230   sizeof(rd_record_tree_64x128) / sizeof(RD_RECORD_IDX_NODE),   // BLOCK_64X128
231   sizeof(rd_record_tree_128x64) / sizeof(RD_RECORD_IDX_NODE),   // BLOCK_128X64
232   sizeof(rd_record_tree_128x128) / sizeof(RD_RECORD_IDX_NODE),  // BLOCK_128X128
233   0,                                                            // BLOCK_4X16
234   0,                                                            // BLOCK_16X4
235   sizeof(rd_record_tree_1_4) / sizeof(RD_RECORD_IDX_NODE),      // BLOCK_8X32
236   sizeof(rd_record_tree_4_1) / sizeof(RD_RECORD_IDX_NODE),      // BLOCK_32X8
237   sizeof(rd_record_tree_1_4) / sizeof(RD_RECORD_IDX_NODE),      // BLOCK_16X64
238   sizeof(rd_record_tree_4_1) / sizeof(RD_RECORD_IDX_NODE),      // BLOCK_64X16
239 };
240 
init_rd_record_tree(TXB_RD_INFO_NODE * tree,BLOCK_SIZE bsize)241 static INLINE void init_rd_record_tree(TXB_RD_INFO_NODE *tree,
242                                        BLOCK_SIZE bsize) {
243   const RD_RECORD_IDX_NODE *rd_record = rd_record_tree[bsize];
244   const int size = rd_record_tree_size[bsize];
245   for (int i = 0; i < size; ++i) {
246     if (rd_record[i].leaf) {
247       av1_zero(tree[i].children);
248     } else {
249       for (int j = 0; j < 4; ++j) {
250         const int8_t idx = rd_record[i].children[j];
251         tree[i].children[j] = idx > 0 ? &tree[idx] : NULL;
252       }
253     }
254   }
255 }
256 
257 // Go through all TX blocks that could be used in TX size search, compute
258 // residual hash values for them and find matching RD info that stores previous
259 // RD search results for these TX blocks. The idea is to prevent repeated
260 // rate/distortion computations that happen because of the combination of
261 // partition and TX size search. The resulting RD info records are returned in
262 // the form of a quadtree for easier access in actual TX size search.
find_tx_size_rd_records(MACROBLOCK * x,BLOCK_SIZE bsize,TXB_RD_INFO_NODE * dst_rd_info)263 static int find_tx_size_rd_records(MACROBLOCK *x, BLOCK_SIZE bsize,
264                                    TXB_RD_INFO_NODE *dst_rd_info) {
265   TxfmSearchInfo *txfm_info = &x->txfm_search_info;
266   TXB_RD_RECORD *rd_records_table[4] = {
267     txfm_info->txb_rd_records->txb_rd_record_8X8,
268     txfm_info->txb_rd_records->txb_rd_record_16X16,
269     txfm_info->txb_rd_records->txb_rd_record_32X32,
270     txfm_info->txb_rd_records->txb_rd_record_64X64
271   };
272   const TX_SIZE max_square_tx_size = max_txsize_lookup[bsize];
273   const int bw = block_size_wide[bsize];
274   const int bh = block_size_high[bsize];
275 
276   // Hashing is performed only for square TX sizes larger than TX_4X4
277   if (max_square_tx_size < TX_8X8) return 0;
278   const int diff_stride = bw;
279   const struct macroblock_plane *const p = &x->plane[0];
280   const int16_t *diff = &p->src_diff[0];
281   init_rd_record_tree(dst_rd_info, bsize);
282   // Coordinates of the top-left corner of current block within the superblock
283   // measured in pixels:
284   const int mi_row = x->e_mbd.mi_row;
285   const int mi_col = x->e_mbd.mi_col;
286   const int mi_row_in_sb = (mi_row % MAX_MIB_SIZE) << MI_SIZE_LOG2;
287   const int mi_col_in_sb = (mi_col % MAX_MIB_SIZE) << MI_SIZE_LOG2;
288   int cur_rd_info_idx = 0;
289   int cur_tx_depth = 0;
290   TX_SIZE cur_tx_size = max_txsize_rect_lookup[bsize];
291   while (cur_tx_depth <= MAX_VARTX_DEPTH) {
292     const int cur_tx_bw = tx_size_wide[cur_tx_size];
293     const int cur_tx_bh = tx_size_high[cur_tx_size];
294     if (cur_tx_bw < 8 || cur_tx_bh < 8) break;
295     const TX_SIZE next_tx_size = sub_tx_size_map[cur_tx_size];
296     const int tx_size_idx = cur_tx_size - TX_8X8;
297     for (int row = 0; row < bh; row += cur_tx_bh) {
298       for (int col = 0; col < bw; col += cur_tx_bw) {
299         if (cur_tx_bw != cur_tx_bh) {
300           // Use dummy nodes for all rectangular transforms within the
301           // TX size search tree.
302           dst_rd_info[cur_rd_info_idx].rd_info_array = NULL;
303         } else {
304           // Get spatial location of this TX block within the superblock
305           // (measured in cur_tx_bsize units).
306           const int row_in_sb = (mi_row_in_sb + row) / cur_tx_bh;
307           const int col_in_sb = (mi_col_in_sb + col) / cur_tx_bw;
308 
309           int16_t hash_data[MAX_SB_SQUARE];
310           int16_t *cur_hash_row = hash_data;
311           const int16_t *cur_diff_row = diff + row * diff_stride + col;
312           for (int i = 0; i < cur_tx_bh; i++) {
313             memcpy(cur_hash_row, cur_diff_row, sizeof(*hash_data) * cur_tx_bw);
314             cur_hash_row += cur_tx_bw;
315             cur_diff_row += diff_stride;
316           }
317           const int hash = av1_get_crc32c_value(
318               &txfm_info->txb_rd_records->mb_rd_record.crc_calculator,
319               (uint8_t *)hash_data, 2 * cur_tx_bw * cur_tx_bh);
320           // Find corresponding RD info based on the hash value.
321           const int record_idx =
322               row_in_sb * (MAX_MIB_SIZE >> (tx_size_idx + 1)) + col_in_sb;
323           TXB_RD_RECORD *records = &rd_records_table[tx_size_idx][record_idx];
324           int idx = find_tx_size_rd_info(records, hash);
325           dst_rd_info[cur_rd_info_idx].rd_info_array =
326               &records->tx_rd_info[idx];
327         }
328         ++cur_rd_info_idx;
329       }
330     }
331     cur_tx_size = next_tx_size;
332     ++cur_tx_depth;
333   }
334   return 1;
335 }
336 
get_block_residue_hash(MACROBLOCK * x,BLOCK_SIZE bsize)337 static INLINE uint32_t get_block_residue_hash(MACROBLOCK *x, BLOCK_SIZE bsize) {
338   const int rows = block_size_high[bsize];
339   const int cols = block_size_wide[bsize];
340   const int16_t *diff = x->plane[0].src_diff;
341   const uint32_t hash = av1_get_crc32c_value(
342       &x->txfm_search_info.txb_rd_records->mb_rd_record.crc_calculator,
343       (uint8_t *)diff, 2 * rows * cols);
344   return (hash << 5) + bsize;
345 }
346 
find_mb_rd_info(const MB_RD_RECORD * const mb_rd_record,const int64_t ref_best_rd,const uint32_t hash)347 static INLINE int32_t find_mb_rd_info(const MB_RD_RECORD *const mb_rd_record,
348                                       const int64_t ref_best_rd,
349                                       const uint32_t hash) {
350   int32_t match_index = -1;
351   if (ref_best_rd != INT64_MAX) {
352     for (int i = 0; i < mb_rd_record->num; ++i) {
353       const int index = (mb_rd_record->index_start + i) % RD_RECORD_BUFFER_LEN;
354       // If there is a match in the tx_rd_record, fetch the RD decision and
355       // terminate early.
356       if (mb_rd_record->tx_rd_info[index].hash_value == hash) {
357         match_index = index;
358         break;
359       }
360     }
361   }
362   return match_index;
363 }
364 
fetch_tx_rd_info(int n4,const MB_RD_INFO * const tx_rd_info,RD_STATS * const rd_stats,MACROBLOCK * const x)365 static AOM_INLINE void fetch_tx_rd_info(int n4,
366                                         const MB_RD_INFO *const tx_rd_info,
367                                         RD_STATS *const rd_stats,
368                                         MACROBLOCK *const x) {
369   MACROBLOCKD *const xd = &x->e_mbd;
370   MB_MODE_INFO *const mbmi = xd->mi[0];
371   mbmi->tx_size = tx_rd_info->tx_size;
372   memcpy(x->txfm_search_info.blk_skip, tx_rd_info->blk_skip,
373          sizeof(tx_rd_info->blk_skip[0]) * n4);
374   av1_copy(mbmi->inter_tx_size, tx_rd_info->inter_tx_size);
375   av1_copy_array(xd->tx_type_map, tx_rd_info->tx_type_map, n4);
376   *rd_stats = tx_rd_info->rd_stats;
377 }
378 
379 // Compute the pixel domain distortion from diff on all visible 4x4s in the
380 // transform block.
pixel_diff_dist(const MACROBLOCK * x,int plane,int blk_row,int blk_col,const BLOCK_SIZE plane_bsize,const BLOCK_SIZE tx_bsize,unsigned int * block_mse_q8)381 static INLINE int64_t pixel_diff_dist(const MACROBLOCK *x, int plane,
382                                       int blk_row, int blk_col,
383                                       const BLOCK_SIZE plane_bsize,
384                                       const BLOCK_SIZE tx_bsize,
385                                       unsigned int *block_mse_q8) {
386   int visible_rows, visible_cols;
387   const MACROBLOCKD *xd = &x->e_mbd;
388   get_txb_dimensions(xd, plane, plane_bsize, blk_row, blk_col, tx_bsize, NULL,
389                      NULL, &visible_cols, &visible_rows);
390   const int diff_stride = block_size_wide[plane_bsize];
391   const int16_t *diff = x->plane[plane].src_diff;
392 
393   diff += ((blk_row * diff_stride + blk_col) << MI_SIZE_LOG2);
394   uint64_t sse =
395       aom_sum_squares_2d_i16(diff, diff_stride, visible_cols, visible_rows);
396   if (block_mse_q8 != NULL) {
397     if (visible_cols > 0 && visible_rows > 0)
398       *block_mse_q8 =
399           (unsigned int)((256 * sse) / (visible_cols * visible_rows));
400     else
401       *block_mse_q8 = UINT_MAX;
402   }
403   return sse;
404 }
405 
406 // Computes the residual block's SSE and mean on all visible 4x4s in the
407 // transform block
pixel_diff_stats(MACROBLOCK * x,int plane,int blk_row,int blk_col,const BLOCK_SIZE plane_bsize,const BLOCK_SIZE tx_bsize,unsigned int * block_mse_q8,int64_t * per_px_mean,uint64_t * block_var)408 static INLINE int64_t pixel_diff_stats(
409     MACROBLOCK *x, int plane, int blk_row, int blk_col,
410     const BLOCK_SIZE plane_bsize, const BLOCK_SIZE tx_bsize,
411     unsigned int *block_mse_q8, int64_t *per_px_mean, uint64_t *block_var) {
412   int visible_rows, visible_cols;
413   const MACROBLOCKD *xd = &x->e_mbd;
414   get_txb_dimensions(xd, plane, plane_bsize, blk_row, blk_col, tx_bsize, NULL,
415                      NULL, &visible_cols, &visible_rows);
416   const int diff_stride = block_size_wide[plane_bsize];
417   const int16_t *diff = x->plane[plane].src_diff;
418 
419   diff += ((blk_row * diff_stride + blk_col) << MI_SIZE_LOG2);
420   uint64_t sse = 0;
421   int sum = 0;
422   sse = aom_sum_sse_2d_i16(diff, diff_stride, visible_cols, visible_rows, &sum);
423   if (visible_cols > 0 && visible_rows > 0) {
424     aom_clear_system_state();
425     double norm_factor = 1.0 / (visible_cols * visible_rows);
426     int sign_sum = sum > 0 ? 1 : -1;
427     // Conversion to transform domain
428     *per_px_mean = (int64_t)(norm_factor * abs(sum)) << 7;
429     *per_px_mean = sign_sum * (*per_px_mean);
430     *block_mse_q8 = (unsigned int)(norm_factor * (256 * sse));
431     *block_var = (uint64_t)(sse - (uint64_t)(norm_factor * sum * sum));
432   } else {
433     *block_mse_q8 = UINT_MAX;
434   }
435   return sse;
436 }
437 
438 // Uses simple features on top of DCT coefficients to quickly predict
439 // whether optimal RD decision is to skip encoding the residual.
440 // The sse value is stored in dist.
predict_skip_txfm(MACROBLOCK * x,BLOCK_SIZE bsize,int64_t * dist,int reduced_tx_set)441 static int predict_skip_txfm(MACROBLOCK *x, BLOCK_SIZE bsize, int64_t *dist,
442                              int reduced_tx_set) {
443   const TxfmSearchParams *txfm_params = &x->txfm_search_params;
444   const int bw = block_size_wide[bsize];
445   const int bh = block_size_high[bsize];
446   const MACROBLOCKD *xd = &x->e_mbd;
447   const int16_t dc_q = av1_dc_quant_QTX(x->qindex, 0, xd->bd);
448 
449   *dist = pixel_diff_dist(x, 0, 0, 0, bsize, bsize, NULL);
450 
451   const int64_t mse = *dist / bw / bh;
452   // Normalized quantizer takes the transform upscaling factor (8 for tx size
453   // smaller than 32) into account.
454   const int16_t normalized_dc_q = dc_q >> 3;
455   const int64_t mse_thresh = (int64_t)normalized_dc_q * normalized_dc_q / 8;
456   // For faster early skip decision, use dist to compare against threshold so
457   // that quality risk is less for the skip=1 decision. Otherwise, use mse
458   // since the fwd_txfm coeff checks will take care of quality
459   // TODO(any): Use dist to return 0 when skip_txfm_level is 1
460   int64_t pred_err = (txfm_params->skip_txfm_level >= 2) ? *dist : mse;
461   // Predict not to skip when error is larger than threshold.
462   if (pred_err > mse_thresh) return 0;
463   // Return as skip otherwise for aggressive early skip
464   else if (txfm_params->skip_txfm_level >= 2)
465     return 1;
466 
467   const int max_tx_size = max_predict_sf_tx_size[bsize];
468   const int tx_h = tx_size_high[max_tx_size];
469   const int tx_w = tx_size_wide[max_tx_size];
470   DECLARE_ALIGNED(32, tran_low_t, coefs[32 * 32]);
471   TxfmParam param;
472   param.tx_type = DCT_DCT;
473   param.tx_size = max_tx_size;
474   param.bd = xd->bd;
475   param.is_hbd = is_cur_buf_hbd(xd);
476   param.lossless = 0;
477   param.tx_set_type = av1_get_ext_tx_set_type(
478       param.tx_size, is_inter_block(xd->mi[0]), reduced_tx_set);
479   const int bd_idx = (xd->bd == 8) ? 0 : ((xd->bd == 10) ? 1 : 2);
480   const uint32_t max_qcoef_thresh = skip_pred_threshold[bd_idx][bsize];
481   const int16_t *src_diff = x->plane[0].src_diff;
482   const int n_coeff = tx_w * tx_h;
483   const int16_t ac_q = av1_ac_quant_QTX(x->qindex, 0, xd->bd);
484   const uint32_t dc_thresh = max_qcoef_thresh * dc_q;
485   const uint32_t ac_thresh = max_qcoef_thresh * ac_q;
486   for (int row = 0; row < bh; row += tx_h) {
487     for (int col = 0; col < bw; col += tx_w) {
488       av1_fwd_txfm(src_diff + col, coefs, bw, &param);
489       // Operating on TX domain, not pixels; we want the QTX quantizers
490       const uint32_t dc_coef = (((uint32_t)abs(coefs[0])) << 7);
491       if (dc_coef >= dc_thresh) return 0;
492       for (int i = 1; i < n_coeff; ++i) {
493         const uint32_t ac_coef = (((uint32_t)abs(coefs[i])) << 7);
494         if (ac_coef >= ac_thresh) return 0;
495       }
496     }
497     src_diff += tx_h * bw;
498   }
499   return 1;
500 }
501 
502 // Used to set proper context for early termination with skip = 1.
set_skip_txfm(MACROBLOCK * x,RD_STATS * rd_stats,int bsize,int64_t dist)503 static AOM_INLINE void set_skip_txfm(MACROBLOCK *x, RD_STATS *rd_stats,
504                                      int bsize, int64_t dist) {
505   MACROBLOCKD *const xd = &x->e_mbd;
506   MB_MODE_INFO *const mbmi = xd->mi[0];
507   const int n4 = bsize_to_num_blk(bsize);
508   const TX_SIZE tx_size = max_txsize_rect_lookup[bsize];
509   memset(xd->tx_type_map, DCT_DCT, sizeof(xd->tx_type_map[0]) * n4);
510   memset(mbmi->inter_tx_size, tx_size, sizeof(mbmi->inter_tx_size));
511   mbmi->tx_size = tx_size;
512   for (int i = 0; i < n4; ++i)
513     set_blk_skip(x->txfm_search_info.blk_skip, 0, i, 1);
514   rd_stats->skip_txfm = 1;
515   if (is_cur_buf_hbd(xd)) dist = ROUND_POWER_OF_TWO(dist, (xd->bd - 8) * 2);
516   rd_stats->dist = rd_stats->sse = (dist << 4);
517   // Though decision is to make the block as skip based on luma stats,
518   // it is possible that block becomes non skip after chroma rd. In addition
519   // intermediate non skip costs calculated by caller function will be
520   // incorrect, if rate is set as  zero (i.e., if zero_blk_rate is not
521   // accounted). Hence intermediate rate is populated to code the luma tx blks
522   // as skip, the caller function based on final rd decision (i.e., skip vs
523   // non-skip) sets the final rate accordingly. Here the rate populated
524   // corresponds to coding all the tx blocks with zero_blk_rate (based on max tx
525   // size possible) in the current block. Eg: For 128*128 block, rate would be
526   // 4 * zero_blk_rate where zero_blk_rate corresponds to coding of one 64x64 tx
527   // block as 'all zeros'
528   ENTROPY_CONTEXT ctxa[MAX_MIB_SIZE];
529   ENTROPY_CONTEXT ctxl[MAX_MIB_SIZE];
530   av1_get_entropy_contexts(bsize, &xd->plane[0], ctxa, ctxl);
531   ENTROPY_CONTEXT *ta = ctxa;
532   ENTROPY_CONTEXT *tl = ctxl;
533   const TX_SIZE txs_ctx = get_txsize_entropy_ctx(tx_size);
534   TXB_CTX txb_ctx;
535   get_txb_ctx(bsize, tx_size, 0, ta, tl, &txb_ctx);
536   const int zero_blk_rate = x->coeff_costs.coeff_costs[txs_ctx][PLANE_TYPE_Y]
537                                 .txb_skip_cost[txb_ctx.txb_skip_ctx][1];
538   rd_stats->rate = zero_blk_rate *
539                    (block_size_wide[bsize] >> tx_size_wide_log2[tx_size]) *
540                    (block_size_high[bsize] >> tx_size_high_log2[tx_size]);
541 }
542 
save_tx_rd_info(int n4,uint32_t hash,const MACROBLOCK * const x,const RD_STATS * const rd_stats,MB_RD_RECORD * tx_rd_record)543 static AOM_INLINE void save_tx_rd_info(int n4, uint32_t hash,
544                                        const MACROBLOCK *const x,
545                                        const RD_STATS *const rd_stats,
546                                        MB_RD_RECORD *tx_rd_record) {
547   int index;
548   if (tx_rd_record->num < RD_RECORD_BUFFER_LEN) {
549     index =
550         (tx_rd_record->index_start + tx_rd_record->num) % RD_RECORD_BUFFER_LEN;
551     ++tx_rd_record->num;
552   } else {
553     index = tx_rd_record->index_start;
554     tx_rd_record->index_start =
555         (tx_rd_record->index_start + 1) % RD_RECORD_BUFFER_LEN;
556   }
557   MB_RD_INFO *const tx_rd_info = &tx_rd_record->tx_rd_info[index];
558   const MACROBLOCKD *const xd = &x->e_mbd;
559   const MB_MODE_INFO *const mbmi = xd->mi[0];
560   tx_rd_info->hash_value = hash;
561   tx_rd_info->tx_size = mbmi->tx_size;
562   memcpy(tx_rd_info->blk_skip, x->txfm_search_info.blk_skip,
563          sizeof(tx_rd_info->blk_skip[0]) * n4);
564   av1_copy(tx_rd_info->inter_tx_size, mbmi->inter_tx_size);
565   av1_copy_array(tx_rd_info->tx_type_map, xd->tx_type_map, n4);
566   tx_rd_info->rd_stats = *rd_stats;
567 }
568 
get_search_init_depth(int mi_width,int mi_height,int is_inter,const SPEED_FEATURES * sf,int tx_size_search_method)569 static int get_search_init_depth(int mi_width, int mi_height, int is_inter,
570                                  const SPEED_FEATURES *sf,
571                                  int tx_size_search_method) {
572   if (tx_size_search_method == USE_LARGESTALL) return MAX_VARTX_DEPTH;
573 
574   if (sf->tx_sf.tx_size_search_lgr_block) {
575     if (mi_width > mi_size_wide[BLOCK_64X64] ||
576         mi_height > mi_size_high[BLOCK_64X64])
577       return MAX_VARTX_DEPTH;
578   }
579 
580   if (is_inter) {
581     return (mi_height != mi_width)
582                ? sf->tx_sf.inter_tx_size_search_init_depth_rect
583                : sf->tx_sf.inter_tx_size_search_init_depth_sqr;
584   } else {
585     return (mi_height != mi_width)
586                ? sf->tx_sf.intra_tx_size_search_init_depth_rect
587                : sf->tx_sf.intra_tx_size_search_init_depth_sqr;
588   }
589 }
590 
591 static AOM_INLINE void select_tx_block(
592     const AV1_COMP *cpi, MACROBLOCK *x, int blk_row, int blk_col, int block,
593     TX_SIZE tx_size, int depth, BLOCK_SIZE plane_bsize, ENTROPY_CONTEXT *ta,
594     ENTROPY_CONTEXT *tl, TXFM_CONTEXT *tx_above, TXFM_CONTEXT *tx_left,
595     RD_STATS *rd_stats, int64_t prev_level_rd, int64_t ref_best_rd,
596     int *is_cost_valid, FAST_TX_SEARCH_MODE ftxs_mode,
597     TXB_RD_INFO_NODE *rd_info_node);
598 
599 // NOTE: CONFIG_COLLECT_RD_STATS has 3 possible values
600 // 0: Do not collect any RD stats
601 // 1: Collect RD stats for transform units
602 // 2: Collect RD stats for partition units
603 #if CONFIG_COLLECT_RD_STATS
604 
get_energy_distribution_fine(const AV1_COMP * cpi,BLOCK_SIZE bsize,const uint8_t * src,int src_stride,const uint8_t * dst,int dst_stride,int need_4th,double * hordist,double * verdist)605 static AOM_INLINE void get_energy_distribution_fine(
606     const AV1_COMP *cpi, BLOCK_SIZE bsize, const uint8_t *src, int src_stride,
607     const uint8_t *dst, int dst_stride, int need_4th, double *hordist,
608     double *verdist) {
609   const int bw = block_size_wide[bsize];
610   const int bh = block_size_high[bsize];
611   unsigned int esq[16] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 };
612 
613   if (bsize < BLOCK_16X16 || (bsize >= BLOCK_4X16 && bsize <= BLOCK_32X8)) {
614     // Special cases: calculate 'esq' values manually, as we don't have 'vf'
615     // functions for the 16 (very small) sub-blocks of this block.
616     const int w_shift = (bw == 4) ? 0 : (bw == 8) ? 1 : (bw == 16) ? 2 : 3;
617     const int h_shift = (bh == 4) ? 0 : (bh == 8) ? 1 : (bh == 16) ? 2 : 3;
618     assert(bw <= 32);
619     assert(bh <= 32);
620     assert(((bw - 1) >> w_shift) + (((bh - 1) >> h_shift) << 2) == 15);
621     if (cpi->common.seq_params.use_highbitdepth) {
622       const uint16_t *src16 = CONVERT_TO_SHORTPTR(src);
623       const uint16_t *dst16 = CONVERT_TO_SHORTPTR(dst);
624       for (int i = 0; i < bh; ++i)
625         for (int j = 0; j < bw; ++j) {
626           const int index = (j >> w_shift) + ((i >> h_shift) << 2);
627           esq[index] +=
628               (src16[j + i * src_stride] - dst16[j + i * dst_stride]) *
629               (src16[j + i * src_stride] - dst16[j + i * dst_stride]);
630         }
631     } else {
632       for (int i = 0; i < bh; ++i)
633         for (int j = 0; j < bw; ++j) {
634           const int index = (j >> w_shift) + ((i >> h_shift) << 2);
635           esq[index] += (src[j + i * src_stride] - dst[j + i * dst_stride]) *
636                         (src[j + i * src_stride] - dst[j + i * dst_stride]);
637         }
638     }
639   } else {  // Calculate 'esq' values using 'vf' functions on the 16 sub-blocks.
640     const int f_index =
641         (bsize < BLOCK_SIZES) ? bsize - BLOCK_16X16 : bsize - BLOCK_8X16;
642     assert(f_index >= 0 && f_index < BLOCK_SIZES_ALL);
643     const BLOCK_SIZE subsize = (BLOCK_SIZE)f_index;
644     assert(block_size_wide[bsize] == 4 * block_size_wide[subsize]);
645     assert(block_size_high[bsize] == 4 * block_size_high[subsize]);
646     cpi->fn_ptr[subsize].vf(src, src_stride, dst, dst_stride, &esq[0]);
647     cpi->fn_ptr[subsize].vf(src + bw / 4, src_stride, dst + bw / 4, dst_stride,
648                             &esq[1]);
649     cpi->fn_ptr[subsize].vf(src + bw / 2, src_stride, dst + bw / 2, dst_stride,
650                             &esq[2]);
651     cpi->fn_ptr[subsize].vf(src + 3 * bw / 4, src_stride, dst + 3 * bw / 4,
652                             dst_stride, &esq[3]);
653     src += bh / 4 * src_stride;
654     dst += bh / 4 * dst_stride;
655 
656     cpi->fn_ptr[subsize].vf(src, src_stride, dst, dst_stride, &esq[4]);
657     cpi->fn_ptr[subsize].vf(src + bw / 4, src_stride, dst + bw / 4, dst_stride,
658                             &esq[5]);
659     cpi->fn_ptr[subsize].vf(src + bw / 2, src_stride, dst + bw / 2, dst_stride,
660                             &esq[6]);
661     cpi->fn_ptr[subsize].vf(src + 3 * bw / 4, src_stride, dst + 3 * bw / 4,
662                             dst_stride, &esq[7]);
663     src += bh / 4 * src_stride;
664     dst += bh / 4 * dst_stride;
665 
666     cpi->fn_ptr[subsize].vf(src, src_stride, dst, dst_stride, &esq[8]);
667     cpi->fn_ptr[subsize].vf(src + bw / 4, src_stride, dst + bw / 4, dst_stride,
668                             &esq[9]);
669     cpi->fn_ptr[subsize].vf(src + bw / 2, src_stride, dst + bw / 2, dst_stride,
670                             &esq[10]);
671     cpi->fn_ptr[subsize].vf(src + 3 * bw / 4, src_stride, dst + 3 * bw / 4,
672                             dst_stride, &esq[11]);
673     src += bh / 4 * src_stride;
674     dst += bh / 4 * dst_stride;
675 
676     cpi->fn_ptr[subsize].vf(src, src_stride, dst, dst_stride, &esq[12]);
677     cpi->fn_ptr[subsize].vf(src + bw / 4, src_stride, dst + bw / 4, dst_stride,
678                             &esq[13]);
679     cpi->fn_ptr[subsize].vf(src + bw / 2, src_stride, dst + bw / 2, dst_stride,
680                             &esq[14]);
681     cpi->fn_ptr[subsize].vf(src + 3 * bw / 4, src_stride, dst + 3 * bw / 4,
682                             dst_stride, &esq[15]);
683   }
684 
685   double total = (double)esq[0] + esq[1] + esq[2] + esq[3] + esq[4] + esq[5] +
686                  esq[6] + esq[7] + esq[8] + esq[9] + esq[10] + esq[11] +
687                  esq[12] + esq[13] + esq[14] + esq[15];
688   if (total > 0) {
689     const double e_recip = 1.0 / total;
690     hordist[0] = ((double)esq[0] + esq[4] + esq[8] + esq[12]) * e_recip;
691     hordist[1] = ((double)esq[1] + esq[5] + esq[9] + esq[13]) * e_recip;
692     hordist[2] = ((double)esq[2] + esq[6] + esq[10] + esq[14]) * e_recip;
693     if (need_4th) {
694       hordist[3] = ((double)esq[3] + esq[7] + esq[11] + esq[15]) * e_recip;
695     }
696     verdist[0] = ((double)esq[0] + esq[1] + esq[2] + esq[3]) * e_recip;
697     verdist[1] = ((double)esq[4] + esq[5] + esq[6] + esq[7]) * e_recip;
698     verdist[2] = ((double)esq[8] + esq[9] + esq[10] + esq[11]) * e_recip;
699     if (need_4th) {
700       verdist[3] = ((double)esq[12] + esq[13] + esq[14] + esq[15]) * e_recip;
701     }
702   } else {
703     hordist[0] = verdist[0] = 0.25;
704     hordist[1] = verdist[1] = 0.25;
705     hordist[2] = verdist[2] = 0.25;
706     if (need_4th) {
707       hordist[3] = verdist[3] = 0.25;
708     }
709   }
710 }
711 
get_sse_norm(const int16_t * diff,int stride,int w,int h)712 static double get_sse_norm(const int16_t *diff, int stride, int w, int h) {
713   double sum = 0.0;
714   for (int j = 0; j < h; ++j) {
715     for (int i = 0; i < w; ++i) {
716       const int err = diff[j * stride + i];
717       sum += err * err;
718     }
719   }
720   assert(w > 0 && h > 0);
721   return sum / (w * h);
722 }
723 
get_sad_norm(const int16_t * diff,int stride,int w,int h)724 static double get_sad_norm(const int16_t *diff, int stride, int w, int h) {
725   double sum = 0.0;
726   for (int j = 0; j < h; ++j) {
727     for (int i = 0; i < w; ++i) {
728       sum += abs(diff[j * stride + i]);
729     }
730   }
731   assert(w > 0 && h > 0);
732   return sum / (w * h);
733 }
734 
get_2x2_normalized_sses_and_sads(const AV1_COMP * const cpi,BLOCK_SIZE tx_bsize,const uint8_t * const src,int src_stride,const uint8_t * const dst,int dst_stride,const int16_t * const src_diff,int diff_stride,double * const sse_norm_arr,double * const sad_norm_arr)735 static AOM_INLINE void get_2x2_normalized_sses_and_sads(
736     const AV1_COMP *const cpi, BLOCK_SIZE tx_bsize, const uint8_t *const src,
737     int src_stride, const uint8_t *const dst, int dst_stride,
738     const int16_t *const src_diff, int diff_stride, double *const sse_norm_arr,
739     double *const sad_norm_arr) {
740   const BLOCK_SIZE tx_bsize_half =
741       get_partition_subsize(tx_bsize, PARTITION_SPLIT);
742   if (tx_bsize_half == BLOCK_INVALID) {  // manually calculate stats
743     const int half_width = block_size_wide[tx_bsize] / 2;
744     const int half_height = block_size_high[tx_bsize] / 2;
745     for (int row = 0; row < 2; ++row) {
746       for (int col = 0; col < 2; ++col) {
747         const int16_t *const this_src_diff =
748             src_diff + row * half_height * diff_stride + col * half_width;
749         if (sse_norm_arr) {
750           sse_norm_arr[row * 2 + col] =
751               get_sse_norm(this_src_diff, diff_stride, half_width, half_height);
752         }
753         if (sad_norm_arr) {
754           sad_norm_arr[row * 2 + col] =
755               get_sad_norm(this_src_diff, diff_stride, half_width, half_height);
756         }
757       }
758     }
759   } else {  // use function pointers to calculate stats
760     const int half_width = block_size_wide[tx_bsize_half];
761     const int half_height = block_size_high[tx_bsize_half];
762     const int num_samples_half = half_width * half_height;
763     for (int row = 0; row < 2; ++row) {
764       for (int col = 0; col < 2; ++col) {
765         const uint8_t *const this_src =
766             src + row * half_height * src_stride + col * half_width;
767         const uint8_t *const this_dst =
768             dst + row * half_height * dst_stride + col * half_width;
769 
770         if (sse_norm_arr) {
771           unsigned int this_sse;
772           cpi->fn_ptr[tx_bsize_half].vf(this_src, src_stride, this_dst,
773                                         dst_stride, &this_sse);
774           sse_norm_arr[row * 2 + col] = (double)this_sse / num_samples_half;
775         }
776 
777         if (sad_norm_arr) {
778           const unsigned int this_sad = cpi->fn_ptr[tx_bsize_half].sdf(
779               this_src, src_stride, this_dst, dst_stride);
780           sad_norm_arr[row * 2 + col] = (double)this_sad / num_samples_half;
781         }
782       }
783     }
784   }
785 }
786 
787 #if CONFIG_COLLECT_RD_STATS == 1
get_mean(const int16_t * diff,int stride,int w,int h)788 static double get_mean(const int16_t *diff, int stride, int w, int h) {
789   double sum = 0.0;
790   for (int j = 0; j < h; ++j) {
791     for (int i = 0; i < w; ++i) {
792       sum += diff[j * stride + i];
793     }
794   }
795   assert(w > 0 && h > 0);
796   return sum / (w * h);
797 }
PrintTransformUnitStats(const AV1_COMP * const cpi,MACROBLOCK * x,const RD_STATS * const rd_stats,int blk_row,int blk_col,BLOCK_SIZE plane_bsize,TX_SIZE tx_size,TX_TYPE tx_type,int64_t rd)798 static AOM_INLINE void PrintTransformUnitStats(
799     const AV1_COMP *const cpi, MACROBLOCK *x, const RD_STATS *const rd_stats,
800     int blk_row, int blk_col, BLOCK_SIZE plane_bsize, TX_SIZE tx_size,
801     TX_TYPE tx_type, int64_t rd) {
802   if (rd_stats->rate == INT_MAX || rd_stats->dist == INT64_MAX) return;
803 
804   // Generate small sample to restrict output size.
805   static unsigned int seed = 21743;
806   if (lcg_rand16(&seed) % 256 > 0) return;
807 
808   const char output_file[] = "tu_stats.txt";
809   FILE *fout = fopen(output_file, "a");
810   if (!fout) return;
811 
812   const BLOCK_SIZE tx_bsize = txsize_to_bsize[tx_size];
813   const MACROBLOCKD *const xd = &x->e_mbd;
814   const int plane = 0;
815   struct macroblock_plane *const p = &x->plane[plane];
816   const struct macroblockd_plane *const pd = &xd->plane[plane];
817   const int txw = tx_size_wide[tx_size];
818   const int txh = tx_size_high[tx_size];
819   const int dequant_shift = (is_cur_buf_hbd(xd)) ? xd->bd - 5 : 3;
820   const int q_step = p->dequant_QTX[1] >> dequant_shift;
821   const int num_samples = txw * txh;
822 
823   const double rate_norm = (double)rd_stats->rate / num_samples;
824   const double dist_norm = (double)rd_stats->dist / num_samples;
825 
826   fprintf(fout, "%g %g", rate_norm, dist_norm);
827 
828   const int src_stride = p->src.stride;
829   const uint8_t *const src =
830       &p->src.buf[(blk_row * src_stride + blk_col) << MI_SIZE_LOG2];
831   const int dst_stride = pd->dst.stride;
832   const uint8_t *const dst =
833       &pd->dst.buf[(blk_row * dst_stride + blk_col) << MI_SIZE_LOG2];
834   unsigned int sse;
835   cpi->fn_ptr[tx_bsize].vf(src, src_stride, dst, dst_stride, &sse);
836   const double sse_norm = (double)sse / num_samples;
837 
838   const unsigned int sad =
839       cpi->fn_ptr[tx_bsize].sdf(src, src_stride, dst, dst_stride);
840   const double sad_norm = (double)sad / num_samples;
841 
842   fprintf(fout, " %g %g", sse_norm, sad_norm);
843 
844   const int diff_stride = block_size_wide[plane_bsize];
845   const int16_t *const src_diff =
846       &p->src_diff[(blk_row * diff_stride + blk_col) << MI_SIZE_LOG2];
847 
848   double sse_norm_arr[4], sad_norm_arr[4];
849   get_2x2_normalized_sses_and_sads(cpi, tx_bsize, src, src_stride, dst,
850                                    dst_stride, src_diff, diff_stride,
851                                    sse_norm_arr, sad_norm_arr);
852   for (int i = 0; i < 4; ++i) {
853     fprintf(fout, " %g", sse_norm_arr[i]);
854   }
855   for (int i = 0; i < 4; ++i) {
856     fprintf(fout, " %g", sad_norm_arr[i]);
857   }
858 
859   const TX_TYPE_1D tx_type_1d_row = htx_tab[tx_type];
860   const TX_TYPE_1D tx_type_1d_col = vtx_tab[tx_type];
861 
862   fprintf(fout, " %d %d %d %d %d", q_step, tx_size_wide[tx_size],
863           tx_size_high[tx_size], tx_type_1d_row, tx_type_1d_col);
864 
865   int model_rate;
866   int64_t model_dist;
867   model_rd_sse_fn[MODELRD_CURVFIT](cpi, x, tx_bsize, plane, sse, num_samples,
868                                    &model_rate, &model_dist);
869   const double model_rate_norm = (double)model_rate / num_samples;
870   const double model_dist_norm = (double)model_dist / num_samples;
871   fprintf(fout, " %g %g", model_rate_norm, model_dist_norm);
872 
873   const double mean = get_mean(src_diff, diff_stride, txw, txh);
874   float hor_corr, vert_corr;
875   av1_get_horver_correlation_full(src_diff, diff_stride, txw, txh, &hor_corr,
876                                   &vert_corr);
877   fprintf(fout, " %g %g %g", mean, hor_corr, vert_corr);
878 
879   double hdist[4] = { 0 }, vdist[4] = { 0 };
880   get_energy_distribution_fine(cpi, tx_bsize, src, src_stride, dst, dst_stride,
881                                1, hdist, vdist);
882   fprintf(fout, " %g %g %g %g %g %g %g %g", hdist[0], hdist[1], hdist[2],
883           hdist[3], vdist[0], vdist[1], vdist[2], vdist[3]);
884 
885   fprintf(fout, " %d %" PRId64, x->rdmult, rd);
886 
887   fprintf(fout, "\n");
888   fclose(fout);
889 }
890 #endif  // CONFIG_COLLECT_RD_STATS == 1
891 
892 #if CONFIG_COLLECT_RD_STATS >= 2
get_sse(const AV1_COMP * cpi,const MACROBLOCK * x)893 static int64_t get_sse(const AV1_COMP *cpi, const MACROBLOCK *x) {
894   const AV1_COMMON *cm = &cpi->common;
895   const int num_planes = av1_num_planes(cm);
896   const MACROBLOCKD *xd = &x->e_mbd;
897   const MB_MODE_INFO *mbmi = xd->mi[0];
898   int64_t total_sse = 0;
899   for (int plane = 0; plane < num_planes; ++plane) {
900     const struct macroblock_plane *const p = &x->plane[plane];
901     const struct macroblockd_plane *const pd = &xd->plane[plane];
902     const BLOCK_SIZE bs =
903         get_plane_block_size(mbmi->bsize, pd->subsampling_x, pd->subsampling_y);
904     unsigned int sse;
905 
906     if (x->skip_chroma_rd && plane) continue;
907 
908     cpi->fn_ptr[bs].vf(p->src.buf, p->src.stride, pd->dst.buf, pd->dst.stride,
909                        &sse);
910     total_sse += sse;
911   }
912   total_sse <<= 4;
913   return total_sse;
914 }
915 
get_est_rate_dist(const TileDataEnc * tile_data,BLOCK_SIZE bsize,int64_t sse,int * est_residue_cost,int64_t * est_dist)916 static int get_est_rate_dist(const TileDataEnc *tile_data, BLOCK_SIZE bsize,
917                              int64_t sse, int *est_residue_cost,
918                              int64_t *est_dist) {
919   aom_clear_system_state();
920   const InterModeRdModel *md = &tile_data->inter_mode_rd_models[bsize];
921   if (md->ready) {
922     if (sse < md->dist_mean) {
923       *est_residue_cost = 0;
924       *est_dist = sse;
925     } else {
926       *est_dist = (int64_t)round(md->dist_mean);
927       const double est_ld = md->a * sse + md->b;
928       // Clamp estimated rate cost by INT_MAX / 2.
929       // TODO(angiebird@google.com): find better solution than clamping.
930       if (fabs(est_ld) < 1e-2) {
931         *est_residue_cost = INT_MAX / 2;
932       } else {
933         double est_residue_cost_dbl = ((sse - md->dist_mean) / est_ld);
934         if (est_residue_cost_dbl < 0) {
935           *est_residue_cost = 0;
936         } else {
937           *est_residue_cost =
938               (int)AOMMIN((int64_t)round(est_residue_cost_dbl), INT_MAX / 2);
939         }
940       }
941       if (*est_residue_cost <= 0) {
942         *est_residue_cost = 0;
943         *est_dist = sse;
944       }
945     }
946     return 1;
947   }
948   return 0;
949 }
950 
get_highbd_diff_mean(const uint8_t * src8,int src_stride,const uint8_t * dst8,int dst_stride,int w,int h)951 static double get_highbd_diff_mean(const uint8_t *src8, int src_stride,
952                                    const uint8_t *dst8, int dst_stride, int w,
953                                    int h) {
954   const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
955   const uint16_t *dst = CONVERT_TO_SHORTPTR(dst8);
956   double sum = 0.0;
957   for (int j = 0; j < h; ++j) {
958     for (int i = 0; i < w; ++i) {
959       const int diff = src[j * src_stride + i] - dst[j * dst_stride + i];
960       sum += diff;
961     }
962   }
963   assert(w > 0 && h > 0);
964   return sum / (w * h);
965 }
966 
get_diff_mean(const uint8_t * src,int src_stride,const uint8_t * dst,int dst_stride,int w,int h)967 static double get_diff_mean(const uint8_t *src, int src_stride,
968                             const uint8_t *dst, int dst_stride, int w, int h) {
969   double sum = 0.0;
970   for (int j = 0; j < h; ++j) {
971     for (int i = 0; i < w; ++i) {
972       const int diff = src[j * src_stride + i] - dst[j * dst_stride + i];
973       sum += diff;
974     }
975   }
976   assert(w > 0 && h > 0);
977   return sum / (w * h);
978 }
979 
PrintPredictionUnitStats(const AV1_COMP * const cpi,const TileDataEnc * tile_data,MACROBLOCK * x,const RD_STATS * const rd_stats,BLOCK_SIZE plane_bsize)980 static AOM_INLINE void PrintPredictionUnitStats(const AV1_COMP *const cpi,
981                                                 const TileDataEnc *tile_data,
982                                                 MACROBLOCK *x,
983                                                 const RD_STATS *const rd_stats,
984                                                 BLOCK_SIZE plane_bsize) {
985   if (rd_stats->rate == INT_MAX || rd_stats->dist == INT64_MAX) return;
986 
987   if (cpi->sf.inter_sf.inter_mode_rd_model_estimation == 1 &&
988       (tile_data == NULL ||
989        !tile_data->inter_mode_rd_models[plane_bsize].ready))
990     return;
991   (void)tile_data;
992   // Generate small sample to restrict output size.
993   static unsigned int seed = 95014;
994 
995   if ((lcg_rand16(&seed) % (1 << (14 - num_pels_log2_lookup[plane_bsize]))) !=
996       1)
997     return;
998 
999   const char output_file[] = "pu_stats.txt";
1000   FILE *fout = fopen(output_file, "a");
1001   if (!fout) return;
1002 
1003   MACROBLOCKD *const xd = &x->e_mbd;
1004   const int plane = 0;
1005   struct macroblock_plane *const p = &x->plane[plane];
1006   struct macroblockd_plane *pd = &xd->plane[plane];
1007   const int diff_stride = block_size_wide[plane_bsize];
1008   int bw, bh;
1009   get_txb_dimensions(xd, plane, plane_bsize, 0, 0, plane_bsize, NULL, NULL, &bw,
1010                      &bh);
1011   const int num_samples = bw * bh;
1012   const int dequant_shift = (is_cur_buf_hbd(xd)) ? xd->bd - 5 : 3;
1013   const int q_step = p->dequant_QTX[1] >> dequant_shift;
1014   const int shift = (xd->bd - 8);
1015 
1016   const double rate_norm = (double)rd_stats->rate / num_samples;
1017   const double dist_norm = (double)rd_stats->dist / num_samples;
1018   const double rdcost_norm =
1019       (double)RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist) / num_samples;
1020 
1021   fprintf(fout, "%g %g %g", rate_norm, dist_norm, rdcost_norm);
1022 
1023   const int src_stride = p->src.stride;
1024   const uint8_t *const src = p->src.buf;
1025   const int dst_stride = pd->dst.stride;
1026   const uint8_t *const dst = pd->dst.buf;
1027   const int16_t *const src_diff = p->src_diff;
1028 
1029   int64_t sse = calculate_sse(xd, p, pd, bw, bh);
1030   const double sse_norm = (double)sse / num_samples;
1031 
1032   const unsigned int sad =
1033       cpi->fn_ptr[plane_bsize].sdf(src, src_stride, dst, dst_stride);
1034   const double sad_norm =
1035       (double)sad / (1 << num_pels_log2_lookup[plane_bsize]);
1036 
1037   fprintf(fout, " %g %g", sse_norm, sad_norm);
1038 
1039   double sse_norm_arr[4], sad_norm_arr[4];
1040   get_2x2_normalized_sses_and_sads(cpi, plane_bsize, src, src_stride, dst,
1041                                    dst_stride, src_diff, diff_stride,
1042                                    sse_norm_arr, sad_norm_arr);
1043   if (shift) {
1044     for (int k = 0; k < 4; ++k) sse_norm_arr[k] /= (1 << (2 * shift));
1045     for (int k = 0; k < 4; ++k) sad_norm_arr[k] /= (1 << shift);
1046   }
1047   for (int i = 0; i < 4; ++i) {
1048     fprintf(fout, " %g", sse_norm_arr[i]);
1049   }
1050   for (int i = 0; i < 4; ++i) {
1051     fprintf(fout, " %g", sad_norm_arr[i]);
1052   }
1053 
1054   fprintf(fout, " %d %d %d %d", q_step, x->rdmult, bw, bh);
1055 
1056   int model_rate;
1057   int64_t model_dist;
1058   model_rd_sse_fn[MODELRD_CURVFIT](cpi, x, plane_bsize, plane, sse, num_samples,
1059                                    &model_rate, &model_dist);
1060   const double model_rdcost_norm =
1061       (double)RDCOST(x->rdmult, model_rate, model_dist) / num_samples;
1062   const double model_rate_norm = (double)model_rate / num_samples;
1063   const double model_dist_norm = (double)model_dist / num_samples;
1064   fprintf(fout, " %g %g %g", model_rate_norm, model_dist_norm,
1065           model_rdcost_norm);
1066 
1067   double mean;
1068   if (is_cur_buf_hbd(xd)) {
1069     mean = get_highbd_diff_mean(p->src.buf, p->src.stride, pd->dst.buf,
1070                                 pd->dst.stride, bw, bh);
1071   } else {
1072     mean = get_diff_mean(p->src.buf, p->src.stride, pd->dst.buf, pd->dst.stride,
1073                          bw, bh);
1074   }
1075   mean /= (1 << shift);
1076   float hor_corr, vert_corr;
1077   av1_get_horver_correlation_full(src_diff, diff_stride, bw, bh, &hor_corr,
1078                                   &vert_corr);
1079   fprintf(fout, " %g %g %g", mean, hor_corr, vert_corr);
1080 
1081   double hdist[4] = { 0 }, vdist[4] = { 0 };
1082   get_energy_distribution_fine(cpi, plane_bsize, src, src_stride, dst,
1083                                dst_stride, 1, hdist, vdist);
1084   fprintf(fout, " %g %g %g %g %g %g %g %g", hdist[0], hdist[1], hdist[2],
1085           hdist[3], vdist[0], vdist[1], vdist[2], vdist[3]);
1086 
1087   if (cpi->sf.inter_sf.inter_mode_rd_model_estimation == 1) {
1088     assert(tile_data->inter_mode_rd_models[plane_bsize].ready);
1089     const int64_t overall_sse = get_sse(cpi, x);
1090     int est_residue_cost = 0;
1091     int64_t est_dist = 0;
1092     get_est_rate_dist(tile_data, plane_bsize, overall_sse, &est_residue_cost,
1093                       &est_dist);
1094     const double est_residue_cost_norm = (double)est_residue_cost / num_samples;
1095     const double est_dist_norm = (double)est_dist / num_samples;
1096     const double est_rdcost_norm =
1097         (double)RDCOST(x->rdmult, est_residue_cost, est_dist) / num_samples;
1098     fprintf(fout, " %g %g %g", est_residue_cost_norm, est_dist_norm,
1099             est_rdcost_norm);
1100   }
1101 
1102   fprintf(fout, "\n");
1103   fclose(fout);
1104 }
1105 #endif  // CONFIG_COLLECT_RD_STATS >= 2
1106 #endif  // CONFIG_COLLECT_RD_STATS
1107 
inverse_transform_block_facade(MACROBLOCK * const x,int plane,int block,int blk_row,int blk_col,int eob,int reduced_tx_set)1108 static AOM_INLINE void inverse_transform_block_facade(MACROBLOCK *const x,
1109                                                       int plane, int block,
1110                                                       int blk_row, int blk_col,
1111                                                       int eob,
1112                                                       int reduced_tx_set) {
1113   if (!eob) return;
1114   struct macroblock_plane *const p = &x->plane[plane];
1115   MACROBLOCKD *const xd = &x->e_mbd;
1116   tran_low_t *dqcoeff = p->dqcoeff + BLOCK_OFFSET(block);
1117   const PLANE_TYPE plane_type = get_plane_type(plane);
1118   const TX_SIZE tx_size = av1_get_tx_size(plane, xd);
1119   const TX_TYPE tx_type = av1_get_tx_type(xd, plane_type, blk_row, blk_col,
1120                                           tx_size, reduced_tx_set);
1121 
1122   struct macroblockd_plane *const pd = &xd->plane[plane];
1123   const int dst_stride = pd->dst.stride;
1124   uint8_t *dst = &pd->dst.buf[(blk_row * dst_stride + blk_col) << MI_SIZE_LOG2];
1125   av1_inverse_transform_block(xd, dqcoeff, plane, tx_type, tx_size, dst,
1126                               dst_stride, eob, reduced_tx_set);
1127 }
1128 
recon_intra(const AV1_COMP * cpi,MACROBLOCK * x,int plane,int block,int blk_row,int blk_col,BLOCK_SIZE plane_bsize,TX_SIZE tx_size,const TXB_CTX * const txb_ctx,int skip_trellis,TX_TYPE best_tx_type,int do_quant,int * rate_cost,uint16_t best_eob)1129 static INLINE void recon_intra(const AV1_COMP *cpi, MACROBLOCK *x, int plane,
1130                                int block, int blk_row, int blk_col,
1131                                BLOCK_SIZE plane_bsize, TX_SIZE tx_size,
1132                                const TXB_CTX *const txb_ctx, int skip_trellis,
1133                                TX_TYPE best_tx_type, int do_quant,
1134                                int *rate_cost, uint16_t best_eob) {
1135   const AV1_COMMON *cm = &cpi->common;
1136   MACROBLOCKD *xd = &x->e_mbd;
1137   MB_MODE_INFO *mbmi = xd->mi[0];
1138   const int is_inter = is_inter_block(mbmi);
1139   if (!is_inter && best_eob &&
1140       (blk_row + tx_size_high_unit[tx_size] < mi_size_high[plane_bsize] ||
1141        blk_col + tx_size_wide_unit[tx_size] < mi_size_wide[plane_bsize])) {
1142     // if the quantized coefficients are stored in the dqcoeff buffer, we don't
1143     // need to do transform and quantization again.
1144     if (do_quant) {
1145       TxfmParam txfm_param_intra;
1146       QUANT_PARAM quant_param_intra;
1147       av1_setup_xform(cm, x, tx_size, best_tx_type, &txfm_param_intra);
1148       av1_setup_quant(tx_size, !skip_trellis,
1149                       skip_trellis
1150                           ? (USE_B_QUANT_NO_TRELLIS ? AV1_XFORM_QUANT_B
1151                                                     : AV1_XFORM_QUANT_FP)
1152                           : AV1_XFORM_QUANT_FP,
1153                       cpi->oxcf.q_cfg.quant_b_adapt, &quant_param_intra);
1154       av1_setup_qmatrix(&cm->quant_params, xd, plane, tx_size, best_tx_type,
1155                         &quant_param_intra);
1156       av1_xform_quant(x, plane, block, blk_row, blk_col, plane_bsize,
1157                       &txfm_param_intra, &quant_param_intra);
1158       if (quant_param_intra.use_optimize_b) {
1159         av1_optimize_b(cpi, x, plane, block, tx_size, best_tx_type, txb_ctx,
1160                        rate_cost);
1161       }
1162     }
1163 
1164     inverse_transform_block_facade(x, plane, block, blk_row, blk_col,
1165                                    x->plane[plane].eobs[block],
1166                                    cm->features.reduced_tx_set_used);
1167 
1168     // This may happen because of hash collision. The eob stored in the hash
1169     // table is non-zero, but the real eob is zero. We need to make sure tx_type
1170     // is DCT_DCT in this case.
1171     if (plane == 0 && x->plane[plane].eobs[block] == 0 &&
1172         best_tx_type != DCT_DCT) {
1173       update_txk_array(xd, blk_row, blk_col, tx_size, DCT_DCT);
1174     }
1175   }
1176 }
1177 
pixel_dist_visible_only(const AV1_COMP * const cpi,const MACROBLOCK * x,const uint8_t * src,const int src_stride,const uint8_t * dst,const int dst_stride,const BLOCK_SIZE tx_bsize,int txb_rows,int txb_cols,int visible_rows,int visible_cols)1178 static unsigned pixel_dist_visible_only(
1179     const AV1_COMP *const cpi, const MACROBLOCK *x, const uint8_t *src,
1180     const int src_stride, const uint8_t *dst, const int dst_stride,
1181     const BLOCK_SIZE tx_bsize, int txb_rows, int txb_cols, int visible_rows,
1182     int visible_cols) {
1183   unsigned sse;
1184 
1185   if (txb_rows == visible_rows && txb_cols == visible_cols) {
1186     cpi->fn_ptr[tx_bsize].vf(src, src_stride, dst, dst_stride, &sse);
1187     return sse;
1188   }
1189 
1190 #if CONFIG_AV1_HIGHBITDEPTH
1191   const MACROBLOCKD *xd = &x->e_mbd;
1192   if (is_cur_buf_hbd(xd)) {
1193     uint64_t sse64 = aom_highbd_sse_odd_size(src, src_stride, dst, dst_stride,
1194                                              visible_cols, visible_rows);
1195     return (unsigned int)ROUND_POWER_OF_TWO(sse64, (xd->bd - 8) * 2);
1196   }
1197 #else
1198   (void)x;
1199 #endif
1200   sse = aom_sse_odd_size(src, src_stride, dst, dst_stride, visible_cols,
1201                          visible_rows);
1202   return sse;
1203 }
1204 
1205 // Compute the pixel domain distortion from src and dst on all visible 4x4s in
1206 // the
1207 // transform block.
pixel_dist(const AV1_COMP * const cpi,const MACROBLOCK * x,int plane,const uint8_t * src,const int src_stride,const uint8_t * dst,const int dst_stride,int blk_row,int blk_col,const BLOCK_SIZE plane_bsize,const BLOCK_SIZE tx_bsize)1208 static unsigned pixel_dist(const AV1_COMP *const cpi, const MACROBLOCK *x,
1209                            int plane, const uint8_t *src, const int src_stride,
1210                            const uint8_t *dst, const int dst_stride,
1211                            int blk_row, int blk_col,
1212                            const BLOCK_SIZE plane_bsize,
1213                            const BLOCK_SIZE tx_bsize) {
1214   int txb_rows, txb_cols, visible_rows, visible_cols;
1215   const MACROBLOCKD *xd = &x->e_mbd;
1216 
1217   get_txb_dimensions(xd, plane, plane_bsize, blk_row, blk_col, tx_bsize,
1218                      &txb_cols, &txb_rows, &visible_cols, &visible_rows);
1219   assert(visible_rows > 0);
1220   assert(visible_cols > 0);
1221 
1222   unsigned sse = pixel_dist_visible_only(cpi, x, src, src_stride, dst,
1223                                          dst_stride, tx_bsize, txb_rows,
1224                                          txb_cols, visible_rows, visible_cols);
1225 
1226   return sse;
1227 }
1228 
dist_block_px_domain(const AV1_COMP * cpi,MACROBLOCK * x,int plane,BLOCK_SIZE plane_bsize,int block,int blk_row,int blk_col,TX_SIZE tx_size)1229 static INLINE int64_t dist_block_px_domain(const AV1_COMP *cpi, MACROBLOCK *x,
1230                                            int plane, BLOCK_SIZE plane_bsize,
1231                                            int block, int blk_row, int blk_col,
1232                                            TX_SIZE tx_size) {
1233   MACROBLOCKD *const xd = &x->e_mbd;
1234   const struct macroblock_plane *const p = &x->plane[plane];
1235   const uint16_t eob = p->eobs[block];
1236   const BLOCK_SIZE tx_bsize = txsize_to_bsize[tx_size];
1237   const int bsw = block_size_wide[tx_bsize];
1238   const int bsh = block_size_high[tx_bsize];
1239   const int src_stride = x->plane[plane].src.stride;
1240   const int dst_stride = xd->plane[plane].dst.stride;
1241   // Scale the transform block index to pixel unit.
1242   const int src_idx = (blk_row * src_stride + blk_col) << MI_SIZE_LOG2;
1243   const int dst_idx = (blk_row * dst_stride + blk_col) << MI_SIZE_LOG2;
1244   const uint8_t *src = &x->plane[plane].src.buf[src_idx];
1245   const uint8_t *dst = &xd->plane[plane].dst.buf[dst_idx];
1246   const tran_low_t *dqcoeff = p->dqcoeff + BLOCK_OFFSET(block);
1247 
1248   assert(cpi != NULL);
1249   assert(tx_size_wide_log2[0] == tx_size_high_log2[0]);
1250 
1251   uint8_t *recon;
1252   DECLARE_ALIGNED(16, uint16_t, recon16[MAX_TX_SQUARE]);
1253 
1254 #if CONFIG_AV1_HIGHBITDEPTH
1255   if (is_cur_buf_hbd(xd)) {
1256     recon = CONVERT_TO_BYTEPTR(recon16);
1257     aom_highbd_convolve_copy(CONVERT_TO_SHORTPTR(dst), dst_stride,
1258                              CONVERT_TO_SHORTPTR(recon), MAX_TX_SIZE, bsw, bsh);
1259   } else {
1260     recon = (uint8_t *)recon16;
1261     aom_convolve_copy(dst, dst_stride, recon, MAX_TX_SIZE, bsw, bsh);
1262   }
1263 #else
1264   recon = (uint8_t *)recon16;
1265   aom_convolve_copy(dst, dst_stride, recon, MAX_TX_SIZE, bsw, bsh);
1266 #endif
1267 
1268   const PLANE_TYPE plane_type = get_plane_type(plane);
1269   TX_TYPE tx_type = av1_get_tx_type(xd, plane_type, blk_row, blk_col, tx_size,
1270                                     cpi->common.features.reduced_tx_set_used);
1271   av1_inverse_transform_block(xd, dqcoeff, plane, tx_type, tx_size, recon,
1272                               MAX_TX_SIZE, eob,
1273                               cpi->common.features.reduced_tx_set_used);
1274 
1275   return 16 * pixel_dist(cpi, x, plane, src, src_stride, recon, MAX_TX_SIZE,
1276                          blk_row, blk_col, plane_bsize, tx_bsize);
1277 }
1278 
get_intra_txb_hash(MACROBLOCK * x,int plane,int blk_row,int blk_col,BLOCK_SIZE plane_bsize,TX_SIZE tx_size)1279 static uint32_t get_intra_txb_hash(MACROBLOCK *x, int plane, int blk_row,
1280                                    int blk_col, BLOCK_SIZE plane_bsize,
1281                                    TX_SIZE tx_size) {
1282   int16_t tmp_data[64 * 64];
1283   const int diff_stride = block_size_wide[plane_bsize];
1284   const int16_t *diff = x->plane[plane].src_diff;
1285   const int16_t *cur_diff_row = diff + 4 * blk_row * diff_stride + 4 * blk_col;
1286   const int txb_w = tx_size_wide[tx_size];
1287   const int txb_h = tx_size_high[tx_size];
1288   uint8_t *hash_data = (uint8_t *)cur_diff_row;
1289   if (txb_w != diff_stride) {
1290     int16_t *cur_hash_row = tmp_data;
1291     for (int i = 0; i < txb_h; i++) {
1292       memcpy(cur_hash_row, cur_diff_row, sizeof(*diff) * txb_w);
1293       cur_hash_row += txb_w;
1294       cur_diff_row += diff_stride;
1295     }
1296     hash_data = (uint8_t *)tmp_data;
1297   }
1298   CRC32C *crc =
1299       &x->txfm_search_info.txb_rd_records->mb_rd_record.crc_calculator;
1300   const uint32_t hash = av1_get_crc32c_value(crc, hash_data, 2 * txb_w * txb_h);
1301   return (hash << 5) + tx_size;
1302 }
1303 
1304 // pruning thresholds for prune_txk_type and prune_txk_type_separ
1305 static const int prune_factors[5] = { 200, 200, 120, 80, 40 };  // scale 1000
1306 static const int mul_factors[5] = { 80, 80, 70, 50, 30 };       // scale 100
1307 
is_intra_hash_match(const AV1_COMP * cpi,MACROBLOCK * x,int plane,int blk_row,int blk_col,BLOCK_SIZE plane_bsize,TX_SIZE tx_size,const TXB_CTX * const txb_ctx,TXB_RD_INFO ** intra_txb_rd_info,const int tx_type_map_idx,uint16_t * cur_joint_ctx)1308 static INLINE int is_intra_hash_match(const AV1_COMP *cpi, MACROBLOCK *x,
1309                                       int plane, int blk_row, int blk_col,
1310                                       BLOCK_SIZE plane_bsize, TX_SIZE tx_size,
1311                                       const TXB_CTX *const txb_ctx,
1312                                       TXB_RD_INFO **intra_txb_rd_info,
1313                                       const int tx_type_map_idx,
1314                                       uint16_t *cur_joint_ctx) {
1315   MACROBLOCKD *xd = &x->e_mbd;
1316   TxfmSearchInfo *txfm_info = &x->txfm_search_info;
1317   assert(cpi->sf.tx_sf.use_intra_txb_hash &&
1318          frame_is_intra_only(&cpi->common) && !is_inter_block(xd->mi[0]) &&
1319          plane == 0 && tx_size_wide[tx_size] == tx_size_high[tx_size]);
1320   const uint32_t intra_hash =
1321       get_intra_txb_hash(x, plane, blk_row, blk_col, plane_bsize, tx_size);
1322   const int intra_hash_idx = find_tx_size_rd_info(
1323       &txfm_info->txb_rd_records->txb_rd_record_intra, intra_hash);
1324   *intra_txb_rd_info = &txfm_info->txb_rd_records->txb_rd_record_intra
1325                             .tx_rd_info[intra_hash_idx];
1326   *cur_joint_ctx = (txb_ctx->dc_sign_ctx << 8) + txb_ctx->txb_skip_ctx;
1327   if ((*intra_txb_rd_info)->entropy_context == *cur_joint_ctx &&
1328       txfm_info->txb_rd_records->txb_rd_record_intra.tx_rd_info[intra_hash_idx]
1329           .valid) {
1330     xd->tx_type_map[tx_type_map_idx] = (*intra_txb_rd_info)->tx_type;
1331     const TX_TYPE ref_tx_type =
1332         av1_get_tx_type(xd, get_plane_type(plane), blk_row, blk_col, tx_size,
1333                         cpi->common.features.reduced_tx_set_used);
1334     return (ref_tx_type == (*intra_txb_rd_info)->tx_type);
1335   }
1336   return 0;
1337 }
1338 
1339 // R-D costs are sorted in ascending order.
sort_rd(int64_t rds[],int txk[],int len)1340 static INLINE void sort_rd(int64_t rds[], int txk[], int len) {
1341   int i, j, k;
1342 
1343   for (i = 1; i <= len - 1; ++i) {
1344     for (j = 0; j < i; ++j) {
1345       if (rds[j] > rds[i]) {
1346         int64_t temprd;
1347         int tempi;
1348 
1349         temprd = rds[i];
1350         tempi = txk[i];
1351 
1352         for (k = i; k > j; k--) {
1353           rds[k] = rds[k - 1];
1354           txk[k] = txk[k - 1];
1355         }
1356 
1357         rds[j] = temprd;
1358         txk[j] = tempi;
1359         break;
1360       }
1361     }
1362   }
1363 }
1364 
dist_block_tx_domain(MACROBLOCK * x,int plane,int block,TX_SIZE tx_size,int64_t * out_dist,int64_t * out_sse)1365 static INLINE void dist_block_tx_domain(MACROBLOCK *x, int plane, int block,
1366                                         TX_SIZE tx_size, int64_t *out_dist,
1367                                         int64_t *out_sse) {
1368   const struct macroblock_plane *const p = &x->plane[plane];
1369   // Transform domain distortion computation is more efficient as it does
1370   // not involve an inverse transform, but it is less accurate.
1371   const int buffer_length = av1_get_max_eob(tx_size);
1372   int64_t this_sse;
1373   // TX-domain results need to shift down to Q2/D10 to match pixel
1374   // domain distortion values which are in Q2^2
1375   int shift = (MAX_TX_SCALE - av1_get_tx_scale(tx_size)) * 2;
1376   const int block_offset = BLOCK_OFFSET(block);
1377   tran_low_t *const coeff = p->coeff + block_offset;
1378   tran_low_t *const dqcoeff = p->dqcoeff + block_offset;
1379 #if CONFIG_AV1_HIGHBITDEPTH
1380   MACROBLOCKD *const xd = &x->e_mbd;
1381   if (is_cur_buf_hbd(xd))
1382     *out_dist = av1_highbd_block_error(coeff, dqcoeff, buffer_length, &this_sse,
1383                                        xd->bd);
1384   else
1385 #endif
1386     *out_dist = av1_block_error(coeff, dqcoeff, buffer_length, &this_sse);
1387 
1388   *out_dist = RIGHT_SIGNED_SHIFT(*out_dist, shift);
1389   *out_sse = RIGHT_SIGNED_SHIFT(this_sse, shift);
1390 }
1391 
prune_txk_type_separ(const AV1_COMP * cpi,MACROBLOCK * x,int plane,int block,TX_SIZE tx_size,int blk_row,int blk_col,BLOCK_SIZE plane_bsize,int * txk_map,int16_t allowed_tx_mask,int prune_factor,const TXB_CTX * const txb_ctx,int reduced_tx_set_used,int64_t ref_best_rd,int num_sel)1392 uint16_t prune_txk_type_separ(const AV1_COMP *cpi, MACROBLOCK *x, int plane,
1393                               int block, TX_SIZE tx_size, int blk_row,
1394                               int blk_col, BLOCK_SIZE plane_bsize, int *txk_map,
1395                               int16_t allowed_tx_mask, int prune_factor,
1396                               const TXB_CTX *const txb_ctx,
1397                               int reduced_tx_set_used, int64_t ref_best_rd,
1398                               int num_sel) {
1399   const AV1_COMMON *cm = &cpi->common;
1400 
1401   int idx;
1402 
1403   int64_t rds_v[4];
1404   int64_t rds_h[4];
1405   int idx_v[4] = { 0, 1, 2, 3 };
1406   int idx_h[4] = { 0, 1, 2, 3 };
1407   int skip_v[4] = { 0 };
1408   int skip_h[4] = { 0 };
1409   const int idx_map[16] = {
1410     DCT_DCT,      DCT_ADST,      DCT_FLIPADST,      V_DCT,
1411     ADST_DCT,     ADST_ADST,     ADST_FLIPADST,     V_ADST,
1412     FLIPADST_DCT, FLIPADST_ADST, FLIPADST_FLIPADST, V_FLIPADST,
1413     H_DCT,        H_ADST,        H_FLIPADST,        IDTX
1414   };
1415 
1416   const int sel_pattern_v[16] = {
1417     0, 0, 1, 1, 0, 2, 1, 2, 2, 0, 3, 1, 3, 2, 3, 3
1418   };
1419   const int sel_pattern_h[16] = {
1420     0, 1, 0, 1, 2, 0, 2, 1, 2, 3, 0, 3, 1, 3, 2, 3
1421   };
1422 
1423   QUANT_PARAM quant_param;
1424   TxfmParam txfm_param;
1425   av1_setup_xform(cm, x, tx_size, DCT_DCT, &txfm_param);
1426   av1_setup_quant(tx_size, 1, AV1_XFORM_QUANT_B, cpi->oxcf.q_cfg.quant_b_adapt,
1427                   &quant_param);
1428   int tx_type;
1429   // to ensure we can try ones even outside of ext_tx_set of current block
1430   // this function should only be called for size < 16
1431   assert(txsize_sqr_up_map[tx_size] <= TX_16X16);
1432   txfm_param.tx_set_type = EXT_TX_SET_ALL16;
1433 
1434   int rate_cost = 0;
1435   int64_t dist = 0, sse = 0;
1436   // evaluate horizontal with vertical DCT
1437   for (idx = 0; idx < 4; ++idx) {
1438     tx_type = idx_map[idx];
1439     txfm_param.tx_type = tx_type;
1440 
1441     av1_xform_quant(x, plane, block, blk_row, blk_col, plane_bsize, &txfm_param,
1442                     &quant_param);
1443 
1444     dist_block_tx_domain(x, plane, block, tx_size, &dist, &sse);
1445 
1446     rate_cost = av1_cost_coeffs_txb_laplacian(x, plane, block, tx_size, tx_type,
1447                                               txb_ctx, reduced_tx_set_used, 0);
1448 
1449     rds_h[idx] = RDCOST(x->rdmult, rate_cost, dist);
1450 
1451     if ((rds_h[idx] - (rds_h[idx] >> 2)) > ref_best_rd) {
1452       skip_h[idx] = 1;
1453     }
1454   }
1455   sort_rd(rds_h, idx_h, 4);
1456   for (idx = 1; idx < 4; idx++) {
1457     if (rds_h[idx] > rds_h[0] * 1.2) skip_h[idx_h[idx]] = 1;
1458   }
1459 
1460   if (skip_h[idx_h[0]]) return (uint16_t)0xFFFF;
1461 
1462   // evaluate vertical with the best horizontal chosen
1463   rds_v[0] = rds_h[0];
1464   int start_v = 1, end_v = 4;
1465   const int *idx_map_v = idx_map + idx_h[0];
1466 
1467   for (idx = start_v; idx < end_v; ++idx) {
1468     tx_type = idx_map_v[idx_v[idx] * 4];
1469     txfm_param.tx_type = tx_type;
1470 
1471     av1_xform_quant(x, plane, block, blk_row, blk_col, plane_bsize, &txfm_param,
1472                     &quant_param);
1473 
1474     dist_block_tx_domain(x, plane, block, tx_size, &dist, &sse);
1475 
1476     rate_cost = av1_cost_coeffs_txb_laplacian(x, plane, block, tx_size, tx_type,
1477                                               txb_ctx, reduced_tx_set_used, 0);
1478 
1479     rds_v[idx] = RDCOST(x->rdmult, rate_cost, dist);
1480 
1481     if ((rds_v[idx] - (rds_v[idx] >> 2)) > ref_best_rd) {
1482       skip_v[idx] = 1;
1483     }
1484   }
1485   sort_rd(rds_v, idx_v, 4);
1486   for (idx = 1; idx < 4; idx++) {
1487     if (rds_v[idx] > rds_v[0] * 1.2) skip_v[idx_v[idx]] = 1;
1488   }
1489 
1490   // combine rd_h and rd_v to prune tx candidates
1491   int i_v, i_h;
1492   int64_t rds[16];
1493   int num_cand = 0, last = TX_TYPES - 1;
1494 
1495   for (int i = 0; i < 16; i++) {
1496     i_v = sel_pattern_v[i];
1497     i_h = sel_pattern_h[i];
1498     tx_type = idx_map[idx_v[i_v] * 4 + idx_h[i_h]];
1499     if (!(allowed_tx_mask & (1 << tx_type)) || skip_h[idx_h[i_h]] ||
1500         skip_v[idx_v[i_v]]) {
1501       txk_map[last] = tx_type;
1502       last--;
1503     } else {
1504       txk_map[num_cand] = tx_type;
1505       rds[num_cand] = rds_v[i_v] + rds_h[i_h];
1506       if (rds[num_cand] == 0) rds[num_cand] = 1;
1507       num_cand++;
1508     }
1509   }
1510   sort_rd(rds, txk_map, num_cand);
1511 
1512   uint16_t prune = (uint16_t)(~(1 << txk_map[0]));
1513   num_sel = AOMMIN(num_sel, num_cand);
1514 
1515   for (int i = 1; i < num_sel; i++) {
1516     int64_t factor = 1800 * (rds[i] - rds[0]) / (rds[0]);
1517     if (factor < (int64_t)prune_factor)
1518       prune &= ~(1 << txk_map[i]);
1519     else
1520       break;
1521   }
1522   return prune;
1523 }
1524 
prune_txk_type(const AV1_COMP * cpi,MACROBLOCK * x,int plane,int block,TX_SIZE tx_size,int blk_row,int blk_col,BLOCK_SIZE plane_bsize,int * txk_map,uint16_t allowed_tx_mask,int prune_factor,const TXB_CTX * const txb_ctx,int reduced_tx_set_used)1525 uint16_t prune_txk_type(const AV1_COMP *cpi, MACROBLOCK *x, int plane,
1526                         int block, TX_SIZE tx_size, int blk_row, int blk_col,
1527                         BLOCK_SIZE plane_bsize, int *txk_map,
1528                         uint16_t allowed_tx_mask, int prune_factor,
1529                         const TXB_CTX *const txb_ctx, int reduced_tx_set_used) {
1530   const AV1_COMMON *cm = &cpi->common;
1531   int tx_type;
1532 
1533   int64_t rds[TX_TYPES];
1534 
1535   int num_cand = 0;
1536   int last = TX_TYPES - 1;
1537 
1538   TxfmParam txfm_param;
1539   QUANT_PARAM quant_param;
1540   av1_setup_xform(cm, x, tx_size, DCT_DCT, &txfm_param);
1541   av1_setup_quant(tx_size, 1, AV1_XFORM_QUANT_B, cpi->oxcf.q_cfg.quant_b_adapt,
1542                   &quant_param);
1543 
1544   for (int idx = 0; idx < TX_TYPES; idx++) {
1545     tx_type = idx;
1546     int rate_cost = 0;
1547     int64_t dist = 0, sse = 0;
1548     if (!(allowed_tx_mask & (1 << tx_type))) {
1549       txk_map[last] = tx_type;
1550       last--;
1551       continue;
1552     }
1553     txfm_param.tx_type = tx_type;
1554 
1555     // do txfm and quantization
1556     av1_xform_quant(x, plane, block, blk_row, blk_col, plane_bsize, &txfm_param,
1557                     &quant_param);
1558     // estimate rate cost
1559     rate_cost = av1_cost_coeffs_txb_laplacian(x, plane, block, tx_size, tx_type,
1560                                               txb_ctx, reduced_tx_set_used, 0);
1561     // tx domain dist
1562     dist_block_tx_domain(x, plane, block, tx_size, &dist, &sse);
1563 
1564     txk_map[num_cand] = tx_type;
1565     rds[num_cand] = RDCOST(x->rdmult, rate_cost, dist);
1566     if (rds[num_cand] == 0) rds[num_cand] = 1;
1567     num_cand++;
1568   }
1569 
1570   if (num_cand == 0) return (uint16_t)0xFFFF;
1571 
1572   sort_rd(rds, txk_map, num_cand);
1573   uint16_t prune = (uint16_t)(~(1 << txk_map[0]));
1574 
1575   // 0 < prune_factor <= 1000 controls aggressiveness
1576   int64_t factor = 0;
1577   for (int idx = 1; idx < num_cand; idx++) {
1578     factor = 1000 * (rds[idx] - rds[0]) / rds[0];
1579     if (factor < (int64_t)prune_factor)
1580       prune &= ~(1 << txk_map[idx]);
1581     else
1582       break;
1583   }
1584   return prune;
1585 }
1586 
1587 // These thresholds were calibrated to provide a certain number of TX types
1588 // pruned by the model on average, i.e. selecting a threshold with index i
1589 // will lead to pruning i+1 TX types on average
1590 static const float *prune_2D_adaptive_thresholds[] = {
1591   // TX_4X4
1592   (float[]){ 0.00549f, 0.01306f, 0.02039f, 0.02747f, 0.03406f, 0.04065f,
1593              0.04724f, 0.05383f, 0.06067f, 0.06799f, 0.07605f, 0.08533f,
1594              0.09778f, 0.11780f },
1595   // TX_8X8
1596   (float[]){ 0.00037f, 0.00183f, 0.00525f, 0.01038f, 0.01697f, 0.02502f,
1597              0.03381f, 0.04333f, 0.05286f, 0.06287f, 0.07434f, 0.08850f,
1598              0.10803f, 0.14124f },
1599   // TX_16X16
1600   (float[]){ 0.01404f, 0.02000f, 0.04211f, 0.05164f, 0.05798f, 0.06335f,
1601              0.06897f, 0.07629f, 0.08875f, 0.11169f },
1602   // TX_32X32
1603   NULL,
1604   // TX_64X64
1605   NULL,
1606   // TX_4X8
1607   (float[]){ 0.00183f, 0.00745f, 0.01428f, 0.02185f, 0.02966f, 0.03723f,
1608              0.04456f, 0.05188f, 0.05920f, 0.06702f, 0.07605f, 0.08704f,
1609              0.10168f, 0.12585f },
1610   // TX_8X4
1611   (float[]){ 0.00085f, 0.00476f, 0.01135f, 0.01892f, 0.02698f, 0.03528f,
1612              0.04358f, 0.05164f, 0.05994f, 0.06848f, 0.07849f, 0.09021f,
1613              0.10583f, 0.13123f },
1614   // TX_8X16
1615   (float[]){ 0.00037f, 0.00232f, 0.00671f, 0.01257f, 0.01965f, 0.02722f,
1616              0.03552f, 0.04382f, 0.05237f, 0.06189f, 0.07336f, 0.08728f,
1617              0.10730f, 0.14221f },
1618   // TX_16X8
1619   (float[]){ 0.00061f, 0.00330f, 0.00818f, 0.01453f, 0.02185f, 0.02966f,
1620              0.03772f, 0.04578f, 0.05383f, 0.06262f, 0.07288f, 0.08582f,
1621              0.10339f, 0.13464f },
1622   // TX_16X32
1623   NULL,
1624   // TX_32X16
1625   NULL,
1626   // TX_32X64
1627   NULL,
1628   // TX_64X32
1629   NULL,
1630   // TX_4X16
1631   (float[]){ 0.00232f, 0.00671f, 0.01257f, 0.01941f, 0.02673f, 0.03430f,
1632              0.04211f, 0.04968f, 0.05750f, 0.06580f, 0.07507f, 0.08655f,
1633              0.10242f, 0.12878f },
1634   // TX_16X4
1635   (float[]){ 0.00110f, 0.00525f, 0.01208f, 0.01990f, 0.02795f, 0.03601f,
1636              0.04358f, 0.05115f, 0.05896f, 0.06702f, 0.07629f, 0.08752f,
1637              0.10217f, 0.12610f },
1638   // TX_8X32
1639   NULL,
1640   // TX_32X8
1641   NULL,
1642   // TX_16X64
1643   NULL,
1644   // TX_64X16
1645   NULL,
1646 };
1647 
1648 // Probablities are sorted in descending order.
sort_probability(float prob[],int txk[],int len)1649 static INLINE void sort_probability(float prob[], int txk[], int len) {
1650   int i, j, k;
1651 
1652   for (i = 1; i <= len - 1; ++i) {
1653     for (j = 0; j < i; ++j) {
1654       if (prob[j] < prob[i]) {
1655         float temp;
1656         int tempi;
1657 
1658         temp = prob[i];
1659         tempi = txk[i];
1660 
1661         for (k = i; k > j; k--) {
1662           prob[k] = prob[k - 1];
1663           txk[k] = txk[k - 1];
1664         }
1665 
1666         prob[j] = temp;
1667         txk[j] = tempi;
1668         break;
1669       }
1670     }
1671   }
1672 }
1673 
get_adaptive_thresholds(TX_SIZE tx_size,TxSetType tx_set_type,TX_TYPE_PRUNE_MODE prune_2d_txfm_mode)1674 static INLINE float get_adaptive_thresholds(
1675     TX_SIZE tx_size, TxSetType tx_set_type,
1676     TX_TYPE_PRUNE_MODE prune_2d_txfm_mode) {
1677   const int prune_aggr_table[5][2] = {
1678     { 4, 1 }, { 6, 3 }, { 9, 6 }, { 9, 6 }, { 12, 9 }
1679   };
1680   int pruning_aggressiveness = 0;
1681   if (tx_set_type == EXT_TX_SET_ALL16)
1682     pruning_aggressiveness =
1683         prune_aggr_table[prune_2d_txfm_mode - TX_TYPE_PRUNE_1][0];
1684   else if (tx_set_type == EXT_TX_SET_DTT9_IDTX_1DDCT)
1685     pruning_aggressiveness =
1686         prune_aggr_table[prune_2d_txfm_mode - TX_TYPE_PRUNE_1][1];
1687 
1688   return prune_2D_adaptive_thresholds[tx_size][pruning_aggressiveness];
1689 }
1690 
get_energy_distribution_finer(const int16_t * diff,int stride,int bw,int bh,float * hordist,float * verdist)1691 static AOM_INLINE void get_energy_distribution_finer(const int16_t *diff,
1692                                                      int stride, int bw, int bh,
1693                                                      float *hordist,
1694                                                      float *verdist) {
1695   // First compute downscaled block energy values (esq); downscale factors
1696   // are defined by w_shift and h_shift.
1697   unsigned int esq[256];
1698   const int w_shift = bw <= 8 ? 0 : 1;
1699   const int h_shift = bh <= 8 ? 0 : 1;
1700   const int esq_w = bw >> w_shift;
1701   const int esq_h = bh >> h_shift;
1702   const int esq_sz = esq_w * esq_h;
1703   int i, j;
1704   memset(esq, 0, esq_sz * sizeof(esq[0]));
1705   if (w_shift) {
1706     for (i = 0; i < bh; i++) {
1707       unsigned int *cur_esq_row = esq + (i >> h_shift) * esq_w;
1708       const int16_t *cur_diff_row = diff + i * stride;
1709       for (j = 0; j < bw; j += 2) {
1710         cur_esq_row[j >> 1] += (cur_diff_row[j] * cur_diff_row[j] +
1711                                 cur_diff_row[j + 1] * cur_diff_row[j + 1]);
1712       }
1713     }
1714   } else {
1715     for (i = 0; i < bh; i++) {
1716       unsigned int *cur_esq_row = esq + (i >> h_shift) * esq_w;
1717       const int16_t *cur_diff_row = diff + i * stride;
1718       for (j = 0; j < bw; j++) {
1719         cur_esq_row[j] += cur_diff_row[j] * cur_diff_row[j];
1720       }
1721     }
1722   }
1723 
1724   uint64_t total = 0;
1725   for (i = 0; i < esq_sz; i++) total += esq[i];
1726 
1727   // Output hordist and verdist arrays are normalized 1D projections of esq
1728   if (total == 0) {
1729     float hor_val = 1.0f / esq_w;
1730     for (j = 0; j < esq_w - 1; j++) hordist[j] = hor_val;
1731     float ver_val = 1.0f / esq_h;
1732     for (i = 0; i < esq_h - 1; i++) verdist[i] = ver_val;
1733     return;
1734   }
1735 
1736   const float e_recip = 1.0f / (float)total;
1737   memset(hordist, 0, (esq_w - 1) * sizeof(hordist[0]));
1738   memset(verdist, 0, (esq_h - 1) * sizeof(verdist[0]));
1739   const unsigned int *cur_esq_row;
1740   for (i = 0; i < esq_h - 1; i++) {
1741     cur_esq_row = esq + i * esq_w;
1742     for (j = 0; j < esq_w - 1; j++) {
1743       hordist[j] += (float)cur_esq_row[j];
1744       verdist[i] += (float)cur_esq_row[j];
1745     }
1746     verdist[i] += (float)cur_esq_row[j];
1747   }
1748   cur_esq_row = esq + i * esq_w;
1749   for (j = 0; j < esq_w - 1; j++) hordist[j] += (float)cur_esq_row[j];
1750 
1751   for (j = 0; j < esq_w - 1; j++) hordist[j] *= e_recip;
1752   for (i = 0; i < esq_h - 1; i++) verdist[i] *= e_recip;
1753 }
1754 
prune_tx_2D(MACROBLOCK * x,BLOCK_SIZE bsize,TX_SIZE tx_size,int blk_row,int blk_col,TxSetType tx_set_type,TX_TYPE_PRUNE_MODE prune_2d_txfm_mode,int * txk_map,uint16_t * allowed_tx_mask)1755 static void prune_tx_2D(MACROBLOCK *x, BLOCK_SIZE bsize, TX_SIZE tx_size,
1756                         int blk_row, int blk_col, TxSetType tx_set_type,
1757                         TX_TYPE_PRUNE_MODE prune_2d_txfm_mode, int *txk_map,
1758                         uint16_t *allowed_tx_mask) {
1759   int tx_type_table_2D[16] = {
1760     DCT_DCT,      DCT_ADST,      DCT_FLIPADST,      V_DCT,
1761     ADST_DCT,     ADST_ADST,     ADST_FLIPADST,     V_ADST,
1762     FLIPADST_DCT, FLIPADST_ADST, FLIPADST_FLIPADST, V_FLIPADST,
1763     H_DCT,        H_ADST,        H_FLIPADST,        IDTX
1764   };
1765   if (tx_set_type != EXT_TX_SET_ALL16 &&
1766       tx_set_type != EXT_TX_SET_DTT9_IDTX_1DDCT)
1767     return;
1768 #if CONFIG_NN_V2
1769   NN_CONFIG_V2 *nn_config_hor = av1_tx_type_nnconfig_map_hor[tx_size];
1770   NN_CONFIG_V2 *nn_config_ver = av1_tx_type_nnconfig_map_ver[tx_size];
1771 #else
1772   const NN_CONFIG *nn_config_hor = av1_tx_type_nnconfig_map_hor[tx_size];
1773   const NN_CONFIG *nn_config_ver = av1_tx_type_nnconfig_map_ver[tx_size];
1774 #endif
1775   if (!nn_config_hor || !nn_config_ver) return;  // Model not established yet.
1776 
1777   aom_clear_system_state();
1778   float hfeatures[16], vfeatures[16];
1779   float hscores[4], vscores[4];
1780   float scores_2D_raw[16];
1781   float scores_2D[16];
1782   const int bw = tx_size_wide[tx_size];
1783   const int bh = tx_size_high[tx_size];
1784   const int hfeatures_num = bw <= 8 ? bw : bw / 2;
1785   const int vfeatures_num = bh <= 8 ? bh : bh / 2;
1786   assert(hfeatures_num <= 16);
1787   assert(vfeatures_num <= 16);
1788 
1789   const struct macroblock_plane *const p = &x->plane[0];
1790   const int diff_stride = block_size_wide[bsize];
1791   const int16_t *diff = p->src_diff + 4 * blk_row * diff_stride + 4 * blk_col;
1792   get_energy_distribution_finer(diff, diff_stride, bw, bh, hfeatures,
1793                                 vfeatures);
1794   av1_get_horver_correlation_full(diff, diff_stride, bw, bh,
1795                                   &hfeatures[hfeatures_num - 1],
1796                                   &vfeatures[vfeatures_num - 1]);
1797   aom_clear_system_state();
1798 #if CONFIG_NN_V2
1799   av1_nn_predict_v2(hfeatures, nn_config_hor, 0, hscores);
1800   av1_nn_predict_v2(vfeatures, nn_config_ver, 0, vscores);
1801 #else
1802   av1_nn_predict(hfeatures, nn_config_hor, 1, hscores);
1803   av1_nn_predict(vfeatures, nn_config_ver, 1, vscores);
1804 #endif
1805   aom_clear_system_state();
1806 
1807   for (int i = 0; i < 4; i++) {
1808     float *cur_scores_2D = scores_2D_raw + i * 4;
1809     cur_scores_2D[0] = vscores[i] * hscores[0];
1810     cur_scores_2D[1] = vscores[i] * hscores[1];
1811     cur_scores_2D[2] = vscores[i] * hscores[2];
1812     cur_scores_2D[3] = vscores[i] * hscores[3];
1813   }
1814 
1815   av1_nn_softmax(scores_2D_raw, scores_2D, 16);
1816 
1817   const float score_thresh =
1818       get_adaptive_thresholds(tx_size, tx_set_type, prune_2d_txfm_mode);
1819 
1820   // Always keep the TX type with the highest score, prune all others with
1821   // score below score_thresh.
1822   int max_score_i = 0;
1823   float max_score = 0.0f;
1824   uint16_t allow_bitmask = 0;
1825   float sum_score = 0.0;
1826   // Calculate sum of allowed tx type score and Populate allow bit mask based
1827   // on score_thresh and allowed_tx_mask
1828   for (int tx_idx = 0; tx_idx < TX_TYPES; tx_idx++) {
1829     int allow_tx_type = *allowed_tx_mask & (1 << tx_type_table_2D[tx_idx]);
1830     if (scores_2D[tx_idx] > max_score && allow_tx_type) {
1831       max_score = scores_2D[tx_idx];
1832       max_score_i = tx_idx;
1833     }
1834     if (scores_2D[tx_idx] >= score_thresh && allow_tx_type) {
1835       // Set allow mask based on score_thresh
1836       allow_bitmask |= (1 << tx_type_table_2D[tx_idx]);
1837 
1838       // Accumulate score of allowed tx type
1839       sum_score += scores_2D[tx_idx];
1840     }
1841   }
1842   if (!((allow_bitmask >> max_score_i) & 0x01)) {
1843     // Set allow mask based on tx type with max score
1844     allow_bitmask |= (1 << tx_type_table_2D[max_score_i]);
1845     sum_score += scores_2D[max_score_i];
1846   }
1847   // Sort tx type probability of all types
1848   sort_probability(scores_2D, tx_type_table_2D, TX_TYPES);
1849 
1850   // Enable more pruning based on tx type probability and number of allowed tx
1851   // types
1852   if (prune_2d_txfm_mode >= TX_TYPE_PRUNE_4) {
1853     float temp_score = 0.0;
1854     float score_ratio = 0.0;
1855     int tx_idx, tx_count = 0;
1856     const float inv_sum_score = 100 / sum_score;
1857     // Get allowed tx types based on sorted probability score and tx count
1858     for (tx_idx = 0; tx_idx < TX_TYPES; tx_idx++) {
1859       // Skip the tx type which has more than 30% of cumulative
1860       // probability and allowed tx type count is more than 2
1861       if (score_ratio > 30.0 && tx_count >= 2) break;
1862 
1863       // Calculate cumulative probability of allowed tx types
1864       if (allow_bitmask & (1 << tx_type_table_2D[tx_idx])) {
1865         // Calculate cumulative probability
1866         temp_score += scores_2D[tx_idx];
1867 
1868         // Calculate percentage of cumulative probability of allowed tx type
1869         score_ratio = temp_score * inv_sum_score;
1870         tx_count++;
1871       }
1872     }
1873     // Set remaining tx types as pruned
1874     for (; tx_idx < TX_TYPES; tx_idx++)
1875       allow_bitmask &= ~(1 << tx_type_table_2D[tx_idx]);
1876   }
1877   memcpy(txk_map, tx_type_table_2D, sizeof(tx_type_table_2D));
1878   *allowed_tx_mask = allow_bitmask;
1879 }
1880 
get_dev(float mean,double x2_sum,int num)1881 static float get_dev(float mean, double x2_sum, int num) {
1882   const float e_x2 = (float)(x2_sum / num);
1883   const float diff = e_x2 - mean * mean;
1884   const float dev = (diff > 0) ? sqrtf(diff) : 0;
1885   return dev;
1886 }
1887 
1888 // Feature used by the model to predict tx split: the mean and standard
1889 // deviation values of the block and sub-blocks.
get_mean_dev_features(const int16_t * data,int stride,int bw,int bh,float * feature)1890 static AOM_INLINE void get_mean_dev_features(const int16_t *data, int stride,
1891                                              int bw, int bh, float *feature) {
1892   const int16_t *const data_ptr = &data[0];
1893   const int subh = (bh >= bw) ? (bh >> 1) : bh;
1894   const int subw = (bw >= bh) ? (bw >> 1) : bw;
1895   const int num = bw * bh;
1896   const int sub_num = subw * subh;
1897   int feature_idx = 2;
1898   int total_x_sum = 0;
1899   int64_t total_x2_sum = 0;
1900   int blk_idx = 0;
1901   double mean2_sum = 0.0f;
1902   float dev_sum = 0.0f;
1903 
1904   for (int row = 0; row < bh; row += subh) {
1905     for (int col = 0; col < bw; col += subw) {
1906       int x_sum;
1907       int64_t x2_sum;
1908       // TODO(any): Write a SIMD version. Clear registers.
1909       aom_get_blk_sse_sum(data_ptr + row * stride + col, stride, subw, subh,
1910                           &x_sum, &x2_sum);
1911       total_x_sum += x_sum;
1912       total_x2_sum += x2_sum;
1913 
1914       aom_clear_system_state();
1915       const float mean = (float)x_sum / sub_num;
1916       const float dev = get_dev(mean, (double)x2_sum, sub_num);
1917       feature[feature_idx++] = mean;
1918       feature[feature_idx++] = dev;
1919       mean2_sum += (double)(mean * mean);
1920       dev_sum += dev;
1921       blk_idx++;
1922     }
1923   }
1924 
1925   const float lvl0_mean = (float)total_x_sum / num;
1926   feature[0] = lvl0_mean;
1927   feature[1] = get_dev(lvl0_mean, (double)total_x2_sum, num);
1928 
1929   if (blk_idx > 1) {
1930     // Deviation of means.
1931     feature[feature_idx++] = get_dev(lvl0_mean, mean2_sum, blk_idx);
1932     // Mean of deviations.
1933     feature[feature_idx++] = dev_sum / blk_idx;
1934   }
1935 }
1936 
ml_predict_tx_split(MACROBLOCK * x,BLOCK_SIZE bsize,int blk_row,int blk_col,TX_SIZE tx_size)1937 static int ml_predict_tx_split(MACROBLOCK *x, BLOCK_SIZE bsize, int blk_row,
1938                                int blk_col, TX_SIZE tx_size) {
1939   const NN_CONFIG *nn_config = av1_tx_split_nnconfig_map[tx_size];
1940   if (!nn_config) return -1;
1941 
1942   const int diff_stride = block_size_wide[bsize];
1943   const int16_t *diff =
1944       x->plane[0].src_diff + 4 * blk_row * diff_stride + 4 * blk_col;
1945   const int bw = tx_size_wide[tx_size];
1946   const int bh = tx_size_high[tx_size];
1947   aom_clear_system_state();
1948 
1949   float features[64] = { 0.0f };
1950   get_mean_dev_features(diff, diff_stride, bw, bh, features);
1951 
1952   float score = 0.0f;
1953   av1_nn_predict(features, nn_config, 1, &score);
1954   aom_clear_system_state();
1955 
1956   int int_score = (int)(score * 10000);
1957   return clamp(int_score, -80000, 80000);
1958 }
1959 
1960 static INLINE uint16_t
get_tx_mask(const AV1_COMP * cpi,MACROBLOCK * x,int plane,int block,int blk_row,int blk_col,BLOCK_SIZE plane_bsize,TX_SIZE tx_size,const TXB_CTX * const txb_ctx,FAST_TX_SEARCH_MODE ftxs_mode,int64_t ref_best_rd,TX_TYPE * allowed_txk_types,int * txk_map)1961 get_tx_mask(const AV1_COMP *cpi, MACROBLOCK *x, int plane, int block,
1962             int blk_row, int blk_col, BLOCK_SIZE plane_bsize, TX_SIZE tx_size,
1963             const TXB_CTX *const txb_ctx, FAST_TX_SEARCH_MODE ftxs_mode,
1964             int64_t ref_best_rd, TX_TYPE *allowed_txk_types, int *txk_map) {
1965   const AV1_COMMON *cm = &cpi->common;
1966   MACROBLOCKD *xd = &x->e_mbd;
1967   MB_MODE_INFO *mbmi = xd->mi[0];
1968   const TxfmSearchParams *txfm_params = &x->txfm_search_params;
1969   const int is_inter = is_inter_block(mbmi);
1970   const int fast_tx_search = ftxs_mode & FTXS_DCT_AND_1D_DCT_ONLY;
1971   // if txk_allowed = TX_TYPES, >1 tx types are allowed, else, if txk_allowed <
1972   // TX_TYPES, only that specific tx type is allowed.
1973   TX_TYPE txk_allowed = TX_TYPES;
1974 
1975   if ((!is_inter && txfm_params->use_default_intra_tx_type) ||
1976       (is_inter && txfm_params->use_default_inter_tx_type)) {
1977     txk_allowed =
1978         get_default_tx_type(0, xd, tx_size, cpi->use_screen_content_tools);
1979   } else if (x->rd_model == LOW_TXFM_RD) {
1980     if (plane == 0) txk_allowed = DCT_DCT;
1981   }
1982 
1983   const TxSetType tx_set_type = av1_get_ext_tx_set_type(
1984       tx_size, is_inter, cm->features.reduced_tx_set_used);
1985 
1986   TX_TYPE uv_tx_type = DCT_DCT;
1987   if (plane) {
1988     // tx_type of PLANE_TYPE_UV should be the same as PLANE_TYPE_Y
1989     uv_tx_type = txk_allowed =
1990         av1_get_tx_type(xd, get_plane_type(plane), blk_row, blk_col, tx_size,
1991                         cm->features.reduced_tx_set_used);
1992   }
1993   PREDICTION_MODE intra_dir =
1994       mbmi->filter_intra_mode_info.use_filter_intra
1995           ? fimode_to_intradir[mbmi->filter_intra_mode_info.filter_intra_mode]
1996           : mbmi->mode;
1997   uint16_t ext_tx_used_flag =
1998       cpi->sf.tx_sf.tx_type_search.use_reduced_intra_txset &&
1999               tx_set_type == EXT_TX_SET_DTT4_IDTX_1DDCT
2000           ? av1_reduced_intra_tx_used_flag[intra_dir]
2001           : av1_ext_tx_used_flag[tx_set_type];
2002   if (xd->lossless[mbmi->segment_id] || txsize_sqr_up_map[tx_size] > TX_32X32 ||
2003       ext_tx_used_flag == 0x0001 ||
2004       (is_inter && cpi->oxcf.txfm_cfg.use_inter_dct_only) ||
2005       (!is_inter && cpi->oxcf.txfm_cfg.use_intra_dct_only)) {
2006     txk_allowed = DCT_DCT;
2007   }
2008 
2009   if (cpi->oxcf.txfm_cfg.enable_flip_idtx == 0)
2010     ext_tx_used_flag &= DCT_ADST_TX_MASK;
2011 
2012   uint16_t allowed_tx_mask = 0;  // 1: allow; 0: skip.
2013   if (txk_allowed < TX_TYPES) {
2014     allowed_tx_mask = 1 << txk_allowed;
2015     allowed_tx_mask &= ext_tx_used_flag;
2016   } else if (fast_tx_search) {
2017     allowed_tx_mask = 0x0c01;  // V_DCT, H_DCT, DCT_DCT
2018     allowed_tx_mask &= ext_tx_used_flag;
2019   } else {
2020     assert(plane == 0);
2021     allowed_tx_mask = ext_tx_used_flag;
2022     int num_allowed = 0;
2023     const FRAME_UPDATE_TYPE update_type = get_frame_update_type(&cpi->gf_group);
2024     const int *tx_type_probs =
2025         cpi->frame_probs.tx_type_probs[update_type][tx_size];
2026     int i;
2027 
2028     if (cpi->sf.tx_sf.tx_type_search.prune_tx_type_using_stats) {
2029       static const int thresh_arr[2][7] = { { 10, 15, 15, 10, 15, 15, 15 },
2030                                             { 10, 17, 17, 10, 17, 17, 17 } };
2031       const int thresh =
2032           thresh_arr[cpi->sf.tx_sf.tx_type_search.prune_tx_type_using_stats - 1]
2033                     [update_type];
2034       uint16_t prune = 0;
2035       int max_prob = -1;
2036       int max_idx = 0;
2037       for (i = 0; i < TX_TYPES; i++) {
2038         if (tx_type_probs[i] > max_prob && (allowed_tx_mask & (1 << i))) {
2039           max_prob = tx_type_probs[i];
2040           max_idx = i;
2041         }
2042         if (tx_type_probs[i] < thresh) prune |= (1 << i);
2043       }
2044       if ((prune >> max_idx) & 0x01) prune &= ~(1 << max_idx);
2045       allowed_tx_mask &= (~prune);
2046     }
2047     for (i = 0; i < TX_TYPES; i++) {
2048       if (allowed_tx_mask & (1 << i)) num_allowed++;
2049     }
2050     assert(num_allowed > 0);
2051 
2052     if (num_allowed > 2 && cpi->sf.tx_sf.tx_type_search.prune_tx_type_est_rd) {
2053       int pf = prune_factors[txfm_params->prune_2d_txfm_mode];
2054       int mf = mul_factors[txfm_params->prune_2d_txfm_mode];
2055       if (num_allowed <= 7) {
2056         const uint16_t prune =
2057             prune_txk_type(cpi, x, plane, block, tx_size, blk_row, blk_col,
2058                            plane_bsize, txk_map, allowed_tx_mask, pf, txb_ctx,
2059                            cm->features.reduced_tx_set_used);
2060         allowed_tx_mask &= (~prune);
2061       } else {
2062         const int num_sel = (num_allowed * mf + 50) / 100;
2063         const uint16_t prune = prune_txk_type_separ(
2064             cpi, x, plane, block, tx_size, blk_row, blk_col, plane_bsize,
2065             txk_map, allowed_tx_mask, pf, txb_ctx,
2066             cm->features.reduced_tx_set_used, ref_best_rd, num_sel);
2067 
2068         allowed_tx_mask &= (~prune);
2069       }
2070     } else {
2071       assert(num_allowed > 0);
2072       int allowed_tx_count =
2073           (txfm_params->prune_2d_txfm_mode >= TX_TYPE_PRUNE_4) ? 1 : 5;
2074       // !fast_tx_search && txk_end != txk_start && plane == 0
2075       if (txfm_params->prune_2d_txfm_mode >= TX_TYPE_PRUNE_1 && is_inter &&
2076           num_allowed > allowed_tx_count) {
2077         prune_tx_2D(x, plane_bsize, tx_size, blk_row, blk_col, tx_set_type,
2078                     txfm_params->prune_2d_txfm_mode, txk_map, &allowed_tx_mask);
2079       }
2080     }
2081   }
2082 
2083   // Need to have at least one transform type allowed.
2084   if (allowed_tx_mask == 0) {
2085     txk_allowed = (plane ? uv_tx_type : DCT_DCT);
2086     allowed_tx_mask = (1 << txk_allowed);
2087   }
2088 
2089   assert(IMPLIES(txk_allowed < TX_TYPES, allowed_tx_mask == 1 << txk_allowed));
2090   *allowed_txk_types = txk_allowed;
2091   return allowed_tx_mask;
2092 }
2093 
2094 #if CONFIG_RD_DEBUG
update_txb_coeff_cost(RD_STATS * rd_stats,int plane,TX_SIZE tx_size,int blk_row,int blk_col,int txb_coeff_cost)2095 static INLINE void update_txb_coeff_cost(RD_STATS *rd_stats, int plane,
2096                                          TX_SIZE tx_size, int blk_row,
2097                                          int blk_col, int txb_coeff_cost) {
2098   (void)blk_row;
2099   (void)blk_col;
2100   (void)tx_size;
2101   rd_stats->txb_coeff_cost[plane] += txb_coeff_cost;
2102 
2103   {
2104     const int txb_h = tx_size_high_unit[tx_size];
2105     const int txb_w = tx_size_wide_unit[tx_size];
2106     int idx, idy;
2107     for (idy = 0; idy < txb_h; ++idy)
2108       for (idx = 0; idx < txb_w; ++idx)
2109         rd_stats->txb_coeff_cost_map[plane][blk_row + idy][blk_col + idx] = 0;
2110 
2111     rd_stats->txb_coeff_cost_map[plane][blk_row][blk_col] = txb_coeff_cost;
2112   }
2113   assert(blk_row < TXB_COEFF_COST_MAP_SIZE);
2114   assert(blk_col < TXB_COEFF_COST_MAP_SIZE);
2115 }
2116 #endif
2117 
cost_coeffs(MACROBLOCK * x,int plane,int block,TX_SIZE tx_size,const TX_TYPE tx_type,const TXB_CTX * const txb_ctx,int reduced_tx_set_used)2118 static INLINE int cost_coeffs(MACROBLOCK *x, int plane, int block,
2119                               TX_SIZE tx_size, const TX_TYPE tx_type,
2120                               const TXB_CTX *const txb_ctx,
2121                               int reduced_tx_set_used) {
2122 #if TXCOEFF_COST_TIMER
2123   struct aom_usec_timer timer;
2124   aom_usec_timer_start(&timer);
2125 #endif
2126   const int cost = av1_cost_coeffs_txb(x, plane, block, tx_size, tx_type,
2127                                        txb_ctx, reduced_tx_set_used);
2128 #if TXCOEFF_COST_TIMER
2129   AV1_COMMON *tmp_cm = (AV1_COMMON *)&cpi->common;
2130   aom_usec_timer_mark(&timer);
2131   const int64_t elapsed_time = aom_usec_timer_elapsed(&timer);
2132   tmp_cm->txcoeff_cost_timer += elapsed_time;
2133   ++tmp_cm->txcoeff_cost_count;
2134 #endif
2135   return cost;
2136 }
2137 
skip_trellis_opt_based_on_satd(MACROBLOCK * x,QUANT_PARAM * quant_param,int plane,int block,TX_SIZE tx_size,int quant_b_adapt,int qstep,unsigned int coeff_opt_satd_threshold,int skip_trellis,int dc_only_blk)2138 static int skip_trellis_opt_based_on_satd(MACROBLOCK *x,
2139                                           QUANT_PARAM *quant_param, int plane,
2140                                           int block, TX_SIZE tx_size,
2141                                           int quant_b_adapt, int qstep,
2142                                           unsigned int coeff_opt_satd_threshold,
2143                                           int skip_trellis, int dc_only_blk) {
2144   if (skip_trellis || (coeff_opt_satd_threshold == UINT_MAX))
2145     return skip_trellis;
2146 
2147   const struct macroblock_plane *const p = &x->plane[plane];
2148   const int block_offset = BLOCK_OFFSET(block);
2149   tran_low_t *const coeff_ptr = p->coeff + block_offset;
2150   const int n_coeffs = av1_get_max_eob(tx_size);
2151   const int shift = (MAX_TX_SCALE - av1_get_tx_scale(tx_size));
2152   int satd = (dc_only_blk) ? abs(coeff_ptr[0]) : aom_satd(coeff_ptr, n_coeffs);
2153   satd = RIGHT_SIGNED_SHIFT(satd, shift);
2154   satd >>= (x->e_mbd.bd - 8);
2155 
2156   const int skip_block_trellis =
2157       ((uint64_t)satd >
2158        (uint64_t)coeff_opt_satd_threshold * qstep * sqrt_tx_pixels_2d[tx_size]);
2159 
2160   av1_setup_quant(
2161       tx_size, !skip_block_trellis,
2162       skip_block_trellis
2163           ? (USE_B_QUANT_NO_TRELLIS ? AV1_XFORM_QUANT_B : AV1_XFORM_QUANT_FP)
2164           : AV1_XFORM_QUANT_FP,
2165       quant_b_adapt, quant_param);
2166 
2167   return skip_block_trellis;
2168 }
2169 
2170 // Predict DC only blocks if the residual variance is below a qstep based
2171 // threshold.For such blocks, transform type search is bypassed.
predict_dc_only_block(MACROBLOCK * x,int plane,BLOCK_SIZE plane_bsize,TX_SIZE tx_size,int block,int blk_row,int blk_col,RD_STATS * best_rd_stats,int64_t * block_sse,unsigned int * block_mse_q8,int64_t * per_px_mean,int * dc_only_blk)2172 static INLINE void predict_dc_only_block(
2173     MACROBLOCK *x, int plane, BLOCK_SIZE plane_bsize, TX_SIZE tx_size,
2174     int block, int blk_row, int blk_col, RD_STATS *best_rd_stats,
2175     int64_t *block_sse, unsigned int *block_mse_q8, int64_t *per_px_mean,
2176     int *dc_only_blk) {
2177   MACROBLOCKD *xd = &x->e_mbd;
2178   MB_MODE_INFO *mbmi = xd->mi[0];
2179   const int dequant_shift = (is_cur_buf_hbd(xd)) ? xd->bd - 5 : 3;
2180   const int qstep = x->plane[plane].dequant_QTX[1] >> dequant_shift;
2181   uint64_t block_var = UINT64_MAX;
2182   const int dc_qstep = x->plane[plane].dequant_QTX[0] >> 3;
2183   *block_sse = pixel_diff_stats(x, plane, blk_row, blk_col, plane_bsize,
2184                                 txsize_to_bsize[tx_size], block_mse_q8,
2185                                 per_px_mean, &block_var);
2186   assert((*block_mse_q8) != UINT_MAX);
2187   uint64_t var_threshold = (uint64_t)(1.8 * qstep * qstep);
2188   if (is_cur_buf_hbd(xd))
2189     block_var = ROUND_POWER_OF_TWO(block_var, (xd->bd - 8) * 2);
2190   // Early prediction of skip block if residual mean and variance are less
2191   // than qstep based threshold
2192   if (((llabs(*per_px_mean) * dc_coeff_scale[tx_size]) < (dc_qstep << 12)) &&
2193       (block_var < var_threshold)) {
2194     // If the normalized mean of residual block is less than the dc qstep and
2195     // the  normalized block variance is less than ac qstep, then the block is
2196     // assumed to be a skip block and its rdcost is updated accordingly.
2197     best_rd_stats->skip_txfm = 1;
2198 
2199     x->plane[plane].eobs[block] = 0;
2200 
2201     if (is_cur_buf_hbd(xd))
2202       *block_sse = ROUND_POWER_OF_TWO((*block_sse), (xd->bd - 8) * 2);
2203 
2204     best_rd_stats->dist = (*block_sse) << 4;
2205     best_rd_stats->sse = best_rd_stats->dist;
2206 
2207     ENTROPY_CONTEXT ctxa[MAX_MIB_SIZE];
2208     ENTROPY_CONTEXT ctxl[MAX_MIB_SIZE];
2209     av1_get_entropy_contexts(plane_bsize, &xd->plane[plane], ctxa, ctxl);
2210     ENTROPY_CONTEXT *ta = ctxa;
2211     ENTROPY_CONTEXT *tl = ctxl;
2212     const TX_SIZE txs_ctx = get_txsize_entropy_ctx(tx_size);
2213     TXB_CTX txb_ctx_tmp;
2214     const PLANE_TYPE plane_type = get_plane_type(plane);
2215     get_txb_ctx(plane_bsize, tx_size, plane, ta, tl, &txb_ctx_tmp);
2216     const int zero_blk_rate = x->coeff_costs.coeff_costs[txs_ctx][plane_type]
2217                                   .txb_skip_cost[txb_ctx_tmp.txb_skip_ctx][1];
2218     best_rd_stats->rate = zero_blk_rate;
2219 
2220     best_rd_stats->rdcost =
2221         RDCOST(x->rdmult, best_rd_stats->rate, best_rd_stats->sse);
2222 
2223     x->plane[plane].txb_entropy_ctx[block] = 0;
2224   } else if (block_var < var_threshold) {
2225     // Predict DC only blocks based on residual variance.
2226     // For chroma plane, this early prediction is disabled for intra blocks.
2227     if ((plane == 0) || (plane > 0 && is_inter_block(mbmi))) *dc_only_blk = 1;
2228   }
2229 }
2230 
2231 // Search for the best transform type for a given transform block.
2232 // This function can be used for both inter and intra, both luma and chroma.
search_tx_type(const AV1_COMP * cpi,MACROBLOCK * x,int plane,int block,int blk_row,int blk_col,BLOCK_SIZE plane_bsize,TX_SIZE tx_size,const TXB_CTX * const txb_ctx,FAST_TX_SEARCH_MODE ftxs_mode,int skip_trellis,int64_t ref_best_rd,RD_STATS * best_rd_stats)2233 static void search_tx_type(const AV1_COMP *cpi, MACROBLOCK *x, int plane,
2234                            int block, int blk_row, int blk_col,
2235                            BLOCK_SIZE plane_bsize, TX_SIZE tx_size,
2236                            const TXB_CTX *const txb_ctx,
2237                            FAST_TX_SEARCH_MODE ftxs_mode, int skip_trellis,
2238                            int64_t ref_best_rd, RD_STATS *best_rd_stats) {
2239   const AV1_COMMON *cm = &cpi->common;
2240   MACROBLOCKD *xd = &x->e_mbd;
2241   MB_MODE_INFO *mbmi = xd->mi[0];
2242   const TxfmSearchParams *txfm_params = &x->txfm_search_params;
2243   int64_t best_rd = INT64_MAX;
2244   uint16_t best_eob = 0;
2245   TX_TYPE best_tx_type = DCT_DCT;
2246   int rate_cost = 0;
2247   // The buffer used to swap dqcoeff in macroblockd_plane so we can keep dqcoeff
2248   // of the best tx_type
2249   DECLARE_ALIGNED(32, tran_low_t, this_dqcoeff[MAX_SB_SQUARE]);
2250   struct macroblock_plane *const p = &x->plane[plane];
2251   tran_low_t *orig_dqcoeff = p->dqcoeff;
2252   tran_low_t *best_dqcoeff = this_dqcoeff;
2253   const int tx_type_map_idx =
2254       plane ? 0 : blk_row * xd->tx_type_map_stride + blk_col;
2255   av1_invalid_rd_stats(best_rd_stats);
2256 
2257   skip_trellis |= !is_trellis_used(cpi->optimize_seg_arr[xd->mi[0]->segment_id],
2258                                    DRY_RUN_NORMAL);
2259 
2260   // Hashing based speed feature for intra block. If the hash of the residue
2261   // is found in the hash table, use the previous RD search results stored in
2262   // the table and terminate early.
2263   TXB_RD_INFO *intra_txb_rd_info = NULL;
2264   uint16_t cur_joint_ctx = 0;
2265   const int is_inter = is_inter_block(mbmi);
2266   const int use_intra_txb_hash =
2267       cpi->sf.tx_sf.use_intra_txb_hash && frame_is_intra_only(cm) &&
2268       !is_inter && plane == 0 && tx_size_wide[tx_size] == tx_size_high[tx_size];
2269   if (use_intra_txb_hash) {
2270     const int mi_row = xd->mi_row;
2271     const int mi_col = xd->mi_col;
2272     const int within_border =
2273         mi_row >= xd->tile.mi_row_start &&
2274         (mi_row + mi_size_high[plane_bsize] < xd->tile.mi_row_end) &&
2275         mi_col >= xd->tile.mi_col_start &&
2276         (mi_col + mi_size_wide[plane_bsize] < xd->tile.mi_col_end);
2277     if (within_border &&
2278         is_intra_hash_match(cpi, x, plane, blk_row, blk_col, plane_bsize,
2279                             tx_size, txb_ctx, &intra_txb_rd_info,
2280                             tx_type_map_idx, &cur_joint_ctx)) {
2281       best_rd_stats->rate = intra_txb_rd_info->rate;
2282       best_rd_stats->dist = intra_txb_rd_info->dist;
2283       best_rd_stats->sse = intra_txb_rd_info->sse;
2284       best_rd_stats->skip_txfm = intra_txb_rd_info->eob == 0;
2285       x->plane[plane].eobs[block] = intra_txb_rd_info->eob;
2286       x->plane[plane].txb_entropy_ctx[block] =
2287           intra_txb_rd_info->txb_entropy_ctx;
2288       best_eob = intra_txb_rd_info->eob;
2289       best_tx_type = intra_txb_rd_info->tx_type;
2290       skip_trellis |= !intra_txb_rd_info->perform_block_coeff_opt;
2291       update_txk_array(xd, blk_row, blk_col, tx_size, best_tx_type);
2292       recon_intra(cpi, x, plane, block, blk_row, blk_col, plane_bsize, tx_size,
2293                   txb_ctx, skip_trellis, best_tx_type, 1, &rate_cost, best_eob);
2294       p->dqcoeff = orig_dqcoeff;
2295       return;
2296     }
2297   }
2298 
2299   uint8_t best_txb_ctx = 0;
2300   // txk_allowed = TX_TYPES: >1 tx types are allowed
2301   // txk_allowed < TX_TYPES: only that specific tx type is allowed.
2302   TX_TYPE txk_allowed = TX_TYPES;
2303   int txk_map[TX_TYPES] = {
2304     0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15
2305   };
2306   const int dequant_shift = (is_cur_buf_hbd(xd)) ? xd->bd - 5 : 3;
2307   const int qstep = x->plane[plane].dequant_QTX[1] >> dequant_shift;
2308 
2309   const uint8_t txw = tx_size_wide[tx_size];
2310   const uint8_t txh = tx_size_high[tx_size];
2311   int64_t block_sse;
2312   unsigned int block_mse_q8;
2313   int dc_only_blk = 0;
2314   const bool predict_dc_block =
2315       txfm_params->predict_dc_level && txw != 64 && txh != 64;
2316   int64_t per_px_mean = INT64_MAX;
2317   if (predict_dc_block) {
2318     predict_dc_only_block(x, plane, plane_bsize, tx_size, block, blk_row,
2319                           blk_col, best_rd_stats, &block_sse, &block_mse_q8,
2320                           &per_px_mean, &dc_only_blk);
2321     if (best_rd_stats->skip_txfm == 1) return;
2322   } else {
2323     block_sse = pixel_diff_dist(x, plane, blk_row, blk_col, plane_bsize,
2324                                 txsize_to_bsize[tx_size], &block_mse_q8);
2325     assert(block_mse_q8 != UINT_MAX);
2326   }
2327 
2328   // Bit mask to indicate which transform types are allowed in the RD search.
2329   uint16_t tx_mask;
2330 
2331   // Use DCT_DCT transform for DC only block.
2332   if (dc_only_blk)
2333     tx_mask = 1 << DCT_DCT;
2334   else
2335     tx_mask = get_tx_mask(cpi, x, plane, block, blk_row, blk_col, plane_bsize,
2336                           tx_size, txb_ctx, ftxs_mode, ref_best_rd,
2337                           &txk_allowed, txk_map);
2338   const uint16_t allowed_tx_mask = tx_mask;
2339 
2340   if (is_cur_buf_hbd(xd)) {
2341     block_sse = ROUND_POWER_OF_TWO(block_sse, (xd->bd - 8) * 2);
2342     block_mse_q8 = ROUND_POWER_OF_TWO(block_mse_q8, (xd->bd - 8) * 2);
2343   }
2344   block_sse *= 16;
2345   // Use mse / qstep^2 based threshold logic to take decision of R-D
2346   // optimization of coeffs. For smaller residuals, coeff optimization
2347   // would be helpful. For larger residuals, R-D optimization may not be
2348   // effective.
2349   // TODO(any): Experiment with variance and mean based thresholds
2350   const int perform_block_coeff_opt =
2351       ((uint64_t)block_mse_q8 <=
2352        (uint64_t)txfm_params->coeff_opt_thresholds[0] * qstep * qstep);
2353   skip_trellis |= !perform_block_coeff_opt;
2354 
2355   // Flag to indicate if distortion should be calculated in transform domain or
2356   // not during iterating through transform type candidates.
2357   // Transform domain distortion is accurate for higher residuals.
2358   // TODO(any): Experiment with variance and mean based thresholds
2359   int use_transform_domain_distortion =
2360       (txfm_params->use_transform_domain_distortion > 0) &&
2361       (block_mse_q8 >= txfm_params->tx_domain_dist_threshold) &&
2362       // Any 64-pt transforms only preserves half the coefficients.
2363       // Therefore transform domain distortion is not valid for these
2364       // transform sizes.
2365       (txsize_sqr_up_map[tx_size] != TX_64X64) &&
2366       // Use pixel domain distortion for DC only blocks
2367       !dc_only_blk;
2368   // Flag to indicate if an extra calculation of distortion in the pixel domain
2369   // should be performed at the end, after the best transform type has been
2370   // decided.
2371   int calc_pixel_domain_distortion_final =
2372       txfm_params->use_transform_domain_distortion == 1 &&
2373       use_transform_domain_distortion && x->rd_model != LOW_TXFM_RD;
2374   if (calc_pixel_domain_distortion_final &&
2375       (txk_allowed < TX_TYPES || allowed_tx_mask == 0x0001))
2376     calc_pixel_domain_distortion_final = use_transform_domain_distortion = 0;
2377 
2378   const uint16_t *eobs_ptr = x->plane[plane].eobs;
2379 
2380   TxfmParam txfm_param;
2381   QUANT_PARAM quant_param;
2382   int skip_trellis_based_on_satd[TX_TYPES] = { 0 };
2383   av1_setup_xform(cm, x, tx_size, DCT_DCT, &txfm_param);
2384   av1_setup_quant(tx_size, !skip_trellis,
2385                   skip_trellis ? (USE_B_QUANT_NO_TRELLIS ? AV1_XFORM_QUANT_B
2386                                                          : AV1_XFORM_QUANT_FP)
2387                                : AV1_XFORM_QUANT_FP,
2388                   cpi->oxcf.q_cfg.quant_b_adapt, &quant_param);
2389 
2390   // Iterate through all transform type candidates.
2391   for (int idx = 0; idx < TX_TYPES; ++idx) {
2392     const TX_TYPE tx_type = (TX_TYPE)txk_map[idx];
2393     if (!(allowed_tx_mask & (1 << tx_type))) continue;
2394     txfm_param.tx_type = tx_type;
2395     if (av1_use_qmatrix(&cm->quant_params, xd, mbmi->segment_id)) {
2396       av1_setup_qmatrix(&cm->quant_params, xd, plane, tx_size, tx_type,
2397                         &quant_param);
2398     }
2399     if (plane == 0) xd->tx_type_map[tx_type_map_idx] = tx_type;
2400     RD_STATS this_rd_stats;
2401     av1_invalid_rd_stats(&this_rd_stats);
2402 
2403     if (!dc_only_blk)
2404       av1_xform(x, plane, block, blk_row, blk_col, plane_bsize, &txfm_param);
2405     else
2406       av1_xform_dc_only(x, plane, block, &txfm_param, per_px_mean);
2407 
2408     skip_trellis_based_on_satd[tx_type] = skip_trellis_opt_based_on_satd(
2409         x, &quant_param, plane, block, tx_size, cpi->oxcf.q_cfg.quant_b_adapt,
2410         qstep, txfm_params->coeff_opt_thresholds[1], skip_trellis, dc_only_blk);
2411 
2412     av1_quant(x, plane, block, &txfm_param, &quant_param);
2413 
2414     // Calculate rate cost of quantized coefficients.
2415     if (quant_param.use_optimize_b) {
2416       av1_optimize_b(cpi, x, plane, block, tx_size, tx_type, txb_ctx,
2417                      &rate_cost);
2418     } else {
2419       rate_cost = cost_coeffs(x, plane, block, tx_size, tx_type, txb_ctx,
2420                               cm->features.reduced_tx_set_used);
2421     }
2422 
2423     // If rd cost based on coeff rate alone is already more than best_rd,
2424     // terminate early.
2425     if (RDCOST(x->rdmult, rate_cost, 0) > best_rd) continue;
2426 
2427     // Calculate distortion.
2428     if (eobs_ptr[block] == 0) {
2429       // When eob is 0, pixel domain distortion is more efficient and accurate.
2430       this_rd_stats.dist = this_rd_stats.sse = block_sse;
2431     } else if (dc_only_blk) {
2432       this_rd_stats.sse = block_sse;
2433       this_rd_stats.dist = dist_block_px_domain(
2434           cpi, x, plane, plane_bsize, block, blk_row, blk_col, tx_size);
2435     } else if (use_transform_domain_distortion) {
2436       dist_block_tx_domain(x, plane, block, tx_size, &this_rd_stats.dist,
2437                            &this_rd_stats.sse);
2438     } else {
2439       int64_t sse_diff = INT64_MAX;
2440       // high_energy threshold assumes that every pixel within a txfm block
2441       // has a residue energy of at least 25% of the maximum, i.e. 128 * 128
2442       // for 8 bit.
2443       const int64_t high_energy_thresh =
2444           ((int64_t)128 * 128 * tx_size_2d[tx_size]);
2445       const int is_high_energy = (block_sse >= high_energy_thresh);
2446       if (tx_size == TX_64X64 || is_high_energy) {
2447         // Because 3 out 4 quadrants of transform coefficients are forced to
2448         // zero, the inverse transform has a tendency to overflow. sse_diff
2449         // is effectively the energy of those 3 quadrants, here we use it
2450         // to decide if we should do pixel domain distortion. If the energy
2451         // is mostly in first quadrant, then it is unlikely that we have
2452         // overflow issue in inverse transform.
2453         dist_block_tx_domain(x, plane, block, tx_size, &this_rd_stats.dist,
2454                              &this_rd_stats.sse);
2455         sse_diff = block_sse - this_rd_stats.sse;
2456       }
2457       if (tx_size != TX_64X64 || !is_high_energy ||
2458           (sse_diff * 2) < this_rd_stats.sse) {
2459         const int64_t tx_domain_dist = this_rd_stats.dist;
2460         this_rd_stats.dist = dist_block_px_domain(
2461             cpi, x, plane, plane_bsize, block, blk_row, blk_col, tx_size);
2462         // For high energy blocks, occasionally, the pixel domain distortion
2463         // can be artificially low due to clamping at reconstruction stage
2464         // even when inverse transform output is hugely different from the
2465         // actual residue.
2466         if (is_high_energy && this_rd_stats.dist < tx_domain_dist)
2467           this_rd_stats.dist = tx_domain_dist;
2468       } else {
2469         assert(sse_diff < INT64_MAX);
2470         this_rd_stats.dist += sse_diff;
2471       }
2472       this_rd_stats.sse = block_sse;
2473     }
2474 
2475     this_rd_stats.rate = rate_cost;
2476 
2477     const int64_t rd =
2478         RDCOST(x->rdmult, this_rd_stats.rate, this_rd_stats.dist);
2479 
2480     if (rd < best_rd) {
2481       best_rd = rd;
2482       *best_rd_stats = this_rd_stats;
2483       best_tx_type = tx_type;
2484       best_txb_ctx = x->plane[plane].txb_entropy_ctx[block];
2485       best_eob = x->plane[plane].eobs[block];
2486       // Swap dqcoeff buffers
2487       tran_low_t *const tmp_dqcoeff = best_dqcoeff;
2488       best_dqcoeff = p->dqcoeff;
2489       p->dqcoeff = tmp_dqcoeff;
2490     }
2491 
2492 #if CONFIG_COLLECT_RD_STATS == 1
2493     if (plane == 0) {
2494       PrintTransformUnitStats(cpi, x, &this_rd_stats, blk_row, blk_col,
2495                               plane_bsize, tx_size, tx_type, rd);
2496     }
2497 #endif  // CONFIG_COLLECT_RD_STATS == 1
2498 
2499 #if COLLECT_TX_SIZE_DATA
2500     // Generate small sample to restrict output size.
2501     static unsigned int seed = 21743;
2502     if (lcg_rand16(&seed) % 200 == 0) {
2503       FILE *fp = NULL;
2504 
2505       if (within_border) {
2506         fp = fopen(av1_tx_size_data_output_file, "a");
2507       }
2508 
2509       if (fp) {
2510         // Transform info and RD
2511         const int txb_w = tx_size_wide[tx_size];
2512         const int txb_h = tx_size_high[tx_size];
2513 
2514         // Residue signal.
2515         const int diff_stride = block_size_wide[plane_bsize];
2516         struct macroblock_plane *const p = &x->plane[plane];
2517         const int16_t *src_diff =
2518             &p->src_diff[(blk_row * diff_stride + blk_col) * 4];
2519 
2520         for (int r = 0; r < txb_h; ++r) {
2521           for (int c = 0; c < txb_w; ++c) {
2522             fprintf(fp, "%d,", src_diff[c]);
2523           }
2524           src_diff += diff_stride;
2525         }
2526 
2527         fprintf(fp, "%d,%d,%d,%" PRId64, txb_w, txb_h, tx_type, rd);
2528         fprintf(fp, "\n");
2529         fclose(fp);
2530       }
2531     }
2532 #endif  // COLLECT_TX_SIZE_DATA
2533 
2534     // If the current best RD cost is much worse than the reference RD cost,
2535     // terminate early.
2536     if (cpi->sf.tx_sf.adaptive_txb_search_level) {
2537       if ((best_rd - (best_rd >> cpi->sf.tx_sf.adaptive_txb_search_level)) >
2538           ref_best_rd) {
2539         break;
2540       }
2541     }
2542 
2543     // Terminate transform type search if the block has been quantized to
2544     // all zero.
2545     if (cpi->sf.tx_sf.tx_type_search.skip_tx_search && !best_eob) break;
2546   }
2547 
2548   assert(best_rd != INT64_MAX);
2549 
2550   best_rd_stats->skip_txfm = best_eob == 0;
2551   if (plane == 0) update_txk_array(xd, blk_row, blk_col, tx_size, best_tx_type);
2552   x->plane[plane].txb_entropy_ctx[block] = best_txb_ctx;
2553   x->plane[plane].eobs[block] = best_eob;
2554   skip_trellis = skip_trellis_based_on_satd[best_tx_type];
2555 
2556   // Point dqcoeff to the quantized coefficients corresponding to the best
2557   // transform type, then we can skip transform and quantization, e.g. in the
2558   // final pixel domain distortion calculation and recon_intra().
2559   p->dqcoeff = best_dqcoeff;
2560 
2561   if (calc_pixel_domain_distortion_final && best_eob) {
2562     best_rd_stats->dist = dist_block_px_domain(
2563         cpi, x, plane, plane_bsize, block, blk_row, blk_col, tx_size);
2564     best_rd_stats->sse = block_sse;
2565   }
2566 
2567   if (intra_txb_rd_info != NULL) {
2568     intra_txb_rd_info->valid = 1;
2569     intra_txb_rd_info->entropy_context = cur_joint_ctx;
2570     intra_txb_rd_info->rate = best_rd_stats->rate;
2571     intra_txb_rd_info->dist = best_rd_stats->dist;
2572     intra_txb_rd_info->sse = best_rd_stats->sse;
2573     intra_txb_rd_info->eob = best_eob;
2574     intra_txb_rd_info->txb_entropy_ctx = best_txb_ctx;
2575     intra_txb_rd_info->perform_block_coeff_opt = perform_block_coeff_opt;
2576     if (plane == 0) intra_txb_rd_info->tx_type = best_tx_type;
2577   }
2578 
2579   // Intra mode needs decoded pixels such that the next transform block
2580   // can use them for prediction.
2581   recon_intra(cpi, x, plane, block, blk_row, blk_col, plane_bsize, tx_size,
2582               txb_ctx, skip_trellis, best_tx_type, 0, &rate_cost, best_eob);
2583   p->dqcoeff = orig_dqcoeff;
2584 }
2585 
2586 // Pick transform type for a luma transform block of tx_size. Note this function
2587 // is used only for inter-predicted blocks.
tx_type_rd(const AV1_COMP * cpi,MACROBLOCK * x,TX_SIZE tx_size,int blk_row,int blk_col,int block,int plane_bsize,TXB_CTX * txb_ctx,RD_STATS * rd_stats,FAST_TX_SEARCH_MODE ftxs_mode,int64_t ref_rdcost,TXB_RD_INFO * rd_info_array)2588 static AOM_INLINE void tx_type_rd(const AV1_COMP *cpi, MACROBLOCK *x,
2589                                   TX_SIZE tx_size, int blk_row, int blk_col,
2590                                   int block, int plane_bsize, TXB_CTX *txb_ctx,
2591                                   RD_STATS *rd_stats,
2592                                   FAST_TX_SEARCH_MODE ftxs_mode,
2593                                   int64_t ref_rdcost,
2594                                   TXB_RD_INFO *rd_info_array) {
2595   const struct macroblock_plane *const p = &x->plane[0];
2596   const uint16_t cur_joint_ctx =
2597       (txb_ctx->dc_sign_ctx << 8) + txb_ctx->txb_skip_ctx;
2598   MACROBLOCKD *xd = &x->e_mbd;
2599   assert(is_inter_block(xd->mi[0]));
2600   const int tx_type_map_idx = blk_row * xd->tx_type_map_stride + blk_col;
2601   // Look up RD and terminate early in case when we've already processed exactly
2602   // the same residue with exactly the same entropy context.
2603   if (rd_info_array != NULL && rd_info_array->valid &&
2604       rd_info_array->entropy_context == cur_joint_ctx) {
2605     xd->tx_type_map[tx_type_map_idx] = rd_info_array->tx_type;
2606     const TX_TYPE ref_tx_type =
2607         av1_get_tx_type(&x->e_mbd, get_plane_type(0), blk_row, blk_col, tx_size,
2608                         cpi->common.features.reduced_tx_set_used);
2609     if (ref_tx_type == rd_info_array->tx_type) {
2610       rd_stats->rate += rd_info_array->rate;
2611       rd_stats->dist += rd_info_array->dist;
2612       rd_stats->sse += rd_info_array->sse;
2613       rd_stats->skip_txfm &= rd_info_array->eob == 0;
2614       p->eobs[block] = rd_info_array->eob;
2615       p->txb_entropy_ctx[block] = rd_info_array->txb_entropy_ctx;
2616       return;
2617     }
2618   }
2619 
2620   RD_STATS this_rd_stats;
2621   const int skip_trellis = 0;
2622   search_tx_type(cpi, x, 0, block, blk_row, blk_col, plane_bsize, tx_size,
2623                  txb_ctx, ftxs_mode, skip_trellis, ref_rdcost, &this_rd_stats);
2624 
2625   av1_merge_rd_stats(rd_stats, &this_rd_stats);
2626 
2627   // Save RD results for possible reuse in future.
2628   if (rd_info_array != NULL) {
2629     rd_info_array->valid = 1;
2630     rd_info_array->entropy_context = cur_joint_ctx;
2631     rd_info_array->rate = this_rd_stats.rate;
2632     rd_info_array->dist = this_rd_stats.dist;
2633     rd_info_array->sse = this_rd_stats.sse;
2634     rd_info_array->eob = p->eobs[block];
2635     rd_info_array->txb_entropy_ctx = p->txb_entropy_ctx[block];
2636     rd_info_array->tx_type = xd->tx_type_map[tx_type_map_idx];
2637   }
2638 }
2639 
try_tx_block_no_split(const AV1_COMP * cpi,MACROBLOCK * x,int blk_row,int blk_col,int block,TX_SIZE tx_size,int depth,BLOCK_SIZE plane_bsize,const ENTROPY_CONTEXT * ta,const ENTROPY_CONTEXT * tl,int txfm_partition_ctx,RD_STATS * rd_stats,int64_t ref_best_rd,FAST_TX_SEARCH_MODE ftxs_mode,TXB_RD_INFO_NODE * rd_info_node,TxCandidateInfo * no_split)2640 static AOM_INLINE void try_tx_block_no_split(
2641     const AV1_COMP *cpi, MACROBLOCK *x, int blk_row, int blk_col, int block,
2642     TX_SIZE tx_size, int depth, BLOCK_SIZE plane_bsize,
2643     const ENTROPY_CONTEXT *ta, const ENTROPY_CONTEXT *tl,
2644     int txfm_partition_ctx, RD_STATS *rd_stats, int64_t ref_best_rd,
2645     FAST_TX_SEARCH_MODE ftxs_mode, TXB_RD_INFO_NODE *rd_info_node,
2646     TxCandidateInfo *no_split) {
2647   MACROBLOCKD *const xd = &x->e_mbd;
2648   MB_MODE_INFO *const mbmi = xd->mi[0];
2649   struct macroblock_plane *const p = &x->plane[0];
2650   const int bw = mi_size_wide[plane_bsize];
2651   const ENTROPY_CONTEXT *const pta = ta + blk_col;
2652   const ENTROPY_CONTEXT *const ptl = tl + blk_row;
2653   const TX_SIZE txs_ctx = get_txsize_entropy_ctx(tx_size);
2654   TXB_CTX txb_ctx;
2655   get_txb_ctx(plane_bsize, tx_size, 0, pta, ptl, &txb_ctx);
2656   const int zero_blk_rate = x->coeff_costs.coeff_costs[txs_ctx][PLANE_TYPE_Y]
2657                                 .txb_skip_cost[txb_ctx.txb_skip_ctx][1];
2658   rd_stats->zero_rate = zero_blk_rate;
2659   const int index = av1_get_txb_size_index(plane_bsize, blk_row, blk_col);
2660   mbmi->inter_tx_size[index] = tx_size;
2661   tx_type_rd(cpi, x, tx_size, blk_row, blk_col, block, plane_bsize, &txb_ctx,
2662              rd_stats, ftxs_mode, ref_best_rd,
2663              rd_info_node != NULL ? rd_info_node->rd_info_array : NULL);
2664   assert(rd_stats->rate < INT_MAX);
2665 
2666   const int pick_skip_txfm =
2667       !xd->lossless[mbmi->segment_id] &&
2668       (rd_stats->skip_txfm == 1 ||
2669        RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist) >=
2670            RDCOST(x->rdmult, zero_blk_rate, rd_stats->sse));
2671   if (pick_skip_txfm) {
2672 #if CONFIG_RD_DEBUG
2673     update_txb_coeff_cost(rd_stats, 0, tx_size, blk_row, blk_col,
2674                           zero_blk_rate - rd_stats->rate);
2675 #endif  // CONFIG_RD_DEBUG
2676     rd_stats->rate = zero_blk_rate;
2677     rd_stats->dist = rd_stats->sse;
2678     p->eobs[block] = 0;
2679     update_txk_array(xd, blk_row, blk_col, tx_size, DCT_DCT);
2680   }
2681   rd_stats->skip_txfm = pick_skip_txfm;
2682   set_blk_skip(x->txfm_search_info.blk_skip, 0, blk_row * bw + blk_col,
2683                pick_skip_txfm);
2684 
2685   if (tx_size > TX_4X4 && depth < MAX_VARTX_DEPTH)
2686     rd_stats->rate += x->mode_costs.txfm_partition_cost[txfm_partition_ctx][0];
2687 
2688   no_split->rd = RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist);
2689   no_split->txb_entropy_ctx = p->txb_entropy_ctx[block];
2690   no_split->tx_type =
2691       xd->tx_type_map[blk_row * xd->tx_type_map_stride + blk_col];
2692 }
2693 
try_tx_block_split(const AV1_COMP * cpi,MACROBLOCK * x,int blk_row,int blk_col,int block,TX_SIZE tx_size,int depth,BLOCK_SIZE plane_bsize,ENTROPY_CONTEXT * ta,ENTROPY_CONTEXT * tl,TXFM_CONTEXT * tx_above,TXFM_CONTEXT * tx_left,int txfm_partition_ctx,int64_t no_split_rd,int64_t ref_best_rd,FAST_TX_SEARCH_MODE ftxs_mode,TXB_RD_INFO_NODE * rd_info_node,RD_STATS * split_rd_stats)2694 static AOM_INLINE void try_tx_block_split(
2695     const AV1_COMP *cpi, MACROBLOCK *x, int blk_row, int blk_col, int block,
2696     TX_SIZE tx_size, int depth, BLOCK_SIZE plane_bsize, ENTROPY_CONTEXT *ta,
2697     ENTROPY_CONTEXT *tl, TXFM_CONTEXT *tx_above, TXFM_CONTEXT *tx_left,
2698     int txfm_partition_ctx, int64_t no_split_rd, int64_t ref_best_rd,
2699     FAST_TX_SEARCH_MODE ftxs_mode, TXB_RD_INFO_NODE *rd_info_node,
2700     RD_STATS *split_rd_stats) {
2701   assert(tx_size < TX_SIZES_ALL);
2702   MACROBLOCKD *const xd = &x->e_mbd;
2703   const int max_blocks_high = max_block_high(xd, plane_bsize, 0);
2704   const int max_blocks_wide = max_block_wide(xd, plane_bsize, 0);
2705   const int txb_width = tx_size_wide_unit[tx_size];
2706   const int txb_height = tx_size_high_unit[tx_size];
2707   // Transform size after splitting current block.
2708   const TX_SIZE sub_txs = sub_tx_size_map[tx_size];
2709   const int sub_txb_width = tx_size_wide_unit[sub_txs];
2710   const int sub_txb_height = tx_size_high_unit[sub_txs];
2711   const int sub_step = sub_txb_width * sub_txb_height;
2712   const int nblks = (txb_height / sub_txb_height) * (txb_width / sub_txb_width);
2713   assert(nblks > 0);
2714   av1_init_rd_stats(split_rd_stats);
2715   split_rd_stats->rate =
2716       x->mode_costs.txfm_partition_cost[txfm_partition_ctx][1];
2717 
2718   for (int r = 0, blk_idx = 0; r < txb_height; r += sub_txb_height) {
2719     for (int c = 0; c < txb_width; c += sub_txb_width, ++blk_idx) {
2720       assert(blk_idx < 4);
2721       const int offsetr = blk_row + r;
2722       const int offsetc = blk_col + c;
2723       if (offsetr >= max_blocks_high || offsetc >= max_blocks_wide) continue;
2724 
2725       RD_STATS this_rd_stats;
2726       int this_cost_valid = 1;
2727       select_tx_block(
2728           cpi, x, offsetr, offsetc, block, sub_txs, depth + 1, plane_bsize, ta,
2729           tl, tx_above, tx_left, &this_rd_stats, no_split_rd / nblks,
2730           ref_best_rd - split_rd_stats->rdcost, &this_cost_valid, ftxs_mode,
2731           (rd_info_node != NULL) ? rd_info_node->children[blk_idx] : NULL);
2732       if (!this_cost_valid) {
2733         split_rd_stats->rdcost = INT64_MAX;
2734         return;
2735       }
2736       av1_merge_rd_stats(split_rd_stats, &this_rd_stats);
2737       split_rd_stats->rdcost =
2738           RDCOST(x->rdmult, split_rd_stats->rate, split_rd_stats->dist);
2739       if (split_rd_stats->rdcost > ref_best_rd) {
2740         split_rd_stats->rdcost = INT64_MAX;
2741         return;
2742       }
2743       block += sub_step;
2744     }
2745   }
2746 }
2747 
get_var(float mean,double x2_sum,int num)2748 static float get_var(float mean, double x2_sum, int num) {
2749   const float e_x2 = (float)(x2_sum / num);
2750   const float diff = e_x2 - mean * mean;
2751   return diff;
2752 }
2753 
get_blk_var_dev(const int16_t * data,int stride,int bw,int bh,float * dev_of_mean,float * var_of_vars)2754 static AOM_INLINE void get_blk_var_dev(const int16_t *data, int stride, int bw,
2755                                        int bh, float *dev_of_mean,
2756                                        float *var_of_vars) {
2757   const int16_t *const data_ptr = &data[0];
2758   const int subh = (bh >= bw) ? (bh >> 1) : bh;
2759   const int subw = (bw >= bh) ? (bw >> 1) : bw;
2760   const int num = bw * bh;
2761   const int sub_num = subw * subh;
2762   int total_x_sum = 0;
2763   int64_t total_x2_sum = 0;
2764   int blk_idx = 0;
2765   float var_sum = 0.0f;
2766   float mean_sum = 0.0f;
2767   double var2_sum = 0.0f;
2768   double mean2_sum = 0.0f;
2769 
2770   for (int row = 0; row < bh; row += subh) {
2771     for (int col = 0; col < bw; col += subw) {
2772       int x_sum;
2773       int64_t x2_sum;
2774       aom_get_blk_sse_sum(data_ptr + row * stride + col, stride, subw, subh,
2775                           &x_sum, &x2_sum);
2776       total_x_sum += x_sum;
2777       total_x2_sum += x2_sum;
2778 
2779       aom_clear_system_state();
2780       const float mean = (float)x_sum / sub_num;
2781       const float var = get_var(mean, (double)x2_sum, sub_num);
2782       mean_sum += mean;
2783       mean2_sum += (double)(mean * mean);
2784       var_sum += var;
2785       var2_sum += var * var;
2786       blk_idx++;
2787     }
2788   }
2789 
2790   const float lvl0_mean = (float)total_x_sum / num;
2791   const float block_var = get_var(lvl0_mean, (double)total_x2_sum, num);
2792   mean_sum += lvl0_mean;
2793   mean2_sum += (double)(lvl0_mean * lvl0_mean);
2794   var_sum += block_var;
2795   var2_sum += block_var * block_var;
2796   const float av_mean = mean_sum / 5;
2797 
2798   if (blk_idx > 1) {
2799     // Deviation of means.
2800     *dev_of_mean = get_dev(av_mean, mean2_sum, (blk_idx + 1));
2801     // Variance of variances.
2802     const float mean_var = var_sum / (blk_idx + 1);
2803     *var_of_vars = get_var(mean_var, var2_sum, (blk_idx + 1));
2804   }
2805 }
2806 
prune_tx_split_no_split(MACROBLOCK * x,BLOCK_SIZE bsize,int blk_row,int blk_col,TX_SIZE tx_size,int * try_no_split,int * try_split,int pruning_level)2807 static void prune_tx_split_no_split(MACROBLOCK *x, BLOCK_SIZE bsize,
2808                                     int blk_row, int blk_col, TX_SIZE tx_size,
2809                                     int *try_no_split, int *try_split,
2810                                     int pruning_level) {
2811   const int diff_stride = block_size_wide[bsize];
2812   const int16_t *diff =
2813       x->plane[0].src_diff + 4 * blk_row * diff_stride + 4 * blk_col;
2814   const int bw = tx_size_wide[tx_size];
2815   const int bh = tx_size_high[tx_size];
2816   aom_clear_system_state();
2817   float dev_of_means = 0.0f;
2818   float var_of_vars = 0.0f;
2819 
2820   // This function calculates the deviation of means, and the variance of pixel
2821   // variances of the block as well as it's sub-blocks.
2822   get_blk_var_dev(diff, diff_stride, bw, bh, &dev_of_means, &var_of_vars);
2823   const int dc_q = x->plane[0].dequant_QTX[0] >> 3;
2824   const int ac_q = x->plane[0].dequant_QTX[1] >> 3;
2825   const int no_split_thresh_scales[4] = { 0, 24, 8, 8 };
2826   const int no_split_thresh_scale = no_split_thresh_scales[pruning_level];
2827   const int split_thresh_scales[4] = { 0, 24, 10, 8 };
2828   const int split_thresh_scale = split_thresh_scales[pruning_level];
2829 
2830   if ((dev_of_means <= dc_q) &&
2831       (split_thresh_scale * var_of_vars <= ac_q * ac_q)) {
2832     *try_split = 0;
2833   }
2834   if ((dev_of_means > no_split_thresh_scale * dc_q) &&
2835       (var_of_vars > no_split_thresh_scale * ac_q * ac_q)) {
2836     *try_no_split = 0;
2837   }
2838 }
2839 
2840 // Search for the best transform partition(recursive)/type for a given
2841 // inter-predicted luma block. The obtained transform selection will be saved
2842 // in xd->mi[0], the corresponding RD stats will be saved in rd_stats.
select_tx_block(const AV1_COMP * cpi,MACROBLOCK * x,int blk_row,int blk_col,int block,TX_SIZE tx_size,int depth,BLOCK_SIZE plane_bsize,ENTROPY_CONTEXT * ta,ENTROPY_CONTEXT * tl,TXFM_CONTEXT * tx_above,TXFM_CONTEXT * tx_left,RD_STATS * rd_stats,int64_t prev_level_rd,int64_t ref_best_rd,int * is_cost_valid,FAST_TX_SEARCH_MODE ftxs_mode,TXB_RD_INFO_NODE * rd_info_node)2843 static AOM_INLINE void select_tx_block(
2844     const AV1_COMP *cpi, MACROBLOCK *x, int blk_row, int blk_col, int block,
2845     TX_SIZE tx_size, int depth, BLOCK_SIZE plane_bsize, ENTROPY_CONTEXT *ta,
2846     ENTROPY_CONTEXT *tl, TXFM_CONTEXT *tx_above, TXFM_CONTEXT *tx_left,
2847     RD_STATS *rd_stats, int64_t prev_level_rd, int64_t ref_best_rd,
2848     int *is_cost_valid, FAST_TX_SEARCH_MODE ftxs_mode,
2849     TXB_RD_INFO_NODE *rd_info_node) {
2850   assert(tx_size < TX_SIZES_ALL);
2851   av1_init_rd_stats(rd_stats);
2852   if (ref_best_rd < 0) {
2853     *is_cost_valid = 0;
2854     return;
2855   }
2856 
2857   MACROBLOCKD *const xd = &x->e_mbd;
2858   assert(blk_row < max_block_high(xd, plane_bsize, 0) &&
2859          blk_col < max_block_wide(xd, plane_bsize, 0));
2860   MB_MODE_INFO *const mbmi = xd->mi[0];
2861   const int ctx = txfm_partition_context(tx_above + blk_col, tx_left + blk_row,
2862                                          mbmi->bsize, tx_size);
2863   struct macroblock_plane *const p = &x->plane[0];
2864 
2865   int try_no_split = (cpi->oxcf.txfm_cfg.enable_tx64 ||
2866                       txsize_sqr_up_map[tx_size] != TX_64X64) &&
2867                      (cpi->oxcf.txfm_cfg.enable_rect_tx ||
2868                       tx_size_wide[tx_size] == tx_size_high[tx_size]);
2869   int try_split = tx_size > TX_4X4 && depth < MAX_VARTX_DEPTH;
2870   TxCandidateInfo no_split = { INT64_MAX, 0, TX_TYPES };
2871 
2872   // Prune tx_split and no-split based on sub-block properties.
2873   if (tx_size != TX_4X4 && try_split == 1 && try_no_split == 1 &&
2874       cpi->sf.tx_sf.prune_tx_size_level > 0) {
2875     prune_tx_split_no_split(x, plane_bsize, blk_row, blk_col, tx_size,
2876                             &try_no_split, &try_split,
2877                             cpi->sf.tx_sf.prune_tx_size_level);
2878   }
2879 
2880   // Try using current block as a single transform block without split.
2881   if (try_no_split) {
2882     try_tx_block_no_split(cpi, x, blk_row, blk_col, block, tx_size, depth,
2883                           plane_bsize, ta, tl, ctx, rd_stats, ref_best_rd,
2884                           ftxs_mode, rd_info_node, &no_split);
2885 
2886     // Speed features for early termination.
2887     const int search_level = cpi->sf.tx_sf.adaptive_txb_search_level;
2888     if (search_level) {
2889       if ((no_split.rd - (no_split.rd >> (1 + search_level))) > ref_best_rd) {
2890         *is_cost_valid = 0;
2891         return;
2892       }
2893       if (no_split.rd - (no_split.rd >> (2 + search_level)) > prev_level_rd) {
2894         try_split = 0;
2895       }
2896     }
2897     if (cpi->sf.tx_sf.txb_split_cap) {
2898       if (p->eobs[block] == 0) try_split = 0;
2899     }
2900   }
2901 
2902   // ML based speed feature to skip searching for split transform blocks.
2903   if (x->e_mbd.bd == 8 && try_split &&
2904       !(ref_best_rd == INT64_MAX && no_split.rd == INT64_MAX)) {
2905     const int threshold = cpi->sf.tx_sf.tx_type_search.ml_tx_split_thresh;
2906     if (threshold >= 0) {
2907       const int split_score =
2908           ml_predict_tx_split(x, plane_bsize, blk_row, blk_col, tx_size);
2909       if (split_score < -threshold) try_split = 0;
2910     }
2911   }
2912 
2913   RD_STATS split_rd_stats;
2914   split_rd_stats.rdcost = INT64_MAX;
2915   // Try splitting current block into smaller transform blocks.
2916   if (try_split) {
2917     try_tx_block_split(cpi, x, blk_row, blk_col, block, tx_size, depth,
2918                        plane_bsize, ta, tl, tx_above, tx_left, ctx, no_split.rd,
2919                        AOMMIN(no_split.rd, ref_best_rd), ftxs_mode,
2920                        rd_info_node, &split_rd_stats);
2921   }
2922 
2923   if (no_split.rd < split_rd_stats.rdcost) {
2924     ENTROPY_CONTEXT *pta = ta + blk_col;
2925     ENTROPY_CONTEXT *ptl = tl + blk_row;
2926     p->txb_entropy_ctx[block] = no_split.txb_entropy_ctx;
2927     av1_set_txb_context(x, 0, block, tx_size, pta, ptl);
2928     txfm_partition_update(tx_above + blk_col, tx_left + blk_row, tx_size,
2929                           tx_size);
2930     for (int idy = 0; idy < tx_size_high_unit[tx_size]; ++idy) {
2931       for (int idx = 0; idx < tx_size_wide_unit[tx_size]; ++idx) {
2932         const int index =
2933             av1_get_txb_size_index(plane_bsize, blk_row + idy, blk_col + idx);
2934         mbmi->inter_tx_size[index] = tx_size;
2935       }
2936     }
2937     mbmi->tx_size = tx_size;
2938     update_txk_array(xd, blk_row, blk_col, tx_size, no_split.tx_type);
2939     const int bw = mi_size_wide[plane_bsize];
2940     set_blk_skip(x->txfm_search_info.blk_skip, 0, blk_row * bw + blk_col,
2941                  rd_stats->skip_txfm);
2942   } else {
2943     *rd_stats = split_rd_stats;
2944     if (split_rd_stats.rdcost == INT64_MAX) *is_cost_valid = 0;
2945   }
2946 }
2947 
choose_largest_tx_size(const AV1_COMP * const cpi,MACROBLOCK * x,RD_STATS * rd_stats,int64_t ref_best_rd,BLOCK_SIZE bs)2948 static AOM_INLINE void choose_largest_tx_size(const AV1_COMP *const cpi,
2949                                               MACROBLOCK *x, RD_STATS *rd_stats,
2950                                               int64_t ref_best_rd,
2951                                               BLOCK_SIZE bs) {
2952   MACROBLOCKD *const xd = &x->e_mbd;
2953   MB_MODE_INFO *const mbmi = xd->mi[0];
2954   const TxfmSearchParams *txfm_params = &x->txfm_search_params;
2955   mbmi->tx_size = tx_size_from_tx_mode(bs, txfm_params->tx_mode_search_type);
2956 
2957   // If tx64 is not enabled, we need to go down to the next available size
2958   if (!cpi->oxcf.txfm_cfg.enable_tx64 && cpi->oxcf.txfm_cfg.enable_rect_tx) {
2959     static const TX_SIZE tx_size_max_32[TX_SIZES_ALL] = {
2960       TX_4X4,    // 4x4 transform
2961       TX_8X8,    // 8x8 transform
2962       TX_16X16,  // 16x16 transform
2963       TX_32X32,  // 32x32 transform
2964       TX_32X32,  // 64x64 transform
2965       TX_4X8,    // 4x8 transform
2966       TX_8X4,    // 8x4 transform
2967       TX_8X16,   // 8x16 transform
2968       TX_16X8,   // 16x8 transform
2969       TX_16X32,  // 16x32 transform
2970       TX_32X16,  // 32x16 transform
2971       TX_32X32,  // 32x64 transform
2972       TX_32X32,  // 64x32 transform
2973       TX_4X16,   // 4x16 transform
2974       TX_16X4,   // 16x4 transform
2975       TX_8X32,   // 8x32 transform
2976       TX_32X8,   // 32x8 transform
2977       TX_16X32,  // 16x64 transform
2978       TX_32X16,  // 64x16 transform
2979     };
2980     mbmi->tx_size = tx_size_max_32[mbmi->tx_size];
2981   } else if (cpi->oxcf.txfm_cfg.enable_tx64 &&
2982              !cpi->oxcf.txfm_cfg.enable_rect_tx) {
2983     static const TX_SIZE tx_size_max_square[TX_SIZES_ALL] = {
2984       TX_4X4,    // 4x4 transform
2985       TX_8X8,    // 8x8 transform
2986       TX_16X16,  // 16x16 transform
2987       TX_32X32,  // 32x32 transform
2988       TX_64X64,  // 64x64 transform
2989       TX_4X4,    // 4x8 transform
2990       TX_4X4,    // 8x4 transform
2991       TX_8X8,    // 8x16 transform
2992       TX_8X8,    // 16x8 transform
2993       TX_16X16,  // 16x32 transform
2994       TX_16X16,  // 32x16 transform
2995       TX_32X32,  // 32x64 transform
2996       TX_32X32,  // 64x32 transform
2997       TX_4X4,    // 4x16 transform
2998       TX_4X4,    // 16x4 transform
2999       TX_8X8,    // 8x32 transform
3000       TX_8X8,    // 32x8 transform
3001       TX_16X16,  // 16x64 transform
3002       TX_16X16,  // 64x16 transform
3003     };
3004     mbmi->tx_size = tx_size_max_square[mbmi->tx_size];
3005   } else if (!cpi->oxcf.txfm_cfg.enable_tx64 &&
3006              !cpi->oxcf.txfm_cfg.enable_rect_tx) {
3007     static const TX_SIZE tx_size_max_32_square[TX_SIZES_ALL] = {
3008       TX_4X4,    // 4x4 transform
3009       TX_8X8,    // 8x8 transform
3010       TX_16X16,  // 16x16 transform
3011       TX_32X32,  // 32x32 transform
3012       TX_32X32,  // 64x64 transform
3013       TX_4X4,    // 4x8 transform
3014       TX_4X4,    // 8x4 transform
3015       TX_8X8,    // 8x16 transform
3016       TX_8X8,    // 16x8 transform
3017       TX_16X16,  // 16x32 transform
3018       TX_16X16,  // 32x16 transform
3019       TX_32X32,  // 32x64 transform
3020       TX_32X32,  // 64x32 transform
3021       TX_4X4,    // 4x16 transform
3022       TX_4X4,    // 16x4 transform
3023       TX_8X8,    // 8x32 transform
3024       TX_8X8,    // 32x8 transform
3025       TX_16X16,  // 16x64 transform
3026       TX_16X16,  // 64x16 transform
3027     };
3028 
3029     mbmi->tx_size = tx_size_max_32_square[mbmi->tx_size];
3030   }
3031 
3032   const int skip_ctx = av1_get_skip_txfm_context(xd);
3033   const int no_skip_txfm_rate = x->mode_costs.skip_txfm_cost[skip_ctx][0];
3034   const int skip_txfm_rate = x->mode_costs.skip_txfm_cost[skip_ctx][1];
3035   // Skip RDcost is used only for Inter blocks
3036   const int64_t skip_txfm_rd =
3037       is_inter_block(mbmi) ? RDCOST(x->rdmult, skip_txfm_rate, 0) : INT64_MAX;
3038   const int64_t no_skip_txfm_rd = RDCOST(x->rdmult, no_skip_txfm_rate, 0);
3039   const int skip_trellis = 0;
3040   av1_txfm_rd_in_plane(x, cpi, rd_stats, ref_best_rd,
3041                        AOMMIN(no_skip_txfm_rd, skip_txfm_rd), AOM_PLANE_Y, bs,
3042                        mbmi->tx_size, FTXS_NONE, skip_trellis);
3043 }
3044 
choose_smallest_tx_size(const AV1_COMP * const cpi,MACROBLOCK * x,RD_STATS * rd_stats,int64_t ref_best_rd,BLOCK_SIZE bs)3045 static AOM_INLINE void choose_smallest_tx_size(const AV1_COMP *const cpi,
3046                                                MACROBLOCK *x,
3047                                                RD_STATS *rd_stats,
3048                                                int64_t ref_best_rd,
3049                                                BLOCK_SIZE bs) {
3050   MACROBLOCKD *const xd = &x->e_mbd;
3051   MB_MODE_INFO *const mbmi = xd->mi[0];
3052 
3053   mbmi->tx_size = TX_4X4;
3054   // TODO(any) : Pass this_rd based on skip/non-skip cost
3055   const int skip_trellis = 0;
3056   av1_txfm_rd_in_plane(x, cpi, rd_stats, ref_best_rd, 0, 0, bs, mbmi->tx_size,
3057                        FTXS_NONE, skip_trellis);
3058 }
3059 
3060 // Search for the best uniform transform size and type for current coding block.
choose_tx_size_type_from_rd(const AV1_COMP * const cpi,MACROBLOCK * x,RD_STATS * rd_stats,int64_t ref_best_rd,BLOCK_SIZE bs)3061 static AOM_INLINE void choose_tx_size_type_from_rd(const AV1_COMP *const cpi,
3062                                                    MACROBLOCK *x,
3063                                                    RD_STATS *rd_stats,
3064                                                    int64_t ref_best_rd,
3065                                                    BLOCK_SIZE bs) {
3066   av1_invalid_rd_stats(rd_stats);
3067 
3068   MACROBLOCKD *const xd = &x->e_mbd;
3069   MB_MODE_INFO *const mbmi = xd->mi[0];
3070   const TxfmSearchParams *txfm_params = &x->txfm_search_params;
3071   const TX_SIZE max_rect_tx_size = max_txsize_rect_lookup[bs];
3072   const int tx_select = txfm_params->tx_mode_search_type == TX_MODE_SELECT;
3073   int start_tx;
3074   // The split depth can be at most MAX_TX_DEPTH, so the init_depth controls
3075   // how many times of splitting is allowed during the RD search.
3076   int init_depth;
3077 
3078   if (tx_select) {
3079     start_tx = max_rect_tx_size;
3080     init_depth = get_search_init_depth(mi_size_wide[bs], mi_size_high[bs],
3081                                        is_inter_block(mbmi), &cpi->sf,
3082                                        txfm_params->tx_size_search_method);
3083   } else {
3084     const TX_SIZE chosen_tx_size =
3085         tx_size_from_tx_mode(bs, txfm_params->tx_mode_search_type);
3086     start_tx = chosen_tx_size;
3087     init_depth = MAX_TX_DEPTH;
3088   }
3089 
3090   const int skip_trellis = 0;
3091   uint8_t best_txk_type_map[MAX_MIB_SIZE * MAX_MIB_SIZE];
3092   uint8_t best_blk_skip[MAX_MIB_SIZE * MAX_MIB_SIZE];
3093   TX_SIZE best_tx_size = max_rect_tx_size;
3094   int64_t best_rd = INT64_MAX;
3095   const int num_blks = bsize_to_num_blk(bs);
3096   x->rd_model = FULL_TXFM_RD;
3097   int64_t rd[MAX_TX_DEPTH + 1] = { INT64_MAX, INT64_MAX, INT64_MAX };
3098   TxfmSearchInfo *txfm_info = &x->txfm_search_info;
3099   for (int tx_size = start_tx, depth = init_depth; depth <= MAX_TX_DEPTH;
3100        depth++, tx_size = sub_tx_size_map[tx_size]) {
3101     if ((!cpi->oxcf.txfm_cfg.enable_tx64 &&
3102          txsize_sqr_up_map[tx_size] == TX_64X64) ||
3103         (!cpi->oxcf.txfm_cfg.enable_rect_tx &&
3104          tx_size_wide[tx_size] != tx_size_high[tx_size])) {
3105       continue;
3106     }
3107 
3108     RD_STATS this_rd_stats;
3109     rd[depth] = av1_uniform_txfm_yrd(cpi, x, &this_rd_stats, ref_best_rd, bs,
3110                                      tx_size, FTXS_NONE, skip_trellis);
3111     if (rd[depth] < best_rd) {
3112       av1_copy_array(best_blk_skip, txfm_info->blk_skip, num_blks);
3113       av1_copy_array(best_txk_type_map, xd->tx_type_map, num_blks);
3114       best_tx_size = tx_size;
3115       best_rd = rd[depth];
3116       *rd_stats = this_rd_stats;
3117     }
3118     if (tx_size == TX_4X4) break;
3119     // If we are searching three depths, prune the smallest size depending
3120     // on rd results for the first two depths for low contrast blocks.
3121     if (depth > init_depth && depth != MAX_TX_DEPTH &&
3122         x->source_variance < 256) {
3123       if (rd[depth - 1] != INT64_MAX && rd[depth] > rd[depth - 1]) break;
3124     }
3125   }
3126 
3127   if (rd_stats->rate != INT_MAX) {
3128     mbmi->tx_size = best_tx_size;
3129     av1_copy_array(xd->tx_type_map, best_txk_type_map, num_blks);
3130     av1_copy_array(txfm_info->blk_skip, best_blk_skip, num_blks);
3131   }
3132 }
3133 
3134 // Search for the best transform type for the given transform block in the
3135 // given plane/channel, and calculate the corresponding RD cost.
block_rd_txfm(int plane,int block,int blk_row,int blk_col,BLOCK_SIZE plane_bsize,TX_SIZE tx_size,void * arg)3136 static AOM_INLINE void block_rd_txfm(int plane, int block, int blk_row,
3137                                      int blk_col, BLOCK_SIZE plane_bsize,
3138                                      TX_SIZE tx_size, void *arg) {
3139   struct rdcost_block_args *args = arg;
3140   if (args->exit_early) {
3141     args->incomplete_exit = 1;
3142     return;
3143   }
3144 
3145   MACROBLOCK *const x = args->x;
3146   MACROBLOCKD *const xd = &x->e_mbd;
3147   const int is_inter = is_inter_block(xd->mi[0]);
3148   const AV1_COMP *cpi = args->cpi;
3149   ENTROPY_CONTEXT *a = args->t_above + blk_col;
3150   ENTROPY_CONTEXT *l = args->t_left + blk_row;
3151   const AV1_COMMON *cm = &cpi->common;
3152   RD_STATS this_rd_stats;
3153   av1_init_rd_stats(&this_rd_stats);
3154 
3155   if (!is_inter) {
3156     av1_predict_intra_block_facade(cm, xd, plane, blk_col, blk_row, tx_size);
3157     av1_subtract_txb(x, plane, plane_bsize, blk_col, blk_row, tx_size);
3158   }
3159 
3160   TXB_CTX txb_ctx;
3161   get_txb_ctx(plane_bsize, tx_size, plane, a, l, &txb_ctx);
3162   search_tx_type(cpi, x, plane, block, blk_row, blk_col, plane_bsize, tx_size,
3163                  &txb_ctx, args->ftxs_mode, args->skip_trellis,
3164                  args->best_rd - args->current_rd, &this_rd_stats);
3165 
3166   if (plane == AOM_PLANE_Y && xd->cfl.store_y) {
3167     assert(!is_inter || plane_bsize < BLOCK_8X8);
3168     cfl_store_tx(xd, blk_row, blk_col, tx_size, plane_bsize);
3169   }
3170 
3171 #if CONFIG_RD_DEBUG
3172   update_txb_coeff_cost(&this_rd_stats, plane, tx_size, blk_row, blk_col,
3173                         this_rd_stats.rate);
3174 #endif  // CONFIG_RD_DEBUG
3175   av1_set_txb_context(x, plane, block, tx_size, a, l);
3176 
3177   const int blk_idx =
3178       blk_row * (block_size_wide[plane_bsize] >> MI_SIZE_LOG2) + blk_col;
3179 
3180   TxfmSearchInfo *txfm_info = &x->txfm_search_info;
3181   if (plane == 0)
3182     set_blk_skip(txfm_info->blk_skip, plane, blk_idx,
3183                  x->plane[plane].eobs[block] == 0);
3184   else
3185     set_blk_skip(txfm_info->blk_skip, plane, blk_idx, 0);
3186 
3187   int64_t rd;
3188   if (is_inter) {
3189     const int64_t no_skip_txfm_rd =
3190         RDCOST(x->rdmult, this_rd_stats.rate, this_rd_stats.dist);
3191     const int64_t skip_txfm_rd = RDCOST(x->rdmult, 0, this_rd_stats.sse);
3192     rd = AOMMIN(no_skip_txfm_rd, skip_txfm_rd);
3193     this_rd_stats.skip_txfm &= !x->plane[plane].eobs[block];
3194   } else {
3195     // Signal non-skip_txfm for Intra blocks
3196     rd = RDCOST(x->rdmult, this_rd_stats.rate, this_rd_stats.dist);
3197     this_rd_stats.skip_txfm = 0;
3198   }
3199 
3200   av1_merge_rd_stats(&args->rd_stats, &this_rd_stats);
3201 
3202   args->current_rd += rd;
3203   if (args->current_rd > args->best_rd) args->exit_early = 1;
3204 }
3205 
av1_estimate_txfm_yrd(const AV1_COMP * const cpi,MACROBLOCK * x,RD_STATS * rd_stats,int64_t ref_best_rd,BLOCK_SIZE bs,TX_SIZE tx_size)3206 int64_t av1_estimate_txfm_yrd(const AV1_COMP *const cpi, MACROBLOCK *x,
3207                               RD_STATS *rd_stats, int64_t ref_best_rd,
3208                               BLOCK_SIZE bs, TX_SIZE tx_size) {
3209   MACROBLOCKD *const xd = &x->e_mbd;
3210   MB_MODE_INFO *const mbmi = xd->mi[0];
3211   const TxfmSearchParams *txfm_params = &x->txfm_search_params;
3212   const ModeCosts *mode_costs = &x->mode_costs;
3213   const int is_inter = is_inter_block(mbmi);
3214   const int tx_select = txfm_params->tx_mode_search_type == TX_MODE_SELECT &&
3215                         block_signals_txsize(mbmi->bsize);
3216   int tx_size_rate = 0;
3217   if (tx_select) {
3218     const int ctx = txfm_partition_context(
3219         xd->above_txfm_context, xd->left_txfm_context, mbmi->bsize, tx_size);
3220     tx_size_rate = mode_costs->txfm_partition_cost[ctx][0];
3221   }
3222   const int skip_ctx = av1_get_skip_txfm_context(xd);
3223   const int no_skip_txfm_rate = mode_costs->skip_txfm_cost[skip_ctx][0];
3224   const int skip_txfm_rate = mode_costs->skip_txfm_cost[skip_ctx][1];
3225   const int64_t skip_txfm_rd = RDCOST(x->rdmult, skip_txfm_rate, 0);
3226   const int64_t no_this_rd =
3227       RDCOST(x->rdmult, no_skip_txfm_rate + tx_size_rate, 0);
3228   mbmi->tx_size = tx_size;
3229 
3230   const uint8_t txw_unit = tx_size_wide_unit[tx_size];
3231   const uint8_t txh_unit = tx_size_high_unit[tx_size];
3232   const int step = txw_unit * txh_unit;
3233   const int max_blocks_wide = max_block_wide(xd, bs, 0);
3234   const int max_blocks_high = max_block_high(xd, bs, 0);
3235 
3236   struct rdcost_block_args args;
3237   av1_zero(args);
3238   args.x = x;
3239   args.cpi = cpi;
3240   args.best_rd = ref_best_rd;
3241   args.current_rd = AOMMIN(no_this_rd, skip_txfm_rd);
3242   av1_init_rd_stats(&args.rd_stats);
3243   av1_get_entropy_contexts(bs, &xd->plane[0], args.t_above, args.t_left);
3244   int i = 0;
3245   for (int blk_row = 0; blk_row < max_blocks_high && !args.incomplete_exit;
3246        blk_row += txh_unit) {
3247     for (int blk_col = 0; blk_col < max_blocks_wide; blk_col += txw_unit) {
3248       RD_STATS this_rd_stats;
3249       av1_init_rd_stats(&this_rd_stats);
3250 
3251       if (args.exit_early) {
3252         args.incomplete_exit = 1;
3253         break;
3254       }
3255 
3256       ENTROPY_CONTEXT *a = args.t_above + blk_col;
3257       ENTROPY_CONTEXT *l = args.t_left + blk_row;
3258       TXB_CTX txb_ctx;
3259       get_txb_ctx(bs, tx_size, 0, a, l, &txb_ctx);
3260 
3261       TxfmParam txfm_param;
3262       QUANT_PARAM quant_param;
3263       av1_setup_xform(&cpi->common, x, tx_size, DCT_DCT, &txfm_param);
3264       av1_setup_quant(tx_size, 0, AV1_XFORM_QUANT_B, 0, &quant_param);
3265 
3266       av1_xform(x, 0, i, blk_row, blk_col, bs, &txfm_param);
3267       av1_quant(x, 0, i, &txfm_param, &quant_param);
3268 
3269       this_rd_stats.rate =
3270           cost_coeffs(x, 0, i, tx_size, txfm_param.tx_type, &txb_ctx, 0);
3271 
3272       dist_block_tx_domain(x, 0, i, tx_size, &this_rd_stats.dist,
3273                            &this_rd_stats.sse);
3274 
3275       const int64_t no_skip_txfm_rd =
3276           RDCOST(x->rdmult, this_rd_stats.rate, this_rd_stats.dist);
3277       const int64_t skip_rd = RDCOST(x->rdmult, 0, this_rd_stats.sse);
3278 
3279       this_rd_stats.skip_txfm &= !x->plane[0].eobs[i];
3280 
3281       av1_merge_rd_stats(&args.rd_stats, &this_rd_stats);
3282       args.current_rd += AOMMIN(no_skip_txfm_rd, skip_rd);
3283 
3284       if (args.current_rd > ref_best_rd) {
3285         args.exit_early = 1;
3286         break;
3287       }
3288 
3289       av1_set_txb_context(x, 0, i, tx_size, a, l);
3290       i += step;
3291     }
3292   }
3293 
3294   if (args.incomplete_exit) av1_invalid_rd_stats(&args.rd_stats);
3295 
3296   *rd_stats = args.rd_stats;
3297   if (rd_stats->rate == INT_MAX) return INT64_MAX;
3298 
3299   int64_t rd;
3300   // rdstats->rate should include all the rate except skip/non-skip cost as the
3301   // same is accounted in the caller functions after rd evaluation of all
3302   // planes. However the decisions should be done after considering the
3303   // skip/non-skip header cost
3304   if (rd_stats->skip_txfm && is_inter) {
3305     rd = RDCOST(x->rdmult, skip_txfm_rate, rd_stats->sse);
3306   } else {
3307     // Intra blocks are always signalled as non-skip
3308     rd = RDCOST(x->rdmult, rd_stats->rate + no_skip_txfm_rate + tx_size_rate,
3309                 rd_stats->dist);
3310     rd_stats->rate += tx_size_rate;
3311   }
3312   // Check if forcing the block to skip transform leads to smaller RD cost.
3313   if (is_inter && !rd_stats->skip_txfm && !xd->lossless[mbmi->segment_id]) {
3314     int64_t temp_skip_txfm_rd =
3315         RDCOST(x->rdmult, skip_txfm_rate, rd_stats->sse);
3316     if (temp_skip_txfm_rd <= rd) {
3317       rd = temp_skip_txfm_rd;
3318       rd_stats->rate = 0;
3319       rd_stats->dist = rd_stats->sse;
3320       rd_stats->skip_txfm = 1;
3321     }
3322   }
3323 
3324   return rd;
3325 }
3326 
av1_uniform_txfm_yrd(const AV1_COMP * const cpi,MACROBLOCK * x,RD_STATS * rd_stats,int64_t ref_best_rd,BLOCK_SIZE bs,TX_SIZE tx_size,FAST_TX_SEARCH_MODE ftxs_mode,int skip_trellis)3327 int64_t av1_uniform_txfm_yrd(const AV1_COMP *const cpi, MACROBLOCK *x,
3328                              RD_STATS *rd_stats, int64_t ref_best_rd,
3329                              BLOCK_SIZE bs, TX_SIZE tx_size,
3330                              FAST_TX_SEARCH_MODE ftxs_mode, int skip_trellis) {
3331   assert(IMPLIES(is_rect_tx(tx_size), is_rect_tx_allowed_bsize(bs)));
3332   MACROBLOCKD *const xd = &x->e_mbd;
3333   MB_MODE_INFO *const mbmi = xd->mi[0];
3334   const TxfmSearchParams *txfm_params = &x->txfm_search_params;
3335   const ModeCosts *mode_costs = &x->mode_costs;
3336   const int is_inter = is_inter_block(mbmi);
3337   const int tx_select = txfm_params->tx_mode_search_type == TX_MODE_SELECT &&
3338                         block_signals_txsize(mbmi->bsize);
3339   int tx_size_rate = 0;
3340   if (tx_select) {
3341     const int ctx = txfm_partition_context(
3342         xd->above_txfm_context, xd->left_txfm_context, mbmi->bsize, tx_size);
3343     tx_size_rate = is_inter ? mode_costs->txfm_partition_cost[ctx][0]
3344                             : tx_size_cost(x, bs, tx_size);
3345   }
3346   const int skip_ctx = av1_get_skip_txfm_context(xd);
3347   const int no_skip_txfm_rate = mode_costs->skip_txfm_cost[skip_ctx][0];
3348   const int skip_txfm_rate = mode_costs->skip_txfm_cost[skip_ctx][1];
3349   const int64_t skip_txfm_rd =
3350       is_inter ? RDCOST(x->rdmult, skip_txfm_rate, 0) : INT64_MAX;
3351   const int64_t no_this_rd =
3352       RDCOST(x->rdmult, no_skip_txfm_rate + tx_size_rate, 0);
3353 
3354   mbmi->tx_size = tx_size;
3355   av1_txfm_rd_in_plane(x, cpi, rd_stats, ref_best_rd,
3356                        AOMMIN(no_this_rd, skip_txfm_rd), AOM_PLANE_Y, bs,
3357                        tx_size, ftxs_mode, skip_trellis);
3358   if (rd_stats->rate == INT_MAX) return INT64_MAX;
3359 
3360   int64_t rd;
3361   // rdstats->rate should include all the rate except skip/non-skip cost as the
3362   // same is accounted in the caller functions after rd evaluation of all
3363   // planes. However the decisions should be done after considering the
3364   // skip/non-skip header cost
3365   if (rd_stats->skip_txfm && is_inter) {
3366     rd = RDCOST(x->rdmult, skip_txfm_rate, rd_stats->sse);
3367   } else {
3368     // Intra blocks are always signalled as non-skip
3369     rd = RDCOST(x->rdmult, rd_stats->rate + no_skip_txfm_rate + tx_size_rate,
3370                 rd_stats->dist);
3371     rd_stats->rate += tx_size_rate;
3372   }
3373   // Check if forcing the block to skip transform leads to smaller RD cost.
3374   if (is_inter && !rd_stats->skip_txfm && !xd->lossless[mbmi->segment_id]) {
3375     int64_t temp_skip_txfm_rd =
3376         RDCOST(x->rdmult, skip_txfm_rate, rd_stats->sse);
3377     if (temp_skip_txfm_rd <= rd) {
3378       rd = temp_skip_txfm_rd;
3379       rd_stats->rate = 0;
3380       rd_stats->dist = rd_stats->sse;
3381       rd_stats->skip_txfm = 1;
3382     }
3383   }
3384 
3385   return rd;
3386 }
3387 
3388 // Search for the best transform type for a luma inter-predicted block, given
3389 // the transform block partitions.
3390 // This function is used only when some speed features are enabled.
tx_block_yrd(const AV1_COMP * cpi,MACROBLOCK * x,int blk_row,int blk_col,int block,TX_SIZE tx_size,BLOCK_SIZE plane_bsize,int depth,ENTROPY_CONTEXT * above_ctx,ENTROPY_CONTEXT * left_ctx,TXFM_CONTEXT * tx_above,TXFM_CONTEXT * tx_left,int64_t ref_best_rd,RD_STATS * rd_stats,FAST_TX_SEARCH_MODE ftxs_mode)3391 static AOM_INLINE void tx_block_yrd(
3392     const AV1_COMP *cpi, MACROBLOCK *x, int blk_row, int blk_col, int block,
3393     TX_SIZE tx_size, BLOCK_SIZE plane_bsize, int depth,
3394     ENTROPY_CONTEXT *above_ctx, ENTROPY_CONTEXT *left_ctx,
3395     TXFM_CONTEXT *tx_above, TXFM_CONTEXT *tx_left, int64_t ref_best_rd,
3396     RD_STATS *rd_stats, FAST_TX_SEARCH_MODE ftxs_mode) {
3397   assert(tx_size < TX_SIZES_ALL);
3398   MACROBLOCKD *const xd = &x->e_mbd;
3399   MB_MODE_INFO *const mbmi = xd->mi[0];
3400   assert(is_inter_block(mbmi));
3401   const int max_blocks_high = max_block_high(xd, plane_bsize, 0);
3402   const int max_blocks_wide = max_block_wide(xd, plane_bsize, 0);
3403 
3404   if (blk_row >= max_blocks_high || blk_col >= max_blocks_wide) return;
3405 
3406   const TX_SIZE plane_tx_size = mbmi->inter_tx_size[av1_get_txb_size_index(
3407       plane_bsize, blk_row, blk_col)];
3408   const int ctx = txfm_partition_context(tx_above + blk_col, tx_left + blk_row,
3409                                          mbmi->bsize, tx_size);
3410 
3411   av1_init_rd_stats(rd_stats);
3412   if (tx_size == plane_tx_size) {
3413     ENTROPY_CONTEXT *ta = above_ctx + blk_col;
3414     ENTROPY_CONTEXT *tl = left_ctx + blk_row;
3415     const TX_SIZE txs_ctx = get_txsize_entropy_ctx(tx_size);
3416     TXB_CTX txb_ctx;
3417     get_txb_ctx(plane_bsize, tx_size, 0, ta, tl, &txb_ctx);
3418 
3419     const int zero_blk_rate =
3420         x->coeff_costs.coeff_costs[txs_ctx][get_plane_type(0)]
3421             .txb_skip_cost[txb_ctx.txb_skip_ctx][1];
3422     rd_stats->zero_rate = zero_blk_rate;
3423     tx_type_rd(cpi, x, tx_size, blk_row, blk_col, block, plane_bsize, &txb_ctx,
3424                rd_stats, ftxs_mode, ref_best_rd, NULL);
3425     const int mi_width = mi_size_wide[plane_bsize];
3426     TxfmSearchInfo *txfm_info = &x->txfm_search_info;
3427     if (RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist) >=
3428             RDCOST(x->rdmult, zero_blk_rate, rd_stats->sse) ||
3429         rd_stats->skip_txfm == 1) {
3430       rd_stats->rate = zero_blk_rate;
3431       rd_stats->dist = rd_stats->sse;
3432       rd_stats->skip_txfm = 1;
3433       set_blk_skip(txfm_info->blk_skip, 0, blk_row * mi_width + blk_col, 1);
3434       x->plane[0].eobs[block] = 0;
3435       x->plane[0].txb_entropy_ctx[block] = 0;
3436       update_txk_array(xd, blk_row, blk_col, tx_size, DCT_DCT);
3437     } else {
3438       rd_stats->skip_txfm = 0;
3439       set_blk_skip(txfm_info->blk_skip, 0, blk_row * mi_width + blk_col, 0);
3440     }
3441     if (tx_size > TX_4X4 && depth < MAX_VARTX_DEPTH)
3442       rd_stats->rate += x->mode_costs.txfm_partition_cost[ctx][0];
3443     av1_set_txb_context(x, 0, block, tx_size, ta, tl);
3444     txfm_partition_update(tx_above + blk_col, tx_left + blk_row, tx_size,
3445                           tx_size);
3446   } else {
3447     const TX_SIZE sub_txs = sub_tx_size_map[tx_size];
3448     const int txb_width = tx_size_wide_unit[sub_txs];
3449     const int txb_height = tx_size_high_unit[sub_txs];
3450     const int step = txb_height * txb_width;
3451     RD_STATS pn_rd_stats;
3452     int64_t this_rd = 0;
3453     assert(txb_width > 0 && txb_height > 0);
3454 
3455     for (int row = 0; row < tx_size_high_unit[tx_size]; row += txb_height) {
3456       for (int col = 0; col < tx_size_wide_unit[tx_size]; col += txb_width) {
3457         const int offsetr = blk_row + row;
3458         const int offsetc = blk_col + col;
3459         if (offsetr >= max_blocks_high || offsetc >= max_blocks_wide) continue;
3460 
3461         av1_init_rd_stats(&pn_rd_stats);
3462         tx_block_yrd(cpi, x, offsetr, offsetc, block, sub_txs, plane_bsize,
3463                      depth + 1, above_ctx, left_ctx, tx_above, tx_left,
3464                      ref_best_rd - this_rd, &pn_rd_stats, ftxs_mode);
3465         if (pn_rd_stats.rate == INT_MAX) {
3466           av1_invalid_rd_stats(rd_stats);
3467           return;
3468         }
3469         av1_merge_rd_stats(rd_stats, &pn_rd_stats);
3470         this_rd += RDCOST(x->rdmult, pn_rd_stats.rate, pn_rd_stats.dist);
3471         block += step;
3472       }
3473     }
3474 
3475     if (tx_size > TX_4X4 && depth < MAX_VARTX_DEPTH)
3476       rd_stats->rate += x->mode_costs.txfm_partition_cost[ctx][1];
3477   }
3478 }
3479 
3480 // search for tx type with tx sizes already decided for a inter-predicted luma
3481 // partition block. It's used only when some speed features are enabled.
3482 // Return value 0: early termination triggered, no valid rd cost available;
3483 //              1: rd cost values are valid.
inter_block_yrd(const AV1_COMP * cpi,MACROBLOCK * x,RD_STATS * rd_stats,BLOCK_SIZE bsize,int64_t ref_best_rd,FAST_TX_SEARCH_MODE ftxs_mode)3484 static int inter_block_yrd(const AV1_COMP *cpi, MACROBLOCK *x,
3485                            RD_STATS *rd_stats, BLOCK_SIZE bsize,
3486                            int64_t ref_best_rd, FAST_TX_SEARCH_MODE ftxs_mode) {
3487   if (ref_best_rd < 0) {
3488     av1_invalid_rd_stats(rd_stats);
3489     return 0;
3490   }
3491 
3492   av1_init_rd_stats(rd_stats);
3493 
3494   MACROBLOCKD *const xd = &x->e_mbd;
3495   const TxfmSearchParams *txfm_params = &x->txfm_search_params;
3496   const struct macroblockd_plane *const pd = &xd->plane[0];
3497   const int mi_width = mi_size_wide[bsize];
3498   const int mi_height = mi_size_high[bsize];
3499   const TX_SIZE max_tx_size = get_vartx_max_txsize(xd, bsize, 0);
3500   const int bh = tx_size_high_unit[max_tx_size];
3501   const int bw = tx_size_wide_unit[max_tx_size];
3502   const int step = bw * bh;
3503   const int init_depth = get_search_init_depth(
3504       mi_width, mi_height, 1, &cpi->sf, txfm_params->tx_size_search_method);
3505   ENTROPY_CONTEXT ctxa[MAX_MIB_SIZE];
3506   ENTROPY_CONTEXT ctxl[MAX_MIB_SIZE];
3507   TXFM_CONTEXT tx_above[MAX_MIB_SIZE];
3508   TXFM_CONTEXT tx_left[MAX_MIB_SIZE];
3509   av1_get_entropy_contexts(bsize, pd, ctxa, ctxl);
3510   memcpy(tx_above, xd->above_txfm_context, sizeof(TXFM_CONTEXT) * mi_width);
3511   memcpy(tx_left, xd->left_txfm_context, sizeof(TXFM_CONTEXT) * mi_height);
3512 
3513   int64_t this_rd = 0;
3514   for (int idy = 0, block = 0; idy < mi_height; idy += bh) {
3515     for (int idx = 0; idx < mi_width; idx += bw) {
3516       RD_STATS pn_rd_stats;
3517       av1_init_rd_stats(&pn_rd_stats);
3518       tx_block_yrd(cpi, x, idy, idx, block, max_tx_size, bsize, init_depth,
3519                    ctxa, ctxl, tx_above, tx_left, ref_best_rd - this_rd,
3520                    &pn_rd_stats, ftxs_mode);
3521       if (pn_rd_stats.rate == INT_MAX) {
3522         av1_invalid_rd_stats(rd_stats);
3523         return 0;
3524       }
3525       av1_merge_rd_stats(rd_stats, &pn_rd_stats);
3526       this_rd +=
3527           AOMMIN(RDCOST(x->rdmult, pn_rd_stats.rate, pn_rd_stats.dist),
3528                  RDCOST(x->rdmult, pn_rd_stats.zero_rate, pn_rd_stats.sse));
3529       block += step;
3530     }
3531   }
3532 
3533   const int skip_ctx = av1_get_skip_txfm_context(xd);
3534   const int no_skip_txfm_rate = x->mode_costs.skip_txfm_cost[skip_ctx][0];
3535   const int skip_txfm_rate = x->mode_costs.skip_txfm_cost[skip_ctx][1];
3536   const int64_t skip_txfm_rd = RDCOST(x->rdmult, skip_txfm_rate, rd_stats->sse);
3537   this_rd =
3538       RDCOST(x->rdmult, rd_stats->rate + no_skip_txfm_rate, rd_stats->dist);
3539   if (skip_txfm_rd < this_rd) {
3540     this_rd = skip_txfm_rd;
3541     rd_stats->rate = 0;
3542     rd_stats->dist = rd_stats->sse;
3543     rd_stats->skip_txfm = 1;
3544   }
3545 
3546   const int is_cost_valid = this_rd > ref_best_rd;
3547   if (!is_cost_valid) {
3548     // reset cost value
3549     av1_invalid_rd_stats(rd_stats);
3550   }
3551   return is_cost_valid;
3552 }
3553 
3554 // Search for the best transform size and type for current inter-predicted
3555 // luma block with recursive transform block partitioning. The obtained
3556 // transform selection will be saved in xd->mi[0], the corresponding RD stats
3557 // will be saved in rd_stats. The returned value is the corresponding RD cost.
select_tx_size_and_type(const AV1_COMP * cpi,MACROBLOCK * x,RD_STATS * rd_stats,BLOCK_SIZE bsize,int64_t ref_best_rd,TXB_RD_INFO_NODE * rd_info_tree)3558 static int64_t select_tx_size_and_type(const AV1_COMP *cpi, MACROBLOCK *x,
3559                                        RD_STATS *rd_stats, BLOCK_SIZE bsize,
3560                                        int64_t ref_best_rd,
3561                                        TXB_RD_INFO_NODE *rd_info_tree) {
3562   MACROBLOCKD *const xd = &x->e_mbd;
3563   const TxfmSearchParams *txfm_params = &x->txfm_search_params;
3564   assert(is_inter_block(xd->mi[0]));
3565   assert(bsize < BLOCK_SIZES_ALL);
3566   const int fast_tx_search = txfm_params->tx_size_search_method > USE_FULL_RD;
3567   int64_t rd_thresh = ref_best_rd;
3568   if (rd_thresh == 0) {
3569     av1_invalid_rd_stats(rd_stats);
3570     return INT64_MAX;
3571   }
3572   if (fast_tx_search && rd_thresh < INT64_MAX) {
3573     if (INT64_MAX - rd_thresh > (rd_thresh >> 3)) rd_thresh += (rd_thresh >> 3);
3574   }
3575   assert(rd_thresh > 0);
3576   const FAST_TX_SEARCH_MODE ftxs_mode =
3577       fast_tx_search ? FTXS_DCT_AND_1D_DCT_ONLY : FTXS_NONE;
3578   const struct macroblockd_plane *const pd = &xd->plane[0];
3579   assert(bsize < BLOCK_SIZES_ALL);
3580   const int mi_width = mi_size_wide[bsize];
3581   const int mi_height = mi_size_high[bsize];
3582   ENTROPY_CONTEXT ctxa[MAX_MIB_SIZE];
3583   ENTROPY_CONTEXT ctxl[MAX_MIB_SIZE];
3584   TXFM_CONTEXT tx_above[MAX_MIB_SIZE];
3585   TXFM_CONTEXT tx_left[MAX_MIB_SIZE];
3586   av1_get_entropy_contexts(bsize, pd, ctxa, ctxl);
3587   memcpy(tx_above, xd->above_txfm_context, sizeof(TXFM_CONTEXT) * mi_width);
3588   memcpy(tx_left, xd->left_txfm_context, sizeof(TXFM_CONTEXT) * mi_height);
3589   const int init_depth = get_search_init_depth(
3590       mi_width, mi_height, 1, &cpi->sf, txfm_params->tx_size_search_method);
3591   const TX_SIZE max_tx_size = max_txsize_rect_lookup[bsize];
3592   const int bh = tx_size_high_unit[max_tx_size];
3593   const int bw = tx_size_wide_unit[max_tx_size];
3594   const int step = bw * bh;
3595   const int skip_ctx = av1_get_skip_txfm_context(xd);
3596   const int no_skip_txfm_cost = x->mode_costs.skip_txfm_cost[skip_ctx][0];
3597   const int skip_txfm_cost = x->mode_costs.skip_txfm_cost[skip_ctx][1];
3598   int64_t skip_txfm_rd = RDCOST(x->rdmult, skip_txfm_cost, 0);
3599   int64_t no_skip_txfm_rd = RDCOST(x->rdmult, no_skip_txfm_cost, 0);
3600   int block = 0;
3601 
3602   av1_init_rd_stats(rd_stats);
3603   for (int idy = 0; idy < max_block_high(xd, bsize, 0); idy += bh) {
3604     for (int idx = 0; idx < max_block_wide(xd, bsize, 0); idx += bw) {
3605       const int64_t best_rd_sofar =
3606           (rd_thresh == INT64_MAX)
3607               ? INT64_MAX
3608               : (rd_thresh - (AOMMIN(skip_txfm_rd, no_skip_txfm_rd)));
3609       int is_cost_valid = 1;
3610       RD_STATS pn_rd_stats;
3611       // Search for the best transform block size and type for the sub-block.
3612       select_tx_block(cpi, x, idy, idx, block, max_tx_size, init_depth, bsize,
3613                       ctxa, ctxl, tx_above, tx_left, &pn_rd_stats, INT64_MAX,
3614                       best_rd_sofar, &is_cost_valid, ftxs_mode, rd_info_tree);
3615       if (!is_cost_valid || pn_rd_stats.rate == INT_MAX) {
3616         av1_invalid_rd_stats(rd_stats);
3617         return INT64_MAX;
3618       }
3619       av1_merge_rd_stats(rd_stats, &pn_rd_stats);
3620       skip_txfm_rd = RDCOST(x->rdmult, skip_txfm_cost, rd_stats->sse);
3621       no_skip_txfm_rd =
3622           RDCOST(x->rdmult, rd_stats->rate + no_skip_txfm_cost, rd_stats->dist);
3623       block += step;
3624       if (rd_info_tree != NULL) rd_info_tree += 1;
3625     }
3626   }
3627 
3628   if (rd_stats->rate == INT_MAX) return INT64_MAX;
3629 
3630   rd_stats->skip_txfm = (skip_txfm_rd <= no_skip_txfm_rd);
3631 
3632   // If fast_tx_search is true, only DCT and 1D DCT were tested in
3633   // select_inter_block_yrd() above. Do a better search for tx type with
3634   // tx sizes already decided.
3635   if (fast_tx_search && cpi->sf.tx_sf.refine_fast_tx_search_results) {
3636     if (!inter_block_yrd(cpi, x, rd_stats, bsize, ref_best_rd, FTXS_NONE))
3637       return INT64_MAX;
3638   }
3639 
3640   int64_t final_rd;
3641   if (rd_stats->skip_txfm) {
3642     final_rd = RDCOST(x->rdmult, skip_txfm_cost, rd_stats->sse);
3643   } else {
3644     final_rd =
3645         RDCOST(x->rdmult, rd_stats->rate + no_skip_txfm_cost, rd_stats->dist);
3646     if (!xd->lossless[xd->mi[0]->segment_id]) {
3647       final_rd =
3648           AOMMIN(final_rd, RDCOST(x->rdmult, skip_txfm_cost, rd_stats->sse));
3649     }
3650   }
3651 
3652   return final_rd;
3653 }
3654 
3655 // Return 1 to terminate transform search early. The decision is made based on
3656 // the comparison with the reference RD cost and the model-estimated RD cost.
model_based_tx_search_prune(const AV1_COMP * cpi,MACROBLOCK * x,BLOCK_SIZE bsize,int64_t ref_best_rd)3657 static AOM_INLINE int model_based_tx_search_prune(const AV1_COMP *cpi,
3658                                                   MACROBLOCK *x,
3659                                                   BLOCK_SIZE bsize,
3660                                                   int64_t ref_best_rd) {
3661   const int level = cpi->sf.tx_sf.model_based_prune_tx_search_level;
3662   assert(level >= 0 && level <= 2);
3663   int model_rate;
3664   int64_t model_dist;
3665   int model_skip;
3666   MACROBLOCKD *const xd = &x->e_mbd;
3667   model_rd_sb_fn[MODELRD_TYPE_TX_SEARCH_PRUNE](
3668       cpi, bsize, x, xd, 0, 0, &model_rate, &model_dist, &model_skip, NULL,
3669       NULL, NULL, NULL);
3670   if (model_skip) return 0;
3671   const int64_t model_rd = RDCOST(x->rdmult, model_rate, model_dist);
3672   // TODO(debargha, urvang): Improve the model and make the check below
3673   // tighter.
3674   static const int prune_factor_by8[] = { 3, 5 };
3675   const int factor = prune_factor_by8[level - 1];
3676   return ((model_rd * factor) >> 3) > ref_best_rd;
3677 }
3678 
av1_pick_recursive_tx_size_type_yrd(const AV1_COMP * cpi,MACROBLOCK * x,RD_STATS * rd_stats,BLOCK_SIZE bsize,int64_t ref_best_rd)3679 void av1_pick_recursive_tx_size_type_yrd(const AV1_COMP *cpi, MACROBLOCK *x,
3680                                          RD_STATS *rd_stats, BLOCK_SIZE bsize,
3681                                          int64_t ref_best_rd) {
3682   MACROBLOCKD *const xd = &x->e_mbd;
3683   const TxfmSearchParams *txfm_params = &x->txfm_search_params;
3684   assert(is_inter_block(xd->mi[0]));
3685 
3686   av1_invalid_rd_stats(rd_stats);
3687 
3688   // If modeled RD cost is a lot worse than the best so far, terminate early.
3689   if (cpi->sf.tx_sf.model_based_prune_tx_search_level &&
3690       ref_best_rd != INT64_MAX) {
3691     if (model_based_tx_search_prune(cpi, x, bsize, ref_best_rd)) return;
3692   }
3693 
3694   // Hashing based speed feature. If the hash of the prediction residue block is
3695   // found in the hash table, use previous search results and terminate early.
3696   uint32_t hash = 0;
3697   MB_RD_RECORD *mb_rd_record = NULL;
3698   const int mi_row = x->e_mbd.mi_row;
3699   const int mi_col = x->e_mbd.mi_col;
3700   const int within_border =
3701       mi_row >= xd->tile.mi_row_start &&
3702       (mi_row + mi_size_high[bsize] < xd->tile.mi_row_end) &&
3703       mi_col >= xd->tile.mi_col_start &&
3704       (mi_col + mi_size_wide[bsize] < xd->tile.mi_col_end);
3705   const int is_mb_rd_hash_enabled =
3706       (within_border && cpi->sf.rd_sf.use_mb_rd_hash);
3707   const int n4 = bsize_to_num_blk(bsize);
3708   if (is_mb_rd_hash_enabled) {
3709     hash = get_block_residue_hash(x, bsize);
3710     mb_rd_record = &x->txfm_search_info.txb_rd_records->mb_rd_record;
3711     const int match_index = find_mb_rd_info(mb_rd_record, ref_best_rd, hash);
3712     if (match_index != -1) {
3713       MB_RD_INFO *tx_rd_info = &mb_rd_record->tx_rd_info[match_index];
3714       fetch_tx_rd_info(n4, tx_rd_info, rd_stats, x);
3715       return;
3716     }
3717   }
3718 
3719   // If we predict that skip is the optimal RD decision - set the respective
3720   // context and terminate early.
3721   int64_t dist;
3722   if (txfm_params->skip_txfm_level &&
3723       predict_skip_txfm(x, bsize, &dist,
3724                         cpi->common.features.reduced_tx_set_used)) {
3725     set_skip_txfm(x, rd_stats, bsize, dist);
3726     // Save the RD search results into tx_rd_record.
3727     if (is_mb_rd_hash_enabled)
3728       save_tx_rd_info(n4, hash, x, rd_stats, mb_rd_record);
3729     return;
3730   }
3731 #if CONFIG_SPEED_STATS
3732   ++x->txfm_search_info.tx_search_count;
3733 #endif  // CONFIG_SPEED_STATS
3734 
3735   // Pre-compute residue hashes (transform block level) and find existing or
3736   // add new RD records to store and reuse rate and distortion values to speed
3737   // up TX size/type search.
3738   TXB_RD_INFO_NODE matched_rd_info[4 + 16 + 64];
3739   int found_rd_info = 0;
3740   if (ref_best_rd != INT64_MAX && within_border &&
3741       cpi->sf.tx_sf.use_inter_txb_hash) {
3742     found_rd_info = find_tx_size_rd_records(x, bsize, matched_rd_info);
3743   }
3744 
3745   const int64_t rd =
3746       select_tx_size_and_type(cpi, x, rd_stats, bsize, ref_best_rd,
3747                               found_rd_info ? matched_rd_info : NULL);
3748 
3749   if (rd == INT64_MAX) {
3750     // We should always find at least one candidate unless ref_best_rd is less
3751     // than INT64_MAX (in which case, all the calls to select_tx_size_fix_type
3752     // might have failed to find something better)
3753     assert(ref_best_rd != INT64_MAX);
3754     av1_invalid_rd_stats(rd_stats);
3755     return;
3756   }
3757 
3758   // Save the RD search results into tx_rd_record.
3759   if (is_mb_rd_hash_enabled) {
3760     assert(mb_rd_record != NULL);
3761     save_tx_rd_info(n4, hash, x, rd_stats, mb_rd_record);
3762   }
3763 }
3764 
av1_pick_uniform_tx_size_type_yrd(const AV1_COMP * const cpi,MACROBLOCK * x,RD_STATS * rd_stats,BLOCK_SIZE bs,int64_t ref_best_rd)3765 void av1_pick_uniform_tx_size_type_yrd(const AV1_COMP *const cpi, MACROBLOCK *x,
3766                                        RD_STATS *rd_stats, BLOCK_SIZE bs,
3767                                        int64_t ref_best_rd) {
3768   MACROBLOCKD *const xd = &x->e_mbd;
3769   MB_MODE_INFO *const mbmi = xd->mi[0];
3770   const TxfmSearchParams *tx_params = &x->txfm_search_params;
3771   assert(bs == mbmi->bsize);
3772   const int is_inter = is_inter_block(mbmi);
3773   const int mi_row = xd->mi_row;
3774   const int mi_col = xd->mi_col;
3775 
3776   av1_init_rd_stats(rd_stats);
3777 
3778   // Hashing based speed feature for inter blocks. If the hash of the residue
3779   // block is found in the table, use previously saved search results and
3780   // terminate early.
3781   uint32_t hash = 0;
3782   MB_RD_RECORD *mb_rd_record = NULL;
3783   const int num_blks = bsize_to_num_blk(bs);
3784   if (is_inter && cpi->sf.rd_sf.use_mb_rd_hash) {
3785     const int within_border =
3786         mi_row >= xd->tile.mi_row_start &&
3787         (mi_row + mi_size_high[bs] < xd->tile.mi_row_end) &&
3788         mi_col >= xd->tile.mi_col_start &&
3789         (mi_col + mi_size_wide[bs] < xd->tile.mi_col_end);
3790     if (within_border) {
3791       hash = get_block_residue_hash(x, bs);
3792       mb_rd_record = &x->txfm_search_info.txb_rd_records->mb_rd_record;
3793       const int match_index = find_mb_rd_info(mb_rd_record, ref_best_rd, hash);
3794       if (match_index != -1) {
3795         MB_RD_INFO *tx_rd_info = &mb_rd_record->tx_rd_info[match_index];
3796         fetch_tx_rd_info(num_blks, tx_rd_info, rd_stats, x);
3797         return;
3798       }
3799     }
3800   }
3801 
3802   // If we predict that skip is the optimal RD decision - set the respective
3803   // context and terminate early.
3804   int64_t dist;
3805   if (tx_params->skip_txfm_level && is_inter &&
3806       !xd->lossless[mbmi->segment_id] &&
3807       predict_skip_txfm(x, bs, &dist,
3808                         cpi->common.features.reduced_tx_set_used)) {
3809     // Populate rdstats as per skip decision
3810     set_skip_txfm(x, rd_stats, bs, dist);
3811     // Save the RD search results into tx_rd_record.
3812     if (mb_rd_record) {
3813       save_tx_rd_info(num_blks, hash, x, rd_stats, mb_rd_record);
3814     }
3815     return;
3816   }
3817 
3818   if (xd->lossless[mbmi->segment_id]) {
3819     // Lossless mode can only pick the smallest (4x4) transform size.
3820     choose_smallest_tx_size(cpi, x, rd_stats, ref_best_rd, bs);
3821   } else if (tx_params->tx_size_search_method == USE_LARGESTALL) {
3822     choose_largest_tx_size(cpi, x, rd_stats, ref_best_rd, bs);
3823   } else {
3824     choose_tx_size_type_from_rd(cpi, x, rd_stats, ref_best_rd, bs);
3825   }
3826 
3827   // Save the RD search results into tx_rd_record for possible reuse in future.
3828   if (mb_rd_record) {
3829     save_tx_rd_info(num_blks, hash, x, rd_stats, mb_rd_record);
3830   }
3831 }
3832 
av1_txfm_uvrd(const AV1_COMP * const cpi,MACROBLOCK * x,RD_STATS * rd_stats,BLOCK_SIZE bsize,int64_t ref_best_rd)3833 int av1_txfm_uvrd(const AV1_COMP *const cpi, MACROBLOCK *x, RD_STATS *rd_stats,
3834                   BLOCK_SIZE bsize, int64_t ref_best_rd) {
3835   av1_init_rd_stats(rd_stats);
3836   if (ref_best_rd < 0) return 0;
3837   if (!x->e_mbd.is_chroma_ref) return 1;
3838 
3839   MACROBLOCKD *const xd = &x->e_mbd;
3840   MB_MODE_INFO *const mbmi = xd->mi[0];
3841   struct macroblockd_plane *const pd = &xd->plane[AOM_PLANE_U];
3842   const int is_inter = is_inter_block(mbmi);
3843   int64_t this_rd = 0, skip_txfm_rd = 0;
3844   const BLOCK_SIZE plane_bsize =
3845       get_plane_block_size(bsize, pd->subsampling_x, pd->subsampling_y);
3846 
3847   if (is_inter) {
3848     for (int plane = 1; plane < MAX_MB_PLANE; ++plane)
3849       av1_subtract_plane(x, plane_bsize, plane);
3850   }
3851 
3852   const int skip_trellis = 0;
3853   const TX_SIZE uv_tx_size = av1_get_tx_size(AOM_PLANE_U, xd);
3854   int is_cost_valid = 1;
3855   for (int plane = 1; plane < MAX_MB_PLANE; ++plane) {
3856     RD_STATS this_rd_stats;
3857     int64_t chroma_ref_best_rd = ref_best_rd;
3858     // For inter blocks, refined ref_best_rd is used for early exit
3859     // For intra blocks, even though current rd crosses ref_best_rd, early
3860     // exit is not recommended as current rd is used for gating subsequent
3861     // modes as well (say, for angular modes)
3862     // TODO(any): Extend the early exit mechanism for intra modes as well
3863     if (cpi->sf.inter_sf.perform_best_rd_based_gating_for_chroma && is_inter &&
3864         chroma_ref_best_rd != INT64_MAX)
3865       chroma_ref_best_rd = ref_best_rd - AOMMIN(this_rd, skip_txfm_rd);
3866     av1_txfm_rd_in_plane(x, cpi, &this_rd_stats, chroma_ref_best_rd, 0, plane,
3867                          plane_bsize, uv_tx_size, FTXS_NONE, skip_trellis);
3868     if (this_rd_stats.rate == INT_MAX) {
3869       is_cost_valid = 0;
3870       break;
3871     }
3872     av1_merge_rd_stats(rd_stats, &this_rd_stats);
3873     this_rd = RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist);
3874     skip_txfm_rd = RDCOST(x->rdmult, 0, rd_stats->sse);
3875     if (AOMMIN(this_rd, skip_txfm_rd) > ref_best_rd) {
3876       is_cost_valid = 0;
3877       break;
3878     }
3879   }
3880 
3881   if (!is_cost_valid) {
3882     // reset cost value
3883     av1_invalid_rd_stats(rd_stats);
3884   }
3885 
3886   return is_cost_valid;
3887 }
3888 
av1_txfm_rd_in_plane(MACROBLOCK * x,const AV1_COMP * cpi,RD_STATS * rd_stats,int64_t ref_best_rd,int64_t current_rd,int plane,BLOCK_SIZE plane_bsize,TX_SIZE tx_size,FAST_TX_SEARCH_MODE ftxs_mode,int skip_trellis)3889 void av1_txfm_rd_in_plane(MACROBLOCK *x, const AV1_COMP *cpi,
3890                           RD_STATS *rd_stats, int64_t ref_best_rd,
3891                           int64_t current_rd, int plane, BLOCK_SIZE plane_bsize,
3892                           TX_SIZE tx_size, FAST_TX_SEARCH_MODE ftxs_mode,
3893                           int skip_trellis) {
3894   assert(IMPLIES(plane == 0, x->e_mbd.mi[0]->tx_size == tx_size));
3895 
3896   if (!cpi->oxcf.txfm_cfg.enable_tx64 &&
3897       txsize_sqr_up_map[tx_size] == TX_64X64) {
3898     av1_invalid_rd_stats(rd_stats);
3899     return;
3900   }
3901 
3902   if (current_rd > ref_best_rd) {
3903     av1_invalid_rd_stats(rd_stats);
3904     return;
3905   }
3906 
3907   MACROBLOCKD *const xd = &x->e_mbd;
3908   const struct macroblockd_plane *const pd = &xd->plane[plane];
3909   struct rdcost_block_args args;
3910   av1_zero(args);
3911   args.x = x;
3912   args.cpi = cpi;
3913   args.best_rd = ref_best_rd;
3914   args.current_rd = current_rd;
3915   args.ftxs_mode = ftxs_mode;
3916   args.skip_trellis = skip_trellis;
3917   av1_init_rd_stats(&args.rd_stats);
3918 
3919   av1_get_entropy_contexts(plane_bsize, pd, args.t_above, args.t_left);
3920   av1_foreach_transformed_block_in_plane(xd, plane_bsize, plane, block_rd_txfm,
3921                                          &args);
3922 
3923   MB_MODE_INFO *const mbmi = xd->mi[0];
3924   const int is_inter = is_inter_block(mbmi);
3925   const int invalid_rd = is_inter ? args.incomplete_exit : args.exit_early;
3926 
3927   if (invalid_rd) {
3928     av1_invalid_rd_stats(rd_stats);
3929   } else {
3930     *rd_stats = args.rd_stats;
3931   }
3932 }
3933 
av1_txfm_search(const AV1_COMP * cpi,MACROBLOCK * x,BLOCK_SIZE bsize,RD_STATS * rd_stats,RD_STATS * rd_stats_y,RD_STATS * rd_stats_uv,int mode_rate,int64_t ref_best_rd)3934 int av1_txfm_search(const AV1_COMP *cpi, MACROBLOCK *x, BLOCK_SIZE bsize,
3935                     RD_STATS *rd_stats, RD_STATS *rd_stats_y,
3936                     RD_STATS *rd_stats_uv, int mode_rate, int64_t ref_best_rd) {
3937   MACROBLOCKD *const xd = &x->e_mbd;
3938   TxfmSearchParams *txfm_params = &x->txfm_search_params;
3939   const int skip_ctx = av1_get_skip_txfm_context(xd);
3940   const int skip_txfm_cost[2] = { x->mode_costs.skip_txfm_cost[skip_ctx][0],
3941                                   x->mode_costs.skip_txfm_cost[skip_ctx][1] };
3942   const int64_t min_header_rate =
3943       mode_rate + AOMMIN(skip_txfm_cost[0], skip_txfm_cost[1]);
3944   // Account for minimum skip and non_skip rd.
3945   // Eventually either one of them will be added to mode_rate
3946   const int64_t min_header_rd_possible = RDCOST(x->rdmult, min_header_rate, 0);
3947   if (min_header_rd_possible > ref_best_rd) {
3948     av1_invalid_rd_stats(rd_stats_y);
3949     return 0;
3950   }
3951 
3952   const AV1_COMMON *cm = &cpi->common;
3953   MB_MODE_INFO *const mbmi = xd->mi[0];
3954   const int64_t mode_rd = RDCOST(x->rdmult, mode_rate, 0);
3955   const int64_t rd_thresh =
3956       ref_best_rd == INT64_MAX ? INT64_MAX : ref_best_rd - mode_rd;
3957   av1_init_rd_stats(rd_stats);
3958   av1_init_rd_stats(rd_stats_y);
3959   rd_stats->rate = mode_rate;
3960 
3961   // cost and distortion
3962   av1_subtract_plane(x, bsize, 0);
3963   if (txfm_params->tx_mode_search_type == TX_MODE_SELECT &&
3964       !xd->lossless[mbmi->segment_id]) {
3965     av1_pick_recursive_tx_size_type_yrd(cpi, x, rd_stats_y, bsize, rd_thresh);
3966 #if CONFIG_COLLECT_RD_STATS == 2
3967     PrintPredictionUnitStats(cpi, tile_data, x, rd_stats_y, bsize);
3968 #endif  // CONFIG_COLLECT_RD_STATS == 2
3969   } else {
3970     av1_pick_uniform_tx_size_type_yrd(cpi, x, rd_stats_y, bsize, rd_thresh);
3971     memset(mbmi->inter_tx_size, mbmi->tx_size, sizeof(mbmi->inter_tx_size));
3972     for (int i = 0; i < xd->height * xd->width; ++i)
3973       set_blk_skip(x->txfm_search_info.blk_skip, 0, i, rd_stats_y->skip_txfm);
3974   }
3975 
3976   if (rd_stats_y->rate == INT_MAX) return 0;
3977 
3978   av1_merge_rd_stats(rd_stats, rd_stats_y);
3979 
3980   const int64_t non_skip_txfm_rdcosty =
3981       RDCOST(x->rdmult, rd_stats->rate + skip_txfm_cost[0], rd_stats->dist);
3982   const int64_t skip_txfm_rdcosty =
3983       RDCOST(x->rdmult, mode_rate + skip_txfm_cost[1], rd_stats->sse);
3984   const int64_t min_rdcosty = AOMMIN(non_skip_txfm_rdcosty, skip_txfm_rdcosty);
3985   if (min_rdcosty > ref_best_rd) return 0;
3986 
3987   av1_init_rd_stats(rd_stats_uv);
3988   const int num_planes = av1_num_planes(cm);
3989   if (num_planes > 1) {
3990     int64_t ref_best_chroma_rd = ref_best_rd;
3991     // Calculate best rd cost possible for chroma
3992     if (cpi->sf.inter_sf.perform_best_rd_based_gating_for_chroma &&
3993         (ref_best_chroma_rd != INT64_MAX)) {
3994       ref_best_chroma_rd = (ref_best_chroma_rd -
3995                             AOMMIN(non_skip_txfm_rdcosty, skip_txfm_rdcosty));
3996     }
3997     const int is_cost_valid_uv =
3998         av1_txfm_uvrd(cpi, x, rd_stats_uv, bsize, ref_best_chroma_rd);
3999     if (!is_cost_valid_uv) return 0;
4000     av1_merge_rd_stats(rd_stats, rd_stats_uv);
4001   }
4002 
4003   int choose_skip_txfm = rd_stats->skip_txfm;
4004   if (!choose_skip_txfm && !xd->lossless[mbmi->segment_id]) {
4005     const int64_t rdcost_no_skip_txfm = RDCOST(
4006         x->rdmult, rd_stats_y->rate + rd_stats_uv->rate + skip_txfm_cost[0],
4007         rd_stats->dist);
4008     const int64_t rdcost_skip_txfm =
4009         RDCOST(x->rdmult, skip_txfm_cost[1], rd_stats->sse);
4010     if (rdcost_no_skip_txfm >= rdcost_skip_txfm) choose_skip_txfm = 1;
4011   }
4012   if (choose_skip_txfm) {
4013     rd_stats_y->rate = 0;
4014     rd_stats_uv->rate = 0;
4015     rd_stats->rate = mode_rate + skip_txfm_cost[1];
4016     rd_stats->dist = rd_stats->sse;
4017     rd_stats_y->dist = rd_stats_y->sse;
4018     rd_stats_uv->dist = rd_stats_uv->sse;
4019     mbmi->skip_txfm = 1;
4020     if (rd_stats->skip_txfm) {
4021       const int64_t tmprd = RDCOST(x->rdmult, rd_stats->rate, rd_stats->dist);
4022       if (tmprd > ref_best_rd) return 0;
4023     }
4024   } else {
4025     rd_stats->rate += skip_txfm_cost[0];
4026     mbmi->skip_txfm = 0;
4027   }
4028 
4029   return 1;
4030 }
4031