1 /*
2  * Copyright (c) 2017, Alliance for Open Media. All rights reserved
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include "av1/common/scan.h"
13 #include "av1/common/blockd.h"
14 #include "av1/common/idct.h"
15 #include "av1/common/pred_common.h"
16 #include "av1/encoder/bitstream.h"
17 #include "av1/encoder/encodeframe.h"
18 #include "av1/encoder/cost.h"
19 #include "av1/encoder/encodetxb.h"
20 #include "av1/encoder/rdopt.h"
21 #include "av1/encoder/subexp.h"
22 #include "av1/encoder/tokenize.h"
23 
24 #define TEST_OPTIMIZE_TXB 0
25 
av1_alloc_txb_buf(AV1_COMP * cpi)26 void av1_alloc_txb_buf(AV1_COMP *cpi) {
27 #if 0
28   AV1_COMMON *cm = &cpi->common;
29   int mi_block_size = 1 << MI_SIZE_LOG2;
30   // TODO(angiebird): Make sure cm->subsampling_x/y is set correctly, and then
31   // use precise buffer size according to cm->subsampling_x/y
32   int pixel_stride = mi_block_size * cm->mi_cols;
33   int pixel_height = mi_block_size * cm->mi_rows;
34   int i;
35   for (i = 0; i < MAX_MB_PLANE; ++i) {
36     CHECK_MEM_ERROR(
37         cm, cpi->tcoeff_buf[i],
38         aom_malloc(sizeof(*cpi->tcoeff_buf[i]) * pixel_stride * pixel_height));
39   }
40 #else
41   AV1_COMMON *cm = &cpi->common;
42   int size = ((cm->mi_rows >> MAX_MIB_SIZE_LOG2) + 1) *
43              ((cm->mi_cols >> MAX_MIB_SIZE_LOG2) + 1);
44 
45   av1_free_txb_buf(cpi);
46   // TODO(jingning): This should be further reduced.
47   CHECK_MEM_ERROR(cm, cpi->coeff_buffer_base,
48                   aom_malloc(sizeof(*cpi->coeff_buffer_base) * size));
49 #endif
50 }
51 
av1_free_txb_buf(AV1_COMP * cpi)52 void av1_free_txb_buf(AV1_COMP *cpi) {
53 #if 0
54   int i;
55   for (i = 0; i < MAX_MB_PLANE; ++i) {
56     aom_free(cpi->tcoeff_buf[i]);
57   }
58 #else
59   aom_free(cpi->coeff_buffer_base);
60 #endif
61 }
62 
av1_set_coeff_buffer(const AV1_COMP * const cpi,MACROBLOCK * const x,int mi_row,int mi_col)63 void av1_set_coeff_buffer(const AV1_COMP *const cpi, MACROBLOCK *const x,
64                           int mi_row, int mi_col) {
65   int stride = (cpi->common.mi_cols >> MAX_MIB_SIZE_LOG2) + 1;
66   int offset =
67       (mi_row >> MAX_MIB_SIZE_LOG2) * stride + (mi_col >> MAX_MIB_SIZE_LOG2);
68   CB_COEFF_BUFFER *coeff_buf = &cpi->coeff_buffer_base[offset];
69   const int txb_offset = x->cb_offset / (TX_SIZE_W_MIN * TX_SIZE_H_MIN);
70   for (int plane = 0; plane < MAX_MB_PLANE; ++plane) {
71     x->mbmi_ext->tcoeff[plane] = coeff_buf->tcoeff[plane] + x->cb_offset;
72     x->mbmi_ext->eobs[plane] = coeff_buf->eobs[plane] + txb_offset;
73     x->mbmi_ext->txb_skip_ctx[plane] =
74         coeff_buf->txb_skip_ctx[plane] + txb_offset;
75     x->mbmi_ext->dc_sign_ctx[plane] =
76         coeff_buf->dc_sign_ctx[plane] + txb_offset;
77   }
78 }
79 
write_golomb(aom_writer * w,int level)80 static void write_golomb(aom_writer *w, int level) {
81   int x = level + 1;
82   int i = x;
83   int length = 0;
84 
85   while (i) {
86     i >>= 1;
87     ++length;
88   }
89   assert(length > 0);
90 
91   for (i = 0; i < length - 1; ++i) aom_write_bit(w, 0);
92 
93   for (i = length - 1; i >= 0; --i) aom_write_bit(w, (x >> i) & 0x01);
94 }
95 
write_nz_map(aom_writer * w,const tran_low_t * tcoeff,uint16_t eob,int plane,const int16_t * scan,TX_SIZE tx_size,TX_TYPE tx_type,FRAME_CONTEXT * fc)96 static INLINE void write_nz_map(aom_writer *w, const tran_low_t *tcoeff,
97                                 uint16_t eob, int plane, const int16_t *scan,
98                                 TX_SIZE tx_size, TX_TYPE tx_type,
99                                 FRAME_CONTEXT *fc) {
100   const PLANE_TYPE plane_type = get_plane_type(plane);
101   const TX_SIZE txs_ctx = get_txsize_context(tx_size);
102   const int bwl = b_width_log2_lookup[txsize_to_bsize[tx_size]] + 2;
103   const int height = tx_size_high[tx_size];
104 #if CONFIG_CTX1D
105   const int width = tx_size_wide[tx_size];
106   const int eob_offset = width + height;
107   const TX_CLASS tx_class = get_tx_class(tx_type);
108   const int seg_eob =
109       (tx_class == TX_CLASS_2D) ? tx_size_2d[tx_size] : eob_offset;
110 #else
111   const int seg_eob = tx_size_2d[tx_size];
112 #endif
113 #if !LV_MAP_PROB
114   aom_prob *nz_map = fc->nz_map[txs_ctx][plane_type];
115   aom_prob *eob_flag = fc->eob_flag[txs_ctx][plane_type];
116 #endif
117 
118   for (int c = 0; c < eob; ++c) {
119     int coeff_ctx = get_nz_map_ctx(tcoeff, c, scan, bwl, height, tx_type);
120     int eob_ctx = get_eob_ctx(tcoeff, scan[c], txs_ctx, tx_type);
121 
122     tran_low_t v = tcoeff[scan[c]];
123     int is_nz = (v != 0);
124 
125     if (c == seg_eob - 1) break;
126 
127 #if LV_MAP_PROB
128     aom_write_bin(w, is_nz, fc->nz_map_cdf[txs_ctx][plane_type][coeff_ctx], 2);
129 #else
130     aom_write(w, is_nz, nz_map[coeff_ctx]);
131 #endif
132 
133     if (is_nz) {
134 #if LV_MAP_PROB
135       aom_write_bin(w, c == (eob - 1),
136                     fc->eob_flag_cdf[txs_ctx][plane_type][eob_ctx], 2);
137 #else
138       aom_write(w, c == (eob - 1), eob_flag[eob_ctx]);
139 #endif
140     }
141   }
142 }
143 
144 #if CONFIG_CTX1D
write_nz_map_vert(aom_writer * w,const tran_low_t * tcoeff,uint16_t eob,int plane,const int16_t * scan,const int16_t * iscan,TX_SIZE tx_size,TX_TYPE tx_type,FRAME_CONTEXT * fc)145 static INLINE void write_nz_map_vert(aom_writer *w, const tran_low_t *tcoeff,
146                                      uint16_t eob, int plane,
147                                      const int16_t *scan, const int16_t *iscan,
148                                      TX_SIZE tx_size, TX_TYPE tx_type,
149                                      FRAME_CONTEXT *fc) {
150   (void)eob;
151   const TX_SIZE txs_ctx = get_txsize_context(tx_size);
152   const PLANE_TYPE plane_type = get_plane_type(plane);
153   const TX_CLASS tx_class = get_tx_class(tx_type);
154   const int bwl = b_width_log2_lookup[txsize_to_bsize[tx_size]] + 2;
155   const int width = tx_size_wide[tx_size];
156   const int height = tx_size_high[tx_size];
157   int16_t eob_ls[MAX_HVTX_SIZE];
158   get_eob_vert(eob_ls, tcoeff, width, height);
159 #if !LV_MAP_PROB
160   aom_prob *nz_map = fc->nz_map[txs_ctx][plane_type];
161 #endif
162   for (int c = 0; c < width; ++c) {
163     int16_t veob = eob_ls[c];
164     assert(veob <= height);
165     int el_ctx = get_empty_line_ctx(c, eob_ls);
166 #if LV_MAP_PROB
167     aom_write_bin(w, veob == 0,
168                   fc->empty_line_cdf[txs_ctx][plane_type][tx_class][el_ctx], 2);
169 #else
170     aom_write(w, veob == 0,
171               fc->empty_line[txs_ctx][plane_type][tx_class][el_ctx]);
172 #endif
173     if (veob) {
174       for (int r = 0; r < veob; ++r) {
175         if (r + 1 != height) {
176           int coeff_idx = r * width + c;
177           int scan_idx = iscan[coeff_idx];
178           int is_nz = tcoeff[coeff_idx] != 0;
179           int coeff_ctx =
180               get_nz_map_ctx(tcoeff, scan_idx, scan, bwl, height, tx_type);
181 #if LV_MAP_PROB
182           aom_write_bin(w, is_nz,
183                         fc->nz_map_cdf[txs_ctx][plane_type][coeff_ctx], 2);
184 #else
185           aom_write(w, is_nz, nz_map[coeff_ctx]);
186 #endif
187           if (is_nz) {
188             int eob_ctx = get_hv_eob_ctx(c, r, eob_ls);
189 #if LV_MAP_PROB
190             aom_write_bin(
191                 w, r == veob - 1,
192                 fc->hv_eob_cdf[txs_ctx][plane_type][tx_class][eob_ctx], 2);
193 #else
194             aom_write(w, r == veob - 1,
195                       fc->hv_eob[txs_ctx][plane_type][tx_class][eob_ctx]);
196 #endif
197           }
198         }
199       }
200     }
201   }
202 }
203 
write_nz_map_horiz(aom_writer * w,const tran_low_t * tcoeff,uint16_t eob,int plane,const int16_t * scan,const int16_t * iscan,TX_SIZE tx_size,TX_TYPE tx_type,FRAME_CONTEXT * fc)204 static INLINE void write_nz_map_horiz(aom_writer *w, const tran_low_t *tcoeff,
205                                       uint16_t eob, int plane,
206                                       const int16_t *scan, const int16_t *iscan,
207                                       TX_SIZE tx_size, TX_TYPE tx_type,
208                                       FRAME_CONTEXT *fc) {
209   (void)scan;
210   (void)eob;
211   const TX_SIZE txs_ctx = get_txsize_context(tx_size);
212   const PLANE_TYPE plane_type = get_plane_type(plane);
213   const TX_CLASS tx_class = get_tx_class(tx_type);
214   const int bwl = b_width_log2_lookup[txsize_to_bsize[tx_size]] + 2;
215   const int width = tx_size_wide[tx_size];
216   const int height = tx_size_high[tx_size];
217   int16_t eob_ls[MAX_HVTX_SIZE];
218   get_eob_horiz(eob_ls, tcoeff, width, height);
219 #if !LV_MAP_PROB
220   aom_prob *nz_map = fc->nz_map[txs_ctx][plane_type];
221 #endif
222   for (int r = 0; r < height; ++r) {
223     int16_t heob = eob_ls[r];
224     int el_ctx = get_empty_line_ctx(r, eob_ls);
225 #if LV_MAP_PROB
226     aom_write_bin(w, heob == 0,
227                   fc->empty_line_cdf[txs_ctx][plane_type][tx_class][el_ctx], 2);
228 #else
229     aom_write(w, heob == 0,
230               fc->empty_line[txs_ctx][plane_type][tx_class][el_ctx]);
231 #endif
232     if (heob) {
233       for (int c = 0; c < heob; ++c) {
234         if (c + 1 != width) {
235           int coeff_idx = r * width + c;
236           int scan_idx = iscan[coeff_idx];
237           int is_nz = tcoeff[coeff_idx] != 0;
238           int coeff_ctx =
239               get_nz_map_ctx(tcoeff, scan_idx, scan, bwl, height, tx_type);
240 #if LV_MAP_PROB
241           aom_write_bin(w, is_nz,
242                         fc->nz_map_cdf[txs_ctx][plane_type][coeff_ctx], 2);
243 #else
244           aom_write(w, is_nz, nz_map[coeff_ctx]);
245 #endif
246           if (is_nz) {
247             int eob_ctx = get_hv_eob_ctx(r, c, eob_ls);
248 #if LV_MAP_PROB
249             aom_write_bin(
250                 w, c == heob - 1,
251                 fc->hv_eob_cdf[txs_ctx][plane_type][tx_class][eob_ctx], 2);
252 #else
253             aom_write(w, c == heob - 1,
254                       fc->hv_eob[txs_ctx][plane_type][tx_class][eob_ctx]);
255 #endif
256           }
257         }
258       }
259     }
260   }
261 }
262 #endif
263 
av1_write_coeffs_txb(const AV1_COMMON * const cm,MACROBLOCKD * xd,aom_writer * w,int blk_row,int blk_col,int block,int plane,TX_SIZE tx_size,const tran_low_t * tcoeff,uint16_t eob,TXB_CTX * txb_ctx)264 void av1_write_coeffs_txb(const AV1_COMMON *const cm, MACROBLOCKD *xd,
265                           aom_writer *w, int blk_row, int blk_col, int block,
266                           int plane, TX_SIZE tx_size, const tran_low_t *tcoeff,
267                           uint16_t eob, TXB_CTX *txb_ctx) {
268   MB_MODE_INFO *mbmi = &xd->mi[0]->mbmi;
269   const PLANE_TYPE plane_type = get_plane_type(plane);
270   const TX_SIZE txs_ctx = get_txsize_context(tx_size);
271   const TX_TYPE tx_type =
272       av1_get_tx_type(plane_type, xd, blk_row, blk_col, block, tx_size);
273   const SCAN_ORDER *const scan_order = get_scan(cm, tx_size, tx_type, mbmi);
274   const int16_t *scan = scan_order->scan;
275   int c;
276   const int bwl = b_width_log2_lookup[txsize_to_bsize[tx_size]] + 2;
277   const int height = tx_size_high[tx_size];
278   uint16_t update_eob = 0;
279   FRAME_CONTEXT *ec_ctx = xd->tile_ctx;
280 
281   (void)blk_row;
282   (void)blk_col;
283 
284 #if LV_MAP_PROB
285   aom_write_bin(w, eob == 0,
286                 ec_ctx->txb_skip_cdf[txs_ctx][txb_ctx->txb_skip_ctx], 2);
287 #else
288   aom_write(w, eob == 0, ec_ctx->txb_skip[txs_ctx][txb_ctx->txb_skip_ctx]);
289 #endif
290 
291   if (eob == 0) return;
292 #if CONFIG_TXK_SEL
293   av1_write_tx_type(cm, xd, blk_row, blk_col, block, plane,
294                     get_min_tx_size(tx_size), w);
295 #endif
296 
297 #if CONFIG_CTX1D
298   TX_CLASS tx_class = get_tx_class(tx_type);
299   if (tx_class == TX_CLASS_2D) {
300     write_nz_map(w, tcoeff, eob, plane, scan, tx_size, tx_type, ec_ctx);
301   } else {
302     const int width = tx_size_wide[tx_size];
303     const int eob_offset = width + height;
304     const int eob_mode = eob > eob_offset;
305 #if LV_MAP_PROB
306     aom_write_bin(w, eob_mode,
307                   ec_ctx->eob_mode_cdf[txs_ctx][plane_type][tx_class], 2);
308 #else
309     aom_write(w, eob_mode, ec_ctx->eob_mode[txs_ctx][plane_type][tx_class]);
310 #endif
311     if (eob_mode == 0) {
312       write_nz_map(w, tcoeff, eob, plane, scan, tx_size, tx_type, ec_ctx);
313     } else {
314       const int16_t *iscan = scan_order->iscan;
315       assert(tx_class == TX_CLASS_VERT || tx_class == TX_CLASS_HORIZ);
316       if (tx_class == TX_CLASS_VERT)
317         write_nz_map_vert(w, tcoeff, eob, plane, scan, iscan, tx_size, tx_type,
318                           ec_ctx);
319       else
320         write_nz_map_horiz(w, tcoeff, eob, plane, scan, iscan, tx_size, tx_type,
321                            ec_ctx);
322     }
323   }
324 #else
325   write_nz_map(w, tcoeff, eob, plane, scan, tx_size, tx_type, ec_ctx);
326 #endif  // CONFIG_CTX1D
327 
328   int i;
329   for (i = 0; i < NUM_BASE_LEVELS; ++i) {
330 #if !LV_MAP_PROB
331     aom_prob *coeff_base = ec_ctx->coeff_base[txs_ctx][plane_type][i];
332 #endif
333     update_eob = 0;
334     for (c = eob - 1; c >= 0; --c) {
335       tran_low_t v = tcoeff[scan[c]];
336       tran_low_t level = abs(v);
337       int sign = (v < 0) ? 1 : 0;
338       int ctx;
339 
340       if (level <= i) continue;
341 
342       ctx = get_base_ctx(tcoeff, scan[c], bwl, height, i + 1);
343 
344       if (level == i + 1) {
345 #if LV_MAP_PROB
346         aom_write_bin(w, 1, ec_ctx->coeff_base_cdf[txs_ctx][plane_type][i][ctx],
347                       2);
348 #else
349         aom_write(w, 1, coeff_base[ctx]);
350 #endif
351         if (c == 0) {
352 #if LV_MAP_PROB
353           aom_write_bin(w, sign,
354                         ec_ctx->dc_sign_cdf[plane_type][txb_ctx->dc_sign_ctx],
355                         2);
356 #else
357           aom_write(w, sign, ec_ctx->dc_sign[plane_type][txb_ctx->dc_sign_ctx]);
358 #endif
359         } else {
360           aom_write_bit(w, sign);
361         }
362         continue;
363       }
364 
365 #if LV_MAP_PROB
366       aom_write_bin(w, 0, ec_ctx->coeff_base_cdf[txs_ctx][plane_type][i][ctx],
367                     2);
368 #else
369       aom_write(w, 0, coeff_base[ctx]);
370 #endif
371       update_eob = AOMMAX(update_eob, c);
372     }
373   }
374 
375   for (c = update_eob; c >= 0; --c) {
376     tran_low_t v = tcoeff[scan[c]];
377     tran_low_t level = abs(v);
378     int sign = (v < 0) ? 1 : 0;
379     int idx;
380     int ctx;
381 
382     if (level <= NUM_BASE_LEVELS) continue;
383 
384     if (c == 0) {
385 #if LV_MAP_PROB
386       aom_write_bin(w, sign,
387                     ec_ctx->dc_sign_cdf[plane_type][txb_ctx->dc_sign_ctx], 2);
388 #else
389       aom_write(w, sign, ec_ctx->dc_sign[plane_type][txb_ctx->dc_sign_ctx]);
390 #endif
391     } else {
392       aom_write_bit(w, sign);
393     }
394 
395     // level is above 1.
396     ctx = get_br_ctx(tcoeff, scan[c], bwl, height);
397 
398 #if BR_NODE
399     int base_range = level - 1 - NUM_BASE_LEVELS;
400     int br_set_idx = 0;
401     int br_base = 0;
402     int br_offset = 0;
403 
404     if (base_range >= COEFF_BASE_RANGE)
405       br_set_idx = BASE_RANGE_SETS;
406     else
407       br_set_idx = coeff_to_br_index[base_range];
408 
409     for (idx = 0; idx < BASE_RANGE_SETS; ++idx) {
410       aom_write_bin(w, idx == br_set_idx,
411                     ec_ctx->coeff_br_cdf[txs_ctx][plane_type][idx][ctx], 2);
412       if (idx == br_set_idx) {
413         br_base = br_index_to_coeff[br_set_idx];
414         br_offset = base_range - br_base;
415         int extra_bits = (1 << br_extra_bits[idx]) - 1;
416         for (int tok = 0; tok < extra_bits; ++tok) {
417           if (tok == br_offset) {
418             aom_write_bin(w, 1, ec_ctx->coeff_lps_cdf[txs_ctx][plane_type][ctx],
419                           2);
420             break;
421           }
422           aom_write_bin(w, 0, ec_ctx->coeff_lps_cdf[txs_ctx][plane_type][ctx],
423                         2);
424         }
425         //        aom_write_literal(w, br_offset, br_extra_bits[idx]);
426         break;
427       }
428     }
429 
430     if (br_set_idx < BASE_RANGE_SETS) continue;
431 #else  // BR_NODE
432     for (idx = 0; idx < COEFF_BASE_RANGE; ++idx) {
433       if (level == (idx + 1 + NUM_BASE_LEVELS)) {
434 #if LV_MAP_PROB
435         aom_write_bin(w, 1, ec_ctx->coeff_lps_cdf[txs_ctx][plane_type][ctx], 2);
436 #else
437         aom_write(w, 1, ec_ctx->coeff_lps[txs_ctx][plane_type][ctx]);
438 #endif
439         break;
440       }
441 #if LV_MAP_PROB
442       aom_write_bin(w, 0, ec_ctx->coeff_lps_cdf[txs_ctx][plane_type][ctx], 2);
443 #else
444       aom_write(w, 0, ec_ctx->coeff_lps[txs_ctx][plane_type][ctx]);
445 #endif
446     }
447     if (idx < COEFF_BASE_RANGE) continue;
448 #endif  // BR_NODE
449 
450     // use 0-th order Golomb code to handle the residual level.
451     write_golomb(w, level - COEFF_BASE_RANGE - 1 - NUM_BASE_LEVELS);
452   }
453 }
454 
av1_write_coeffs_mb(const AV1_COMMON * const cm,MACROBLOCK * x,aom_writer * w,int plane)455 void av1_write_coeffs_mb(const AV1_COMMON *const cm, MACROBLOCK *x,
456                          aom_writer *w, int plane) {
457   MACROBLOCKD *xd = &x->e_mbd;
458   MB_MODE_INFO *mbmi = &xd->mi[0]->mbmi;
459   BLOCK_SIZE bsize = mbmi->sb_type;
460   struct macroblockd_plane *pd = &xd->plane[plane];
461 
462 #if CONFIG_CHROMA_SUB8X8
463   const BLOCK_SIZE plane_bsize =
464       AOMMAX(BLOCK_4X4, get_plane_block_size(bsize, pd));
465 #elif CONFIG_CB4X4
466   const BLOCK_SIZE plane_bsize = get_plane_block_size(bsize, pd);
467 #else
468   const BLOCK_SIZE plane_bsize =
469       get_plane_block_size(AOMMAX(bsize, BLOCK_8X8), pd);
470 #endif
471   const int max_blocks_wide = max_block_wide(xd, plane_bsize, plane);
472   const int max_blocks_high = max_block_high(xd, plane_bsize, plane);
473   const TX_SIZE tx_size = av1_get_tx_size(plane, xd);
474   const int bkw = tx_size_wide_unit[tx_size];
475   const int bkh = tx_size_high_unit[tx_size];
476   const int step = tx_size_wide_unit[tx_size] * tx_size_high_unit[tx_size];
477   int row, col;
478   int block = 0;
479   for (row = 0; row < max_blocks_high; row += bkh) {
480     for (col = 0; col < max_blocks_wide; col += bkw) {
481       tran_low_t *tcoeff = BLOCK_OFFSET(x->mbmi_ext->tcoeff[plane], block);
482       uint16_t eob = x->mbmi_ext->eobs[plane][block];
483       TXB_CTX txb_ctx = { x->mbmi_ext->txb_skip_ctx[plane][block],
484                           x->mbmi_ext->dc_sign_ctx[plane][block] };
485       av1_write_coeffs_txb(cm, xd, w, row, col, block, plane, tx_size, tcoeff,
486                            eob, &txb_ctx);
487       block += step;
488     }
489   }
490 }
491 
get_base_ctx_set(const tran_low_t * tcoeffs,int c,const int bwl,const int height,int ctx_set[NUM_BASE_LEVELS])492 static INLINE void get_base_ctx_set(const tran_low_t *tcoeffs,
493                                     int c,  // raster order
494                                     const int bwl, const int height,
495                                     int ctx_set[NUM_BASE_LEVELS]) {
496   const int row = c >> bwl;
497   const int col = c - (row << bwl);
498   const int stride = 1 << bwl;
499   int mag[NUM_BASE_LEVELS] = { 0 };
500   int idx;
501   tran_low_t abs_coeff;
502   int i;
503 
504   for (idx = 0; idx < BASE_CONTEXT_POSITION_NUM; ++idx) {
505     int ref_row = row + base_ref_offset[idx][0];
506     int ref_col = col + base_ref_offset[idx][1];
507     int pos = (ref_row << bwl) + ref_col;
508 
509     if (ref_row < 0 || ref_col < 0 || ref_row >= height || ref_col >= stride)
510       continue;
511 
512     abs_coeff = abs(tcoeffs[pos]);
513 
514     for (i = 0; i < NUM_BASE_LEVELS; ++i) {
515       ctx_set[i] += abs_coeff > i;
516       if (base_ref_offset[idx][0] >= 0 && base_ref_offset[idx][1] >= 0)
517         mag[i] |= abs_coeff > (i + 1);
518     }
519   }
520 
521   for (i = 0; i < NUM_BASE_LEVELS; ++i) {
522     ctx_set[i] = get_base_ctx_from_count_mag(row, col, ctx_set[i], mag[i]);
523   }
524   return;
525 }
526 
get_br_cost(tran_low_t abs_qc,int ctx,const int * coeff_lps)527 static INLINE int get_br_cost(tran_low_t abs_qc, int ctx,
528                               const int *coeff_lps) {
529   const tran_low_t min_level = 1 + NUM_BASE_LEVELS;
530   const tran_low_t max_level = 1 + NUM_BASE_LEVELS + COEFF_BASE_RANGE;
531   (void)ctx;
532   if (abs_qc >= min_level) {
533 #if BR_NODE
534     if (abs_qc >= max_level)
535       return coeff_lps[COEFF_BASE_RANGE];  // COEFF_BASE_RANGE * cost0;
536     else
537       return coeff_lps[(abs_qc - min_level)];  //  * cost0 + cost1;
538 #else
539     const int cost0 = coeff_lps[0];
540     const int cost1 = coeff_lps[1];
541     if (abs_qc >= max_level)
542       return COEFF_BASE_RANGE * cost0;
543     else
544       return (abs_qc - min_level) * cost0 + cost1;
545 #endif
546   } else {
547     return 0;
548   }
549 }
550 
get_base_cost(tran_low_t abs_qc,int ctx,const int coeff_base[2],int base_idx)551 static INLINE int get_base_cost(tran_low_t abs_qc, int ctx,
552                                 const int coeff_base[2], int base_idx) {
553   const int level = base_idx + 1;
554   (void)ctx;
555   if (abs_qc < level)
556     return 0;
557   else
558     return coeff_base[abs_qc == level];
559 }
560 
get_nz_eob_map_cost(const LV_MAP_COEFF_COST * coeff_costs,const tran_low_t * qcoeff,uint16_t eob,int plane,const int16_t * scan,TX_SIZE tx_size,TX_TYPE tx_type)561 int get_nz_eob_map_cost(const LV_MAP_COEFF_COST *coeff_costs,
562                         const tran_low_t *qcoeff, uint16_t eob, int plane,
563                         const int16_t *scan, TX_SIZE tx_size, TX_TYPE tx_type) {
564   (void)plane;
565   TX_SIZE txs_ctx = get_txsize_context(tx_size);
566   const int bwl = b_width_log2_lookup[txsize_to_bsize[tx_size]] + 2;
567   const int height = tx_size_high[tx_size];
568 #if CONFIG_CTX1D
569   const TX_CLASS tx_class = get_tx_class(tx_type);
570   const int width = tx_size_wide[tx_size];
571   const int eob_offset = width + height;
572   const int seg_eob =
573       (tx_class == TX_CLASS_2D) ? tx_size_2d[tx_size] : eob_offset;
574 #else
575   const int seg_eob = tx_size_2d[tx_size];
576 #endif
577   int cost = 0;
578   for (int c = 0; c < eob; ++c) {
579     tran_low_t v = qcoeff[scan[c]];
580     int is_nz = (v != 0);
581     if (c + 1 != seg_eob) {
582       int coeff_ctx = get_nz_map_ctx(qcoeff, c, scan, bwl, height, tx_type);
583       cost += coeff_costs->nz_map_cost[coeff_ctx][is_nz];
584       if (is_nz) {
585         int eob_ctx = get_eob_ctx(qcoeff, scan[c], txs_ctx, tx_type);
586         cost += coeff_costs->eob_cost[eob_ctx][c == (eob - 1)];
587       }
588     }
589   }
590   return cost;
591 }
592 
593 #if CONFIG_CTX1D
get_nz_eob_map_cost_vert(const LV_MAP_COEFF_COST * coeff_costs,const tran_low_t * qcoeff,uint16_t eob,int plane,const int16_t * scan,const int16_t * iscan,TX_SIZE tx_size,TX_TYPE tx_type)594 static INLINE int get_nz_eob_map_cost_vert(const LV_MAP_COEFF_COST *coeff_costs,
595                                            const tran_low_t *qcoeff,
596                                            uint16_t eob, int plane,
597                                            const int16_t *scan,
598                                            const int16_t *iscan,
599                                            TX_SIZE tx_size, TX_TYPE tx_type) {
600   (void)tx_size;
601   (void)scan;
602   (void)eob;
603   (void)plane;
604   const TX_CLASS tx_class = get_tx_class(tx_type);
605   const int bwl = b_width_log2_lookup[txsize_to_bsize[tx_size]] + 2;
606   const int width = tx_size_wide[tx_size];
607   const int height = tx_size_high[tx_size];
608   int16_t eob_ls[MAX_HVTX_SIZE];
609   get_eob_vert(eob_ls, qcoeff, width, height);
610   int cost = 0;
611   for (int c = 0; c < width; ++c) {
612     int16_t veob = eob_ls[c];
613     assert(veob <= height);
614     int el_ctx = get_empty_line_ctx(c, eob_ls);
615     cost += coeff_costs->empty_line_cost[tx_class][el_ctx][veob == 0];
616     if (veob) {
617       for (int r = 0; r < veob; ++r) {
618         if (r + 1 != height) {
619           int coeff_idx = r * width + c;
620           int scan_idx = iscan[coeff_idx];
621           int is_nz = qcoeff[coeff_idx] != 0;
622           int coeff_ctx =
623               get_nz_map_ctx(qcoeff, scan_idx, scan, bwl, height, tx_type);
624           cost += coeff_costs->nz_map_cost[coeff_ctx][is_nz];
625           if (is_nz) {
626             int eob_ctx = get_hv_eob_ctx(c, r, eob_ls);
627             cost += coeff_costs->hv_eob_cost[tx_class][eob_ctx][r == veob - 1];
628           }
629         }
630       }
631     }
632   }
633   return cost;
634 }
635 
get_nz_eob_map_cost_horiz(const LV_MAP_COEFF_COST * coeff_costs,const tran_low_t * qcoeff,uint16_t eob,int plane,const int16_t * scan,const int16_t * iscan,TX_SIZE tx_size,TX_TYPE tx_type)636 static INLINE int get_nz_eob_map_cost_horiz(
637     const LV_MAP_COEFF_COST *coeff_costs, const tran_low_t *qcoeff,
638     uint16_t eob, int plane, const int16_t *scan, const int16_t *iscan,
639     TX_SIZE tx_size, TX_TYPE tx_type) {
640   (void)tx_size;
641   (void)scan;
642   (void)eob;
643   (void)plane;
644   const TX_CLASS tx_class = get_tx_class(tx_type);
645   const int bwl = b_width_log2_lookup[txsize_to_bsize[tx_size]] + 2;
646   const int width = tx_size_wide[tx_size];
647   const int height = tx_size_high[tx_size];
648   int16_t eob_ls[MAX_HVTX_SIZE];
649   get_eob_horiz(eob_ls, qcoeff, width, height);
650   int cost = 0;
651   for (int r = 0; r < height; ++r) {
652     int16_t heob = eob_ls[r];
653     assert(heob <= width);
654     int el_ctx = get_empty_line_ctx(r, eob_ls);
655     cost += coeff_costs->empty_line_cost[tx_class][el_ctx][heob == 0];
656     if (heob) {
657       for (int c = 0; c < heob; ++c) {
658         if (c + 1 != width) {
659           int coeff_idx = r * width + c;
660           int scan_idx = iscan[coeff_idx];
661           int is_nz = qcoeff[coeff_idx] != 0;
662           int coeff_ctx =
663               get_nz_map_ctx(qcoeff, scan_idx, scan, bwl, height, tx_type);
664           cost += coeff_costs->nz_map_cost[coeff_ctx][is_nz];
665           if (is_nz) {
666             int eob_ctx = get_hv_eob_ctx(r, c, eob_ls);
667             cost += coeff_costs->hv_eob_cost[tx_class][eob_ctx][c == heob - 1];
668           }
669         }
670       }
671     }
672   }
673   return cost;
674 }
675 #endif
676 
av1_cost_coeffs_txb(const AV1_COMMON * const cm,MACROBLOCK * x,int plane,int blk_row,int blk_col,int block,TX_SIZE tx_size,TXB_CTX * txb_ctx)677 int av1_cost_coeffs_txb(const AV1_COMMON *const cm, MACROBLOCK *x, int plane,
678                         int blk_row, int blk_col, int block, TX_SIZE tx_size,
679                         TXB_CTX *txb_ctx) {
680   MACROBLOCKD *const xd = &x->e_mbd;
681   TX_SIZE txs_ctx = get_txsize_context(tx_size);
682   const PLANE_TYPE plane_type = get_plane_type(plane);
683   const TX_TYPE tx_type =
684       av1_get_tx_type(plane_type, xd, blk_row, blk_col, block, tx_size);
685   MB_MODE_INFO *mbmi = &xd->mi[0]->mbmi;
686   const struct macroblock_plane *p = &x->plane[plane];
687   const int eob = p->eobs[block];
688   const tran_low_t *const qcoeff = BLOCK_OFFSET(p->qcoeff, block);
689   int c, cost;
690   int txb_skip_ctx = txb_ctx->txb_skip_ctx;
691 
692   const int bwl = b_width_log2_lookup[txsize_to_bsize[tx_size]] + 2;
693   const int height = tx_size_high[tx_size];
694 
695   const SCAN_ORDER *const scan_order = get_scan(cm, tx_size, tx_type, mbmi);
696   const int16_t *scan = scan_order->scan;
697 
698   LV_MAP_COEFF_COST *coeff_costs = &x->coeff_costs[txs_ctx][plane_type];
699 
700   cost = 0;
701 
702   if (eob == 0) {
703     cost = coeff_costs->txb_skip_cost[txb_skip_ctx][1];
704     return cost;
705   }
706   cost = coeff_costs->txb_skip_cost[txb_skip_ctx][0];
707 
708 #if CONFIG_TXK_SEL
709   cost += av1_tx_type_cost(cm, x, xd, mbmi->sb_type, plane, tx_size, tx_type);
710 #endif
711 
712 #if CONFIG_CTX1D
713   TX_CLASS tx_class = get_tx_class(tx_type);
714   if (tx_class == TX_CLASS_2D) {
715     cost += get_nz_eob_map_cost(coeff_costs, qcoeff, eob, plane, scan, tx_size,
716                                 tx_type);
717   } else {
718     const int width = tx_size_wide[tx_size];
719     const int eob_offset = width + height;
720     const int eob_mode = eob > eob_offset;
721     cost += coeff_costs->eob_mode_cost[tx_class][eob_mode];
722     if (eob_mode == 0) {
723       cost += get_nz_eob_map_cost(coeff_costs, qcoeff, eob, plane, scan,
724                                   tx_size, tx_type);
725     } else {
726       const int16_t *iscan = scan_order->iscan;
727       assert(tx_class == TX_CLASS_VERT || tx_class == TX_CLASS_HORIZ);
728       if (tx_class == TX_CLASS_VERT)
729         cost += get_nz_eob_map_cost_vert(coeff_costs, qcoeff, eob, plane, scan,
730                                          iscan, tx_size, tx_type);
731       else
732         cost += get_nz_eob_map_cost_horiz(coeff_costs, qcoeff, eob, plane, scan,
733                                           iscan, tx_size, tx_type);
734     }
735   }
736 #else   // CONFIG_CTX1D
737   cost += get_nz_eob_map_cost(coeff_costs, qcoeff, eob, plane, scan, tx_size,
738                               tx_type);
739 #endif  // CONFIG_CTX1D
740 
741   for (c = 0; c < eob; ++c) {
742     tran_low_t v = qcoeff[scan[c]];
743     int is_nz = (v != 0);
744     int level = abs(v);
745 
746     if (is_nz) {
747       int ctx_ls[NUM_BASE_LEVELS] = { 0 };
748       int sign = (v < 0) ? 1 : 0;
749 
750       // sign bit cost
751       if (c == 0) {
752         int dc_sign_ctx = txb_ctx->dc_sign_ctx;
753         cost += coeff_costs->dc_sign_cost[dc_sign_ctx][sign];
754       } else {
755         cost += av1_cost_bit(128, sign);
756       }
757 
758       get_base_ctx_set(qcoeff, scan[c], bwl, height, ctx_ls);
759 
760       int i;
761       for (i = 0; i < NUM_BASE_LEVELS; ++i) {
762         if (level <= i) continue;
763 
764         if (level == i + 1) {
765           cost += coeff_costs->base_cost[i][ctx_ls[i]][1];
766           continue;
767         }
768         cost += coeff_costs->base_cost[i][ctx_ls[i]][0];
769       }
770 
771       if (level > NUM_BASE_LEVELS) {
772         int ctx;
773         ctx = get_br_ctx(qcoeff, scan[c], bwl, height);
774 #if BR_NODE
775         int base_range = level - 1 - NUM_BASE_LEVELS;
776         if (base_range < COEFF_BASE_RANGE) {
777           cost += coeff_costs->lps_cost[ctx][base_range];
778         } else {
779           cost += coeff_costs->lps_cost[ctx][COEFF_BASE_RANGE];
780         }
781 
782 #else
783         for (int idx = 0; idx < COEFF_BASE_RANGE; ++idx) {
784           if (level == (idx + 1 + NUM_BASE_LEVELS)) {
785             cost += coeff_costs->lps_cost[ctx][1];
786             break;
787           }
788           cost += coeff_costs->lps_cost[ctx][0];
789         }
790 #endif
791         if (level >= 1 + NUM_BASE_LEVELS + COEFF_BASE_RANGE) {
792           // residual cost
793           int r = level - COEFF_BASE_RANGE - NUM_BASE_LEVELS;
794           int ri = r;
795           int length = 0;
796 
797           while (ri) {
798             ri >>= 1;
799             ++length;
800           }
801 
802           for (ri = 0; ri < length - 1; ++ri) cost += av1_cost_bit(128, 0);
803 
804           for (ri = length - 1; ri >= 0; --ri)
805             cost += av1_cost_bit(128, (r >> ri) & 0x01);
806         }
807       }
808     }
809   }
810 
811   return cost;
812 }
813 
has_base(tran_low_t qc,int base_idx)814 static INLINE int has_base(tran_low_t qc, int base_idx) {
815   const int level = base_idx + 1;
816   return abs(qc) >= level;
817 }
818 
has_br(tran_low_t qc)819 static INLINE int has_br(tran_low_t qc) {
820   return abs(qc) >= 1 + NUM_BASE_LEVELS;
821 }
822 
get_sign_bit_cost(tran_low_t qc,int coeff_idx,const int (* dc_sign_cost)[2],int dc_sign_ctx)823 static INLINE int get_sign_bit_cost(tran_low_t qc, int coeff_idx,
824                                     const int (*dc_sign_cost)[2],
825                                     int dc_sign_ctx) {
826   const int sign = (qc < 0) ? 1 : 0;
827   // sign bit cost
828   if (coeff_idx == 0) {
829     return dc_sign_cost[dc_sign_ctx][sign];
830   } else {
831     return av1_cost_bit(128, sign);
832   }
833 }
get_golomb_cost(int abs_qc)834 static INLINE int get_golomb_cost(int abs_qc) {
835   if (abs_qc >= 1 + NUM_BASE_LEVELS + COEFF_BASE_RANGE) {
836     // residual cost
837     int r = abs_qc - COEFF_BASE_RANGE - NUM_BASE_LEVELS;
838     int ri = r;
839     int length = 0;
840 
841     while (ri) {
842       ri >>= 1;
843       ++length;
844     }
845 
846     return av1_cost_literal(2 * length - 1);
847   } else {
848     return 0;
849   }
850 }
851 
gen_txb_cache(TxbCache * txb_cache,TxbInfo * txb_info)852 void gen_txb_cache(TxbCache *txb_cache, TxbInfo *txb_info) {
853   // gen_nz_count_arr
854   const int16_t *scan = txb_info->scan_order->scan;
855   const int bwl = txb_info->bwl;
856   const int height = txb_info->height;
857   tran_low_t *qcoeff = txb_info->qcoeff;
858   const BASE_CTX_TABLE *base_ctx_table =
859       txb_info->coeff_ctx_table->base_ctx_table;
860   for (int c = 0; c < txb_info->eob; ++c) {
861     const int coeff_idx = scan[c];  // raster order
862     const int row = coeff_idx >> bwl;
863     const int col = coeff_idx - (row << bwl);
864 #if REDUCE_CONTEXT_DEPENDENCY
865     int prev_coeff_idx;
866     int prev_row;
867     int prev_col;
868     if (c > MIN_SCAN_IDX_REDUCE_CONTEXT_DEPENDENCY) {
869       prev_coeff_idx = scan[c - 1];  // raster order
870       prev_row = prev_coeff_idx >> bwl;
871       prev_col = prev_coeff_idx - (prev_row << bwl);
872     } else {
873       prev_coeff_idx = -1;
874       prev_row = -1;
875       prev_col = -1;
876     }
877     txb_cache->nz_count_arr[coeff_idx] =
878         get_nz_count(qcoeff, bwl, height, row, col, prev_row, prev_col);
879 #else
880     txb_cache->nz_count_arr[coeff_idx] =
881         get_nz_count(qcoeff, bwl, height, row, col);
882 #endif
883     const int nz_count = txb_cache->nz_count_arr[coeff_idx];
884     txb_cache->nz_ctx_arr[coeff_idx] =
885         get_nz_map_ctx_from_count(nz_count, coeff_idx, bwl, txb_info->tx_type);
886 
887     // gen_base_count_mag_arr
888     if (!has_base(qcoeff[coeff_idx], 0)) continue;
889     int *base_mag = txb_cache->base_mag_arr[coeff_idx];
890     int count[NUM_BASE_LEVELS];
891     get_base_count_mag(base_mag, count, qcoeff, bwl, height, row, col);
892 
893     for (int i = 0; i < NUM_BASE_LEVELS; ++i) {
894       if (!has_base(qcoeff[coeff_idx], i)) break;
895       txb_cache->base_count_arr[i][coeff_idx] = count[i];
896       const int level = i + 1;
897       txb_cache->base_ctx_arr[i][coeff_idx] =
898           base_ctx_table[row != 0][col != 0][base_mag[0] > level][count[i]];
899     }
900 
901     // gen_br_count_mag_arr
902     if (!has_br(qcoeff[coeff_idx])) continue;
903     int *br_count = txb_cache->br_count_arr + coeff_idx;
904     int *br_mag = txb_cache->br_mag_arr[coeff_idx];
905     *br_count = get_br_count_mag(br_mag, qcoeff, bwl, height, row, col,
906                                  NUM_BASE_LEVELS);
907     txb_cache->br_ctx_arr[coeff_idx] =
908         get_br_ctx_from_count_mag(row, col, *br_count, br_mag[0]);
909   }
910 }
911 
get_level_prob(int level,int coeff_idx,const TxbCache * txb_cache,const LV_MAP_COEFF_COST * txb_costs)912 static INLINE const int *get_level_prob(int level, int coeff_idx,
913                                         const TxbCache *txb_cache,
914                                         const LV_MAP_COEFF_COST *txb_costs) {
915   if (level == 0) {
916     const int ctx = txb_cache->nz_ctx_arr[coeff_idx];
917     return txb_costs->nz_map_cost[ctx];
918   } else if (level >= 1 && level < 1 + NUM_BASE_LEVELS) {
919     const int idx = level - 1;
920     const int ctx = txb_cache->base_ctx_arr[idx][coeff_idx];
921     return txb_costs->base_cost[idx][ctx];
922   } else if (level >= 1 + NUM_BASE_LEVELS &&
923              level < 1 + NUM_BASE_LEVELS + COEFF_BASE_RANGE) {
924     const int ctx = txb_cache->br_ctx_arr[coeff_idx];
925     return txb_costs->lps_cost[ctx];
926   } else if (level >= 1 + NUM_BASE_LEVELS + COEFF_BASE_RANGE) {
927     printf("get_level_prob does not support golomb\n");
928     assert(0);
929     return 0;
930   } else {
931     assert(0);
932     return 0;
933   }
934 }
935 
get_lower_coeff(tran_low_t qc)936 static INLINE tran_low_t get_lower_coeff(tran_low_t qc) {
937   if (qc == 0) {
938     return 0;
939   }
940   return qc > 0 ? qc - 1 : qc + 1;
941 }
942 
update_mag_arr(int * mag_arr,int abs_qc)943 static INLINE void update_mag_arr(int *mag_arr, int abs_qc) {
944   if (mag_arr[0] == abs_qc) {
945     mag_arr[1] -= 1;
946     assert(mag_arr[1] >= 0);
947   }
948 }
949 
get_mag_from_mag_arr(const int * mag_arr)950 static INLINE int get_mag_from_mag_arr(const int *mag_arr) {
951   int mag;
952   if (mag_arr[1] > 0) {
953     mag = mag_arr[0];
954   } else if (mag_arr[0] > 0) {
955     mag = mag_arr[0] - 1;
956   } else {
957     // no neighbor
958     assert(mag_arr[0] == 0 && mag_arr[1] == 0);
959     mag = 0;
960   }
961   return mag;
962 }
963 
neighbor_level_down_update(int * new_count,int * new_mag,int count,const int * mag,int coeff_idx,tran_low_t abs_nb_coeff,int nb_coeff_idx,int level,const TxbInfo * txb_info)964 static int neighbor_level_down_update(int *new_count, int *new_mag, int count,
965                                       const int *mag, int coeff_idx,
966                                       tran_low_t abs_nb_coeff, int nb_coeff_idx,
967                                       int level, const TxbInfo *txb_info) {
968   *new_count = count;
969   *new_mag = get_mag_from_mag_arr(mag);
970 
971   int update = 0;
972   // check if br_count changes
973   if (abs_nb_coeff == level) {
974     update = 1;
975     *new_count -= 1;
976     assert(*new_count >= 0);
977   }
978   const int row = coeff_idx >> txb_info->bwl;
979   const int col = coeff_idx - (row << txb_info->bwl);
980   const int nb_row = nb_coeff_idx >> txb_info->bwl;
981   const int nb_col = nb_coeff_idx - (nb_row << txb_info->bwl);
982 
983   // check if mag changes
984   if (nb_row >= row && nb_col >= col) {
985     if (abs_nb_coeff == mag[0]) {
986       assert(mag[1] > 0);
987       if (mag[1] == 1) {
988         // the nb is the only qc with max mag
989         *new_mag -= 1;
990         assert(*new_mag >= 0);
991         update = 1;
992       }
993     }
994   }
995   return update;
996 }
997 
try_neighbor_level_down_br(int coeff_idx,int nb_coeff_idx,const TxbCache * txb_cache,const LV_MAP_COEFF_COST * txb_costs,const TxbInfo * txb_info)998 static int try_neighbor_level_down_br(int coeff_idx, int nb_coeff_idx,
999                                       const TxbCache *txb_cache,
1000                                       const LV_MAP_COEFF_COST *txb_costs,
1001                                       const TxbInfo *txb_info) {
1002   const tran_low_t qc = txb_info->qcoeff[coeff_idx];
1003   const tran_low_t abs_qc = abs(qc);
1004   const int level = NUM_BASE_LEVELS + 1;
1005   if (abs_qc < level) return 0;
1006 
1007   const tran_low_t nb_coeff = txb_info->qcoeff[nb_coeff_idx];
1008   const tran_low_t abs_nb_coeff = abs(nb_coeff);
1009   const int count = txb_cache->br_count_arr[coeff_idx];
1010   const int *mag = txb_cache->br_mag_arr[coeff_idx];
1011   int new_count;
1012   int new_mag;
1013   const int update =
1014       neighbor_level_down_update(&new_count, &new_mag, count, mag, coeff_idx,
1015                                  abs_nb_coeff, nb_coeff_idx, level, txb_info);
1016   if (update) {
1017     const int row = coeff_idx >> txb_info->bwl;
1018     const int col = coeff_idx - (row << txb_info->bwl);
1019     const int ctx = txb_cache->br_ctx_arr[coeff_idx];
1020     const int org_cost = get_br_cost(abs_qc, ctx, txb_costs->lps_cost[ctx]);
1021 
1022     const int new_ctx = get_br_ctx_from_count_mag(row, col, new_count, new_mag);
1023     const int new_cost =
1024         get_br_cost(abs_qc, new_ctx, txb_costs->lps_cost[new_ctx]);
1025     const int cost_diff = -org_cost + new_cost;
1026     return cost_diff;
1027   } else {
1028     return 0;
1029   }
1030 }
1031 
try_neighbor_level_down_base(int coeff_idx,int nb_coeff_idx,const TxbCache * txb_cache,const LV_MAP_COEFF_COST * txb_costs,const TxbInfo * txb_info)1032 static int try_neighbor_level_down_base(int coeff_idx, int nb_coeff_idx,
1033                                         const TxbCache *txb_cache,
1034                                         const LV_MAP_COEFF_COST *txb_costs,
1035                                         const TxbInfo *txb_info) {
1036   const tran_low_t qc = txb_info->qcoeff[coeff_idx];
1037   const tran_low_t abs_qc = abs(qc);
1038   const BASE_CTX_TABLE *base_ctx_table =
1039       txb_info->coeff_ctx_table->base_ctx_table;
1040 
1041   int cost_diff = 0;
1042   for (int base_idx = 0; base_idx < NUM_BASE_LEVELS; ++base_idx) {
1043     const int level = base_idx + 1;
1044     if (abs_qc < level) continue;
1045 
1046     const tran_low_t nb_coeff = txb_info->qcoeff[nb_coeff_idx];
1047     const tran_low_t abs_nb_coeff = abs(nb_coeff);
1048 
1049     const int count = txb_cache->base_count_arr[base_idx][coeff_idx];
1050     const int *mag = txb_cache->base_mag_arr[coeff_idx];
1051     int new_count;
1052     int new_mag;
1053     const int update =
1054         neighbor_level_down_update(&new_count, &new_mag, count, mag, coeff_idx,
1055                                    abs_nb_coeff, nb_coeff_idx, level, txb_info);
1056     if (update) {
1057       const int row = coeff_idx >> txb_info->bwl;
1058       const int col = coeff_idx - (row << txb_info->bwl);
1059       const int ctx = txb_cache->base_ctx_arr[base_idx][coeff_idx];
1060       const int org_cost = get_base_cost(
1061           abs_qc, ctx, txb_costs->base_cost[base_idx][ctx], base_idx);
1062 
1063       const int new_ctx =
1064           base_ctx_table[row != 0][col != 0][new_mag > level][new_count];
1065       const int new_cost = get_base_cost(
1066           abs_qc, new_ctx, txb_costs->base_cost[base_idx][new_ctx], base_idx);
1067       cost_diff += -org_cost + new_cost;
1068     }
1069   }
1070   return cost_diff;
1071 }
1072 
try_neighbor_level_down_nz(int coeff_idx,int nb_coeff_idx,const TxbCache * txb_cache,const LV_MAP_COEFF_COST * txb_costs,TxbInfo * txb_info)1073 static int try_neighbor_level_down_nz(int coeff_idx, int nb_coeff_idx,
1074                                       const TxbCache *txb_cache,
1075                                       const LV_MAP_COEFF_COST *txb_costs,
1076                                       TxbInfo *txb_info) {
1077   // assume eob doesn't change
1078   const tran_low_t qc = txb_info->qcoeff[coeff_idx];
1079   const tran_low_t abs_qc = abs(qc);
1080   const tran_low_t nb_coeff = txb_info->qcoeff[nb_coeff_idx];
1081   const tran_low_t abs_nb_coeff = abs(nb_coeff);
1082   if (abs_nb_coeff != 1) return 0;
1083   const int16_t *iscan = txb_info->scan_order->iscan;
1084   const int scan_idx = iscan[coeff_idx];
1085   if (scan_idx == txb_info->seg_eob) return 0;
1086   const int nb_scan_idx = iscan[nb_coeff_idx];
1087   if (nb_scan_idx < scan_idx) {
1088     const int count = txb_cache->nz_count_arr[coeff_idx];
1089     assert(count > 0);
1090     txb_info->qcoeff[nb_coeff_idx] = get_lower_coeff(nb_coeff);
1091     const int new_ctx = get_nz_map_ctx_from_count(
1092         count - 1, coeff_idx, txb_info->bwl, txb_info->tx_type);
1093     txb_info->qcoeff[nb_coeff_idx] = nb_coeff;
1094     const int ctx = txb_cache->nz_ctx_arr[coeff_idx];
1095     const int is_nz = abs_qc > 0;
1096     const int org_cost = txb_costs->nz_map_cost[ctx][is_nz];
1097     const int new_cost = txb_costs->nz_map_cost[new_ctx][is_nz];
1098     const int cost_diff = new_cost - org_cost;
1099     return cost_diff;
1100   } else {
1101     return 0;
1102   }
1103 }
1104 
try_self_level_down(tran_low_t * low_coeff,int coeff_idx,const TxbCache * txb_cache,const LV_MAP_COEFF_COST * txb_costs,TxbInfo * txb_info)1105 static int try_self_level_down(tran_low_t *low_coeff, int coeff_idx,
1106                                const TxbCache *txb_cache,
1107                                const LV_MAP_COEFF_COST *txb_costs,
1108                                TxbInfo *txb_info) {
1109   const tran_low_t qc = txb_info->qcoeff[coeff_idx];
1110   if (qc == 0) {
1111     *low_coeff = 0;
1112     return 0;
1113   }
1114   const tran_low_t abs_qc = abs(qc);
1115   *low_coeff = get_lower_coeff(qc);
1116   int cost_diff;
1117   if (*low_coeff == 0) {
1118     const int scan_idx = txb_info->scan_order->iscan[coeff_idx];
1119     const int *level_cost =
1120         get_level_prob(abs_qc, coeff_idx, txb_cache, txb_costs);
1121     const int *low_level_cost =
1122         get_level_prob(abs(*low_coeff), coeff_idx, txb_cache, txb_costs);
1123     if (scan_idx < txb_info->seg_eob) {
1124       // When level-0, we code the binary of abs_qc > level
1125       // but when level-k k > 0 we code the binary of abs_qc == level
1126       // That's why wee need this special treatment for level-0 map
1127       // TODO(angiebird): make leve-0 consistent to other levels
1128       cost_diff = -level_cost[1] + low_level_cost[0] - low_level_cost[1];
1129     } else {
1130       cost_diff = -level_cost[1];
1131     }
1132 
1133     if (scan_idx < txb_info->seg_eob) {
1134       const int eob_ctx = get_eob_ctx(txb_info->qcoeff, coeff_idx,
1135                                       txb_info->txs_ctx, txb_info->tx_type);
1136       cost_diff -=
1137           txb_costs->eob_cost[eob_ctx][scan_idx == (txb_info->eob - 1)];
1138     }
1139 
1140     const int sign_cost = get_sign_bit_cost(
1141         qc, coeff_idx, txb_costs->dc_sign_cost, txb_info->txb_ctx->dc_sign_ctx);
1142     cost_diff -= sign_cost;
1143   } else if (abs_qc <= NUM_BASE_LEVELS) {
1144     const int *level_cost =
1145         get_level_prob(abs_qc, coeff_idx, txb_cache, txb_costs);
1146     const int *low_level_cost =
1147         get_level_prob(abs(*low_coeff), coeff_idx, txb_cache, txb_costs);
1148     cost_diff = -level_cost[1] + low_level_cost[1] - low_level_cost[0];
1149   } else if (abs_qc == NUM_BASE_LEVELS + 1) {
1150     const int *level_cost =
1151         get_level_prob(abs_qc, coeff_idx, txb_cache, txb_costs);
1152     const int *low_level_cost =
1153         get_level_prob(abs(*low_coeff), coeff_idx, txb_cache, txb_costs);
1154 #if BR_NODE
1155     cost_diff = -level_cost[0] + low_level_cost[1] - low_level_cost[0];
1156 #else
1157     cost_diff = -level_cost[1] + low_level_cost[1] - low_level_cost[0];
1158 #endif
1159   } else if (abs_qc < 1 + NUM_BASE_LEVELS + COEFF_BASE_RANGE) {
1160     const int *level_cost =
1161         get_level_prob(abs_qc, coeff_idx, txb_cache, txb_costs);
1162     const int *low_level_cost =
1163         get_level_prob(abs(*low_coeff), coeff_idx, txb_cache, txb_costs);
1164 
1165 #if BR_NODE
1166     cost_diff = -level_cost[abs_qc - 1 - NUM_BASE_LEVELS] +
1167                 low_level_cost[abs(*low_coeff) - 1 - NUM_BASE_LEVELS];
1168 #else
1169     cost_diff = -level_cost[1] + low_level_cost[1] - low_level_cost[0];
1170 #endif
1171   } else if (abs_qc == 1 + NUM_BASE_LEVELS + COEFF_BASE_RANGE) {
1172     const int *low_level_cost =
1173         get_level_prob(abs(*low_coeff), coeff_idx, txb_cache, txb_costs);
1174 #if BR_NODE
1175     cost_diff = -get_golomb_cost(abs_qc) - low_level_cost[COEFF_BASE_RANGE] +
1176                 low_level_cost[COEFF_BASE_RANGE - 1];
1177 #else
1178     cost_diff =
1179         -get_golomb_cost(abs_qc) + low_level_cost[1] - low_level_cost[0];
1180 #endif
1181   } else {
1182     assert(abs_qc > 1 + NUM_BASE_LEVELS + COEFF_BASE_RANGE);
1183     const tran_low_t abs_low_coeff = abs(*low_coeff);
1184     cost_diff = -get_golomb_cost(abs_qc) + get_golomb_cost(abs_low_coeff);
1185   }
1186   return cost_diff;
1187 }
1188 
1189 #define COST_MAP_SIZE 5
1190 #define COST_MAP_OFFSET 2
1191 
check_nz_neighbor(tran_low_t qc)1192 static INLINE int check_nz_neighbor(tran_low_t qc) { return abs(qc) == 1; }
1193 
check_base_neighbor(tran_low_t qc)1194 static INLINE int check_base_neighbor(tran_low_t qc) {
1195   return abs(qc) <= 1 + NUM_BASE_LEVELS;
1196 }
1197 
check_br_neighbor(tran_low_t qc)1198 static INLINE int check_br_neighbor(tran_low_t qc) {
1199   return abs(qc) > BR_MAG_OFFSET;
1200 }
1201 
1202 #define FAST_OPTIMIZE_TXB 1
1203 
1204 #if FAST_OPTIMIZE_TXB
1205 #define ALNB_REF_OFFSET_NUM 2
1206 static int alnb_ref_offset[ALNB_REF_OFFSET_NUM][2] = {
1207   { -1, 0 }, { 0, -1 },
1208 };
1209 #define NB_REF_OFFSET_NUM 4
1210 static int nb_ref_offset[NB_REF_OFFSET_NUM][2] = {
1211   { -1, 0 }, { 0, -1 }, { 1, 0 }, { 0, 1 },
1212 };
1213 #endif  // FAST_OPTIMIZE_TXB
1214 
1215 // TODO(angiebird): add static to this function once it's called
try_level_down(int coeff_idx,const TxbCache * txb_cache,const LV_MAP_COEFF_COST * txb_costs,TxbInfo * txb_info,int (* cost_map)[COST_MAP_SIZE],int fast_mode)1216 int try_level_down(int coeff_idx, const TxbCache *txb_cache,
1217                    const LV_MAP_COEFF_COST *txb_costs, TxbInfo *txb_info,
1218                    int (*cost_map)[COST_MAP_SIZE], int fast_mode) {
1219 #if !FAST_OPTIMIZE_TXB
1220   (void)fast_mode;
1221 #endif
1222   if (cost_map) {
1223     for (int i = 0; i < COST_MAP_SIZE; ++i) av1_zero(cost_map[i]);
1224   }
1225 
1226   tran_low_t qc = txb_info->qcoeff[coeff_idx];
1227   tran_low_t low_coeff;
1228   if (qc == 0) return 0;
1229   int accu_cost_diff = 0;
1230 
1231   const int16_t *iscan = txb_info->scan_order->iscan;
1232   const int eob = txb_info->eob;
1233   const int scan_idx = iscan[coeff_idx];
1234   if (scan_idx < eob) {
1235     const int cost_diff = try_self_level_down(&low_coeff, coeff_idx, txb_cache,
1236                                               txb_costs, txb_info);
1237     if (cost_map)
1238       cost_map[0 + COST_MAP_OFFSET][0 + COST_MAP_OFFSET] = cost_diff;
1239     accu_cost_diff += cost_diff;
1240   }
1241 
1242   const int row = coeff_idx >> txb_info->bwl;
1243   const int col = coeff_idx - (row << txb_info->bwl);
1244   if (check_nz_neighbor(qc)) {
1245 #if FAST_OPTIMIZE_TXB
1246     int(*ref_offset)[2];
1247     int ref_num;
1248     if (fast_mode) {
1249       ref_offset = alnb_ref_offset;
1250       ref_num = ALNB_REF_OFFSET_NUM;
1251     } else {
1252       ref_offset = sig_ref_offset;
1253       ref_num = SIG_REF_OFFSET_NUM;
1254     }
1255 #else
1256     int(*ref_offset)[2] = sig_ref_offset;
1257     const int ref_num = SIG_REF_OFFSET_NUM;
1258 #endif
1259     for (int i = 0; i < ref_num; ++i) {
1260       const int nb_row = row - ref_offset[i][0];
1261       const int nb_col = col - ref_offset[i][1];
1262       const int nb_coeff_idx = nb_row * txb_info->stride + nb_col;
1263 
1264       if (nb_row < 0 || nb_col < 0 || nb_row >= txb_info->height ||
1265           nb_col >= txb_info->stride)
1266         continue;
1267 
1268       const int nb_scan_idx = iscan[nb_coeff_idx];
1269       if (nb_scan_idx < eob) {
1270         const int cost_diff = try_neighbor_level_down_nz(
1271             nb_coeff_idx, coeff_idx, txb_cache, txb_costs, txb_info);
1272         if (cost_map)
1273           cost_map[nb_row - row + COST_MAP_OFFSET]
1274                   [nb_col - col + COST_MAP_OFFSET] += cost_diff;
1275         accu_cost_diff += cost_diff;
1276       }
1277     }
1278   }
1279 
1280   if (check_base_neighbor(qc)) {
1281 #if FAST_OPTIMIZE_TXB
1282     int(*ref_offset)[2];
1283     int ref_num;
1284     if (fast_mode) {
1285       ref_offset = nb_ref_offset;
1286       ref_num = NB_REF_OFFSET_NUM;
1287     } else {
1288       ref_offset = base_ref_offset;
1289       ref_num = BASE_CONTEXT_POSITION_NUM;
1290     }
1291 #else
1292     int(*ref_offset)[2] = base_ref_offset;
1293     int ref_num = BASE_CONTEXT_POSITION_NUM;
1294 #endif
1295     for (int i = 0; i < ref_num; ++i) {
1296       const int nb_row = row - ref_offset[i][0];
1297       const int nb_col = col - ref_offset[i][1];
1298       const int nb_coeff_idx = nb_row * txb_info->stride + nb_col;
1299 
1300       if (nb_row < 0 || nb_col < 0 || nb_row >= txb_info->height ||
1301           nb_col >= txb_info->stride)
1302         continue;
1303 
1304       const int nb_scan_idx = iscan[nb_coeff_idx];
1305       if (nb_scan_idx < eob) {
1306         const int cost_diff = try_neighbor_level_down_base(
1307             nb_coeff_idx, coeff_idx, txb_cache, txb_costs, txb_info);
1308         if (cost_map)
1309           cost_map[nb_row - row + COST_MAP_OFFSET]
1310                   [nb_col - col + COST_MAP_OFFSET] += cost_diff;
1311         accu_cost_diff += cost_diff;
1312       }
1313     }
1314   }
1315 
1316   if (check_br_neighbor(qc)) {
1317 #if FAST_OPTIMIZE_TXB
1318     int(*ref_offset)[2];
1319     int ref_num;
1320     if (fast_mode) {
1321       ref_offset = nb_ref_offset;
1322       ref_num = NB_REF_OFFSET_NUM;
1323     } else {
1324       ref_offset = br_ref_offset;
1325       ref_num = BR_CONTEXT_POSITION_NUM;
1326     }
1327 #else
1328     int(*ref_offset)[2] = br_ref_offset;
1329     const int ref_num = BR_CONTEXT_POSITION_NUM;
1330 #endif
1331     for (int i = 0; i < ref_num; ++i) {
1332       const int nb_row = row - ref_offset[i][0];
1333       const int nb_col = col - ref_offset[i][1];
1334       const int nb_coeff_idx = nb_row * txb_info->stride + nb_col;
1335 
1336       if (nb_row < 0 || nb_col < 0 || nb_row >= txb_info->height ||
1337           nb_col >= txb_info->stride)
1338         continue;
1339 
1340       const int nb_scan_idx = iscan[nb_coeff_idx];
1341       if (nb_scan_idx < eob) {
1342         const int cost_diff = try_neighbor_level_down_br(
1343             nb_coeff_idx, coeff_idx, txb_cache, txb_costs, txb_info);
1344         if (cost_map)
1345           cost_map[nb_row - row + COST_MAP_OFFSET]
1346                   [nb_col - col + COST_MAP_OFFSET] += cost_diff;
1347         accu_cost_diff += cost_diff;
1348       }
1349     }
1350   }
1351 
1352   return accu_cost_diff;
1353 }
1354 
get_low_coeff_cost(int coeff_idx,const TxbCache * txb_cache,const LV_MAP_COEFF_COST * txb_costs,const TxbInfo * txb_info)1355 static int get_low_coeff_cost(int coeff_idx, const TxbCache *txb_cache,
1356                               const LV_MAP_COEFF_COST *txb_costs,
1357                               const TxbInfo *txb_info) {
1358   const tran_low_t qc = txb_info->qcoeff[coeff_idx];
1359   const int abs_qc = abs(qc);
1360   assert(abs_qc <= 1);
1361   int cost = 0;
1362   const int scan_idx = txb_info->scan_order->iscan[coeff_idx];
1363   if (scan_idx < txb_info->seg_eob) {
1364     const int *level_cost = get_level_prob(0, coeff_idx, txb_cache, txb_costs);
1365     cost += level_cost[qc != 0];
1366   }
1367 
1368   if (qc != 0) {
1369     const int base_idx = 0;
1370     const int ctx = txb_cache->base_ctx_arr[base_idx][coeff_idx];
1371     cost += get_base_cost(abs_qc, ctx, txb_costs->base_cost[base_idx][ctx],
1372                           base_idx);
1373     if (scan_idx < txb_info->seg_eob) {
1374       const int eob_ctx = get_eob_ctx(txb_info->qcoeff, coeff_idx,
1375                                       txb_info->txs_ctx, txb_info->tx_type);
1376       cost += txb_costs->eob_cost[eob_ctx][scan_idx == (txb_info->eob - 1)];
1377     }
1378     cost += get_sign_bit_cost(qc, coeff_idx, txb_costs->dc_sign_cost,
1379                               txb_info->txb_ctx->dc_sign_ctx);
1380   }
1381   return cost;
1382 }
1383 
set_eob(TxbInfo * txb_info,int eob)1384 static INLINE void set_eob(TxbInfo *txb_info, int eob) {
1385   txb_info->eob = eob;
1386   txb_info->seg_eob = AOMMIN(eob, tx_size_2d[txb_info->tx_size] - 1);
1387 }
1388 
1389 // TODO(angiebird): add static to this function once it's called
try_change_eob(int * new_eob,int coeff_idx,const TxbCache * txb_cache,const LV_MAP_COEFF_COST * txb_costs,TxbInfo * txb_info,int fast_mode)1390 int try_change_eob(int *new_eob, int coeff_idx, const TxbCache *txb_cache,
1391                    const LV_MAP_COEFF_COST *txb_costs, TxbInfo *txb_info,
1392                    int fast_mode) {
1393   assert(txb_info->eob > 0);
1394   const tran_low_t qc = txb_info->qcoeff[coeff_idx];
1395   const int abs_qc = abs(qc);
1396   if (abs_qc != 1) {
1397     *new_eob = -1;
1398     return 0;
1399   }
1400   const int16_t *iscan = txb_info->scan_order->iscan;
1401   const int16_t *scan = txb_info->scan_order->scan;
1402   const int scan_idx = iscan[coeff_idx];
1403   *new_eob = 0;
1404   int cost_diff = 0;
1405   cost_diff -= get_low_coeff_cost(coeff_idx, txb_cache, txb_costs, txb_info);
1406   // int coeff_cost =
1407   //     get_coeff_cost(qc, scan_idx, txb_info, txb_probs);
1408   // if (-cost_diff != coeff_cost) {
1409   //   printf("-cost_diff %d coeff_cost %d\n", -cost_diff, coeff_cost);
1410   //   get_low_coeff_cost(coeff_idx, txb_cache, txb_probs, txb_info);
1411   //   get_coeff_cost(qc, scan_idx, txb_info, txb_probs);
1412   // }
1413   for (int si = scan_idx - 1; si >= 0; --si) {
1414     const int ci = scan[si];
1415     if (txb_info->qcoeff[ci] != 0) {
1416       *new_eob = si + 1;
1417       break;
1418     } else {
1419       cost_diff -= get_low_coeff_cost(ci, txb_cache, txb_costs, txb_info);
1420     }
1421   }
1422 
1423   const int org_eob = txb_info->eob;
1424   set_eob(txb_info, *new_eob);
1425   cost_diff += try_level_down(coeff_idx, txb_cache, txb_costs, txb_info, NULL,
1426                               fast_mode);
1427   set_eob(txb_info, org_eob);
1428 
1429   if (*new_eob > 0) {
1430     // Note that get_eob_ctx does NOT actually account for qcoeff, so we don't
1431     // need to lower down the qcoeff here
1432     const int eob_ctx = get_eob_ctx(txb_info->qcoeff, scan[*new_eob - 1],
1433                                     txb_info->txs_ctx, txb_info->tx_type);
1434     cost_diff -= txb_costs->eob_cost[eob_ctx][0];
1435     cost_diff += txb_costs->eob_cost[eob_ctx][1];
1436   } else {
1437     const int txb_skip_ctx = txb_info->txb_ctx->txb_skip_ctx;
1438     cost_diff -= txb_costs->txb_skip_cost[txb_skip_ctx][0];
1439     cost_diff += txb_costs->txb_skip_cost[txb_skip_ctx][1];
1440   }
1441   return cost_diff;
1442 }
1443 
qcoeff_to_dqcoeff(tran_low_t qc,int dqv,int shift)1444 static INLINE tran_low_t qcoeff_to_dqcoeff(tran_low_t qc, int dqv, int shift) {
1445   int sgn = qc < 0 ? -1 : 1;
1446   return sgn * ((abs(qc) * dqv) >> shift);
1447 }
1448 
1449 // TODO(angiebird): add static to this function it's called
update_level_down(int coeff_idx,TxbCache * txb_cache,TxbInfo * txb_info)1450 void update_level_down(int coeff_idx, TxbCache *txb_cache, TxbInfo *txb_info) {
1451   const tran_low_t qc = txb_info->qcoeff[coeff_idx];
1452   const int abs_qc = abs(qc);
1453   if (qc == 0) return;
1454   const tran_low_t low_coeff = get_lower_coeff(qc);
1455   txb_info->qcoeff[coeff_idx] = low_coeff;
1456   const int dqv = txb_info->dequant[coeff_idx != 0];
1457   txb_info->dqcoeff[coeff_idx] =
1458       qcoeff_to_dqcoeff(low_coeff, dqv, txb_info->shift);
1459 
1460   const int row = coeff_idx >> txb_info->bwl;
1461   const int col = coeff_idx - (row << txb_info->bwl);
1462   const int eob = txb_info->eob;
1463   const int16_t *iscan = txb_info->scan_order->iscan;
1464   for (int i = 0; i < SIG_REF_OFFSET_NUM; ++i) {
1465     const int nb_row = row - sig_ref_offset[i][0];
1466     const int nb_col = col - sig_ref_offset[i][1];
1467 
1468     if (!(nb_row >= 0 && nb_col >= 0 && nb_row < txb_info->height &&
1469           nb_col < txb_info->stride))
1470       continue;
1471 
1472     const int nb_coeff_idx = nb_row * txb_info->stride + nb_col;
1473     const int nb_scan_idx = iscan[nb_coeff_idx];
1474     if (nb_scan_idx < eob) {
1475       const int scan_idx = iscan[coeff_idx];
1476       if (scan_idx < nb_scan_idx) {
1477         const int level = 1;
1478         if (abs_qc == level) {
1479           txb_cache->nz_count_arr[nb_coeff_idx] -= 1;
1480           assert(txb_cache->nz_count_arr[nb_coeff_idx] >= 0);
1481         }
1482         const int count = txb_cache->nz_count_arr[nb_coeff_idx];
1483         txb_cache->nz_ctx_arr[nb_coeff_idx] = get_nz_map_ctx_from_count(
1484             count, nb_coeff_idx, txb_info->bwl, txb_info->tx_type);
1485         // int ref_ctx = get_nz_map_ctx(txb_info->qcoeff, nb_coeff_idx,
1486         // txb_info->bwl, tx_type);
1487         // if (ref_ctx != txb_cache->nz_ctx_arr[nb_coeff_idx])
1488         //   printf("nz ctx %d ref_ctx %d\n",
1489         //   txb_cache->nz_ctx_arr[nb_coeff_idx], ref_ctx);
1490       }
1491     }
1492   }
1493 
1494   const BASE_CTX_TABLE *base_ctx_table =
1495       txb_info->coeff_ctx_table->base_ctx_table;
1496   for (int i = 0; i < BASE_CONTEXT_POSITION_NUM; ++i) {
1497     const int nb_row = row - base_ref_offset[i][0];
1498     const int nb_col = col - base_ref_offset[i][1];
1499     const int nb_coeff_idx = nb_row * txb_info->stride + nb_col;
1500 
1501     if (!(nb_row >= 0 && nb_col >= 0 && nb_row < txb_info->height &&
1502           nb_col < txb_info->stride))
1503       continue;
1504 
1505     const tran_low_t nb_coeff = txb_info->qcoeff[nb_coeff_idx];
1506     if (!has_base(nb_coeff, 0)) continue;
1507     const int nb_scan_idx = iscan[nb_coeff_idx];
1508     if (nb_scan_idx < eob) {
1509       if (row >= nb_row && col >= nb_col)
1510         update_mag_arr(txb_cache->base_mag_arr[nb_coeff_idx], abs_qc);
1511       const int mag =
1512           get_mag_from_mag_arr(txb_cache->base_mag_arr[nb_coeff_idx]);
1513       for (int base_idx = 0; base_idx < NUM_BASE_LEVELS; ++base_idx) {
1514         if (!has_base(nb_coeff, base_idx)) continue;
1515         const int level = base_idx + 1;
1516         if (abs_qc == level) {
1517           txb_cache->base_count_arr[base_idx][nb_coeff_idx] -= 1;
1518           assert(txb_cache->base_count_arr[base_idx][nb_coeff_idx] >= 0);
1519         }
1520         const int count = txb_cache->base_count_arr[base_idx][nb_coeff_idx];
1521         txb_cache->base_ctx_arr[base_idx][nb_coeff_idx] =
1522             base_ctx_table[nb_row != 0][nb_col != 0][mag > level][count];
1523         // int ref_ctx = get_base_ctx(txb_info->qcoeff, nb_coeff_idx,
1524         // txb_info->bwl, level);
1525         // if (ref_ctx != txb_cache->base_ctx_arr[base_idx][nb_coeff_idx]) {
1526         //   printf("base ctx %d ref_ctx %d\n",
1527         //   txb_cache->base_ctx_arr[base_idx][nb_coeff_idx], ref_ctx);
1528         // }
1529       }
1530     }
1531   }
1532 
1533   for (int i = 0; i < BR_CONTEXT_POSITION_NUM; ++i) {
1534     const int nb_row = row - br_ref_offset[i][0];
1535     const int nb_col = col - br_ref_offset[i][1];
1536     const int nb_coeff_idx = nb_row * txb_info->stride + nb_col;
1537 
1538     if (!(nb_row >= 0 && nb_col >= 0 && nb_row < txb_info->height &&
1539           nb_col < txb_info->stride))
1540       continue;
1541 
1542     const int nb_scan_idx = iscan[nb_coeff_idx];
1543     const tran_low_t nb_coeff = txb_info->qcoeff[nb_coeff_idx];
1544     if (!has_br(nb_coeff)) continue;
1545     if (nb_scan_idx < eob) {
1546       const int level = 1 + NUM_BASE_LEVELS;
1547       if (abs_qc == level) {
1548         txb_cache->br_count_arr[nb_coeff_idx] -= 1;
1549         assert(txb_cache->br_count_arr[nb_coeff_idx] >= 0);
1550       }
1551       if (row >= nb_row && col >= nb_col)
1552         update_mag_arr(txb_cache->br_mag_arr[nb_coeff_idx], abs_qc);
1553       const int count = txb_cache->br_count_arr[nb_coeff_idx];
1554       const int mag = get_mag_from_mag_arr(txb_cache->br_mag_arr[nb_coeff_idx]);
1555       txb_cache->br_ctx_arr[nb_coeff_idx] =
1556           get_br_ctx_from_count_mag(nb_row, nb_col, count, mag);
1557       // int ref_ctx = get_level_ctx(txb_info->qcoeff, nb_coeff_idx,
1558       // txb_info->bwl);
1559       // if (ref_ctx != txb_cache->br_ctx_arr[nb_coeff_idx]) {
1560       //   printf("base ctx %d ref_ctx %d\n",
1561       //   txb_cache->br_ctx_arr[nb_coeff_idx], ref_ctx);
1562       // }
1563     }
1564   }
1565 }
1566 
get_coeff_cost(tran_low_t qc,int scan_idx,TxbInfo * txb_info,const LV_MAP_COEFF_COST * txb_costs)1567 static int get_coeff_cost(tran_low_t qc, int scan_idx, TxbInfo *txb_info,
1568                           const LV_MAP_COEFF_COST *txb_costs) {
1569   const TXB_CTX *txb_ctx = txb_info->txb_ctx;
1570   const int is_nz = (qc != 0);
1571   const tran_low_t abs_qc = abs(qc);
1572   int cost = 0;
1573   const int16_t *scan = txb_info->scan_order->scan;
1574 
1575   if (scan_idx < txb_info->seg_eob) {
1576     int coeff_ctx =
1577         get_nz_map_ctx(txb_info->qcoeff, scan_idx, scan, txb_info->bwl,
1578                        txb_info->height, txb_info->tx_type);
1579     cost += txb_costs->nz_map_cost[coeff_ctx][is_nz];
1580   }
1581 
1582   if (is_nz) {
1583     cost += get_sign_bit_cost(qc, scan_idx, txb_costs->dc_sign_cost,
1584                               txb_ctx->dc_sign_ctx);
1585 
1586     int ctx_ls[NUM_BASE_LEVELS] = { 0 };
1587     get_base_ctx_set(txb_info->qcoeff, scan[scan_idx], txb_info->bwl,
1588                      txb_info->height, ctx_ls);
1589 
1590     int i;
1591     for (i = 0; i < NUM_BASE_LEVELS; ++i) {
1592       cost += get_base_cost(abs_qc, ctx_ls[i],
1593                             txb_costs->base_cost[i][ctx_ls[i]], i);
1594     }
1595 
1596     if (abs_qc > NUM_BASE_LEVELS) {
1597       int ctx = get_br_ctx(txb_info->qcoeff, scan[scan_idx], txb_info->bwl,
1598                            txb_info->height);
1599       cost += get_br_cost(abs_qc, ctx, txb_costs->lps_cost[ctx]);
1600       cost += get_golomb_cost(abs_qc);
1601     }
1602 
1603     if (scan_idx < txb_info->seg_eob) {
1604       int eob_ctx = get_eob_ctx(txb_info->qcoeff, scan[scan_idx],
1605                                 txb_info->txs_ctx, txb_info->tx_type);
1606       cost += txb_costs->eob_cost[eob_ctx][scan_idx == (txb_info->eob - 1)];
1607     }
1608   }
1609   return cost;
1610 }
1611 
1612 #if TEST_OPTIMIZE_TXB
1613 #define ALL_REF_OFFSET_NUM 17
1614 static int all_ref_offset[ALL_REF_OFFSET_NUM][2] = {
1615   { 0, 0 },  { -2, -1 }, { -2, 0 }, { -2, 1 }, { -1, -2 }, { -1, -1 },
1616   { -1, 0 }, { -1, 1 },  { 0, -2 }, { 0, -1 }, { 1, -2 },  { 1, -1 },
1617   { 1, 0 },  { 2, 0 },   { 0, 1 },  { 0, 2 },  { 1, 1 },
1618 };
1619 
try_level_down_ref(int coeff_idx,const LV_MAP_COEFF_COST * txb_costs,TxbInfo * txb_info,int (* cost_map)[COST_MAP_SIZE])1620 static int try_level_down_ref(int coeff_idx, const LV_MAP_COEFF_COST *txb_costs,
1621                               TxbInfo *txb_info,
1622                               int (*cost_map)[COST_MAP_SIZE]) {
1623   if (cost_map) {
1624     for (int i = 0; i < COST_MAP_SIZE; ++i) av1_zero(cost_map[i]);
1625   }
1626   tran_low_t qc = txb_info->qcoeff[coeff_idx];
1627   if (qc == 0) return 0;
1628   int row = coeff_idx >> txb_info->bwl;
1629   int col = coeff_idx - (row << txb_info->bwl);
1630   int org_cost = 0;
1631   for (int i = 0; i < ALL_REF_OFFSET_NUM; ++i) {
1632     int nb_row = row - all_ref_offset[i][0];
1633     int nb_col = col - all_ref_offset[i][1];
1634     int nb_coeff_idx = nb_row * txb_info->stride + nb_col;
1635     int nb_scan_idx = txb_info->scan_order->iscan[nb_coeff_idx];
1636     if (nb_scan_idx < txb_info->eob && nb_row >= 0 && nb_col >= 0 &&
1637         nb_row < txb_info->height && nb_col < txb_info->stride) {
1638       tran_low_t nb_coeff = txb_info->qcoeff[nb_coeff_idx];
1639       int cost = get_coeff_cost(nb_coeff, nb_scan_idx, txb_info, txb_costs);
1640       if (cost_map)
1641         cost_map[nb_row - row + COST_MAP_OFFSET]
1642                 [nb_col - col + COST_MAP_OFFSET] -= cost;
1643       org_cost += cost;
1644     }
1645   }
1646   txb_info->qcoeff[coeff_idx] = get_lower_coeff(qc);
1647   int new_cost = 0;
1648   for (int i = 0; i < ALL_REF_OFFSET_NUM; ++i) {
1649     int nb_row = row - all_ref_offset[i][0];
1650     int nb_col = col - all_ref_offset[i][1];
1651     int nb_coeff_idx = nb_row * txb_info->stride + nb_col;
1652     int nb_scan_idx = txb_info->scan_order->iscan[nb_coeff_idx];
1653     if (nb_scan_idx < txb_info->eob && nb_row >= 0 && nb_col >= 0 &&
1654         nb_row < txb_info->height && nb_col < txb_info->stride) {
1655       tran_low_t nb_coeff = txb_info->qcoeff[nb_coeff_idx];
1656       int cost = get_coeff_cost(nb_coeff, nb_scan_idx, txb_info, txb_costs);
1657       if (cost_map)
1658         cost_map[nb_row - row + COST_MAP_OFFSET]
1659                 [nb_col - col + COST_MAP_OFFSET] += cost;
1660       new_cost += cost;
1661     }
1662   }
1663   txb_info->qcoeff[coeff_idx] = qc;
1664   return new_cost - org_cost;
1665 }
1666 
test_level_down(int coeff_idx,const TxbCache * txb_cache,const LV_MAP_COEFF_COST * txb_costs,TxbInfo * txb_info)1667 static void test_level_down(int coeff_idx, const TxbCache *txb_cache,
1668                             const LV_MAP_COEFF_COST *txb_costs,
1669                             TxbInfo *txb_info) {
1670   int cost_map[COST_MAP_SIZE][COST_MAP_SIZE];
1671   int ref_cost_map[COST_MAP_SIZE][COST_MAP_SIZE];
1672   const int cost_diff =
1673       try_level_down(coeff_idx, txb_cache, txb_costs, txb_info, cost_map, 0);
1674   const int cost_diff_ref =
1675       try_level_down_ref(coeff_idx, txb_costs, txb_info, ref_cost_map);
1676   if (cost_diff != cost_diff_ref) {
1677     printf("qc %d cost_diff %d cost_diff_ref %d\n", txb_info->qcoeff[coeff_idx],
1678            cost_diff, cost_diff_ref);
1679     for (int r = 0; r < COST_MAP_SIZE; ++r) {
1680       for (int c = 0; c < COST_MAP_SIZE; ++c) {
1681         printf("%d:%d ", cost_map[r][c], ref_cost_map[r][c]);
1682       }
1683       printf("\n");
1684     }
1685   }
1686 }
1687 #endif
1688 
1689 // TODO(angiebird): make this static once it's called
get_txb_cost(TxbInfo * txb_info,const LV_MAP_COEFF_COST * txb_costs)1690 int get_txb_cost(TxbInfo *txb_info, const LV_MAP_COEFF_COST *txb_costs) {
1691   int cost = 0;
1692   int txb_skip_ctx = txb_info->txb_ctx->txb_skip_ctx;
1693   const int16_t *scan = txb_info->scan_order->scan;
1694   if (txb_info->eob == 0) {
1695     cost = txb_costs->txb_skip_cost[txb_skip_ctx][1];
1696     return cost;
1697   }
1698   cost = txb_costs->txb_skip_cost[txb_skip_ctx][0];
1699   for (int c = 0; c < txb_info->eob; ++c) {
1700     tran_low_t qc = txb_info->qcoeff[scan[c]];
1701     int coeff_cost = get_coeff_cost(qc, c, txb_info, txb_costs);
1702     cost += coeff_cost;
1703   }
1704   return cost;
1705 }
1706 
1707 #if TEST_OPTIMIZE_TXB
test_try_change_eob(TxbInfo * txb_info,const LV_MAP_COEFF_COST * txb_costs,TxbCache * txb_cache)1708 void test_try_change_eob(TxbInfo *txb_info, const LV_MAP_COEFF_COST *txb_costs,
1709                          TxbCache *txb_cache) {
1710   int eob = txb_info->eob;
1711   const int16_t *scan = txb_info->scan_order->scan;
1712   if (eob > 0) {
1713     int last_si = eob - 1;
1714     int last_ci = scan[last_si];
1715     int last_coeff = txb_info->qcoeff[last_ci];
1716     if (abs(last_coeff) == 1) {
1717       int new_eob;
1718       int cost_diff =
1719           try_change_eob(&new_eob, last_ci, txb_cache, txb_costs, txb_info, 0);
1720       int org_eob = txb_info->eob;
1721       int cost = get_txb_cost(txb_info, txb_costs);
1722 
1723       txb_info->qcoeff[last_ci] = get_lower_coeff(last_coeff);
1724       set_eob(txb_info, new_eob);
1725       int new_cost = get_txb_cost(txb_info, txb_costs);
1726       set_eob(txb_info, org_eob);
1727       txb_info->qcoeff[last_ci] = last_coeff;
1728 
1729       int ref_cost_diff = -cost + new_cost;
1730       if (cost_diff != ref_cost_diff)
1731         printf("org_eob %d new_eob %d cost_diff %d ref_cost_diff %d\n", org_eob,
1732                new_eob, cost_diff, ref_cost_diff);
1733     }
1734   }
1735 }
1736 #endif
1737 
get_coeff_dist(tran_low_t tcoeff,tran_low_t dqcoeff,int shift)1738 static INLINE int64_t get_coeff_dist(tran_low_t tcoeff, tran_low_t dqcoeff,
1739                                      int shift) {
1740   const int64_t diff = (tcoeff - dqcoeff) * (1 << shift);
1741   const int64_t error = diff * diff;
1742   return error;
1743 }
1744 
1745 typedef struct LevelDownStats {
1746   int update;
1747   tran_low_t low_qc;
1748   tran_low_t low_dqc;
1749   int64_t rd_diff;
1750   int cost_diff;
1751   int64_t dist_diff;
1752   int new_eob;
1753 } LevelDownStats;
1754 
try_level_down_facade(LevelDownStats * stats,int scan_idx,const TxbCache * txb_cache,const LV_MAP_COEFF_COST * txb_costs,TxbInfo * txb_info,int fast_mode)1755 void try_level_down_facade(LevelDownStats *stats, int scan_idx,
1756                            const TxbCache *txb_cache,
1757                            const LV_MAP_COEFF_COST *txb_costs,
1758                            TxbInfo *txb_info, int fast_mode) {
1759   const int16_t *scan = txb_info->scan_order->scan;
1760   const int coeff_idx = scan[scan_idx];
1761   const tran_low_t qc = txb_info->qcoeff[coeff_idx];
1762   stats->new_eob = -1;
1763   stats->update = 0;
1764   if (qc == 0) {
1765     return;
1766   }
1767 
1768   const tran_low_t tqc = txb_info->tcoeff[coeff_idx];
1769   const int dqv = txb_info->dequant[coeff_idx != 0];
1770 
1771   const tran_low_t dqc = qcoeff_to_dqcoeff(qc, dqv, txb_info->shift);
1772   const int64_t dqc_dist = get_coeff_dist(tqc, dqc, txb_info->shift);
1773 
1774   stats->low_qc = get_lower_coeff(qc);
1775   stats->low_dqc = qcoeff_to_dqcoeff(stats->low_qc, dqv, txb_info->shift);
1776   const int64_t low_dqc_dist =
1777       get_coeff_dist(tqc, stats->low_dqc, txb_info->shift);
1778 
1779   stats->dist_diff = -dqc_dist + low_dqc_dist;
1780   stats->cost_diff = 0;
1781   stats->new_eob = txb_info->eob;
1782   if (scan_idx == txb_info->eob - 1 && abs(qc) == 1) {
1783     stats->cost_diff = try_change_eob(&stats->new_eob, coeff_idx, txb_cache,
1784                                       txb_costs, txb_info, fast_mode);
1785   } else {
1786     stats->cost_diff = try_level_down(coeff_idx, txb_cache, txb_costs, txb_info,
1787                                       NULL, fast_mode);
1788 #if TEST_OPTIMIZE_TXB
1789     test_level_down(coeff_idx, txb_cache, txb_costs, txb_info);
1790 #endif
1791   }
1792   stats->rd_diff = RDCOST(txb_info->rdmult, stats->cost_diff, stats->dist_diff);
1793   if (stats->rd_diff < 0) stats->update = 1;
1794   return;
1795 }
1796 
optimize_txb(TxbInfo * txb_info,const LV_MAP_COEFF_COST * txb_costs,TxbCache * txb_cache,int dry_run,int fast_mode)1797 static int optimize_txb(TxbInfo *txb_info, const LV_MAP_COEFF_COST *txb_costs,
1798                         TxbCache *txb_cache, int dry_run, int fast_mode) {
1799   int update = 0;
1800   if (txb_info->eob == 0) return update;
1801   int cost_diff = 0;
1802   int64_t dist_diff = 0;
1803   int64_t rd_diff = 0;
1804   const int max_eob = tx_size_2d[txb_info->tx_size];
1805 
1806 #if TEST_OPTIMIZE_TXB
1807   int64_t sse;
1808   int64_t org_dist =
1809       av1_block_error_c(txb_info->tcoeff, txb_info->dqcoeff, max_eob, &sse) *
1810       (1 << (2 * txb_info->shift));
1811   int org_cost = get_txb_cost(txb_info, txb_costs);
1812 #endif
1813 
1814   tran_low_t *org_qcoeff = txb_info->qcoeff;
1815   tran_low_t *org_dqcoeff = txb_info->dqcoeff;
1816 
1817   tran_low_t tmp_qcoeff[MAX_TX_SQUARE];
1818   tran_low_t tmp_dqcoeff[MAX_TX_SQUARE];
1819   const int org_eob = txb_info->eob;
1820   if (dry_run) {
1821     memcpy(tmp_qcoeff, org_qcoeff, sizeof(org_qcoeff[0]) * max_eob);
1822     memcpy(tmp_dqcoeff, org_dqcoeff, sizeof(org_dqcoeff[0]) * max_eob);
1823     txb_info->qcoeff = tmp_qcoeff;
1824     txb_info->dqcoeff = tmp_dqcoeff;
1825   }
1826 
1827   const int16_t *scan = txb_info->scan_order->scan;
1828 
1829   // forward optimize the nz_map
1830   const int cur_eob = txb_info->eob;
1831   for (int si = 0; si < cur_eob; ++si) {
1832     const int coeff_idx = scan[si];
1833     tran_low_t qc = txb_info->qcoeff[coeff_idx];
1834     if (abs(qc) == 1) {
1835       LevelDownStats stats;
1836       try_level_down_facade(&stats, si, txb_cache, txb_costs, txb_info,
1837                             fast_mode);
1838       if (stats.update) {
1839         update = 1;
1840         cost_diff += stats.cost_diff;
1841         dist_diff += stats.dist_diff;
1842         rd_diff += stats.rd_diff;
1843         update_level_down(coeff_idx, txb_cache, txb_info);
1844         set_eob(txb_info, stats.new_eob);
1845       }
1846     }
1847   }
1848 
1849   // backward optimize the level-k map
1850   int eob_fix = 0;
1851   for (int si = txb_info->eob - 1; si >= 0; --si) {
1852     const int coeff_idx = scan[si];
1853     if (eob_fix == 1 && txb_info->qcoeff[coeff_idx] == 1) {
1854       // when eob is fixed, there is not need to optimize again when
1855       // abs(qc) == 1
1856       continue;
1857     }
1858     LevelDownStats stats;
1859     try_level_down_facade(&stats, si, txb_cache, txb_costs, txb_info,
1860                           fast_mode);
1861     if (stats.update) {
1862 #if TEST_OPTIMIZE_TXB
1863 // printf("si %d low_qc %d cost_diff %d dist_diff %ld rd_diff %ld eob %d new_eob
1864 // %d\n", si, stats.low_qc, stats.cost_diff, stats.dist_diff, stats.rd_diff,
1865 // txb_info->eob, stats.new_eob);
1866 #endif
1867       update = 1;
1868       cost_diff += stats.cost_diff;
1869       dist_diff += stats.dist_diff;
1870       rd_diff += stats.rd_diff;
1871       update_level_down(coeff_idx, txb_cache, txb_info);
1872       set_eob(txb_info, stats.new_eob);
1873     }
1874     if (eob_fix == 0 && txb_info->qcoeff[coeff_idx] != 0) eob_fix = 1;
1875     if (si > txb_info->eob) si = txb_info->eob;
1876   }
1877 #if TEST_OPTIMIZE_TXB
1878   int64_t new_dist =
1879       av1_block_error_c(txb_info->tcoeff, txb_info->dqcoeff, max_eob, &sse) *
1880       (1 << (2 * txb_info->shift));
1881   int new_cost = get_txb_cost(txb_info, txb_costs);
1882   int64_t ref_dist_diff = new_dist - org_dist;
1883   int ref_cost_diff = new_cost - org_cost;
1884   if (cost_diff != ref_cost_diff || dist_diff != ref_dist_diff)
1885     printf(
1886         "overall rd_diff %ld\ncost_diff %d ref_cost_diff%d\ndist_diff %ld "
1887         "ref_dist_diff %ld\neob %d new_eob %d\n\n",
1888         rd_diff, cost_diff, ref_cost_diff, dist_diff, ref_dist_diff, org_eob,
1889         txb_info->eob);
1890 #endif
1891   if (dry_run) {
1892     txb_info->qcoeff = org_qcoeff;
1893     txb_info->dqcoeff = org_dqcoeff;
1894     set_eob(txb_info, org_eob);
1895   }
1896   return update;
1897 }
1898 
1899 // These numbers are empirically obtained.
1900 static const int plane_rd_mult[REF_TYPES][PLANE_TYPES] = {
1901   { 17, 13 }, { 16, 10 },
1902 };
1903 
av1_optimize_txb(const AV1_COMMON * cm,MACROBLOCK * x,int plane,int blk_row,int blk_col,int block,TX_SIZE tx_size,TXB_CTX * txb_ctx,int fast_mode)1904 int av1_optimize_txb(const AV1_COMMON *cm, MACROBLOCK *x, int plane,
1905                      int blk_row, int blk_col, int block, TX_SIZE tx_size,
1906                      TXB_CTX *txb_ctx, int fast_mode) {
1907   MACROBLOCKD *const xd = &x->e_mbd;
1908   const PLANE_TYPE plane_type = get_plane_type(plane);
1909   const TX_SIZE txs_ctx = get_txsize_context(tx_size);
1910   const TX_TYPE tx_type =
1911       av1_get_tx_type(plane_type, xd, blk_row, blk_col, block, tx_size);
1912   const MB_MODE_INFO *mbmi = &xd->mi[0]->mbmi;
1913   const struct macroblock_plane *p = &x->plane[plane];
1914   struct macroblockd_plane *pd = &xd->plane[plane];
1915   const int eob = p->eobs[block];
1916   tran_low_t *qcoeff = BLOCK_OFFSET(p->qcoeff, block);
1917   tran_low_t *dqcoeff = BLOCK_OFFSET(pd->dqcoeff, block);
1918   const tran_low_t *tcoeff = BLOCK_OFFSET(p->coeff, block);
1919   const int16_t *dequant = pd->dequant;
1920   const int seg_eob = AOMMIN(eob, tx_size_2d[tx_size] - 1);
1921   const int bwl = b_width_log2_lookup[txsize_to_bsize[tx_size]] + 2;
1922   const int stride = 1 << bwl;
1923   const int height = tx_size_high[tx_size];
1924   const int is_inter = is_inter_block(mbmi);
1925   const SCAN_ORDER *const scan_order = get_scan(cm, tx_size, tx_type, mbmi);
1926   const LV_MAP_COEFF_COST txb_costs = x->coeff_costs[txs_ctx][plane_type];
1927 
1928   const int shift = av1_get_tx_scale(tx_size);
1929   const int64_t rdmult =
1930       (x->rdmult * plane_rd_mult[is_inter][plane_type] + 2) >> 2;
1931 
1932   TxbInfo txb_info = { qcoeff,
1933                        dqcoeff,
1934                        tcoeff,
1935                        dequant,
1936                        shift,
1937                        tx_size,
1938                        txs_ctx,
1939                        tx_type,
1940                        bwl,
1941                        stride,
1942                        height,
1943                        eob,
1944                        seg_eob,
1945                        scan_order,
1946                        txb_ctx,
1947                        rdmult,
1948                        &cm->coeff_ctx_table };
1949 
1950   TxbCache txb_cache;
1951   gen_txb_cache(&txb_cache, &txb_info);
1952 
1953   const int update =
1954       optimize_txb(&txb_info, &txb_costs, &txb_cache, 0, fast_mode);
1955   if (update) p->eobs[block] = txb_info.eob;
1956   return txb_info.eob;
1957 }
av1_get_txb_entropy_context(const tran_low_t * qcoeff,const SCAN_ORDER * scan_order,int eob)1958 int av1_get_txb_entropy_context(const tran_low_t *qcoeff,
1959                                 const SCAN_ORDER *scan_order, int eob) {
1960   const int16_t *scan = scan_order->scan;
1961   int cul_level = 0;
1962   int c;
1963 
1964   if (eob == 0) return 0;
1965   for (c = 0; c < eob; ++c) {
1966     cul_level += abs(qcoeff[scan[c]]);
1967   }
1968 
1969   cul_level = AOMMIN(COEFF_CONTEXT_MASK, cul_level);
1970   set_dc_sign(&cul_level, qcoeff[0]);
1971 
1972   return cul_level;
1973 }
1974 
av1_update_txb_context_b(int plane,int block,int blk_row,int blk_col,BLOCK_SIZE plane_bsize,TX_SIZE tx_size,void * arg)1975 void av1_update_txb_context_b(int plane, int block, int blk_row, int blk_col,
1976                               BLOCK_SIZE plane_bsize, TX_SIZE tx_size,
1977                               void *arg) {
1978   struct tokenize_b_args *const args = arg;
1979   const AV1_COMP *cpi = args->cpi;
1980   const AV1_COMMON *cm = &cpi->common;
1981   ThreadData *const td = args->td;
1982   MACROBLOCK *const x = &td->mb;
1983   MACROBLOCKD *const xd = &x->e_mbd;
1984   MB_MODE_INFO *mbmi = &xd->mi[0]->mbmi;
1985   struct macroblock_plane *p = &x->plane[plane];
1986   struct macroblockd_plane *pd = &xd->plane[plane];
1987   const uint16_t eob = p->eobs[block];
1988   const tran_low_t *qcoeff = BLOCK_OFFSET(p->qcoeff, block);
1989   const PLANE_TYPE plane_type = pd->plane_type;
1990   const TX_TYPE tx_type =
1991       av1_get_tx_type(plane_type, xd, blk_row, blk_col, block, tx_size);
1992   const SCAN_ORDER *const scan_order = get_scan(cm, tx_size, tx_type, mbmi);
1993   (void)plane_bsize;
1994 
1995   int cul_level = av1_get_txb_entropy_context(qcoeff, scan_order, eob);
1996   av1_set_contexts(xd, pd, plane, tx_size, cul_level, blk_col, blk_row);
1997 }
1998 
av1_update_nz_eob_counts(FRAME_CONTEXT * fc,FRAME_COUNTS * counts,uint16_t eob,const tran_low_t * tcoeff,int plane,TX_SIZE tx_size,TX_TYPE tx_type,const int16_t * scan)1999 static INLINE void av1_update_nz_eob_counts(FRAME_CONTEXT *fc,
2000                                             FRAME_COUNTS *counts, uint16_t eob,
2001                                             const tran_low_t *tcoeff, int plane,
2002                                             TX_SIZE tx_size, TX_TYPE tx_type,
2003                                             const int16_t *scan) {
2004   const PLANE_TYPE plane_type = get_plane_type(plane);
2005   const int bwl = b_width_log2_lookup[txsize_to_bsize[tx_size]] + 2;
2006   const int height = tx_size_high[tx_size];
2007   TX_SIZE txsize_ctx = get_txsize_context(tx_size);
2008 #if CONFIG_CTX1D
2009   const int width = tx_size_wide[tx_size];
2010   const int eob_offset = width + height;
2011   const TX_CLASS tx_class = get_tx_class(tx_type);
2012   const int seg_eob =
2013       (tx_class == TX_CLASS_2D) ? tx_size_2d[tx_size] : eob_offset;
2014 #else
2015   const int seg_eob = tx_size_2d[tx_size];
2016 #endif
2017   unsigned int(*nz_map_count)[SIG_COEF_CONTEXTS][2] =
2018       &counts->nz_map[txsize_ctx][plane_type];
2019   for (int c = 0; c < eob; ++c) {
2020     tran_low_t v = tcoeff[scan[c]];
2021     int is_nz = (v != 0);
2022     int coeff_ctx = get_nz_map_ctx(tcoeff, c, scan, bwl, height, tx_type);
2023     int eob_ctx = get_eob_ctx(tcoeff, scan[c], txsize_ctx, tx_type);
2024 
2025     if (c == seg_eob - 1) break;
2026 
2027     ++(*nz_map_count)[coeff_ctx][is_nz];
2028 #if LV_MAP_PROB
2029     update_bin(fc->nz_map_cdf[txsize_ctx][plane_type][coeff_ctx], is_nz, 2);
2030 #endif
2031 
2032     if (is_nz) {
2033       ++counts->eob_flag[txsize_ctx][plane_type][eob_ctx][c == (eob - 1)];
2034 #if LV_MAP_PROB
2035       update_bin(fc->eob_flag_cdf[txsize_ctx][plane_type][eob_ctx],
2036                  c == (eob - 1), 2);
2037 #endif
2038     }
2039   }
2040 }
2041 
2042 #if CONFIG_CTX1D
av1_update_nz_eob_counts_vert(FRAME_CONTEXT * fc,FRAME_COUNTS * counts,uint16_t eob,const tran_low_t * tcoeff,int plane,TX_SIZE tx_size,TX_TYPE tx_type,const int16_t * scan,const int16_t * iscan)2043 static INLINE void av1_update_nz_eob_counts_vert(
2044     FRAME_CONTEXT *fc, FRAME_COUNTS *counts, uint16_t eob,
2045     const tran_low_t *tcoeff, int plane, TX_SIZE tx_size, TX_TYPE tx_type,
2046     const int16_t *scan, const int16_t *iscan) {
2047   (void)eob;
2048   const TX_SIZE txs_ctx = get_txsize_context(tx_size);
2049   const PLANE_TYPE plane_type = get_plane_type(plane);
2050   const TX_CLASS tx_class = get_tx_class(tx_type);
2051   const int bwl = b_width_log2_lookup[txsize_to_bsize[tx_size]] + 2;
2052   const int width = tx_size_wide[tx_size];
2053   const int height = tx_size_high[tx_size];
2054   int16_t eob_ls[MAX_HVTX_SIZE];
2055   get_eob_vert(eob_ls, tcoeff, width, height);
2056   unsigned int(*nz_map_count)[SIG_COEF_CONTEXTS][2] =
2057       &counts->nz_map[txs_ctx][plane_type];
2058   for (int c = 0; c < width; ++c) {
2059     int16_t veob = eob_ls[c];
2060     assert(veob <= height);
2061     int el_ctx = get_empty_line_ctx(c, eob_ls);
2062     ++counts->empty_line[txs_ctx][plane_type][tx_class][el_ctx][veob == 0];
2063 #if LV_MAP_PROB
2064     update_bin(fc->empty_line_cdf[txs_ctx][plane_type][tx_class][el_ctx],
2065                veob == 0, 2);
2066 #endif
2067     if (veob) {
2068       for (int r = 0; r < veob; ++r) {
2069         if (r + 1 != height) {
2070           int coeff_idx = r * width + c;
2071           int scan_idx = iscan[coeff_idx];
2072           int is_nz = tcoeff[coeff_idx] != 0;
2073           int coeff_ctx =
2074               get_nz_map_ctx(tcoeff, scan_idx, scan, bwl, height, tx_type);
2075           ++(*nz_map_count)[coeff_ctx][is_nz];
2076 #if LV_MAP_PROB
2077           update_bin(fc->nz_map_cdf[txs_ctx][plane_type][coeff_ctx], is_nz, 2);
2078 #endif
2079           if (is_nz) {
2080             int eob_ctx = get_hv_eob_ctx(c, r, eob_ls);
2081             ++counts->hv_eob[txs_ctx][plane_type][tx_class][eob_ctx]
2082                             [r == veob - 1];
2083 #if LV_MAP_PROB
2084             update_bin(fc->hv_eob_cdf[txs_ctx][plane_type][tx_class][eob_ctx],
2085                        r == veob - 1, 2);
2086 #endif
2087           }
2088         }
2089       }
2090     }
2091   }
2092 }
2093 
av1_update_nz_eob_counts_horiz(FRAME_CONTEXT * fc,FRAME_COUNTS * counts,uint16_t eob,const tran_low_t * tcoeff,int plane,TX_SIZE tx_size,TX_TYPE tx_type,const int16_t * scan,const int16_t * iscan)2094 static INLINE void av1_update_nz_eob_counts_horiz(
2095     FRAME_CONTEXT *fc, FRAME_COUNTS *counts, uint16_t eob,
2096     const tran_low_t *tcoeff, int plane, TX_SIZE tx_size, TX_TYPE tx_type,
2097     const int16_t *scan, const int16_t *iscan) {
2098   (void)eob;
2099   (void)scan;
2100   const TX_SIZE txs_ctx = get_txsize_context(tx_size);
2101   const PLANE_TYPE plane_type = get_plane_type(plane);
2102   const TX_CLASS tx_class = get_tx_class(tx_type);
2103   const int bwl = b_width_log2_lookup[txsize_to_bsize[tx_size]] + 2;
2104   const int width = tx_size_wide[tx_size];
2105   const int height = tx_size_high[tx_size];
2106   int16_t eob_ls[MAX_HVTX_SIZE];
2107   get_eob_horiz(eob_ls, tcoeff, width, height);
2108   unsigned int(*nz_map_count)[SIG_COEF_CONTEXTS][2] =
2109       &counts->nz_map[txs_ctx][plane_type];
2110   for (int r = 0; r < height; ++r) {
2111     int16_t heob = eob_ls[r];
2112     int el_ctx = get_empty_line_ctx(r, eob_ls);
2113     ++counts->empty_line[txs_ctx][plane_type][tx_class][el_ctx][heob == 0];
2114 #if LV_MAP_PROB
2115     update_bin(fc->empty_line_cdf[txs_ctx][plane_type][tx_class][el_ctx],
2116                heob == 0, 2);
2117 #endif
2118     if (heob) {
2119       for (int c = 0; c < heob; ++c) {
2120         if (c + 1 != width) {
2121           int coeff_idx = r * width + c;
2122           int scan_idx = iscan[coeff_idx];
2123           int is_nz = tcoeff[coeff_idx] != 0;
2124           int coeff_ctx =
2125               get_nz_map_ctx(tcoeff, scan_idx, scan, bwl, height, tx_type);
2126           ++(*nz_map_count)[coeff_ctx][is_nz];
2127 #if LV_MAP_PROB
2128           update_bin(fc->nz_map_cdf[txs_ctx][plane_type][coeff_ctx], is_nz, 2);
2129 #endif
2130           if (is_nz) {
2131             int eob_ctx = get_hv_eob_ctx(r, c, eob_ls);
2132             ++counts->hv_eob[txs_ctx][plane_type][tx_class][eob_ctx]
2133                             [c == heob - 1];
2134 #if LV_MAP_PROB
2135             update_bin(fc->hv_eob_cdf[txs_ctx][plane_type][tx_class][eob_ctx],
2136                        c == heob - 1, 2);
2137 #endif
2138           }
2139         }
2140       }
2141     }
2142   }
2143 }
2144 #endif  // CONFIG_CTX1D
2145 
av1_update_and_record_txb_context(int plane,int block,int blk_row,int blk_col,BLOCK_SIZE plane_bsize,TX_SIZE tx_size,void * arg)2146 void av1_update_and_record_txb_context(int plane, int block, int blk_row,
2147                                        int blk_col, BLOCK_SIZE plane_bsize,
2148                                        TX_SIZE tx_size, void *arg) {
2149   struct tokenize_b_args *const args = arg;
2150   const AV1_COMP *cpi = args->cpi;
2151   const AV1_COMMON *cm = &cpi->common;
2152   ThreadData *const td = args->td;
2153   MACROBLOCK *const x = &td->mb;
2154   MACROBLOCKD *const xd = &x->e_mbd;
2155   struct macroblock_plane *p = &x->plane[plane];
2156   struct macroblockd_plane *pd = &xd->plane[plane];
2157   MB_MODE_INFO *mbmi = &xd->mi[0]->mbmi;
2158   int eob = p->eobs[block], update_eob = 0;
2159   const PLANE_TYPE plane_type = pd->plane_type;
2160   const tran_low_t *qcoeff = BLOCK_OFFSET(p->qcoeff, block);
2161   tran_low_t *tcoeff = BLOCK_OFFSET(x->mbmi_ext->tcoeff[plane], block);
2162   const int segment_id = mbmi->segment_id;
2163   const TX_TYPE tx_type =
2164       av1_get_tx_type(plane_type, xd, blk_row, blk_col, block, tx_size);
2165   const SCAN_ORDER *const scan_order = get_scan(cm, tx_size, tx_type, mbmi);
2166   const int16_t *scan = scan_order->scan;
2167   const int seg_eob = av1_get_tx_eob(&cpi->common.seg, segment_id, tx_size);
2168   int c, i;
2169   TXB_CTX txb_ctx;
2170   get_txb_ctx(plane_bsize, tx_size, plane, pd->above_context + blk_col,
2171               pd->left_context + blk_row, &txb_ctx);
2172   const int bwl = b_width_log2_lookup[txsize_to_bsize[tx_size]] + 2;
2173   const int height = tx_size_high[tx_size];
2174   int cul_level = 0;
2175 
2176   TX_SIZE txsize_ctx = get_txsize_context(tx_size);
2177   FRAME_CONTEXT *ec_ctx = xd->tile_ctx;
2178 
2179   memcpy(tcoeff, qcoeff, sizeof(*tcoeff) * seg_eob);
2180 
2181   ++td->counts->txb_skip[txsize_ctx][txb_ctx.txb_skip_ctx][eob == 0];
2182 #if LV_MAP_PROB
2183   update_bin(ec_ctx->txb_skip_cdf[txsize_ctx][txb_ctx.txb_skip_ctx], eob == 0,
2184              2);
2185 #endif
2186   x->mbmi_ext->txb_skip_ctx[plane][block] = txb_ctx.txb_skip_ctx;
2187 
2188   x->mbmi_ext->eobs[plane][block] = eob;
2189 
2190   if (eob == 0) {
2191     av1_set_contexts(xd, pd, plane, tx_size, 0, blk_col, blk_row);
2192     return;
2193   }
2194 
2195 #if CONFIG_TXK_SEL
2196   av1_update_tx_type_count(cm, xd, blk_row, blk_col, block, plane,
2197                            mbmi->sb_type, get_min_tx_size(tx_size), td->counts);
2198 #endif
2199 
2200 #if CONFIG_CTX1D
2201   TX_CLASS tx_class = get_tx_class(tx_type);
2202   if (tx_class == TX_CLASS_2D) {
2203     av1_update_nz_eob_counts(ec_ctx, td->counts, eob, tcoeff, plane, tx_size,
2204                              tx_type, scan);
2205   } else {
2206     const int width = tx_size_wide[tx_size];
2207     const int eob_offset = width + height;
2208     const int eob_mode = eob > eob_offset;
2209     const TX_SIZE txs_ctx = get_txsize_context(tx_size);
2210     ++td->counts->eob_mode[txs_ctx][plane_type][tx_class][eob_mode];
2211 #if LV_MAP_PROB
2212     update_bin(ec_ctx->eob_mode_cdf[txs_ctx][plane_type][tx_class], eob_mode,
2213                2);
2214 #endif
2215     if (eob_mode == 0) {
2216       av1_update_nz_eob_counts(ec_ctx, td->counts, eob, tcoeff, plane, tx_size,
2217                                tx_type, scan);
2218     } else {
2219       const int16_t *iscan = scan_order->iscan;
2220       assert(tx_class == TX_CLASS_VERT || tx_class == TX_CLASS_HORIZ);
2221       if (tx_class == TX_CLASS_VERT)
2222         av1_update_nz_eob_counts_vert(ec_ctx, td->counts, eob, tcoeff, plane,
2223                                       tx_size, tx_type, scan, iscan);
2224       else
2225         av1_update_nz_eob_counts_horiz(ec_ctx, td->counts, eob, tcoeff, plane,
2226                                        tx_size, tx_type, scan, iscan);
2227     }
2228   }
2229 #else   // CONFIG_CTX1D
2230   av1_update_nz_eob_counts(ec_ctx, td->counts, eob, tcoeff, plane, tx_size,
2231                            tx_type, scan);
2232 #endif  // CONFIG_CTX1D
2233 
2234   // Reverse process order to handle coefficient level and sign.
2235   for (i = 0; i < NUM_BASE_LEVELS; ++i) {
2236     update_eob = 0;
2237     for (c = eob - 1; c >= 0; --c) {
2238       tran_low_t v = qcoeff[scan[c]];
2239       tran_low_t level = abs(v);
2240       int ctx;
2241 
2242       if (level <= i) continue;
2243 
2244       ctx = get_base_ctx(tcoeff, scan[c], bwl, height, i + 1);
2245 
2246       if (level == i + 1) {
2247         ++td->counts->coeff_base[txsize_ctx][plane_type][i][ctx][1];
2248 #if LV_MAP_PROB
2249         update_bin(ec_ctx->coeff_base_cdf[txsize_ctx][plane_type][i][ctx], 1,
2250                    2);
2251 #endif
2252         if (c == 0) {
2253           int dc_sign_ctx = txb_ctx.dc_sign_ctx;
2254 
2255           ++td->counts->dc_sign[plane_type][dc_sign_ctx][v < 0];
2256 #if LV_MAP_PROB
2257           update_bin(ec_ctx->dc_sign_cdf[plane_type][dc_sign_ctx], v < 0, 2);
2258 #endif
2259           x->mbmi_ext->dc_sign_ctx[plane][block] = dc_sign_ctx;
2260         }
2261         cul_level += level;
2262         continue;
2263       }
2264       ++td->counts->coeff_base[txsize_ctx][plane_type][i][ctx][0];
2265 #if LV_MAP_PROB
2266       update_bin(ec_ctx->coeff_base_cdf[txsize_ctx][plane_type][i][ctx], 0, 2);
2267 #endif
2268       update_eob = AOMMAX(update_eob, c);
2269     }
2270   }
2271 
2272   for (c = update_eob; c >= 0; --c) {
2273     tran_low_t v = qcoeff[scan[c]];
2274     tran_low_t level = abs(v);
2275     int idx;
2276     int ctx;
2277 
2278     if (level <= NUM_BASE_LEVELS) continue;
2279 
2280     cul_level += level;
2281     if (c == 0) {
2282       int dc_sign_ctx = txb_ctx.dc_sign_ctx;
2283 
2284       ++td->counts->dc_sign[plane_type][dc_sign_ctx][v < 0];
2285 #if LV_MAP_PROB
2286       update_bin(ec_ctx->dc_sign_cdf[plane_type][dc_sign_ctx], v < 0, 2);
2287 #endif
2288       x->mbmi_ext->dc_sign_ctx[plane][block] = dc_sign_ctx;
2289     }
2290 
2291     // level is above 1.
2292     ctx = get_br_ctx(tcoeff, scan[c], bwl, height);
2293 
2294 #if BR_NODE
2295     int base_range = level - 1 - NUM_BASE_LEVELS;
2296     int br_set_idx = base_range < COEFF_BASE_RANGE
2297                          ? coeff_to_br_index[base_range]
2298                          : BASE_RANGE_SETS;
2299 
2300     for (idx = 0; idx < BASE_RANGE_SETS; ++idx) {
2301       if (idx == br_set_idx) {
2302         int br_base = br_index_to_coeff[br_set_idx];
2303         int br_offset = base_range - br_base;
2304         ++td->counts->coeff_br[txsize_ctx][plane_type][idx][ctx][1];
2305 #if LV_MAP_PROB
2306         update_bin(ec_ctx->coeff_br_cdf[txsize_ctx][plane_type][idx][ctx], 1,
2307                    2);
2308 #endif
2309         int extra_bits = (1 << br_extra_bits[idx]) - 1;
2310         for (int tok = 0; tok < extra_bits; ++tok) {
2311           if (br_offset == tok) {
2312             ++td->counts->coeff_lps[txsize_ctx][plane_type][ctx][1];
2313 #if LV_MAP_PROB
2314             update_bin(ec_ctx->coeff_lps_cdf[txsize_ctx][plane_type][ctx], 1,
2315                        2);
2316 #endif
2317             break;
2318           }
2319           ++td->counts->coeff_lps[txsize_ctx][plane_type][ctx][0];
2320 #if LV_MAP_PROB
2321           update_bin(ec_ctx->coeff_lps_cdf[txsize_ctx][plane_type][ctx], 0, 2);
2322 #endif
2323         }
2324         break;
2325       }
2326       ++td->counts->coeff_br[txsize_ctx][plane_type][idx][ctx][0];
2327 #if LV_MAP_PROB
2328       update_bin(ec_ctx->coeff_br_cdf[txsize_ctx][plane_type][idx][ctx], 0, 2);
2329 #endif
2330     }
2331 #else  // BR_NODE
2332     for (idx = 0; idx < COEFF_BASE_RANGE; ++idx) {
2333       if (level == (idx + 1 + NUM_BASE_LEVELS)) {
2334         ++td->counts->coeff_lps[txsize_ctx][plane_type][ctx][1];
2335 #if LV_MAP_PROB
2336         update_bin(ec_ctx->coeff_lps_cdf[txsize_ctx][plane_type][ctx], 1, 2);
2337 #endif
2338         break;
2339       }
2340       ++td->counts->coeff_lps[txsize_ctx][plane_type][ctx][0];
2341 #if LV_MAP_PROB
2342       update_bin(ec_ctx->coeff_lps_cdf[txsize_ctx][plane_type][ctx], 0, 2);
2343 #endif
2344     }
2345     if (idx < COEFF_BASE_RANGE) continue;
2346 #endif  // BR_NODE
2347     // use 0-th order Golomb code to handle the residual level.
2348   }
2349 
2350   cul_level = AOMMIN(COEFF_CONTEXT_MASK, cul_level);
2351 
2352   // DC value
2353   set_dc_sign(&cul_level, tcoeff[0]);
2354   av1_set_contexts(xd, pd, plane, tx_size, cul_level, blk_col, blk_row);
2355 
2356 #if CONFIG_ADAPT_SCAN
2357   // Since dqcoeff is not available here, we pass qcoeff into
2358   // av1_update_scan_count_facade(). The update behavior should be the same
2359   // because av1_update_scan_count_facade() only cares if coefficients are zero
2360   // or not.
2361   av1_update_scan_count_facade((AV1_COMMON *)cm, td->counts, tx_size, tx_type,
2362                                qcoeff, eob);
2363 #endif
2364 }
2365 
av1_update_txb_context(const AV1_COMP * cpi,ThreadData * td,RUN_TYPE dry_run,BLOCK_SIZE bsize,int * rate,int mi_row,int mi_col)2366 void av1_update_txb_context(const AV1_COMP *cpi, ThreadData *td,
2367                             RUN_TYPE dry_run, BLOCK_SIZE bsize, int *rate,
2368                             int mi_row, int mi_col) {
2369   const AV1_COMMON *const cm = &cpi->common;
2370   MACROBLOCK *const x = &td->mb;
2371   MACROBLOCKD *const xd = &x->e_mbd;
2372   MB_MODE_INFO *const mbmi = &xd->mi[0]->mbmi;
2373   const int ctx = av1_get_skip_context(xd);
2374   const int skip_inc =
2375       !segfeature_active(&cm->seg, mbmi->segment_id, SEG_LVL_SKIP);
2376   struct tokenize_b_args arg = { cpi, td, NULL, 0 };
2377   (void)rate;
2378   (void)mi_row;
2379   (void)mi_col;
2380   if (mbmi->skip) {
2381     if (!dry_run) td->counts->skip[ctx][1] += skip_inc;
2382     av1_reset_skip_context(xd, mi_row, mi_col, bsize);
2383     return;
2384   }
2385 
2386   if (!dry_run) {
2387     td->counts->skip[ctx][0] += skip_inc;
2388     av1_foreach_transformed_block(xd, bsize, mi_row, mi_col,
2389                                   av1_update_and_record_txb_context, &arg);
2390   } else if (dry_run == DRY_RUN_NORMAL) {
2391     av1_foreach_transformed_block(xd, bsize, mi_row, mi_col,
2392                                   av1_update_txb_context_b, &arg);
2393   } else {
2394     printf("DRY_RUN_COSTCOEFFS is not supported yet\n");
2395     assert(0);
2396   }
2397 }
2398 
find_new_prob(unsigned int * branch_cnt,aom_prob * oldp,int * savings,int * update,aom_writer * const bc)2399 static void find_new_prob(unsigned int *branch_cnt, aom_prob *oldp,
2400                           int *savings, int *update, aom_writer *const bc) {
2401   const aom_prob upd = DIFF_UPDATE_PROB;
2402   int u = 0;
2403   aom_prob newp = get_binary_prob(branch_cnt[0], branch_cnt[1]);
2404   int s = av1_prob_diff_update_savings_search(branch_cnt, *oldp, &newp, upd, 1);
2405 
2406   if (s > 0 && newp != *oldp) u = 1;
2407 
2408   if (u)
2409     *savings += s - (int)(av1_cost_zero(upd));  // TODO(jingning): 1?
2410   else
2411     *savings -= (int)(av1_cost_zero(upd));
2412 
2413   if (update) {
2414     ++update[u];
2415     return;
2416   }
2417 
2418   aom_write(bc, u, upd);
2419   if (u) {
2420     /* send/use new probability */
2421     av1_write_prob_diff_update(bc, newp, *oldp);
2422     *oldp = newp;
2423   }
2424 }
2425 
write_txb_probs(aom_writer * const bc,AV1_COMP * cpi,TX_SIZE tx_size)2426 static void write_txb_probs(aom_writer *const bc, AV1_COMP *cpi,
2427                             TX_SIZE tx_size) {
2428   FRAME_CONTEXT *fc = cpi->common.fc;
2429   FRAME_COUNTS *counts = cpi->td.counts;
2430   int savings = 0;
2431   int update[2] = { 0, 0 };
2432   int plane, ctx, level;
2433 
2434   for (ctx = 0; ctx < TXB_SKIP_CONTEXTS; ++ctx) {
2435     find_new_prob(counts->txb_skip[tx_size][ctx], &fc->txb_skip[tx_size][ctx],
2436                   &savings, update, bc);
2437   }
2438 
2439   for (plane = 0; plane < PLANE_TYPES; ++plane) {
2440     for (ctx = 0; ctx < SIG_COEF_CONTEXTS; ++ctx) {
2441       find_new_prob(counts->nz_map[tx_size][plane][ctx],
2442                     &fc->nz_map[tx_size][plane][ctx], &savings, update, bc);
2443     }
2444   }
2445 
2446   for (plane = 0; plane < PLANE_TYPES; ++plane) {
2447     for (ctx = 0; ctx < EOB_COEF_CONTEXTS; ++ctx) {
2448       find_new_prob(counts->eob_flag[tx_size][plane][ctx],
2449                     &fc->eob_flag[tx_size][plane][ctx], &savings, update, bc);
2450     }
2451   }
2452 
2453   for (level = 0; level < NUM_BASE_LEVELS; ++level) {
2454     for (plane = 0; plane < PLANE_TYPES; ++plane) {
2455       for (ctx = 0; ctx < COEFF_BASE_CONTEXTS; ++ctx) {
2456         find_new_prob(counts->coeff_base[tx_size][plane][level][ctx],
2457                       &fc->coeff_base[tx_size][plane][level][ctx], &savings,
2458                       update, bc);
2459       }
2460     }
2461   }
2462 
2463   for (plane = 0; plane < PLANE_TYPES; ++plane) {
2464     for (ctx = 0; ctx < LEVEL_CONTEXTS; ++ctx) {
2465       find_new_prob(counts->coeff_lps[tx_size][plane][ctx],
2466                     &fc->coeff_lps[tx_size][plane][ctx], &savings, update, bc);
2467     }
2468   }
2469 
2470   // Decide if to update the model for this tx_size
2471   if (update[1] == 0 || savings < 0) {
2472     aom_write_bit(bc, 0);
2473     return;
2474   }
2475   aom_write_bit(bc, 1);
2476 
2477   for (ctx = 0; ctx < TXB_SKIP_CONTEXTS; ++ctx) {
2478     find_new_prob(counts->txb_skip[tx_size][ctx], &fc->txb_skip[tx_size][ctx],
2479                   &savings, NULL, bc);
2480   }
2481 
2482   for (plane = 0; plane < PLANE_TYPES; ++plane) {
2483     for (ctx = 0; ctx < SIG_COEF_CONTEXTS; ++ctx) {
2484       find_new_prob(counts->nz_map[tx_size][plane][ctx],
2485                     &fc->nz_map[tx_size][plane][ctx], &savings, NULL, bc);
2486     }
2487   }
2488 
2489   for (plane = 0; plane < PLANE_TYPES; ++plane) {
2490     for (ctx = 0; ctx < EOB_COEF_CONTEXTS; ++ctx) {
2491       find_new_prob(counts->eob_flag[tx_size][plane][ctx],
2492                     &fc->eob_flag[tx_size][plane][ctx], &savings, NULL, bc);
2493     }
2494   }
2495 
2496   for (level = 0; level < NUM_BASE_LEVELS; ++level) {
2497     for (plane = 0; plane < PLANE_TYPES; ++plane) {
2498       for (ctx = 0; ctx < COEFF_BASE_CONTEXTS; ++ctx) {
2499         find_new_prob(counts->coeff_base[tx_size][plane][level][ctx],
2500                       &fc->coeff_base[tx_size][plane][level][ctx], &savings,
2501                       NULL, bc);
2502       }
2503     }
2504   }
2505 
2506   for (plane = 0; plane < PLANE_TYPES; ++plane) {
2507     for (ctx = 0; ctx < LEVEL_CONTEXTS; ++ctx) {
2508       find_new_prob(counts->coeff_lps[tx_size][plane][ctx],
2509                     &fc->coeff_lps[tx_size][plane][ctx], &savings, NULL, bc);
2510     }
2511   }
2512 }
2513 
av1_write_txb_probs(AV1_COMP * cpi,aom_writer * w)2514 void av1_write_txb_probs(AV1_COMP *cpi, aom_writer *w) {
2515   const TX_MODE tx_mode = cpi->common.tx_mode;
2516   const TX_SIZE max_tx_size = tx_mode_to_biggest_tx_size[tx_mode];
2517   TX_SIZE tx_size;
2518   int ctx, plane;
2519 
2520 #if LV_MAP_PROB
2521   return;
2522 #endif
2523 
2524   for (plane = 0; plane < PLANE_TYPES; ++plane)
2525     for (ctx = 0; ctx < DC_SIGN_CONTEXTS; ++ctx)
2526       av1_cond_prob_diff_update(w, &cpi->common.fc->dc_sign[plane][ctx],
2527                                 cpi->td.counts->dc_sign[plane][ctx], 1);
2528 
2529   for (tx_size = TX_4X4; tx_size <= max_tx_size; ++tx_size)
2530     write_txb_probs(w, cpi, tx_size);
2531 }
2532 
2533 #if CONFIG_TXK_SEL
av1_search_txk_type(const AV1_COMP * cpi,MACROBLOCK * x,int plane,int block,int blk_row,int blk_col,BLOCK_SIZE plane_bsize,TX_SIZE tx_size,const ENTROPY_CONTEXT * a,const ENTROPY_CONTEXT * l,int use_fast_coef_costing,RD_STATS * rd_stats)2534 int64_t av1_search_txk_type(const AV1_COMP *cpi, MACROBLOCK *x, int plane,
2535                             int block, int blk_row, int blk_col,
2536                             BLOCK_SIZE plane_bsize, TX_SIZE tx_size,
2537                             const ENTROPY_CONTEXT *a, const ENTROPY_CONTEXT *l,
2538                             int use_fast_coef_costing, RD_STATS *rd_stats) {
2539   const AV1_COMMON *cm = &cpi->common;
2540   MACROBLOCKD *xd = &x->e_mbd;
2541   MB_MODE_INFO *mbmi = &xd->mi[0]->mbmi;
2542   TX_TYPE txk_start = DCT_DCT;
2543   TX_TYPE txk_end = TX_TYPES - 1;
2544   TX_TYPE best_tx_type = txk_start;
2545   int64_t best_rd = INT64_MAX;
2546   uint8_t best_eob = 0;
2547   const int coeff_ctx = combine_entropy_contexts(*a, *l);
2548   RD_STATS best_rd_stats;
2549   TX_TYPE tx_type;
2550 
2551   av1_invalid_rd_stats(&best_rd_stats);
2552 
2553   for (tx_type = txk_start; tx_type <= txk_end; ++tx_type) {
2554     if (plane == 0) mbmi->txk_type[(blk_row << 4) + blk_col] = tx_type;
2555     TX_TYPE ref_tx_type = av1_get_tx_type(get_plane_type(plane), xd, blk_row,
2556                                           blk_col, block, tx_size);
2557     if (tx_type != ref_tx_type) {
2558       // use av1_get_tx_type() to check if the tx_type is valid for the current
2559       // mode if it's not, we skip it here.
2560       continue;
2561     }
2562 
2563 #if CONFIG_EXT_TX
2564     const int is_inter = is_inter_block(mbmi);
2565     const TxSetType tx_set_type =
2566         get_ext_tx_set_type(get_min_tx_size(tx_size), mbmi->sb_type, is_inter,
2567                             cm->reduced_tx_set_used);
2568     if (!av1_ext_tx_used[tx_set_type][tx_type]) continue;
2569 #endif  // CONFIG_EXT_TX
2570 
2571     RD_STATS this_rd_stats;
2572     av1_invalid_rd_stats(&this_rd_stats);
2573     av1_xform_quant(cm, x, plane, block, blk_row, blk_col, plane_bsize, tx_size,
2574                     coeff_ctx, AV1_XFORM_QUANT_FP);
2575     av1_optimize_b(cm, x, plane, blk_row, blk_col, block, plane_bsize, tx_size,
2576                    a, l, 1);
2577     av1_dist_block(cpi, x, plane, plane_bsize, block, blk_row, blk_col, tx_size,
2578                    &this_rd_stats.dist, &this_rd_stats.sse,
2579                    OUTPUT_HAS_PREDICTED_PIXELS);
2580     const SCAN_ORDER *scan_order = get_scan(cm, tx_size, tx_type, mbmi);
2581     this_rd_stats.rate =
2582         av1_cost_coeffs(cpi, x, plane, blk_row, blk_col, block, tx_size,
2583                         scan_order, a, l, use_fast_coef_costing);
2584     int rd = RDCOST(x->rdmult, this_rd_stats.rate, this_rd_stats.dist);
2585 
2586     if (rd < best_rd) {
2587       best_rd = rd;
2588       best_rd_stats = this_rd_stats;
2589       best_tx_type = tx_type;
2590       best_eob = x->plane[plane].txb_entropy_ctx[block];
2591     }
2592   }
2593 
2594   av1_merge_rd_stats(rd_stats, &best_rd_stats);
2595 
2596   if (best_eob == 0 && is_inter_block(mbmi)) best_tx_type = DCT_DCT;
2597 
2598   if (plane == 0) mbmi->txk_type[(blk_row << 4) + blk_col] = best_tx_type;
2599   x->plane[plane].txb_entropy_ctx[block] = best_eob;
2600 
2601   if (!is_inter_block(mbmi)) {
2602     // intra mode needs decoded result such that the next transform block
2603     // can use it for prediction.
2604     av1_xform_quant(cm, x, plane, block, blk_row, blk_col, plane_bsize, tx_size,
2605                     coeff_ctx, AV1_XFORM_QUANT_FP);
2606     av1_optimize_b(cm, x, plane, blk_row, blk_col, block, plane_bsize, tx_size,
2607                    a, l, 1);
2608 
2609     av1_inverse_transform_block_facade(xd, plane, block, blk_row, blk_col,
2610                                        x->plane[plane].eobs[block]);
2611   }
2612   return best_rd;
2613 }
2614 #endif  // CONFIG_TXK_SEL
2615