1 /*
2  * Copyright (c) 2018, Alliance for Open Media. All rights reserved
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <stdio.h>
13 #include <stdlib.h>
14 #include <math.h>
15 #include <float.h>
16 #include <string.h>
17 
18 #include "tools/txfm_analyzer/txfm_graph.h"
19 
20 typedef enum CODE_TYPE {
21   CODE_TYPE_C,
22   CODE_TYPE_SSE2,
23   CODE_TYPE_SSE4_1
24 } CODE_TYPE;
25 
get_cos_idx(double value,int mod)26 int get_cos_idx(double value, int mod) {
27   return round(acos(fabs(value)) / PI * mod);
28 }
29 
cos_text_arr(double value,int mod,char * text,int size)30 char *cos_text_arr(double value, int mod, char *text, int size) {
31   int num = get_cos_idx(value, mod);
32   if (value < 0) {
33     snprintf(text, size, "-cospi[%2d]", num);
34   } else {
35     snprintf(text, size, " cospi[%2d]", num);
36   }
37 
38   if (num == 0)
39     printf("v: %f -> %d/%d v==-1 is %d\n", value, num, mod, value == -1);
40 
41   return text;
42 }
43 
cos_text_sse2(double w0,double w1,int mod,char * text,int size)44 char *cos_text_sse2(double w0, double w1, int mod, char *text, int size) {
45   int idx0 = get_cos_idx(w0, mod);
46   int idx1 = get_cos_idx(w1, mod);
47   char p[] = "p";
48   char n[] = "m";
49   char *sgn0 = w0 < 0 ? n : p;
50   char *sgn1 = w1 < 0 ? n : p;
51   snprintf(text, size, "cospi_%s%02d_%s%02d", sgn0, idx0, sgn1, idx1);
52   return text;
53 }
54 
cos_text_sse4_1(double w,int mod,char * text,int size)55 char *cos_text_sse4_1(double w, int mod, char *text, int size) {
56   int idx = get_cos_idx(w, mod);
57   char p[] = "p";
58   char n[] = "m";
59   char *sgn = w < 0 ? n : p;
60   snprintf(text, size, "cospi_%s%02d", sgn, idx);
61   return text;
62 }
63 
node_to_code_c(Node * node,const char * buf0,const char * buf1)64 void node_to_code_c(Node *node, const char *buf0, const char *buf1) {
65   int cnt = 0;
66   for (int i = 0; i < 2; i++) {
67     if (fabs(node->inWeight[i]) == 1 || fabs(node->inWeight[i]) == 0) cnt++;
68   }
69   if (cnt == 2) {
70     int cnt2 = 0;
71     printf("  %s[%d] =", buf1, node->nodeIdx);
72     for (int i = 0; i < 2; i++) {
73       if (fabs(node->inWeight[i]) == 1) {
74         cnt2++;
75       }
76     }
77     if (cnt2 == 2) {
78       printf(" apply_value(");
79     }
80     int cnt1 = 0;
81     for (int i = 0; i < 2; i++) {
82       if (node->inWeight[i] == 1) {
83         if (cnt1 > 0)
84           printf(" + %s[%d]", buf0, node->inNodeIdx[i]);
85         else
86           printf(" %s[%d]", buf0, node->inNodeIdx[i]);
87         cnt1++;
88       } else if (node->inWeight[i] == -1) {
89         if (cnt1 > 0)
90           printf(" - %s[%d]", buf0, node->inNodeIdx[i]);
91         else
92           printf("-%s[%d]", buf0, node->inNodeIdx[i]);
93         cnt1++;
94       }
95     }
96     if (cnt2 == 2) {
97       printf(", stage_range[stage])");
98     }
99     printf(";\n");
100   } else {
101     char w0[100];
102     char w1[100];
103     printf(
104         "  %s[%d] = half_btf(%s, %s[%d], %s, %s[%d], "
105         "cos_bit);\n",
106         buf1, node->nodeIdx, cos_text_arr(node->inWeight[0], COS_MOD, w0, 100),
107         buf0, node->inNodeIdx[0],
108         cos_text_arr(node->inWeight[1], COS_MOD, w1, 100), buf0,
109         node->inNodeIdx[1]);
110   }
111 }
112 
gen_code_c(Node * node,int stage_num,int node_num,TYPE_TXFM type)113 void gen_code_c(Node *node, int stage_num, int node_num, TYPE_TXFM type) {
114   char *fun_name = new char[100];
115   get_fun_name(fun_name, 100, type, node_num);
116 
117   printf("\n");
118   printf(
119       "void av1_%s(const int32_t *input, int32_t *output, int8_t cos_bit, "
120       "const int8_t* stage_range) "
121       "{\n",
122       fun_name);
123   printf("  assert(output != input);\n");
124   printf("  const int32_t size = %d;\n", node_num);
125   printf("  const int32_t *cospi = cospi_arr(cos_bit);\n");
126   printf("\n");
127 
128   printf("  int32_t stage = 0;\n");
129   printf("  int32_t *bf0, *bf1;\n");
130   printf("  int32_t step[%d];\n", node_num);
131 
132   const char *buf0 = "bf0";
133   const char *buf1 = "bf1";
134   const char *input = "input";
135 
136   int si = 0;
137   printf("\n");
138   printf("  // stage %d;\n", si);
139   printf("  apply_range(stage, input, %s, size, stage_range[stage]);\n", input);
140 
141   si = 1;
142   printf("\n");
143   printf("  // stage %d;\n", si);
144   printf("  stage++;\n");
145   if (si % 2 == (stage_num - 1) % 2) {
146     printf("  %s = output;\n", buf1);
147   } else {
148     printf("  %s = step;\n", buf1);
149   }
150 
151   for (int ni = 0; ni < node_num; ni++) {
152     int idx = get_idx(si, ni, node_num);
153     node_to_code_c(node + idx, input, buf1);
154   }
155 
156   printf("  range_check_buf(stage, input, bf1, size, stage_range[stage]);\n");
157 
158   for (int si = 2; si < stage_num; si++) {
159     printf("\n");
160     printf("  // stage %d\n", si);
161     printf("  stage++;\n");
162     if (si % 2 == (stage_num - 1) % 2) {
163       printf("  %s = step;\n", buf0);
164       printf("  %s = output;\n", buf1);
165     } else {
166       printf("  %s = output;\n", buf0);
167       printf("  %s = step;\n", buf1);
168     }
169 
170     // computation code
171     for (int ni = 0; ni < node_num; ni++) {
172       int idx = get_idx(si, ni, node_num);
173       node_to_code_c(node + idx, buf0, buf1);
174     }
175 
176     if (si != stage_num - 1) {
177       printf(
178           "  range_check_buf(stage, input, bf1, size, stage_range[stage]);\n");
179     }
180   }
181   printf("  apply_range(stage, input, output, size, stage_range[stage]);\n");
182   printf("}\n");
183 }
184 
single_node_to_code_sse2(Node * node,const char * buf0,const char * buf1)185 void single_node_to_code_sse2(Node *node, const char *buf0, const char *buf1) {
186   printf("  %s[%2d] =", buf1, node->nodeIdx);
187   if (node->inWeight[0] == 1 && node->inWeight[1] == 1) {
188     printf(" _mm_adds_epi16(%s[%d], %s[%d])", buf0, node->inNodeIdx[0], buf0,
189            node->inNodeIdx[1]);
190   } else if (node->inWeight[0] == 1 && node->inWeight[1] == -1) {
191     printf(" _mm_subs_epi16(%s[%d], %s[%d])", buf0, node->inNodeIdx[0], buf0,
192            node->inNodeIdx[1]);
193   } else if (node->inWeight[0] == -1 && node->inWeight[1] == 1) {
194     printf(" _mm_subs_epi16(%s[%d], %s[%d])", buf0, node->inNodeIdx[1], buf0,
195            node->inNodeIdx[0]);
196   } else if (node->inWeight[0] == 1 && node->inWeight[1] == 0) {
197     printf(" %s[%d]", buf0, node->inNodeIdx[0]);
198   } else if (node->inWeight[0] == 0 && node->inWeight[1] == 1) {
199     printf(" %s[%d]", buf0, node->inNodeIdx[1]);
200   } else if (node->inWeight[0] == -1 && node->inWeight[1] == 0) {
201     printf(" _mm_subs_epi16(__zero, %s[%d])", buf0, node->inNodeIdx[0]);
202   } else if (node->inWeight[0] == 0 && node->inWeight[1] == -1) {
203     printf(" _mm_subs_epi16(__zero, %s[%d])", buf0, node->inNodeIdx[1]);
204   }
205   printf(";\n");
206 }
207 
pair_node_to_code_sse2(Node * node,Node * partnerNode,const char * buf0,const char * buf1)208 void pair_node_to_code_sse2(Node *node, Node *partnerNode, const char *buf0,
209                             const char *buf1) {
210   char temp0[100];
211   char temp1[100];
212   // btf_16_sse2_type0(w0, w1, in0, in1, out0, out1)
213   if (node->inNodeIdx[0] != partnerNode->inNodeIdx[0])
214     printf("  btf_16_sse2(%s, %s, %s[%d], %s[%d], %s[%d], %s[%d]);\n",
215            cos_text_sse2(node->inWeight[0], node->inWeight[1], COS_MOD, temp0,
216                          100),
217            cos_text_sse2(partnerNode->inWeight[1], partnerNode->inWeight[0],
218                          COS_MOD, temp1, 100),
219            buf0, node->inNodeIdx[0], buf0, node->inNodeIdx[1], buf1,
220            node->nodeIdx, buf1, partnerNode->nodeIdx);
221   else
222     printf("  btf_16_sse2(%s, %s, %s[%d], %s[%d], %s[%d], %s[%d]);\n",
223            cos_text_sse2(node->inWeight[0], node->inWeight[1], COS_MOD, temp0,
224                          100),
225            cos_text_sse2(partnerNode->inWeight[0], partnerNode->inWeight[1],
226                          COS_MOD, temp1, 100),
227            buf0, node->inNodeIdx[0], buf0, node->inNodeIdx[1], buf1,
228            node->nodeIdx, buf1, partnerNode->nodeIdx);
229 }
230 
get_partner_node(Node * node)231 Node *get_partner_node(Node *node) {
232   int diff = node->inNode[1]->nodeIdx - node->nodeIdx;
233   return node + diff;
234 }
235 
node_to_code_sse2(Node * node,const char * buf0,const char * buf1)236 void node_to_code_sse2(Node *node, const char *buf0, const char *buf1) {
237   int cnt = 0;
238   int cnt1 = 0;
239   if (node->visited == 0) {
240     node->visited = 1;
241     for (int i = 0; i < 2; i++) {
242       if (fabs(node->inWeight[i]) == 1 || fabs(node->inWeight[i]) == 0) cnt++;
243       if (fabs(node->inWeight[i]) == 1) cnt1++;
244     }
245     if (cnt == 2) {
246       if (cnt1 == 2) {
247         // has a partner
248         Node *partnerNode = get_partner_node(node);
249         partnerNode->visited = 1;
250         single_node_to_code_sse2(node, buf0, buf1);
251         single_node_to_code_sse2(partnerNode, buf0, buf1);
252       } else {
253         single_node_to_code_sse2(node, buf0, buf1);
254       }
255     } else {
256       Node *partnerNode = get_partner_node(node);
257       partnerNode->visited = 1;
258       pair_node_to_code_sse2(node, partnerNode, buf0, buf1);
259     }
260   }
261 }
262 
gen_cospi_list_sse2(Node * node,int stage_num,int node_num)263 void gen_cospi_list_sse2(Node *node, int stage_num, int node_num) {
264   int visited[65][65][2][2];
265   memset(visited, 0, sizeof(visited));
266   char text[100];
267   char text1[100];
268   char text2[100];
269   int size = 100;
270   printf("\n");
271   for (int si = 1; si < stage_num; si++) {
272     for (int ni = 0; ni < node_num; ni++) {
273       int idx = get_idx(si, ni, node_num);
274       int cnt = 0;
275       Node *node0 = node + idx;
276       if (node0->visited == 0) {
277         node0->visited = 1;
278         for (int i = 0; i < 2; i++) {
279           if (fabs(node0->inWeight[i]) == 1 || fabs(node0->inWeight[i]) == 0)
280             cnt++;
281         }
282         if (cnt != 2) {
283           {
284             double w0 = node0->inWeight[0];
285             double w1 = node0->inWeight[1];
286             int idx0 = get_cos_idx(w0, COS_MOD);
287             int idx1 = get_cos_idx(w1, COS_MOD);
288             int sgn0 = w0 < 0 ? 1 : 0;
289             int sgn1 = w1 < 0 ? 1 : 0;
290 
291             if (!visited[idx0][idx1][sgn0][sgn1]) {
292               visited[idx0][idx1][sgn0][sgn1] = 1;
293               printf("  __m128i %s = pair_set_epi16(%s, %s);\n",
294                      cos_text_sse2(w0, w1, COS_MOD, text, size),
295                      cos_text_arr(w0, COS_MOD, text1, size),
296                      cos_text_arr(w1, COS_MOD, text2, size));
297             }
298           }
299           Node *node1 = get_partner_node(node0);
300           node1->visited = 1;
301           if (node1->inNode[0]->nodeIdx != node0->inNode[0]->nodeIdx) {
302             double w0 = node1->inWeight[0];
303             double w1 = node1->inWeight[1];
304             int idx0 = get_cos_idx(w0, COS_MOD);
305             int idx1 = get_cos_idx(w1, COS_MOD);
306             int sgn0 = w0 < 0 ? 1 : 0;
307             int sgn1 = w1 < 0 ? 1 : 0;
308 
309             if (!visited[idx1][idx0][sgn1][sgn0]) {
310               visited[idx1][idx0][sgn1][sgn0] = 1;
311               printf("  __m128i %s = pair_set_epi16(%s, %s);\n",
312                      cos_text_sse2(w1, w0, COS_MOD, text, size),
313                      cos_text_arr(w1, COS_MOD, text1, size),
314                      cos_text_arr(w0, COS_MOD, text2, size));
315             }
316           } else {
317             double w0 = node1->inWeight[0];
318             double w1 = node1->inWeight[1];
319             int idx0 = get_cos_idx(w0, COS_MOD);
320             int idx1 = get_cos_idx(w1, COS_MOD);
321             int sgn0 = w0 < 0 ? 1 : 0;
322             int sgn1 = w1 < 0 ? 1 : 0;
323 
324             if (!visited[idx0][idx1][sgn0][sgn1]) {
325               visited[idx0][idx1][sgn0][sgn1] = 1;
326               printf("  __m128i %s = pair_set_epi16(%s, %s);\n",
327                      cos_text_sse2(w0, w1, COS_MOD, text, size),
328                      cos_text_arr(w0, COS_MOD, text1, size),
329                      cos_text_arr(w1, COS_MOD, text2, size));
330             }
331           }
332         }
333       }
334     }
335   }
336 }
337 
gen_code_sse2(Node * node,int stage_num,int node_num,TYPE_TXFM type)338 void gen_code_sse2(Node *node, int stage_num, int node_num, TYPE_TXFM type) {
339   char *fun_name = new char[100];
340   get_fun_name(fun_name, 100, type, node_num);
341 
342   printf("\n");
343   printf(
344       "void %s_sse2(const __m128i *input, __m128i *output, int8_t cos_bit) "
345       "{\n",
346       fun_name);
347 
348   printf("  const int32_t* cospi = cospi_arr(cos_bit);\n");
349   printf("  const __m128i __zero = _mm_setzero_si128();\n");
350   printf("  const __m128i __rounding = _mm_set1_epi32(1 << (cos_bit - 1));\n");
351 
352   graph_reset_visited(node, stage_num, node_num);
353   gen_cospi_list_sse2(node, stage_num, node_num);
354   graph_reset_visited(node, stage_num, node_num);
355   for (int si = 1; si < stage_num; si++) {
356     char in[100];
357     char out[100];
358     printf("\n");
359     printf("  // stage %d\n", si);
360     if (si == 1)
361       snprintf(in, 100, "%s", "input");
362     else
363       snprintf(in, 100, "x%d", si - 1);
364     if (si == stage_num - 1) {
365       snprintf(out, 100, "%s", "output");
366     } else {
367       snprintf(out, 100, "x%d", si);
368       printf("  __m128i %s[%d];\n", out, node_num);
369     }
370     // computation code
371     for (int ni = 0; ni < node_num; ni++) {
372       int idx = get_idx(si, ni, node_num);
373       node_to_code_sse2(node + idx, in, out);
374     }
375   }
376 
377   printf("}\n");
378 }
gen_cospi_list_sse4_1(Node * node,int stage_num,int node_num)379 void gen_cospi_list_sse4_1(Node *node, int stage_num, int node_num) {
380   int visited[65][2];
381   memset(visited, 0, sizeof(visited));
382   char text[100];
383   char text1[100];
384   int size = 100;
385   printf("\n");
386   for (int si = 1; si < stage_num; si++) {
387     for (int ni = 0; ni < node_num; ni++) {
388       int idx = get_idx(si, ni, node_num);
389       Node *node0 = node + idx;
390       if (node0->visited == 0) {
391         int cnt = 0;
392         node0->visited = 1;
393         for (int i = 0; i < 2; i++) {
394           if (fabs(node0->inWeight[i]) == 1 || fabs(node0->inWeight[i]) == 0)
395             cnt++;
396         }
397         if (cnt != 2) {
398           for (int i = 0; i < 2; i++) {
399             if (fabs(node0->inWeight[i]) != 1 &&
400                 fabs(node0->inWeight[i]) != 0) {
401               double w = node0->inWeight[i];
402               int idx = get_cos_idx(w, COS_MOD);
403               int sgn = w < 0 ? 1 : 0;
404 
405               if (!visited[idx][sgn]) {
406                 visited[idx][sgn] = 1;
407                 printf("  __m128i %s = _mm_set1_epi32(%s);\n",
408                        cos_text_sse4_1(w, COS_MOD, text, size),
409                        cos_text_arr(w, COS_MOD, text1, size));
410               }
411             }
412           }
413           Node *node1 = get_partner_node(node0);
414           node1->visited = 1;
415         }
416       }
417     }
418   }
419 }
420 
single_node_to_code_sse4_1(Node * node,const char * buf0,const char * buf1)421 void single_node_to_code_sse4_1(Node *node, const char *buf0,
422                                 const char *buf1) {
423   printf("  %s[%2d] =", buf1, node->nodeIdx);
424   if (node->inWeight[0] == 1 && node->inWeight[1] == 1) {
425     printf(" _mm_add_epi32(%s[%d], %s[%d])", buf0, node->inNodeIdx[0], buf0,
426            node->inNodeIdx[1]);
427   } else if (node->inWeight[0] == 1 && node->inWeight[1] == -1) {
428     printf(" _mm_sub_epi32(%s[%d], %s[%d])", buf0, node->inNodeIdx[0], buf0,
429            node->inNodeIdx[1]);
430   } else if (node->inWeight[0] == -1 && node->inWeight[1] == 1) {
431     printf(" _mm_sub_epi32(%s[%d], %s[%d])", buf0, node->inNodeIdx[1], buf0,
432            node->inNodeIdx[0]);
433   } else if (node->inWeight[0] == 1 && node->inWeight[1] == 0) {
434     printf(" %s[%d]", buf0, node->inNodeIdx[0]);
435   } else if (node->inWeight[0] == 0 && node->inWeight[1] == 1) {
436     printf(" %s[%d]", buf0, node->inNodeIdx[1]);
437   } else if (node->inWeight[0] == -1 && node->inWeight[1] == 0) {
438     printf(" _mm_sub_epi32(__zero, %s[%d])", buf0, node->inNodeIdx[0]);
439   } else if (node->inWeight[0] == 0 && node->inWeight[1] == -1) {
440     printf(" _mm_sub_epi32(__zero, %s[%d])", buf0, node->inNodeIdx[1]);
441   }
442   printf(";\n");
443 }
444 
pair_node_to_code_sse4_1(Node * node,Node * partnerNode,const char * buf0,const char * buf1)445 void pair_node_to_code_sse4_1(Node *node, Node *partnerNode, const char *buf0,
446                               const char *buf1) {
447   char temp0[100];
448   char temp1[100];
449   if (node->inWeight[0] * partnerNode->inWeight[0] < 0) {
450     /* type0
451      * cos  sin
452      * sin -cos
453      */
454     // btf_32_sse2_type0(w0, w1, in0, in1, out0, out1)
455     // out0 = w0*in0 + w1*in1
456     // out1 = -w0*in1 + w1*in0
457     printf(
458         "  btf_32_type0_sse4_1_new(%s, %s, %s[%d], %s[%d], %s[%d], %s[%d], "
459         "__rounding, cos_bit);\n",
460         cos_text_sse4_1(node->inWeight[0], COS_MOD, temp0, 100),
461         cos_text_sse4_1(node->inWeight[1], COS_MOD, temp1, 100), buf0,
462         node->inNodeIdx[0], buf0, node->inNodeIdx[1], buf1, node->nodeIdx, buf1,
463         partnerNode->nodeIdx);
464   } else {
465     /* type1
466      *  cos sin
467      * -sin cos
468      */
469     // btf_32_sse2_type1(w0, w1, in0, in1, out0, out1)
470     // out0 = w0*in0 + w1*in1
471     // out1 = w0*in1 - w1*in0
472     printf(
473         "  btf_32_type1_sse4_1_new(%s, %s, %s[%d], %s[%d], %s[%d], %s[%d], "
474         "__rounding, cos_bit);\n",
475         cos_text_sse4_1(node->inWeight[0], COS_MOD, temp0, 100),
476         cos_text_sse4_1(node->inWeight[1], COS_MOD, temp1, 100), buf0,
477         node->inNodeIdx[0], buf0, node->inNodeIdx[1], buf1, node->nodeIdx, buf1,
478         partnerNode->nodeIdx);
479   }
480 }
481 
node_to_code_sse4_1(Node * node,const char * buf0,const char * buf1)482 void node_to_code_sse4_1(Node *node, const char *buf0, const char *buf1) {
483   int cnt = 0;
484   int cnt1 = 0;
485   if (node->visited == 0) {
486     node->visited = 1;
487     for (int i = 0; i < 2; i++) {
488       if (fabs(node->inWeight[i]) == 1 || fabs(node->inWeight[i]) == 0) cnt++;
489       if (fabs(node->inWeight[i]) == 1) cnt1++;
490     }
491     if (cnt == 2) {
492       if (cnt1 == 2) {
493         // has a partner
494         Node *partnerNode = get_partner_node(node);
495         partnerNode->visited = 1;
496         single_node_to_code_sse4_1(node, buf0, buf1);
497         single_node_to_code_sse4_1(partnerNode, buf0, buf1);
498       } else {
499         single_node_to_code_sse2(node, buf0, buf1);
500       }
501     } else {
502       Node *partnerNode = get_partner_node(node);
503       partnerNode->visited = 1;
504       pair_node_to_code_sse4_1(node, partnerNode, buf0, buf1);
505     }
506   }
507 }
508 
gen_code_sse4_1(Node * node,int stage_num,int node_num,TYPE_TXFM type)509 void gen_code_sse4_1(Node *node, int stage_num, int node_num, TYPE_TXFM type) {
510   char *fun_name = new char[100];
511   get_fun_name(fun_name, 100, type, node_num);
512 
513   printf("\n");
514   printf(
515       "void %s_sse4_1(const __m128i *input, __m128i *output, int8_t cos_bit) "
516       "{\n",
517       fun_name);
518 
519   printf("  const int32_t* cospi = cospi_arr(cos_bit);\n");
520   printf("  const __m128i __zero = _mm_setzero_si128();\n");
521   printf("  const __m128i __rounding = _mm_set1_epi32(1 << (cos_bit - 1));\n");
522 
523   graph_reset_visited(node, stage_num, node_num);
524   gen_cospi_list_sse4_1(node, stage_num, node_num);
525   graph_reset_visited(node, stage_num, node_num);
526   for (int si = 1; si < stage_num; si++) {
527     char in[100];
528     char out[100];
529     printf("\n");
530     printf("  // stage %d\n", si);
531     if (si == 1)
532       snprintf(in, 100, "%s", "input");
533     else
534       snprintf(in, 100, "x%d", si - 1);
535     if (si == stage_num - 1) {
536       snprintf(out, 100, "%s", "output");
537     } else {
538       snprintf(out, 100, "x%d", si);
539       printf("  __m128i %s[%d];\n", out, node_num);
540     }
541     // computation code
542     for (int ni = 0; ni < node_num; ni++) {
543       int idx = get_idx(si, ni, node_num);
544       node_to_code_sse4_1(node + idx, in, out);
545     }
546   }
547 
548   printf("}\n");
549 }
550 
gen_hybrid_code(CODE_TYPE code_type,TYPE_TXFM txfm_type,int node_num)551 void gen_hybrid_code(CODE_TYPE code_type, TYPE_TXFM txfm_type, int node_num) {
552   int stage_num = get_hybrid_stage_num(txfm_type, node_num);
553 
554   Node *node = new Node[node_num * stage_num];
555   init_graph(node, stage_num, node_num);
556 
557   gen_hybrid_graph_1d(node, stage_num, node_num, 0, 0, node_num, txfm_type);
558 
559   switch (code_type) {
560     case CODE_TYPE_C: gen_code_c(node, stage_num, node_num, txfm_type); break;
561     case CODE_TYPE_SSE2:
562       gen_code_sse2(node, stage_num, node_num, txfm_type);
563       break;
564     case CODE_TYPE_SSE4_1:
565       gen_code_sse4_1(node, stage_num, node_num, txfm_type);
566       break;
567   }
568 
569   delete[] node;
570 }
571 
main(int argc,char ** argv)572 int main(int argc, char **argv) {
573   CODE_TYPE code_type = CODE_TYPE_SSE4_1;
574   for (int txfm_type = TYPE_DCT; txfm_type < TYPE_LAST; txfm_type++) {
575     for (int node_num = 4; node_num <= 64; node_num *= 2) {
576       gen_hybrid_code(code_type, (TYPE_TXFM)txfm_type, node_num);
577     }
578   }
579   return 0;
580 }
581