1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14 
15 #include "onnx.pb.h"
16 
17 #include <algorithm>
18 #include <float.h>
19 #include <fstream>
20 #include <google/protobuf/io/coded_stream.h>
21 #include <google/protobuf/io/zero_copy_stream_impl.h>
22 #include <google/protobuf/message.h>
23 #include <google/protobuf/text_format.h>
24 #include <iostream>
25 #include <limits.h>
26 #include <limits>
27 #include <set>
28 #include <stdio.h>
29 
read_proto_from_binary(const char * filepath,onnx::ModelProto * message)30 static bool read_proto_from_binary(const char* filepath, onnx::ModelProto* message)
31 {
32     std::ifstream fs(filepath, std::ifstream::in | std::ifstream::binary);
33     if (!fs.is_open())
34     {
35         fprintf(stderr, "open failed %s\n", filepath);
36         return false;
37     }
38 
39     google::protobuf::io::IstreamInputStream input(&fs);
40     google::protobuf::io::CodedInputStream codedstr(&input);
41 
42 #if GOOGLE_PROTOBUF_VERSION >= 3011000
43     codedstr.SetTotalBytesLimit(INT_MAX);
44 #else
45     codedstr.SetTotalBytesLimit(INT_MAX, INT_MAX / 2);
46 #endif
47 
48     bool success = message->ParseFromCodedStream(&codedstr);
49 
50     fs.close();
51 
52     return success;
53 }
54 
get_node_attr_ai(const onnx::NodeProto & node,const char * key)55 static std::vector<int> get_node_attr_ai(const onnx::NodeProto& node, const char* key)
56 {
57     std::vector<int> v;
58 
59     for (int i = 0; i < node.attribute_size(); i++)
60     {
61         const onnx::AttributeProto& attr = node.attribute(i);
62         if (attr.name() == key)
63         {
64             v.resize(attr.ints_size());
65             for (int j = 0; j < attr.ints_size(); j++)
66             {
67                 v[j] = std::max(std::min(attr.ints(j), (::google::protobuf::int64)INT_MAX), (::google::protobuf::int64)INT_MIN);
68             }
69 
70             break;
71         }
72     }
73 
74     return v;
75 }
76 
get_node_attr_af(const onnx::NodeProto & node,const char * key)77 static std::vector<float> get_node_attr_af(const onnx::NodeProto& node, const char* key)
78 {
79     std::vector<float> v;
80 
81     for (int i = 0; i < node.attribute_size(); i++)
82     {
83         const onnx::AttributeProto& attr = node.attribute(i);
84         if (attr.name() == key)
85         {
86             v.resize(attr.floats_size());
87             for (int j = 0; j < attr.floats_size(); j++)
88             {
89                 v[j] = attr.floats(j);
90             }
91 
92             break;
93         }
94     }
95 
96     return v;
97 }
98 
get_node_attr_i(const onnx::NodeProto & node,const char * key,int def=0)99 static int get_node_attr_i(const onnx::NodeProto& node, const char* key, int def = 0)
100 {
101     for (int i = 0; i < node.attribute_size(); i++)
102     {
103         const onnx::AttributeProto& attr = node.attribute(i);
104         if (attr.name() == key)
105         {
106             return std::max(std::min(attr.i(), (::google::protobuf::int64)INT_MAX), (::google::protobuf::int64)INT_MIN);
107         }
108     }
109 
110     return def;
111 }
112 
get_node_attr_f(const onnx::NodeProto & node,const char * key,float def=0.f)113 static float get_node_attr_f(const onnx::NodeProto& node, const char* key, float def = 0.f)
114 {
115     for (int i = 0; i < node.attribute_size(); i++)
116     {
117         const onnx::AttributeProto& attr = node.attribute(i);
118         if (attr.name() == key)
119         {
120             return attr.f();
121         }
122     }
123 
124     return def;
125 }
126 
get_node_attr_s(const onnx::NodeProto & node,const char * key,const std::string & def=std::string ())127 static std::string get_node_attr_s(const onnx::NodeProto& node, const char* key, const std::string& def = std::string())
128 {
129     for (int i = 0; i < node.attribute_size(); i++)
130     {
131         const onnx::AttributeProto& attr = node.attribute(i);
132         if (attr.name() == key)
133         {
134             return attr.s();
135         }
136     }
137 
138     return def;
139 }
140 
get_node_attr_tensor(const onnx::NodeProto & node,const char * key)141 static onnx::TensorProto get_node_attr_tensor(const onnx::NodeProto& node, const char* key)
142 {
143     for (int i = 0; i < node.attribute_size(); i++)
144     {
145         const onnx::AttributeProto& attr = node.attribute(i);
146         if (attr.name() == key)
147         {
148             return attr.t();
149         }
150     }
151 
152     return onnx::TensorProto();
153 }
154 
get_node_attr_from_input_f(const onnx::TensorProto & tp)155 static float get_node_attr_from_input_f(const onnx::TensorProto& tp)
156 {
157     float v = 0.f;
158 
159     // float
160     if (tp.data_type() == 1)
161     {
162         const float* shape_data = 0;
163         if (tp.has_raw_data())
164         {
165             shape_data = (const float*)tp.raw_data().data();
166         }
167         else
168         {
169             shape_data = tp.float_data().data();
170         }
171         v = shape_data[0];
172     }
173     // double
174     else if (tp.data_type() == 11)
175     {
176         const double* shape_data = 0;
177         if (tp.has_raw_data())
178         {
179             shape_data = (const double*)tp.raw_data().data();
180         }
181         else
182         {
183             shape_data = tp.double_data().data();
184         }
185         v = shape_data[0];
186     }
187     // int64
188     else if (tp.data_type() == 7)
189     {
190         const int64_t* shape_data = 0;
191         if (tp.has_raw_data())
192         {
193             shape_data = (const int64_t*)tp.raw_data().data();
194         }
195         else
196         {
197             shape_data = tp.int64_data().data();
198         }
199         v = std::max(std::min(shape_data[0], (::google::protobuf::int64)INT_MAX), (::google::protobuf::int64)INT_MIN);
200     }
201     // int32
202     else if (tp.data_type() == 6)
203     {
204         const int32_t* shape_data = 0;
205         if (tp.has_raw_data())
206         {
207             shape_data = (const int32_t*)tp.raw_data().data();
208         }
209         else
210         {
211             shape_data = tp.int32_data().data();
212         }
213         v = shape_data[0];
214     }
215     else
216     {
217         fprintf(stderr, "Unknown data type %d\n", tp.data_type());
218         abort();
219     }
220 
221     return v;
222 }
223 
get_node_attr_from_input_ai(const onnx::TensorProto & tp)224 static std::vector<int> get_node_attr_from_input_ai(const onnx::TensorProto& tp)
225 {
226     int size = 0;
227 
228     std::vector<int> v;
229 
230     // int64
231     if (tp.data_type() == 7)
232     {
233         const int64_t* shape_data = 0;
234         if (tp.has_raw_data())
235         {
236             shape_data = (const int64_t*)tp.raw_data().data();
237             size = (int)(tp.raw_data().size() / 8);
238         }
239         else
240         {
241             shape_data = tp.int64_data().data();
242             size = tp.int64_data_size();
243         }
244         for (int j = 0; j < size; j++)
245         {
246             int vi = std::max(std::min(shape_data[j], (::google::protobuf::int64)INT_MAX), (::google::protobuf::int64)INT_MIN);
247             v.push_back(vi);
248         }
249     }
250     // int32
251     else if (tp.data_type() == 6)
252     {
253         const int32_t* shape_data = 0;
254         if (tp.has_raw_data())
255         {
256             shape_data = (const int32_t*)tp.raw_data().data();
257             size = (int)(tp.raw_data().size() / 4);
258         }
259         else
260         {
261             shape_data = tp.int32_data().data();
262             size = tp.int32_data_size();
263         }
264         for (int j = 0; j < size; j++)
265         {
266             v.push_back(shape_data[j]);
267         }
268     }
269     else
270     {
271         fprintf(stderr, "Unknown data type %d\n", tp.data_type());
272     }
273 
274     return v;
275 }
276 
get_node_attr_from_input_af(const onnx::TensorProto & tp)277 static std::vector<float> get_node_attr_from_input_af(const onnx::TensorProto& tp)
278 {
279     int size = 0;
280 
281     std::vector<float> v;
282 
283     // float
284     if (tp.data_type() == 1)
285     {
286         const float* shape_data = 0;
287         if (tp.has_raw_data())
288         {
289             shape_data = (const float*)tp.raw_data().data();
290             size = (int)(tp.raw_data().size() / 4);
291         }
292         else
293         {
294             shape_data = tp.float_data().data();
295             size = tp.float_data_size();
296         }
297         for (int j = 0; j < size; j++)
298         {
299             v.push_back(shape_data[j]);
300         }
301     }
302     // double
303     else if (tp.data_type() == 11)
304     {
305         const double* shape_data = 0;
306         if (tp.has_raw_data())
307         {
308             shape_data = (const double*)tp.raw_data().data();
309             size = (int)(tp.raw_data().size() / 8);
310         }
311         else
312         {
313             shape_data = tp.double_data().data();
314             size = tp.double_data_size();
315         }
316         for (int j = 0; j < size; j++)
317         {
318             v.push_back((float)shape_data[j]);
319         }
320     }
321     else
322     {
323         fprintf(stderr, "Unknown data type %d\n", tp.data_type());
324     }
325 
326     return v;
327 }
328 
get_tensor_proto_data_size(const onnx::TensorProto & tp)329 static int get_tensor_proto_data_size(const onnx::TensorProto& tp)
330 {
331     if (tp.has_raw_data())
332     {
333         const std::string& raw_data = tp.raw_data();
334         int size = (int)raw_data.size() / 4;
335         return size;
336     }
337     else if (tp.data_type() == 1)
338     {
339         return tp.float_data_size();
340     }
341 
342     return 0;
343 }
344 
fwrite_tensor_proto_data(const onnx::TensorProto & tp,FILE * bp)345 static void fwrite_tensor_proto_data(const onnx::TensorProto& tp, FILE* bp)
346 {
347     int size = get_tensor_proto_data_size(tp);
348 
349     if (tp.has_raw_data())
350     {
351         const std::string& raw_data = tp.raw_data();
352         fwrite(raw_data.data(), sizeof(float), size, bp);
353     }
354     else if (tp.data_type() == 1)
355     {
356         fwrite(tp.float_data().data(), sizeof(float), size, bp);
357     }
358 }
359 
fuse_weight_reshape(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)360 static void fuse_weight_reshape(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
361 {
362     int node_count = mutable_graph->node_size();
363     for (int i = 0; i < node_count; i++)
364     {
365         onnx::NodeProto* node = mutable_graph->mutable_node(i);
366 
367         // weight <= Reshape(weight)
368         if (node->op_type() == "Reshape")
369         {
370             // check weight
371             if (weights.find(node->input(0)) == weights.end())
372                 continue;
373 
374             weights[node->output(0)] = weights[node->input(0)];
375 
376             // set weight shape directly
377             std::vector<int> shape;
378             if (node->input_size() == 1)
379             {
380                 shape = get_node_attr_ai(*node, "shape");
381             }
382             else if (node->input_size() == 2)
383             {
384                 // opset 5
385                 shape = get_node_attr_from_input_ai(weights[node->input(1)]);
386             }
387 
388             weights[node->output(0)].clear_dims();
389             for (int j = 0; j < shape.size(); j++)
390             {
391                 weights[node->output(0)].add_dims(shape[j]);
392             }
393 
394             // reduce
395             node->set_op_type("noop_reducedncnn");
396 
397             node_reference[node->input(0)] -= 1;
398             if (node->input_size() == 2)
399             {
400                 node_reference[node->input(1)] -= 1;
401             }
402 
403             reduced_node_count += 1;
404             i += 1;
405         }
406     }
407 }
408 
fuse_weight_transpose(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)409 static void fuse_weight_transpose(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
410 {
411     int node_count = mutable_graph->node_size();
412     for (int i = 0; i < node_count; i++)
413     {
414         onnx::NodeProto* node = mutable_graph->mutable_node(i);
415 
416         // weight <= Transpose(weight)
417         if (node->op_type() == "Transpose")
418         {
419             // check weight
420             if (weights.find(node->input(0)) == weights.end())
421                 continue;
422 
423             if (weights[node->input(0)].dims_size() != 2)
424                 continue;
425 
426             // perm = (1, 0)
427             std::vector<int> perm = get_node_attr_ai(*node, "perm");
428             if (perm.size() != 2)
429                 continue;
430             if (perm[0] != 1 || perm[1] != 0)
431                 continue;
432 
433             weights[node->output(0)] = weights[node->input(0)];
434 
435             // permute weight
436             {
437                 onnx::TensorProto& B = weights[node->output(0)];
438 
439                 const int h = B.dims(0);
440                 const int w = B.dims(1);
441 
442                 std::vector<float> permuted_data;
443                 permuted_data.reserve((size_t)h * w);
444                 const float* bptr = B.has_raw_data() ? (const float*)B.raw_data().data() : B.float_data().data();
445 
446                 for (int j = 0; j < w; j++)
447                 {
448                     for (int k = 0; k < h; k++)
449                     {
450                         float vb = bptr[k * w + j];
451                         permuted_data.push_back(vb);
452                     }
453                 }
454 
455                 B.set_dims(0, w);
456                 B.set_dims(1, h);
457 
458                 if (B.has_raw_data())
459                 {
460                     B.set_raw_data(permuted_data.data(), permuted_data.size() * sizeof(float));
461                 }
462                 else
463                 {
464                     for (int j = 0; j < (int)permuted_data.size(); j++)
465                         B.set_float_data(j, permuted_data[j]);
466                 }
467             }
468 
469             // reduce
470             node->set_op_type("noop_reducedncnn");
471 
472             node_reference[node->input(0)] -= 1;
473 
474             reduced_node_count += 1;
475             i += 1;
476         }
477     }
478 }
479 
fuse_shufflechannel(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)480 static void fuse_shufflechannel(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
481 {
482     int node_count = mutable_graph->node_size();
483     for (int i = 0; i < node_count; i++)
484     {
485         onnx::NodeProto* node = mutable_graph->mutable_node(i);
486 
487         // ShuffleChannel <= Reshape - Transpose - Reshape
488         // ShuffleChannel <= Reshape - Transpose - Constant - Reshape
489         if (node->op_type() == "Reshape")
490         {
491             if (node_reference[node->output(0)] != 1)
492                 continue;
493 
494             std::vector<int> shape;
495             if (node->input_size() == 1)
496             {
497                 shape = get_node_attr_ai(*node, "shape");
498             }
499             else
500             {
501                 // skip weight reshape
502                 if (weights.find(node->input(1)) == weights.end())
503                     continue;
504 
505                 shape = get_node_attr_from_input_ai(weights[node->input(1)]);
506             }
507 
508             // 1 groups channels_per_group, height, width
509             // reverse style = channels_per_group, groups, height * width
510             if (shape.size() != 5 && shape.size() != 3)
511                 continue;
512 
513             if (shape.size() == 5 && shape[0] != 1)
514                 continue;
515 
516             if (i + 2 >= node_count)
517                 continue;
518 
519             onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
520             onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
521 
522             if (node3->op_type() == "Constant")
523             {
524                 if (i + 3 >= node_count)
525                     continue;
526 
527                 node3 = mutable_graph->mutable_node(i + 3);
528             }
529 
530             if (node2->op_type() != "Transpose" || node3->op_type() != "Reshape")
531                 continue;
532 
533             if (node_reference[node2->output(0)] != 1)
534                 continue;
535 
536             // 0 2 1 3 4
537             // reverse style = 1 0 2
538             std::vector<int> perm = get_node_attr_ai(*node2, "perm");
539             if (perm.size() != 5 && perm.size() != 3)
540                 continue;
541 
542             if (perm.size() == 5 && (perm[0] != 0 || perm[1] != 2 || perm[2] != 1 || perm[3] != 3 || perm[4] != 4))
543                 continue;
544 
545             if (perm.size() == 3 && (perm[0] != 1 || perm[1] != 0 || perm[2] != 2))
546                 continue;
547 
548             std::vector<int> shape3;
549             if (node3->input_size() == 1)
550             {
551                 shape3 = get_node_attr_ai(*node3, "shape");
552             }
553             else
554             {
555                 // skip weight reshape
556                 if (weights.find(node3->input(1)) == weights.end())
557                     continue;
558 
559                 shape3 = get_node_attr_from_input_ai(weights[node3->input(1)]);
560             }
561 
562             // 1, -1, height, width
563             // reverse style = group, -1, channels_per_group, height, width
564             if (shape3.size() != 4 && shape3.size() != 5)
565                 continue;
566 
567             if (shape3.size() == 4 && (shape3[0] != 1 || (shape3[1] != -1 && shape3[1] != shape[1] * shape[2])))
568                 continue;
569 
570             if (shape3.size() == 5 && (shape3[0] != shape[1] || shape3[2] != shape[0] || shape3[3] * shape3[4] != shape[2]))
571                 continue;
572 
573             // reduce
574             node->set_op_type("noop_reducedncnn");
575             node2->set_op_type("noop_reducedncnn");
576 
577             if (node->input_size() == 2)
578             {
579                 node_reference[node->input(1)] -= 1;
580             }
581             node_reference[node->output(0)] -= 1;
582             node_reference[node2->output(0)] -= 1;
583             if (node3->input_size() == 2)
584             {
585                 node_reference[node3->input(1)] -= 1;
586             }
587 
588             blob_names.erase(node->output(0));
589             blob_names.erase(node2->output(0));
590 
591             node3->set_op_type("ShuffleChannel");
592             node3->set_input(0, node->input(0));
593 
594             onnx::AttributeProto* attr_group = node3->add_attribute();
595             attr_group->set_name("group");
596             attr_group->set_i(shape[1]);
597 
598             onnx::AttributeProto* attr_reverse = node3->add_attribute();
599             attr_reverse->set_name("reverse");
600             attr_reverse->set_i(shape.size() == 3);
601 
602             reduced_node_count += 2;
603             i += 2;
604         }
605     }
606 }
607 
fuse_shufflechannel_split(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)608 static void fuse_shufflechannel_split(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
609 {
610     int node_count = mutable_graph->node_size();
611     for (int i = 0; i < node_count; i++)
612     {
613         onnx::NodeProto* node = mutable_graph->mutable_node(i);
614 
615         // Split <= ShuffleChannel(reverse type) - Gather(0) - Gather(1)
616         if (node->op_type() == "ShuffleChannel")
617         {
618             // reverse = 1
619             int reverse = get_node_attr_i(*node, "reverse");
620             if (reverse != 1)
621                 continue;
622 
623             if (i + 2 >= node_count)
624                 continue;
625 
626             onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
627             onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
628 
629             if (node2->op_type() != "Gather" || node3->op_type() != "Gather")
630                 continue;
631 
632             if (node2->input(0) != node->output(0) || node3->input(0) != node->output(0))
633                 continue;
634 
635             // axis = 0
636             int gather2_axis = get_node_attr_i(*node2, "axis");
637             if (gather2_axis != 0)
638                 continue;
639 
640             // indices = 0
641             if (weights.find(node2->input(1)) == weights.end())
642                 continue;
643 
644             std::vector<int> gather2_indices = get_node_attr_from_input_ai(weights[node2->input(1)]);
645             if (gather2_indices.size() != 1 || gather2_indices[0] != 0)
646                 continue;
647 
648             // axis = 0
649             int gather3_axis = get_node_attr_i(*node3, "axis");
650             if (gather3_axis != 0)
651                 continue;
652 
653             // indices = 1
654             if (weights.find(node3->input(1)) == weights.end())
655                 continue;
656 
657             std::vector<int> gather3_indices = get_node_attr_from_input_ai(weights[node3->input(1)]);
658             if (gather3_indices.size() != 1 || gather3_indices[0] != 1)
659                 continue;
660 
661             // reduce
662             node2->set_op_type("noop_reducedncnn");
663 
664             node_reference[node->output(0)] -= 2;
665             node_reference[node2->input(1)] -= 1;
666             node_reference[node3->input(1)] -= 1;
667 
668             node3->set_op_type("Split");
669             node3->clear_input();
670             node3->add_input(node->output(0));
671             node3->add_output(node3->output(0));
672             node3->set_output(0, node2->output(0));
673 
674             node3->clear_attribute();
675             onnx::AttributeProto* attr_axis = node3->add_attribute();
676             attr_axis->set_name("axis");
677             attr_axis->set_i(1);
678 
679             reduced_node_count += 1;
680             i += 1;
681         }
682     }
683 }
684 
fuse_hardswish(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)685 static void fuse_hardswish(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
686 {
687     int node_count = mutable_graph->node_size();
688     for (int i = 0; i < node_count; i++)
689     {
690         onnx::NodeProto* node = mutable_graph->mutable_node(i);
691 
692         // HardSwish <= Add(+3) - Clip(0,6) - Mul(X,) - Div(/6)
693         // HardSwish <= Add(+3) - Clip(0,6) - Mul(X,) - Mul(*(1/6))
694         // HardSwish <= Add(+3) - Clip(0,6) - Mul(X,) - Constant - Div(/6)
695         // HardSwish <= Add(+3) - Clip(0,6) - Mul(X,) - Constant - Mul(*(1/6))
696         //     out = x * F.relu6(x + 3, inplace=True) / 6
697         if (node->op_type() == "Add")
698         {
699             if (node_reference[node->output(0)] != 1)
700                 continue;
701 
702             if (i + 3 >= node_count)
703                 continue;
704 
705             if (weights.find(node->input(1)) == weights.end())
706                 continue;
707 
708             const onnx::TensorProto& add_three = weights[node->input(1)];
709             if (add_three.dims_size() != 0 || get_tensor_proto_data_size(add_three) != 1)
710                 continue;
711 
712             float constant_add_three = get_node_attr_from_input_f(add_three);
713             if (constant_add_three != 3.f)
714                 continue;
715 
716             onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
717             onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
718             onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3);
719 
720             if (node4->op_type() == "Constant")
721             {
722                 if (i + 4 >= node_count)
723                     continue;
724 
725                 node4 = mutable_graph->mutable_node(i + 4);
726             }
727 
728             if (node2->op_type() != "Clip" || node3->op_type() != "Mul" || (node4->op_type() != "Div" && node4->op_type() != "Mul"))
729                 continue;
730 
731             if (node_reference[node2->output(0)] != 1)
732                 continue;
733 
734             float relu6_min;
735             float relu6_max;
736             if (node2->input_size() == 1)
737             {
738                 relu6_min = get_node_attr_f(*node2, "min", -FLT_MAX);
739                 relu6_max = get_node_attr_f(*node2, "max", FLT_MAX);
740             }
741             else
742             {
743                 const onnx::TensorProto& min_tp = weights[node2->input(1)];
744                 const onnx::TensorProto& max_tp = weights[node2->input(2)];
745 
746                 relu6_min = get_node_attr_from_input_f(min_tp);
747                 relu6_max = get_node_attr_from_input_f(max_tp);
748             }
749             if (relu6_min != 0.f || relu6_max != 6.f)
750                 continue;
751 
752             if (node_reference[node3->output(0)] != 1)
753                 continue;
754 
755             if (node3->input(0) != node->input(0) || node3->input(1) != node2->output(0))
756                 continue;
757 
758             if (weights.find(node4->input(1)) == weights.end())
759                 continue;
760 
761             const onnx::TensorProto& div_six = weights[node4->input(1)];
762             if (div_six.dims_size() != 0 || get_tensor_proto_data_size(div_six) != 1)
763                 continue;
764 
765             float constant_div_six = get_node_attr_from_input_f(div_six);
766             if (node4->op_type() == "Div" && constant_div_six != 6.f)
767                 continue;
768             if (node4->op_type() == "Mul" && constant_div_six != 1 / 6.f)
769                 continue;
770 
771             // reduce
772             node->set_op_type("noop_reducedncnn");
773             node2->set_op_type("noop_reducedncnn");
774             node3->set_op_type("noop_reducedncnn");
775 
776             node_reference[node->input(0)] -= 1;
777             node_reference[node->input(1)] -= 1;
778             node_reference[node->output(0)] -= 1;
779             if (node2->input_size() == 3)
780             {
781                 node_reference[node2->input(1)] -= 1;
782                 node_reference[node2->input(2)] -= 1;
783             }
784             node_reference[node2->output(0)] -= 1;
785             node_reference[node3->output(0)] -= 1;
786             node_reference[node4->input(1)] -= 1;
787 
788             blob_names.erase(node->output(0));
789             blob_names.erase(node2->output(0));
790             blob_names.erase(node3->output(0));
791 
792             node4->set_op_type("HardSwish");
793             node4->clear_input();
794             node4->add_input(node->input(0));
795 
796             onnx::AttributeProto* attr_alpha = node4->add_attribute();
797             attr_alpha->set_name("alpha");
798             attr_alpha->set_f(1.f / 6.f);
799 
800             onnx::AttributeProto* attr_beta = node4->add_attribute();
801             attr_beta->set_name("beta");
802             attr_beta->set_f(3.f / 6.f);
803 
804             reduced_node_count += 3;
805             i += 3;
806         }
807     }
808 
809     for (int i = 0; i < node_count; i++)
810     {
811         onnx::NodeProto* node = mutable_graph->mutable_node(i);
812 
813         // HardSwish <= HardSigmoid - Mul
814         //     out = x * hsigmoid(x)
815         if (node->op_type() == "HardSigmoid")
816         {
817             if (node_reference[node->output(0)] != 1)
818                 continue;
819 
820             float alpha = get_node_attr_f(*node, "alpha", 0.2f);
821             float beta = get_node_attr_f(*node, "beta", 0.5f);
822 
823             if (i + 1 >= node_count)
824                 continue;
825 
826             onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
827 
828             if (node2->op_type() != "Mul")
829                 continue;
830 
831             if (node2->input(0) != node->input(0) || node2->input(1) != node->output(0))
832                 continue;
833 
834             // reduce
835             node->set_op_type("noop_reducedncnn");
836 
837             node_reference[node->input(0)] -= 1;
838             node_reference[node->output(0)] -= 1;
839 
840             blob_names.erase(node->output(0));
841 
842             node2->set_op_type("HardSwish");
843             node2->clear_input();
844             node2->add_input(node->input(0));
845 
846             onnx::AttributeProto* attr_alpha = node2->add_attribute();
847             attr_alpha->set_name("alpha");
848             attr_alpha->set_f(alpha);
849 
850             onnx::AttributeProto* attr_beta = node2->add_attribute();
851             attr_beta->set_name("beta");
852             attr_beta->set_f(beta);
853 
854             reduced_node_count += 1;
855             i += 1;
856         }
857     }
858 }
859 
fuse_hardsigmoid(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)860 static void fuse_hardsigmoid(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
861 {
862     int node_count = mutable_graph->node_size();
863     for (int i = 0; i < node_count; i++)
864     {
865         onnx::NodeProto* node = mutable_graph->mutable_node(i);
866 
867         // HardSigmoid <= Add(+3) - Clip(0,6) - Div(/6)
868         // HardSigmoid <= Add(+3) - Clip(0,6) - Mul(*(1/6))
869         // HardSigmoid <= Add(+3) - Clip(0,6) - Constant - Div(/6)
870         // HardSigmoid <= Add(+3) - Clip(0,6) - Constant - Mul(*(1/6))
871         //     out = F.relu6(x + 3, inplace=True) / 6
872         if (node->op_type() == "Add")
873         {
874             if (node_reference[node->output(0)] != 1)
875                 continue;
876 
877             if (i + 2 >= node_count)
878                 continue;
879 
880             if (weights.find(node->input(1)) == weights.end())
881                 continue;
882 
883             const onnx::TensorProto& add_three = weights[node->input(1)];
884             if (add_three.dims_size() != 0 || get_tensor_proto_data_size(add_three) != 1)
885                 continue;
886 
887             float constant_add_three = get_node_attr_from_input_f(add_three);
888             if (constant_add_three != 3.f)
889                 continue;
890 
891             onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
892             onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
893 
894             if (node3->op_type() == "Constant")
895             {
896                 if (i + 3 >= node_count)
897                     continue;
898 
899                 node3 = mutable_graph->mutable_node(i + 3);
900             }
901 
902             if (node2->op_type() != "Clip" || (node3->op_type() != "Div" && node3->op_type() != "Mul"))
903                 continue;
904 
905             if (node_reference[node2->output(0)] != 1)
906                 continue;
907 
908             float relu6_min;
909             float relu6_max;
910             if (node2->input_size() == 1)
911             {
912                 relu6_min = get_node_attr_f(*node2, "min", -FLT_MAX);
913                 relu6_max = get_node_attr_f(*node2, "max", FLT_MAX);
914             }
915             else
916             {
917                 const onnx::TensorProto& min_tp = weights[node2->input(1)];
918                 const onnx::TensorProto& max_tp = weights[node2->input(2)];
919 
920                 relu6_min = get_node_attr_from_input_f(min_tp);
921                 relu6_max = get_node_attr_from_input_f(max_tp);
922             }
923             if (relu6_min != 0.f || relu6_max != 6.f)
924                 continue;
925 
926             if (weights.find(node3->input(1)) == weights.end())
927                 continue;
928 
929             const onnx::TensorProto& div_six = weights[node3->input(1)];
930             if (div_six.dims_size() != 0 || get_tensor_proto_data_size(div_six) != 1)
931                 continue;
932 
933             float constant_div_six = get_node_attr_from_input_f(div_six);
934             if (node3->op_type() == "Div" && constant_div_six != 6.f)
935                 continue;
936             if (node3->op_type() == "Mul" && constant_div_six != 1 / 6.f)
937                 continue;
938 
939             // reduce
940             node->set_op_type("noop_reducedncnn");
941             node2->set_op_type("noop_reducedncnn");
942 
943             node_reference[node->input(1)] -= 1;
944             node_reference[node->output(0)] -= 1;
945             if (node2->input_size() == 3)
946             {
947                 node_reference[node2->input(1)] -= 1;
948                 node_reference[node2->input(2)] -= 1;
949             }
950             node_reference[node2->output(0)] -= 1;
951             node_reference[node3->input(1)] -= 1;
952 
953             blob_names.erase(node->output(0));
954             blob_names.erase(node2->output(0));
955 
956             node3->set_op_type("HardSigmoid");
957             node3->clear_input();
958             node3->add_input(node->input(0));
959 
960             onnx::AttributeProto* attr_alpha = node3->add_attribute();
961             attr_alpha->set_name("alpha");
962             attr_alpha->set_f(1.f / 6.f);
963 
964             onnx::AttributeProto* attr_beta = node3->add_attribute();
965             attr_beta->set_name("beta");
966             attr_beta->set_f(3.f / 6.f);
967 
968             reduced_node_count += 2;
969             i += 2;
970         }
971     }
972 }
973 
fuse_swish(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)974 static void fuse_swish(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
975 {
976     int node_count = mutable_graph->node_size();
977     for (int i = 0; i < node_count; i++)
978     {
979         onnx::NodeProto* node = mutable_graph->mutable_node(i);
980 
981         // Swish <= Sigmoid - Mul
982         //     x * torch.sigmoid(x)
983         if (node->op_type() == "Sigmoid")
984         {
985             if (node_reference[node->output(0)] != 1)
986                 continue;
987 
988             if (i + 1 >= node_count)
989                 continue;
990 
991             onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
992 
993             if (node2->op_type() != "Mul")
994                 continue;
995 
996             if (node2->input(0) != node->input(0) || node2->input(1) != node->output(0))
997                 continue;
998 
999             // reduce
1000             node->set_op_type("noop_reducedncnn");
1001 
1002             node_reference[node->input(0)] -= 1;
1003             node_reference[node->output(0)] -= 1;
1004 
1005             blob_names.erase(node->output(0));
1006 
1007             node2->set_op_type("Swish");
1008             node2->clear_input();
1009             node2->add_input(node->input(0));
1010 
1011             reduced_node_count += 1;
1012             i += 1;
1013         }
1014     }
1015 }
1016 
fuse_batchnorm1d_squeeze_unsqueeze(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)1017 static void fuse_batchnorm1d_squeeze_unsqueeze(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
1018 {
1019     int node_count = mutable_graph->node_size();
1020     for (int i = 0; i < node_count; i++)
1021     {
1022         onnx::NodeProto* node = mutable_graph->mutable_node(i);
1023 
1024         // BatchNormalization <= Unsqueeze - BatchNormalization - Squeeze
1025         if (node->op_type() == "Unsqueeze")
1026         {
1027             if (node_reference[node->output(0)] != 1)
1028                 continue;
1029 
1030             if (i + 2 >= node_count)
1031                 continue;
1032 
1033             onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
1034             onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
1035 
1036             if (node2->op_type() != "BatchNormalization" || node3->op_type() != "Squeeze")
1037                 continue;
1038 
1039             if (node_reference[node2->output(0)] != 1)
1040                 continue;
1041 
1042             if (node2->input(0) != node->output(0) || node3->input(0) != node2->output(0))
1043                 continue;
1044 
1045             // reduce
1046             node->set_op_type("noop_reducedncnn");
1047             node3->set_op_type("noop_reducedncnn");
1048 
1049             node_reference[node->output(0)] -= 1;
1050             node_reference[node2->output(0)] -= 1;
1051 
1052             blob_names.erase(node->output(0));
1053             blob_names.erase(node2->output(0));
1054 
1055             node2->set_input(0, node->input(0));
1056             node2->set_output(0, node3->output(0));
1057 
1058             reduced_node_count += 2;
1059             i += 2;
1060         }
1061     }
1062 }
1063 
fuse_unsqueeze_prelu(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)1064 static void fuse_unsqueeze_prelu(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
1065 {
1066     int node_count = mutable_graph->node_size();
1067     for (int i = 0; i < node_count; i++)
1068     {
1069         onnx::NodeProto* node = mutable_graph->mutable_node(i);
1070 
1071         // PReLU <= Unsqueeze - PReLU
1072         if (node->op_type() == "Unsqueeze")
1073         {
1074             // check weight
1075             if (weights.find(node->input(0)) == weights.end())
1076                 continue;
1077 
1078             onnx::TensorProto& B = weights[node->input(0)];
1079             if (B.dims_size() != 1)
1080                 continue;
1081 
1082             if (node_reference[node->output(0)] != 1)
1083                 continue;
1084 
1085             // axes = (1, 2)
1086             std::vector<int> axes = get_node_attr_ai(*node, "axes");
1087             if (axes.size() != 2)
1088                 continue;
1089             if (axes[0] != 1 || axes[1] != 2)
1090                 continue;
1091 
1092             if (i + 1 >= node_count)
1093                 continue;
1094 
1095             onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
1096 
1097             if (node2->op_type() != "PRelu")
1098                 continue;
1099 
1100             if (node2->input(1) != node->output(0))
1101                 continue;
1102 
1103             // reduce
1104             node->set_op_type("noop_reducedncnn");
1105 
1106             node_reference[node->output(0)] -= 1;
1107 
1108             blob_names.erase(node->output(0));
1109 
1110             node2->set_input(1, node->input(0));
1111 
1112             reduced_node_count += 1;
1113             i += 1;
1114         }
1115     }
1116 }
1117 
fuse_normalize(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)1118 static void fuse_normalize(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
1119 {
1120     int node_count = mutable_graph->node_size();
1121     for (int i = 0; i < node_count; i++)
1122     {
1123         onnx::NodeProto* node = mutable_graph->mutable_node(i);
1124 
1125         // Normalize <= X - ReduceL2 - Clip - Expand - Div
1126         // Normalize <= X - ReduceL2 - Clip - Shape - Expand - Div
1127         if (node->op_type() == "ReduceL2")
1128         {
1129             if (node_reference[node->output(0)] != 1)
1130                 continue;
1131 
1132             // axes = (1)
1133             std::vector<int> axes = get_node_attr_ai(*node, "axes");
1134             if (axes.size() != 1)
1135                 continue;
1136             if (axes[0] != 1)
1137                 continue;
1138 
1139             if (i + 3 >= node_count)
1140                 continue;
1141 
1142             onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
1143             onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
1144             onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3);
1145 
1146             bool has_shape_node = node3->op_type() == "Shape";
1147             onnx::NodeProto* node_shape = 0;
1148             if (has_shape_node)
1149             {
1150                 if (i + 4 >= node_count)
1151                     continue;
1152 
1153                 node_shape = node3;
1154                 node3 = mutable_graph->mutable_node(i + 3);
1155                 node4 = mutable_graph->mutable_node(i + 4);
1156             }
1157 
1158             if (node2->op_type() != "Clip" || node3->op_type() != "Expand" || node4->op_type() != "Div")
1159                 continue;
1160 
1161             if (node_reference[node2->output(0)] != 1)
1162                 continue;
1163 
1164             if (node_reference[node3->output(0)] != 1)
1165                 continue;
1166 
1167             if (node2->input(0) != node->output(0) || node3->input(0) != node2->output(0)
1168                     || node4->input(0) != node->input(0) || node4->input(1) != node3->output(0))
1169                 continue;
1170 
1171             if (has_shape_node)
1172             {
1173                 if (node_shape->input(0) != node->input(0) || node3->input(1) != node_shape->output(0))
1174                     continue;
1175             }
1176 
1177             // +eps
1178             float clip_min;
1179             if (node2->input_size() == 1)
1180             {
1181                 clip_min = get_node_attr_f(*node2, "min", -FLT_MAX);
1182             }
1183             else
1184             {
1185                 const onnx::TensorProto& min_tp = weights[node2->input(1)];
1186 
1187                 clip_min = get_node_attr_from_input_f(min_tp);
1188             }
1189 
1190             // reduce
1191             node->set_op_type("noop_reducedncnn");
1192             node2->set_op_type("noop_reducedncnn");
1193             if (has_shape_node)
1194             {
1195                 node_shape->set_op_type("noop_reducedncnn");
1196             }
1197             node3->set_op_type("noop_reducedncnn");
1198 
1199             node_reference[node->input(0)] -= has_shape_node ? 2 : 1;
1200             node_reference[node->output(0)] -= 1;
1201             node_reference[node2->output(0)] -= 1;
1202             if (has_shape_node)
1203             {
1204                 node_reference[node_shape->output(0)] -= 1;
1205             }
1206             node_reference[node3->output(0)] -= 1;
1207 
1208             blob_names.erase(node->output(0));
1209             blob_names.erase(node2->output(0));
1210             if (has_shape_node)
1211             {
1212                 blob_names.erase(node_shape->output(0));
1213             }
1214             blob_names.erase(node3->output(0));
1215 
1216             node4->set_op_type("Normalize");
1217             node4->clear_input();
1218             node4->add_input(node->input(0));
1219 
1220             onnx::AttributeProto* attr_alpha = node4->add_attribute();
1221             attr_alpha->set_name("eps");
1222             attr_alpha->set_f(clip_min);
1223 
1224             reduced_node_count += has_shape_node ? 4 : 3;
1225             i += has_shape_node ? 4 : 3;
1226         }
1227     }
1228 }
1229 
fuse_groupnorm(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)1230 static void fuse_groupnorm(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
1231 {
1232     int node_count = mutable_graph->node_size();
1233     for (int i = 0; i < node_count; i++)
1234     {
1235         onnx::NodeProto* node = mutable_graph->mutable_node(i);
1236 
1237         // GroupNorm <= X - Reshape - InstanceNormalization - Reshape - Mul - Add
1238         if (node->op_type() == "Reshape")
1239         {
1240             if (node_reference[node->output(0)] != 1)
1241                 continue;
1242 
1243             std::vector<int> shape;
1244             if (node->input_size() == 1)
1245             {
1246                 shape = get_node_attr_ai(*node, "shape");
1247             }
1248             else
1249             {
1250                 // skip weight reshape
1251                 if (weights.find(node->input(1)) == weights.end())
1252                     continue;
1253 
1254                 shape = get_node_attr_from_input_ai(weights[node->input(1)]);
1255             }
1256 
1257             // 0, group, -1
1258             if (shape.size() != 3)
1259                 continue;
1260 
1261             if (shape[0] != 0 || shape[2] != -1)
1262                 continue;
1263 
1264             int groups = shape[1];
1265 
1266             if (i + 4 >= node_count)
1267                 continue;
1268 
1269             onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
1270             onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
1271             onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3);
1272             onnx::NodeProto* node5 = mutable_graph->mutable_node(i + 4);
1273 
1274             if (node2->op_type() != "InstanceNormalization" || node3->op_type() != "Reshape" || node4->op_type() != "Mul" || node5->op_type() != "Add")
1275                 continue;
1276 
1277             if (node_reference[node2->output(0)] != 1)
1278                 continue;
1279 
1280             if (node_reference[node3->output(0)] != 1)
1281                 continue;
1282 
1283             if (node_reference[node4->output(0)] != 1)
1284                 continue;
1285 
1286             if (node2->input(0) != node->output(0) || node3->input(0) != node2->output(0)
1287                     || node4->input(0) != node3->output(0) || node5->input(0) != node4->output(0))
1288                 continue;
1289 
1290             // +eps
1291             float eps = get_node_attr_f(*node2, "epsilon", 1e-05f);
1292 
1293             // InstanceNormalization S=1 B=0
1294             std::vector<float> S = get_node_attr_from_input_af(weights[node2->input(1)]);
1295             std::vector<float> B = get_node_attr_from_input_af(weights[node2->input(2)]);
1296             if ((int)S.size() != groups || (int)B.size() != groups)
1297                 continue;
1298 
1299             bool instancenorm_affine = false;
1300             for (int j = 0; j < groups; j++)
1301             {
1302                 if (S[j] != 1.f || B[j] != 0.f)
1303                 {
1304                     instancenorm_affine = true;
1305                     break;
1306                 }
1307             }
1308 
1309             if (instancenorm_affine)
1310                 continue;
1311 
1312             std::vector<int> shape2;
1313             if (node3->input_size() == 1)
1314             {
1315                 shape2 = get_node_attr_ai(*node3, "shape");
1316             }
1317             else
1318             {
1319                 // skip weight reshape
1320                 if (weights.find(node3->input(1)) == weights.end())
1321                     continue;
1322 
1323                 shape2 = get_node_attr_from_input_ai(weights[node3->input(1)]);
1324             }
1325 
1326             // 1, channels, w, h
1327             if (shape2.size() != 4)
1328                 continue;
1329 
1330             if (shape2[0] != 1)
1331                 continue;
1332 
1333             int channels = shape2[1];
1334 
1335             // affine
1336             int affine = 0;
1337             std::vector<float> affine_S = get_node_attr_from_input_af(weights[node4->input(1)]);
1338             std::vector<float> affine_B = get_node_attr_from_input_af(weights[node5->input(1)]);
1339             if (affine_S.size() == 1 && affine_S[0] == 1.f && affine_B.size() == 1 && affine_B[0] == 0.f)
1340             {
1341                 affine = 0;
1342             }
1343             else if ((int)affine_S.size() != channels && (int)affine_B.size() != channels)
1344             {
1345                 // we only allow per-channel affine
1346                 continue;
1347             }
1348 
1349             for (int j = 0; j < channels; j++)
1350             {
1351                 if (affine_S[j] != 1.f || affine_B[j] != 0.f)
1352                 {
1353                     affine = 1;
1354                     break;
1355                 }
1356             }
1357 
1358             // reduce
1359             node->set_op_type("noop_reducedncnn");
1360             node2->set_op_type("noop_reducedncnn");
1361             node3->set_op_type("noop_reducedncnn");
1362             node4->set_op_type("noop_reducedncnn");
1363 
1364             if (node->input_size() == 2)
1365             {
1366                 node_reference[node->input(1)] -= 1;
1367             }
1368             node_reference[node->output(0)] -= 1;
1369             node_reference[node2->input(1)] -= 1;
1370             node_reference[node2->input(2)] -= 1;
1371             node_reference[node2->output(0)] -= 1;
1372             if (node3->input_size() == 2)
1373             {
1374                 node_reference[node3->input(1)] -= 1;
1375             }
1376             node_reference[node3->output(0)] -= 1;
1377             node_reference[node4->output(0)] -= 1;
1378 
1379             std::string affine_scale = node4->input(1);
1380             std::string affine_bias = node5->input(1);
1381 
1382             node_reference[affine_scale] -= 1;
1383             node_reference[affine_bias] -= 1;
1384 
1385             blob_names.erase(node->output(0));
1386             blob_names.erase(node2->output(0));
1387             blob_names.erase(node3->output(0));
1388             blob_names.erase(node4->output(0));
1389 
1390             node5->set_op_type("GroupNorm");
1391             node5->clear_input();
1392             node5->add_input(node->input(0));
1393             if (affine)
1394             {
1395                 node5->add_input(affine_scale);
1396                 node5->add_input(affine_bias);
1397             }
1398 
1399             onnx::AttributeProto* attr_groups = node5->add_attribute();
1400             attr_groups->set_name("groups");
1401             attr_groups->set_i(groups);
1402 
1403             onnx::AttributeProto* attr_channels = node5->add_attribute();
1404             attr_channels->set_name("channels");
1405             attr_channels->set_i(channels);
1406 
1407             onnx::AttributeProto* attr_eps = node5->add_attribute();
1408             attr_eps->set_name("epsilon");
1409             attr_eps->set_f(eps);
1410 
1411             onnx::AttributeProto* attr_affine = node5->add_attribute();
1412             attr_affine->set_name("affine");
1413             attr_affine->set_i(affine);
1414 
1415             reduced_node_count += 4;
1416             i += 4;
1417         }
1418     }
1419 }
1420 
fuse_flatten(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)1421 static void fuse_flatten(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
1422 {
1423     int node_count = mutable_graph->node_size();
1424     for (int i = 0; i < node_count; i++)
1425     {
1426         onnx::NodeProto* node = mutable_graph->mutable_node(i);
1427 
1428         // Flatten <= X - Shape - Gather - Constant - Unsqueeze - Unsqueeze - Concat - Reshape
1429         if (node->op_type() == "Shape")
1430         {
1431             if (node_reference[node->output(0)] != 1)
1432                 continue;
1433 
1434             if (i + 6 >= node_count)
1435                 continue;
1436 
1437             onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
1438             onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
1439             onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3);
1440             onnx::NodeProto* node5 = mutable_graph->mutable_node(i + 4);
1441             onnx::NodeProto* node6 = mutable_graph->mutable_node(i + 5);
1442             onnx::NodeProto* node7 = mutable_graph->mutable_node(i + 6);
1443 
1444             if (node2->op_type() != "Gather" || node3->op_type() != "Constant" || node4->op_type() != "Unsqueeze" || node5->op_type() != "Unsqueeze"
1445                     || node6->op_type() != "Concat" || node7->op_type() != "Reshape")
1446                 continue;
1447 
1448             if (node_reference[node2->output(0)] != 1)
1449                 continue;
1450 
1451             //             if (node_reference[node3->output(0)] != 1)
1452             //                 continue;
1453 
1454             if (node_reference[node4->output(0)] != 1)
1455                 continue;
1456 
1457             if (node_reference[node5->output(0)] != 1)
1458                 continue;
1459 
1460             if (node_reference[node6->output(0)] != 1)
1461                 continue;
1462 
1463             if (node2->input(0) != node->output(0) || node4->input(0) != node2->output(0) || node5->input(0) != node3->output(0)
1464                     || node6->input(0) != node4->output(0) || node6->input(1) != node5->output(0)
1465                     || node7->input(0) != node->input(0) || node7->input(1) != node6->output(0))
1466                 continue;
1467 
1468             // axis = 0
1469             int gather_axis = get_node_attr_i(*node2, "axis");
1470             if (gather_axis != 0)
1471                 continue;
1472 
1473             // indices = 0
1474             if (weights.find(node2->input(1)) == weights.end())
1475                 continue;
1476 
1477             std::vector<int> gather_indices = get_node_attr_from_input_ai(weights[node2->input(1)]);
1478             if (gather_indices.size() != 1 || gather_indices[0] != 0)
1479                 continue;
1480 
1481             // axes = (0)
1482             std::vector<int> unsqueeze_axes = get_node_attr_ai(*node4, "axes");
1483             if (unsqueeze_axes.size() != 1)
1484                 continue;
1485             if (unsqueeze_axes[0] != 0)
1486                 continue;
1487 
1488             // axes = (0)
1489             std::vector<int> unsqueeze2_axes = get_node_attr_ai(*node5, "axes");
1490             if (unsqueeze2_axes.size() != 1)
1491                 continue;
1492             if (unsqueeze2_axes[0] != 0)
1493                 continue;
1494 
1495             // data = -1
1496             if (weights.find(node5->input(0)) == weights.end())
1497                 continue;
1498 
1499             std::vector<int> unsqueeze2_data = get_node_attr_from_input_ai(weights[node5->input(0)]);
1500             if (unsqueeze2_data.size() != 1 || unsqueeze2_data[0] != -1)
1501                 continue;
1502 
1503             // axis = 0
1504             int concat_axis = get_node_attr_i(*node6, "axis");
1505             if (concat_axis != 0)
1506                 continue;
1507 
1508             // reduce
1509             node->set_op_type("noop_reducedncnn");
1510             node2->set_op_type("noop_reducedncnn");
1511             //             node3->set_op_type("noop_reducedncnn");
1512             node4->set_op_type("noop_reducedncnn");
1513             node5->set_op_type("noop_reducedncnn");
1514             node6->set_op_type("noop_reducedncnn");
1515 
1516             node_reference[node->input(0)] -= 1;
1517             node_reference[node->output(0)] -= 1;
1518             node_reference[node2->input(1)] -= 1;
1519             node_reference[node2->output(0)] -= 1;
1520             //             node_reference[node3->output(0)] -= 1;
1521             node_reference[node4->output(0)] -= 1;
1522             node_reference[node5->input(0)] -= 1;
1523             node_reference[node5->output(0)] -= 1;
1524             node_reference[node6->output(0)] -= 1;
1525 
1526             blob_names.erase(node->output(0));
1527             blob_names.erase(node2->output(0));
1528             //             blob_names.erase(node3->output(0));
1529             blob_names.erase(node4->output(0));
1530             blob_names.erase(node5->output(0));
1531             blob_names.erase(node6->output(0));
1532 
1533             node7->set_op_type("Flatten");
1534             node7->clear_input();
1535             node7->add_input(node->input(0));
1536 
1537             reduced_node_count += 5;
1538             i += 5;
1539         }
1540     }
1541 }
1542 
fuse_pixelshuffle(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)1543 static void fuse_pixelshuffle(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
1544 {
1545     int node_count = mutable_graph->node_size();
1546     for (int i = 0; i < node_count; i++)
1547     {
1548         onnx::NodeProto* node = mutable_graph->mutable_node(i);
1549 
1550         // PixelShuffle <= Reshape - Transpose - Reshape
1551         // PixelShuffle <= Reshape - Transpose - Constant - Reshape
1552         if (node->op_type() == "Reshape")
1553         {
1554             if (node_reference[node->output(0)] != 1)
1555                 continue;
1556 
1557             std::vector<int> shape;
1558             if (node->input_size() == 1)
1559             {
1560                 shape = get_node_attr_ai(*node, "shape");
1561             }
1562             else
1563             {
1564                 // skip weight reshape
1565                 if (weights.find(node->input(1)) == weights.end())
1566                     continue;
1567 
1568                 shape = get_node_attr_from_input_ai(weights[node->input(1)]);
1569             }
1570 
1571             // -1, 3, upscale_factor, upscale_factor, height, width
1572             if (shape.size() != 6)
1573                 continue;
1574 
1575             if (shape[0] != 1 && shape[0] != -1)
1576                 continue;
1577 
1578             if (shape[2] != shape[3])
1579                 continue;
1580 
1581             if (i + 2 >= node_count)
1582                 continue;
1583 
1584             onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
1585             onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
1586 
1587             if (node3->op_type() == "Constant")
1588             {
1589                 if (i + 3 >= node_count)
1590                     continue;
1591 
1592                 node3 = mutable_graph->mutable_node(i + 3);
1593             }
1594 
1595             if (node2->op_type() != "Transpose" || node3->op_type() != "Reshape")
1596                 continue;
1597 
1598             if (node_reference[node2->output(0)] != 1)
1599                 continue;
1600 
1601             // 0 1 4 2 5 3
1602             std::vector<int> perm = get_node_attr_ai(*node2, "perm");
1603             if (perm.size() != 6)
1604                 continue;
1605 
1606             if (perm[0] != 0 || perm[1] != 1 || perm[2] != 4 || perm[3] != 2 || perm[4] != 5 || perm[5] != 3)
1607                 continue;
1608 
1609             std::vector<int> shape3;
1610             if (node3->input_size() == 1)
1611             {
1612                 shape3 = get_node_attr_ai(*node3, "shape");
1613             }
1614             else
1615             {
1616                 // skip weight reshape
1617                 if (weights.find(node3->input(1)) == weights.end())
1618                     continue;
1619 
1620                 shape3 = get_node_attr_from_input_ai(weights[node3->input(1)]);
1621             }
1622 
1623             // -1, 3, height, width
1624             if (shape3.size() != 4)
1625                 continue;
1626 
1627             if (shape3[0] != 1 && shape3[0] != -1)
1628                 continue;
1629 
1630             if (shape3[1] != shape[1] || shape3[2] != shape[2] * shape[4] || shape3[3] != shape[3] * shape[5])
1631                 continue;
1632 
1633             // reduce
1634             node->set_op_type("noop_reducedncnn");
1635             node2->set_op_type("noop_reducedncnn");
1636 
1637             if (node->input_size() == 2)
1638             {
1639                 node_reference[node->input(1)] -= 1;
1640             }
1641             node_reference[node->output(0)] -= 1;
1642             node_reference[node2->output(0)] -= 1;
1643             if (node3->input_size() == 2)
1644             {
1645                 node_reference[node3->input(1)] -= 1;
1646             }
1647 
1648             blob_names.erase(node->output(0));
1649             blob_names.erase(node2->output(0));
1650 
1651             node3->set_op_type("PixelShuffle");
1652             node3->set_input(0, node->input(0));
1653 
1654             onnx::AttributeProto* attr_group = node3->add_attribute();
1655             attr_group->set_name("scale_factor");
1656             attr_group->set_i(shape[2]);
1657 
1658             reduced_node_count += 2;
1659             i += 2;
1660         }
1661     }
1662 }
1663 
fuse_reorg(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)1664 static void fuse_reorg(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
1665 {
1666     int node_count = mutable_graph->node_size();
1667     for (int i = 0; i < node_count; i++)
1668     {
1669         onnx::NodeProto* node = mutable_graph->mutable_node(i);
1670 
1671         // PixelShuffle <= Reshape - Transpose - Reshape
1672         // PixelShuffle <= Reshape - Transpose - Constant - Reshape
1673         if (node->op_type() == "Reshape")
1674         {
1675             if (node_reference[node->output(0)] != 1)
1676                 continue;
1677 
1678             std::vector<int> shape;
1679             if (node->input_size() == 1)
1680             {
1681                 shape = get_node_attr_ai(*node, "shape");
1682             }
1683             else
1684             {
1685                 // skip weight reshape
1686                 if (weights.find(node->input(1)) == weights.end())
1687                     continue;
1688 
1689                 shape = get_node_attr_from_input_ai(weights[node->input(1)]);
1690             }
1691 
1692             // -1, 3, out_height, block_size, out_width, block_size
1693             if (shape.size() != 6)
1694                 continue;
1695 
1696             if (shape[0] != 1 && shape[0] != -1)
1697                 continue;
1698 
1699             if (shape[3] != shape[5])
1700                 continue;
1701 
1702             if (i + 2 >= node_count)
1703                 continue;
1704 
1705             onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
1706             onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
1707 
1708             if (node3->op_type() == "Constant")
1709             {
1710                 if (i + 3 >= node_count)
1711                     continue;
1712 
1713                 node3 = mutable_graph->mutable_node(i + 3);
1714             }
1715 
1716             if (node2->op_type() != "Transpose" || node3->op_type() != "Reshape")
1717                 continue;
1718 
1719             if (node_reference[node2->output(0)] != 1)
1720                 continue;
1721 
1722             // 0 1 3 5 2 4
1723             std::vector<int> perm = get_node_attr_ai(*node2, "perm");
1724             if (perm.size() != 6)
1725                 continue;
1726 
1727             if (perm[0] != 0 || perm[1] != 1 || perm[2] != 3 || perm[3] != 5 || perm[4] != 2 || perm[5] != 4)
1728                 continue;
1729 
1730             std::vector<int> shape3;
1731             if (node3->input_size() == 1)
1732             {
1733                 shape3 = get_node_attr_ai(*node3, "shape");
1734             }
1735             else
1736             {
1737                 // skip weight reshape
1738                 if (weights.find(node3->input(1)) == weights.end())
1739                     continue;
1740 
1741                 shape3 = get_node_attr_from_input_ai(weights[node3->input(1)]);
1742             }
1743 
1744             // -1, out_channels, out_height, out_width
1745             if (shape3.size() != 4)
1746                 continue;
1747 
1748             if (shape3[0] != 1 && shape3[0] != -1)
1749                 continue;
1750 
1751             if (shape3[1] != shape[1] * shape[3] * shape[5] || shape3[2] != shape[2] || shape3[3] != shape[4])
1752                 continue;
1753 
1754             // reduce
1755             node->set_op_type("noop_reducedncnn");
1756             node2->set_op_type("noop_reducedncnn");
1757 
1758             if (node->input_size() == 2)
1759             {
1760                 node_reference[node->input(1)] -= 1;
1761             }
1762             node_reference[node->output(0)] -= 1;
1763             node_reference[node2->output(0)] -= 1;
1764             if (node3->input_size() == 2)
1765             {
1766                 node_reference[node3->input(1)] -= 1;
1767             }
1768 
1769             blob_names.erase(node->output(0));
1770             blob_names.erase(node2->output(0));
1771 
1772             node3->set_op_type("Reorg");
1773             node3->set_input(0, node->input(0));
1774 
1775             onnx::AttributeProto* attr_group = node3->add_attribute();
1776             attr_group->set_name("stride");
1777             attr_group->set_i(shape[3]);
1778 
1779             reduced_node_count += 2;
1780             i += 2;
1781         }
1782     }
1783 }
1784 
fuse_expand_broadcast(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)1785 static void fuse_expand_broadcast(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
1786 {
1787     int node_count = mutable_graph->node_size();
1788     for (int i = 0; i < node_count; i++)
1789     {
1790         onnx::NodeProto* node = mutable_graph->mutable_node(i);
1791 
1792         // Add/Sub/Mul/Div/Min/Max <= Expand - Add/Sub/Mul/Div/Min/Max
1793         if (node->op_type() == "Expand")
1794         {
1795             if (node_reference[node->output(0)] != 1)
1796                 continue;
1797 
1798             if (i + 1 >= node_count)
1799                 continue;
1800 
1801             onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
1802 
1803             if (node2->op_type() != "Add" && node2->op_type() != "Sub" && node2->op_type() != "Mul" && node2->op_type() != "Div" && node2->op_type() != "Min" && node2->op_type() != "Max")
1804                 continue;
1805 
1806             if (node2->input(1) != node->output(0) && node2->input(0) != node->output(0))
1807                 continue;
1808 
1809             // reduce
1810             node->set_op_type("noop_reducedncnn");
1811 
1812             node_reference[node->output(0)] -= 1;
1813             if (node->input_size() == 2)
1814             {
1815                 node_reference[node->input(1)] -= 1;
1816             }
1817 
1818             blob_names.erase(node->output(0));
1819 
1820             node2->set_input(1, node->input(0));
1821 
1822             reduced_node_count += 1;
1823             i += 1;
1824         }
1825     }
1826 }
1827 
fuse_lstm_gru_rnn(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)1828 static void fuse_lstm_gru_rnn(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
1829 {
1830     int node_count = mutable_graph->node_size();
1831     for (int i = 0; i < node_count; i++)
1832     {
1833         onnx::NodeProto* node = mutable_graph->mutable_node(i);
1834 
1835         // LSTM(bi) <= LSTM(bi) - Transpose - Reshape - Transpose
1836         if (node->op_type() == "LSTM" || node->op_type() == "GRU" || node->op_type() == "RNN")
1837         {
1838             if (node_reference[node->output(0)] != 1)
1839                 continue;
1840 
1841             if (i + 2 >= node_count)
1842                 continue;
1843 
1844             onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
1845             onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
1846 
1847             if (node2->op_type() != "Transpose" || node3->op_type() != "Reshape")
1848                 continue;
1849 
1850             if (node_reference[node2->output(0)] != 1)
1851                 continue;
1852 
1853             if (node2->input(0) != node->output(0) || node3->input(0) != node2->output(0))
1854                 continue;
1855 
1856             std::string direction = get_node_attr_s(*node, "direction");
1857             if (direction != "bidirectional")
1858                 continue;
1859 
1860             // 0 2 1 3
1861             std::vector<int> perm = get_node_attr_ai(*node2, "perm");
1862             if (perm.size() != 4)
1863                 continue;
1864 
1865             if (perm[0] != 0 || perm[1] != 2 || perm[2] != 1 || perm[3] != 3)
1866                 continue;
1867 
1868             std::vector<int> shape;
1869             if (node3->input_size() == 1)
1870             {
1871                 shape = get_node_attr_ai(*node3, "shape");
1872             }
1873             else
1874             {
1875                 // skip weight reshape
1876                 if (weights.find(node3->input(1)) == weights.end())
1877                     continue;
1878 
1879                 shape = get_node_attr_from_input_ai(weights[node3->input(1)]);
1880             }
1881 
1882             // 0 0 -1
1883             if (shape.size() != 3)
1884                 continue;
1885 
1886             if (shape[0] != 0 || shape[1] != 0 || shape[2] != -1)
1887                 continue;
1888 
1889             // reduce
1890             node2->set_op_type("noop_reducedncnn");
1891             node3->set_op_type("noop_reducedncnn");
1892 
1893             node_reference[node->output(0)] -= 1;
1894             node_reference[node2->output(0)] -= 1;
1895             if (node3->input_size() == 2)
1896             {
1897                 node_reference[node3->input(1)] -= 1;
1898             }
1899 
1900             blob_names.erase(node->output(0));
1901             if (node->output_size() > 1)
1902             {
1903                 for (int j = 1; j < node->output_size(); j++)
1904                 {
1905                     blob_names.erase(node->output(j));
1906                 }
1907             }
1908             blob_names.erase(node2->output(0));
1909 
1910             node->clear_output();
1911             node->add_output(node3->output(0));
1912 
1913             reduced_node_count += 2;
1914             i += 2;
1915 
1916             if (i + 1 < node_count)
1917             {
1918                 if (node_reference[node3->output(0)] != 1)
1919                     continue;
1920 
1921                 onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 1);
1922 
1923                 if (node4->op_type() != "Transpose")
1924                     continue;
1925 
1926                 if (node4->input(0) != node->output(0))
1927                     continue;
1928 
1929                 // 1 0 2
1930                 std::vector<int> perm4 = get_node_attr_ai(*node4, "perm");
1931                 if (perm4.size() != 3)
1932                     continue;
1933 
1934                 if (perm4[0] != 1 || perm4[1] != 0 || perm4[2] != 2)
1935                     continue;
1936 
1937                 // reduce
1938                 node4->set_op_type("noop_reducedncnn");
1939 
1940                 node_reference[node->output(0)] -= 1;
1941 
1942                 blob_names.erase(node->output(0));
1943 
1944                 node->clear_output();
1945                 node->add_output(node4->output(0));
1946 
1947                 reduced_node_count += 1;
1948                 i += 1;
1949             }
1950         }
1951     }
1952 
1953     for (int i = 0; i < node_count; i++)
1954     {
1955         onnx::NodeProto* node = mutable_graph->mutable_node(i);
1956 
1957         // LSTM(uni) <= LSTM(uni) - Squeeze - Transpose
1958         if (node->op_type() == "LSTM" || node->op_type() == "GRU" || node->op_type() == "RNN")
1959         {
1960             if (node_reference[node->output(0)] != 1)
1961                 continue;
1962 
1963             if (i + 1 >= node_count)
1964                 continue;
1965 
1966             onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
1967 
1968             if (node2->op_type() != "Squeeze")
1969                 continue;
1970 
1971             if (node2->input(0) != node->output(0))
1972                 continue;
1973 
1974             std::string direction = get_node_attr_s(*node, "direction");
1975             if (direction == "bidirectional")
1976                 continue;
1977 
1978             // 1
1979             std::vector<int> axes = get_node_attr_ai(*node2, "axes");
1980             if (axes.size() != 1)
1981                 continue;
1982 
1983             if (axes[0] != 1)
1984                 continue;
1985 
1986             // reduce
1987             node2->set_op_type("noop_reducedncnn");
1988 
1989             node_reference[node->output(0)] -= 1;
1990 
1991             blob_names.erase(node->output(0));
1992             if (node->output_size() > 1)
1993             {
1994                 for (int j = 1; j < node->output_size(); j++)
1995                 {
1996                     blob_names.erase(node->output(j));
1997                 }
1998             }
1999 
2000             node->clear_output();
2001             node->add_output(node2->output(0));
2002 
2003             reduced_node_count += 1;
2004             i += 1;
2005 
2006             if (i + 1 < node_count)
2007             {
2008                 if (node_reference[node2->output(0)] != 1)
2009                     continue;
2010 
2011                 onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 1);
2012 
2013                 if (node3->op_type() != "Transpose")
2014                     continue;
2015 
2016                 if (node3->input(0) != node->output(0))
2017                     continue;
2018 
2019                 // 1 0 2
2020                 std::vector<int> perm4 = get_node_attr_ai(*node3, "perm");
2021                 if (perm4.size() != 3)
2022                     continue;
2023 
2024                 if (perm4[0] != 1 || perm4[1] != 0 || perm4[2] != 2)
2025                     continue;
2026 
2027                 // reduce
2028                 node3->set_op_type("noop_reducedncnn");
2029 
2030                 node_reference[node->output(0)] -= 1;
2031 
2032                 blob_names.erase(node->output(0));
2033 
2034                 node->clear_output();
2035                 node->add_output(node3->output(0));
2036 
2037                 reduced_node_count += 1;
2038                 i += 1;
2039             }
2040         }
2041     }
2042 
2043     for (int i = 0; i < node_count; i++)
2044     {
2045         onnx::NodeProto* node = mutable_graph->mutable_node(i);
2046 
2047         // LSTM <= Transpose - LSTM
2048         if (node->op_type() == "Transpose")
2049         {
2050             if (node_reference[node->output(0)] != 1)
2051                 continue;
2052 
2053             // 1 0 2
2054             std::vector<int> perm = get_node_attr_ai(*node, "perm");
2055             if (perm.size() != 3)
2056                 continue;
2057 
2058             if (perm[0] != 1 || perm[1] != 0 || perm[2] != 2)
2059                 continue;
2060 
2061             if (i + 1 >= node_count)
2062                 continue;
2063 
2064             onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
2065 
2066             if (node2->op_type() != "LSTM" && node->op_type() != "GRU" && node->op_type() != "RNN")
2067                 continue;
2068 
2069             if (node2->input(0) != node->output(0))
2070                 continue;
2071 
2072             // reduce
2073             node->set_op_type("noop_reducedncnn");
2074 
2075             node_reference[node->output(0)] -= 1;
2076 
2077             blob_names.erase(node->output(0));
2078 
2079             node2->set_input(0, node->input(0));
2080 
2081             reduced_node_count += 1;
2082             i += 1;
2083         }
2084     }
2085 }
2086 
main(int argc,char ** argv)2087 int main(int argc, char** argv)
2088 {
2089     const char* onnxpb = argv[1];
2090     const char* ncnn_prototxt = argc >= 4 ? argv[2] : "ncnn.param";
2091     const char* ncnn_modelbin = argc >= 4 ? argv[3] : "ncnn.bin";
2092 
2093     onnx::ModelProto model;
2094 
2095     // load
2096     bool s1 = read_proto_from_binary(onnxpb, &model);
2097     if (!s1)
2098     {
2099         fprintf(stderr, "read_proto_from_binary failed\n");
2100         return -1;
2101     }
2102 
2103     FILE* pp = fopen(ncnn_prototxt, "wb");
2104     FILE* bp = fopen(ncnn_modelbin, "wb");
2105 
2106     // magic
2107     fprintf(pp, "7767517\n");
2108 
2109     const onnx::GraphProto& graph = model.graph();
2110     onnx::GraphProto* mutable_graph = model.mutable_graph();
2111 
2112     int node_count = graph.node_size();
2113 
2114     // node reference
2115     std::map<std::string, int> node_reference;
2116 
2117     // weight node and weight reshape node
2118     std::map<std::string, onnx::TensorProto> weights;
2119 
2120     for (int j = 0; j < graph.initializer_size(); j++)
2121     {
2122         const onnx::TensorProto& initializer = graph.initializer(j);
2123 
2124         //         fprintf(stderr, "weight = %s %d\n", initializer.name().c_str(), initializer.data_type());
2125 
2126         weights[initializer.name()] = initializer;
2127     }
2128 
2129     // topological sort
2130     {
2131         // name -> producer node index
2132         std::set<std::string> producers;
2133         for (int j = 0; j < graph.input_size(); j++)
2134         {
2135             const std::string& input_name = graph.input(j).name();
2136             producers.insert(input_name);
2137         }
2138 
2139         for (int i = 0; i < node_count;)
2140         {
2141             onnx::NodeProto* node = mutable_graph->mutable_node(i);
2142 
2143             bool swapnode = false;
2144             std::string missing_input_name;
2145             for (int j = 0; j < (int)node->input_size(); j++)
2146             {
2147                 const std::string& input_name = node->input(j);
2148                 if (input_name.empty())
2149                     continue;
2150 
2151                 if (producers.find(input_name) == producers.end() && weights.find(input_name) == weights.end())
2152                 {
2153                     swapnode = true;
2154                     missing_input_name = input_name;
2155                     break;
2156                 }
2157             }
2158 
2159             if (!swapnode)
2160             {
2161                 for (int j = 0; j < (int)node->output_size(); j++)
2162                 {
2163                     const std::string& output_name = node->output(j);
2164                     if (output_name.empty())
2165                         continue;
2166 
2167                     producers.insert(output_name);
2168                 }
2169 
2170                 i++;
2171                 continue;
2172             }
2173 
2174             // find node that produce missing_input_name
2175             int q = i + 1;
2176             for (; q < node_count; q++)
2177             {
2178                 onnx::NodeProto* nodeq = mutable_graph->mutable_node(q);
2179                 bool found = false;
2180                 for (int j = 0; j < (int)nodeq->output_size(); j++)
2181                 {
2182                     const std::string& output_name = nodeq->output(j);
2183                     if (output_name == missing_input_name)
2184                     {
2185                         found = true;
2186                         break;
2187                     }
2188                 }
2189 
2190                 if (found)
2191                     break;
2192             }
2193 
2194             if (q == node_count)
2195             {
2196                 fprintf(stderr, "cannot find node produces %s but node %d requires it\n", missing_input_name.c_str(), i);
2197                 return -1;
2198             }
2199 
2200             // fprintf(stderr, "swap %d %d\n", i, q);
2201             // swap this node with q
2202             onnx::NodeProto* nodeq = mutable_graph->mutable_node(q);
2203             onnx::NodeProto tmp = *node;
2204             *node = *nodeq;
2205             *nodeq = tmp;
2206         }
2207     }
2208 
2209     // global definition line
2210     // [layer count] [blob count]
2211     std::set<std::string> blob_names;
2212     for (int i = 0; i < node_count; i++)
2213     {
2214         const onnx::NodeProto& node = graph.node(i);
2215 
2216         const std::string& op = node.op_type();
2217 
2218         std::string name = node.name();
2219         if (name.empty())
2220         {
2221             name = node.output(0);
2222         }
2223 
2224         if (op == "Constant")
2225         {
2226             onnx::TensorProto tensor = get_node_attr_tensor(node, "value");
2227             weights[node.output(0)] = tensor;
2228         }
2229 
2230         for (int j = 0; j < (int)node.input_size(); j++)
2231         {
2232             const std::string& input_name = node.input(j);
2233 
2234             blob_names.insert(input_name);
2235 
2236             if (node_reference.find(input_name) == node_reference.end())
2237             {
2238                 node_reference[input_name] = 1;
2239             }
2240             else
2241             {
2242                 node_reference[input_name] = node_reference[input_name] + 1;
2243             }
2244         }
2245 
2246         if (op == "Dropout")
2247         {
2248             const std::string& output_name = node.output(0);
2249             blob_names.insert(output_name);
2250             node_reference[output_name] = 0;
2251             continue;
2252         }
2253 
2254         for (int j = 0; j < (int)node.output_size(); j++)
2255         {
2256             const std::string& output_name = node.output(j);
2257 
2258             blob_names.insert(output_name);
2259 
2260             node_reference[output_name] = 0;
2261         }
2262     }
2263 
2264     // include Input node
2265     int input_node_count = 0;
2266     for (int j = 0; j < graph.input_size(); j++)
2267     {
2268         const std::string& input_name = graph.input(j).name();
2269 
2270         // check weight
2271         if (weights.find(input_name) != weights.end())
2272             continue;
2273 
2274         blob_names.insert(input_name);
2275 
2276         input_node_count++;
2277     }
2278 
2279     //     for (auto a: node_reference)
2280     //     {
2281     //         fprintf(stderr, "a = %s %d\n", a.first.c_str(), a.second);
2282     //     }
2283 
2284     // op chain fusion
2285     int reduced_node_count = 0;
2286     fuse_weight_reshape(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
2287     fuse_weight_transpose(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
2288     fuse_shufflechannel(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
2289     fuse_shufflechannel_split(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
2290     fuse_hardsigmoid(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
2291     fuse_hardswish(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
2292     fuse_swish(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
2293     fuse_batchnorm1d_squeeze_unsqueeze(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
2294     fuse_unsqueeze_prelu(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
2295     fuse_normalize(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
2296     fuse_groupnorm(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
2297     fuse_flatten(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
2298     fuse_pixelshuffle(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
2299     fuse_reorg(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
2300     fuse_expand_broadcast(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
2301     fuse_lstm_gru_rnn(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
2302 
2303     // reduce common const weight node_reference
2304     for (int i = 0; i < node_count; i++)
2305     {
2306         const onnx::NodeProto& node = graph.node(i);
2307 
2308         const std::string& op = node.op_type();
2309 
2310         if (op == "Add" || op == "Sub" || op == "Mul" || op == "Div" || op == "Max" || op == "Min" || op == "Pow")
2311         {
2312             // binaryop with scalar
2313             if (weights.find(node.input(1)) != weights.end() && weights[node.input(1)].dims_size() == 0)
2314             {
2315                 node_reference[node.input(1)] -= 1;
2316             }
2317         }
2318         else if (op == "Attention")
2319         {
2320             node_reference[node.input(1)] -= 1;
2321             node_reference[node.input(2)] -= 1;
2322             node_reference[node.input(3)] -= 1;
2323         }
2324         else if (op == "BatchNormalization")
2325         {
2326             node_reference[node.input(1)] -= 1;
2327             node_reference[node.input(2)] -= 1;
2328             node_reference[node.input(3)] -= 1;
2329             node_reference[node.input(4)] -= 1;
2330         }
2331         else if (op == "BiasGelu")
2332         {
2333             node_reference[node.input(1)] -= 1;
2334         }
2335         else if (op == "Clip")
2336         {
2337             if (node.input_size() == 3)
2338             {
2339                 node_reference[node.input(1)] -= 1;
2340                 node_reference[node.input(2)] -= 1;
2341             }
2342         }
2343         else if (op == "Conv")
2344         {
2345             node_reference[node.input(1)] -= 1;
2346             if (node.input_size() == 3)
2347             {
2348                 node_reference[node.input(2)] -= 1;
2349             }
2350         }
2351         else if (op == "ConvTranspose")
2352         {
2353             node_reference[node.input(1)] -= 1;
2354             if (node.input_size() == 3)
2355             {
2356                 node_reference[node.input(2)] -= 1;
2357             }
2358         }
2359         else if (op == "EmbedLayerNormalization")
2360         {
2361             node_reference[node.input(1)] -= 1;
2362             node_reference[node.input(2)] -= 1;
2363             node_reference[node.input(3)] -= 1;
2364             node_reference[node.input(4)] -= 1;
2365             node_reference[node.input(5)] -= 1;
2366             node_reference[node.input(6)] -= 1;
2367         }
2368         else if (op == "Gemm")
2369         {
2370             float alpha = get_node_attr_f(node, "alpha", 1.f);
2371             float beta = get_node_attr_f(node, "beta", 1.f);
2372             int transA = get_node_attr_i(node, "transA", 0);
2373             int transB = get_node_attr_i(node, "transB", 0);
2374 
2375             if (alpha == 1.f && beta == 1.f && transA == 0 && transB == 1)
2376             {
2377                 // InnerProduct-like A * B + C
2378                 node_reference[node.input(1)] -= 1;
2379                 node_reference[node.input(2)] -= 1;
2380             }
2381         }
2382         else if (op == "GroupNorm")
2383         {
2384             int affine = get_node_attr_i(node, "affine", 1);
2385             if (affine)
2386             {
2387                 node_reference[node.input(1)] -= 1;
2388                 node_reference[node.input(2)] -= 1;
2389             }
2390         }
2391         else if (op == "GRU")
2392         {
2393             for (int j = 1; j < node.input_size(); j++)
2394             {
2395                 node_reference[node.input(j)] -= 1;
2396             }
2397         }
2398         else if (op == "InstanceNormalization")
2399         {
2400             node_reference[node.input(1)] -= 1;
2401             node_reference[node.input(2)] -= 1;
2402         }
2403         else if (op == "LSTM")
2404         {
2405             for (int j = 1; j < node.input_size(); j++)
2406             {
2407                 node_reference[node.input(j)] -= 1;
2408             }
2409         }
2410         else if (op == "MatMul")
2411         {
2412             if (weights.find(node.input(1)) != weights.end() && weights[node.input(1)].dims_size() == 2)
2413             {
2414                 // InnerProduct
2415                 node_reference[node.input(1)] -= 1;
2416             }
2417         }
2418         else if (op == "Pad")
2419         {
2420             if (node.input_size() >= 2)
2421             {
2422                 node_reference[node.input(1)] -= 1;
2423             }
2424         }
2425         else if (op == "PRelu")
2426         {
2427             node_reference[node.input(1)] -= 1;
2428         }
2429         else if (op == "Reshape")
2430         {
2431             if (node.input_size() >= 2)
2432             {
2433                 node_reference[node.input(1)] -= 1;
2434             }
2435         }
2436         else if (op == "Resize")
2437         {
2438             if (node.input_size() == 2)
2439             {
2440                 // opset 10
2441                 node_reference[node.input(1)] -= 1;
2442             }
2443             else
2444             {
2445                 // opset 11+
2446                 node_reference[node.input(1)] -= 1;
2447                 node_reference[node.input(2)] -= 1;
2448                 if (node.input_size() >= 4)
2449                 {
2450                     node_reference[node.input(3)] -= 1;
2451                 }
2452             }
2453         }
2454         else if (op == "RNN")
2455         {
2456             for (int j = 1; j < node.input_size(); j++)
2457             {
2458                 node_reference[node.input(j)] -= 1;
2459             }
2460         }
2461         else if (op == "SkipLayerNormalization")
2462         {
2463             node_reference[node.input(2)] -= 1;
2464             node_reference[node.input(3)] -= 1;
2465             node_reference[node.input(4)] -= 1;
2466         }
2467         else if (op == "Slice")
2468         {
2469             if (node.input_size() >= 2)
2470             {
2471                 node_reference[node.input(1)] -= 1;
2472                 node_reference[node.input(2)] -= 1;
2473                 if (node.input_size() >= 4)
2474                     node_reference[node.input(3)] -= 1;
2475                 if (node.input_size() >= 5)
2476                     node_reference[node.input(4)] -= 1;
2477             }
2478         }
2479         else if (op == "Upsample")
2480         {
2481             if (node.input_size() >= 2)
2482             {
2483                 node_reference[node.input(1)] -= 1;
2484             }
2485         }
2486         else if (op == "adaptive_avg_pool2d" || op == "adaptive_max_pool2d")
2487         {
2488             if (node.input_size() >= 2)
2489             {
2490                 node_reference[node.input(1)] -= 1;
2491             }
2492         }
2493     }
2494 
2495     //         for (auto a: node_reference)
2496     //         {
2497     //             fprintf(stderr, "b = %s %d\n", a.first.c_str(), a.second);
2498     //         }
2499 
2500     // count all weight node with zero reference
2501     int zero_reference_weight_node_count = 0;
2502     for (std::map<std::string, onnx::TensorProto>::iterator it = weights.begin(); it != weights.end(); it++)
2503     {
2504         const std::string& input_name = it->first;
2505 
2506         int refcount = node_reference[input_name];
2507         if (refcount == 0)
2508             zero_reference_weight_node_count++;
2509     }
2510 
2511     // we always treat constant node as weight or binaryop_weights
2512     // do not count it twice for layer_count
2513     int constant_node_count_moved_to_weight = 0;
2514     for (int i = 0; i < node_count; i++)
2515     {
2516         const onnx::NodeProto& node = graph.node(i);
2517 
2518         const std::string& op = node.op_type();
2519 
2520         if (op == "Constant")
2521         {
2522             constant_node_count_moved_to_weight++;
2523         }
2524     }
2525 
2526     // some op may have anonymous input
2527     // LSTM sequence_lens
2528     blob_names.erase("");
2529     node_reference.erase("");
2530 
2531     // remove node_reference entry with reference equals to one
2532     int split_layer_count = 0;
2533     int splitncnn_blob_count = 0;
2534     // split node reference
2535     std::map<std::string, int> split_node_reference;
2536     for (std::map<std::string, int>::iterator it = node_reference.begin(); it != node_reference.end(); it++)
2537     {
2538         if (it->second > 1)
2539         {
2540             split_layer_count++;
2541             splitncnn_blob_count += it->second;
2542 
2543             split_node_reference[it->first] = it->second;
2544         }
2545     }
2546 
2547     fprintf(pp, "%zu %zu\n", node_count - constant_node_count_moved_to_weight + weights.size() - zero_reference_weight_node_count - reduced_node_count + input_node_count + split_layer_count, blob_names.size() - zero_reference_weight_node_count + splitncnn_blob_count);
2548 
2549     int internal_split = 0;
2550 
2551     // place Input at the beginning
2552     for (int j = 0; j < graph.input_size(); j++)
2553     {
2554         const std::string& input_name = graph.input(j).name();
2555 
2556         // check weight
2557         if (weights.find(input_name) != weights.end())
2558             continue;
2559 
2560         fprintf(pp, "%-16s %-24s 0 1 %s\n", "Input", input_name.c_str(), input_name.c_str());
2561 
2562         int refcount = node_reference[input_name];
2563         if (refcount <= 1)
2564         {
2565             continue;
2566         }
2567 
2568         char splitname[256];
2569         sprintf(splitname, "splitncnn_input%d", j);
2570         fprintf(pp, "%-16s %-24s %d %d", "Split", splitname, 1, refcount);
2571         fprintf(pp, " %s", input_name.c_str());
2572 
2573         for (int k = 0; k < refcount; k++)
2574         {
2575             fprintf(pp, " %s_splitncnn_%d", input_name.c_str(), k);
2576         }
2577         fprintf(pp, "\n");
2578     }
2579 
2580     // place MemoryData next
2581     for (std::map<std::string, onnx::TensorProto>::iterator weight_it = weights.begin(); weight_it != weights.end(); weight_it++)
2582     {
2583         const std::string& input_name = weight_it->first;
2584 
2585         int refcount = node_reference[input_name];
2586         if (refcount == 0)
2587         {
2588             continue;
2589         }
2590 
2591         fprintf(pp, "%-16s %-24s 0 1 %s", "MemoryData", input_name.c_str(), input_name.c_str());
2592 
2593         const onnx::TensorProto& M = weights[input_name];
2594 
2595         if (M.dims_size() == 0)
2596         {
2597             fprintf(pp, " 0=%d", get_tensor_proto_data_size(M));
2598         }
2599         else if (M.dims_size() == 1)
2600         {
2601             fprintf(pp, " 0=%d", (int)M.dims(0));
2602         }
2603         else if (M.dims_size() == 2)
2604         {
2605             fprintf(pp, " 0=%d", (int)M.dims(1));
2606             if (M.dims(0) != 1)
2607             {
2608                 fprintf(pp, " 1=%d", (int)M.dims(0));
2609             }
2610         }
2611         else if (M.dims_size() == 3)
2612         {
2613             fprintf(pp, " 0=%d", (int)M.dims(2));
2614             fprintf(pp, " 1=%d", (int)M.dims(1));
2615             if (M.dims(0) != 1)
2616             {
2617                 fprintf(pp, " 2=%d", (int)M.dims(0));
2618             }
2619         }
2620         else if (M.dims_size() == 4)
2621         {
2622             fprintf(pp, " 0=%d", (int)M.dims(3));
2623             fprintf(pp, " 1=%d", (int)M.dims(2));
2624             fprintf(pp, " 2=%d", (int)M.dims(1));
2625         }
2626 
2627         fprintf(pp, "\n");
2628 
2629         fwrite_tensor_proto_data(M, bp);
2630 
2631         if (refcount <= 1)
2632         {
2633             continue;
2634         }
2635 
2636         char splitname[256];
2637         sprintf(splitname, "splitncnn_%d", internal_split);
2638         fprintf(pp, "%-16s %-24s %d %d", "Split", splitname, 1, refcount);
2639 
2640         fprintf(pp, " %s", input_name.c_str());
2641 
2642         for (int k = 0; k < refcount; k++)
2643         {
2644             fprintf(pp, " %s_splitncnn_%d", input_name.c_str(), k);
2645         }
2646         fprintf(pp, "\n");
2647 
2648         internal_split++;
2649     }
2650 
2651     for (int i = 0; i < node_count; i++)
2652     {
2653         const onnx::NodeProto& node = graph.node(i);
2654 
2655         const std::string& op = node.op_type();
2656 
2657         //         fprintf(stderr, "op = %s\n", op.c_str());
2658 
2659         if (op == "noop_reducedncnn")
2660         {
2661             continue;
2662         }
2663 
2664         std::string name = node.name();
2665         if (name.empty())
2666         {
2667             name = node.output(0);
2668         }
2669 
2670         int input_size = node.input_size();
2671         int output_size = node.output_size();
2672 
2673         for (int j = 0; j < (int)node.input_size(); j++)
2674         {
2675             const std::string& input_name = node.input(j);
2676 
2677             // check weight
2678             if (weights.find(input_name) != weights.end() && node_reference[input_name] == 0)
2679             {
2680                 input_size--;
2681             }
2682 
2683             if (input_name.empty())
2684             {
2685                 input_size--;
2686             }
2687 
2688             //             fprintf(stderr, "  input = %s\n", input_name.c_str());
2689         }
2690         /*
2691         for (int j=0; j<(int)node.output_size(); j++)
2692         {
2693             const std::string& output_name = node.output(j);
2694             fprintf(stderr, "  output = %s\n", output_name.c_str());
2695         }
2696         */
2697 
2698         if (op == "Abs")
2699         {
2700             fprintf(pp, "%-16s", "UnaryOp");
2701         }
2702         else if (op == "Acos")
2703         {
2704             fprintf(pp, "%-16s", "UnaryOp");
2705         }
2706         else if (op == "Add")
2707         {
2708             fprintf(pp, "%-16s", "BinaryOp");
2709         }
2710         else if (op == "Asin")
2711         {
2712             fprintf(pp, "%-16s", "UnaryOp");
2713         }
2714         else if (op == "Atan")
2715         {
2716             fprintf(pp, "%-16s", "UnaryOp");
2717         }
2718         else if (op == "Attention")
2719         {
2720             fprintf(pp, "%-16s", "Attention");
2721         }
2722         else if (op == "AveragePool" || op == "MaxPool")
2723         {
2724             fprintf(pp, "%-16s", "Pooling");
2725         }
2726         else if (op == "BatchNormalization")
2727         {
2728             fprintf(pp, "%-16s", "BatchNorm");
2729         }
2730         else if (op == "BiasGelu")
2731         {
2732             fprintf(pp, "%-16s", "BiasGelu");
2733         }
2734         else if (op == "Ceil")
2735         {
2736             fprintf(pp, "%-16s", "UnaryOp");
2737         }
2738         else if (op == "Clip")
2739         {
2740             fprintf(pp, "%-16s", "Clip");
2741         }
2742         else if (op == "Concat")
2743         {
2744             fprintf(pp, "%-16s", "Concat");
2745         }
2746         else if (op == "Constant")
2747         {
2748             continue;
2749         }
2750         else if (op == "Conv")
2751         {
2752             int group = get_node_attr_i(node, "group", 1);
2753             if (group > 1)
2754             {
2755                 fprintf(pp, "%-16s", "ConvolutionDepthWise");
2756             }
2757             else
2758             {
2759                 fprintf(pp, "%-16s", "Convolution");
2760             }
2761         }
2762         else if (op == "ConvTranspose")
2763         {
2764             int group = get_node_attr_i(node, "group", 1);
2765             if (group > 1)
2766             {
2767                 fprintf(pp, "%-16s", "DeconvolutionDepthWise");
2768             }
2769             else
2770             {
2771                 fprintf(pp, "%-16s", "Deconvolution");
2772             }
2773         }
2774         else if (op == "Cos")
2775         {
2776             fprintf(pp, "%-16s", "UnaryOp");
2777         }
2778         else if (op == "DepthToSpace")
2779         {
2780             fprintf(pp, "%-16s", "PixelShuffle");
2781         }
2782         else if (op == "Div")
2783         {
2784             fprintf(pp, "%-16s", "BinaryOp");
2785         }
2786         else if (op == "Dropout")
2787         {
2788             fprintf(pp, "%-16s", "Dropout");
2789             output_size = 1;
2790         }
2791         else if (op == "Elu")
2792         {
2793             fprintf(pp, "%-16s", "ELU");
2794         }
2795         else if (op == "EmbedLayerNormalization")
2796         {
2797             fprintf(pp, "%-16s", "EmbedLayerNormalization");
2798         }
2799         else if (op == "Exp")
2800         {
2801             fprintf(pp, "%-16s", "UnaryOp");
2802         }
2803         else if (op == "Flatten")
2804         {
2805             fprintf(pp, "%-16s", "Flatten");
2806         }
2807         else if (op == "Floor")
2808         {
2809             fprintf(pp, "%-16s", "UnaryOp");
2810         }
2811         else if (op == "Gemm")
2812         {
2813             float alpha = get_node_attr_f(node, "alpha", 1.f);
2814             float beta = get_node_attr_f(node, "beta", 1.f);
2815             int transA = get_node_attr_i(node, "transA", 0);
2816             int transB = get_node_attr_i(node, "transB", 0);
2817 
2818             if (alpha == 1.f && beta == 1.f && transA == 0 && transB == 1)
2819             {
2820                 // InnerProduct-like A * B + C
2821                 fprintf(pp, "%-16s", "InnerProduct");
2822             }
2823             else
2824             {
2825                 fprintf(pp, "%-16s", "Gemm");
2826             }
2827         }
2828         else if (op == "GlobalAveragePool")
2829         {
2830             fprintf(pp, "%-16s", "Pooling");
2831         }
2832         else if (op == "GlobalMaxPool")
2833         {
2834             fprintf(pp, "%-16s", "Pooling");
2835         }
2836         else if (op == "adaptive_avg_pool2d" || op == "adaptive_max_pool2d")
2837         {
2838             fprintf(pp, "%-16s", "Pooling");
2839         }
2840         else if (op == "GroupNorm")
2841         {
2842             fprintf(pp, "%-16s", "GroupNorm");
2843         }
2844         else if (op == "GRU")
2845         {
2846             fprintf(pp, "%-16s", "GRU");
2847         }
2848         else if (op == "HardSigmoid")
2849         {
2850             fprintf(pp, "%-16s", "HardSigmoid");
2851         }
2852         else if (op == "HardSwish")
2853         {
2854             fprintf(pp, "%-16s", "HardSwish");
2855         }
2856         else if (op == "ImageScaler")
2857         {
2858             fprintf(pp, "%-16s", "Scale");
2859         }
2860         else if (op == "InstanceNormalization")
2861         {
2862             fprintf(pp, "%-16s", "InstanceNorm");
2863         }
2864         else if (op == "LeakyRelu")
2865         {
2866             fprintf(pp, "%-16s", "ReLU");
2867         }
2868         else if (op == "Log")
2869         {
2870             fprintf(pp, "%-16s", "UnaryOp");
2871         }
2872         else if (op == "LRN")
2873         {
2874             fprintf(pp, "%-16s", "LRN");
2875         }
2876         else if (op == "LSTM")
2877         {
2878             fprintf(pp, "%-16s", "LSTM");
2879         }
2880         else if (op == "MatMul")
2881         {
2882             if (weights.find(node.input(1)) != weights.end() && weights[node.input(1)].dims_size() == 2)
2883             {
2884                 fprintf(pp, "%-16s", "InnerProduct");
2885             }
2886             else
2887             {
2888                 fprintf(pp, "%-16s", "Gemm");
2889             }
2890         }
2891         else if (op == "Max")
2892         {
2893             fprintf(pp, "%-16s", "BinaryOp");
2894         }
2895         else if (op == "Min")
2896         {
2897             fprintf(pp, "%-16s", "BinaryOp");
2898         }
2899         else if (op == "Mul")
2900         {
2901             fprintf(pp, "%-16s", "BinaryOp");
2902         }
2903         else if (op == "Neg")
2904         {
2905             fprintf(pp, "%-16s", "UnaryOp");
2906         }
2907         else if (op == "Normalize")
2908         {
2909             fprintf(pp, "%-16s", "Normalize");
2910         }
2911         else if (op == "Pad")
2912         {
2913             fprintf(pp, "%-16s", "Padding");
2914         }
2915         else if (op == "PixelShuffle")
2916         {
2917             fprintf(pp, "%-16s", "PixelShuffle");
2918         }
2919         else if (op == "Pow")
2920         {
2921             fprintf(pp, "%-16s", "BinaryOp");
2922         }
2923         else if (op == "PRelu")
2924         {
2925             fprintf(pp, "%-16s", "PReLU");
2926         }
2927         else if (op == "Reciprocal")
2928         {
2929             fprintf(pp, "%-16s", "UnaryOp");
2930         }
2931         else if (op == "ReduceMax" || op == "ReduceMin" || op == "ReduceMean" || op == "ReduceProd" || op == "ReduceSum" || op == "ReduceSumSquare" || op == "ReduceL1" || op == "ReduceL2" || op == "ReduceLogSum" || op == "ReduceLogSumExp")
2932         {
2933             fprintf(pp, "%-16s", "Reduction");
2934         }
2935         else if (op == "Relu")
2936         {
2937             fprintf(pp, "%-16s", "ReLU");
2938         }
2939         else if (op == "Reorg")
2940         {
2941             fprintf(pp, "%-16s", "Reorg");
2942         }
2943         else if (op == "Reshape")
2944         {
2945             fprintf(pp, "%-16s", "Reshape");
2946         }
2947         else if (op == "RNN")
2948         {
2949             fprintf(pp, "%-16s", "RNN");
2950         }
2951         else if (op == "ShuffleChannel")
2952         {
2953             fprintf(pp, "%-16s", "ShuffleChannel");
2954         }
2955         else if (op == "Sigmoid")
2956         {
2957             fprintf(pp, "%-16s", "Sigmoid");
2958         }
2959         else if (op == "Sin")
2960         {
2961             fprintf(pp, "%-16s", "UnaryOp");
2962         }
2963         else if (op == "SkipLayerNormalization")
2964         {
2965             fprintf(pp, "%-16s", "SkipLayerNormalization");
2966         }
2967         else if (op == "Slice")
2968         {
2969             fprintf(pp, "%-16s", "Crop");
2970         }
2971         else if (op == "Softmax")
2972         {
2973             fprintf(pp, "%-16s", "Softmax");
2974         }
2975         else if (op == "Softplus")
2976         {
2977             fprintf(pp, "%-16s", "Softplus");
2978         }
2979         else if (op == "Split")
2980         {
2981             fprintf(pp, "%-16s", "Slice");
2982         }
2983         else if (op == "Sqrt")
2984         {
2985             fprintf(pp, "%-16s", "UnaryOp");
2986         }
2987         else if (op == "Squeeze")
2988         {
2989             fprintf(pp, "%-16s", "Squeeze");
2990         }
2991         else if (op == "Sub")
2992         {
2993             fprintf(pp, "%-16s", "BinaryOp");
2994         }
2995         else if (op == "Sum")
2996         {
2997             fprintf(pp, "%-16s", "Eltwise");
2998         }
2999         else if (op == "Swish")
3000         {
3001             fprintf(pp, "%-16s", "Swish");
3002         }
3003         else if (op == "Tan")
3004         {
3005             fprintf(pp, "%-16s", "UnaryOp");
3006         }
3007         else if (op == "Tanh")
3008         {
3009             fprintf(pp, "%-16s", "UnaryOp");
3010         }
3011         else if (op == "Transpose")
3012         {
3013             fprintf(pp, "%-16s", "Permute");
3014         }
3015         else if (op == "Upsample" || op == "Resize")
3016         {
3017             fprintf(pp, "%-16s", "Interp");
3018         }
3019         else if (op == "Unsqueeze")
3020         {
3021             fprintf(pp, "%-16s", "ExpandDims");
3022         }
3023         else
3024         {
3025             // TODO
3026             fprintf(stderr, "%s not supported yet!\n", op.c_str());
3027             fprintf(pp, "%-16s", op.c_str());
3028         }
3029 
3030         fprintf(pp, " %-24s %d %d", name.c_str(), input_size, output_size);
3031 
3032         for (int j = 0; j < (int)node.input_size(); j++)
3033         {
3034             std::string input_name = node.input(j);
3035 
3036             // check weight
3037             if (weights.find(input_name) != weights.end() && node_reference[input_name] == 0)
3038             {
3039                 continue;
3040             }
3041 
3042             if (input_name.empty())
3043             {
3044                 continue;
3045             }
3046 
3047             if (split_node_reference.find(input_name) != split_node_reference.end())
3048             {
3049                 int refidx = split_node_reference[input_name] - 1;
3050                 split_node_reference[input_name] = refidx;
3051 
3052                 char splitsuffix[256];
3053                 sprintf(splitsuffix, "_splitncnn_%d", refidx);
3054                 input_name = input_name + splitsuffix;
3055             }
3056 
3057             fprintf(pp, " %s", input_name.c_str());
3058         }
3059 
3060         for (int j = 0; j < output_size; j++)
3061         {
3062             const std::string& output_name = node.output(j);
3063 
3064             fprintf(pp, " %s", output_name.c_str());
3065         }
3066 
3067         if (op == "Abs")
3068         {
3069             int op_type = 0;
3070             fprintf(pp, " 0=%d", op_type);
3071         }
3072         else if (op == "Acos")
3073         {
3074             int op_type = 13;
3075             fprintf(pp, " 0=%d", op_type);
3076         }
3077         else if (op == "Add")
3078         {
3079             int op_type = 0;
3080             fprintf(pp, " 0=%d", op_type);
3081 
3082             if (weights.find(node.input(1)) != weights.end() && weights[node.input(1)].dims_size() == 0)
3083             {
3084                 float b = get_node_attr_from_input_f(weights[node.input(1)]);
3085                 fprintf(pp, " 1=1");
3086                 fprintf(pp, " 2=%e", b);
3087             }
3088         }
3089         else if (op == "Asin")
3090         {
3091             int op_type = 12;
3092             fprintf(pp, " 0=%d", op_type);
3093         }
3094         else if (op == "Atan")
3095         {
3096             int op_type = 14;
3097             fprintf(pp, " 0=%d", op_type);
3098         }
3099         else if (op == "Attention")
3100         {
3101             int num_heads = get_node_attr_i(node, "num_heads", 1);
3102 
3103             const onnx::TensorProto& W = weights[node.input(1)];
3104             const onnx::TensorProto& B = weights[node.input(2)];
3105 
3106             fprintf(pp, " 0=%d", get_tensor_proto_data_size(B));
3107             fprintf(pp, " 1=%d", num_heads);
3108             fprintf(pp, " 2=%d", get_tensor_proto_data_size(W));
3109 
3110             int quantize_tag = 0;
3111             fwrite(&quantize_tag, sizeof(int), 1, bp);
3112 
3113             fwrite_tensor_proto_data(W, bp);
3114 
3115             fwrite(&quantize_tag, sizeof(int), 1, bp);
3116 
3117             fwrite_tensor_proto_data(B, bp);
3118         }
3119         else if (op == "AveragePool" || op == "MaxPool")
3120         {
3121             std::string auto_pad = get_node_attr_s(node, "auto_pad");
3122             int ceil_mode = get_node_attr_i(node, "ceil_mode", 0);
3123             std::vector<int> kernel_shape = get_node_attr_ai(node, "kernel_shape");
3124             std::vector<int> strides = get_node_attr_ai(node, "strides");
3125             std::vector<int> pads = get_node_attr_ai(node, "pads");
3126 
3127             int pool = op == "AveragePool" ? 1 : 0;
3128             int pad_mode = 1;
3129 
3130             if (auto_pad == "SAME_UPPER")
3131             {
3132                 pad_mode = 2;
3133             }
3134             else if (auto_pad == "SAME_LOWER")
3135             {
3136                 pad_mode = 3;
3137             }
3138 
3139             if (ceil_mode == 1)
3140             {
3141                 pad_mode = 0;
3142             }
3143 
3144             fprintf(pp, " 0=%d", pool);
3145 
3146             if (kernel_shape.size() == 1)
3147             {
3148                 fprintf(pp, " 1=%d", kernel_shape[0]);
3149             }
3150             else if (kernel_shape.size() == 2)
3151             {
3152                 fprintf(pp, " 1=%d", kernel_shape[1]);
3153                 fprintf(pp, " 11=%d", kernel_shape[0]);
3154             }
3155 
3156             if (strides.size() == 1)
3157             {
3158                 fprintf(pp, " 2=%d", strides[0]);
3159             }
3160             else if (strides.size() == 2)
3161             {
3162                 fprintf(pp, " 2=%d", strides[1]);
3163                 fprintf(pp, " 12=%d", strides[0]);
3164             }
3165 
3166             if (pads.size() == 1)
3167             {
3168                 fprintf(pp, " 3=%d", pads[0]);
3169             }
3170             else if (pads.size() == 2)
3171             {
3172                 fprintf(pp, " 3=%d", pads[1]);
3173                 fprintf(pp, " 13=%d", pads[0]);
3174             }
3175             else if (pads.size() == 4)
3176             {
3177                 fprintf(pp, " 3=%d", pads[1]);
3178                 fprintf(pp, " 13=%d", pads[0]);
3179                 fprintf(pp, " 14=%d", pads[3]);
3180                 fprintf(pp, " 15=%d", pads[2]);
3181             }
3182 
3183             fprintf(pp, " 5=%d", pad_mode);
3184 
3185             if (op == "AveragePool")
3186             {
3187                 int avgpool_count_include_pad = get_node_attr_i(node, "count_include_pad", 0);
3188                 fprintf(pp, " 6=%d", avgpool_count_include_pad);
3189             }
3190         }
3191         else if (op == "BatchNormalization")
3192         {
3193             float epsilon = get_node_attr_f(node, "epsilon", 1e-5f);
3194 
3195             const onnx::TensorProto& scale = weights[node.input(1)];
3196             const onnx::TensorProto& B = weights[node.input(2)];
3197             const onnx::TensorProto& mean = weights[node.input(3)];
3198             const onnx::TensorProto& var = weights[node.input(4)];
3199 
3200             int channels = get_tensor_proto_data_size(scale);
3201 
3202             fprintf(pp, " 0=%d", channels);
3203 
3204             fwrite_tensor_proto_data(scale, bp);
3205             fwrite_tensor_proto_data(mean, bp);
3206             // apply epsilon to var
3207             {
3208                 const float* v = var.has_raw_data() ? (const float*)var.raw_data().data() : var.float_data().data();
3209 
3210                 for (int j = 0; j < channels; j++)
3211                 {
3212                     float ve = v[j] + epsilon;
3213                     fwrite(&ve, sizeof(float), 1, bp);
3214                 }
3215             }
3216             fwrite_tensor_proto_data(B, bp);
3217         }
3218         else if (op == "BiasGelu")
3219         {
3220             const onnx::TensorProto& B = weights[node.input(1)];
3221 
3222             fprintf(pp, " 0=%d", get_tensor_proto_data_size(B));
3223 
3224             int quantize_tag = 0;
3225             fwrite(&quantize_tag, sizeof(int), 1, bp);
3226 
3227             fwrite_tensor_proto_data(B, bp);
3228         }
3229         else if (op == "Ceil")
3230         {
3231             int op_type = 3;
3232             fprintf(pp, " 0=%d", op_type);
3233         }
3234         else if (op == "Clip")
3235         {
3236             float min;
3237             float max;
3238             if (node.input_size() == 1)
3239             {
3240                 min = get_node_attr_f(node, "min", -FLT_MAX);
3241                 max = get_node_attr_f(node, "max", FLT_MAX);
3242             }
3243             else
3244             {
3245                 const onnx::TensorProto& min_tp = weights[node.input(1)];
3246                 const onnx::TensorProto& max_tp = weights[node.input(2)];
3247 
3248                 min = get_node_attr_from_input_f(min_tp);
3249                 max = get_node_attr_from_input_f(max_tp);
3250             }
3251 
3252             fprintf(pp, " 0=%e", min);
3253             fprintf(pp, " 1=%e", max);
3254         }
3255         else if (op == "Concat")
3256         {
3257             int axis = get_node_attr_i(node, "axis", 1);
3258             fprintf(pp, " 0=%d", axis - 1);
3259         }
3260         else if (op == "Constant")
3261         {
3262             // never reach here
3263         }
3264         else if (op == "Conv")
3265         {
3266             const onnx::TensorProto& W = weights[node.input(1)];
3267 
3268             int num_filter = W.dims(0);
3269             int has_bias = node.input_size() == 3 ? 1 : 0;
3270 
3271             std::string auto_pad = get_node_attr_s(node, "auto_pad");
3272             std::vector<int> kernel_shape = get_node_attr_ai(node, "kernel_shape");
3273             std::vector<int> dilations = get_node_attr_ai(node, "dilations");
3274             std::vector<int> strides = get_node_attr_ai(node, "strides");
3275             std::vector<int> pads = get_node_attr_ai(node, "pads");
3276             int group = get_node_attr_i(node, "group", 1);
3277 
3278             fprintf(pp, " 0=%d", num_filter);
3279 
3280             if (kernel_shape.size() == 1)
3281             {
3282                 fprintf(pp, " 1=%d", kernel_shape[0]);
3283             }
3284             else if (kernel_shape.size() == 2)
3285             {
3286                 fprintf(pp, " 1=%d", kernel_shape[1]);
3287                 fprintf(pp, " 11=%d", kernel_shape[0]);
3288             }
3289 
3290             if (dilations.size() == 1)
3291             {
3292                 fprintf(pp, " 2=%d", dilations[0]);
3293             }
3294             else if (dilations.size() == 2)
3295             {
3296                 fprintf(pp, " 2=%d", dilations[1]);
3297                 fprintf(pp, " 12=%d", dilations[0]);
3298             }
3299 
3300             if (strides.size() == 1)
3301             {
3302                 fprintf(pp, " 3=%d", strides[0]);
3303             }
3304             else if (strides.size() == 2)
3305             {
3306                 fprintf(pp, " 3=%d", strides[1]);
3307                 fprintf(pp, " 13=%d", strides[0]);
3308             }
3309 
3310             if (auto_pad == "SAME_UPPER")
3311             {
3312                 fprintf(pp, " 4=-233");
3313             }
3314             else if (auto_pad == "SAME_LOWER")
3315             {
3316                 fprintf(pp, " 4=-234");
3317             }
3318             else
3319             {
3320                 if (pads.size() == 1)
3321                 {
3322                     fprintf(pp, " 4=%d", pads[0]);
3323                 }
3324                 else if (pads.size() == 2)
3325                 {
3326                     fprintf(pp, " 4=%d", pads[1]);
3327                     fprintf(pp, " 14=%d", pads[0]);
3328                 }
3329                 else if (pads.size() == 4)
3330                 {
3331                     fprintf(pp, " 4=%d", pads[1]);
3332                     fprintf(pp, " 14=%d", pads[0]);
3333                     fprintf(pp, " 15=%d", pads[3]);
3334                     fprintf(pp, " 16=%d", pads[2]);
3335                 }
3336             }
3337 
3338             fprintf(pp, " 5=%d", has_bias);
3339 
3340             fprintf(pp, " 6=%d", get_tensor_proto_data_size(W));
3341 
3342             if (group > 1)
3343             {
3344                 fprintf(pp, " 7=%d", group);
3345             }
3346 
3347             int quantize_tag = 0;
3348             fwrite(&quantize_tag, sizeof(int), 1, bp);
3349 
3350             fwrite_tensor_proto_data(W, bp);
3351 
3352             if (has_bias)
3353             {
3354                 const onnx::TensorProto& B = weights[node.input(2)];
3355                 fwrite_tensor_proto_data(B, bp);
3356             }
3357         }
3358         else if (op == "ConvTranspose")
3359         {
3360             const onnx::TensorProto& W = weights[node.input(1)];
3361 
3362             int has_bias = node.input_size() == 3 ? 1 : 0;
3363 
3364             std::string auto_pad = get_node_attr_s(node, "auto_pad");
3365             std::vector<int> kernel_shape = get_node_attr_ai(node, "kernel_shape");
3366             std::vector<int> dilations = get_node_attr_ai(node, "dilations");
3367             std::vector<int> strides = get_node_attr_ai(node, "strides");
3368             std::vector<int> output_padding = get_node_attr_ai(node, "output_padding");
3369             std::vector<int> output_shape = get_node_attr_ai(node, "output_shape");
3370             std::vector<int> pads = get_node_attr_ai(node, "pads");
3371             int group = get_node_attr_i(node, "group", 1);
3372             int num_filter = W.dims(1) * group;
3373 
3374             fprintf(pp, " 0=%d", num_filter);
3375 
3376             if (kernel_shape.size() == 1)
3377             {
3378                 fprintf(pp, " 1=%d", kernel_shape[0]);
3379             }
3380             else if (kernel_shape.size() == 2)
3381             {
3382                 fprintf(pp, " 1=%d", kernel_shape[1]);
3383                 fprintf(pp, " 11=%d", kernel_shape[0]);
3384             }
3385 
3386             if (dilations.size() == 1)
3387             {
3388                 fprintf(pp, " 2=%d", dilations[0]);
3389             }
3390             else if (dilations.size() == 2)
3391             {
3392                 fprintf(pp, " 2=%d", dilations[1]);
3393                 fprintf(pp, " 12=%d", dilations[0]);
3394             }
3395 
3396             if (strides.size() == 1)
3397             {
3398                 fprintf(pp, " 3=%d", strides[0]);
3399             }
3400             else if (strides.size() == 2)
3401             {
3402                 fprintf(pp, " 3=%d", strides[1]);
3403                 fprintf(pp, " 13=%d", strides[0]);
3404             }
3405 
3406             if (auto_pad == "SAME_UPPER")
3407             {
3408                 fprintf(pp, " 4=-233");
3409             }
3410             else if (auto_pad == "SAME_LOWER")
3411             {
3412                 fprintf(pp, " 4=-234");
3413             }
3414             else
3415             {
3416                 if (pads.size() == 1)
3417                 {
3418                     fprintf(pp, " 4=%d", pads[0]);
3419                 }
3420                 else if (pads.size() == 2)
3421                 {
3422                     fprintf(pp, " 4=%d", pads[1]);
3423                     fprintf(pp, " 14=%d", pads[0]);
3424                 }
3425                 else if (pads.size() == 4)
3426                 {
3427                     fprintf(pp, " 4=%d", pads[1]);
3428                     fprintf(pp, " 14=%d", pads[0]);
3429                     fprintf(pp, " 15=%d", pads[3]);
3430                     fprintf(pp, " 16=%d", pads[2]);
3431                 }
3432             }
3433 
3434             if (output_padding.size() == 1)
3435             {
3436                 fprintf(pp, " 18=%d", output_padding[0]);
3437             }
3438             else if (output_padding.size() == 2)
3439             {
3440                 fprintf(pp, " 18=%d", output_padding[1]);
3441                 fprintf(pp, " 19=%d", output_padding[0]);
3442             }
3443 
3444             if (output_shape.size() == 1)
3445             {
3446                 fprintf(pp, " 20=%d", output_shape[0]);
3447             }
3448             else if (output_shape.size() == 2)
3449             {
3450                 fprintf(pp, " 20=%d", output_shape[1]);
3451                 fprintf(pp, " 21=%d", output_shape[0]);
3452             }
3453 
3454             fprintf(pp, " 5=%d", has_bias);
3455 
3456             fprintf(pp, " 6=%d", get_tensor_proto_data_size(W));
3457 
3458             if (group > 1)
3459             {
3460                 fprintf(pp, " 7=%d", group);
3461             }
3462 
3463             int quantize_tag = 0;
3464             fwrite(&quantize_tag, sizeof(int), 1, bp);
3465 
3466             int maxk = 0;
3467             if (kernel_shape.size() == 2)
3468             {
3469                 maxk = kernel_shape[1] * kernel_shape[0];
3470             }
3471             else
3472             {
3473                 maxk = kernel_shape[0] * kernel_shape[0];
3474             }
3475             int weight_data_size = get_tensor_proto_data_size(W);
3476             const float* weight_data = 0;
3477             if (W.has_raw_data())
3478             {
3479                 weight_data = (const float*)W.raw_data().data();
3480             }
3481             else if (W.data_type() == 1)
3482             {
3483                 weight_data = W.float_data().data();
3484             }
3485             for (int g = 0; g < group; g++)
3486             {
3487                 // reorder weight from inch-outch to outch-inch
3488                 int num_filter_g = num_filter / group;
3489                 int num_input = weight_data_size / maxk / num_filter_g / group;
3490                 const float* weight_data_ptr = weight_data + g * maxk * num_filter_g * num_input;
3491                 for (int k = 0; k < num_filter_g; k++)
3492                 {
3493                     for (int j = 0; j < num_input; j++)
3494                     {
3495                         fwrite(weight_data_ptr + (j * num_filter_g + k) * maxk, sizeof(float), maxk, bp);
3496                     }
3497                 }
3498             }
3499 
3500             if (has_bias)
3501             {
3502                 const onnx::TensorProto& B = weights[node.input(2)];
3503                 fwrite_tensor_proto_data(B, bp);
3504             }
3505         }
3506         else if (op == "Cos")
3507         {
3508             int op_type = 10;
3509             fprintf(pp, " 0=%d", op_type);
3510         }
3511         else if (op == "DepthToSpace")
3512         {
3513             // pixelshuffle
3514             int scale_factor = get_node_attr_i(node, "blocksize", 1);
3515             std::string mode = get_node_attr_s(node, "mode");
3516             fprintf(pp, " 0=%d", scale_factor);
3517             if (mode == "CRD")
3518             {
3519                 fprintf(pp, " 1=0");
3520             }
3521             else if (mode == "DCR")
3522             {
3523                 fprintf(pp, " 1=1");
3524             }
3525         }
3526         else if (op == "Div")
3527         {
3528             int op_type = 3;
3529             fprintf(pp, " 0=%d", op_type);
3530 
3531             if (weights.find(node.input(1)) != weights.end() && weights[node.input(1)].dims_size() == 0)
3532             {
3533                 float b = get_node_attr_from_input_f(weights[node.input(1)]);
3534                 fprintf(pp, " 1=1");
3535                 fprintf(pp, " 2=%e", b);
3536             }
3537         }
3538         else if (op == "Dropout")
3539         {
3540             // no-op
3541         }
3542         else if (op == "Elu")
3543         {
3544             float alpha = get_node_attr_f(node, "alpha", 1.f);
3545             fprintf(pp, " 0=%e", alpha);
3546         }
3547         else if (op == "EmbedLayerNormalization")
3548         {
3549             const onnx::TensorProto& words = weights[node.input(2)];
3550             const onnx::TensorProto& positions = weights[node.input(3)];
3551             const onnx::TensorProto& W = weights[node.input(5)];
3552             const onnx::TensorProto& B = weights[node.input(6)];
3553 
3554             fprintf(pp, " 0=%d", get_tensor_proto_data_size(B));
3555             fprintf(pp, " 1=%d", get_tensor_proto_data_size(words));
3556             fprintf(pp, " 2=%d", get_tensor_proto_data_size(positions));
3557 
3558             int quantize_tag = 0;
3559             fwrite(&quantize_tag, sizeof(int), 1, bp);
3560 
3561             fwrite_tensor_proto_data(words, bp);
3562 
3563             fwrite(&quantize_tag, sizeof(int), 1, bp);
3564 
3565             fwrite_tensor_proto_data(positions, bp);
3566 
3567             fwrite(&quantize_tag, sizeof(int), 1, bp);
3568 
3569             fwrite_tensor_proto_data(W, bp);
3570 
3571             fwrite(&quantize_tag, sizeof(int), 1, bp);
3572 
3573             fwrite_tensor_proto_data(B, bp);
3574         }
3575         else if (op == "Exp")
3576         {
3577             int op_type = 7;
3578             fprintf(pp, " 0=%d", op_type);
3579         }
3580         else if (op == "Flatten")
3581         {
3582             int axis = get_node_attr_i(node, "axis", 1);
3583             if (axis != 1)
3584             {
3585                 fprintf(stderr, "Unsupported Flatten axis %d!\n", axis);
3586             }
3587         }
3588         else if (op == "Floor")
3589         {
3590             int op_type = 2;
3591             fprintf(pp, " 0=%d", op_type);
3592         }
3593         else if (op == "Gemm")
3594         {
3595             float alpha = get_node_attr_f(node, "alpha", 1.f);
3596             float beta = get_node_attr_f(node, "beta", 1.f);
3597             int transA = get_node_attr_i(node, "transA", 0);
3598             int transB = get_node_attr_i(node, "transB", 0);
3599 
3600             if (alpha == 1.f && beta == 1.f && transA == 0 && transB == 1)
3601             {
3602                 // InnerProduct-like A * B + C
3603                 const onnx::TensorProto& B = weights[node.input(1)];
3604                 const onnx::TensorProto& C = weights[node.input(2)];
3605 
3606                 fprintf(pp, " 0=%d", get_tensor_proto_data_size(C));
3607                 fprintf(pp, " 1=1");
3608                 fprintf(pp, " 2=%d", get_tensor_proto_data_size(B));
3609 
3610                 int quantize_tag = 0;
3611                 fwrite(&quantize_tag, sizeof(int), 1, bp);
3612 
3613                 fwrite_tensor_proto_data(B, bp);
3614                 fwrite_tensor_proto_data(C, bp);
3615             }
3616             else
3617             {
3618                 // gemm
3619                 fprintf(pp, " 0=%e", alpha);
3620                 fprintf(pp, " 1=%e", beta);
3621                 fprintf(pp, " 2=%d", transA);
3622                 fprintf(pp, " 3=%d", transB);
3623             }
3624         }
3625         else if (op == "GlobalAveragePool")
3626         {
3627             int pool = 1;
3628             int global_pool = 1;
3629 
3630             fprintf(pp, " 0=%d", pool);
3631             fprintf(pp, " 4=%d", global_pool);
3632         }
3633         else if (op == "GlobalMaxPool")
3634         {
3635             int pool = 0;
3636             int global_pool = 1;
3637 
3638             fprintf(pp, " 0=%d", pool);
3639             fprintf(pp, " 4=%d", global_pool);
3640         }
3641         else if (op == "adaptive_avg_pool2d" || op == "adaptive_max_pool2d")
3642         {
3643             int pool = 0;
3644             if (op == "adaptive_avg_pool2d")
3645             {
3646                 pool = 1;
3647             }
3648             int adaptive_pooling = 1;
3649             const onnx::TensorProto& out_shape_tp = weights[node.input(1)];
3650             std::vector<int> out_shape = get_node_attr_from_input_ai(out_shape_tp);
3651 
3652             fprintf(pp, " 0=%d", pool);
3653             fprintf(pp, " 7=%d", adaptive_pooling);
3654             if (out_shape.size() == 1)
3655             {
3656                 fprintf(pp, " 8=%d", out_shape[0]);
3657             }
3658             else if (out_shape.size() == 2)
3659             {
3660                 // out_w
3661                 fprintf(pp, " 8=%d", out_shape[1]);
3662                 // out_h
3663                 fprintf(pp, " 18=%d", out_shape[0]);
3664             }
3665         }
3666         else if (op == "GroupNorm")
3667         {
3668             int groups = get_node_attr_i(node, "groups", 1);
3669             int channels = get_node_attr_i(node, "channels", 1);
3670             float eps = get_node_attr_f(node, "epsilon", 1e-5f);
3671             int affine = get_node_attr_i(node, "affine", 1);
3672 
3673             fprintf(pp, " 0=%d", groups);
3674             fprintf(pp, " 1=%d", channels);
3675             fprintf(pp, " 2=%e", eps);
3676             fprintf(pp, " 3=%d", affine);
3677             if (affine)
3678             {
3679                 const onnx::TensorProto& scale = weights[node.input(1)];
3680                 const onnx::TensorProto& B = weights[node.input(2)];
3681 
3682                 fwrite_tensor_proto_data(scale, bp);
3683                 fwrite_tensor_proto_data(B, bp);
3684             }
3685         }
3686         else if (op == "GRU")
3687         {
3688             const onnx::TensorProto& W = weights[node.input(1)];
3689             const onnx::TensorProto& R = weights[node.input(2)];
3690             const onnx::TensorProto& B = weights[node.input(3)];
3691 
3692             int hidden_size = get_node_attr_i(node, "hidden_size", 0);
3693             std::string direction = get_node_attr_s(node, "direction");
3694 
3695             int direction_type = 0;
3696             if (direction == "forward")
3697             {
3698                 direction_type = 0;
3699             }
3700             else if (direction == "reverse")
3701             {
3702                 direction_type = 1;
3703             }
3704             else if (direction == "bidirectional")
3705             {
3706                 direction_type = 2;
3707             }
3708 
3709             int weight_data_size = get_tensor_proto_data_size(W);
3710 
3711             fprintf(pp, " 0=%d", hidden_size);
3712             fprintf(pp, " 1=%d", weight_data_size);
3713             fprintf(pp, " 2=%d", direction_type);
3714 
3715             int num_directions = direction_type == 2 ? 2 : 1;
3716 
3717             int quantize_tag = 0;
3718 
3719             // reorder num_directions-URN-hidden-size to num_directions-RUN-hidden-size
3720             {
3721                 fwrite(&quantize_tag, sizeof(int), 1, bp);
3722 
3723                 int weight_data_size_g = get_tensor_proto_data_size(W) / 3 / num_directions;
3724                 const float* wptr = W.has_raw_data() ? (const float*)W.raw_data().data() : W.float_data().data();
3725 
3726                 const float* uptr = wptr;
3727                 const float* rptr = wptr + weight_data_size_g;
3728                 const float* nptr = wptr + weight_data_size_g * 2;
3729                 fwrite(rptr, sizeof(float), weight_data_size_g, bp);
3730                 fwrite(uptr, sizeof(float), weight_data_size_g, bp);
3731                 fwrite(nptr, sizeof(float), weight_data_size_g, bp);
3732 
3733                 if (direction_type == 2)
3734                 {
3735                     uptr += weight_data_size_g * 3;
3736                     rptr += weight_data_size_g * 3;
3737                     nptr += weight_data_size_g * 3;
3738                     fwrite(rptr, sizeof(float), weight_data_size_g, bp);
3739                     fwrite(uptr, sizeof(float), weight_data_size_g, bp);
3740                     fwrite(nptr, sizeof(float), weight_data_size_g, bp);
3741                 }
3742             }
3743 
3744             // reduce U and R bias except N
3745             // reorder num_directions-URN-hidden to num_directions-RUN-hidden
3746             {
3747                 fwrite(&quantize_tag, sizeof(int), 1, bp);
3748 
3749                 int bias_data_size_g = get_tensor_proto_data_size(B) / 2 / 3 / num_directions;
3750                 const float* bptr = B.has_raw_data() ? (const float*)B.raw_data().data() : B.float_data().data();
3751                 const float* wuptr = bptr;
3752                 const float* wrptr = bptr + bias_data_size_g;
3753                 const float* wnptr = bptr + bias_data_size_g * 2;
3754                 const float* buptr = bptr + bias_data_size_g * 3;
3755                 const float* brptr = bptr + bias_data_size_g * 4;
3756                 const float* bnptr = bptr + bias_data_size_g * 5;
3757 
3758                 for (int j = 0; j < bias_data_size_g; j++)
3759                 {
3760                     float vb = wrptr[j] + brptr[j];
3761                     fwrite(&vb, sizeof(float), 1, bp);
3762                 }
3763                 for (int j = 0; j < bias_data_size_g; j++)
3764                 {
3765                     float vb = wuptr[j] + buptr[j];
3766                     fwrite(&vb, sizeof(float), 1, bp);
3767                 }
3768                 fwrite(wnptr, sizeof(float), bias_data_size_g, bp);
3769                 fwrite(bnptr, sizeof(float), bias_data_size_g, bp);
3770 
3771                 if (direction_type == 2)
3772                 {
3773                     wuptr += bias_data_size_g * 6;
3774                     wrptr += bias_data_size_g * 6;
3775                     wnptr += bias_data_size_g * 6;
3776                     buptr += bias_data_size_g * 6;
3777                     brptr += bias_data_size_g * 6;
3778                     bnptr += bias_data_size_g * 6;
3779 
3780                     for (int j = 0; j < bias_data_size_g; j++)
3781                     {
3782                         float vb = wrptr[j] + brptr[j];
3783                         fwrite(&vb, sizeof(float), 1, bp);
3784                     }
3785                     for (int j = 0; j < bias_data_size_g; j++)
3786                     {
3787                         float vb = wuptr[j] + buptr[j];
3788                         fwrite(&vb, sizeof(float), 1, bp);
3789                     }
3790                     fwrite(wnptr, sizeof(float), bias_data_size_g, bp);
3791                     fwrite(bnptr, sizeof(float), bias_data_size_g, bp);
3792                 }
3793             }
3794 
3795             // reorder num_directions-URN-hidden-hidden to num_directions-RUN-hidden-hidden
3796             {
3797                 fwrite(&quantize_tag, sizeof(int), 1, bp);
3798 
3799                 int weight_data_size_g = get_tensor_proto_data_size(R) / 3 / num_directions;
3800                 const float* Rptr = R.has_raw_data() ? (const float*)R.raw_data().data() : R.float_data().data();
3801 
3802                 const float* uptr = Rptr;
3803                 const float* rptr = Rptr + weight_data_size_g;
3804                 const float* nptr = Rptr + weight_data_size_g * 2;
3805                 fwrite(rptr, sizeof(float), weight_data_size_g, bp);
3806                 fwrite(uptr, sizeof(float), weight_data_size_g, bp);
3807                 fwrite(nptr, sizeof(float), weight_data_size_g, bp);
3808 
3809                 if (direction_type == 2)
3810                 {
3811                     uptr += weight_data_size_g * 3;
3812                     rptr += weight_data_size_g * 3;
3813                     nptr += weight_data_size_g * 3;
3814                     fwrite(rptr, sizeof(float), weight_data_size_g, bp);
3815                     fwrite(uptr, sizeof(float), weight_data_size_g, bp);
3816                     fwrite(nptr, sizeof(float), weight_data_size_g, bp);
3817                 }
3818             }
3819         }
3820         else if (op == "HardSigmoid")
3821         {
3822             float alpha = get_node_attr_f(node, "alpha", 0.2f);
3823             float beta = get_node_attr_f(node, "beta", 0.5f);
3824 
3825             fprintf(pp, " 0=%e", alpha);
3826             fprintf(pp, " 1=%e", beta);
3827         }
3828         else if (op == "HardSwish")
3829         {
3830             float alpha = get_node_attr_f(node, "alpha", 0.2f);
3831             float beta = get_node_attr_f(node, "beta", 0.5f);
3832 
3833             fprintf(pp, " 0=%e", alpha);
3834             fprintf(pp, " 1=%e", beta);
3835         }
3836         else if (op == "ImageScaler")
3837         {
3838             std::vector<float> bias = get_node_attr_af(node, "bias");
3839             float scale = get_node_attr_f(node, "scale", 1.f);
3840 
3841             int channels = (int)bias.size();
3842 
3843             fprintf(pp, " 0=%d", channels);
3844             fprintf(pp, " 1=1");
3845 
3846             for (int j = 0; j < channels; j++)
3847             {
3848                 fwrite(&scale, sizeof(float), 1, bp);
3849             }
3850             fwrite(&bias[0], sizeof(float), channels, bp);
3851         }
3852         else if (op == "InstanceNormalization")
3853         {
3854             float eps = get_node_attr_f(node, "epsilon", 1e-5f);
3855 
3856             // discard affine-less S=1 B=0
3857             std::vector<float> affine_S = get_node_attr_from_input_af(weights[node.input(1)]);
3858             std::vector<float> affine_B = get_node_attr_from_input_af(weights[node.input(2)]);
3859             int channels = (int)affine_S.size();
3860             int affine = 0;
3861             {
3862                 for (int j = 0; j < channels; j++)
3863                 {
3864                     if (affine_S[j] != 1.f || affine_B[j] != 0.f)
3865                     {
3866                         affine = 1;
3867                         break;
3868                     }
3869                 }
3870             }
3871 
3872             fprintf(pp, " 0=%d", channels);
3873             fprintf(pp, " 1=%e", eps);
3874             fprintf(pp, " 2=%d", affine);
3875             if (affine)
3876             {
3877                 const onnx::TensorProto& scale = weights[node.input(1)];
3878                 const onnx::TensorProto& B = weights[node.input(2)];
3879 
3880                 fwrite_tensor_proto_data(scale, bp);
3881                 fwrite_tensor_proto_data(B, bp);
3882             }
3883         }
3884         else if (op == "LeakyRelu")
3885         {
3886             float alpha = get_node_attr_f(node, "alpha", 0.01f);
3887 
3888             fprintf(pp, " 0=%e", alpha);
3889         }
3890         else if (op == "Log")
3891         {
3892             int op_type = 8;
3893             fprintf(pp, " 0=%d", op_type);
3894         }
3895         else if (op == "LRN")
3896         {
3897             float alpha = get_node_attr_f(node, "alpha", 1.f);
3898             float beta = get_node_attr_f(node, "beta", 0.5f);
3899             float bias = get_node_attr_f(node, "bias", 1.f);
3900             int size = get_node_attr_i(node, "size", 1);
3901 
3902             int norm_region = 0;
3903 
3904             fprintf(pp, " 0=%d", norm_region);
3905             fprintf(pp, " 1=%d", size);
3906             fprintf(pp, " 2=%e", alpha);
3907             fprintf(pp, " 3=%e", beta);
3908             fprintf(pp, " 4=%e", bias);
3909         }
3910         else if (op == "LSTM")
3911         {
3912             const onnx::TensorProto& W = weights[node.input(1)];
3913             const onnx::TensorProto& R = weights[node.input(2)];
3914             const onnx::TensorProto& B = weights[node.input(3)];
3915 
3916             int hidden_size = get_node_attr_i(node, "hidden_size", 0);
3917             std::string direction = get_node_attr_s(node, "direction");
3918 
3919             int direction_type = 0;
3920             if (direction == "forward")
3921             {
3922                 direction_type = 0;
3923             }
3924             else if (direction == "reverse")
3925             {
3926                 direction_type = 1;
3927             }
3928             else if (direction == "bidirectional")
3929             {
3930                 direction_type = 2;
3931             }
3932 
3933             int weight_data_size = get_tensor_proto_data_size(W);
3934 
3935             fprintf(pp, " 0=%d", hidden_size);
3936             fprintf(pp, " 1=%d", weight_data_size);
3937             fprintf(pp, " 2=%d", direction_type);
3938 
3939             int num_directions = direction_type == 2 ? 2 : 1;
3940 
3941             int quantize_tag = 0;
3942 
3943             // reorder num_directions-IOFG-hidden-size to num_directions-IFOG-hidden-size
3944             {
3945                 fwrite(&quantize_tag, sizeof(int), 1, bp);
3946 
3947                 int weight_data_size_g = get_tensor_proto_data_size(W) / 4 / num_directions;
3948                 const float* wptr = W.has_raw_data() ? (const float*)W.raw_data().data() : W.float_data().data();
3949 
3950                 const float* iptr = wptr;
3951                 const float* optr = wptr + weight_data_size_g;
3952                 const float* fptr = wptr + weight_data_size_g * 2;
3953                 const float* gptr = wptr + weight_data_size_g * 3;
3954                 fwrite(iptr, sizeof(float), weight_data_size_g, bp);
3955                 fwrite(fptr, sizeof(float), weight_data_size_g, bp);
3956                 fwrite(optr, sizeof(float), weight_data_size_g, bp);
3957                 fwrite(gptr, sizeof(float), weight_data_size_g, bp);
3958 
3959                 if (direction_type == 2)
3960                 {
3961                     iptr += weight_data_size_g * 4;
3962                     optr += weight_data_size_g * 4;
3963                     fptr += weight_data_size_g * 4;
3964                     gptr += weight_data_size_g * 4;
3965                     fwrite(iptr, sizeof(float), weight_data_size_g, bp);
3966                     fwrite(fptr, sizeof(float), weight_data_size_g, bp);
3967                     fwrite(optr, sizeof(float), weight_data_size_g, bp);
3968                     fwrite(gptr, sizeof(float), weight_data_size_g, bp);
3969                 }
3970             }
3971 
3972             // reduce xc and hc bias
3973             // reorder num_directions-IOFG-hidden to num_directions-IFOG-hidden
3974             {
3975                 fwrite(&quantize_tag, sizeof(int), 1, bp);
3976 
3977                 int bias_data_size_g = get_tensor_proto_data_size(B) / 2 / 4 / num_directions;
3978                 const float* xcbptr = B.has_raw_data() ? (const float*)B.raw_data().data() : B.float_data().data();
3979                 const float* xiptr = xcbptr;
3980                 const float* xoptr = xcbptr + bias_data_size_g;
3981                 const float* xfptr = xcbptr + bias_data_size_g * 2;
3982                 const float* xgptr = xcbptr + bias_data_size_g * 3;
3983                 const float* hiptr = xcbptr + bias_data_size_g * 4;
3984                 const float* hoptr = xcbptr + bias_data_size_g * 5;
3985                 const float* hfptr = xcbptr + bias_data_size_g * 6;
3986                 const float* hgptr = xcbptr + bias_data_size_g * 7;
3987 
3988                 for (int j = 0; j < bias_data_size_g; j++)
3989                 {
3990                     float vb = xiptr[j] + hiptr[j];
3991                     fwrite(&vb, sizeof(float), 1, bp);
3992                 }
3993                 for (int j = 0; j < bias_data_size_g; j++)
3994                 {
3995                     float vb = xfptr[j] + hfptr[j];
3996                     fwrite(&vb, sizeof(float), 1, bp);
3997                 }
3998                 for (int j = 0; j < bias_data_size_g; j++)
3999                 {
4000                     float vb = xoptr[j] + hoptr[j];
4001                     fwrite(&vb, sizeof(float), 1, bp);
4002                 }
4003                 for (int j = 0; j < bias_data_size_g; j++)
4004                 {
4005                     float vb = xgptr[j] + hgptr[j];
4006                     fwrite(&vb, sizeof(float), 1, bp);
4007                 }
4008 
4009                 if (direction_type == 2)
4010                 {
4011                     xiptr += bias_data_size_g * 8;
4012                     xoptr += bias_data_size_g * 8;
4013                     xfptr += bias_data_size_g * 8;
4014                     xgptr += bias_data_size_g * 8;
4015                     hiptr += bias_data_size_g * 8;
4016                     hoptr += bias_data_size_g * 8;
4017                     hfptr += bias_data_size_g * 8;
4018                     hgptr += bias_data_size_g * 8;
4019 
4020                     for (int j = 0; j < bias_data_size_g; j++)
4021                     {
4022                         float vb = xiptr[j] + hiptr[j];
4023                         fwrite(&vb, sizeof(float), 1, bp);
4024                     }
4025                     for (int j = 0; j < bias_data_size_g; j++)
4026                     {
4027                         float vb = xfptr[j] + hfptr[j];
4028                         fwrite(&vb, sizeof(float), 1, bp);
4029                     }
4030                     for (int j = 0; j < bias_data_size_g; j++)
4031                     {
4032                         float vb = xoptr[j] + hoptr[j];
4033                         fwrite(&vb, sizeof(float), 1, bp);
4034                     }
4035                     for (int j = 0; j < bias_data_size_g; j++)
4036                     {
4037                         float vb = xgptr[j] + hgptr[j];
4038                         fwrite(&vb, sizeof(float), 1, bp);
4039                     }
4040                 }
4041             }
4042 
4043             // reorder num_directions-IOFG-hidden-hidden to num_directions-IFOG-hidden-hidden
4044             {
4045                 fwrite(&quantize_tag, sizeof(int), 1, bp);
4046 
4047                 int weight_data_size_g = get_tensor_proto_data_size(R) / 4 / num_directions;
4048                 const float* rptr = R.has_raw_data() ? (const float*)R.raw_data().data() : R.float_data().data();
4049 
4050                 const float* iptr = rptr;
4051                 const float* optr = rptr + weight_data_size_g;
4052                 const float* fptr = rptr + weight_data_size_g * 2;
4053                 const float* gptr = rptr + weight_data_size_g * 3;
4054                 fwrite(iptr, sizeof(float), weight_data_size_g, bp);
4055                 fwrite(fptr, sizeof(float), weight_data_size_g, bp);
4056                 fwrite(optr, sizeof(float), weight_data_size_g, bp);
4057                 fwrite(gptr, sizeof(float), weight_data_size_g, bp);
4058 
4059                 if (direction_type == 2)
4060                 {
4061                     iptr += weight_data_size_g * 4;
4062                     optr += weight_data_size_g * 4;
4063                     fptr += weight_data_size_g * 4;
4064                     gptr += weight_data_size_g * 4;
4065                     fwrite(iptr, sizeof(float), weight_data_size_g, bp);
4066                     fwrite(fptr, sizeof(float), weight_data_size_g, bp);
4067                     fwrite(optr, sizeof(float), weight_data_size_g, bp);
4068                     fwrite(gptr, sizeof(float), weight_data_size_g, bp);
4069                 }
4070             }
4071         }
4072         else if (op == "MatMul")
4073         {
4074             if (weights.find(node.input(1)) != weights.end() && weights[node.input(1)].dims_size() == 2)
4075             {
4076                 // InnerProduct
4077                 const onnx::TensorProto& B = weights[node.input(1)];
4078 
4079                 int weight_data_size = get_tensor_proto_data_size(B);
4080 
4081                 int num_output = B.dims(B.dims_size() - 1);
4082                 int num_input = weight_data_size / num_output;
4083 
4084                 fprintf(pp, " 0=%d", num_output);
4085                 fprintf(pp, " 1=0");
4086                 fprintf(pp, " 2=%d", weight_data_size);
4087 
4088                 int quantize_tag = 0;
4089                 fwrite(&quantize_tag, sizeof(int), 1, bp);
4090 
4091                 // reorder num_input-num_output to num_output-num_input
4092                 {
4093                     const float* bptr = B.has_raw_data() ? (const float*)B.raw_data().data() : B.float_data().data();
4094 
4095                     for (int j = 0; j < num_output; j++)
4096                     {
4097                         for (int k = 0; k < num_input; k++)
4098                         {
4099                             float vb = bptr[k * num_output + j];
4100                             fwrite(&vb, sizeof(float), 1, bp);
4101                         }
4102                     }
4103                 }
4104 
4105                 // fwrite_tensor_proto_data(B, bp)
4106             }
4107             else
4108             {
4109                 // default matrix multiplication
4110             }
4111         }
4112         else if (op == "Max")
4113         {
4114             int op_type = 4;
4115             fprintf(pp, " 0=%d", op_type);
4116 
4117             if (weights.find(node.input(1)) != weights.end() && weights[node.input(1)].dims_size() == 0)
4118             {
4119                 float b = get_node_attr_from_input_f(weights[node.input(1)]);
4120                 fprintf(pp, " 1=1");
4121                 fprintf(pp, " 2=%e", b);
4122             }
4123         }
4124         else if (op == "Min")
4125         {
4126             int op_type = 5;
4127             fprintf(pp, " 0=%d", op_type);
4128 
4129             if (weights.find(node.input(1)) != weights.end() && weights[node.input(1)].dims_size() == 0)
4130             {
4131                 float b = get_node_attr_from_input_f(weights[node.input(1)]);
4132                 fprintf(pp, " 1=1");
4133                 fprintf(pp, " 2=%e", b);
4134             }
4135         }
4136         else if (op == "Mul")
4137         {
4138             int op_type = 2;
4139             fprintf(pp, " 0=%d", op_type);
4140 
4141             if (weights.find(node.input(1)) != weights.end() && weights[node.input(1)].dims_size() == 0)
4142             {
4143                 float b = get_node_attr_from_input_f(weights[node.input(1)]);
4144                 fprintf(pp, " 1=1");
4145                 fprintf(pp, " 2=%e", b);
4146             }
4147         }
4148         else if (op == "Neg")
4149         {
4150             int op_type = 1;
4151             fprintf(pp, " 0=%d", op_type);
4152         }
4153         else if (op == "Normalize")
4154         {
4155             float eps = get_node_attr_f(node, "eps", 0.f);
4156             int scale_data_size = 1;
4157 
4158             fprintf(pp, " 1=1"); // channel_shared
4159             fprintf(pp, " 2=%e", eps);
4160             fprintf(pp, " 3=%d", scale_data_size);
4161             fprintf(pp, " 9=1"); // TODO hardcode pytorch style
4162 
4163             const float scale_data[1] = {1.f};
4164             fwrite(scale_data, sizeof(float), 1, bp);
4165         }
4166         else if (op == "Pad")
4167         {
4168             std::string mode = get_node_attr_s(node, "mode");
4169             float value = get_node_attr_f(node, "value", 0.f);
4170 
4171             std::vector<int> pads;
4172             if (node.input_size() == 1)
4173             {
4174                 pads = get_node_attr_ai(node, "pads");
4175             }
4176             else
4177             {
4178                 pads = get_node_attr_from_input_ai(weights[node.input(1)]);
4179             }
4180 
4181             int type = 0;
4182             if (mode == "constant")
4183             {
4184                 type = 0;
4185             }
4186             else if (mode == "edge")
4187             {
4188                 type = 1;
4189             }
4190             else if (mode == "reflect")
4191             {
4192                 type = 2;
4193             }
4194 
4195             int pad_size = (int)pads.size();
4196             int top = 0;
4197             int bottom = 0;
4198             int left = 0;
4199             int right = 0;
4200             int front = 0;
4201             int behind = 0;
4202             if (pad_size == 8)
4203             {
4204                 //NCHW
4205                 top = pads[2];
4206                 bottom = pads[6];
4207                 left = pads[3];
4208                 right = pads[7];
4209                 front = pads[1];
4210                 behind = pads[5];
4211             }
4212             else if (pad_size == 6)
4213             {
4214                 //NHW
4215                 top = pads[1];
4216                 bottom = pads[4];
4217                 left = pads[2];
4218                 right = pads[5];
4219             }
4220             else
4221             {
4222                 //NW
4223                 left = pads[1];
4224                 right = pads[3];
4225             }
4226 
4227             fprintf(pp, " 0=%d", top);
4228             fprintf(pp, " 1=%d", bottom);
4229             fprintf(pp, " 2=%d", left);
4230             fprintf(pp, " 3=%d", right);
4231             fprintf(pp, " 4=%d", type);
4232             fprintf(pp, " 5=%e", value);
4233             fprintf(pp, " 7=%d", front);
4234             fprintf(pp, " 8=%d", behind);
4235         }
4236         else if (op == "Pow")
4237         {
4238             int op_type = 6;
4239             fprintf(pp, " 0=%d", op_type);
4240 
4241             if (weights.find(node.input(1)) != weights.end() && weights[node.input(1)].dims_size() == 0)
4242             {
4243                 float b = get_node_attr_from_input_f(weights[node.input(1)]);
4244                 fprintf(pp, " 1=1");
4245                 fprintf(pp, " 2=%e", b);
4246             }
4247         }
4248         else if (op == "PixelShuffle")
4249         {
4250             int scale_factor = get_node_attr_i(node, "scale_factor", 1);
4251             fprintf(pp, " 0=%d", scale_factor);
4252         }
4253         else if (op == "PRelu")
4254         {
4255             const onnx::TensorProto& slope = weights[node.input(1)];
4256 
4257             int num_slope = get_tensor_proto_data_size(slope);
4258 
4259             fprintf(pp, " 0=%d", num_slope);
4260 
4261             fwrite_tensor_proto_data(slope, bp);
4262         }
4263         else if (op == "Reciprocal")
4264         {
4265             int op_type = 15;
4266             fprintf(pp, " 0=%d", op_type);
4267         }
4268         else if (op == "ReduceMax" || op == "ReduceMin" || op == "ReduceMean" || op == "ReduceProd" || op == "ReduceSum" || op == "ReduceSumSquare" || op == "ReduceL1" || op == "ReduceL2" || op == "ReduceLogSum" || op == "ReduceLogSumExp")
4269         {
4270             int op_type = -233;
4271             if (op == "ReduceSum")
4272                 op_type = 0;
4273             else if (op == "ReduceSumSquare")
4274                 op_type = 2;
4275             else if (op == "ReduceMean")
4276                 op_type = 3;
4277             else if (op == "ReduceMax")
4278                 op_type = 4;
4279             else if (op == "ReduceMin")
4280                 op_type = 5;
4281             else if (op == "ReduceProd")
4282                 op_type = 6;
4283             else if (op == "ReduceL1")
4284                 op_type = 7;
4285             else if (op == "ReduceL2")
4286                 op_type = 8;
4287             else if (op == "ReduceLogSum")
4288                 op_type = 9;
4289             else if (op == "ReduceLogSumExp")
4290                 op_type = 10;
4291             fprintf(pp, " 0=%d", op_type);
4292 
4293             std::vector<int> axes = get_node_attr_ai(node, "axes");
4294             int keepdims = get_node_attr_i(node, "keepdims", 1);
4295 
4296             if (axes.size() > 0)
4297             {
4298                 // if axes set, reduce according to axes
4299                 fprintf(pp, " 1=%d", 0);
4300                 fprintf(pp, " -23303=%zu", axes.size());
4301                 for (size_t j = 0; j < axes.size(); j++)
4302                 {
4303                     if (axes[j] == 0 || axes[j] > 3 || axes[j] < -3)
4304                         fprintf(stderr, "Unsupported reduction axes !\n");
4305                     fprintf(pp, ",%d", axes[j]);
4306                 }
4307             }
4308             else
4309             {
4310                 // if axes not set, reduce all axes by default
4311                 fprintf(pp, " 1=%d", 1);
4312             }
4313             fprintf(pp, " 4=%d", keepdims);
4314         }
4315         else if (op == "Reorg")
4316         {
4317             int stride = get_node_attr_i(node, "stride", 1);
4318             fprintf(pp, " 0=%d", stride);
4319         }
4320         else if (op == "Reshape")
4321         {
4322             std::vector<int> shape;
4323 
4324             if (node.input_size() == 1)
4325             {
4326                 shape = get_node_attr_ai(node, "shape");
4327             }
4328             else
4329             {
4330                 shape = get_node_attr_from_input_ai(weights[node.input(1)]);
4331             }
4332 
4333             if (shape.size() == 1)
4334             {
4335                 fprintf(pp, " 0=%d", shape[0]); // should never reach here
4336             }
4337             else if (shape.size() == 2)
4338             {
4339                 fprintf(pp, " 0=%d", shape[1]);
4340             }
4341             else if (shape.size() == 3)
4342             {
4343                 fprintf(pp, " 0=%d", shape[2]);
4344                 fprintf(pp, " 1=%d", shape[1]);
4345             }
4346             else if (shape.size() == 4)
4347             {
4348                 fprintf(pp, " 0=%d", shape[3]);
4349                 fprintf(pp, " 1=%d", shape[2]);
4350                 fprintf(pp, " 2=%d", shape[1]);
4351             }
4352             else if (shape.size() == 5)
4353             {
4354                 fprintf(pp, " 0=%d", shape[4] * shape[3]);
4355                 fprintf(pp, " 1=%d", shape[2]);
4356                 fprintf(pp, " 2=%d", shape[1]);
4357             }
4358         }
4359         else if (op == "Resize")
4360         {
4361             std::string mode = get_node_attr_s(node, "mode");
4362             std::string align = get_node_attr_s(node, "coordinate_transformation_mode");
4363 
4364             std::vector<float> scales;
4365             std::vector<int> sizes;
4366             if (node.input_size() == 2)
4367             {
4368                 // opset 10
4369                 scales = get_node_attr_from_input_af(weights[node.input(1)]);
4370             }
4371             else
4372             {
4373                 // opset 11+
4374                 scales = get_node_attr_from_input_af(weights[node.input(2)]);
4375                 if (node.input_size() >= 4)
4376                 {
4377                     sizes = get_node_attr_from_input_ai(weights[node.input(3)]);
4378                 }
4379             }
4380 
4381             int resize_type = 1;
4382             if (mode == "nearest")
4383             {
4384                 resize_type = 1;
4385             }
4386             else if (mode == "linear")
4387             {
4388                 resize_type = 2;
4389             }
4390             else if (mode == "cubic")
4391             {
4392                 resize_type = 3;
4393             }
4394 
4395             if (scales.empty() && sizes.empty())
4396             {
4397                 fprintf(stderr, "Unsupported Resize scales and sizes are all empty!\n");
4398             }
4399 
4400             float h_scale = 1.f;
4401             float w_scale = 1.f;
4402             if (scales.size() == 2)
4403             {
4404                 w_scale = scales[1];
4405             }
4406             else if (scales.size() == 3)
4407             {
4408                 h_scale = scales[1];
4409                 w_scale = scales[2];
4410             }
4411             else if (scales.size() == 4)
4412             {
4413                 h_scale = scales[2];
4414                 w_scale = scales[3];
4415 
4416                 if (scales[1] != 1.f)
4417                     fprintf(stderr, "Unsupported Resize scales !\n");
4418             }
4419 
4420             int output_height = 0;
4421             int output_width = 0;
4422             if (sizes.size() == 2)
4423             {
4424                 output_width = sizes[1];
4425             }
4426             else if (sizes.size() == 3)
4427             {
4428                 output_height = sizes[1];
4429                 output_width = sizes[2];
4430             }
4431             else if (sizes.size() == 4)
4432             {
4433                 output_height = sizes[2];
4434                 output_width = sizes[3];
4435             }
4436 
4437             int align_corner = 0;
4438             if (align == "align_corners")
4439             {
4440                 align_corner = 1;
4441             }
4442 
4443             fprintf(pp, " 0=%d", resize_type);
4444             fprintf(pp, " 1=%e", h_scale);
4445             fprintf(pp, " 2=%e", w_scale);
4446             fprintf(pp, " 3=%d", output_height);
4447             fprintf(pp, " 4=%d", output_width);
4448             fprintf(pp, " 6=%d", align_corner);
4449         }
4450         else if (op == "RNN")
4451         {
4452             const onnx::TensorProto& W = weights[node.input(1)];
4453             const onnx::TensorProto& R = weights[node.input(2)];
4454             const onnx::TensorProto& B = weights[node.input(3)];
4455 
4456             int hidden_size = get_node_attr_i(node, "hidden_size", 0);
4457             std::string direction = get_node_attr_s(node, "direction");
4458 
4459             int direction_type = 0;
4460             if (direction == "forward")
4461             {
4462                 direction_type = 0;
4463             }
4464             else if (direction == "reverse")
4465             {
4466                 direction_type = 1;
4467             }
4468             else if (direction == "bidirectional")
4469             {
4470                 direction_type = 2;
4471             }
4472 
4473             int weight_data_size = get_tensor_proto_data_size(W);
4474 
4475             fprintf(pp, " 0=%d", hidden_size);
4476             fprintf(pp, " 1=%d", weight_data_size);
4477             fprintf(pp, " 2=%d", direction_type);
4478 
4479             int num_directions = direction_type == 2 ? 2 : 1;
4480 
4481             int quantize_tag = 0;
4482 
4483             fwrite(&quantize_tag, sizeof(int), 1, bp);
4484             fwrite_tensor_proto_data(W, bp);
4485 
4486             // reduce xc and hc bias
4487             {
4488                 fwrite(&quantize_tag, sizeof(int), 1, bp);
4489 
4490                 int bias_data_size_g = get_tensor_proto_data_size(B) / 2 / num_directions;
4491                 const float* bptr = B.has_raw_data() ? (const float*)B.raw_data().data() : B.float_data().data();
4492                 const float* xiptr = bptr;
4493                 const float* hiptr = bptr + bias_data_size_g;
4494 
4495                 for (int j = 0; j < bias_data_size_g; j++)
4496                 {
4497                     float vb = xiptr[j] + hiptr[j];
4498                     fwrite(&vb, sizeof(float), 1, bp);
4499                 }
4500 
4501                 if (direction_type == 2)
4502                 {
4503                     xiptr += bias_data_size_g * 2;
4504                     hiptr += bias_data_size_g * 2;
4505 
4506                     for (int j = 0; j < bias_data_size_g; j++)
4507                     {
4508                         float vb = xiptr[j] + hiptr[j];
4509                         fwrite(&vb, sizeof(float), 1, bp);
4510                     }
4511                 }
4512             }
4513 
4514             fwrite(&quantize_tag, sizeof(int), 1, bp);
4515             fwrite_tensor_proto_data(R, bp);
4516         }
4517         else if (op == "ShuffleChannel")
4518         {
4519             int group = get_node_attr_i(node, "group", 1);
4520             int reverse = get_node_attr_i(node, "reverse", 0);
4521             fprintf(pp, " 0=%d", group);
4522             fprintf(pp, " 1=%d", reverse);
4523         }
4524         else if (op == "Sigmoid")
4525         {
4526         }
4527         else if (op == "Sin")
4528         {
4529             int op_type = 9;
4530             fprintf(pp, " 0=%d", op_type);
4531         }
4532         else if (op == "SkipLayerNormalization")
4533         {
4534             const onnx::TensorProto& W = weights[node.input(2)];
4535             const onnx::TensorProto& B = weights[node.input(3)];
4536             const onnx::TensorProto& B2 = weights[node.input(4)];
4537 
4538             fprintf(pp, " 0=%d", get_tensor_proto_data_size(B));
4539 
4540             int quantize_tag = 0;
4541             fwrite(&quantize_tag, sizeof(int), 1, bp);
4542 
4543             fwrite_tensor_proto_data(W, bp);
4544 
4545             fwrite(&quantize_tag, sizeof(int), 1, bp);
4546 
4547             fwrite_tensor_proto_data(B, bp);
4548 
4549             fwrite(&quantize_tag, sizeof(int), 1, bp);
4550 
4551             fwrite_tensor_proto_data(B2, bp);
4552         }
4553         else if (op == "Slice")
4554         {
4555             std::vector<int> starts;
4556             std::vector<int> ends;
4557             std::vector<int> axes;
4558             std::vector<int> steps;
4559             if (node.input_size() == 1)
4560             {
4561                 starts = get_node_attr_ai(node, "starts");
4562                 ends = get_node_attr_ai(node, "ends");
4563                 axes = get_node_attr_ai(node, "axes");
4564                 steps = get_node_attr_ai(node, "steps"); // TODO
4565             }
4566             else
4567             {
4568                 starts = get_node_attr_from_input_ai(weights[node.input(1)]);
4569                 ends = get_node_attr_from_input_ai(weights[node.input(2)]);
4570                 if (node.input_size() >= 4)
4571                     axes = get_node_attr_from_input_ai(weights[node.input(3)]);
4572                 if (node.input_size() >= 5)
4573                     steps = get_node_attr_from_input_ai(weights[node.input(4)]);
4574             }
4575 
4576             // assert step == 1
4577             for (int i = 0; i < (int)steps.size(); i++)
4578             {
4579                 if (steps[i] != 1)
4580                     fprintf(stderr, "Unsupported slice step !\n");
4581             }
4582 
4583             // filter out N-dim axis
4584             if (!axes.empty())
4585             {
4586                 for (int i = 0; i < (int)axes.size(); i++)
4587                 {
4588                     int axis = axes[i];
4589                     if (axis == 0)
4590                     {
4591                         starts.erase(starts.begin() + i);
4592                         ends.erase(ends.begin() + i);
4593                         axes.erase(axes.begin() + i);
4594                         break;
4595                     }
4596                 }
4597             }
4598 
4599             fprintf(pp, " -23309=%d", (int)starts.size());
4600             for (int i = 0; i < (int)starts.size(); i++)
4601             {
4602                 fprintf(pp, ",%d", starts[i]);
4603             }
4604             fprintf(pp, " -23310=%d", (int)ends.size());
4605             for (int i = 0; i < (int)ends.size(); i++)
4606             {
4607                 fprintf(pp, ",%d", ends[i]);
4608             }
4609             if (!axes.empty())
4610             {
4611                 fprintf(pp, " -23311=%d", (int)axes.size());
4612                 for (int i = 0; i < (int)axes.size(); i++)
4613                 {
4614                     int axis = axes[i];
4615                     if (axis == 0 || axis > 3 || axis < -3)
4616                         fprintf(stderr, "Unsupported slice axes !\n");
4617 
4618                     if (axis > 0)
4619                         axis = axis - 1; // -1 for skip N-dim
4620 
4621                     fprintf(pp, ",%d", axis);
4622                 }
4623             }
4624         }
4625         else if (op == "Softmax")
4626         {
4627             int axis = get_node_attr_i(node, "axis", 1);
4628             fprintf(pp, " 0=%d", axis - 1);
4629             fprintf(pp, " 1=1");
4630         }
4631         else if (op == "Split")
4632         {
4633             int axis = get_node_attr_i(node, "axis", 0);
4634             std::vector<int> split = get_node_attr_ai(node, "split");
4635             if (axis < 1)
4636                 fprintf(stderr, "Unsupported split axis !\n");
4637 
4638             fprintf(pp, " -23300=%d", output_size);
4639             if (split.empty())
4640             {
4641                 for (int i = 0; i < output_size; i++)
4642                 {
4643                     fprintf(pp, ",-233");
4644                 }
4645             }
4646             else
4647             {
4648                 for (size_t i = 0; i < split.size() - 1; i++)
4649                 {
4650                     fprintf(pp, ",%d", split[i]);
4651                 }
4652                 fprintf(pp, ",-233");
4653             }
4654             fprintf(pp, " 1=%d", axis - 1);
4655         }
4656         else if (op == "Sqrt")
4657         {
4658             int op_type = 5;
4659             fprintf(pp, " 0=%d", op_type);
4660         }
4661         else if (op == "Squeeze")
4662         {
4663             std::vector<int> axes = get_node_attr_ai(node, "axes");
4664 
4665             if (axes.empty())
4666             {
4667                 fprintf(pp, " 0=1");
4668                 fprintf(pp, " 1=1");
4669                 fprintf(pp, " 2=1");
4670             }
4671             else
4672             {
4673                 fprintf(pp, " -23303=%zu", axes.size());
4674                 for (int i = 0; i < (int)axes.size(); i++)
4675                 {
4676                     if (axes[i] == 0 || axes[i] > 3 || axes[i] < -3)
4677                         fprintf(stderr, "Unsupported squeeze axes !\n");
4678                     fprintf(pp, ",%d", axes[i]);
4679                 }
4680             }
4681         }
4682         else if (op == "Sub")
4683         {
4684             int op_type = 1;
4685             fprintf(pp, " 0=%d", op_type);
4686 
4687             if (weights.find(node.input(1)) != weights.end() && weights[node.input(1)].dims_size() == 0)
4688             {
4689                 float b = get_node_attr_from_input_f(weights[node.input(1)]);
4690                 fprintf(pp, " 1=1");
4691                 fprintf(pp, " 2=%e", b);
4692             }
4693         }
4694         else if (op == "Sum")
4695         {
4696             int op_type = 1;
4697             fprintf(pp, " 0=%d", op_type);
4698         }
4699         else if (op == "Swish")
4700         {
4701         }
4702         else if (op == "Tan")
4703         {
4704             int op_type = 11;
4705             fprintf(pp, " 0=%d", op_type);
4706         }
4707         else if (op == "Tanh")
4708         {
4709             int op_type = 16;
4710             fprintf(pp, " 0=%d", op_type);
4711         }
4712         else if (op == "Transpose")
4713         {
4714             std::vector<int> perm = get_node_attr_ai(node, "perm");
4715 
4716             if (perm.size() == 3)
4717             {
4718                 if (perm[1] == 1 && perm[2] == 2)
4719                     fprintf(pp, " 0=0"); // w h
4720                 else if (perm[1] == 2 && perm[2] == 1)
4721                     fprintf(pp, " 0=1"); // h w
4722                 else if (perm[0] == 1 && perm[1] == 0 && perm[2] == 2)
4723                     fprintf(pp, " 0=0"); // w h
4724                 else if (perm[0] == 2 && perm[1] == 0 && perm[2] == 1)
4725                     fprintf(pp, " 0=1"); // h w
4726             }
4727             else if (perm.size() == 4)
4728             {
4729                 if (perm[1] == 1 && perm[2] == 2 && perm[3] == 3)
4730                     fprintf(pp, " 0=0"); // w h c
4731                 else if (perm[1] == 1 && perm[2] == 3 && perm[3] == 2)
4732                     fprintf(pp, " 0=1"); // h w c
4733                 else if (perm[1] == 2 && perm[2] == 1 && perm[3] == 3)
4734                     fprintf(pp, " 0=2"); // w c h
4735                 else if (perm[1] == 2 && perm[2] == 3 && perm[3] == 1)
4736                     fprintf(pp, " 0=3"); // c w h
4737                 else if (perm[1] == 3 && perm[2] == 1 && perm[3] == 2)
4738                     fprintf(pp, " 0=4"); // h c w
4739                 else if (perm[1] == 3 && perm[2] == 2 && perm[3] == 1)
4740                     fprintf(pp, " 0=5"); // c h w
4741             }
4742             else if (perm.size() == 5)
4743             {
4744                 if (perm[1] == 1 && perm[2] == 2 && perm[3] == 3 && perm[4] == 4)
4745                     fprintf(pp, " 0=0"); // wx h c
4746                 else if (perm[1] == 1 && perm[2] == 3 && perm[3] == 4 && perm[4] == 2)
4747                     fprintf(pp, " 0=1"); // h wx c
4748                 else if (perm[1] == 2 && perm[2] == 1 && perm[3] == 3 && perm[4] == 4)
4749                     fprintf(pp, " 0=2"); // wx c h
4750                 else if (perm[1] == 2 && perm[2] == 3 && perm[3] == 4 && perm[4] == 1)
4751                     fprintf(pp, " 0=3"); // c wx h
4752                 else if (perm[1] == 3 && perm[2] == 4 && perm[3] == 1 && perm[4] == 2)
4753                     fprintf(pp, " 0=4"); // h c wx
4754                 else if (perm[1] == 3 && perm[2] == 4 && perm[3] == 2 && perm[4] == 1)
4755                     fprintf(pp, " 0=5"); // c h wx
4756                 else
4757                     fprintf(stderr, "Unsupported transpose type !\n");
4758             }
4759         }
4760         else if (op == "Upsample")
4761         {
4762             std::string mode = get_node_attr_s(node, "mode");
4763             std::string align = get_node_attr_s(node, "coordinate_transformation_mode");
4764 
4765             std::vector<float> scales;
4766 
4767             if (node.input_size() == 1)
4768             {
4769                 scales = get_node_attr_af(node, "scales");
4770             }
4771             else
4772             {
4773                 scales = get_node_attr_from_input_af(weights[node.input(1)]);
4774             }
4775 
4776             int resize_type = 1;
4777             if (mode == "nearest")
4778             {
4779                 resize_type = 1;
4780             }
4781             else if (mode == "bilinear" || mode == "linear")
4782             {
4783                 resize_type = 2;
4784             }
4785             else if (mode == "trilinear")
4786             {
4787                 fprintf(stderr, "Unsupported Upsample mode !\n");
4788             }
4789 
4790             float h_scale = 1.f;
4791             float w_scale = 1.f;
4792             if (scales.size() == 2)
4793             {
4794                 w_scale = scales[1];
4795             }
4796             else if (scales.size() == 3)
4797             {
4798                 h_scale = scales[1];
4799                 w_scale = scales[2];
4800             }
4801             else if (scales.size() == 4)
4802             {
4803                 h_scale = scales[2];
4804                 w_scale = scales[3];
4805 
4806                 if (scales[1] != 1.f)
4807                     fprintf(stderr, "Unsupported Upsample scales !\n");
4808             }
4809             else
4810             {
4811                 fprintf(stderr, "Unsupported Upsample scales !\n");
4812             }
4813 
4814             int align_corner = 0;
4815             if (align == "align_corners")
4816             {
4817                 align_corner = 1;
4818             }
4819 
4820             fprintf(pp, " 0=%d", resize_type);
4821             fprintf(pp, " 1=%e", h_scale);
4822             fprintf(pp, " 2=%e", w_scale);
4823             fprintf(pp, " 6=%d", align_corner);
4824         }
4825         else if (op == "Unsqueeze")
4826         {
4827             std::vector<int> axes = get_node_attr_ai(node, "axes");
4828 
4829             fprintf(pp, " -23303=%zu", axes.size());
4830             for (int i = 0; i < (int)axes.size(); i++)
4831             {
4832                 if (axes[i] == 0 || axes[i] > 4 || axes[i] < -4)
4833                     fprintf(stderr, "Unsupported unsqueeze axes !\n");
4834                 fprintf(pp, ",%d", axes[i]);
4835             }
4836         }
4837         else
4838         {
4839             // TODO op specific param
4840             for (int j = 0; j < node.attribute_size(); j++)
4841             {
4842                 const onnx::AttributeProto& attr = node.attribute(j);
4843                 if (attr.type() == 1)
4844                 {
4845                     fprintf(stderr, "  # %s=%g\n", attr.name().c_str(), attr.f());
4846                 }
4847                 else if (attr.type() == 2)
4848                 {
4849                     fprintf(stderr, "  # %s=%lld\n", attr.name().c_str(), (long long)attr.i());
4850                 }
4851                 else if (attr.type() == 3)
4852                 {
4853                     fprintf(stderr, "  # %s=%s\n", attr.name().c_str(), attr.s().c_str());
4854                 }
4855                 else
4856                 {
4857                     fprintf(stderr, "  # %s %d\n", attr.name().c_str(), attr.type());
4858                 }
4859             }
4860         }
4861 
4862         fprintf(pp, "\n");
4863 
4864         for (int j = 0; j < output_size; j++)
4865         {
4866             const std::string& output_name = node.output(j);
4867             if (node_reference.find(output_name) != node_reference.end())
4868             {
4869                 int refcount = node_reference[output_name];
4870                 if (refcount > 1)
4871                 {
4872                     char splitname[256];
4873                     sprintf(splitname, "splitncnn_%d", internal_split);
4874                     fprintf(pp, "%-16s %-24s %d %d", "Split", splitname, 1, refcount);
4875 
4876                     fprintf(pp, " %s", output_name.c_str());
4877 
4878                     for (int k = 0; k < refcount; k++)
4879                     {
4880                         fprintf(pp, " %s_splitncnn_%d", output_name.c_str(), k);
4881                     }
4882                     fprintf(pp, "\n");
4883 
4884                     internal_split++;
4885                 }
4886             }
4887         }
4888     }
4889 
4890     fclose(pp);
4891     fclose(bp);
4892 
4893     return 0;
4894 }
4895