1 /*
2  * Copyright (c) 2016, Alliance for Open Media. All rights reserved
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #ifndef AV1_TXFM_H_
13 #define AV1_TXFM_H_
14 
15 #include <assert.h>
16 #include <math.h>
17 #include <stdio.h>
18 
19 #include "av1/common/enums.h"
20 #include "av1/common/blockd.h"
21 #include "aom/aom_integer.h"
22 #include "aom_dsp/aom_dsp_common.h"
23 
24 #ifdef __cplusplus
25 extern "C" {
26 #endif
27 
28 #define MAX_TXFM_STAGE_NUM 12
29 
30 static const int cos_bit_min = 10;
31 static const int cos_bit_max = 16;
32 
33 // cospi_arr[i][j] = (int)round(cos(M_PI*j/128) * (1<<(cos_bit_min+i)));
34 static const int32_t cospi_arr_data[7][64] = {
35   { 1024, 1024, 1023, 1021, 1019, 1016, 1013, 1009, 1004, 999, 993, 987, 980,
36     972,  964,  955,  946,  936,  926,  915,  903,  891,  878, 865, 851, 837,
37     822,  807,  792,  775,  759,  742,  724,  706,  688,  669, 650, 630, 610,
38     590,  569,  548,  526,  505,  483,  460,  438,  415,  392, 369, 345, 321,
39     297,  273,  249,  224,  200,  175,  150,  125,  100,  75,  50,  25 },
40   { 2048, 2047, 2046, 2042, 2038, 2033, 2026, 2018, 2009, 1998, 1987,
41     1974, 1960, 1945, 1928, 1911, 1892, 1872, 1851, 1829, 1806, 1782,
42     1757, 1730, 1703, 1674, 1645, 1615, 1583, 1551, 1517, 1483, 1448,
43     1412, 1375, 1338, 1299, 1260, 1220, 1179, 1138, 1096, 1053, 1009,
44     965,  921,  876,  830,  784,  737,  690,  642,  595,  546,  498,
45     449,  400,  350,  301,  251,  201,  151,  100,  50 },
46   { 4096, 4095, 4091, 4085, 4076, 4065, 4052, 4036, 4017, 3996, 3973,
47     3948, 3920, 3889, 3857, 3822, 3784, 3745, 3703, 3659, 3612, 3564,
48     3513, 3461, 3406, 3349, 3290, 3229, 3166, 3102, 3035, 2967, 2896,
49     2824, 2751, 2675, 2598, 2520, 2440, 2359, 2276, 2191, 2106, 2019,
50     1931, 1842, 1751, 1660, 1567, 1474, 1380, 1285, 1189, 1092, 995,
51     897,  799,  700,  601,  501,  401,  301,  201,  101 },
52   { 8192, 8190, 8182, 8170, 8153, 8130, 8103, 8071, 8035, 7993, 7946,
53     7895, 7839, 7779, 7713, 7643, 7568, 7489, 7405, 7317, 7225, 7128,
54     7027, 6921, 6811, 6698, 6580, 6458, 6333, 6203, 6070, 5933, 5793,
55     5649, 5501, 5351, 5197, 5040, 4880, 4717, 4551, 4383, 4212, 4038,
56     3862, 3683, 3503, 3320, 3135, 2948, 2760, 2570, 2378, 2185, 1990,
57     1795, 1598, 1401, 1202, 1003, 803,  603,  402,  201 },
58   { 16384, 16379, 16364, 16340, 16305, 16261, 16207, 16143, 16069, 15986, 15893,
59     15791, 15679, 15557, 15426, 15286, 15137, 14978, 14811, 14635, 14449, 14256,
60     14053, 13842, 13623, 13395, 13160, 12916, 12665, 12406, 12140, 11866, 11585,
61     11297, 11003, 10702, 10394, 10080, 9760,  9434,  9102,  8765,  8423,  8076,
62     7723,  7366,  7005,  6639,  6270,  5897,  5520,  5139,  4756,  4370,  3981,
63     3590,  3196,  2801,  2404,  2006,  1606,  1205,  804,   402 },
64   { 32768, 32758, 32729, 32679, 32610, 32522, 32413, 32286, 32138, 31972, 31786,
65     31581, 31357, 31114, 30853, 30572, 30274, 29957, 29622, 29269, 28899, 28511,
66     28106, 27684, 27246, 26791, 26320, 25833, 25330, 24812, 24279, 23732, 23170,
67     22595, 22006, 21403, 20788, 20160, 19520, 18868, 18205, 17531, 16846, 16151,
68     15447, 14733, 14010, 13279, 12540, 11793, 11039, 10279, 9512,  8740,  7962,
69     7180,  6393,  5602,  4808,  4011,  3212,  2411,  1608,  804 },
70   { 65536, 65516, 65457, 65358, 65220, 65043, 64827, 64571, 64277, 63944, 63572,
71     63162, 62714, 62228, 61705, 61145, 60547, 59914, 59244, 58538, 57798, 57022,
72     56212, 55368, 54491, 53581, 52639, 51665, 50660, 49624, 48559, 47464, 46341,
73     45190, 44011, 42806, 41576, 40320, 39040, 37736, 36410, 35062, 33692, 32303,
74     30893, 29466, 28020, 26558, 25080, 23586, 22078, 20557, 19024, 17479, 15924,
75     14359, 12785, 11204, 9616,  8022,  6424,  4821,  3216,  1608 }
76 };
77 
cospi_arr(int n)78 static INLINE const int32_t *cospi_arr(int n) {
79   return cospi_arr_data[n - cos_bit_min];
80 }
81 
round_shift(int32_t value,int bit)82 static INLINE int32_t round_shift(int32_t value, int bit) {
83   assert(bit >= 1);
84   return (value + (1 << (bit - 1))) >> bit;
85 }
86 
round_shift_array(int32_t * arr,int size,int bit)87 static INLINE void round_shift_array(int32_t *arr, int size, int bit) {
88   int i;
89   if (bit == 0) {
90     return;
91   } else {
92     if (bit > 0) {
93       for (i = 0; i < size; i++) {
94         arr[i] = round_shift(arr[i], bit);
95       }
96     } else {
97       for (i = 0; i < size; i++) {
98         arr[i] = arr[i] * (1 << (-bit));
99       }
100     }
101   }
102 }
103 
half_btf(int32_t w0,int32_t in0,int32_t w1,int32_t in1,int bit)104 static INLINE int32_t half_btf(int32_t w0, int32_t in0, int32_t w1, int32_t in1,
105                                int bit) {
106   int32_t result_32 = w0 * in0 + w1 * in1;
107 #if CONFIG_COEFFICIENT_RANGE_CHECKING
108   int64_t result_64 = (int64_t)w0 * (int64_t)in0 + (int64_t)w1 * (int64_t)in1;
109   if (result_64 < INT32_MIN || result_64 > INT32_MAX) {
110     printf("%s %d overflow result_32: %d result_64: %" PRId64
111            " w0: %d in0: %d w1: %d in1: "
112            "%d\n",
113            __FILE__, __LINE__, result_32, result_64, w0, in0, w1, in1);
114     assert(0 && "half_btf overflow");
115   }
116 #endif
117   return round_shift(result_32, bit);
118 }
119 
120 typedef void (*TxfmFunc)(const int32_t *input, int32_t *output,
121                          const int8_t *cos_bit, const int8_t *stage_range);
122 
123 typedef enum TXFM_TYPE {
124   TXFM_TYPE_DCT4,
125   TXFM_TYPE_DCT8,
126   TXFM_TYPE_DCT16,
127   TXFM_TYPE_DCT32,
128   TXFM_TYPE_DCT64,
129   TXFM_TYPE_ADST4,
130   TXFM_TYPE_ADST8,
131   TXFM_TYPE_ADST16,
132   TXFM_TYPE_ADST32,
133   TXFM_TYPE_IDENTITY4,
134   TXFM_TYPE_IDENTITY8,
135   TXFM_TYPE_IDENTITY16,
136   TXFM_TYPE_IDENTITY32,
137   TXFM_TYPE_IDENTITY64,
138 } TXFM_TYPE;
139 
140 typedef struct TXFM_1D_CFG {
141   const int txfm_size;
142   const int stage_num;
143 
144   const int8_t *shift;
145   const int8_t *stage_range;
146   const int8_t *cos_bit;
147   const TXFM_TYPE txfm_type;
148 } TXFM_1D_CFG;
149 
150 typedef struct TXFM_2D_FLIP_CFG {
151   int ud_flip;  // flip upside down
152   int lr_flip;  // flip left to right
153   const TXFM_1D_CFG *col_cfg;
154   const TXFM_1D_CFG *row_cfg;
155 } TXFM_2D_FLIP_CFG;
156 
set_flip_cfg(TX_TYPE tx_type,TXFM_2D_FLIP_CFG * cfg)157 static INLINE void set_flip_cfg(TX_TYPE tx_type, TXFM_2D_FLIP_CFG *cfg) {
158   switch (tx_type) {
159     case DCT_DCT:
160     case ADST_DCT:
161     case DCT_ADST:
162     case ADST_ADST:
163       cfg->ud_flip = 0;
164       cfg->lr_flip = 0;
165       break;
166 #if CONFIG_EXT_TX
167     case IDTX:
168     case V_DCT:
169     case H_DCT:
170     case V_ADST:
171     case H_ADST:
172       cfg->ud_flip = 0;
173       cfg->lr_flip = 0;
174       break;
175     case FLIPADST_DCT:
176     case FLIPADST_ADST:
177     case V_FLIPADST:
178       cfg->ud_flip = 1;
179       cfg->lr_flip = 0;
180       break;
181     case DCT_FLIPADST:
182     case ADST_FLIPADST:
183     case H_FLIPADST:
184       cfg->ud_flip = 0;
185       cfg->lr_flip = 1;
186       break;
187     case FLIPADST_FLIPADST:
188       cfg->ud_flip = 1;
189       cfg->lr_flip = 1;
190       break;
191 #endif  // CONFIG_EXT_TX
192     default:
193       cfg->ud_flip = 0;
194       cfg->lr_flip = 0;
195       assert(0);
196   }
197 }
198 
199 #if CONFIG_TXMG
av1_rotate_tx_size(TX_SIZE tx_size)200 static INLINE TX_SIZE av1_rotate_tx_size(TX_SIZE tx_size) {
201   switch (tx_size) {
202 #if CONFIG_CHROMA_2X2
203     case TX_2X2: return TX_2X2;
204 #endif
205     case TX_4X4: return TX_4X4;
206     case TX_8X8: return TX_8X8;
207     case TX_16X16: return TX_16X16;
208     case TX_32X32: return TX_32X32;
209 #if CONFIG_TX64X64
210     case TX_64X64: return TX_64X64;
211     case TX_32X64: return TX_64X32;
212     case TX_64X32: return TX_32X64;
213 #endif
214     case TX_4X8: return TX_8X4;
215     case TX_8X4: return TX_4X8;
216     case TX_8X16: return TX_16X8;
217     case TX_16X8: return TX_8X16;
218     case TX_16X32: return TX_32X16;
219     case TX_32X16: return TX_16X32;
220     case TX_4X16: return TX_16X4;
221     case TX_16X4: return TX_4X16;
222     case TX_8X32: return TX_32X8;
223     case TX_32X8: return TX_8X32;
224     default: assert(0); return TX_INVALID;
225   }
226 }
227 
av1_rotate_tx_type(TX_TYPE tx_type)228 static INLINE TX_TYPE av1_rotate_tx_type(TX_TYPE tx_type) {
229   switch (tx_type) {
230     case DCT_DCT: return DCT_DCT;
231     case ADST_DCT: return DCT_ADST;
232     case DCT_ADST: return ADST_DCT;
233     case ADST_ADST: return ADST_ADST;
234 #if CONFIG_EXT_TX
235     case FLIPADST_DCT: return DCT_FLIPADST;
236     case DCT_FLIPADST: return FLIPADST_DCT;
237     case FLIPADST_FLIPADST: return FLIPADST_FLIPADST;
238     case ADST_FLIPADST: return FLIPADST_ADST;
239     case FLIPADST_ADST: return ADST_FLIPADST;
240     case IDTX: return IDTX;
241     case V_DCT: return H_DCT;
242     case H_DCT: return V_DCT;
243     case V_ADST: return H_ADST;
244     case H_ADST: return V_ADST;
245     case V_FLIPADST: return H_FLIPADST;
246     case H_FLIPADST: return V_FLIPADST;
247 #endif  // CONFIG_EXT_TX
248 #if CONFIG_MRC_TX
249     case MRC_DCT: return MRC_DCT;
250 #endif  // CONFIG_MRC_TX
251     default: assert(0); return TX_TYPES;
252   }
253 }
254 #endif  // CONFIG_TXMG
255 
256 #if CONFIG_MRC_TX
get_mrc_diff_mask_inter(const int16_t * diff,int diff_stride,uint8_t * mask,int mask_stride,int width,int height)257 static INLINE int get_mrc_diff_mask_inter(const int16_t *diff, int diff_stride,
258                                           uint8_t *mask, int mask_stride,
259                                           int width, int height) {
260   // placeholder mask generation function
261   assert(SIGNAL_MRC_MASK_INTER);
262   int n_masked_vals = 0;
263   for (int i = 0; i < height; ++i) {
264     for (int j = 0; j < width; ++j) {
265       mask[i * mask_stride + j] = diff[i * diff_stride + j] > 100 ? 1 : 0;
266       n_masked_vals += mask[i * mask_stride + j];
267     }
268   }
269   return n_masked_vals;
270 }
271 
get_mrc_pred_mask_inter(const uint8_t * pred,int pred_stride,uint8_t * mask,int mask_stride,int width,int height)272 static INLINE int get_mrc_pred_mask_inter(const uint8_t *pred, int pred_stride,
273                                           uint8_t *mask, int mask_stride,
274                                           int width, int height) {
275   // placeholder mask generation function
276   int n_masked_vals = 0;
277   for (int i = 0; i < height; ++i) {
278     for (int j = 0; j < width; ++j) {
279       mask[i * mask_stride + j] = pred[i * pred_stride + j] > 100 ? 1 : 0;
280       n_masked_vals += mask[i * mask_stride + j];
281     }
282   }
283   return n_masked_vals;
284 }
285 
get_mrc_diff_mask_intra(const int16_t * diff,int diff_stride,uint8_t * mask,int mask_stride,int width,int height)286 static INLINE int get_mrc_diff_mask_intra(const int16_t *diff, int diff_stride,
287                                           uint8_t *mask, int mask_stride,
288                                           int width, int height) {
289   // placeholder mask generation function
290   assert(SIGNAL_MRC_MASK_INTRA);
291   int n_masked_vals = 0;
292   for (int i = 0; i < height; ++i) {
293     for (int j = 0; j < width; ++j) {
294       mask[i * mask_stride + j] = diff[i * diff_stride + j] > 100 ? 1 : 0;
295       n_masked_vals += mask[i * mask_stride + j];
296     }
297   }
298   return n_masked_vals;
299 }
300 
get_mrc_pred_mask_intra(const uint8_t * pred,int pred_stride,uint8_t * mask,int mask_stride,int width,int height)301 static INLINE int get_mrc_pred_mask_intra(const uint8_t *pred, int pred_stride,
302                                           uint8_t *mask, int mask_stride,
303                                           int width, int height) {
304   // placeholder mask generation function
305   int n_masked_vals = 0;
306   for (int i = 0; i < height; ++i) {
307     for (int j = 0; j < width; ++j) {
308       mask[i * mask_stride + j] = pred[i * pred_stride + j] > 100 ? 1 : 0;
309       n_masked_vals += mask[i * mask_stride + j];
310     }
311   }
312   return n_masked_vals;
313 }
314 
get_mrc_diff_mask(const int16_t * diff,int diff_stride,uint8_t * mask,int mask_stride,int width,int height,int is_inter)315 static INLINE int get_mrc_diff_mask(const int16_t *diff, int diff_stride,
316                                     uint8_t *mask, int mask_stride, int width,
317                                     int height, int is_inter) {
318   if (is_inter) {
319     assert(USE_MRC_INTER && "MRC invalid for inter blocks");
320     assert(SIGNAL_MRC_MASK_INTER);
321     return get_mrc_diff_mask_inter(diff, diff_stride, mask, mask_stride, width,
322                                    height);
323   } else {
324     assert(USE_MRC_INTRA && "MRC invalid for intra blocks");
325     assert(SIGNAL_MRC_MASK_INTRA);
326     return get_mrc_diff_mask_intra(diff, diff_stride, mask, mask_stride, width,
327                                    height);
328   }
329 }
330 
get_mrc_pred_mask(const uint8_t * pred,int pred_stride,uint8_t * mask,int mask_stride,int width,int height,int is_inter)331 static INLINE int get_mrc_pred_mask(const uint8_t *pred, int pred_stride,
332                                     uint8_t *mask, int mask_stride, int width,
333                                     int height, int is_inter) {
334   if (is_inter) {
335     assert(USE_MRC_INTER && "MRC invalid for inter blocks");
336     return get_mrc_pred_mask_inter(pred, pred_stride, mask, mask_stride, width,
337                                    height);
338   } else {
339     assert(USE_MRC_INTRA && "MRC invalid for intra blocks");
340     return get_mrc_pred_mask_intra(pred, pred_stride, mask, mask_stride, width,
341                                    height);
342   }
343 }
344 
is_valid_mrc_mask(int n_masked_vals,int width,int height)345 static INLINE int is_valid_mrc_mask(int n_masked_vals, int width, int height) {
346   return !(n_masked_vals == 0 || n_masked_vals == (width * height));
347 }
348 #endif  // CONFIG_MRC_TX
349 
350 void av1_gen_fwd_stage_range(int8_t *stage_range_col, int8_t *stage_range_row,
351                              const TXFM_2D_FLIP_CFG *cfg, int bd);
352 
353 void av1_gen_inv_stage_range(int8_t *stage_range_col, int8_t *stage_range_row,
354                              const TXFM_2D_FLIP_CFG *cfg, int8_t fwd_shift,
355                              int bd);
356 
357 TXFM_2D_FLIP_CFG av1_get_fwd_txfm_cfg(TX_TYPE tx_type, TX_SIZE tx_size);
358 #if CONFIG_TX64X64
359 TXFM_2D_FLIP_CFG av1_get_fwd_txfm_64x64_cfg(TX_TYPE tx_type);
360 TXFM_2D_FLIP_CFG av1_get_fwd_txfm_64x32_cfg(TX_TYPE tx_type);
361 TXFM_2D_FLIP_CFG av1_get_fwd_txfm_32x64_cfg(TX_TYPE tx_type);
362 #endif  // CONFIG_TX64X64
363 TXFM_2D_FLIP_CFG av1_get_inv_txfm_cfg(TX_TYPE tx_type, TX_SIZE tx_size);
364 #ifdef __cplusplus
365 }
366 #endif  // __cplusplus
367 
368 #endif  // AV1_TXFM_H_
369