1 /*
2  * Copyright (c) 2016, Alliance for Open Media. All rights reserved
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at https://www.aomedia.org/license/software-license. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at https://www.aomedia.org/license/patent-license.
10  */
11 
12 #include <math.h>
13 #include <stdlib.h>
14 #include "EbDefinitions.h"
15 #include "EbModeDecisionProcess.h"
16 #include "aom_dsp_rtcd.h"
17 
18 #define DIVIDE_AND_ROUND(x, y) (((x) + ((y) >> 1)) / (y))
19 
20 // Generate a random number in the range [0, 32768).
lcg_rand16(unsigned int * state)21 static INLINE unsigned int lcg_rand16(unsigned int *state) {
22     *state = (unsigned int)(*state * 1103515245ULL + 12345);
23     return *state / 65536 % 32768;
24 }
25 
26 #define AV1_K_MEANS_RENAME(func, dim) func##_dim##dim##_c
27 
28 void AV1_K_MEANS_RENAME(svt_av1_calc_indices, 1)(const int *data, const int *centroids,
29                                                  uint8_t *indices, int n, int k);
30 void AV1_K_MEANS_RENAME(svt_av1_calc_indices, 2)(const int *data, const int *centroids,
31                                                  uint8_t *indices, int n, int k);
32 void AV1_K_MEANS_RENAME(svt_av1_k_means, 1)(const int *data, int *centroids, uint8_t *indices,
33                                             int n, int k, int max_itr);
34 void AV1_K_MEANS_RENAME(svt_av1_k_means, 2)(const int *data, int *centroids, uint8_t *indices,
35                                             int n, int k, int max_itr);
36 
37 // Given 'n' 'data' points and 'k' 'centroids' each of dimension 'dim',
38 // calculate the centroid 'indices' for the data points.
av1_calc_indices(const int * data,const int * centroids,uint8_t * indices,int n,int k,int dim)39 static inline void av1_calc_indices(const int *data, const int *centroids, uint8_t *indices, int n,
40                                     int k, int dim) {
41     if (dim == 1) {
42         svt_av1_calc_indices_dim1(data, centroids, indices, n, k);
43     } else if (dim == 2) {
44         svt_av1_calc_indices_dim2(data, centroids, indices, n, k);
45     } else {
46         assert(0 && "Untemplated k means dimension");
47     }
48 }
49 
50 // Given 'n' 'data' points and an initial guess of 'k' 'centroids' each of
51 // dimension 'dim', runs up to 'max_itr' iterations of k-means algorithm to get
52 // updated 'centroids' and the centroid 'indices' for elements in 'data'.
53 // Note: the output centroids are rounded off to nearest integers.
av1_k_means(const int * data,int * centroids,uint8_t * indices,int n,int k,int dim,int max_itr)54 static inline void av1_k_means(const int *data, int *centroids, uint8_t *indices, int n, int k,
55                                int dim, int max_itr) {
56     if (dim == 1) {
57         svt_av1_k_means_dim1(data, centroids, indices, n, k, max_itr);
58     } else if (dim == 2) {
59         svt_av1_k_means_dim2(data, centroids, indices, n, k, max_itr);
60     } else {
61         assert(0 && "Untemplated k means dimension");
62     }
63 }
64 
65 #define AV1_K_MEANS_DIM 1
66 #include "k_means_template.h"
67 #undef AV1_K_MEANS_DIM
68 #define AV1_K_MEANS_DIM 2
69 #include "k_means_template.h"
70 #undef AV1_K_MEANS_DIM
71 
int_comparer(const void * a,const void * b)72 static int int_comparer(const void *a, const void *b) { return (*(int *)a - *(int *)b); }
73 
av1_remove_duplicates(int * centroids,int num_centroids)74 static int av1_remove_duplicates(int *centroids, int num_centroids) {
75     int num_unique; // number of unique centroids
76     int i;
77     qsort(centroids, num_centroids, sizeof(*centroids), int_comparer);
78     // Remove duplicates.
79     num_unique = 1;
80     for (i = 1; i < num_centroids; ++i) {
81         if (centroids[i] != centroids[i - 1]) { // found a new unique centroid
82             centroids[num_unique++] = centroids[i];
83         }
84     }
85     return num_unique;
86 }
87 
delta_encode_cost(const int * colors,int num,int bit_depth,int min_val)88 static int delta_encode_cost(const int *colors, int num, int bit_depth, int min_val) {
89     if (num <= 0)
90         return 0;
91     int bits_cost = bit_depth;
92     if (num == 1)
93         return bits_cost;
94     bits_cost += 2;
95     int       max_delta = 0;
96     int       deltas[PALETTE_MAX_SIZE];
97     const int min_bits = bit_depth - 3;
98     for (int i = 1; i < num; ++i) {
99         const int delta = colors[i] - colors[i - 1];
100         deltas[i - 1]   = delta;
101         assert(delta >= min_val);
102         if (delta > max_delta)
103             max_delta = delta;
104     }
105     int bits_per_delta = AOMMAX(av1_ceil_log2(max_delta + 1 - min_val), min_bits);
106     assert(bits_per_delta <= bit_depth);
107     int range = (1 << bit_depth) - colors[0] - min_val;
108     for (int i = 0; i < num - 1; ++i) {
109         bits_cost += bits_per_delta;
110         range -= deltas[i];
111         bits_per_delta = AOMMIN(bits_per_delta, av1_ceil_log2(range));
112     }
113     return bits_cost;
114 }
115 
svt_av1_index_color_cache(const uint16_t * color_cache,int n_cache,const uint16_t * colors,int n_colors,uint8_t * cache_color_found,int * out_cache_colors)116 int svt_av1_index_color_cache(const uint16_t *color_cache, int n_cache, const uint16_t *colors,
117                               int n_colors, uint8_t *cache_color_found, int *out_cache_colors) {
118     if (n_cache <= 0) {
119         for (int i = 0; i < n_colors; ++i) out_cache_colors[i] = colors[i];
120         return n_colors;
121     }
122     memset(cache_color_found, 0, n_cache * sizeof(*cache_color_found));
123     int n_in_cache = 0;
124     int in_cache_flags[PALETTE_MAX_SIZE];
125     memset(in_cache_flags, 0, sizeof(in_cache_flags));
126     for (int i = 0; i < n_cache && n_in_cache < n_colors; ++i) {
127         for (int j = 0; j < n_colors; ++j) {
128             if (colors[j] == color_cache[i]) {
129                 in_cache_flags[j]    = 1;
130                 cache_color_found[i] = 1;
131                 ++n_in_cache;
132                 break;
133             }
134         }
135     }
136     int j = 0;
137     for (int i = 0; i < n_colors; ++i)
138         if (!in_cache_flags[i])
139             out_cache_colors[j++] = colors[i];
140     assert(j == n_colors - n_in_cache);
141     return j;
142 }
143 
svt_av1_palette_color_cost_y(const PaletteModeInfo * const pmi,uint16_t * color_cache,int n_cache,int bit_depth)144 int svt_av1_palette_color_cost_y(const PaletteModeInfo *const pmi, uint16_t *color_cache,
145                                  int n_cache, int bit_depth) {
146     const int n = pmi->palette_size[0];
147     int       out_cache_colors[PALETTE_MAX_SIZE];
148     uint8_t   cache_color_found[2 * PALETTE_MAX_SIZE];
149     const int n_out_cache = svt_av1_index_color_cache(
150         color_cache, n_cache, pmi->palette_colors, n, cache_color_found, out_cache_colors);
151     const int total_bits = n_cache + delta_encode_cost(out_cache_colors, n_out_cache, bit_depth, 1);
152     return av1_cost_literal(total_bits);
153 }
154 
palette_add_to_cache(uint16_t * cache,int * n,uint16_t val)155 static void palette_add_to_cache(uint16_t *cache, int *n, uint16_t val) {
156     // Do not add an already existing value
157     if (*n > 0 && val == cache[*n - 1])
158         return;
159 
160     cache[(*n)++] = val;
161 }
162 
svt_get_palette_cache(const MacroBlockD * const xd,int plane,uint16_t * cache)163 int svt_get_palette_cache(const MacroBlockD *const xd, int plane, uint16_t *cache) {
164     const int row = -xd->mb_to_top_edge >> 3;
165     // Do not refer to above SB row when on SB boundary.
166     const MbModeInfo *const above_mi = (row % (1 << MIN_SB_SIZE_LOG2)) ? xd->above_mbmi : NULL;
167     const MbModeInfo *const left_mi  = xd->left_mbmi;
168     int                     above_n = 0, left_n = 0;
169     if (above_mi)
170         above_n = above_mi->palette_mode_info.palette_size[plane != 0];
171     if (left_mi)
172         left_n = left_mi->palette_mode_info.palette_size[plane != 0];
173     if (above_n == 0 && left_n == 0)
174         return 0;
175     int             above_idx    = plane * PALETTE_MAX_SIZE;
176     int             left_idx     = plane * PALETTE_MAX_SIZE;
177     int             n            = 0;
178     const uint16_t *above_colors = above_mi ? above_mi->palette_mode_info.palette_colors : NULL;
179     const uint16_t *left_colors  = left_mi ? left_mi->palette_mode_info.palette_colors : NULL;
180     // Merge the sorted lists of base colors from above and left to get
181     // combined sorted color cache.
182     while (above_n > 0 && left_n > 0) {
183         uint16_t v_above = above_colors[above_idx];
184         uint16_t v_left  = left_colors[left_idx];
185         if (v_left < v_above) {
186             palette_add_to_cache(cache, &n, v_left);
187             ++left_idx, --left_n;
188         } else {
189             palette_add_to_cache(cache, &n, v_above);
190             ++above_idx, --above_n;
191             if (v_left == v_above)
192                 ++left_idx, --left_n;
193         }
194     }
195     while (above_n-- > 0) {
196         uint16_t val = above_colors[above_idx++];
197         palette_add_to_cache(cache, &n, val);
198     }
199     while (left_n-- > 0) {
200         uint16_t val = left_colors[left_idx++];
201         palette_add_to_cache(cache, &n, val);
202     }
203     assert(n <= 2 * PALETTE_MAX_SIZE);
204     return n;
205 }
206 // Returns sub-sampled dimensions of the given block.
207 // The output values for 'rows_within_bounds' and 'cols_within_bounds' will
208 // differ from 'height' and 'width' when part of the block is outside the
209 // right
210 // and/or bottom image boundary.
av1_get_block_dimensions(BlockSize bsize,int plane,const MacroBlockD * xd,int * width,int * height,int * rows_within_bounds,int * cols_within_bounds)211 void av1_get_block_dimensions(BlockSize bsize, int plane, const MacroBlockD *xd, int *width,
212                               int *height, int *rows_within_bounds, int *cols_within_bounds) {
213     const int block_height = block_size_high[bsize];
214     const int block_width  = block_size_wide[bsize];
215     const int block_rows   = (xd->mb_to_bottom_edge >= 0)
216           ? block_height
217           : (xd->mb_to_bottom_edge >> 3) + block_height;
218     const int block_cols   = (xd->mb_to_right_edge >= 0) ? block_width
219                                                          : (xd->mb_to_right_edge >> 3) + block_width;
220 
221     uint8_t subsampling_x = plane == 0 ? 0 : 1;
222     uint8_t subsampling_y = plane == 0 ? 0 : 1;
223 
224     assert(block_width >= block_cols);
225     assert(block_height >= block_rows);
226     const int plane_block_width  = block_width >> subsampling_x;
227     const int plane_block_height = block_height >> subsampling_y;
228     // Special handling for chroma sub8x8.
229     const int is_chroma_sub8_x = plane > 0 && plane_block_width < 4;
230     const int is_chroma_sub8_y = plane > 0 && plane_block_height < 4;
231     if (width)
232         *width = plane_block_width + 2 * is_chroma_sub8_x;
233     if (height)
234         *height = plane_block_height + 2 * is_chroma_sub8_y;
235     if (rows_within_bounds) {
236         *rows_within_bounds = (block_rows >> subsampling_y) + 2 * is_chroma_sub8_y;
237     }
238     if (cols_within_bounds) {
239         *cols_within_bounds = (block_cols >> subsampling_x) + 2 * is_chroma_sub8_x;
240     }
241 }
242 
243 // Bias toward using colors in the cache.
244 // TODO: Try other schemes to improve compression.
optimize_palette_colors(uint16_t * color_cache,int n_cache,int n_colors,int stride,int * centroids)245 static AOM_INLINE void optimize_palette_colors(uint16_t *color_cache, int n_cache, int n_colors,
246                                                int stride, int *centroids) {
247     if (n_cache <= 0)
248         return;
249     for (int i = 0; i < n_colors * stride; i += stride) {
250         int min_diff = abs(centroids[i] - (int)color_cache[0]);
251         int idx      = 0;
252         for (int j = 1; j < n_cache; ++j) {
253             const int this_diff = abs(centroids[i] - color_cache[j]);
254             if (this_diff < min_diff) {
255                 min_diff = this_diff;
256                 idx      = j;
257             }
258         }
259         if (min_diff <= 1)
260             centroids[i] = color_cache[idx];
261     }
262 }
263 // Extends 'color_map' array from 'orig_width x orig_height' to 'new_width x
264 // new_height'. Extra rows and columns are filled in by copying last valid
265 // row/column.
extend_palette_color_map(uint8_t * const color_map,int orig_width,int orig_height,int new_width,int new_height)266 static AOM_INLINE void extend_palette_color_map(uint8_t *const color_map, int orig_width,
267                                                 int orig_height, int new_width, int new_height) {
268     int j;
269     assert(new_width >= orig_width);
270     assert(new_height >= orig_height);
271     if (new_width == orig_width && new_height == orig_height)
272         return;
273 
274     for (j = orig_height - 1; j >= 0; --j) {
275         memmove(color_map + j * new_width, color_map + j * orig_width, orig_width);
276         // Copy last column to extra columns.
277         memset(color_map + j * new_width + orig_width,
278                color_map[j * new_width + orig_width - 1],
279                new_width - orig_width);
280     }
281     // Copy last row to extra rows.
282     for (j = orig_height; j < new_height; ++j) {
283         svt_memcpy(color_map + j * new_width, color_map + (orig_height - 1) * new_width, new_width);
284     }
285 }
palette_rd_y(PaletteInfo * palette_info,ModeDecisionContext * context_ptr,BlockSize bsize,const int * data,int * centroids,int n,uint16_t * color_cache,int n_cache,int bit_depth)286 void palette_rd_y(PaletteInfo *palette_info, ModeDecisionContext *context_ptr, BlockSize bsize,
287                   const int *data, int *centroids, int n, uint16_t *color_cache, int n_cache,
288                   int bit_depth) {
289     optimize_palette_colors(color_cache, n_cache, n, 1, centroids);
290     int k = av1_remove_duplicates(centroids, n);
291     if (k < PALETTE_MIN_SIZE) {
292         // Too few unique colors to create a palette. And DC_PRED will work
293         // well for that case anyway. So skip.
294         palette_info->pmi.palette_size[0] = 0;
295         return;
296     }
297 
298     if (bit_depth > EB_8BIT) {
299         for (int i = 0; i < k; ++i)
300             palette_info->pmi.palette_colors[i] = clip_pixel_highbd((int)centroids[i], bit_depth);
301     } else {
302         for (int i = 0; i < k; ++i) palette_info->pmi.palette_colors[i] = clip_pixel(centroids[i]);
303     }
304     palette_info->pmi.palette_size[0] = k;
305 
306     uint8_t *const color_map = palette_info->color_idx_map;
307     int            block_width, block_height, rows, cols;
308     av1_get_block_dimensions(
309         bsize, 0, context_ptr->blk_ptr->av1xd, &block_width, &block_height, &rows, &cols);
310     av1_calc_indices(data, centroids, color_map, rows * cols, k, 1);
311     extend_palette_color_map(color_map, cols, rows, block_width, block_height);
312 }
313 
314 int svt_av1_count_colors(const uint8_t *src, int stride, int rows, int cols, int *val_count);
315 int svt_av1_count_colors_highbd(uint16_t *src, int stride, int rows, int cols, int bit_depth,
316                                 int *val_count);
317 /****************************************
318    determine all palette luma candidates
319  ****************************************/
search_palette_luma(PictureControlSet * pcs_ptr,ModeDecisionContext * context_ptr,PaletteInfo * palette_cand,uint32_t * tot_palette_cands)320 void search_palette_luma(PictureControlSet *pcs_ptr, ModeDecisionContext *context_ptr,
321                          PaletteInfo *palette_cand, uint32_t *tot_palette_cands) {
322     int    colors;
323     EbBool is16bit = context_ptr->hbd_mode_decision > 0;
324 
325     EbPictureBufferDesc *src_pic    = is16bit ? pcs_ptr->input_frame16bit
326                                               : pcs_ptr->parent_pcs_ptr->enhanced_picture_ptr;
327     const int            src_stride = src_pic->stride_y;
328 
329     const uint8_t *const src = src_pic->buffer_y +
330         (((context_ptr->blk_origin_x + src_pic->origin_x) +
331           (context_ptr->blk_origin_y + src_pic->origin_y) * src_pic->stride_y)
332          << is16bit);
333     int block_width, block_height, rows, cols;
334     MacroBlockD *xd = context_ptr->blk_ptr->av1xd;
335     BlockSize    bsize = context_ptr->blk_geom->bsize;
336     av1_get_block_dimensions(context_ptr->blk_geom->bsize,
337                              0,
338                              context_ptr->blk_ptr->av1xd,
339                              &block_width,
340                              &block_height,
341                              &rows,
342                              &cols);
343 
344     int count_buf[1 << 12]; // Maximum (1 << 12) color levels.
345 
346     unsigned bit_depth = pcs_ptr->parent_pcs_ptr->scs_ptr->encoder_bit_depth;
347     if (is16bit)
348         colors = svt_av1_count_colors_highbd(
349             (uint16_t *)src, src_stride, rows, cols, bit_depth, count_buf);
350     else
351         colors = svt_av1_count_colors(src, src_stride, rows, cols, count_buf);
352 
353     if (colors > 1 && colors <= 64) {
354         int        r, c, i;
355         const int  max_itr = 50;
356         int *const data    = context_ptr->palette_buffer.kmeans_data_buf;
357         int        centroids[PALETTE_MAX_SIZE];
358         int        lb, ub;
359 
360 #define GENERATE_KMEANS_DATA(src_data_type)                                    \
361     do {                                                                       \
362         lb = ub = ((src_data_type)src)[0];                                     \
363         for (r = 0; r < rows; ++r) {                                           \
364             for (c = 0; c < cols; ++c) {                                       \
365                 int val            = ((src_data_type)src)[r * src_stride + c]; \
366                 data[r * cols + c] = val;                                      \
367                 if (val < lb)                                                  \
368                     lb = val;                                                  \
369                 else if (val > ub)                                             \
370                     ub = val;                                                  \
371             }                                                                  \
372         }                                                                      \
373     } while (0)
374 
375         if (is16bit)
376             GENERATE_KMEANS_DATA(uint16_t *);
377         else
378             GENERATE_KMEANS_DATA(uint8_t *);
379 
380         uint16_t  color_cache[2 * PALETTE_MAX_SIZE];
381         const int n_cache = svt_get_palette_cache(xd, 0, color_cache);
382 
383         // Find the dominant colors, stored in top_colors[].
384         int top_colors[PALETTE_MAX_SIZE] = {0};
385         for (i = 0; i < AOMMIN(colors, PALETTE_MAX_SIZE); ++i) {
386             int max_count = 0;
387             for (int j = 0; j < (1 << bit_depth); ++j) {
388                 if (count_buf[j] > max_count) {
389                     max_count     = count_buf[j];
390                     top_colors[i] = j;
391                 }
392             }
393             assert(max_count > 0);
394             count_buf[top_colors[i]] = 0;
395         }
396 
397         // Try the dominant colors directly.
398         // TODO: Try to avoid duplicate computation in cases
399         // where the dominant colors and the k-means results are similar.
400 
401         int step = (pcs_ptr->parent_pcs_ptr->palette_level == 6) ? 2 : 1;
402         for (int n = AOMMIN(colors, PALETTE_MAX_SIZE); n >= 2; n -= step) {
403             for (i = 0; i < n; ++i) centroids[i] = top_colors[i];
404 
405             palette_rd_y(&palette_cand[*tot_palette_cands],
406                          context_ptr,
407                          bsize,
408                          data,
409                          centroids,
410                          n,
411                          color_cache,
412                          n_cache,
413                          bit_depth);
414 
415             //consider this candidate if it has some non zero palette
416             if (palette_cand[*tot_palette_cands].pmi.palette_size[0] > 2)
417                 (*tot_palette_cands)++;
418             assert((*tot_palette_cands) <= 14);
419         }
420 
421         // K-means clustering.
422         for (int n = AOMMIN(colors, PALETTE_MAX_SIZE); n >= 2; --n) {
423             if (colors == PALETTE_MIN_SIZE) {
424                 // Special case: These colors automatically become the centroids.
425                 assert(colors == n);
426                 centroids[0] = lb;
427                 centroids[1] = ub;
428             } else {
429                 for (i = 0; i < n; ++i) { centroids[i] = lb + (2 * i + 1) * (ub - lb) / n / 2; }
430                 uint8_t *const color_map = palette_cand[*tot_palette_cands].color_idx_map;
431                 av1_k_means(data, centroids, color_map, rows * cols, n, 1, max_itr);
432             }
433 
434             palette_rd_y(&palette_cand[*tot_palette_cands],
435                          context_ptr,
436                          bsize,
437                          data,
438                          centroids,
439                          n,
440                          color_cache,
441                          n_cache,
442                          bit_depth);
443 
444             //consider this candidate if it has some non zero palette
445             if (palette_cand[*tot_palette_cands].pmi.palette_size[0] > 2)
446                 (*tot_palette_cands)++;
447 
448             assert((*tot_palette_cands) <= 14);
449         }
450     }
451 }
452 
453 /* clang-format off */
454  typedef AomCdfProb(*MapCdf)[PALETTE_COLOR_INDEX_CONTEXTS]
455      [CDF_SIZE(PALETTE_COLORS)];
456  typedef const int(*ColorCost)[PALETTE_SIZES][PALETTE_COLOR_INDEX_CONTEXTS]
457      [PALETTE_COLORS];
458 /* clang-format on */
459 
460 typedef struct {
461     int       rows;
462     int       cols;
463     int       n_colors;
464     int       plane_width;
465     uint8_t * color_map;
466     MapCdf    map_cdf;
467     ColorCost color_cost;
468 } Av1ColorMapParam;
469 
get_palette_params(FRAME_CONTEXT * frame_context,BlkStruct * blk_ptr,int plane,BlockSize bsize,Av1ColorMapParam * params)470 static void get_palette_params(FRAME_CONTEXT *frame_context, BlkStruct *blk_ptr, int plane,
471                                BlockSize bsize, Av1ColorMapParam *params) {
472     const MacroBlockD *const     xd   = blk_ptr->av1xd;
473     MbModeInfo *                 mbmi = &(xd->mi[0]->mbmi);
474     const PaletteModeInfo *const pmi  = &mbmi->palette_mode_info;
475     params->color_map                 = blk_ptr->palette_info.color_idx_map;
476     params->map_cdf                   = plane ? frame_context->palette_uv_color_index_cdf
477                                               : frame_context->palette_y_color_index_cdf;
478     params->color_cost                = NULL;
479     params->n_colors                  = pmi->palette_size[plane];
480     av1_get_block_dimensions(
481         bsize, plane, xd, &params->plane_width, NULL, &params->rows, &params->cols);
482 }
483 
get_color_map_params(FRAME_CONTEXT * frame_context,BlkStruct * blk_ptr,int plane,BlockSize bsize,TxSize tx_size,COLOR_MAP_TYPE type,Av1ColorMapParam * params)484 static void get_color_map_params(FRAME_CONTEXT *frame_context, BlkStruct *blk_ptr, int plane,
485                                  BlockSize bsize, TxSize tx_size, COLOR_MAP_TYPE type,
486                                  Av1ColorMapParam *params) {
487     (void)tx_size;
488     memset(params, 0, sizeof(*params));
489     switch (type) {
490     case PALETTE_MAP: get_palette_params(frame_context, blk_ptr, plane, bsize, params); break;
491     default: assert(0 && "Invalid color map type"); return;
492     }
493 }
get_palette_params_rate(PaletteInfo * palette_info,MdRateEstimationContext * rate_table,BlkStruct * blk_ptr,int plane,BlockSize bsize,Av1ColorMapParam * params)494 static void get_palette_params_rate(PaletteInfo *palette_info, MdRateEstimationContext *rate_table,
495                                     BlkStruct *blk_ptr, int plane, BlockSize bsize,
496                                     Av1ColorMapParam *params) {
497     const MacroBlockD *const     xd  = blk_ptr->av1xd;
498     const PaletteModeInfo *const pmi = &palette_info->pmi;
499 
500     params->color_map  = palette_info->color_idx_map;
501     params->map_cdf    = NULL;
502     params->color_cost = plane ? NULL : (ColorCost)&rate_table->palette_ycolor_fac_bitss;
503     params->n_colors   = pmi->palette_size[plane];
504 
505     av1_get_block_dimensions(
506         bsize, plane, xd, &params->plane_width, NULL, &params->rows, &params->cols);
507 }
508 
get_color_map_params_rate(PaletteInfo * palette_info,MdRateEstimationContext * rate_table,BlkStruct * blk_ptr,int plane,BlockSize bsize,COLOR_MAP_TYPE type,Av1ColorMapParam * params)509 static void get_color_map_params_rate(PaletteInfo *                            palette_info,
510                                       MdRateEstimationContext *                rate_table,
511                                       /*const MACROBLOCK *const x*/ BlkStruct *blk_ptr, int plane,
512                                       BlockSize bsize, COLOR_MAP_TYPE type,
513                                       Av1ColorMapParam *params) {
514     memset(params, 0, sizeof(*params));
515     switch (type) {
516     case PALETTE_MAP:
517         get_palette_params_rate(palette_info, rate_table, blk_ptr, plane, bsize, params);
518         break;
519     default: assert(0 && "Invalid color map type"); return;
520     }
521 }
522 
cost_and_tokenize_map(Av1ColorMapParam * param,TOKENEXTRA ** t,int plane,int calc_rate,int allow_update_cdf,MapCdf map_pb_cdf)523 static int cost_and_tokenize_map(Av1ColorMapParam *param, TOKENEXTRA **t, int plane, int calc_rate,
524                                  int allow_update_cdf, MapCdf map_pb_cdf) {
525     const uint8_t *const color_map         = param->color_map;
526     MapCdf               map_cdf           = param->map_cdf;
527     ColorCost            color_cost        = param->color_cost;
528     const int            plane_block_width = param->plane_width;
529     const int            rows              = param->rows;
530     const int            cols              = param->cols;
531     const int            n                 = param->n_colors;
532     const int            palette_size_idx  = n - PALETTE_MIN_SIZE;
533     int                  this_rate         = 0;
534 
535     (void)plane;
536 
537     for (int k = 1; k < rows + cols - 1; ++k) {
538         for (int j = AOMMIN(k, cols - 1); j >= AOMMAX(0, k - rows + 1); --j) {
539             int       i = k - j;
540             int       color_new_idx;
541             const int color_ctx = av1_get_palette_color_index_context_optimized(
542                 color_map, plane_block_width, i, j, n, &color_new_idx);
543             assert(color_new_idx >= 0 && color_new_idx < n);
544             if (calc_rate) {
545                 this_rate += (*color_cost)[palette_size_idx][color_ctx][color_new_idx];
546             } else {
547                 (*t)->token         = color_new_idx;
548                 (*t)->color_map_cdf = map_pb_cdf[palette_size_idx][color_ctx];
549                 ++(*t);
550                 if (allow_update_cdf)
551                     update_cdf(map_cdf[palette_size_idx][color_ctx], color_new_idx, n);
552 #if CONFIG_ENTROPY_STATS
553                 if (plane) {
554                     ++counts->palette_uv_color_index[palette_size_idx][color_ctx][color_new_idx];
555                 } else {
556                     ++counts->palette_y_color_index[palette_size_idx][color_ctx][color_new_idx];
557                 }
558 #endif
559             }
560         }
561     }
562     return this_rate;
563 }
564 
svt_av1_tokenize_color_map(FRAME_CONTEXT * frame_context,BlkStruct * blk_ptr,int plane,TOKENEXTRA ** t,BlockSize bsize,TxSize tx_size,COLOR_MAP_TYPE type,int allow_update_cdf)565 void svt_av1_tokenize_color_map(FRAME_CONTEXT *frame_context, BlkStruct *blk_ptr, int plane,
566                                 TOKENEXTRA **t, BlockSize bsize, TxSize tx_size,
567                                 COLOR_MAP_TYPE type, int allow_update_cdf) {
568     assert(plane == 0 || plane == 1);
569     Av1ColorMapParam color_map_params;
570     get_color_map_params(frame_context, blk_ptr, plane, bsize, tx_size, type, &color_map_params);
571     // The first color index does not use context or entropy.
572     (*t)->token         = color_map_params.color_map[0];
573     (*t)->color_map_cdf = NULL;
574     ++(*t);
575     MapCdf map_pb_cdf = plane ? frame_context->palette_uv_color_index_cdf
576                               : frame_context->palette_y_color_index_cdf;
577     cost_and_tokenize_map(&color_map_params, t, plane, 0, allow_update_cdf, map_pb_cdf);
578 }
svt_av1_cost_color_map(PaletteInfo * palette_info,MdRateEstimationContext * rate_table,BlkStruct * blk_ptr,int plane,BlockSize bsize,COLOR_MAP_TYPE type)579 int svt_av1_cost_color_map(PaletteInfo *palette_info, MdRateEstimationContext *rate_table,
580                            BlkStruct *blk_ptr, int plane, BlockSize bsize, COLOR_MAP_TYPE type) {
581     assert(plane == 0 || plane == 1);
582     Av1ColorMapParam color_map_params;
583     get_color_map_params_rate(
584         palette_info, rate_table, blk_ptr, plane, bsize, type, &color_map_params);
585     MapCdf map_pb_cdf = NULL;
586     return cost_and_tokenize_map(&color_map_params, NULL, plane, 1, 0, map_pb_cdf);
587 }
588