1 /*
2  * Copyright (c) 2018, Alliance for Open Media. All rights reserved
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include "tools/txfm_analyzer/txfm_graph.h"
13 
14 #include <stdio.h>
15 #include <stdlib.h>
16 #include <math.h>
17 
18 typedef struct Node Node;
19 
get_fun_name(char * str_fun_name,int str_buf_size,const TYPE_TXFM type,const int txfm_size)20 void get_fun_name(char *str_fun_name, int str_buf_size, const TYPE_TXFM type,
21                   const int txfm_size) {
22   if (type == TYPE_DCT)
23     snprintf(str_fun_name, str_buf_size, "fdct%d_new", txfm_size);
24   else if (type == TYPE_ADST)
25     snprintf(str_fun_name, str_buf_size, "fadst%d_new", txfm_size);
26   else if (type == TYPE_IDCT)
27     snprintf(str_fun_name, str_buf_size, "idct%d_new", txfm_size);
28   else if (type == TYPE_IADST)
29     snprintf(str_fun_name, str_buf_size, "iadst%d_new", txfm_size);
30 }
31 
get_txfm_type_name(char * str_fun_name,int str_buf_size,const TYPE_TXFM type,const int txfm_size)32 void get_txfm_type_name(char *str_fun_name, int str_buf_size,
33                         const TYPE_TXFM type, const int txfm_size) {
34   if (type == TYPE_DCT)
35     snprintf(str_fun_name, str_buf_size, "TXFM_TYPE_DCT%d", txfm_size);
36   else if (type == TYPE_ADST)
37     snprintf(str_fun_name, str_buf_size, "TXFM_TYPE_ADST%d", txfm_size);
38   else if (type == TYPE_IDCT)
39     snprintf(str_fun_name, str_buf_size, "TXFM_TYPE_DCT%d", txfm_size);
40   else if (type == TYPE_IADST)
41     snprintf(str_fun_name, str_buf_size, "TXFM_TYPE_ADST%d", txfm_size);
42 }
43 
get_hybrid_2d_type_name(char * buf,int buf_size,const TYPE_TXFM type0,const TYPE_TXFM type1,const int txfm_size0,const int txfm_size1)44 void get_hybrid_2d_type_name(char *buf, int buf_size, const TYPE_TXFM type0,
45                              const TYPE_TXFM type1, const int txfm_size0,
46                              const int txfm_size1) {
47   if (type0 == TYPE_DCT && type1 == TYPE_DCT)
48     snprintf(buf, buf_size, "_dct_dct_%dx%d", txfm_size1, txfm_size0);
49   else if (type0 == TYPE_DCT && type1 == TYPE_ADST)
50     snprintf(buf, buf_size, "_dct_adst_%dx%d", txfm_size1, txfm_size0);
51   else if (type0 == TYPE_ADST && type1 == TYPE_ADST)
52     snprintf(buf, buf_size, "_adst_adst_%dx%d", txfm_size1, txfm_size0);
53   else if (type0 == TYPE_ADST && type1 == TYPE_DCT)
54     snprintf(buf, buf_size, "_adst_dct_%dx%d", txfm_size1, txfm_size0);
55 }
56 
get_inv_type(TYPE_TXFM type)57 TYPE_TXFM get_inv_type(TYPE_TXFM type) {
58   if (type == TYPE_DCT)
59     return TYPE_IDCT;
60   else if (type == TYPE_ADST)
61     return TYPE_IADST;
62   else if (type == TYPE_IDCT)
63     return TYPE_DCT;
64   else if (type == TYPE_IADST)
65     return TYPE_ADST;
66   else
67     return TYPE_LAST;
68 }
69 
reference_dct_1d(double * in,double * out,int size)70 void reference_dct_1d(double *in, double *out, int size) {
71   const double kInvSqrt2 = 0.707106781186547524400844362104;
72   for (int k = 0; k < size; k++) {
73     out[k] = 0;  // initialize out[k]
74     for (int n = 0; n < size; n++) {
75       out[k] += in[n] * cos(PI * (2 * n + 1) * k / (2 * size));
76     }
77     if (k == 0) out[k] = out[k] * kInvSqrt2;
78   }
79 }
80 
reference_dct_2d(double * in,double * out,int size)81 void reference_dct_2d(double *in, double *out, int size) {
82   double *tempOut = new double[size * size];
83   // dct each row: in -> out
84   for (int r = 0; r < size; r++) {
85     reference_dct_1d(in + r * size, out + r * size, size);
86   }
87 
88   for (int r = 0; r < size; r++) {
89     // out ->tempOut
90     for (int c = 0; c < size; c++) {
91       tempOut[r * size + c] = out[c * size + r];
92     }
93   }
94   for (int r = 0; r < size; r++) {
95     reference_dct_1d(tempOut + r * size, out + r * size, size);
96   }
97   delete[] tempOut;
98 }
99 
reference_adst_1d(double * in,double * out,int size)100 void reference_adst_1d(double *in, double *out, int size) {
101   for (int k = 0; k < size; k++) {
102     out[k] = 0;  // initialize out[k]
103     for (int n = 0; n < size; n++) {
104       out[k] += in[n] * sin(PI * (2 * n + 1) * (2 * k + 1) / (4 * size));
105     }
106   }
107 }
108 
reference_hybrid_2d(double * in,double * out,int size,int type0,int type1)109 void reference_hybrid_2d(double *in, double *out, int size, int type0,
110                          int type1) {
111   double *tempOut = new double[size * size];
112   // dct each row: in -> out
113   for (int r = 0; r < size; r++) {
114     if (type0 == TYPE_DCT)
115       reference_dct_1d(in + r * size, out + r * size, size);
116     else
117       reference_adst_1d(in + r * size, out + r * size, size);
118   }
119 
120   for (int r = 0; r < size; r++) {
121     // out ->tempOut
122     for (int c = 0; c < size; c++) {
123       tempOut[r * size + c] = out[c * size + r];
124     }
125   }
126   for (int r = 0; r < size; r++) {
127     if (type1 == TYPE_DCT)
128       reference_dct_1d(tempOut + r * size, out + r * size, size);
129     else
130       reference_adst_1d(tempOut + r * size, out + r * size, size);
131   }
132   delete[] tempOut;
133 }
134 
reference_hybrid_2d_new(double * in,double * out,int size0,int size1,int type0,int type1)135 void reference_hybrid_2d_new(double *in, double *out, int size0, int size1,
136                              int type0, int type1) {
137   double *tempOut = new double[size0 * size1];
138   // dct each row: in -> out
139   for (int r = 0; r < size1; r++) {
140     if (type0 == TYPE_DCT)
141       reference_dct_1d(in + r * size0, out + r * size0, size0);
142     else
143       reference_adst_1d(in + r * size0, out + r * size0, size0);
144   }
145 
146   for (int r = 0; r < size1; r++) {
147     // out ->tempOut
148     for (int c = 0; c < size0; c++) {
149       tempOut[c * size1 + r] = out[r * size0 + c];
150     }
151   }
152   for (int r = 0; r < size0; r++) {
153     if (type1 == TYPE_DCT)
154       reference_dct_1d(tempOut + r * size1, out + r * size1, size1);
155     else
156       reference_adst_1d(tempOut + r * size1, out + r * size1, size1);
157   }
158   delete[] tempOut;
159 }
160 
get_max_bit(unsigned int x)161 unsigned int get_max_bit(unsigned int x) {
162   int max_bit = -1;
163   while (x) {
164     x = x >> 1;
165     max_bit++;
166   }
167   return max_bit;
168 }
169 
bitwise_reverse(unsigned int x,int max_bit)170 unsigned int bitwise_reverse(unsigned int x, int max_bit) {
171   x = ((x >> 16) & 0x0000ffff) | ((x & 0x0000ffff) << 16);
172   x = ((x >> 8) & 0x00ff00ff) | ((x & 0x00ff00ff) << 8);
173   x = ((x >> 4) & 0x0f0f0f0f) | ((x & 0x0f0f0f0f) << 4);
174   x = ((x >> 2) & 0x33333333) | ((x & 0x33333333) << 2);
175   x = ((x >> 1) & 0x55555555) | ((x & 0x55555555) << 1);
176   x = x >> (31 - max_bit);
177   return x;
178 }
179 
get_idx(int ri,int ci,int cSize)180 int get_idx(int ri, int ci, int cSize) { return ri * cSize + ci; }
181 
add_node(Node * node,int stage_num,int node_num,int stage_idx,int node_idx,int in,double w)182 void add_node(Node *node, int stage_num, int node_num, int stage_idx,
183               int node_idx, int in, double w) {
184   int outIdx = get_idx(stage_idx, node_idx, node_num);
185   int inIdx = get_idx(stage_idx - 1, in, node_num);
186   int idx = node[outIdx].inNodeNum;
187   if (idx < 2) {
188     node[outIdx].inNode[idx] = &node[inIdx];
189     node[outIdx].inNodeIdx[idx] = in;
190     node[outIdx].inWeight[idx] = w;
191     idx++;
192     node[outIdx].inNodeNum = idx;
193   } else {
194     printf("Error: inNode is full");
195   }
196 }
197 
connect_node(Node * node,int stage_num,int node_num,int stage_idx,int node_idx,int in0,double w0,int in1,double w1)198 void connect_node(Node *node, int stage_num, int node_num, int stage_idx,
199                   int node_idx, int in0, double w0, int in1, double w1) {
200   int outIdx = get_idx(stage_idx, node_idx, node_num);
201   int inIdx0 = get_idx(stage_idx - 1, in0, node_num);
202   int inIdx1 = get_idx(stage_idx - 1, in1, node_num);
203 
204   int idx = 0;
205   // if(w0 != 0) {
206   node[outIdx].inNode[idx] = &node[inIdx0];
207   node[outIdx].inNodeIdx[idx] = in0;
208   node[outIdx].inWeight[idx] = w0;
209   idx++;
210   //}
211 
212   // if(w1 != 0) {
213   node[outIdx].inNode[idx] = &node[inIdx1];
214   node[outIdx].inNodeIdx[idx] = in1;
215   node[outIdx].inWeight[idx] = w1;
216   idx++;
217   //}
218 
219   node[outIdx].inNodeNum = idx;
220 }
221 
propagate(Node * node,int stage_num,int node_num,int stage_idx)222 void propagate(Node *node, int stage_num, int node_num, int stage_idx) {
223   for (int ni = 0; ni < node_num; ni++) {
224     int outIdx = get_idx(stage_idx, ni, node_num);
225     node[outIdx].value = 0;
226     for (int k = 0; k < node[outIdx].inNodeNum; k++) {
227       node[outIdx].value +=
228           node[outIdx].inNode[k]->value * node[outIdx].inWeight[k];
229     }
230   }
231 }
232 
round_shift(int64_t value,int bit)233 int64_t round_shift(int64_t value, int bit) {
234   if (bit > 0) {
235     if (value < 0) {
236       return -round_shift(-value, bit);
237     } else {
238       return (value + (1 << (bit - 1))) >> bit;
239     }
240   } else {
241     return value << (-bit);
242   }
243 }
244 
round_shift_array(int32_t * arr,int size,int bit)245 void round_shift_array(int32_t *arr, int size, int bit) {
246   if (bit == 0) {
247     return;
248   } else {
249     for (int i = 0; i < size; i++) {
250       arr[i] = round_shift(arr[i], bit);
251     }
252   }
253 }
254 
graph_reset_visited(Node * node,int stage_num,int node_num)255 void graph_reset_visited(Node *node, int stage_num, int node_num) {
256   for (int si = 0; si < stage_num; si++) {
257     for (int ni = 0; ni < node_num; ni++) {
258       int idx = get_idx(si, ni, node_num);
259       node[idx].visited = 0;
260     }
261   }
262 }
263 
estimate_value(Node * node,int stage_num,int node_num,int stage_idx,int node_idx,int estimate_bit)264 void estimate_value(Node *node, int stage_num, int node_num, int stage_idx,
265                     int node_idx, int estimate_bit) {
266   if (stage_idx > 0) {
267     int outIdx = get_idx(stage_idx, node_idx, node_num);
268     int64_t out = 0;
269     node[outIdx].value = 0;
270     for (int k = 0; k < node[outIdx].inNodeNum; k++) {
271       int64_t w = round(node[outIdx].inWeight[k] * (1 << estimate_bit));
272       int64_t v = round(node[outIdx].inNode[k]->value);
273       out += v * w;
274     }
275     node[outIdx].value = round_shift(out, estimate_bit);
276   }
277 }
278 
amplify_value(Node * node,int stage_num,int node_num,int stage_idx,int node_idx,int amplify_bit)279 void amplify_value(Node *node, int stage_num, int node_num, int stage_idx,
280                    int node_idx, int amplify_bit) {
281   int outIdx = get_idx(stage_idx, node_idx, node_num);
282   node[outIdx].value = round_shift(round(node[outIdx].value), -amplify_bit);
283 }
284 
propagate_estimate_amlify(Node * node,int stage_num,int node_num,int stage_idx,int amplify_bit,int estimate_bit)285 void propagate_estimate_amlify(Node *node, int stage_num, int node_num,
286                                int stage_idx, int amplify_bit,
287                                int estimate_bit) {
288   for (int ni = 0; ni < node_num; ni++) {
289     estimate_value(node, stage_num, node_num, stage_idx, ni, estimate_bit);
290     amplify_value(node, stage_num, node_num, stage_idx, ni, amplify_bit);
291   }
292 }
293 
init_graph(Node * node,int stage_num,int node_num)294 void init_graph(Node *node, int stage_num, int node_num) {
295   for (int si = 0; si < stage_num; si++) {
296     for (int ni = 0; ni < node_num; ni++) {
297       int outIdx = get_idx(si, ni, node_num);
298       node[outIdx].stageIdx = si;
299       node[outIdx].nodeIdx = ni;
300       node[outIdx].value = 0;
301       node[outIdx].inNodeNum = 0;
302       if (si >= 1) {
303         connect_node(node, stage_num, node_num, si, ni, ni, 1, ni, 0);
304       }
305     }
306   }
307 }
308 
gen_B_graph(Node * node,int stage_num,int node_num,int stage_idx,int node_idx,int N,int star)309 void gen_B_graph(Node *node, int stage_num, int node_num, int stage_idx,
310                  int node_idx, int N, int star) {
311   for (int i = 0; i < N / 2; i++) {
312     int out = node_idx + i;
313     int in1 = node_idx + N - 1 - i;
314     if (star == 1) {
315       connect_node(node, stage_num, node_num, stage_idx + 1, out, out, -1, in1,
316                    1);
317     } else {
318       connect_node(node, stage_num, node_num, stage_idx + 1, out, out, 1, in1,
319                    1);
320     }
321   }
322   for (int i = N / 2; i < N; i++) {
323     int out = node_idx + i;
324     int in1 = node_idx + N - 1 - i;
325     if (star == 1) {
326       connect_node(node, stage_num, node_num, stage_idx + 1, out, out, 1, in1,
327                    1);
328     } else {
329       connect_node(node, stage_num, node_num, stage_idx + 1, out, out, -1, in1,
330                    1);
331     }
332   }
333 }
334 
gen_P_graph(Node * node,int stage_num,int node_num,int stage_idx,int node_idx,int N)335 void gen_P_graph(Node *node, int stage_num, int node_num, int stage_idx,
336                  int node_idx, int N) {
337   int max_bit = get_max_bit(N - 1);
338   for (int i = 0; i < N; i++) {
339     int out = node_idx + bitwise_reverse(i, max_bit);
340     int in = node_idx + i;
341     connect_node(node, stage_num, node_num, stage_idx + 1, out, in, 1, in, 0);
342   }
343 }
344 
gen_type1_graph(Node * node,int stage_num,int node_num,int stage_idx,int node_idx,int N)345 void gen_type1_graph(Node *node, int stage_num, int node_num, int stage_idx,
346                      int node_idx, int N) {
347   int max_bit = get_max_bit(N);
348   for (int ni = 0; ni < N / 2; ni++) {
349     int ai = bitwise_reverse(N + ni, max_bit);
350     int out = node_idx + ni;
351     int in1 = node_idx + N - ni - 1;
352     connect_node(node, stage_num, node_num, stage_idx + 1, out, out,
353                  sin(PI * ai / (2 * 2 * N)), in1, cos(PI * ai / (2 * 2 * N)));
354   }
355   for (int ni = N / 2; ni < N; ni++) {
356     int ai = bitwise_reverse(N + ni, max_bit);
357     int out = node_idx + ni;
358     int in1 = node_idx + N - ni - 1;
359     connect_node(node, stage_num, node_num, stage_idx + 1, out, out,
360                  cos(PI * ai / (2 * 2 * N)), in1, -sin(PI * ai / (2 * 2 * N)));
361   }
362 }
363 
gen_type2_graph(Node * node,int stage_num,int node_num,int stage_idx,int node_idx,int N)364 void gen_type2_graph(Node *node, int stage_num, int node_num, int stage_idx,
365                      int node_idx, int N) {
366   for (int ni = 0; ni < N / 4; ni++) {
367     int out = node_idx + ni;
368     connect_node(node, stage_num, node_num, stage_idx + 1, out, out, 1, out, 0);
369   }
370 
371   for (int ni = N / 4; ni < N / 2; ni++) {
372     int out = node_idx + ni;
373     int in1 = node_idx + N - ni - 1;
374     connect_node(node, stage_num, node_num, stage_idx + 1, out, out,
375                  -cos(PI / 4), in1, cos(-PI / 4));
376   }
377 
378   for (int ni = N / 2; ni < N * 3 / 4; ni++) {
379     int out = node_idx + ni;
380     int in1 = node_idx + N - ni - 1;
381     connect_node(node, stage_num, node_num, stage_idx + 1, out, out,
382                  cos(-PI / 4), in1, cos(PI / 4));
383   }
384 
385   for (int ni = N * 3 / 4; ni < N; ni++) {
386     int out = node_idx + ni;
387     connect_node(node, stage_num, node_num, stage_idx + 1, out, out, 1, out, 0);
388   }
389 }
390 
gen_type3_graph(Node * node,int stage_num,int node_num,int stage_idx,int node_idx,int idx,int N)391 void gen_type3_graph(Node *node, int stage_num, int node_num, int stage_idx,
392                      int node_idx, int idx, int N) {
393   // TODO(angiebird): Simplify and clarify this function
394 
395   int i = 2 * N / (1 << (idx / 2));
396   int max_bit =
397       get_max_bit(i / 2) - 1;  // the max_bit counts on i/2 instead of N here
398   int N_over_i = 2 << (idx / 2);
399 
400   for (int nj = 0; nj < N / 2; nj += N_over_i) {
401     int j = nj / (N_over_i);
402     int kj = bitwise_reverse(i / 4 + j, max_bit);
403     // printf("kj = %d\n", kj);
404 
405     // I_N/2i   --- 0
406     int offset = nj;
407     for (int ni = 0; ni < N_over_i / 4; ni++) {
408       int out = node_idx + offset + ni;
409       int in = out;
410       connect_node(node, stage_num, node_num, stage_idx + 1, out, in, 1, in, 0);
411     }
412 
413     // -C_Kj/i --- S_Kj/i
414     offset += N_over_i / 4;
415     for (int ni = 0; ni < N_over_i / 4; ni++) {
416       int out = node_idx + offset + ni;
417       int in0 = out;
418       double w0 = -cos(kj * PI / i);
419       int in1 = N - (offset + ni) - 1 + node_idx;
420       double w1 = sin(kj * PI / i);
421       connect_node(node, stage_num, node_num, stage_idx + 1, out, in0, w0, in1,
422                    w1);
423     }
424 
425     // S_kj/i  --- -C_Kj/i
426     offset += N_over_i / 4;
427     for (int ni = 0; ni < N_over_i / 4; ni++) {
428       int out = node_idx + offset + ni;
429       int in0 = out;
430       double w0 = -sin(kj * PI / i);
431       int in1 = N - (offset + ni) - 1 + node_idx;
432       double w1 = -cos(kj * PI / i);
433       connect_node(node, stage_num, node_num, stage_idx + 1, out, in0, w0, in1,
434                    w1);
435     }
436 
437     // I_N/2i   --- 0
438     offset += N_over_i / 4;
439     for (int ni = 0; ni < N_over_i / 4; ni++) {
440       int out = node_idx + offset + ni;
441       int in = out;
442       connect_node(node, stage_num, node_num, stage_idx + 1, out, in, 1, in, 0);
443     }
444   }
445 
446   for (int nj = N / 2; nj < N; nj += N_over_i) {
447     int j = nj / N_over_i;
448     int kj = bitwise_reverse(i / 4 + j, max_bit);
449 
450     // I_N/2i --- 0
451     int offset = nj;
452     for (int ni = 0; ni < N_over_i / 4; ni++) {
453       int out = node_idx + offset + ni;
454       int in = out;
455       connect_node(node, stage_num, node_num, stage_idx + 1, out, in, 1, in, 0);
456     }
457 
458     // C_kj/i --- -S_Kj/i
459     offset += N_over_i / 4;
460     for (int ni = 0; ni < N_over_i / 4; ni++) {
461       int out = node_idx + offset + ni;
462       int in0 = out;
463       double w0 = cos(kj * PI / i);
464       int in1 = N - (offset + ni) - 1 + node_idx;
465       double w1 = -sin(kj * PI / i);
466       connect_node(node, stage_num, node_num, stage_idx + 1, out, in0, w0, in1,
467                    w1);
468     }
469 
470     // S_kj/i --- C_Kj/i
471     offset += N_over_i / 4;
472     for (int ni = 0; ni < N_over_i / 4; ni++) {
473       int out = node_idx + offset + ni;
474       int in0 = out;
475       double w0 = sin(kj * PI / i);
476       int in1 = N - (offset + ni) - 1 + node_idx;
477       double w1 = cos(kj * PI / i);
478       connect_node(node, stage_num, node_num, stage_idx + 1, out, in0, w0, in1,
479                    w1);
480     }
481 
482     // I_N/2i --- 0
483     offset += N_over_i / 4;
484     for (int ni = 0; ni < N_over_i / 4; ni++) {
485       int out = node_idx + offset + ni;
486       int in = out;
487       connect_node(node, stage_num, node_num, stage_idx + 1, out, in, 1, in, 0);
488     }
489   }
490 }
491 
gen_type4_graph(Node * node,int stage_num,int node_num,int stage_idx,int node_idx,int idx,int N)492 void gen_type4_graph(Node *node, int stage_num, int node_num, int stage_idx,
493                      int node_idx, int idx, int N) {
494   int B_size = 1 << ((idx + 1) / 2);
495   for (int ni = 0; ni < N; ni += B_size) {
496     gen_B_graph(node, stage_num, node_num, stage_idx, node_idx + ni, B_size,
497                 (ni / B_size) % 2);
498   }
499 }
500 
gen_R_graph(Node * node,int stage_num,int node_num,int stage_idx,int node_idx,int N)501 void gen_R_graph(Node *node, int stage_num, int node_num, int stage_idx,
502                  int node_idx, int N) {
503   int max_idx = 2 * (get_max_bit(N) + 1) - 3;
504   for (int idx = 0; idx < max_idx; idx++) {
505     int s = stage_idx + max_idx - idx - 1;
506     if (idx == 0) {
507       // type 1
508       gen_type1_graph(node, stage_num, node_num, s, node_idx, N);
509     } else if (idx == max_idx - 1) {
510       // type 2
511       gen_type2_graph(node, stage_num, node_num, s, node_idx, N);
512     } else if ((idx + 1) % 2 == 0) {
513       // type 4
514       gen_type4_graph(node, stage_num, node_num, s, node_idx, idx, N);
515     } else if ((idx + 1) % 2 == 1) {
516       // type 3
517       gen_type3_graph(node, stage_num, node_num, s, node_idx, idx, N);
518     } else {
519       printf("check gen_R_graph()\n");
520     }
521   }
522 }
523 
gen_DCT_graph(Node * node,int stage_num,int node_num,int stage_idx,int node_idx,int N)524 void gen_DCT_graph(Node *node, int stage_num, int node_num, int stage_idx,
525                    int node_idx, int N) {
526   if (N > 2) {
527     gen_B_graph(node, stage_num, node_num, stage_idx, node_idx, N, 0);
528     gen_DCT_graph(node, stage_num, node_num, stage_idx + 1, node_idx, N / 2);
529     gen_R_graph(node, stage_num, node_num, stage_idx + 1, node_idx + N / 2,
530                 N / 2);
531   } else {
532     // generate dct_2
533     connect_node(node, stage_num, node_num, stage_idx + 1, node_idx, node_idx,
534                  cos(PI / 4), node_idx + 1, cos(PI / 4));
535     connect_node(node, stage_num, node_num, stage_idx + 1, node_idx + 1,
536                  node_idx + 1, -cos(PI / 4), node_idx, cos(PI / 4));
537   }
538 }
539 
get_dct_stage_num(int size)540 int get_dct_stage_num(int size) { return 2 * get_max_bit(size); }
541 
gen_DCT_graph_1d(Node * node,int stage_num,int node_num,int stage_idx,int node_idx,int dct_node_num)542 void gen_DCT_graph_1d(Node *node, int stage_num, int node_num, int stage_idx,
543                       int node_idx, int dct_node_num) {
544   gen_DCT_graph(node, stage_num, node_num, stage_idx, node_idx, dct_node_num);
545   int dct_stage_num = get_dct_stage_num(dct_node_num);
546   gen_P_graph(node, stage_num, node_num, stage_idx + dct_stage_num - 2,
547               node_idx, dct_node_num);
548 }
549 
gen_adst_B_graph(Node * node,int stage_num,int node_num,int stage_idx,int node_idx,int adst_idx)550 void gen_adst_B_graph(Node *node, int stage_num, int node_num, int stage_idx,
551                       int node_idx, int adst_idx) {
552   int size = 1 << (adst_idx + 1);
553   for (int ni = 0; ni < size / 2; ni++) {
554     int nOut = node_idx + ni;
555     int nIn = nOut + size / 2;
556     connect_node(node, stage_num, node_num, stage_idx + 1, nOut, nOut, 1, nIn,
557                  1);
558     // printf("nOut: %d nIn: %d\n", nOut, nIn);
559   }
560   for (int ni = size / 2; ni < size; ni++) {
561     int nOut = node_idx + ni;
562     int nIn = nOut - size / 2;
563     connect_node(node, stage_num, node_num, stage_idx + 1, nOut, nOut, -1, nIn,
564                  1);
565     // printf("ndctOut: %d nIn: %d\n", nOut, nIn);
566   }
567 }
568 
gen_adst_U_graph(Node * node,int stage_num,int node_num,int stage_idx,int node_idx,int adst_idx,int adst_node_num)569 void gen_adst_U_graph(Node *node, int stage_num, int node_num, int stage_idx,
570                       int node_idx, int adst_idx, int adst_node_num) {
571   int size = 1 << (adst_idx + 1);
572   for (int ni = 0; ni < adst_node_num; ni += size) {
573     gen_adst_B_graph(node, stage_num, node_num, stage_idx, node_idx + ni,
574                      adst_idx);
575   }
576 }
577 
gen_adst_T_graph(Node * node,int stage_num,int node_num,int stage_idx,int node_idx,double freq)578 void gen_adst_T_graph(Node *node, int stage_num, int node_num, int stage_idx,
579                       int node_idx, double freq) {
580   connect_node(node, stage_num, node_num, stage_idx + 1, node_idx, node_idx,
581                cos(freq * PI), node_idx + 1, sin(freq * PI));
582   connect_node(node, stage_num, node_num, stage_idx + 1, node_idx + 1,
583                node_idx + 1, -cos(freq * PI), node_idx, sin(freq * PI));
584 }
585 
gen_adst_E_graph(Node * node,int stage_num,int node_num,int stage_idx,int node_idx,int adst_idx)586 void gen_adst_E_graph(Node *node, int stage_num, int node_num, int stage_idx,
587                       int node_idx, int adst_idx) {
588   int size = 1 << (adst_idx);
589   for (int i = 0; i < size / 2; i++) {
590     int ni = i * 2;
591     double fi = (1 + 4 * i) * 1.0 / (1 << (adst_idx + 1));
592     gen_adst_T_graph(node, stage_num, node_num, stage_idx, node_idx + ni, fi);
593   }
594 }
595 
gen_adst_V_graph(Node * node,int stage_num,int node_num,int stage_idx,int node_idx,int adst_idx,int adst_node_num)596 void gen_adst_V_graph(Node *node, int stage_num, int node_num, int stage_idx,
597                       int node_idx, int adst_idx, int adst_node_num) {
598   int size = 1 << (adst_idx);
599   for (int i = 0; i < adst_node_num / size; i++) {
600     if (i % 2 == 1) {
601       int ni = i * size;
602       gen_adst_E_graph(node, stage_num, node_num, stage_idx, node_idx + ni,
603                        adst_idx);
604     }
605   }
606 }
gen_adst_VJ_graph(Node * node,int stage_num,int node_num,int stage_idx,int node_idx,int adst_node_num)607 void gen_adst_VJ_graph(Node *node, int stage_num, int node_num, int stage_idx,
608                        int node_idx, int adst_node_num) {
609   for (int i = 0; i < adst_node_num / 2; i++) {
610     int ni = i * 2;
611     double fi = (1 + 4 * i) * 1.0 / (4 * adst_node_num);
612     gen_adst_T_graph(node, stage_num, node_num, stage_idx, node_idx + ni, fi);
613   }
614 }
gen_adst_Q_graph(Node * node,int stage_num,int node_num,int stage_idx,int node_idx,int adst_node_num)615 void gen_adst_Q_graph(Node *node, int stage_num, int node_num, int stage_idx,
616                       int node_idx, int adst_node_num) {
617   // reverse order when idx is 1, 3, 5, 7 ...
618   // example of adst_node_num = 8:
619   //   0 1 2 3 4 5 6 7
620   // --> 0 7 2 5 4 3 6 1
621   for (int ni = 0; ni < adst_node_num; ni++) {
622     if (ni % 2 == 0) {
623       int out = node_idx + ni;
624       connect_node(node, stage_num, node_num, stage_idx + 1, out, out, 1, out,
625                    0);
626     } else {
627       int out = node_idx + ni;
628       int in = node_idx + adst_node_num - ni;
629       connect_node(node, stage_num, node_num, stage_idx + 1, out, in, 1, in, 0);
630     }
631   }
632 }
gen_adst_Ibar_graph(Node * node,int stage_num,int node_num,int stage_idx,int node_idx,int adst_node_num)633 void gen_adst_Ibar_graph(Node *node, int stage_num, int node_num, int stage_idx,
634                          int node_idx, int adst_node_num) {
635   // reverse order
636   // 0 1 2 3 --> 3 2 1 0
637   for (int ni = 0; ni < adst_node_num; ni++) {
638     int out = node_idx + ni;
639     int in = node_idx + adst_node_num - ni - 1;
640     connect_node(node, stage_num, node_num, stage_idx + 1, out, in, 1, in, 0);
641   }
642 }
643 
get_Q_out2in(int adst_node_num,int out)644 int get_Q_out2in(int adst_node_num, int out) {
645   int in;
646   if (out % 2 == 0) {
647     in = out;
648   } else {
649     in = adst_node_num - out;
650   }
651   return in;
652 }
653 
get_Ibar_out2in(int adst_node_num,int out)654 int get_Ibar_out2in(int adst_node_num, int out) {
655   return adst_node_num - out - 1;
656 }
657 
gen_adst_IbarQ_graph(Node * node,int stage_num,int node_num,int stage_idx,int node_idx,int adst_node_num)658 void gen_adst_IbarQ_graph(Node *node, int stage_num, int node_num,
659                           int stage_idx, int node_idx, int adst_node_num) {
660   // in -> Ibar -> Q -> out
661   for (int ni = 0; ni < adst_node_num; ni++) {
662     int out = node_idx + ni;
663     int in = node_idx +
664              get_Ibar_out2in(adst_node_num, get_Q_out2in(adst_node_num, ni));
665     connect_node(node, stage_num, node_num, stage_idx + 1, out, in, 1, in, 0);
666   }
667 }
668 
gen_adst_D_graph(Node * node,int stage_num,int node_num,int stage_idx,int node_idx,int adst_node_num)669 void gen_adst_D_graph(Node *node, int stage_num, int node_num, int stage_idx,
670                       int node_idx, int adst_node_num) {
671   // reverse order
672   for (int ni = 0; ni < adst_node_num; ni++) {
673     int out = node_idx + ni;
674     int in = out;
675     if (ni % 2 == 0) {
676       connect_node(node, stage_num, node_num, stage_idx + 1, out, in, 1, in, 0);
677     } else {
678       connect_node(node, stage_num, node_num, stage_idx + 1, out, in, -1, in,
679                    0);
680     }
681   }
682 }
683 
get_hadamard_idx(int x,int adst_node_num)684 int get_hadamard_idx(int x, int adst_node_num) {
685   int max_bit = get_max_bit(adst_node_num - 1);
686   x = bitwise_reverse(x, max_bit);
687 
688   // gray code
689   int c = x & 1;
690   int p = x & 1;
691   int y = c;
692 
693   for (int i = 1; i <= max_bit; i++) {
694     p = c;
695     c = (x >> i) & 1;
696     y += (c ^ p) << i;
697   }
698   return y;
699 }
700 
gen_adst_Ht_graph(Node * node,int stage_num,int node_num,int stage_idx,int node_idx,int adst_node_num)701 void gen_adst_Ht_graph(Node *node, int stage_num, int node_num, int stage_idx,
702                        int node_idx, int adst_node_num) {
703   for (int ni = 0; ni < adst_node_num; ni++) {
704     int out = node_idx + ni;
705     int in = node_idx + get_hadamard_idx(ni, adst_node_num);
706     connect_node(node, stage_num, node_num, stage_idx + 1, out, in, 1, in, 0);
707   }
708 }
709 
gen_adst_HtD_graph(Node * node,int stage_num,int node_num,int stage_idx,int node_idx,int adst_node_num)710 void gen_adst_HtD_graph(Node *node, int stage_num, int node_num, int stage_idx,
711                         int node_idx, int adst_node_num) {
712   for (int ni = 0; ni < adst_node_num; ni++) {
713     int out = node_idx + ni;
714     int in = node_idx + get_hadamard_idx(ni, adst_node_num);
715     double inW;
716     if (ni % 2 == 0)
717       inW = 1;
718     else
719       inW = -1;
720     connect_node(node, stage_num, node_num, stage_idx + 1, out, in, inW, in, 0);
721   }
722 }
723 
get_adst_stage_num(int adst_node_num)724 int get_adst_stage_num(int adst_node_num) {
725   return 2 * get_max_bit(adst_node_num) + 2;
726 }
727 
gen_iadst_graph(Node * node,int stage_num,int node_num,int stage_idx,int node_idx,int adst_node_num)728 int gen_iadst_graph(Node *node, int stage_num, int node_num, int stage_idx,
729                     int node_idx, int adst_node_num) {
730   int max_bit = get_max_bit(adst_node_num);
731   int si = 0;
732   gen_adst_IbarQ_graph(node, stage_num, node_num, stage_idx + si, node_idx,
733                        adst_node_num);
734   si++;
735   gen_adst_VJ_graph(node, stage_num, node_num, stage_idx + si, node_idx,
736                     adst_node_num);
737   si++;
738   for (int adst_idx = max_bit - 1; adst_idx >= 1; adst_idx--) {
739     gen_adst_U_graph(node, stage_num, node_num, stage_idx + si, node_idx,
740                      adst_idx, adst_node_num);
741     si++;
742     gen_adst_V_graph(node, stage_num, node_num, stage_idx + si, node_idx,
743                      adst_idx, adst_node_num);
744     si++;
745   }
746   gen_adst_HtD_graph(node, stage_num, node_num, stage_idx + si, node_idx,
747                      adst_node_num);
748   si++;
749   return si + 1;
750 }
751 
gen_adst_graph(Node * node,int stage_num,int node_num,int stage_idx,int node_idx,int adst_node_num)752 int gen_adst_graph(Node *node, int stage_num, int node_num, int stage_idx,
753                    int node_idx, int adst_node_num) {
754   int hybrid_stage_num = get_hybrid_stage_num(TYPE_ADST, adst_node_num);
755   // generate a adst tempNode
756   Node *tempNode = new Node[hybrid_stage_num * adst_node_num];
757   init_graph(tempNode, hybrid_stage_num, adst_node_num);
758   int si = gen_iadst_graph(tempNode, hybrid_stage_num, adst_node_num, 0, 0,
759                            adst_node_num);
760 
761   // tempNode's inverse graph to node[stage_idx][node_idx]
762   gen_inv_graph(tempNode, hybrid_stage_num, adst_node_num, node, stage_num,
763                 node_num, stage_idx, node_idx);
764   delete[] tempNode;
765   return si;
766 }
767 
connect_layer_2d(Node * node,int stage_num,int node_num,int stage_idx,int node_idx,int dct_node_num)768 void connect_layer_2d(Node *node, int stage_num, int node_num, int stage_idx,
769                       int node_idx, int dct_node_num) {
770   for (int first = 0; first < dct_node_num; first++) {
771     for (int second = 0; second < dct_node_num; second++) {
772       // int sIn = stage_idx;
773       int sOut = stage_idx + 1;
774       int nIn = node_idx + first * dct_node_num + second;
775       int nOut = node_idx + second * dct_node_num + first;
776 
777       // printf("sIn: %d nIn: %d sOut: %d nOut: %d\n", sIn, nIn, sOut, nOut);
778 
779       connect_node(node, stage_num, node_num, sOut, nOut, nIn, 1, nIn, 0);
780     }
781   }
782 }
783 
connect_layer_2d_new(Node * node,int stage_num,int node_num,int stage_idx,int node_idx,int dct_node_num0,int dct_node_num1)784 void connect_layer_2d_new(Node *node, int stage_num, int node_num,
785                           int stage_idx, int node_idx, int dct_node_num0,
786                           int dct_node_num1) {
787   for (int i = 0; i < dct_node_num1; i++) {
788     for (int j = 0; j < dct_node_num0; j++) {
789       // int sIn = stage_idx;
790       int sOut = stage_idx + 1;
791       int nIn = node_idx + i * dct_node_num0 + j;
792       int nOut = node_idx + j * dct_node_num1 + i;
793 
794       // printf("sIn: %d nIn: %d sOut: %d nOut: %d\n", sIn, nIn, sOut, nOut);
795 
796       connect_node(node, stage_num, node_num, sOut, nOut, nIn, 1, nIn, 0);
797     }
798   }
799 }
800 
gen_DCT_graph_2d(Node * node,int stage_num,int node_num,int stage_idx,int node_idx,int dct_node_num)801 void gen_DCT_graph_2d(Node *node, int stage_num, int node_num, int stage_idx,
802                       int node_idx, int dct_node_num) {
803   int dct_stage_num = get_dct_stage_num(dct_node_num);
804   // put 2 layers of dct_node_num DCTs on the graph
805   for (int ni = 0; ni < dct_node_num; ni++) {
806     gen_DCT_graph_1d(node, stage_num, node_num, stage_idx,
807                      node_idx + ni * dct_node_num, dct_node_num);
808     gen_DCT_graph_1d(node, stage_num, node_num, stage_idx + dct_stage_num,
809                      node_idx + ni * dct_node_num, dct_node_num);
810   }
811   // connect first layer and second layer
812   connect_layer_2d(node, stage_num, node_num, stage_idx + dct_stage_num - 1,
813                    node_idx, dct_node_num);
814 }
815 
get_hybrid_stage_num(int type,int hybrid_node_num)816 int get_hybrid_stage_num(int type, int hybrid_node_num) {
817   if (type == TYPE_DCT || type == TYPE_IDCT) {
818     return get_dct_stage_num(hybrid_node_num);
819   } else if (type == TYPE_ADST || type == TYPE_IADST) {
820     return get_adst_stage_num(hybrid_node_num);
821   }
822   return 0;
823 }
824 
get_hybrid_2d_stage_num(int type0,int type1,int hybrid_node_num)825 int get_hybrid_2d_stage_num(int type0, int type1, int hybrid_node_num) {
826   int stage_num = 0;
827   stage_num += get_hybrid_stage_num(type0, hybrid_node_num);
828   stage_num += get_hybrid_stage_num(type1, hybrid_node_num);
829   return stage_num;
830 }
831 
get_hybrid_2d_stage_num_new(int type0,int type1,int hybrid_node_num0,int hybrid_node_num1)832 int get_hybrid_2d_stage_num_new(int type0, int type1, int hybrid_node_num0,
833                                 int hybrid_node_num1) {
834   int stage_num = 0;
835   stage_num += get_hybrid_stage_num(type0, hybrid_node_num0);
836   stage_num += get_hybrid_stage_num(type1, hybrid_node_num1);
837   return stage_num;
838 }
839 
get_hybrid_amplify_factor(int type,int hybrid_node_num)840 int get_hybrid_amplify_factor(int type, int hybrid_node_num) {
841   return get_max_bit(hybrid_node_num) - 1;
842 }
843 
gen_hybrid_graph_1d(Node * node,int stage_num,int node_num,int stage_idx,int node_idx,int hybrid_node_num,int type)844 void gen_hybrid_graph_1d(Node *node, int stage_num, int node_num, int stage_idx,
845                          int node_idx, int hybrid_node_num, int type) {
846   if (type == TYPE_DCT) {
847     gen_DCT_graph_1d(node, stage_num, node_num, stage_idx, node_idx,
848                      hybrid_node_num);
849   } else if (type == TYPE_ADST) {
850     gen_adst_graph(node, stage_num, node_num, stage_idx, node_idx,
851                    hybrid_node_num);
852   } else if (type == TYPE_IDCT) {
853     int hybrid_stage_num = get_hybrid_stage_num(type, hybrid_node_num);
854     // generate a dct tempNode
855     Node *tempNode = new Node[hybrid_stage_num * hybrid_node_num];
856     init_graph(tempNode, hybrid_stage_num, hybrid_node_num);
857     gen_DCT_graph_1d(tempNode, hybrid_stage_num, hybrid_node_num, 0, 0,
858                      hybrid_node_num);
859 
860     // tempNode's inverse graph to node[stage_idx][node_idx]
861     gen_inv_graph(tempNode, hybrid_stage_num, hybrid_node_num, node, stage_num,
862                   node_num, stage_idx, node_idx);
863     delete[] tempNode;
864   } else if (type == TYPE_IADST) {
865     int hybrid_stage_num = get_hybrid_stage_num(type, hybrid_node_num);
866     // generate a adst tempNode
867     Node *tempNode = new Node[hybrid_stage_num * hybrid_node_num];
868     init_graph(tempNode, hybrid_stage_num, hybrid_node_num);
869     gen_adst_graph(tempNode, hybrid_stage_num, hybrid_node_num, 0, 0,
870                    hybrid_node_num);
871 
872     // tempNode's inverse graph to node[stage_idx][node_idx]
873     gen_inv_graph(tempNode, hybrid_stage_num, hybrid_node_num, node, stage_num,
874                   node_num, stage_idx, node_idx);
875     delete[] tempNode;
876   }
877 }
878 
gen_hybrid_graph_2d(Node * node,int stage_num,int node_num,int stage_idx,int node_idx,int hybrid_node_num,int type0,int type1)879 void gen_hybrid_graph_2d(Node *node, int stage_num, int node_num, int stage_idx,
880                          int node_idx, int hybrid_node_num, int type0,
881                          int type1) {
882   int hybrid_stage_num = get_hybrid_stage_num(type0, hybrid_node_num);
883 
884   for (int ni = 0; ni < hybrid_node_num; ni++) {
885     gen_hybrid_graph_1d(node, stage_num, node_num, stage_idx,
886                         node_idx + ni * hybrid_node_num, hybrid_node_num,
887                         type0);
888     gen_hybrid_graph_1d(node, stage_num, node_num, stage_idx + hybrid_stage_num,
889                         node_idx + ni * hybrid_node_num, hybrid_node_num,
890                         type1);
891   }
892 
893   // connect first layer and second layer
894   connect_layer_2d(node, stage_num, node_num, stage_idx + hybrid_stage_num - 1,
895                    node_idx, hybrid_node_num);
896 }
897 
gen_hybrid_graph_2d_new(Node * node,int stage_num,int node_num,int stage_idx,int node_idx,int hybrid_node_num0,int hybrid_node_num1,int type0,int type1)898 void gen_hybrid_graph_2d_new(Node *node, int stage_num, int node_num,
899                              int stage_idx, int node_idx, int hybrid_node_num0,
900                              int hybrid_node_num1, int type0, int type1) {
901   int hybrid_stage_num0 = get_hybrid_stage_num(type0, hybrid_node_num0);
902 
903   for (int ni = 0; ni < hybrid_node_num1; ni++) {
904     gen_hybrid_graph_1d(node, stage_num, node_num, stage_idx,
905                         node_idx + ni * hybrid_node_num0, hybrid_node_num0,
906                         type0);
907   }
908   for (int ni = 0; ni < hybrid_node_num0; ni++) {
909     gen_hybrid_graph_1d(
910         node, stage_num, node_num, stage_idx + hybrid_stage_num0,
911         node_idx + ni * hybrid_node_num1, hybrid_node_num1, type1);
912   }
913 
914   // connect first layer and second layer
915   connect_layer_2d_new(node, stage_num, node_num,
916                        stage_idx + hybrid_stage_num0 - 1, node_idx,
917                        hybrid_node_num0, hybrid_node_num1);
918 }
919 
gen_inv_graph(Node * node,int stage_num,int node_num,Node * invNode,int inv_stage_num,int inv_node_num,int inv_stage_idx,int inv_node_idx)920 void gen_inv_graph(Node *node, int stage_num, int node_num, Node *invNode,
921                    int inv_stage_num, int inv_node_num, int inv_stage_idx,
922                    int inv_node_idx) {
923   // clean up inNodeNum in invNode because of add_node
924   for (int si = 1 + inv_stage_idx; si < inv_stage_idx + stage_num; si++) {
925     for (int ni = inv_node_idx; ni < inv_node_idx + node_num; ni++) {
926       int idx = get_idx(si, ni, inv_node_num);
927       invNode[idx].inNodeNum = 0;
928     }
929   }
930   // generate inverse graph of node on invNode
931   for (int si = 1; si < stage_num; si++) {
932     for (int ni = 0; ni < node_num; ni++) {
933       int invSi = stage_num - si;
934       int idx = get_idx(si, ni, node_num);
935       for (int k = 0; k < node[idx].inNodeNum; k++) {
936         int invNi = node[idx].inNodeIdx[k];
937         add_node(invNode, inv_stage_num, inv_node_num, invSi + inv_stage_idx,
938                  invNi + inv_node_idx, ni + inv_node_idx,
939                  node[idx].inWeight[k]);
940       }
941     }
942   }
943 }
944