1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14 
15 #include "onnx.pb.h"
16 
17 #include <algorithm>
18 #include <float.h>
19 #include <fstream>
20 #include <google/protobuf/io/coded_stream.h>
21 #include <google/protobuf/io/zero_copy_stream_impl.h>
22 #include <google/protobuf/message.h>
23 #include <google/protobuf/text_format.h>
24 #include <iostream>
25 #include <limits.h>
26 #include <limits>
27 #include <set>
28 #include <stdio.h>
29 
read_proto_from_binary(const char * filepath,onnx::ModelProto * message)30 static bool read_proto_from_binary(const char* filepath, onnx::ModelProto* message)
31 {
32     std::ifstream fs(filepath, std::ifstream::in | std::ifstream::binary);
33     if (!fs.is_open())
34     {
35         fprintf(stderr, "open failed %s\n", filepath);
36         return false;
37     }
38 
39     google::protobuf::io::IstreamInputStream input(&fs);
40     google::protobuf::io::CodedInputStream codedstr(&input);
41 
42 #if GOOGLE_PROTOBUF_VERSION >= 3011000
43     codedstr.SetTotalBytesLimit(INT_MAX);
44 #else
45     codedstr.SetTotalBytesLimit(INT_MAX, INT_MAX / 2);
46 #endif
47 
48     bool success = message->ParseFromCodedStream(&codedstr);
49 
50     fs.close();
51 
52     return success;
53 }
54 
get_node_attr_ai(const onnx::NodeProto & node,const char * key)55 static std::vector<int> get_node_attr_ai(const onnx::NodeProto& node, const char* key)
56 {
57     std::vector<int> v;
58 
59     for (int i = 0; i < node.attribute_size(); i++)
60     {
61         const onnx::AttributeProto& attr = node.attribute(i);
62         if (attr.name() == key)
63         {
64             v.resize(attr.ints_size());
65             for (int j = 0; j < attr.ints_size(); j++)
66             {
67                 v[j] = std::max(std::min(attr.ints(j), (::google::protobuf::int64)INT_MAX), (::google::protobuf::int64)INT_MIN);
68             }
69 
70             break;
71         }
72     }
73 
74     return v;
75 }
76 
get_node_attr_af(const onnx::NodeProto & node,const char * key)77 static std::vector<float> get_node_attr_af(const onnx::NodeProto& node, const char* key)
78 {
79     std::vector<float> v;
80 
81     for (int i = 0; i < node.attribute_size(); i++)
82     {
83         const onnx::AttributeProto& attr = node.attribute(i);
84         if (attr.name() == key)
85         {
86             v.resize(attr.floats_size());
87             for (int j = 0; j < attr.floats_size(); j++)
88             {
89                 v[j] = attr.floats(j);
90             }
91 
92             break;
93         }
94     }
95 
96     return v;
97 }
98 
get_node_attr_i(const onnx::NodeProto & node,const char * key,int def=0)99 static int get_node_attr_i(const onnx::NodeProto& node, const char* key, int def = 0)
100 {
101     for (int i = 0; i < node.attribute_size(); i++)
102     {
103         const onnx::AttributeProto& attr = node.attribute(i);
104         if (attr.name() == key)
105         {
106             return std::max(std::min(attr.i(), (::google::protobuf::int64)INT_MAX), (::google::protobuf::int64)INT_MIN);
107         }
108     }
109 
110     return def;
111 }
112 
get_node_attr_f(const onnx::NodeProto & node,const char * key,float def=0.f)113 static float get_node_attr_f(const onnx::NodeProto& node, const char* key, float def = 0.f)
114 {
115     for (int i = 0; i < node.attribute_size(); i++)
116     {
117         const onnx::AttributeProto& attr = node.attribute(i);
118         if (attr.name() == key)
119         {
120             return attr.f();
121         }
122     }
123 
124     return def;
125 }
126 
get_node_attr_s(const onnx::NodeProto & node,const char * key,const std::string & def=std::string ())127 static std::string get_node_attr_s(const onnx::NodeProto& node, const char* key, const std::string& def = std::string())
128 {
129     for (int i = 0; i < node.attribute_size(); i++)
130     {
131         const onnx::AttributeProto& attr = node.attribute(i);
132         if (attr.name() == key)
133         {
134             return attr.s();
135         }
136     }
137 
138     return def;
139 }
140 
get_node_attr_tensor(const onnx::NodeProto & node,const char * key)141 static onnx::TensorProto get_node_attr_tensor(const onnx::NodeProto& node, const char* key)
142 {
143     for (int i = 0; i < node.attribute_size(); i++)
144     {
145         const onnx::AttributeProto& attr = node.attribute(i);
146         if (attr.name() == key)
147         {
148             return attr.t();
149         }
150     }
151 
152     return onnx::TensorProto();
153 }
154 
get_node_attr_from_input_f(const onnx::TensorProto & tp)155 static float get_node_attr_from_input_f(const onnx::TensorProto& tp)
156 {
157     float v = 0.f;
158 
159     // float
160     if (tp.data_type() == 1)
161     {
162         const float* shape_data = 0;
163         if (tp.has_raw_data())
164         {
165             shape_data = (const float*)tp.raw_data().data();
166         }
167         else
168         {
169             shape_data = tp.float_data().data();
170         }
171         v = shape_data[0];
172     }
173     // double
174     else if (tp.data_type() == 11)
175     {
176         const double* shape_data = 0;
177         if (tp.has_raw_data())
178         {
179             shape_data = (const double*)tp.raw_data().data();
180         }
181         else
182         {
183             shape_data = tp.double_data().data();
184         }
185         v = shape_data[0];
186     }
187     // int64
188     else if (tp.data_type() == 7)
189     {
190         const int64_t* shape_data = 0;
191         if (tp.has_raw_data())
192         {
193             shape_data = (const int64_t*)tp.raw_data().data();
194         }
195         else
196         {
197             shape_data = tp.int64_data().data();
198         }
199         v = std::max(std::min(shape_data[0], (::google::protobuf::int64)INT_MAX), (::google::protobuf::int64)INT_MIN);
200     }
201     // int32
202     else if (tp.data_type() == 6)
203     {
204         const int32_t* shape_data = 0;
205         if (tp.has_raw_data())
206         {
207             shape_data = (const int32_t*)tp.raw_data().data();
208         }
209         else
210         {
211             shape_data = tp.int32_data().data();
212         }
213         v = shape_data[0];
214     }
215     else
216     {
217         fprintf(stderr, "Unknown data type %d\n", tp.data_type());
218         abort();
219     }
220 
221     return v;
222 }
223 
get_node_attr_from_input_ai(const onnx::TensorProto & tp)224 static std::vector<int> get_node_attr_from_input_ai(const onnx::TensorProto& tp)
225 {
226     int size = 0;
227 
228     std::vector<int> v;
229 
230     // int64
231     if (tp.data_type() == 7)
232     {
233         const int64_t* shape_data = 0;
234         if (tp.has_raw_data())
235         {
236             shape_data = (const int64_t*)tp.raw_data().data();
237             size = (int)(tp.raw_data().size() / 8);
238         }
239         else
240         {
241             shape_data = tp.int64_data().data();
242             size = tp.int64_data_size();
243         }
244         for (int j = 0; j < size; j++)
245         {
246             int vi = std::max(std::min(shape_data[j], (::google::protobuf::int64)INT_MAX), (::google::protobuf::int64)INT_MIN);
247             v.push_back(vi);
248         }
249     }
250     // int32
251     else if (tp.data_type() == 6)
252     {
253         const int32_t* shape_data = 0;
254         if (tp.has_raw_data())
255         {
256             shape_data = (const int32_t*)tp.raw_data().data();
257             size = (int)(tp.raw_data().size() / 4);
258         }
259         else
260         {
261             shape_data = tp.int32_data().data();
262             size = tp.int32_data_size();
263         }
264         for (int j = 0; j < size; j++)
265         {
266             v.push_back(shape_data[j]);
267         }
268     }
269     else
270     {
271         fprintf(stderr, "Unknown data type %d\n", tp.data_type());
272     }
273 
274     return v;
275 }
276 
get_node_attr_from_input_af(const onnx::TensorProto & tp)277 static std::vector<float> get_node_attr_from_input_af(const onnx::TensorProto& tp)
278 {
279     int size = 0;
280 
281     std::vector<float> v;
282 
283     // float
284     if (tp.data_type() == 1)
285     {
286         const float* shape_data = 0;
287         if (tp.has_raw_data())
288         {
289             shape_data = (const float*)tp.raw_data().data();
290             size = (int)(tp.raw_data().size() / 4);
291         }
292         else
293         {
294             shape_data = tp.float_data().data();
295             size = tp.float_data_size();
296         }
297         for (int j = 0; j < size; j++)
298         {
299             v.push_back(shape_data[j]);
300         }
301     }
302     // double
303     else if (tp.data_type() == 11)
304     {
305         const double* shape_data = 0;
306         if (tp.has_raw_data())
307         {
308             shape_data = (const double*)tp.raw_data().data();
309             size = (int)(tp.raw_data().size() / 8);
310         }
311         else
312         {
313             shape_data = tp.double_data().data();
314             size = tp.double_data_size();
315         }
316         for (int j = 0; j < size; j++)
317         {
318             v.push_back((float)shape_data[j]);
319         }
320     }
321     else
322     {
323         fprintf(stderr, "Unknown data type %d\n", tp.data_type());
324     }
325 
326     return v;
327 }
328 
get_tensor_proto_data_size(const onnx::TensorProto & tp)329 static int get_tensor_proto_data_size(const onnx::TensorProto& tp)
330 {
331     if (tp.has_raw_data())
332     {
333         const std::string& raw_data = tp.raw_data();
334         int size = (int)raw_data.size() / 4;
335         return size;
336     }
337     else if (tp.data_type() == 1)
338     {
339         return tp.float_data_size();
340     }
341 
342     return 0;
343 }
344 
fwrite_tensor_proto_data(const onnx::TensorProto & tp,FILE * bp)345 static void fwrite_tensor_proto_data(const onnx::TensorProto& tp, FILE* bp)
346 {
347     int size = get_tensor_proto_data_size(tp);
348 
349     if (tp.has_raw_data())
350     {
351         const std::string& raw_data = tp.raw_data();
352         fwrite(raw_data.data(), sizeof(float), size, bp);
353     }
354     else if (tp.data_type() == 1)
355     {
356         fwrite(tp.float_data().data(), sizeof(float), size, bp);
357     }
358 }
359 
fuse_weight_reshape(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)360 static void fuse_weight_reshape(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
361 {
362     int node_count = mutable_graph->node_size();
363     for (int i = 0; i < node_count; i++)
364     {
365         onnx::NodeProto* node = mutable_graph->mutable_node(i);
366 
367         // weight <= Reshape(weight)
368         if (node->op_type() == "Reshape")
369         {
370             // check weight
371             if (weights.find(node->input(0)) == weights.end())
372                 continue;
373 
374             weights[node->output(0)] = weights[node->input(0)];
375 
376             // set weight shape directly
377             std::vector<int> shape;
378             if (node->input_size() == 1)
379             {
380                 shape = get_node_attr_ai(*node, "shape");
381             }
382             else if (node->input_size() == 2)
383             {
384                 // opset 5
385                 shape = get_node_attr_from_input_ai(weights[node->input(1)]);
386             }
387 
388             weights[node->output(0)].clear_dims();
389             for (int j = 0; j < shape.size(); j++)
390             {
391                 weights[node->output(0)].add_dims(shape[j]);
392             }
393 
394             // reduce
395             node->set_op_type("noop_reducedncnn");
396 
397             node_reference[node->input(0)] -= 1;
398             if (node->input_size() == 2)
399             {
400                 node_reference[node->input(1)] -= 1;
401             }
402 
403             reduced_node_count += 1;
404             i += 1;
405         }
406     }
407 }
408 
fuse_weight_transpose(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)409 static void fuse_weight_transpose(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
410 {
411     int node_count = mutable_graph->node_size();
412     for (int i = 0; i < node_count; i++)
413     {
414         onnx::NodeProto* node = mutable_graph->mutable_node(i);
415 
416         // weight <= Transpose(weight)
417         if (node->op_type() == "Transpose")
418         {
419             // check weight
420             if (weights.find(node->input(0)) == weights.end())
421                 continue;
422 
423             if (weights[node->input(0)].dims_size() != 2)
424                 continue;
425 
426             // perm = (1, 0)
427             std::vector<int> perm = get_node_attr_ai(*node, "perm");
428             if (perm.size() != 2)
429                 continue;
430             if (perm[0] != 1 || perm[1] != 0)
431                 continue;
432 
433             weights[node->output(0)] = weights[node->input(0)];
434 
435             // permute weight
436             {
437                 onnx::TensorProto& B = weights[node->output(0)];
438 
439                 const int h = B.dims(0);
440                 const int w = B.dims(1);
441 
442                 std::vector<float> permuted_data;
443                 permuted_data.reserve((size_t)h * w);
444                 const float* bptr = B.has_raw_data() ? (const float*)B.raw_data().data() : B.float_data().data();
445 
446                 for (int j = 0; j < w; j++)
447                 {
448                     for (int k = 0; k < h; k++)
449                     {
450                         float vb = bptr[k * w + j];
451                         permuted_data.push_back(vb);
452                     }
453                 }
454 
455                 B.set_dims(0, w);
456                 B.set_dims(1, h);
457 
458                 if (B.has_raw_data())
459                 {
460                     B.set_raw_data(permuted_data.data(), permuted_data.size() * sizeof(float));
461                 }
462                 else
463                 {
464                     for (int j = 0; j < (int)permuted_data.size(); j++)
465                         B.set_float_data(j, permuted_data[j]);
466                 }
467             }
468 
469             // reduce
470             node->set_op_type("noop_reducedncnn");
471 
472             node_reference[node->input(0)] -= 1;
473 
474             reduced_node_count += 1;
475             i += 1;
476         }
477     }
478 }
479 
fuse_shufflechannel(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)480 static void fuse_shufflechannel(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
481 {
482     int node_count = mutable_graph->node_size();
483     for (int i = 0; i < node_count; i++)
484     {
485         onnx::NodeProto* node = mutable_graph->mutable_node(i);
486 
487         // ShuffleChannel <= Reshape - Transpose - Reshape
488         // ShuffleChannel <= Reshape - Transpose - Constant - Reshape
489         if (node->op_type() == "Reshape")
490         {
491             if (node_reference[node->output(0)] != 1)
492                 continue;
493 
494             std::vector<int> shape;
495             if (node->input_size() == 1)
496             {
497                 shape = get_node_attr_ai(*node, "shape");
498             }
499             else
500             {
501                 // skip weight reshape
502                 if (weights.find(node->input(1)) == weights.end())
503                     continue;
504 
505                 shape = get_node_attr_from_input_ai(weights[node->input(1)]);
506             }
507 
508             // 1 groups channels_per_group, height, width
509             // reverse style = channels_per_group, groups, height * width
510             if (shape.size() != 5 && shape.size() != 3)
511                 continue;
512 
513             if (shape.size() == 5 && shape[0] != 1)
514                 continue;
515 
516             if (i + 2 >= node_count)
517                 continue;
518 
519             onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
520             onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
521 
522             if (node3->op_type() == "Constant")
523             {
524                 if (i + 3 >= node_count)
525                     continue;
526 
527                 node3 = mutable_graph->mutable_node(i + 3);
528             }
529 
530             if (node2->op_type() != "Transpose" || node3->op_type() != "Reshape")
531                 continue;
532 
533             if (node_reference[node2->output(0)] != 1)
534                 continue;
535 
536             // 0 2 1 3 4
537             // reverse style = 1 0 2
538             std::vector<int> perm = get_node_attr_ai(*node2, "perm");
539             if (perm.size() != 5 && perm.size() != 3)
540                 continue;
541 
542             if (perm.size() == 5 && (perm[0] != 0 || perm[1] != 2 || perm[2] != 1 || perm[3] != 3 || perm[4] != 4))
543                 continue;
544 
545             if (perm.size() == 3 && (perm[0] != 1 || perm[1] != 0 || perm[2] != 2))
546                 continue;
547 
548             std::vector<int> shape3;
549             if (node3->input_size() == 1)
550             {
551                 shape3 = get_node_attr_ai(*node3, "shape");
552             }
553             else
554             {
555                 // skip weight reshape
556                 if (weights.find(node3->input(1)) == weights.end())
557                     continue;
558 
559                 shape3 = get_node_attr_from_input_ai(weights[node3->input(1)]);
560             }
561 
562             // 1, -1, height, width
563             // reverse style = group, -1, channels_per_group, height, width
564             if (shape3.size() != 4 && shape3.size() != 5)
565                 continue;
566 
567             if (shape3.size() == 4 && (shape3[0] != 1 || (shape3[1] != -1 && shape3[1] != shape[1] * shape[2])))
568                 continue;
569 
570             if (shape3.size() == 5 && (shape3[0] != shape[1] || shape3[2] != shape[0] || shape3[3] * shape3[4] != shape[2]))
571                 continue;
572 
573             // reduce
574             node->set_op_type("noop_reducedncnn");
575             node2->set_op_type("noop_reducedncnn");
576 
577             if (node->input_size() == 2)
578             {
579                 node_reference[node->input(1)] -= 1;
580             }
581             node_reference[node->output(0)] -= 1;
582             node_reference[node2->output(0)] -= 1;
583             if (node3->input_size() == 2)
584             {
585                 node_reference[node3->input(1)] -= 1;
586             }
587 
588             blob_names.erase(node->output(0));
589             blob_names.erase(node2->output(0));
590 
591             node3->set_op_type("ShuffleChannel");
592             node3->set_input(0, node->input(0));
593 
594             onnx::AttributeProto* attr_group = node3->add_attribute();
595             attr_group->set_name("group");
596             attr_group->set_i(shape[1]);
597 
598             onnx::AttributeProto* attr_reverse = node3->add_attribute();
599             attr_reverse->set_name("reverse");
600             attr_reverse->set_i(shape.size() == 3);
601 
602             reduced_node_count += 2;
603             i += 2;
604         }
605     }
606 }
607 
fuse_shufflechannel_split(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)608 static void fuse_shufflechannel_split(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
609 {
610     int node_count = mutable_graph->node_size();
611     for (int i = 0; i < node_count; i++)
612     {
613         onnx::NodeProto* node = mutable_graph->mutable_node(i);
614 
615         // Split <= ShuffleChannel(reverse type) - Gather(0) - Gather(1)
616         if (node->op_type() == "ShuffleChannel")
617         {
618             // reverse = 1
619             int reverse = get_node_attr_i(*node, "reverse");
620             if (reverse != 1)
621                 continue;
622 
623             if (i + 2 >= node_count)
624                 continue;
625 
626             onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
627             onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
628 
629             if (node2->op_type() != "Gather" || node3->op_type() != "Gather")
630                 continue;
631 
632             if (node2->input(0) != node->output(0) || node3->input(0) != node->output(0))
633                 continue;
634 
635             // axis = 0
636             int gather2_axis = get_node_attr_i(*node2, "axis");
637             if (gather2_axis != 0)
638                 continue;
639 
640             // indices = 0
641             if (weights.find(node2->input(1)) == weights.end())
642                 continue;
643 
644             std::vector<int> gather2_indices = get_node_attr_from_input_ai(weights[node2->input(1)]);
645             if (gather2_indices.size() != 1 || gather2_indices[0] != 0)
646                 continue;
647 
648             // axis = 0
649             int gather3_axis = get_node_attr_i(*node3, "axis");
650             if (gather3_axis != 0)
651                 continue;
652 
653             // indices = 1
654             if (weights.find(node3->input(1)) == weights.end())
655                 continue;
656 
657             std::vector<int> gather3_indices = get_node_attr_from_input_ai(weights[node3->input(1)]);
658             if (gather3_indices.size() != 1 || gather3_indices[0] != 1)
659                 continue;
660 
661             // reduce
662             node2->set_op_type("noop_reducedncnn");
663 
664             node_reference[node->output(0)] -= 2;
665             node_reference[node2->input(1)] -= 1;
666             node_reference[node3->input(1)] -= 1;
667 
668             node3->set_op_type("Split");
669             node3->clear_input();
670             node3->add_input(node->output(0));
671             node3->add_output(node3->output(0));
672             node3->set_output(0, node2->output(0));
673 
674             node3->clear_attribute();
675             onnx::AttributeProto* attr_axis = node3->add_attribute();
676             attr_axis->set_name("axis");
677             attr_axis->set_i(1);
678 
679             reduced_node_count += 1;
680             i += 1;
681         }
682     }
683 }
684 
fuse_hardswish(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)685 static void fuse_hardswish(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
686 {
687     int node_count = mutable_graph->node_size();
688     for (int i = 0; i < node_count; i++)
689     {
690         onnx::NodeProto* node = mutable_graph->mutable_node(i);
691 
692         // HardSwish <= Add(+3) - Clip(0,6) - Mul(X,) - Div(/6)
693         // HardSwish <= Add(+3) - Clip(0,6) - Mul(X,) - Mul(*(1/6))
694         // HardSwish <= Add(+3) - Clip(0,6) - Mul(X,) - Constant - Div(/6)
695         // HardSwish <= Add(+3) - Clip(0,6) - Mul(X,) - Constant - Mul(*(1/6))
696         //     out = x * F.relu6(x + 3, inplace=True) / 6
697         if (node->op_type() == "Add")
698         {
699             if (node_reference[node->output(0)] != 1)
700                 continue;
701 
702             if (i + 3 >= node_count)
703                 continue;
704 
705             if (weights.find(node->input(1)) == weights.end())
706                 continue;
707 
708             const onnx::TensorProto& add_three = weights[node->input(1)];
709             if (add_three.dims_size() != 0 || get_tensor_proto_data_size(add_three) != 1)
710                 continue;
711 
712             float constant_add_three = get_node_attr_from_input_f(add_three);
713             if (constant_add_three != 3.f)
714                 continue;
715 
716             onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
717             onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
718             onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3);
719 
720             if (node4->op_type() == "Constant")
721             {
722                 if (i + 4 >= node_count)
723                     continue;
724 
725                 node4 = mutable_graph->mutable_node(i + 4);
726             }
727 
728             if (node2->op_type() != "Clip" || node3->op_type() != "Mul" || (node4->op_type() != "Div" && node4->op_type() != "Mul"))
729                 continue;
730 
731             if (node_reference[node2->output(0)] != 1)
732                 continue;
733 
734             float relu6_min;
735             float relu6_max;
736             if (node2->input_size() == 1)
737             {
738                 relu6_min = get_node_attr_f(*node2, "min", -FLT_MAX);
739                 relu6_max = get_node_attr_f(*node2, "max", FLT_MAX);
740             }
741             else
742             {
743                 const onnx::TensorProto& min_tp = weights[node2->input(1)];
744                 const onnx::TensorProto& max_tp = weights[node2->input(2)];
745 
746                 relu6_min = get_node_attr_from_input_f(min_tp);
747                 relu6_max = get_node_attr_from_input_f(max_tp);
748             }
749             if (relu6_min != 0.f || relu6_max != 6.f)
750                 continue;
751 
752             if (node_reference[node3->output(0)] != 1)
753                 continue;
754 
755             if (node3->input(0) != node->input(0) || node3->input(1) != node2->output(0))
756                 continue;
757 
758             if (weights.find(node4->input(1)) == weights.end())
759                 continue;
760 
761             const onnx::TensorProto& div_six = weights[node4->input(1)];
762             if (div_six.dims_size() != 0 || get_tensor_proto_data_size(div_six) != 1)
763                 continue;
764 
765             float constant_div_six = get_node_attr_from_input_f(div_six);
766             if (node4->op_type() == "Div" && constant_div_six != 6.f)
767                 continue;
768             if (node4->op_type() == "Mul" && constant_div_six != 1 / 6.f)
769                 continue;
770 
771             // reduce
772             node->set_op_type("noop_reducedncnn");
773             node2->set_op_type("noop_reducedncnn");
774             node3->set_op_type("noop_reducedncnn");
775 
776             node_reference[node->input(0)] -= 1;
777             node_reference[node->input(1)] -= 1;
778             node_reference[node->output(0)] -= 1;
779             if (node2->input_size() == 3)
780             {
781                 node_reference[node2->input(1)] -= 1;
782                 node_reference[node2->input(2)] -= 1;
783             }
784             node_reference[node2->output(0)] -= 1;
785             node_reference[node3->output(0)] -= 1;
786             node_reference[node4->input(1)] -= 1;
787 
788             blob_names.erase(node->output(0));
789             blob_names.erase(node2->output(0));
790             blob_names.erase(node3->output(0));
791 
792             node4->set_op_type("HardSwish");
793             node4->clear_input();
794             node4->add_input(node->input(0));
795 
796             onnx::AttributeProto* attr_alpha = node4->add_attribute();
797             attr_alpha->set_name("alpha");
798             attr_alpha->set_f(1.f / 6.f);
799 
800             onnx::AttributeProto* attr_beta = node4->add_attribute();
801             attr_beta->set_name("beta");
802             attr_beta->set_f(3.f / 6.f);
803 
804             reduced_node_count += 3;
805             i += 3;
806         }
807     }
808 
809     for (int i = 0; i < node_count; i++)
810     {
811         onnx::NodeProto* node = mutable_graph->mutable_node(i);
812 
813         // HardSwish <= HardSigmoid - Mul
814         //     out = x * hsigmoid(x)
815         if (node->op_type() == "HardSigmoid")
816         {
817             if (node_reference[node->output(0)] != 1)
818                 continue;
819 
820             float alpha = get_node_attr_f(*node, "alpha", 0.2f);
821             float beta = get_node_attr_f(*node, "beta", 0.5f);
822 
823             if (i + 1 >= node_count)
824                 continue;
825 
826             onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
827 
828             if (node2->op_type() != "Mul")
829                 continue;
830 
831             if (node2->input(0) != node->input(0) || node2->input(1) != node->output(0))
832                 continue;
833 
834             // reduce
835             node->set_op_type("noop_reducedncnn");
836 
837             node_reference[node->input(0)] -= 1;
838             node_reference[node->output(0)] -= 1;
839 
840             blob_names.erase(node->output(0));
841 
842             node2->set_op_type("HardSwish");
843             node2->clear_input();
844             node2->add_input(node->input(0));
845 
846             onnx::AttributeProto* attr_alpha = node2->add_attribute();
847             attr_alpha->set_name("alpha");
848             attr_alpha->set_f(alpha);
849 
850             onnx::AttributeProto* attr_beta = node2->add_attribute();
851             attr_beta->set_name("beta");
852             attr_beta->set_f(beta);
853 
854             reduced_node_count += 1;
855             i += 1;
856         }
857     }
858 }
859 
fuse_hardsigmoid(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)860 static void fuse_hardsigmoid(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
861 {
862     int node_count = mutable_graph->node_size();
863     for (int i = 0; i < node_count; i++)
864     {
865         onnx::NodeProto* node = mutable_graph->mutable_node(i);
866 
867         // HardSigmoid <= Add(+3) - Clip(0,6) - Div(/6)
868         // HardSigmoid <= Add(+3) - Clip(0,6) - Mul(*(1/6))
869         // HardSigmoid <= Add(+3) - Clip(0,6) - Constant - Div(/6)
870         // HardSigmoid <= Add(+3) - Clip(0,6) - Constant - Mul(*(1/6))
871         //     out = F.relu6(x + 3, inplace=True) / 6
872         if (node->op_type() == "Add")
873         {
874             if (node_reference[node->output(0)] != 1)
875                 continue;
876 
877             if (i + 2 >= node_count)
878                 continue;
879 
880             if (weights.find(node->input(1)) == weights.end())
881                 continue;
882 
883             const onnx::TensorProto& add_three = weights[node->input(1)];
884             if (add_three.dims_size() != 0 || get_tensor_proto_data_size(add_three) != 1)
885                 continue;
886 
887             float constant_add_three = get_node_attr_from_input_f(add_three);
888             if (constant_add_three != 3.f)
889                 continue;
890 
891             onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
892             onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
893 
894             if (node3->op_type() == "Constant")
895             {
896                 if (i + 3 >= node_count)
897                     continue;
898 
899                 node3 = mutable_graph->mutable_node(i + 3);
900             }
901 
902             if (node2->op_type() != "Clip" || (node3->op_type() != "Div" && node3->op_type() != "Mul"))
903                 continue;
904 
905             if (node_reference[node2->output(0)] != 1)
906                 continue;
907 
908             float relu6_min;
909             float relu6_max;
910             if (node2->input_size() == 1)
911             {
912                 relu6_min = get_node_attr_f(*node2, "min", -FLT_MAX);
913                 relu6_max = get_node_attr_f(*node2, "max", FLT_MAX);
914             }
915             else
916             {
917                 const onnx::TensorProto& min_tp = weights[node2->input(1)];
918                 const onnx::TensorProto& max_tp = weights[node2->input(2)];
919 
920                 relu6_min = get_node_attr_from_input_f(min_tp);
921                 relu6_max = get_node_attr_from_input_f(max_tp);
922             }
923             if (relu6_min != 0.f || relu6_max != 6.f)
924                 continue;
925 
926             if (weights.find(node3->input(1)) == weights.end())
927                 continue;
928 
929             const onnx::TensorProto& div_six = weights[node3->input(1)];
930             if (div_six.dims_size() != 0 || get_tensor_proto_data_size(div_six) != 1)
931                 continue;
932 
933             float constant_div_six = get_node_attr_from_input_f(div_six);
934             if (node3->op_type() == "Div" && constant_div_six != 6.f)
935                 continue;
936             if (node3->op_type() == "Mul" && constant_div_six != 1 / 6.f)
937                 continue;
938 
939             // reduce
940             node->set_op_type("noop_reducedncnn");
941             node2->set_op_type("noop_reducedncnn");
942 
943             node_reference[node->input(1)] -= 1;
944             node_reference[node->output(0)] -= 1;
945             if (node2->input_size() == 3)
946             {
947                 node_reference[node2->input(1)] -= 1;
948                 node_reference[node2->input(2)] -= 1;
949             }
950             node_reference[node2->output(0)] -= 1;
951             node_reference[node3->input(1)] -= 1;
952 
953             blob_names.erase(node->output(0));
954             blob_names.erase(node2->output(0));
955 
956             node3->set_op_type("HardSigmoid");
957             node3->clear_input();
958             node3->add_input(node->input(0));
959 
960             onnx::AttributeProto* attr_alpha = node3->add_attribute();
961             attr_alpha->set_name("alpha");
962             attr_alpha->set_f(1.f / 6.f);
963 
964             onnx::AttributeProto* attr_beta = node3->add_attribute();
965             attr_beta->set_name("beta");
966             attr_beta->set_f(3.f / 6.f);
967 
968             reduced_node_count += 2;
969             i += 2;
970         }
971     }
972 }
973 
fuse_swish(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)974 static void fuse_swish(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
975 {
976     int node_count = mutable_graph->node_size();
977     for (int i = 0; i < node_count; i++)
978     {
979         onnx::NodeProto* node = mutable_graph->mutable_node(i);
980 
981         // Swish <= Sigmoid - Mul
982         //     x * torch.sigmoid(x)
983         if (node->op_type() == "Sigmoid")
984         {
985             if (node_reference[node->output(0)] != 1)
986                 continue;
987 
988             if (i + 1 >= node_count)
989                 continue;
990 
991             onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
992 
993             if (node2->op_type() != "Mul")
994                 continue;
995 
996             if (node2->input(0) != node->input(0) || node2->input(1) != node->output(0))
997                 continue;
998 
999             // reduce
1000             node->set_op_type("noop_reducedncnn");
1001 
1002             node_reference[node->input(0)] -= 1;
1003             node_reference[node->output(0)] -= 1;
1004 
1005             blob_names.erase(node->output(0));
1006 
1007             node2->set_op_type("Swish");
1008             node2->clear_input();
1009             node2->add_input(node->input(0));
1010 
1011             reduced_node_count += 1;
1012             i += 1;
1013         }
1014     }
1015 }
1016 
fuse_batchnorm1d_squeeze_unsqueeze(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)1017 static void fuse_batchnorm1d_squeeze_unsqueeze(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
1018 {
1019     int node_count = mutable_graph->node_size();
1020     for (int i = 0; i < node_count; i++)
1021     {
1022         onnx::NodeProto* node = mutable_graph->mutable_node(i);
1023 
1024         // BatchNormalization <= Unsqueeze - BatchNormalization - Squeeze
1025         if (node->op_type() == "Unsqueeze")
1026         {
1027             if (node_reference[node->output(0)] != 1)
1028                 continue;
1029 
1030             if (i + 2 >= node_count)
1031                 continue;
1032 
1033             onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
1034             onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
1035 
1036             if (node2->op_type() != "BatchNormalization" || node3->op_type() != "Squeeze")
1037                 continue;
1038 
1039             if (node_reference[node2->output(0)] != 1)
1040                 continue;
1041 
1042             if (node2->input(0) != node->output(0) || node3->input(0) != node2->output(0))
1043                 continue;
1044 
1045             // reduce
1046             node->set_op_type("noop_reducedncnn");
1047             node3->set_op_type("noop_reducedncnn");
1048 
1049             node_reference[node->output(0)] -= 1;
1050             node_reference[node2->output(0)] -= 1;
1051 
1052             blob_names.erase(node->output(0));
1053             blob_names.erase(node2->output(0));
1054 
1055             node2->set_input(0, node->input(0));
1056             node2->set_output(0, node3->output(0));
1057 
1058             reduced_node_count += 2;
1059             i += 2;
1060         }
1061     }
1062 }
1063 
fuse_unsqueeze_prelu(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)1064 static void fuse_unsqueeze_prelu(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
1065 {
1066     int node_count = mutable_graph->node_size();
1067     for (int i = 0; i < node_count; i++)
1068     {
1069         onnx::NodeProto* node = mutable_graph->mutable_node(i);
1070 
1071         // PReLU <= Unsqueeze - PReLU
1072         if (node->op_type() == "Unsqueeze")
1073         {
1074             // check weight
1075             if (weights.find(node->input(0)) == weights.end())
1076                 continue;
1077 
1078             onnx::TensorProto& B = weights[node->input(0)];
1079             if (B.dims_size() != 1)
1080                 continue;
1081 
1082             if (node_reference[node->output(0)] != 1)
1083                 continue;
1084 
1085             // axes = (1, 2)
1086             std::vector<int> axes = get_node_attr_ai(*node, "axes");
1087             if (axes.size() != 2)
1088                 continue;
1089             if (axes[0] != 1 || axes[1] != 2)
1090                 continue;
1091 
1092             if (i + 1 >= node_count)
1093                 continue;
1094 
1095             onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
1096 
1097             if (node2->op_type() != "PRelu")
1098                 continue;
1099 
1100             if (node2->input(1) != node->output(0))
1101                 continue;
1102 
1103             // reduce
1104             node->set_op_type("noop_reducedncnn");
1105 
1106             node_reference[node->output(0)] -= 1;
1107 
1108             blob_names.erase(node->output(0));
1109 
1110             node2->set_input(1, node->input(0));
1111 
1112             reduced_node_count += 1;
1113             i += 1;
1114         }
1115     }
1116 }
1117 
fuse_normalize(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)1118 static void fuse_normalize(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
1119 {
1120     int node_count = mutable_graph->node_size();
1121     for (int i = 0; i < node_count; i++)
1122     {
1123         onnx::NodeProto* node = mutable_graph->mutable_node(i);
1124 
1125         // Normalize <= X - ReduceL2 - Clip - Expand - Div
1126         // Normalize <= X - ReduceL2 - Clip - Shape - Expand - Div
1127         if (node->op_type() == "ReduceL2")
1128         {
1129             if (node_reference[node->output(0)] != 1)
1130                 continue;
1131 
1132             // axes = (1)
1133             std::vector<int> axes = get_node_attr_ai(*node, "axes");
1134             if (axes.size() != 1)
1135                 continue;
1136             if (axes[0] != 1)
1137                 continue;
1138 
1139             if (i + 3 >= node_count)
1140                 continue;
1141 
1142             onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
1143             onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
1144             onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3);
1145 
1146             bool has_shape_node = node3->op_type() == "Shape";
1147             onnx::NodeProto* node_shape = 0;
1148             if (has_shape_node)
1149             {
1150                 if (i + 4 >= node_count)
1151                     continue;
1152 
1153                 node_shape = node3;
1154                 node3 = mutable_graph->mutable_node(i + 3);
1155                 node4 = mutable_graph->mutable_node(i + 4);
1156             }
1157 
1158             if (node2->op_type() != "Clip" || node3->op_type() != "Expand" || node4->op_type() != "Div")
1159                 continue;
1160 
1161             if (node_reference[node2->output(0)] != 1)
1162                 continue;
1163 
1164             if (node_reference[node3->output(0)] != 1)
1165                 continue;
1166 
1167             if (node2->input(0) != node->output(0) || node3->input(0) != node2->output(0)
1168                     || node4->input(0) != node->input(0) || node4->input(1) != node3->output(0))
1169                 continue;
1170 
1171             if (has_shape_node)
1172             {
1173                 if (node_shape->input(0) != node->input(0) || node3->input(1) != node_shape->output(0))
1174                     continue;
1175             }
1176 
1177             // +eps
1178             float clip_min;
1179             if (node2->input_size() == 1)
1180             {
1181                 clip_min = get_node_attr_f(*node2, "min", -FLT_MAX);
1182             }
1183             else
1184             {
1185                 const onnx::TensorProto& min_tp = weights[node2->input(1)];
1186 
1187                 clip_min = get_node_attr_from_input_f(min_tp);
1188             }
1189 
1190             // reduce
1191             node->set_op_type("noop_reducedncnn");
1192             node2->set_op_type("noop_reducedncnn");
1193             if (has_shape_node)
1194             {
1195                 node_shape->set_op_type("noop_reducedncnn");
1196             }
1197             node3->set_op_type("noop_reducedncnn");
1198 
1199             node_reference[node->input(0)] -= has_shape_node ? 2 : 1;
1200             node_reference[node->output(0)] -= 1;
1201             node_reference[node2->output(0)] -= 1;
1202             if (has_shape_node)
1203             {
1204                 node_reference[node_shape->output(0)] -= 1;
1205             }
1206             node_reference[node3->output(0)] -= 1;
1207             if (node3->input_size() == 2)
1208             {
1209                 node_reference[node3->input(1)] -= 1;
1210             }
1211 
1212             blob_names.erase(node->output(0));
1213             blob_names.erase(node2->output(0));
1214             if (has_shape_node)
1215             {
1216                 blob_names.erase(node_shape->output(0));
1217             }
1218             blob_names.erase(node3->output(0));
1219 
1220             node4->set_op_type("Normalize");
1221             node4->clear_input();
1222             node4->add_input(node->input(0));
1223 
1224             onnx::AttributeProto* attr_alpha = node4->add_attribute();
1225             attr_alpha->set_name("eps");
1226             attr_alpha->set_f(clip_min);
1227 
1228             reduced_node_count += has_shape_node ? 4 : 3;
1229             i += has_shape_node ? 4 : 3;
1230         }
1231     }
1232 }
1233 
fuse_groupnorm(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)1234 static void fuse_groupnorm(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
1235 {
1236     int node_count = mutable_graph->node_size();
1237     for (int i = 0; i < node_count; i++)
1238     {
1239         onnx::NodeProto* node = mutable_graph->mutable_node(i);
1240 
1241         // GroupNorm <= X - Reshape - InstanceNormalization - Reshape - Mul - Add
1242         if (node->op_type() == "Reshape")
1243         {
1244             if (node_reference[node->output(0)] != 1)
1245                 continue;
1246 
1247             std::vector<int> shape;
1248             if (node->input_size() == 1)
1249             {
1250                 shape = get_node_attr_ai(*node, "shape");
1251             }
1252             else
1253             {
1254                 // skip weight reshape
1255                 if (weights.find(node->input(1)) == weights.end())
1256                     continue;
1257 
1258                 shape = get_node_attr_from_input_ai(weights[node->input(1)]);
1259             }
1260 
1261             // 0, group, -1
1262             if (shape.size() != 3)
1263                 continue;
1264 
1265             if (shape[0] != 0 || shape[2] != -1)
1266                 continue;
1267 
1268             int groups = shape[1];
1269 
1270             if (i + 4 >= node_count)
1271                 continue;
1272 
1273             onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
1274             onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
1275             onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3);
1276             onnx::NodeProto* node5 = mutable_graph->mutable_node(i + 4);
1277 
1278             if (node2->op_type() != "InstanceNormalization" || node3->op_type() != "Reshape" || node4->op_type() != "Mul" || node5->op_type() != "Add")
1279                 continue;
1280 
1281             if (node_reference[node2->output(0)] != 1)
1282                 continue;
1283 
1284             if (node_reference[node3->output(0)] != 1)
1285                 continue;
1286 
1287             if (node_reference[node4->output(0)] != 1)
1288                 continue;
1289 
1290             if (node2->input(0) != node->output(0) || node3->input(0) != node2->output(0)
1291                     || node4->input(0) != node3->output(0) || node5->input(0) != node4->output(0))
1292                 continue;
1293 
1294             // +eps
1295             float eps = get_node_attr_f(*node2, "epsilon", 1e-05f);
1296 
1297             // InstanceNormalization S=1 B=0
1298             std::vector<float> S = get_node_attr_from_input_af(weights[node2->input(1)]);
1299             std::vector<float> B = get_node_attr_from_input_af(weights[node2->input(2)]);
1300             if ((int)S.size() != groups || (int)B.size() != groups)
1301                 continue;
1302 
1303             bool instancenorm_affine = false;
1304             for (int j = 0; j < groups; j++)
1305             {
1306                 if (S[j] != 1.f || B[j] != 0.f)
1307                 {
1308                     instancenorm_affine = true;
1309                     break;
1310                 }
1311             }
1312 
1313             if (instancenorm_affine)
1314                 continue;
1315 
1316             std::vector<int> shape2;
1317             if (node3->input_size() == 1)
1318             {
1319                 shape2 = get_node_attr_ai(*node3, "shape");
1320             }
1321             else
1322             {
1323                 // skip weight reshape
1324                 if (weights.find(node3->input(1)) == weights.end())
1325                     continue;
1326 
1327                 shape2 = get_node_attr_from_input_ai(weights[node3->input(1)]);
1328             }
1329 
1330             // 1, channels, w, h
1331             if (shape2.size() != 4)
1332                 continue;
1333 
1334             if (shape2[0] != 1)
1335                 continue;
1336 
1337             int channels = shape2[1];
1338 
1339             // affine
1340             std::vector<float> affine_S = get_node_attr_from_input_af(weights[node4->input(1)]);
1341             std::vector<float> affine_B = get_node_attr_from_input_af(weights[node5->input(1)]);
1342             if (affine_S.size() == 1 && affine_S[0] == 1.f && affine_B.size() == 1 && affine_B[0] == 0.f)
1343             {
1344                 // no affine
1345             }
1346             else if ((int)affine_S.size() != channels && (int)affine_B.size() != channels)
1347             {
1348                 // we only allow per-channel affine
1349                 continue;
1350             }
1351 
1352             // reduce
1353             node->set_op_type("noop_reducedncnn");
1354             node2->set_op_type("noop_reducedncnn");
1355             node3->set_op_type("noop_reducedncnn");
1356             node4->set_op_type("noop_reducedncnn");
1357 
1358             if (node->input_size() == 2)
1359             {
1360                 node_reference[node->input(1)] -= 1;
1361             }
1362             node_reference[node->output(0)] -= 1;
1363             node_reference[node2->input(1)] -= 1;
1364             node_reference[node2->input(2)] -= 1;
1365             node_reference[node2->output(0)] -= 1;
1366             if (node3->input_size() == 2)
1367             {
1368                 node_reference[node3->input(1)] -= 1;
1369             }
1370             node_reference[node3->output(0)] -= 1;
1371             node_reference[node4->output(0)] -= 1;
1372 
1373             blob_names.erase(node->output(0));
1374             blob_names.erase(node2->output(0));
1375             blob_names.erase(node3->output(0));
1376             blob_names.erase(node4->output(0));
1377 
1378             std::string affine_scale = node4->input(1);
1379             std::string affine_bias = node5->input(1);
1380 
1381             node5->set_op_type("GroupNorm");
1382             node5->clear_input();
1383             node5->add_input(node->input(0));
1384             node5->add_input(affine_scale);
1385             node5->add_input(affine_bias);
1386 
1387             onnx::AttributeProto* attr_groups = node5->add_attribute();
1388             attr_groups->set_name("groups");
1389             attr_groups->set_i(groups);
1390 
1391             onnx::AttributeProto* attr_channels = node5->add_attribute();
1392             attr_channels->set_name("channels");
1393             attr_channels->set_i(channels);
1394 
1395             onnx::AttributeProto* attr_eps = node5->add_attribute();
1396             attr_eps->set_name("epsilon");
1397             attr_eps->set_f(eps);
1398 
1399             onnx::AttributeProto* attr_affine = node5->add_attribute();
1400             attr_affine->set_name("affine");
1401             attr_affine->set_i(1);
1402 
1403             reduced_node_count += 4;
1404             i += 4;
1405         }
1406     }
1407 }
1408 
fuse_layernorm(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)1409 static void fuse_layernorm(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
1410 {
1411     int node_count = mutable_graph->node_size();
1412     for (int i = 0; i < node_count; i++)
1413     {
1414         onnx::NodeProto* node = mutable_graph->mutable_node(i);
1415 
1416         // LayerNorm <= X - ReduceMean - Sub - Pow - ReduceMean - Add - Sqrt - Div
1417         // LayerNorm <= X - ReduceMean - Sub - Pow - ReduceMean - Add - Sqrt - Div - Mul - Add
1418         if (node->op_type() == "ReduceMean")
1419         {
1420             if (node_reference[node->output(0)] != 1)
1421                 continue;
1422 
1423             std::vector<int> axes = get_node_attr_ai(*node, "axes");
1424 
1425             // -1
1426             // -2 -1
1427             if (axes.size() != 1 && axes.size() != 2)
1428                 continue;
1429 
1430             int normed_axes = (int)axes.size();
1431             if (normed_axes == 1 && axes[0] != -1)
1432                 continue;
1433             if (normed_axes == 2 && (axes[0] != -2 || axes[1] != -1))
1434                 continue;
1435 
1436             if (i + 6 >= node_count)
1437                 continue;
1438 
1439             onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
1440             onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
1441             onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3);
1442             onnx::NodeProto* node5 = mutable_graph->mutable_node(i + 4);
1443             onnx::NodeProto* node6 = mutable_graph->mutable_node(i + 5);
1444             onnx::NodeProto* node7 = mutable_graph->mutable_node(i + 6);
1445 
1446             if (node2->op_type() != "Sub" || node3->op_type() != "Pow" || node4->op_type() != "ReduceMean" || node5->op_type() != "Add" || node6->op_type() != "Sqrt" || node7->op_type() != "Div")
1447                 continue;
1448 
1449             if (node_reference[node2->output(0)] != 2)
1450                 continue;
1451 
1452             if (node_reference[node3->output(0)] != 1)
1453                 continue;
1454 
1455             if (node_reference[node4->output(0)] != 1)
1456                 continue;
1457 
1458             if (node_reference[node5->output(0)] != 1)
1459                 continue;
1460 
1461             if (node_reference[node6->output(0)] != 1)
1462                 continue;
1463 
1464             if (node2->input(0) != node->input(0) || node2->input(1) != node->output(0)
1465                     || node3->input(0) != node2->output(0) || node4->input(0) != node3->output(0)
1466                     || node5->input(0) != node4->output(0) || node6->input(0) != node5->output(0)
1467                     || node7->input(0) != node2->output(0) || node7->input(1) != node6->output(0))
1468                 continue;
1469 
1470             if (weights.find(node3->input(1)) == weights.end())
1471                 continue;
1472 
1473             const onnx::TensorProto& pow_two = weights[node3->input(1)];
1474             if (pow_two.dims_size() != 0 || get_tensor_proto_data_size(pow_two) != 1)
1475                 continue;
1476 
1477             float constant_pow_two = get_node_attr_from_input_f(pow_two);
1478             if (constant_pow_two != 2.f)
1479                 continue;
1480 
1481             std::vector<int> axes4 = get_node_attr_ai(*node4, "axes");
1482 
1483             // -1
1484             // -2 -1
1485             if ((int)axes4.size() != normed_axes)
1486                 continue;
1487 
1488             if (normed_axes == 1 && axes4[0] != -1)
1489                 continue;
1490             if (normed_axes == 2 && (axes4[0] != -2 || axes4[1] != -1))
1491                 continue;
1492 
1493             if (weights.find(node5->input(1)) == weights.end())
1494                 continue;
1495 
1496             const onnx::TensorProto& add_eps = weights[node5->input(1)];
1497             if (add_eps.dims_size() != 0 || get_tensor_proto_data_size(add_eps) != 1)
1498                 continue;
1499 
1500             float eps = get_node_attr_from_input_f(add_eps);
1501 
1502             int affine = 0;
1503             while (i + 8 < node_count)
1504             {
1505                 onnx::NodeProto* node8 = mutable_graph->mutable_node(i + 7);
1506                 onnx::NodeProto* node9 = mutable_graph->mutable_node(i + 8);
1507 
1508                 if (node8->op_type() != "Mul" || node9->op_type() != "Add")
1509                     break;
1510 
1511                 if (node_reference[node7->output(0)] != 1)
1512                     break;
1513 
1514                 if (node_reference[node8->output(0)] != 1)
1515                     break;
1516 
1517                 if (node8->input(0) != node7->output(0) || node9->input(0) != node8->output(0))
1518                     break;
1519 
1520                 // affine
1521                 std::vector<float> affine_S = get_node_attr_from_input_af(weights[node8->input(1)]);
1522                 std::vector<float> affine_B = get_node_attr_from_input_af(weights[node9->input(1)]);
1523                 if (affine_S.size() != affine_B.size())
1524                     break;
1525 
1526                 affine = 1;
1527                 break;
1528             }
1529 
1530             // reduce
1531             node->set_op_type("noop_reducedncnn");
1532             node2->set_op_type("noop_reducedncnn");
1533             node3->set_op_type("noop_reducedncnn");
1534             node4->set_op_type("noop_reducedncnn");
1535             node5->set_op_type("noop_reducedncnn");
1536             node6->set_op_type("noop_reducedncnn");
1537 
1538             node_reference[node->input(0)] -= 1;
1539             node_reference[node2->input(0)] -= 1;
1540             node_reference[node2->input(1)] -= 1;
1541             node_reference[node3->input(0)] -= 1;
1542             node_reference[node3->input(1)] -= 1;
1543             node_reference[node4->input(0)] -= 1;
1544             node_reference[node5->input(0)] -= 1;
1545             node_reference[node5->input(1)] -= 1;
1546             node_reference[node6->input(0)] -= 1;
1547             node_reference[node7->input(0)] -= 1;
1548             node_reference[node7->input(1)] -= 1;
1549 
1550             blob_names.erase(node->output(0));
1551             blob_names.erase(node2->output(0));
1552             blob_names.erase(node3->output(0));
1553             blob_names.erase(node4->output(0));
1554             blob_names.erase(node5->output(0));
1555             blob_names.erase(node6->output(0));
1556 
1557             node_reference[node->input(0)] += 1;
1558 
1559             if (affine == 0)
1560             {
1561                 node7->set_op_type("LayerNorm");
1562                 node7->clear_input();
1563                 node7->add_input(node->input(0));
1564 
1565                 onnx::AttributeProto* attr_eps = node7->add_attribute();
1566                 attr_eps->set_name("epsilon");
1567                 attr_eps->set_f(eps);
1568 
1569                 onnx::AttributeProto* attr_affine = node7->add_attribute();
1570                 attr_affine->set_name("affine");
1571                 attr_affine->set_i(affine);
1572 
1573                 reduced_node_count += 6;
1574                 i += 6;
1575             }
1576             else // if (affine == 1)
1577             {
1578                 onnx::NodeProto* node8 = mutable_graph->mutable_node(i + 7);
1579                 onnx::NodeProto* node9 = mutable_graph->mutable_node(i + 8);
1580 
1581                 node7->set_op_type("noop_reducedncnn");
1582                 node8->set_op_type("noop_reducedncnn");
1583 
1584                 node_reference[node8->input(0)] -= 1;
1585                 node_reference[node9->input(0)] -= 1;
1586 
1587                 blob_names.erase(node7->output(0));
1588                 blob_names.erase(node8->output(0));
1589 
1590                 std::string affine_scale = node8->input(1);
1591                 std::string affine_bias = node9->input(1);
1592 
1593                 node9->set_op_type("LayerNorm");
1594                 node9->clear_input();
1595                 node9->add_input(node->input(0));
1596                 node9->add_input(affine_scale);
1597                 node9->add_input(affine_bias);
1598 
1599                 onnx::AttributeProto* attr_eps = node9->add_attribute();
1600                 attr_eps->set_name("epsilon");
1601                 attr_eps->set_f(eps);
1602 
1603                 onnx::AttributeProto* attr_affine = node9->add_attribute();
1604                 attr_affine->set_name("affine");
1605                 attr_affine->set_i(affine);
1606 
1607                 reduced_node_count += 8;
1608                 i += 8;
1609             }
1610         }
1611     }
1612 }
1613 
fuse_flatten(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)1614 static void fuse_flatten(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
1615 {
1616     int node_count = mutable_graph->node_size();
1617     for (int i = 0; i < node_count; i++)
1618     {
1619         onnx::NodeProto* node = mutable_graph->mutable_node(i);
1620 
1621         // Flatten <= X - Shape - Gather - Constant - Unsqueeze - Unsqueeze - Concat - Reshape
1622         if (node->op_type() == "Shape")
1623         {
1624             if (node_reference[node->output(0)] != 1)
1625                 continue;
1626 
1627             if (i + 6 >= node_count)
1628                 continue;
1629 
1630             onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
1631             onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
1632             onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3);
1633             onnx::NodeProto* node5 = mutable_graph->mutable_node(i + 4);
1634             onnx::NodeProto* node6 = mutable_graph->mutable_node(i + 5);
1635             onnx::NodeProto* node7 = mutable_graph->mutable_node(i + 6);
1636 
1637             if (node2->op_type() != "Gather" || node3->op_type() != "Constant" || node4->op_type() != "Unsqueeze" || node5->op_type() != "Unsqueeze"
1638                     || node6->op_type() != "Concat" || node7->op_type() != "Reshape")
1639                 continue;
1640 
1641             if (node_reference[node2->output(0)] != 1)
1642                 continue;
1643 
1644             //             if (node_reference[node3->output(0)] != 1)
1645             //                 continue;
1646 
1647             if (node_reference[node4->output(0)] != 1)
1648                 continue;
1649 
1650             if (node_reference[node5->output(0)] != 1)
1651                 continue;
1652 
1653             if (node_reference[node6->output(0)] != 1)
1654                 continue;
1655 
1656             if (node2->input(0) != node->output(0) || node4->input(0) != node2->output(0) || node5->input(0) != node3->output(0)
1657                     || node6->input(0) != node4->output(0) || node6->input(1) != node5->output(0)
1658                     || node7->input(0) != node->input(0) || node7->input(1) != node6->output(0))
1659                 continue;
1660 
1661             // axis = 0
1662             int gather_axis = get_node_attr_i(*node2, "axis");
1663             if (gather_axis != 0)
1664                 continue;
1665 
1666             // indices = 0
1667             if (weights.find(node2->input(1)) == weights.end())
1668                 continue;
1669 
1670             std::vector<int> gather_indices = get_node_attr_from_input_ai(weights[node2->input(1)]);
1671             if (gather_indices.size() != 1 || gather_indices[0] != 0)
1672                 continue;
1673 
1674             // axes = (0)
1675             std::vector<int> unsqueeze_axes = get_node_attr_ai(*node4, "axes");
1676             if (unsqueeze_axes.size() != 1)
1677                 continue;
1678             if (unsqueeze_axes[0] != 0)
1679                 continue;
1680 
1681             // axes = (0)
1682             std::vector<int> unsqueeze2_axes = get_node_attr_ai(*node5, "axes");
1683             if (unsqueeze2_axes.size() != 1)
1684                 continue;
1685             if (unsqueeze2_axes[0] != 0)
1686                 continue;
1687 
1688             // data = -1
1689             if (weights.find(node5->input(0)) == weights.end())
1690                 continue;
1691 
1692             std::vector<int> unsqueeze2_data = get_node_attr_from_input_ai(weights[node5->input(0)]);
1693             if (unsqueeze2_data.size() != 1 || unsqueeze2_data[0] != -1)
1694                 continue;
1695 
1696             // axis = 0
1697             int concat_axis = get_node_attr_i(*node6, "axis");
1698             if (concat_axis != 0)
1699                 continue;
1700 
1701             // reduce
1702             node->set_op_type("noop_reducedncnn");
1703             node2->set_op_type("noop_reducedncnn");
1704             //             node3->set_op_type("noop_reducedncnn");
1705             node4->set_op_type("noop_reducedncnn");
1706             node5->set_op_type("noop_reducedncnn");
1707             node6->set_op_type("noop_reducedncnn");
1708 
1709             node_reference[node->input(0)] -= 1;
1710             node_reference[node->output(0)] -= 1;
1711             node_reference[node2->input(1)] -= 1;
1712             node_reference[node2->output(0)] -= 1;
1713             //             node_reference[node3->output(0)] -= 1;
1714             node_reference[node4->output(0)] -= 1;
1715             node_reference[node5->input(0)] -= 1;
1716             node_reference[node5->output(0)] -= 1;
1717             node_reference[node6->output(0)] -= 1;
1718 
1719             blob_names.erase(node->output(0));
1720             blob_names.erase(node2->output(0));
1721             //             blob_names.erase(node3->output(0));
1722             blob_names.erase(node4->output(0));
1723             blob_names.erase(node5->output(0));
1724             blob_names.erase(node6->output(0));
1725 
1726             node7->set_op_type("Flatten");
1727             node7->clear_input();
1728             node7->add_input(node->input(0));
1729 
1730             reduced_node_count += 5;
1731             i += 5;
1732         }
1733     }
1734 }
1735 
fuse_pixelshuffle(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)1736 static void fuse_pixelshuffle(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
1737 {
1738     int node_count = mutable_graph->node_size();
1739     for (int i = 0; i < node_count; i++)
1740     {
1741         onnx::NodeProto* node = mutable_graph->mutable_node(i);
1742 
1743         // PixelShuffle <= Reshape - Transpose - Reshape
1744         // PixelShuffle <= Reshape - Transpose - Constant - Reshape
1745         if (node->op_type() == "Reshape")
1746         {
1747             if (node_reference[node->output(0)] != 1)
1748                 continue;
1749 
1750             std::vector<int> shape;
1751             if (node->input_size() == 1)
1752             {
1753                 shape = get_node_attr_ai(*node, "shape");
1754             }
1755             else
1756             {
1757                 // skip weight reshape
1758                 if (weights.find(node->input(1)) == weights.end())
1759                     continue;
1760 
1761                 shape = get_node_attr_from_input_ai(weights[node->input(1)]);
1762             }
1763 
1764             // -1, 3, upscale_factor, upscale_factor, height, width
1765             if (shape.size() != 6)
1766                 continue;
1767 
1768             if (shape[0] != 1 && shape[0] != -1)
1769                 continue;
1770 
1771             if (shape[2] != shape[3])
1772                 continue;
1773 
1774             if (i + 2 >= node_count)
1775                 continue;
1776 
1777             onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
1778             onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
1779 
1780             if (node3->op_type() == "Constant")
1781             {
1782                 if (i + 3 >= node_count)
1783                     continue;
1784 
1785                 node3 = mutable_graph->mutable_node(i + 3);
1786             }
1787 
1788             if (node2->op_type() != "Transpose" || node3->op_type() != "Reshape")
1789                 continue;
1790 
1791             if (node_reference[node2->output(0)] != 1)
1792                 continue;
1793 
1794             // 0 1 4 2 5 3
1795             std::vector<int> perm = get_node_attr_ai(*node2, "perm");
1796             if (perm.size() != 6)
1797                 continue;
1798 
1799             if (perm[0] != 0 || perm[1] != 1 || perm[2] != 4 || perm[3] != 2 || perm[4] != 5 || perm[5] != 3)
1800                 continue;
1801 
1802             std::vector<int> shape3;
1803             if (node3->input_size() == 1)
1804             {
1805                 shape3 = get_node_attr_ai(*node3, "shape");
1806             }
1807             else
1808             {
1809                 // skip weight reshape
1810                 if (weights.find(node3->input(1)) == weights.end())
1811                     continue;
1812 
1813                 shape3 = get_node_attr_from_input_ai(weights[node3->input(1)]);
1814             }
1815 
1816             // -1, 3, height, width
1817             if (shape3.size() != 4)
1818                 continue;
1819 
1820             if (shape3[0] != 1 && shape3[0] != -1)
1821                 continue;
1822 
1823             if (shape3[1] != shape[1] || shape3[2] != shape[2] * shape[4] || shape3[3] != shape[3] * shape[5])
1824                 continue;
1825 
1826             // reduce
1827             node->set_op_type("noop_reducedncnn");
1828             node2->set_op_type("noop_reducedncnn");
1829 
1830             if (node->input_size() == 2)
1831             {
1832                 node_reference[node->input(1)] -= 1;
1833             }
1834             node_reference[node->output(0)] -= 1;
1835             node_reference[node2->output(0)] -= 1;
1836             if (node3->input_size() == 2)
1837             {
1838                 node_reference[node3->input(1)] -= 1;
1839             }
1840 
1841             blob_names.erase(node->output(0));
1842             blob_names.erase(node2->output(0));
1843 
1844             node3->set_op_type("PixelShuffle");
1845             node3->set_input(0, node->input(0));
1846 
1847             onnx::AttributeProto* attr_group = node3->add_attribute();
1848             attr_group->set_name("scale_factor");
1849             attr_group->set_i(shape[2]);
1850 
1851             reduced_node_count += 2;
1852             i += 2;
1853         }
1854     }
1855 }
1856 
fuse_reorg(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)1857 static void fuse_reorg(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
1858 {
1859     int node_count = mutable_graph->node_size();
1860     for (int i = 0; i < node_count; i++)
1861     {
1862         onnx::NodeProto* node = mutable_graph->mutable_node(i);
1863 
1864         // PixelShuffle <= Reshape - Transpose - Reshape
1865         // PixelShuffle <= Reshape - Transpose - Constant - Reshape
1866         if (node->op_type() == "Reshape")
1867         {
1868             if (node_reference[node->output(0)] != 1)
1869                 continue;
1870 
1871             std::vector<int> shape;
1872             if (node->input_size() == 1)
1873             {
1874                 shape = get_node_attr_ai(*node, "shape");
1875             }
1876             else
1877             {
1878                 // skip weight reshape
1879                 if (weights.find(node->input(1)) == weights.end())
1880                     continue;
1881 
1882                 shape = get_node_attr_from_input_ai(weights[node->input(1)]);
1883             }
1884 
1885             // -1, 3, out_height, block_size, out_width, block_size
1886             if (shape.size() != 6)
1887                 continue;
1888 
1889             if (shape[0] != 1 && shape[0] != -1)
1890                 continue;
1891 
1892             if (shape[3] != shape[5])
1893                 continue;
1894 
1895             if (i + 2 >= node_count)
1896                 continue;
1897 
1898             onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
1899             onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
1900 
1901             if (node3->op_type() == "Constant")
1902             {
1903                 if (i + 3 >= node_count)
1904                     continue;
1905 
1906                 node3 = mutable_graph->mutable_node(i + 3);
1907             }
1908 
1909             if (node2->op_type() != "Transpose" || node3->op_type() != "Reshape")
1910                 continue;
1911 
1912             if (node_reference[node2->output(0)] != 1)
1913                 continue;
1914 
1915             // 0 1 3 5 2 4
1916             std::vector<int> perm = get_node_attr_ai(*node2, "perm");
1917             if (perm.size() != 6)
1918                 continue;
1919 
1920             if (perm[0] != 0 || perm[1] != 1 || perm[2] != 3 || perm[3] != 5 || perm[4] != 2 || perm[5] != 4)
1921                 continue;
1922 
1923             std::vector<int> shape3;
1924             if (node3->input_size() == 1)
1925             {
1926                 shape3 = get_node_attr_ai(*node3, "shape");
1927             }
1928             else
1929             {
1930                 // skip weight reshape
1931                 if (weights.find(node3->input(1)) == weights.end())
1932                     continue;
1933 
1934                 shape3 = get_node_attr_from_input_ai(weights[node3->input(1)]);
1935             }
1936 
1937             // -1, out_channels, out_height, out_width
1938             if (shape3.size() != 4)
1939                 continue;
1940 
1941             if (shape3[0] != 1 && shape3[0] != -1)
1942                 continue;
1943 
1944             if (shape3[1] != shape[1] * shape[3] * shape[5] || shape3[2] != shape[2] || shape3[3] != shape[4])
1945                 continue;
1946 
1947             // reduce
1948             node->set_op_type("noop_reducedncnn");
1949             node2->set_op_type("noop_reducedncnn");
1950 
1951             if (node->input_size() == 2)
1952             {
1953                 node_reference[node->input(1)] -= 1;
1954             }
1955             node_reference[node->output(0)] -= 1;
1956             node_reference[node2->output(0)] -= 1;
1957             if (node3->input_size() == 2)
1958             {
1959                 node_reference[node3->input(1)] -= 1;
1960             }
1961 
1962             blob_names.erase(node->output(0));
1963             blob_names.erase(node2->output(0));
1964 
1965             node3->set_op_type("Reorg");
1966             node3->set_input(0, node->input(0));
1967 
1968             onnx::AttributeProto* attr_group = node3->add_attribute();
1969             attr_group->set_name("stride");
1970             attr_group->set_i(shape[3]);
1971 
1972             reduced_node_count += 2;
1973             i += 2;
1974         }
1975     }
1976 }
1977 
fuse_expand_broadcast(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)1978 static void fuse_expand_broadcast(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
1979 {
1980     int node_count = mutable_graph->node_size();
1981     for (int i = 0; i < node_count; i++)
1982     {
1983         onnx::NodeProto* node = mutable_graph->mutable_node(i);
1984 
1985         // Add/Sub/Mul/Div/Min/Max <= Expand - Add/Sub/Mul/Div/Min/Max
1986         if (node->op_type() == "Expand")
1987         {
1988             if (node_reference[node->output(0)] != 1)
1989                 continue;
1990 
1991             if (i + 1 >= node_count)
1992                 continue;
1993 
1994             onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
1995 
1996             if (node2->op_type() != "Add" && node2->op_type() != "Sub" && node2->op_type() != "Mul" && node2->op_type() != "Div" && node2->op_type() != "Min" && node2->op_type() != "Max")
1997                 continue;
1998 
1999             if (node2->input(1) != node->output(0) && node2->input(0) != node->output(0))
2000                 continue;
2001 
2002             // reduce
2003             node->set_op_type("noop_reducedncnn");
2004 
2005             node_reference[node->output(0)] -= 1;
2006             if (node->input_size() == 2)
2007             {
2008                 node_reference[node->input(1)] -= 1;
2009             }
2010 
2011             blob_names.erase(node->output(0));
2012 
2013             if (node2->input(0) == node->output(0))
2014             {
2015                 node2->set_input(0, node->input(0));
2016             }
2017             else
2018             {
2019                 node2->set_input(1, node->input(0));
2020             }
2021 
2022             reduced_node_count += 1;
2023             i += 1;
2024         }
2025     }
2026 }
2027 
fuse_lstm_gru_rnn(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)2028 static void fuse_lstm_gru_rnn(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
2029 {
2030     int node_count = mutable_graph->node_size();
2031     for (int i = 0; i < node_count; i++)
2032     {
2033         onnx::NodeProto* node = mutable_graph->mutable_node(i);
2034 
2035         // LSTM(bi) <= LSTM(bi) - Transpose - Reshape - Transpose
2036         if (node->op_type() == "LSTM" || node->op_type() == "GRU" || node->op_type() == "RNN")
2037         {
2038             if (node_reference[node->output(0)] != 1)
2039                 continue;
2040 
2041             if (i + 2 >= node_count)
2042                 continue;
2043 
2044             onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
2045             onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
2046 
2047             if (node2->op_type() != "Transpose" || node3->op_type() != "Reshape")
2048                 continue;
2049 
2050             if (node_reference[node2->output(0)] != 1)
2051                 continue;
2052 
2053             if (node2->input(0) != node->output(0) || node3->input(0) != node2->output(0))
2054                 continue;
2055 
2056             std::string direction = get_node_attr_s(*node, "direction");
2057             if (direction != "bidirectional")
2058                 continue;
2059 
2060             // 0 2 1 3
2061             std::vector<int> perm = get_node_attr_ai(*node2, "perm");
2062             if (perm.size() != 4)
2063                 continue;
2064 
2065             if (perm[0] != 0 || perm[1] != 2 || perm[2] != 1 || perm[3] != 3)
2066                 continue;
2067 
2068             std::vector<int> shape;
2069             if (node3->input_size() == 1)
2070             {
2071                 shape = get_node_attr_ai(*node3, "shape");
2072             }
2073             else
2074             {
2075                 // skip weight reshape
2076                 if (weights.find(node3->input(1)) == weights.end())
2077                     continue;
2078 
2079                 shape = get_node_attr_from_input_ai(weights[node3->input(1)]);
2080             }
2081 
2082             // 0 0 -1
2083             if (shape.size() != 3)
2084                 continue;
2085 
2086             if (shape[0] != 0 || shape[1] != 0 || shape[2] != -1)
2087                 continue;
2088 
2089             // reduce
2090             node2->set_op_type("noop_reducedncnn");
2091             node3->set_op_type("noop_reducedncnn");
2092 
2093             node_reference[node->output(0)] -= 1;
2094             node_reference[node2->output(0)] -= 1;
2095             if (node3->input_size() == 2)
2096             {
2097                 node_reference[node3->input(1)] -= 1;
2098             }
2099 
2100             blob_names.erase(node->output(0));
2101             blob_names.erase(node2->output(0));
2102 
2103             node->set_output(0, node3->output(0));
2104 
2105             reduced_node_count += 2;
2106             i += 2;
2107 
2108             if (i + 1 < node_count)
2109             {
2110                 if (node_reference[node3->output(0)] != 1)
2111                     continue;
2112 
2113                 onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 1);
2114 
2115                 if (node4->op_type() != "Transpose")
2116                     continue;
2117 
2118                 if (node4->input(0) != node->output(0))
2119                     continue;
2120 
2121                 // 1 0 2
2122                 std::vector<int> perm4 = get_node_attr_ai(*node4, "perm");
2123                 if (perm4.size() != 3)
2124                     continue;
2125 
2126                 if (perm4[0] != 1 || perm4[1] != 0 || perm4[2] != 2)
2127                     continue;
2128 
2129                 // reduce
2130                 node4->set_op_type("noop_reducedncnn");
2131 
2132                 node_reference[node->output(0)] -= 1;
2133 
2134                 blob_names.erase(node->output(0));
2135 
2136                 node->set_output(0, node4->output(0));
2137 
2138                 reduced_node_count += 1;
2139                 i += 1;
2140             }
2141         }
2142     }
2143 
2144     for (int i = 0; i < node_count; i++)
2145     {
2146         onnx::NodeProto* node = mutable_graph->mutable_node(i);
2147 
2148         // LSTM(uni) <= LSTM(uni) - Squeeze - Transpose
2149         if (node->op_type() == "LSTM" || node->op_type() == "GRU" || node->op_type() == "RNN")
2150         {
2151             if (node_reference[node->output(0)] != 1)
2152                 continue;
2153 
2154             if (i + 1 >= node_count)
2155                 continue;
2156 
2157             onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
2158 
2159             if (node2->op_type() != "Squeeze")
2160                 continue;
2161 
2162             if (node2->input(0) != node->output(0))
2163                 continue;
2164 
2165             std::string direction = get_node_attr_s(*node, "direction");
2166             if (direction == "bidirectional")
2167                 continue;
2168 
2169             // 1
2170             std::vector<int> axes = get_node_attr_ai(*node2, "axes");
2171             if (axes.size() != 1)
2172                 continue;
2173 
2174             if (axes[0] != 1)
2175                 continue;
2176 
2177             // reduce
2178             node2->set_op_type("noop_reducedncnn");
2179 
2180             node_reference[node->output(0)] -= 1;
2181 
2182             blob_names.erase(node->output(0));
2183 
2184             node->set_output(0, node2->output(0));
2185 
2186             reduced_node_count += 1;
2187             i += 1;
2188 
2189             if (i + 1 < node_count)
2190             {
2191                 if (node_reference[node2->output(0)] != 1)
2192                     continue;
2193 
2194                 onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 1);
2195 
2196                 if (node3->op_type() != "Transpose")
2197                     continue;
2198 
2199                 if (node3->input(0) != node->output(0))
2200                     continue;
2201 
2202                 // 1 0 2
2203                 std::vector<int> perm4 = get_node_attr_ai(*node3, "perm");
2204                 if (perm4.size() != 3)
2205                     continue;
2206 
2207                 if (perm4[0] != 1 || perm4[1] != 0 || perm4[2] != 2)
2208                     continue;
2209 
2210                 // reduce
2211                 node3->set_op_type("noop_reducedncnn");
2212 
2213                 node_reference[node->output(0)] -= 1;
2214 
2215                 blob_names.erase(node->output(0));
2216 
2217                 node->set_output(0, node3->output(0));
2218 
2219                 reduced_node_count += 1;
2220                 i += 1;
2221             }
2222         }
2223     }
2224 
2225     for (int i = 0; i < node_count; i++)
2226     {
2227         onnx::NodeProto* node = mutable_graph->mutable_node(i);
2228 
2229         // LSTM <= Transpose - LSTM
2230         if (node->op_type() == "Transpose")
2231         {
2232             if (node_reference[node->output(0)] != 1)
2233                 continue;
2234 
2235             // 1 0 2
2236             std::vector<int> perm = get_node_attr_ai(*node, "perm");
2237             if (perm.size() != 3)
2238                 continue;
2239 
2240             if (perm[0] != 1 || perm[1] != 0 || perm[2] != 2)
2241                 continue;
2242 
2243             if (i + 1 >= node_count)
2244                 continue;
2245 
2246             onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
2247 
2248             if (node2->op_type() != "LSTM" && node->op_type() != "GRU" && node->op_type() != "RNN")
2249                 continue;
2250 
2251             if (node2->input(0) != node->output(0))
2252                 continue;
2253 
2254             // reduce
2255             node->set_op_type("noop_reducedncnn");
2256 
2257             node_reference[node->output(0)] -= 1;
2258 
2259             blob_names.erase(node->output(0));
2260 
2261             node2->set_input(0, node->input(0));
2262 
2263             reduced_node_count += 1;
2264             i += 1;
2265         }
2266     }
2267 }
2268 
fuse_multiheadattention(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)2269 static void fuse_multiheadattention(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
2270 {
2271     int node_count = mutable_graph->node_size();
2272     for (int i = 0; i < node_count; i++)
2273     {
2274         onnx::NodeProto* node = mutable_graph->mutable_node(i);
2275 
2276         // MultiHeadAttention <= MatMul(q) - Add
2277         //                      - MatMul(k) - Add
2278         //                      - MatMul(v) - Add
2279         //                      - Mul
2280         //                      - Reshape - Transpose
2281         //                      - Reshape - Reshape - Transpose - Transpose
2282         //                      - Gemm - Softmax - Gemm - Transpose - Reshape - MatMul - Add
2283         if (node->op_type() == "MatMul")
2284         {
2285             if (i + 19 >= node_count)
2286                 continue;
2287 
2288             if (node_reference[node->output(0)] != 1)
2289                 continue;
2290 
2291             onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
2292             onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
2293             onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3);
2294             onnx::NodeProto* node5 = mutable_graph->mutable_node(i + 4);
2295             onnx::NodeProto* node6 = mutable_graph->mutable_node(i + 5);
2296             onnx::NodeProto* node7 = mutable_graph->mutable_node(i + 6);
2297             onnx::NodeProto* node8 = mutable_graph->mutable_node(i + 7);
2298             onnx::NodeProto* node9 = mutable_graph->mutable_node(i + 8);
2299             onnx::NodeProto* node10 = mutable_graph->mutable_node(i + 9);
2300             onnx::NodeProto* node11 = mutable_graph->mutable_node(i + 10);
2301             onnx::NodeProto* node12 = mutable_graph->mutable_node(i + 11);
2302             onnx::NodeProto* node13 = mutable_graph->mutable_node(i + 12);
2303             onnx::NodeProto* node14 = mutable_graph->mutable_node(i + 13);
2304             onnx::NodeProto* node15 = mutable_graph->mutable_node(i + 14);
2305             onnx::NodeProto* node16 = mutable_graph->mutable_node(i + 15);
2306             onnx::NodeProto* node17 = mutable_graph->mutable_node(i + 16);
2307             onnx::NodeProto* node18 = mutable_graph->mutable_node(i + 17);
2308             onnx::NodeProto* node19 = mutable_graph->mutable_node(i + 18);
2309             onnx::NodeProto* node20 = mutable_graph->mutable_node(i + 19);
2310 
2311             if (node2->op_type() != "Add" || node3->op_type() != "MatMul" || node4->op_type() != "Add" || node5->op_type() != "MatMul" || node6->op_type() != "Add" || node7->op_type() != "Mul" || node8->op_type() != "Reshape" || node9->op_type() != "Transpose" || node10->op_type() != "Reshape" || node11->op_type() != "Reshape" || node12->op_type() != "Transpose" || node13->op_type() != "Transpose" || node14->op_type() != "MatMul" || node15->op_type() != "Softmax" || node16->op_type() != "MatMul" || node17->op_type() != "Transpose" || node18->op_type() != "Reshape" || node19->op_type() != "MatMul" || node20->op_type() != "Add")
2312                 continue;
2313 
2314             if (node_reference[node2->output(0)] != 1 || node_reference[node3->output(0)] != 1 || node_reference[node4->output(0)] != 1 || node_reference[node5->output(0)] != 1 || node_reference[node6->output(0)] != 1 || node_reference[node7->output(0)] != 1 || node_reference[node8->output(0)] != 1 || node_reference[node9->output(0)] != 1 || node_reference[node10->output(0)] != 1 || node_reference[node11->output(0)] != 1 || node_reference[node12->output(0)] != 1 || node_reference[node13->output(0)] != 1 || node_reference[node14->output(0)] != 1 || node_reference[node15->output(0)] != 1 || node_reference[node16->output(0)] != 1 || node_reference[node17->output(0)] != 1 || node_reference[node18->output(0)] != 1 || node_reference[node19->output(0)] != 1)
2315                 continue;
2316 
2317             if (node2->input(0) != node->output(0) || node4->input(0) != node3->output(0) || node6->input(0) != node5->output(0) || node7->input(0) != node2->output(0) || node8->input(0) != node7->output(0) || node9->input(0) != node8->output(0) || node10->input(0) != node4->output(0) || node11->input(0) != node6->output(0) || node12->input(0) != node11->output(0) || node13->input(0) != node10->output(0) || node14->input(0) != node9->output(0) || node14->input(1) != node13->output(0) || node15->input(0) != node14->output(0) || node16->input(0) != node15->output(0) || node16->input(1) != node12->output(0) || node17->input(0) != node16->output(0) || node18->input(0) != node17->output(0) || node19->input(0) != node18->output(0) || node20->input(0) != node19->output(0))
2318                 continue;
2319 
2320             std::vector<float> q_B = get_node_attr_from_input_af(weights[node2->input(1)]);
2321             std::vector<float> k_B = get_node_attr_from_input_af(weights[node4->input(1)]);
2322             std::vector<float> v_B = get_node_attr_from_input_af(weights[node6->input(1)]);
2323             std::vector<float> o_B = get_node_attr_from_input_af(weights[node20->input(1)]);
2324 
2325             if (q_B.size() != k_B.size() || q_B.size() != v_B.size() || q_B.size() != o_B.size())
2326                 continue;
2327 
2328             int embed_dim = q_B.size();
2329 
2330             // 1 0 2
2331             std::vector<int> perm9 = get_node_attr_ai(*node9, "perm");
2332             std::vector<int> perm12 = get_node_attr_ai(*node12, "perm");
2333             if (perm9.size() != 3 || perm12.size() != 3)
2334                 continue;
2335 
2336             if (perm9[0] != 1 || perm9[1] != 0 || perm9[2] != 2 || perm12[0] != 1 || perm12[1] != 0 || perm12[2] != 2)
2337                 continue;
2338 
2339             // 1 2 0
2340             std::vector<int> perm13 = get_node_attr_ai(*node13, "perm");
2341             if (perm13.size() != 3)
2342                 continue;
2343 
2344             if (perm13[0] != 1 || perm13[1] != 2 || perm13[2] != 0)
2345                 continue;
2346 
2347             // 1 0 2
2348             std::vector<int> perm17 = get_node_attr_ai(*node17, "perm");
2349             if (perm17.size() != 3)
2350                 continue;
2351 
2352             if (perm17[0] != 1 || perm17[1] != 0 || perm17[2] != 2)
2353                 continue;
2354 
2355             int softmax_axis = get_node_attr_i(*node15, "axis");
2356             if (softmax_axis != 2)
2357                 continue;
2358 
2359             // 1/-1, seqlen * num_heads, embed_dim / num_heads
2360             std::vector<int> shape8;
2361             std::vector<int> shape10;
2362             std::vector<int> shape11;
2363             if (node8->input_size() == 1)
2364             {
2365                 shape8 = get_node_attr_ai(*node8, "shape");
2366             }
2367             else
2368             {
2369                 // skip weight reshape
2370                 if (weights.find(node8->input(1)) == weights.end())
2371                     continue;
2372 
2373                 shape8 = get_node_attr_from_input_ai(weights[node8->input(1)]);
2374             }
2375             if (node10->input_size() == 1)
2376             {
2377                 shape10 = get_node_attr_ai(*node10, "shape");
2378             }
2379             else
2380             {
2381                 // skip weight reshape
2382                 if (weights.find(node10->input(1)) == weights.end())
2383                     continue;
2384 
2385                 shape10 = get_node_attr_from_input_ai(weights[node10->input(1)]);
2386             }
2387             if (node11->input_size() == 1)
2388             {
2389                 shape11 = get_node_attr_ai(*node11, "shape");
2390             }
2391             else
2392             {
2393                 // skip weight reshape
2394                 if (weights.find(node11->input(1)) == weights.end())
2395                     continue;
2396 
2397                 shape11 = get_node_attr_from_input_ai(weights[node11->input(1)]);
2398             }
2399 
2400             if (shape8.size() != 3 || shape10.size() != 3 || shape11.size() != 3)
2401                 continue;
2402 
2403             if (shape8[1] != shape10[1] || shape8[1] != shape11[1] || shape8[2] != shape10[2] || shape8[2] != shape11[2])
2404                 continue;
2405 
2406             int num_heads = embed_dim / shape8[2];
2407 
2408             // 1, seqlen, embed_dim
2409             std::vector<int> shape18;
2410             if (node18->input_size() == 1)
2411             {
2412                 shape18 = get_node_attr_ai(*node18, "shape");
2413             }
2414             else
2415             {
2416                 // skip weight reshape
2417                 if (weights.find(node18->input(1)) == weights.end())
2418                     continue;
2419 
2420                 shape18 = get_node_attr_from_input_ai(weights[node18->input(1)]);
2421             }
2422 
2423             if (shape18.size() != 3)
2424                 continue;
2425 
2426             if (shape18[2] != embed_dim || shape18[1] * num_heads != shape8[1])
2427                 continue;
2428 
2429             // reduce
2430             node->set_op_type("noop_reducedncnn");
2431             node2->set_op_type("noop_reducedncnn");
2432             node3->set_op_type("noop_reducedncnn");
2433             node4->set_op_type("noop_reducedncnn");
2434             node5->set_op_type("noop_reducedncnn");
2435             node6->set_op_type("noop_reducedncnn");
2436             node7->set_op_type("noop_reducedncnn");
2437             node8->set_op_type("noop_reducedncnn");
2438             node9->set_op_type("noop_reducedncnn");
2439             node10->set_op_type("noop_reducedncnn");
2440             node11->set_op_type("noop_reducedncnn");
2441             node12->set_op_type("noop_reducedncnn");
2442             node13->set_op_type("noop_reducedncnn");
2443             node14->set_op_type("noop_reducedncnn");
2444             node15->set_op_type("noop_reducedncnn");
2445             node16->set_op_type("noop_reducedncnn");
2446             node17->set_op_type("noop_reducedncnn");
2447             node18->set_op_type("noop_reducedncnn");
2448             node19->set_op_type("noop_reducedncnn");
2449 
2450             node_reference[node2->input(0)] -= 1;
2451             node_reference[node4->input(0)] -= 1;
2452             node_reference[node6->input(0)] -= 1;
2453             node_reference[node7->input(0)] -= 1;
2454             node_reference[node7->input(1)] -= 1;
2455             node_reference[node8->input(0)] -= 1;
2456             if (node8->input_size() == 2)
2457             {
2458                 node_reference[node8->input(1)] -= 1;
2459             }
2460             node_reference[node9->input(0)] -= 1;
2461             node_reference[node10->input(0)] -= 1;
2462             if (node10->input_size() == 2)
2463             {
2464                 node_reference[node10->input(1)] -= 1;
2465             }
2466             node_reference[node11->input(0)] -= 1;
2467             if (node11->input_size() == 2)
2468             {
2469                 node_reference[node11->input(1)] -= 1;
2470             }
2471             node_reference[node12->input(0)] -= 1;
2472             node_reference[node13->input(0)] -= 1;
2473             node_reference[node14->input(0)] -= 1;
2474             node_reference[node14->input(1)] -= 1;
2475             node_reference[node15->input(0)] -= 1;
2476             node_reference[node16->input(0)] -= 1;
2477             node_reference[node16->input(1)] -= 1;
2478             node_reference[node17->input(0)] -= 1;
2479             node_reference[node18->input(0)] -= 1;
2480             if (node18->input_size() == 2)
2481             {
2482                 node_reference[node18->input(1)] -= 1;
2483             }
2484             node_reference[node19->input(0)] -= 1;
2485             node_reference[node20->input(0)] -= 1;
2486 
2487             blob_names.erase(node->output(0));
2488             blob_names.erase(node2->output(0));
2489             blob_names.erase(node3->output(0));
2490             blob_names.erase(node4->output(0));
2491             blob_names.erase(node5->output(0));
2492             blob_names.erase(node6->output(0));
2493             blob_names.erase(node7->output(0));
2494             blob_names.erase(node8->output(0));
2495             blob_names.erase(node9->output(0));
2496             blob_names.erase(node10->output(0));
2497             blob_names.erase(node11->output(0));
2498             blob_names.erase(node12->output(0));
2499             blob_names.erase(node13->output(0));
2500             blob_names.erase(node14->output(0));
2501             blob_names.erase(node15->output(0));
2502             blob_names.erase(node16->output(0));
2503             blob_names.erase(node17->output(0));
2504             blob_names.erase(node18->output(0));
2505             blob_names.erase(node19->output(0));
2506 
2507             std::string qw = node->input(1);
2508             std::string qb = node2->input(1);
2509             std::string kw = node3->input(1);
2510             std::string kb = node4->input(1);
2511             std::string vw = node5->input(1);
2512             std::string vb = node6->input(1);
2513             std::string ow = node19->input(1);
2514             std::string ob = node20->input(1);
2515 
2516             node20->set_op_type("MultiHeadAttention");
2517             node20->clear_input();
2518             node20->add_input(node->input(0));
2519             node20->add_input(node3->input(0));
2520             node20->add_input(node5->input(0));
2521             // q
2522             node20->add_input(qw);
2523             node20->add_input(qb);
2524             // k
2525             node20->add_input(kw);
2526             node20->add_input(kb);
2527             // v
2528             node20->add_input(vw);
2529             node20->add_input(vb);
2530             // out linear
2531             node20->add_input(ow);
2532             node20->add_input(ob);
2533 
2534             onnx::AttributeProto* attr_embed_dim = node20->add_attribute();
2535             attr_embed_dim->set_name("embed_dim");
2536             attr_embed_dim->set_i(embed_dim);
2537 
2538             onnx::AttributeProto* attr_num_heads = node20->add_attribute();
2539             attr_num_heads->set_name("num_heads");
2540             attr_num_heads->set_i(num_heads);
2541 
2542             reduced_node_count += 19;
2543             i += 19;
2544         }
2545     }
2546 
2547     for (int i = 0; i < node_count; i++)
2548     {
2549         onnx::NodeProto* node = mutable_graph->mutable_node(i);
2550 
2551         // MultiHeadAttention <= MatMul(qkv) - Add - Split
2552         //                      - Mul
2553         //                      - Reshape - Transpose
2554         //                      - Reshape - Reshape - Transpose - Transpose
2555         //                      - Gemm - Softmax - Gemm - Transpose - Reshape - MatMul - Add
2556         if (node->op_type() == "MatMul")
2557         {
2558             if (i + 16 >= node_count)
2559                 continue;
2560 
2561             if (node_reference[node->output(0)] != 1)
2562                 continue;
2563 
2564             onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
2565             onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
2566             onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3);
2567             onnx::NodeProto* node5 = mutable_graph->mutable_node(i + 4);
2568             onnx::NodeProto* node6 = mutable_graph->mutable_node(i + 5);
2569             onnx::NodeProto* node7 = mutable_graph->mutable_node(i + 6);
2570             onnx::NodeProto* node8 = mutable_graph->mutable_node(i + 7);
2571             onnx::NodeProto* node9 = mutable_graph->mutable_node(i + 8);
2572             onnx::NodeProto* node10 = mutable_graph->mutable_node(i + 9);
2573             onnx::NodeProto* node11 = mutable_graph->mutable_node(i + 10);
2574             onnx::NodeProto* node12 = mutable_graph->mutable_node(i + 11);
2575             onnx::NodeProto* node13 = mutable_graph->mutable_node(i + 12);
2576             onnx::NodeProto* node14 = mutable_graph->mutable_node(i + 13);
2577             onnx::NodeProto* node15 = mutable_graph->mutable_node(i + 14);
2578             onnx::NodeProto* node16 = mutable_graph->mutable_node(i + 15);
2579             onnx::NodeProto* node17 = mutable_graph->mutable_node(i + 16);
2580 
2581             if (node2->op_type() != "Add" || node3->op_type() != "Split" || node4->op_type() != "Mul" || node5->op_type() != "Reshape" || node6->op_type() != "Transpose" || node7->op_type() != "Reshape" || node8->op_type() != "Reshape" || node9->op_type() != "Transpose" || node10->op_type() != "Transpose" || node11->op_type() != "MatMul" || node12->op_type() != "Softmax" || node13->op_type() != "MatMul" || node14->op_type() != "Transpose" || node15->op_type() != "Reshape" || node16->op_type() != "MatMul" || node17->op_type() != "Add")
2582                 continue;
2583 
2584             if (node_reference[node2->output(0)] != 1 || node_reference[node3->output(0)] != 1 || node_reference[node3->output(1)] != 1 || node_reference[node3->output(2)] != 1 || node_reference[node4->output(0)] != 1 || node_reference[node5->output(0)] != 1 || node_reference[node6->output(0)] != 1 || node_reference[node7->output(0)] != 1 || node_reference[node8->output(0)] != 1 || node_reference[node9->output(0)] != 1 || node_reference[node10->output(0)] != 1 || node_reference[node11->output(0)] != 1 || node_reference[node12->output(0)] != 1 || node_reference[node13->output(0)] != 1 || node_reference[node14->output(0)] != 1 || node_reference[node15->output(0)] != 1 || node_reference[node16->output(0)] != 1)
2585                 continue;
2586 
2587             if (node2->input(0) != node->output(0) || node3->input(0) != node2->output(0) || node4->input(0) != node3->output(0) || node5->input(0) != node4->output(0) || node6->input(0) != node5->output(0) || node7->input(0) != node3->output(1) || node8->input(0) != node3->output(2) || node9->input(0) != node8->output(0) || node10->input(0) != node7->output(0) || node11->input(0) != node6->output(0) || node11->input(1) != node10->output(0) || node12->input(0) != node11->output(0) || node13->input(0) != node12->output(0) || node13->input(1) != node9->output(0) || node14->input(0) != node13->output(0) || node15->input(0) != node14->output(0) || node16->input(0) != node15->output(0) || node17->input(0) != node16->output(0))
2588                 continue;
2589 
2590             std::vector<float> qkv_B = get_node_attr_from_input_af(weights[node2->input(1)]);
2591             std::vector<float> o_B = get_node_attr_from_input_af(weights[node17->input(1)]);
2592 
2593             if (qkv_B.size() != o_B.size() * 3)
2594                 continue;
2595 
2596             int embed_dim = o_B.size();
2597 
2598             // 1 0 2
2599             std::vector<int> perm6 = get_node_attr_ai(*node6, "perm");
2600             std::vector<int> perm9 = get_node_attr_ai(*node9, "perm");
2601             if (perm6.size() != 3 || perm9.size() != 3)
2602                 continue;
2603 
2604             if (perm6[0] != 1 || perm6[1] != 0 || perm6[2] != 2 || perm9[0] != 1 || perm9[1] != 0 || perm9[2] != 2)
2605                 continue;
2606 
2607             // 1 2 0
2608             std::vector<int> perm10 = get_node_attr_ai(*node10, "perm");
2609             if (perm10.size() != 3)
2610                 continue;
2611 
2612             if (perm10[0] != 1 || perm10[1] != 2 || perm10[2] != 0)
2613                 continue;
2614 
2615             // 1 0 2
2616             std::vector<int> perm14 = get_node_attr_ai(*node14, "perm");
2617             if (perm14.size() != 3)
2618                 continue;
2619 
2620             if (perm14[0] != 1 || perm14[1] != 0 || perm14[2] != 2)
2621                 continue;
2622 
2623             int softmax_axis = get_node_attr_i(*node12, "axis");
2624             if (softmax_axis != 2)
2625                 continue;
2626 
2627             // 1/-1, seqlen * num_heads, embed_dim / num_heads
2628             std::vector<int> shape5;
2629             std::vector<int> shape7;
2630             std::vector<int> shape8;
2631             if (node5->input_size() == 1)
2632             {
2633                 shape5 = get_node_attr_ai(*node5, "shape");
2634             }
2635             else
2636             {
2637                 // skip weight reshape
2638                 if (weights.find(node5->input(1)) == weights.end())
2639                     continue;
2640 
2641                 shape5 = get_node_attr_from_input_ai(weights[node5->input(1)]);
2642             }
2643             if (node7->input_size() == 1)
2644             {
2645                 shape7 = get_node_attr_ai(*node7, "shape");
2646             }
2647             else
2648             {
2649                 // skip weight reshape
2650                 if (weights.find(node7->input(1)) == weights.end())
2651                     continue;
2652 
2653                 shape7 = get_node_attr_from_input_ai(weights[node7->input(1)]);
2654             }
2655             if (node8->input_size() == 1)
2656             {
2657                 shape8 = get_node_attr_ai(*node8, "shape");
2658             }
2659             else
2660             {
2661                 // skip weight reshape
2662                 if (weights.find(node8->input(1)) == weights.end())
2663                     continue;
2664 
2665                 shape8 = get_node_attr_from_input_ai(weights[node8->input(1)]);
2666             }
2667 
2668             if (shape5.size() != 3 || shape7.size() != 3 || shape8.size() != 3)
2669                 continue;
2670 
2671             if (shape5[1] != shape7[1] || shape5[1] != shape8[1] || shape5[2] != shape7[2] || shape5[2] != shape8[2])
2672                 continue;
2673 
2674             int num_heads = embed_dim / shape5[2];
2675 
2676             // 1, seqlen, embed_dim
2677             std::vector<int> shape15;
2678             if (node15->input_size() == 1)
2679             {
2680                 shape15 = get_node_attr_ai(*node15, "shape");
2681             }
2682             else
2683             {
2684                 // skip weight reshape
2685                 if (weights.find(node15->input(1)) == weights.end())
2686                     continue;
2687 
2688                 shape15 = get_node_attr_from_input_ai(weights[node15->input(1)]);
2689             }
2690 
2691             if (shape15.size() != 3)
2692                 continue;
2693 
2694             if (shape15[2] != embed_dim || shape15[1] * num_heads != shape8[1])
2695                 continue;
2696 
2697             // reduce
2698             node->set_op_type("noop_reducedncnn");
2699             node2->set_op_type("noop_reducedncnn");
2700             node3->set_op_type("noop_reducedncnn");
2701             node4->set_op_type("noop_reducedncnn");
2702             node5->set_op_type("noop_reducedncnn");
2703             node6->set_op_type("noop_reducedncnn");
2704             node7->set_op_type("noop_reducedncnn");
2705             node8->set_op_type("noop_reducedncnn");
2706             node9->set_op_type("noop_reducedncnn");
2707             node10->set_op_type("noop_reducedncnn");
2708             node11->set_op_type("noop_reducedncnn");
2709             node12->set_op_type("noop_reducedncnn");
2710             node13->set_op_type("noop_reducedncnn");
2711             node14->set_op_type("noop_reducedncnn");
2712             node15->set_op_type("noop_reducedncnn");
2713             node16->set_op_type("noop_reducedncnn");
2714 
2715             node_reference[node2->input(0)] -= 1;
2716             node_reference[node3->input(0)] -= 1;
2717             node_reference[node4->input(0)] -= 1;
2718             node_reference[node4->input(1)] -= 1;
2719             node_reference[node5->input(0)] -= 1;
2720             if (node5->input_size() == 2)
2721             {
2722                 node_reference[node5->input(1)] -= 1;
2723             }
2724             node_reference[node6->input(0)] -= 1;
2725             node_reference[node7->input(0)] -= 1;
2726             if (node7->input_size() == 2)
2727             {
2728                 node_reference[node7->input(1)] -= 1;
2729             }
2730             node_reference[node8->input(0)] -= 1;
2731             if (node8->input_size() == 2)
2732             {
2733                 node_reference[node8->input(1)] -= 1;
2734             }
2735             node_reference[node9->input(0)] -= 1;
2736             node_reference[node10->input(0)] -= 1;
2737             node_reference[node11->input(0)] -= 1;
2738             node_reference[node11->input(1)] -= 1;
2739             node_reference[node12->input(0)] -= 1;
2740             node_reference[node13->input(0)] -= 1;
2741             node_reference[node13->input(1)] -= 1;
2742             node_reference[node14->input(0)] -= 1;
2743             node_reference[node15->input(0)] -= 1;
2744             if (node15->input_size() == 2)
2745             {
2746                 node_reference[node15->input(1)] -= 1;
2747             }
2748             node_reference[node16->input(0)] -= 1;
2749             node_reference[node17->input(0)] -= 1;
2750 
2751             blob_names.erase(node->output(0));
2752             blob_names.erase(node2->output(0));
2753             blob_names.erase(node3->output(0));
2754             blob_names.erase(node3->output(1));
2755             blob_names.erase(node3->output(2));
2756             blob_names.erase(node4->output(0));
2757             blob_names.erase(node5->output(0));
2758             blob_names.erase(node6->output(0));
2759             blob_names.erase(node7->output(0));
2760             blob_names.erase(node8->output(0));
2761             blob_names.erase(node9->output(0));
2762             blob_names.erase(node10->output(0));
2763             blob_names.erase(node11->output(0));
2764             blob_names.erase(node12->output(0));
2765             blob_names.erase(node13->output(0));
2766             blob_names.erase(node14->output(0));
2767             blob_names.erase(node15->output(0));
2768             blob_names.erase(node16->output(0));
2769 
2770             std::string qkvw = node->input(1);
2771             std::string qkvb = node2->input(1);
2772             std::string ow = node16->input(1);
2773             std::string ob = node17->input(1);
2774 
2775             node17->set_op_type("MultiHeadAttention");
2776             node17->clear_input();
2777             node17->add_input(node->input(0));
2778             // qkv
2779             node17->add_input(qkvw);
2780             node17->add_input(qkvb);
2781             // out linear
2782             node17->add_input(ow);
2783             node17->add_input(ob);
2784 
2785             onnx::AttributeProto* attr_embed_dim = node17->add_attribute();
2786             attr_embed_dim->set_name("embed_dim");
2787             attr_embed_dim->set_i(embed_dim);
2788 
2789             onnx::AttributeProto* attr_num_heads = node17->add_attribute();
2790             attr_num_heads->set_name("num_heads");
2791             attr_num_heads->set_i(num_heads);
2792 
2793             reduced_node_count += 16;
2794             i += 16;
2795         }
2796     }
2797 }
2798 
fuse_binaryop_with_scalar(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)2799 static void fuse_binaryop_with_scalar(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
2800 {
2801     int node_count = mutable_graph->node_size();
2802     for (int i = 0; i < node_count; i++)
2803     {
2804         onnx::NodeProto* node = mutable_graph->mutable_node(i);
2805 
2806         // Add/Sub/Mul/Div/Min/Max/Pow(a, x)
2807         if (node->op_type() == "Add" || node->op_type() == "Sub" || node->op_type() == "Mul" || node->op_type() == "Div" || node->op_type() == "Max" || node->op_type() == "Min" || node->op_type() == "Pow")
2808         {
2809             if (weights.find(node->input(0)) == weights.end())
2810                 continue;
2811 
2812             const onnx::TensorProto& scalar_b = weights[node->input(0)];
2813             if (scalar_b.dims_size() != 0 || get_tensor_proto_data_size(scalar_b) != 1)
2814                 continue;
2815 
2816             if (node->op_type() == "Sub")
2817             {
2818                 node->set_op_type("RSub");
2819             }
2820             else if (node->op_type() == "Div")
2821             {
2822                 node->set_op_type("RDiv");
2823             }
2824 
2825             float b = get_node_attr_from_input_f(scalar_b);
2826 
2827             node_reference[node->input(0)] -= 1;
2828 
2829             std::string input = node->input(1);
2830 
2831             node->clear_input();
2832             node->add_input(input);
2833 
2834             onnx::AttributeProto* attr_with_scalar = node->add_attribute();
2835             attr_with_scalar->set_name("with_scalar");
2836             attr_with_scalar->set_i(1);
2837 
2838             onnx::AttributeProto* attr_b = node->add_attribute();
2839             attr_b->set_name("b");
2840             attr_b->set_f(b);
2841         }
2842     }
2843 
2844     for (int i = 0; i < node_count; i++)
2845     {
2846         onnx::NodeProto* node = mutable_graph->mutable_node(i);
2847 
2848         // Add/Sub/Mul/Div/Min/Max/Pow(x, b)
2849         if (node->op_type() == "Add" || node->op_type() == "Sub" || node->op_type() == "Mul" || node->op_type() == "Div" || node->op_type() == "Max" || node->op_type() == "Min" || node->op_type() == "Pow")
2850         {
2851             if (weights.find(node->input(1)) == weights.end())
2852                 continue;
2853 
2854             const onnx::TensorProto& scalar_b = weights[node->input(1)];
2855             if (scalar_b.dims_size() != 0 || get_tensor_proto_data_size(scalar_b) != 1)
2856                 continue;
2857 
2858             float b = get_node_attr_from_input_f(scalar_b);
2859 
2860             node_reference[node->input(1)] -= 1;
2861 
2862             std::string input = node->input(0);
2863 
2864             node->clear_input();
2865             node->add_input(input);
2866 
2867             onnx::AttributeProto* attr_with_scalar = node->add_attribute();
2868             attr_with_scalar->set_name("with_scalar");
2869             attr_with_scalar->set_i(1);
2870 
2871             onnx::AttributeProto* attr_b = node->add_attribute();
2872             attr_b->set_name("b");
2873             attr_b->set_f(b);
2874         }
2875     }
2876 }
2877 
main(int argc,char ** argv)2878 int main(int argc, char** argv)
2879 {
2880     if (!(argc == 2 || argc == 4))
2881     {
2882         fprintf(stderr, "Usage: %s [onnxpb] [ncnnparam] [ncnnbin]\n", argv[0]);
2883         return -1;
2884     }
2885 
2886     const char* onnxpb = argv[1];
2887     const char* ncnn_prototxt = argc == 4 ? argv[2] : "ncnn.param";
2888     const char* ncnn_modelbin = argc == 4 ? argv[3] : "ncnn.bin";
2889 
2890     onnx::ModelProto model;
2891 
2892     // load
2893     bool s1 = read_proto_from_binary(onnxpb, &model);
2894     if (!s1)
2895     {
2896         fprintf(stderr, "read_proto_from_binary failed\n");
2897         return -1;
2898     }
2899 
2900     FILE* pp = fopen(ncnn_prototxt, "wb");
2901     FILE* bp = fopen(ncnn_modelbin, "wb");
2902 
2903     // magic
2904     fprintf(pp, "7767517\n");
2905 
2906     const onnx::GraphProto& graph = model.graph();
2907     onnx::GraphProto* mutable_graph = model.mutable_graph();
2908 
2909     int node_count = graph.node_size();
2910 
2911     // node reference
2912     std::map<std::string, int> node_reference;
2913 
2914     // weight node and weight reshape node
2915     std::map<std::string, onnx::TensorProto> weights;
2916 
2917     for (int j = 0; j < graph.initializer_size(); j++)
2918     {
2919         const onnx::TensorProto& initializer = graph.initializer(j);
2920 
2921         //         fprintf(stderr, "weight = %s %d\n", initializer.name().c_str(), initializer.data_type());
2922 
2923         weights[initializer.name()] = initializer;
2924     }
2925 
2926     // topological sort
2927     {
2928         // name -> producer node index
2929         std::set<std::string> producers;
2930         for (int j = 0; j < graph.input_size(); j++)
2931         {
2932             const std::string& input_name = graph.input(j).name();
2933             producers.insert(input_name);
2934         }
2935 
2936         for (int i = 0; i < node_count;)
2937         {
2938             onnx::NodeProto* node = mutable_graph->mutable_node(i);
2939 
2940             bool swapnode = false;
2941             std::string missing_input_name;
2942             for (int j = 0; j < (int)node->input_size(); j++)
2943             {
2944                 const std::string& input_name = node->input(j);
2945                 if (input_name.empty())
2946                     continue;
2947 
2948                 if (producers.find(input_name) == producers.end() && weights.find(input_name) == weights.end())
2949                 {
2950                     swapnode = true;
2951                     missing_input_name = input_name;
2952                     break;
2953                 }
2954             }
2955 
2956             if (!swapnode)
2957             {
2958                 for (int j = 0; j < (int)node->output_size(); j++)
2959                 {
2960                     const std::string& output_name = node->output(j);
2961                     if (output_name.empty())
2962                         continue;
2963 
2964                     producers.insert(output_name);
2965                 }
2966 
2967                 i++;
2968                 continue;
2969             }
2970 
2971             // find node that produce missing_input_name
2972             int q = i + 1;
2973             for (; q < node_count; q++)
2974             {
2975                 onnx::NodeProto* nodeq = mutable_graph->mutable_node(q);
2976                 bool found = false;
2977                 for (int j = 0; j < (int)nodeq->output_size(); j++)
2978                 {
2979                     const std::string& output_name = nodeq->output(j);
2980                     if (output_name == missing_input_name)
2981                     {
2982                         found = true;
2983                         break;
2984                     }
2985                 }
2986 
2987                 if (found)
2988                     break;
2989             }
2990 
2991             if (q == node_count)
2992             {
2993                 fprintf(stderr, "cannot find node produces %s but node %d requires it\n", missing_input_name.c_str(), i);
2994                 return -1;
2995             }
2996 
2997             // fprintf(stderr, "swap %d %d\n", i, q);
2998             // swap this node with q
2999             onnx::NodeProto* nodeq = mutable_graph->mutable_node(q);
3000             onnx::NodeProto tmp = *node;
3001             *node = *nodeq;
3002             *nodeq = tmp;
3003         }
3004     }
3005 
3006     // global definition line
3007     // [layer count] [blob count]
3008     std::set<std::string> blob_names;
3009     for (int i = 0; i < node_count; i++)
3010     {
3011         const onnx::NodeProto& node = graph.node(i);
3012 
3013         const std::string& op = node.op_type();
3014 
3015         std::string name = node.name();
3016         if (name.empty())
3017         {
3018             name = node.output(0);
3019         }
3020 
3021         if (op == "Constant")
3022         {
3023             onnx::TensorProto tensor = get_node_attr_tensor(node, "value");
3024             weights[node.output(0)] = tensor;
3025         }
3026 
3027         for (int j = 0; j < (int)node.input_size(); j++)
3028         {
3029             const std::string& input_name = node.input(j);
3030 
3031             blob_names.insert(input_name);
3032 
3033             if (node_reference.find(input_name) == node_reference.end())
3034             {
3035                 node_reference[input_name] = 1;
3036             }
3037             else
3038             {
3039                 node_reference[input_name] = node_reference[input_name] + 1;
3040             }
3041         }
3042 
3043         if (op == "Dropout")
3044         {
3045             const std::string& output_name = node.output(0);
3046             blob_names.insert(output_name);
3047             node_reference[output_name] = 0;
3048             continue;
3049         }
3050 
3051         for (int j = 0; j < (int)node.output_size(); j++)
3052         {
3053             const std::string& output_name = node.output(j);
3054 
3055             blob_names.insert(output_name);
3056 
3057             node_reference[output_name] = 0;
3058         }
3059     }
3060 
3061     // include Input node
3062     int input_node_count = 0;
3063     for (int j = 0; j < graph.input_size(); j++)
3064     {
3065         const std::string& input_name = graph.input(j).name();
3066 
3067         // check weight
3068         if (weights.find(input_name) != weights.end())
3069             continue;
3070 
3071         blob_names.insert(input_name);
3072 
3073         input_node_count++;
3074     }
3075 
3076     //     for (auto a: node_reference)
3077     //     {
3078     //         fprintf(stderr, "a = %s %d\n", a.first.c_str(), a.second);
3079     //     }
3080 
3081     // op chain fusion
3082     int reduced_node_count = 0;
3083     fuse_weight_reshape(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3084     fuse_weight_transpose(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3085     fuse_shufflechannel(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3086     fuse_shufflechannel_split(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3087     fuse_hardsigmoid(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3088     fuse_hardswish(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3089     fuse_swish(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3090     fuse_batchnorm1d_squeeze_unsqueeze(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3091     fuse_unsqueeze_prelu(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3092     fuse_normalize(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3093     fuse_groupnorm(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3094     fuse_layernorm(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3095     fuse_flatten(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3096     fuse_pixelshuffle(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3097     fuse_reorg(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3098     fuse_expand_broadcast(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3099     fuse_lstm_gru_rnn(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3100     fuse_multiheadattention(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3101     fuse_binaryop_with_scalar(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3102 
3103     // reduce common const weight node_reference
3104     for (int i = 0; i < node_count; i++)
3105     {
3106         const onnx::NodeProto& node = graph.node(i);
3107 
3108         const std::string& op = node.op_type();
3109 
3110         if (op == "BatchNormalization")
3111         {
3112             node_reference[node.input(1)] -= 1;
3113             node_reference[node.input(2)] -= 1;
3114             node_reference[node.input(3)] -= 1;
3115             node_reference[node.input(4)] -= 1;
3116         }
3117         else if (op == "BiasGelu")
3118         {
3119             node_reference[node.input(1)] -= 1;
3120         }
3121         else if (op == "Clip")
3122         {
3123             if (node.input_size() == 3)
3124             {
3125                 node_reference[node.input(1)] -= 1;
3126                 node_reference[node.input(2)] -= 1;
3127             }
3128         }
3129         else if (op == "Conv")
3130         {
3131             node_reference[node.input(1)] -= 1;
3132             if (node.input_size() == 3)
3133             {
3134                 node_reference[node.input(2)] -= 1;
3135             }
3136         }
3137         else if (op == "ConvTranspose")
3138         {
3139             node_reference[node.input(1)] -= 1;
3140             if (node.input_size() == 3)
3141             {
3142                 node_reference[node.input(2)] -= 1;
3143             }
3144         }
3145         else if (op == "EmbedLayerNormalization")
3146         {
3147             node_reference[node.input(1)] -= 1;
3148             node_reference[node.input(2)] -= 1;
3149             node_reference[node.input(3)] -= 1;
3150             node_reference[node.input(4)] -= 1;
3151             node_reference[node.input(5)] -= 1;
3152             node_reference[node.input(6)] -= 1;
3153         }
3154         else if (op == "Gemm")
3155         {
3156             float alpha = get_node_attr_f(node, "alpha", 1.f);
3157             float beta = get_node_attr_f(node, "beta", 1.f);
3158             int transA = get_node_attr_i(node, "transA", 0);
3159             int transB = get_node_attr_i(node, "transB", 0);
3160 
3161             if (alpha == 1.f && beta == 1.f && transA == 0 && transB == 1)
3162             {
3163                 // InnerProduct-like A * B + C
3164                 node_reference[node.input(1)] -= 1;
3165                 node_reference[node.input(2)] -= 1;
3166             }
3167         }
3168         else if (op == "GroupNorm")
3169         {
3170             int affine = get_node_attr_i(node, "affine", 1);
3171             if (affine)
3172             {
3173                 node_reference[node.input(1)] -= 1;
3174                 node_reference[node.input(2)] -= 1;
3175             }
3176         }
3177         else if (op == "GRU")
3178         {
3179             for (int j = 1; j < node.input_size(); j++)
3180             {
3181                 node_reference[node.input(j)] -= 1;
3182             }
3183         }
3184         else if (op == "InstanceNormalization")
3185         {
3186             node_reference[node.input(1)] -= 1;
3187             node_reference[node.input(2)] -= 1;
3188         }
3189         else if (op == "LayerNorm")
3190         {
3191             int affine = get_node_attr_i(node, "affine", 1);
3192             if (affine)
3193             {
3194                 node_reference[node.input(1)] -= 1;
3195                 node_reference[node.input(2)] -= 1;
3196             }
3197         }
3198         else if (op == "LSTM")
3199         {
3200             for (int j = 1; j < node.input_size(); j++)
3201             {
3202                 node_reference[node.input(j)] -= 1;
3203             }
3204         }
3205         else if (op == "MatMul")
3206         {
3207             if (weights.find(node.input(1)) != weights.end() && weights[node.input(1)].dims_size() == 2)
3208             {
3209                 // InnerProduct
3210                 node_reference[node.input(1)] -= 1;
3211             }
3212         }
3213         else if (op == "MultiHeadAttention")
3214         {
3215             if (node.input_size() == 5)
3216             {
3217                 node_reference[node.input(1)] -= 1;
3218                 node_reference[node.input(2)] -= 1;
3219                 node_reference[node.input(3)] -= 1;
3220                 node_reference[node.input(4)] -= 1;
3221             }
3222             else
3223             {
3224                 node_reference[node.input(3)] -= 1;
3225                 node_reference[node.input(4)] -= 1;
3226                 node_reference[node.input(5)] -= 1;
3227                 node_reference[node.input(6)] -= 1;
3228                 node_reference[node.input(7)] -= 1;
3229                 node_reference[node.input(8)] -= 1;
3230                 node_reference[node.input(9)] -= 1;
3231                 node_reference[node.input(10)] -= 1;
3232             }
3233         }
3234         else if (op == "Pad")
3235         {
3236             if (node.input_size() >= 2)
3237             {
3238                 node_reference[node.input(1)] -= 1;
3239             }
3240         }
3241         else if (op == "PRelu")
3242         {
3243             node_reference[node.input(1)] -= 1;
3244         }
3245         else if (op == "Reshape")
3246         {
3247             if (node.input_size() >= 2)
3248             {
3249                 node_reference[node.input(1)] -= 1;
3250             }
3251         }
3252         else if (op == "Resize")
3253         {
3254             if (node.input_size() == 2)
3255             {
3256                 // opset 10
3257                 node_reference[node.input(1)] -= 1;
3258             }
3259             else
3260             {
3261                 // opset 11+
3262                 node_reference[node.input(1)] -= 1;
3263                 node_reference[node.input(2)] -= 1;
3264                 if (node.input_size() >= 4)
3265                 {
3266                     node_reference[node.input(3)] -= 1;
3267                 }
3268             }
3269         }
3270         else if (op == "RNN")
3271         {
3272             for (int j = 1; j < node.input_size(); j++)
3273             {
3274                 node_reference[node.input(j)] -= 1;
3275             }
3276         }
3277         else if (op == "SkipLayerNormalization")
3278         {
3279             node_reference[node.input(2)] -= 1;
3280             node_reference[node.input(3)] -= 1;
3281             node_reference[node.input(4)] -= 1;
3282         }
3283         else if (op == "Slice")
3284         {
3285             if (node.input_size() >= 2)
3286             {
3287                 node_reference[node.input(1)] -= 1;
3288                 node_reference[node.input(2)] -= 1;
3289                 if (node.input_size() >= 4)
3290                     node_reference[node.input(3)] -= 1;
3291                 if (node.input_size() >= 5)
3292                     node_reference[node.input(4)] -= 1;
3293             }
3294         }
3295         else if (op == "Upsample")
3296         {
3297             if (node.input_size() >= 2)
3298             {
3299                 node_reference[node.input(1)] -= 1;
3300             }
3301         }
3302         else if (op == "adaptive_avg_pool2d" || op == "adaptive_max_pool2d")
3303         {
3304             if (node.input_size() >= 2)
3305             {
3306                 node_reference[node.input(1)] -= 1;
3307             }
3308         }
3309     }
3310 
3311     //         for (auto a: node_reference)
3312     //         {
3313     //             fprintf(stderr, "b = %s %d\n", a.first.c_str(), a.second);
3314     //         }
3315 
3316     // count all weight node with zero reference
3317     int zero_reference_weight_node_count = 0;
3318     for (std::map<std::string, onnx::TensorProto>::iterator it = weights.begin(); it != weights.end(); it++)
3319     {
3320         const std::string& input_name = it->first;
3321 
3322         // there may be some weight nodes in initializer but none of the graph node use them
3323         // add them to blob_names so we could get proper blob count later
3324         blob_names.insert(input_name);
3325 
3326         int refcount = node_reference[input_name];
3327         if (refcount == 0)
3328             zero_reference_weight_node_count++;
3329     }
3330 
3331     // we always treat constant node as weight or binaryop_weights
3332     // do not count it twice for layer_count
3333     int constant_node_count_moved_to_weight = 0;
3334     for (int i = 0; i < node_count; i++)
3335     {
3336         const onnx::NodeProto& node = graph.node(i);
3337 
3338         const std::string& op = node.op_type();
3339 
3340         if (op == "Constant")
3341         {
3342             constant_node_count_moved_to_weight++;
3343         }
3344     }
3345 
3346     // some op may have anonymous input
3347     // LSTM sequence_lens
3348     blob_names.erase("");
3349     node_reference.erase("");
3350 
3351     // remove node_reference entry with reference equals to one
3352     int split_layer_count = 0;
3353     int splitncnn_blob_count = 0;
3354     // split node reference
3355     std::map<std::string, int> split_node_reference;
3356     for (std::map<std::string, int>::iterator it = node_reference.begin(); it != node_reference.end(); it++)
3357     {
3358         if (it->second > 1)
3359         {
3360             split_layer_count++;
3361             splitncnn_blob_count += it->second;
3362 
3363             split_node_reference[it->first] = it->second;
3364         }
3365     }
3366 
3367     fprintf(pp, "%zu %zu\n", node_count - constant_node_count_moved_to_weight + weights.size() - zero_reference_weight_node_count - reduced_node_count + input_node_count + split_layer_count, blob_names.size() - zero_reference_weight_node_count + splitncnn_blob_count);
3368 
3369     int internal_split = 0;
3370 
3371     // place Input at the beginning
3372     for (int j = 0; j < graph.input_size(); j++)
3373     {
3374         const std::string& input_name = graph.input(j).name();
3375 
3376         // check weight
3377         if (weights.find(input_name) != weights.end())
3378             continue;
3379 
3380         fprintf(pp, "%-16s %-24s 0 1 %s\n", "Input", input_name.c_str(), input_name.c_str());
3381 
3382         int refcount = node_reference[input_name];
3383         if (refcount <= 1)
3384         {
3385             continue;
3386         }
3387 
3388         char splitname[256];
3389         sprintf(splitname, "splitncnn_input%d", j);
3390         fprintf(pp, "%-16s %-24s %d %d", "Split", splitname, 1, refcount);
3391         fprintf(pp, " %s", input_name.c_str());
3392 
3393         for (int k = 0; k < refcount; k++)
3394         {
3395             fprintf(pp, " %s_splitncnn_%d", input_name.c_str(), k);
3396         }
3397         fprintf(pp, "\n");
3398     }
3399 
3400     // place MemoryData next
3401     for (std::map<std::string, onnx::TensorProto>::iterator weight_it = weights.begin(); weight_it != weights.end(); weight_it++)
3402     {
3403         const std::string& input_name = weight_it->first;
3404 
3405         int refcount = node_reference[input_name];
3406         if (refcount == 0)
3407         {
3408             continue;
3409         }
3410 
3411         fprintf(pp, "%-16s %-24s 0 1 %s", "MemoryData", input_name.c_str(), input_name.c_str());
3412 
3413         const onnx::TensorProto& M = weights[input_name];
3414 
3415         if (M.dims_size() == 0)
3416         {
3417             fprintf(pp, " 0=%d", get_tensor_proto_data_size(M));
3418         }
3419         else if (M.dims_size() == 1)
3420         {
3421             fprintf(pp, " 0=%d", (int)M.dims(0));
3422         }
3423         else if (M.dims_size() == 2)
3424         {
3425             fprintf(pp, " 0=%d", (int)M.dims(1));
3426             if (M.dims(0) != 1)
3427             {
3428                 fprintf(pp, " 1=%d", (int)M.dims(0));
3429             }
3430         }
3431         else if (M.dims_size() == 3)
3432         {
3433             fprintf(pp, " 0=%d", (int)M.dims(2));
3434             fprintf(pp, " 1=%d", (int)M.dims(1));
3435             if (M.dims(0) != 1)
3436             {
3437                 fprintf(pp, " 2=%d", (int)M.dims(0));
3438             }
3439         }
3440         else if (M.dims_size() == 4)
3441         {
3442             fprintf(pp, " 0=%d", (int)M.dims(3));
3443             fprintf(pp, " 1=%d", (int)M.dims(2));
3444             fprintf(pp, " 2=%d", (int)M.dims(1));
3445         }
3446 
3447         fprintf(pp, "\n");
3448 
3449         fwrite_tensor_proto_data(M, bp);
3450 
3451         if (refcount <= 1)
3452         {
3453             continue;
3454         }
3455 
3456         char splitname[256];
3457         sprintf(splitname, "splitncnn_%d", internal_split);
3458         fprintf(pp, "%-16s %-24s %d %d", "Split", splitname, 1, refcount);
3459 
3460         fprintf(pp, " %s", input_name.c_str());
3461 
3462         for (int k = 0; k < refcount; k++)
3463         {
3464             fprintf(pp, " %s_splitncnn_%d", input_name.c_str(), k);
3465         }
3466         fprintf(pp, "\n");
3467 
3468         internal_split++;
3469     }
3470 
3471     for (int i = 0; i < node_count; i++)
3472     {
3473         const onnx::NodeProto& node = graph.node(i);
3474 
3475         const std::string& op = node.op_type();
3476 
3477         //         fprintf(stderr, "op = %s\n", op.c_str());
3478 
3479         if (op == "noop_reducedncnn")
3480         {
3481             continue;
3482         }
3483 
3484         std::string name = node.name();
3485         if (name.empty())
3486         {
3487             name = node.output(0);
3488         }
3489 
3490         int input_size = node.input_size();
3491         int output_size = node.output_size();
3492 
3493         for (int j = 0; j < (int)node.input_size(); j++)
3494         {
3495             const std::string& input_name = node.input(j);
3496 
3497             // check weight
3498             if (weights.find(input_name) != weights.end() && node_reference[input_name] == 0)
3499             {
3500                 input_size--;
3501             }
3502 
3503             if (input_name.empty())
3504             {
3505                 input_size--;
3506             }
3507 
3508             //             fprintf(stderr, "  input = %s\n", input_name.c_str());
3509         }
3510         /*
3511         for (int j=0; j<(int)node.output_size(); j++)
3512         {
3513             const std::string& output_name = node.output(j);
3514             fprintf(stderr, "  output = %s\n", output_name.c_str());
3515         }
3516         */
3517 
3518         if (op == "Abs")
3519         {
3520             fprintf(pp, "%-16s", "UnaryOp");
3521         }
3522         else if (op == "Acos")
3523         {
3524             fprintf(pp, "%-16s", "UnaryOp");
3525         }
3526         else if (op == "Add")
3527         {
3528             fprintf(pp, "%-16s", "BinaryOp");
3529         }
3530         else if (op == "Asin")
3531         {
3532             fprintf(pp, "%-16s", "UnaryOp");
3533         }
3534         else if (op == "Atan")
3535         {
3536             fprintf(pp, "%-16s", "UnaryOp");
3537         }
3538         else if (op == "AveragePool" || op == "MaxPool")
3539         {
3540             std::vector<int> kernel_shape = get_node_attr_ai(node, "kernel_shape");
3541             if (kernel_shape.size() == 1)
3542             {
3543                 fprintf(pp, "%-16s", "Pooling1D");
3544             }
3545             else
3546             {
3547                 fprintf(pp, "%-16s", "Pooling");
3548             }
3549         }
3550         else if (op == "BatchNormalization")
3551         {
3552             fprintf(pp, "%-16s", "BatchNorm");
3553         }
3554         else if (op == "BiasGelu")
3555         {
3556             fprintf(pp, "%-16s", "BiasGelu");
3557         }
3558         else if (op == "Ceil")
3559         {
3560             fprintf(pp, "%-16s", "UnaryOp");
3561         }
3562         else if (op == "Clip")
3563         {
3564             fprintf(pp, "%-16s", "Clip");
3565         }
3566         else if (op == "Concat")
3567         {
3568             fprintf(pp, "%-16s", "Concat");
3569         }
3570         else if (op == "Constant")
3571         {
3572             continue;
3573         }
3574         else if (op == "Conv")
3575         {
3576             std::vector<int> kernel_shape = get_node_attr_ai(node, "kernel_shape");
3577             if (kernel_shape.size() == 1)
3578             {
3579                 fprintf(pp, "%-16s", "Convolution1D");
3580             }
3581             else
3582             {
3583                 int group = get_node_attr_i(node, "group", 1);
3584                 if (group > 1)
3585                 {
3586                     fprintf(pp, "%-16s", "ConvolutionDepthWise");
3587                 }
3588                 else
3589                 {
3590                     fprintf(pp, "%-16s", "Convolution");
3591                 }
3592             }
3593         }
3594         else if (op == "ConvTranspose")
3595         {
3596             int group = get_node_attr_i(node, "group", 1);
3597             if (group > 1)
3598             {
3599                 fprintf(pp, "%-16s", "DeconvolutionDepthWise");
3600             }
3601             else
3602             {
3603                 fprintf(pp, "%-16s", "Deconvolution");
3604             }
3605         }
3606         else if (op == "Cos")
3607         {
3608             fprintf(pp, "%-16s", "UnaryOp");
3609         }
3610         else if (op == "DepthToSpace")
3611         {
3612             fprintf(pp, "%-16s", "PixelShuffle");
3613         }
3614         else if (op == "Div")
3615         {
3616             fprintf(pp, "%-16s", "BinaryOp");
3617         }
3618         else if (op == "Dropout")
3619         {
3620             fprintf(pp, "%-16s", "Dropout");
3621             output_size = 1;
3622         }
3623         else if (op == "Elu")
3624         {
3625             fprintf(pp, "%-16s", "ELU");
3626         }
3627         else if (op == "EmbedLayerNormalization")
3628         {
3629             fprintf(pp, "%-16s", "EmbedLayerNormalization");
3630         }
3631         else if (op == "Exp")
3632         {
3633             fprintf(pp, "%-16s", "UnaryOp");
3634         }
3635         else if (op == "Flatten")
3636         {
3637             fprintf(pp, "%-16s", "Flatten");
3638         }
3639         else if (op == "Floor")
3640         {
3641             fprintf(pp, "%-16s", "UnaryOp");
3642         }
3643         else if (op == "Gemm")
3644         {
3645             float alpha = get_node_attr_f(node, "alpha", 1.f);
3646             float beta = get_node_attr_f(node, "beta", 1.f);
3647             int transA = get_node_attr_i(node, "transA", 0);
3648             int transB = get_node_attr_i(node, "transB", 0);
3649 
3650             if (alpha == 1.f && beta == 1.f && transA == 0 && transB == 1)
3651             {
3652                 // InnerProduct-like A * B + C
3653                 fprintf(pp, "%-16s", "InnerProduct");
3654             }
3655             else
3656             {
3657                 fprintf(pp, "%-16s", "Gemm");
3658             }
3659         }
3660         else if (op == "GlobalAveragePool")
3661         {
3662             fprintf(pp, "%-16s", "Pooling");
3663         }
3664         else if (op == "GlobalMaxPool")
3665         {
3666             fprintf(pp, "%-16s", "Pooling");
3667         }
3668         else if (op == "adaptive_avg_pool2d" || op == "adaptive_max_pool2d")
3669         {
3670             fprintf(pp, "%-16s", "Pooling");
3671         }
3672         else if (op == "GroupNorm")
3673         {
3674             fprintf(pp, "%-16s", "GroupNorm");
3675         }
3676         else if (op == "GRU")
3677         {
3678             fprintf(pp, "%-16s", "GRU");
3679         }
3680         else if (op == "HardSigmoid")
3681         {
3682             fprintf(pp, "%-16s", "HardSigmoid");
3683         }
3684         else if (op == "HardSwish")
3685         {
3686             fprintf(pp, "%-16s", "HardSwish");
3687         }
3688         else if (op == "ImageScaler")
3689         {
3690             fprintf(pp, "%-16s", "Scale");
3691         }
3692         else if (op == "InstanceNormalization")
3693         {
3694             fprintf(pp, "%-16s", "InstanceNorm");
3695         }
3696         else if (op == "LayerNorm")
3697         {
3698             fprintf(pp, "%-16s", "LayerNorm");
3699         }
3700         else if (op == "LeakyRelu")
3701         {
3702             fprintf(pp, "%-16s", "ReLU");
3703         }
3704         else if (op == "Log")
3705         {
3706             fprintf(pp, "%-16s", "UnaryOp");
3707         }
3708         else if (op == "LRN")
3709         {
3710             fprintf(pp, "%-16s", "LRN");
3711         }
3712         else if (op == "LSTM")
3713         {
3714             fprintf(pp, "%-16s", "LSTM");
3715         }
3716         else if (op == "MatMul")
3717         {
3718             if (weights.find(node.input(1)) != weights.end() && weights[node.input(1)].dims_size() == 2)
3719             {
3720                 fprintf(pp, "%-16s", "InnerProduct");
3721             }
3722             else
3723             {
3724                 fprintf(pp, "%-16s", "Gemm");
3725             }
3726         }
3727         else if (op == "Max")
3728         {
3729             fprintf(pp, "%-16s", "BinaryOp");
3730         }
3731         else if (op == "Min")
3732         {
3733             fprintf(pp, "%-16s", "BinaryOp");
3734         }
3735         else if (op == "Mul")
3736         {
3737             fprintf(pp, "%-16s", "BinaryOp");
3738         }
3739         else if (op == "MultiHeadAttention")
3740         {
3741             fprintf(pp, "%-16s", "MultiHeadAttention");
3742         }
3743         else if (op == "Neg")
3744         {
3745             fprintf(pp, "%-16s", "UnaryOp");
3746         }
3747         else if (op == "Normalize")
3748         {
3749             fprintf(pp, "%-16s", "Normalize");
3750         }
3751         else if (op == "Pad")
3752         {
3753             fprintf(pp, "%-16s", "Padding");
3754         }
3755         else if (op == "PixelShuffle")
3756         {
3757             fprintf(pp, "%-16s", "PixelShuffle");
3758         }
3759         else if (op == "Pow")
3760         {
3761             fprintf(pp, "%-16s", "BinaryOp");
3762         }
3763         else if (op == "PRelu")
3764         {
3765             fprintf(pp, "%-16s", "PReLU");
3766         }
3767         else if (op == "Reciprocal")
3768         {
3769             fprintf(pp, "%-16s", "UnaryOp");
3770         }
3771         else if (op == "ReduceMax" || op == "ReduceMin" || op == "ReduceMean" || op == "ReduceProd" || op == "ReduceSum" || op == "ReduceSumSquare" || op == "ReduceL1" || op == "ReduceL2" || op == "ReduceLogSum" || op == "ReduceLogSumExp")
3772         {
3773             fprintf(pp, "%-16s", "Reduction");
3774         }
3775         else if (op == "Relu")
3776         {
3777             fprintf(pp, "%-16s", "ReLU");
3778         }
3779         else if (op == "Reorg")
3780         {
3781             fprintf(pp, "%-16s", "Reorg");
3782         }
3783         else if (op == "Reshape")
3784         {
3785             fprintf(pp, "%-16s", "Reshape");
3786         }
3787         else if (op == "RNN")
3788         {
3789             fprintf(pp, "%-16s", "RNN");
3790         }
3791         else if (op == "RDiv")
3792         {
3793             fprintf(pp, "%-16s", "BinaryOp");
3794         }
3795         else if (op == "RSub")
3796         {
3797             fprintf(pp, "%-16s", "BinaryOp");
3798         }
3799         else if (op == "ShuffleChannel")
3800         {
3801             fprintf(pp, "%-16s", "ShuffleChannel");
3802         }
3803         else if (op == "Sigmoid")
3804         {
3805             fprintf(pp, "%-16s", "Sigmoid");
3806         }
3807         else if (op == "Sin")
3808         {
3809             fprintf(pp, "%-16s", "UnaryOp");
3810         }
3811         else if (op == "SkipLayerNormalization")
3812         {
3813             fprintf(pp, "%-16s", "SkipLayerNormalization");
3814         }
3815         else if (op == "Slice")
3816         {
3817             fprintf(pp, "%-16s", "Crop");
3818         }
3819         else if (op == "Softmax")
3820         {
3821             fprintf(pp, "%-16s", "Softmax");
3822         }
3823         else if (op == "Softplus")
3824         {
3825             fprintf(pp, "%-16s", "Softplus");
3826         }
3827         else if (op == "Split")
3828         {
3829             fprintf(pp, "%-16s", "Slice");
3830         }
3831         else if (op == "Sqrt")
3832         {
3833             fprintf(pp, "%-16s", "UnaryOp");
3834         }
3835         else if (op == "Squeeze")
3836         {
3837             fprintf(pp, "%-16s", "Squeeze");
3838         }
3839         else if (op == "Sub")
3840         {
3841             fprintf(pp, "%-16s", "BinaryOp");
3842         }
3843         else if (op == "Sum")
3844         {
3845             fprintf(pp, "%-16s", "Eltwise");
3846         }
3847         else if (op == "Swish")
3848         {
3849             fprintf(pp, "%-16s", "Swish");
3850         }
3851         else if (op == "Tan")
3852         {
3853             fprintf(pp, "%-16s", "UnaryOp");
3854         }
3855         else if (op == "Tanh")
3856         {
3857             fprintf(pp, "%-16s", "UnaryOp");
3858         }
3859         else if (op == "Transpose")
3860         {
3861             fprintf(pp, "%-16s", "Permute");
3862         }
3863         else if (op == "Upsample" || op == "Resize")
3864         {
3865             fprintf(pp, "%-16s", "Interp");
3866         }
3867         else if (op == "Unsqueeze")
3868         {
3869             fprintf(pp, "%-16s", "ExpandDims");
3870         }
3871         else
3872         {
3873             // TODO
3874             fprintf(stderr, "%s not supported yet!\n", op.c_str());
3875             fprintf(pp, "%-16s", op.c_str());
3876         }
3877 
3878         fprintf(pp, " %-24s %d %d", name.c_str(), input_size, output_size);
3879 
3880         for (int j = 0; j < (int)node.input_size(); j++)
3881         {
3882             std::string input_name = node.input(j);
3883 
3884             // check weight
3885             if (weights.find(input_name) != weights.end() && node_reference[input_name] == 0)
3886             {
3887                 continue;
3888             }
3889 
3890             if (input_name.empty())
3891             {
3892                 continue;
3893             }
3894 
3895             if (split_node_reference.find(input_name) != split_node_reference.end())
3896             {
3897                 int refidx = split_node_reference[input_name] - 1;
3898                 split_node_reference[input_name] = refidx;
3899 
3900                 char splitsuffix[256];
3901                 sprintf(splitsuffix, "_splitncnn_%d", refidx);
3902                 input_name = input_name + splitsuffix;
3903             }
3904 
3905             fprintf(pp, " %s", input_name.c_str());
3906         }
3907 
3908         for (int j = 0; j < output_size; j++)
3909         {
3910             const std::string& output_name = node.output(j);
3911 
3912             fprintf(pp, " %s", output_name.c_str());
3913         }
3914 
3915         if (op == "Abs")
3916         {
3917             int op_type = 0;
3918             fprintf(pp, " 0=%d", op_type);
3919         }
3920         else if (op == "Acos")
3921         {
3922             int op_type = 13;
3923             fprintf(pp, " 0=%d", op_type);
3924         }
3925         else if (op == "Add")
3926         {
3927             int op_type = 0;
3928             fprintf(pp, " 0=%d", op_type);
3929 
3930             int with_scalar = get_node_attr_i(node, "with_scalar", 0);
3931             float b = get_node_attr_f(node, "b", 0.f);
3932             if (with_scalar)
3933             {
3934                 fprintf(pp, " 1=%d", with_scalar);
3935                 fprintf(pp, " 2=%e", b);
3936             }
3937         }
3938         else if (op == "Asin")
3939         {
3940             int op_type = 12;
3941             fprintf(pp, " 0=%d", op_type);
3942         }
3943         else if (op == "Atan")
3944         {
3945             int op_type = 14;
3946             fprintf(pp, " 0=%d", op_type);
3947         }
3948         else if (op == "AveragePool" || op == "MaxPool")
3949         {
3950             std::string auto_pad = get_node_attr_s(node, "auto_pad");
3951             int ceil_mode = get_node_attr_i(node, "ceil_mode", 0);
3952             std::vector<int> kernel_shape = get_node_attr_ai(node, "kernel_shape");
3953             std::vector<int> strides = get_node_attr_ai(node, "strides");
3954             std::vector<int> pads = get_node_attr_ai(node, "pads");
3955 
3956             int pool = op == "AveragePool" ? 1 : 0;
3957             int pad_mode = 1;
3958 
3959             if (auto_pad == "SAME_UPPER")
3960             {
3961                 pad_mode = 2;
3962             }
3963             else if (auto_pad == "SAME_LOWER")
3964             {
3965                 pad_mode = 3;
3966             }
3967 
3968             if (ceil_mode == 1)
3969             {
3970                 pad_mode = 0;
3971             }
3972 
3973             fprintf(pp, " 0=%d", pool);
3974 
3975             if (kernel_shape.size() == 1)
3976             {
3977                 fprintf(pp, " 1=%d", kernel_shape[0]);
3978             }
3979             else if (kernel_shape.size() == 2)
3980             {
3981                 fprintf(pp, " 1=%d", kernel_shape[1]);
3982                 fprintf(pp, " 11=%d", kernel_shape[0]);
3983             }
3984 
3985             if (strides.size() == 1)
3986             {
3987                 fprintf(pp, " 2=%d", strides[0]);
3988             }
3989             else if (strides.size() == 2)
3990             {
3991                 fprintf(pp, " 2=%d", strides[1]);
3992                 fprintf(pp, " 12=%d", strides[0]);
3993             }
3994 
3995             if (pads.size() == 1)
3996             {
3997                 fprintf(pp, " 3=%d", pads[0]);
3998             }
3999             else if (pads.size() == 2)
4000             {
4001                 fprintf(pp, " 3=%d", pads[1]);
4002                 fprintf(pp, " 13=%d", pads[0]);
4003             }
4004             else if (pads.size() == 4)
4005             {
4006                 fprintf(pp, " 3=%d", pads[1]);
4007                 fprintf(pp, " 13=%d", pads[0]);
4008                 fprintf(pp, " 14=%d", pads[3]);
4009                 fprintf(pp, " 15=%d", pads[2]);
4010             }
4011 
4012             fprintf(pp, " 5=%d", pad_mode);
4013 
4014             if (op == "AveragePool")
4015             {
4016                 int avgpool_count_include_pad = get_node_attr_i(node, "count_include_pad", 0);
4017                 fprintf(pp, " 6=%d", avgpool_count_include_pad);
4018             }
4019         }
4020         else if (op == "BatchNormalization")
4021         {
4022             float epsilon = get_node_attr_f(node, "epsilon", 1e-5f);
4023 
4024             const onnx::TensorProto& scale = weights[node.input(1)];
4025             const onnx::TensorProto& B = weights[node.input(2)];
4026             const onnx::TensorProto& mean = weights[node.input(3)];
4027             const onnx::TensorProto& var = weights[node.input(4)];
4028 
4029             int channels = get_tensor_proto_data_size(scale);
4030 
4031             fprintf(pp, " 0=%d", channels);
4032 
4033             fwrite_tensor_proto_data(scale, bp);
4034             fwrite_tensor_proto_data(mean, bp);
4035             // apply epsilon to var
4036             {
4037                 const float* v = var.has_raw_data() ? (const float*)var.raw_data().data() : var.float_data().data();
4038 
4039                 for (int j = 0; j < channels; j++)
4040                 {
4041                     float ve = v[j] + epsilon;
4042                     fwrite(&ve, sizeof(float), 1, bp);
4043                 }
4044             }
4045             fwrite_tensor_proto_data(B, bp);
4046         }
4047         else if (op == "BiasGelu")
4048         {
4049             const onnx::TensorProto& B = weights[node.input(1)];
4050 
4051             fprintf(pp, " 0=%d", get_tensor_proto_data_size(B));
4052 
4053             int quantize_tag = 0;
4054             fwrite(&quantize_tag, sizeof(int), 1, bp);
4055 
4056             fwrite_tensor_proto_data(B, bp);
4057         }
4058         else if (op == "Ceil")
4059         {
4060             int op_type = 3;
4061             fprintf(pp, " 0=%d", op_type);
4062         }
4063         else if (op == "Clip")
4064         {
4065             float min;
4066             float max;
4067             if (node.input_size() == 1)
4068             {
4069                 min = get_node_attr_f(node, "min", -FLT_MAX);
4070                 max = get_node_attr_f(node, "max", FLT_MAX);
4071             }
4072             else
4073             {
4074                 min = weights.find(node.input(1)) != weights.end() ? get_node_attr_from_input_f(weights[node.input(1)]) : -FLT_MAX;
4075                 max = weights.find(node.input(2)) != weights.end() ? get_node_attr_from_input_f(weights[node.input(2)]) : FLT_MAX;
4076             }
4077 
4078             fprintf(pp, " 0=%e", min);
4079             fprintf(pp, " 1=%e", max);
4080         }
4081         else if (op == "Concat")
4082         {
4083             int axis = get_node_attr_i(node, "axis", 1);
4084             fprintf(pp, " 0=%d", axis > 0 ? axis - 1 : axis);
4085         }
4086         else if (op == "Constant")
4087         {
4088             // never reach here
4089         }
4090         else if (op == "Conv")
4091         {
4092             const onnx::TensorProto& W = weights[node.input(1)];
4093 
4094             int num_filter = W.dims(0);
4095             int has_bias = node.input_size() == 3 ? 1 : 0;
4096 
4097             std::string auto_pad = get_node_attr_s(node, "auto_pad");
4098             std::vector<int> kernel_shape = get_node_attr_ai(node, "kernel_shape");
4099             std::vector<int> dilations = get_node_attr_ai(node, "dilations");
4100             std::vector<int> strides = get_node_attr_ai(node, "strides");
4101             std::vector<int> pads = get_node_attr_ai(node, "pads");
4102             int group = get_node_attr_i(node, "group", 1);
4103 
4104             fprintf(pp, " 0=%d", num_filter);
4105 
4106             if (kernel_shape.size() == 1)
4107             {
4108                 fprintf(pp, " 1=%d", kernel_shape[0]);
4109             }
4110             else if (kernel_shape.size() == 2)
4111             {
4112                 fprintf(pp, " 1=%d", kernel_shape[1]);
4113                 fprintf(pp, " 11=%d", kernel_shape[0]);
4114             }
4115 
4116             if (dilations.size() == 1)
4117             {
4118                 fprintf(pp, " 2=%d", dilations[0]);
4119             }
4120             else if (dilations.size() == 2)
4121             {
4122                 fprintf(pp, " 2=%d", dilations[1]);
4123                 fprintf(pp, " 12=%d", dilations[0]);
4124             }
4125 
4126             if (strides.size() == 1)
4127             {
4128                 fprintf(pp, " 3=%d", strides[0]);
4129             }
4130             else if (strides.size() == 2)
4131             {
4132                 fprintf(pp, " 3=%d", strides[1]);
4133                 fprintf(pp, " 13=%d", strides[0]);
4134             }
4135 
4136             if (auto_pad == "SAME_UPPER")
4137             {
4138                 fprintf(pp, " 4=-233");
4139             }
4140             else if (auto_pad == "SAME_LOWER")
4141             {
4142                 fprintf(pp, " 4=-234");
4143             }
4144             else
4145             {
4146                 if (pads.size() == 1)
4147                 {
4148                     fprintf(pp, " 4=%d", pads[0]);
4149                 }
4150                 else if (pads.size() == 2)
4151                 {
4152                     fprintf(pp, " 4=%d", pads[1]);
4153                     fprintf(pp, " 14=%d", pads[0]);
4154                 }
4155                 else if (pads.size() == 4)
4156                 {
4157                     fprintf(pp, " 4=%d", pads[1]);
4158                     fprintf(pp, " 14=%d", pads[0]);
4159                     fprintf(pp, " 15=%d", pads[3]);
4160                     fprintf(pp, " 16=%d", pads[2]);
4161                 }
4162             }
4163 
4164             fprintf(pp, " 5=%d", has_bias);
4165 
4166             fprintf(pp, " 6=%d", get_tensor_proto_data_size(W));
4167 
4168             if (group > 1)
4169             {
4170                 fprintf(pp, " 7=%d", group);
4171             }
4172 
4173             int quantize_tag = 0;
4174             fwrite(&quantize_tag, sizeof(int), 1, bp);
4175 
4176             fwrite_tensor_proto_data(W, bp);
4177 
4178             if (has_bias)
4179             {
4180                 const onnx::TensorProto& B = weights[node.input(2)];
4181                 fwrite_tensor_proto_data(B, bp);
4182             }
4183         }
4184         else if (op == "ConvTranspose")
4185         {
4186             const onnx::TensorProto& W = weights[node.input(1)];
4187 
4188             int has_bias = node.input_size() == 3 ? 1 : 0;
4189 
4190             std::string auto_pad = get_node_attr_s(node, "auto_pad");
4191             std::vector<int> kernel_shape = get_node_attr_ai(node, "kernel_shape");
4192             std::vector<int> dilations = get_node_attr_ai(node, "dilations");
4193             std::vector<int> strides = get_node_attr_ai(node, "strides");
4194             std::vector<int> output_padding = get_node_attr_ai(node, "output_padding");
4195             std::vector<int> output_shape = get_node_attr_ai(node, "output_shape");
4196             std::vector<int> pads = get_node_attr_ai(node, "pads");
4197             int group = get_node_attr_i(node, "group", 1);
4198             int num_filter = W.dims(1) * group;
4199 
4200             fprintf(pp, " 0=%d", num_filter);
4201 
4202             if (kernel_shape.size() == 1)
4203             {
4204                 fprintf(pp, " 1=%d", kernel_shape[0]);
4205             }
4206             else if (kernel_shape.size() == 2)
4207             {
4208                 fprintf(pp, " 1=%d", kernel_shape[1]);
4209                 fprintf(pp, " 11=%d", kernel_shape[0]);
4210             }
4211 
4212             if (dilations.size() == 1)
4213             {
4214                 fprintf(pp, " 2=%d", dilations[0]);
4215             }
4216             else if (dilations.size() == 2)
4217             {
4218                 fprintf(pp, " 2=%d", dilations[1]);
4219                 fprintf(pp, " 12=%d", dilations[0]);
4220             }
4221 
4222             if (strides.size() == 1)
4223             {
4224                 fprintf(pp, " 3=%d", strides[0]);
4225             }
4226             else if (strides.size() == 2)
4227             {
4228                 fprintf(pp, " 3=%d", strides[1]);
4229                 fprintf(pp, " 13=%d", strides[0]);
4230             }
4231 
4232             if (auto_pad == "SAME_UPPER")
4233             {
4234                 fprintf(pp, " 4=-233");
4235             }
4236             else if (auto_pad == "SAME_LOWER")
4237             {
4238                 fprintf(pp, " 4=-234");
4239             }
4240             else
4241             {
4242                 if (pads.size() == 1)
4243                 {
4244                     fprintf(pp, " 4=%d", pads[0]);
4245                 }
4246                 else if (pads.size() == 2)
4247                 {
4248                     fprintf(pp, " 4=%d", pads[1]);
4249                     fprintf(pp, " 14=%d", pads[0]);
4250                 }
4251                 else if (pads.size() == 4)
4252                 {
4253                     fprintf(pp, " 4=%d", pads[1]);
4254                     fprintf(pp, " 14=%d", pads[0]);
4255                     fprintf(pp, " 15=%d", pads[3]);
4256                     fprintf(pp, " 16=%d", pads[2]);
4257                 }
4258             }
4259 
4260             if (output_padding.size() == 1)
4261             {
4262                 fprintf(pp, " 18=%d", output_padding[0]);
4263             }
4264             else if (output_padding.size() == 2)
4265             {
4266                 fprintf(pp, " 18=%d", output_padding[1]);
4267                 fprintf(pp, " 19=%d", output_padding[0]);
4268             }
4269 
4270             if (output_shape.size() == 1)
4271             {
4272                 fprintf(pp, " 20=%d", output_shape[0]);
4273             }
4274             else if (output_shape.size() == 2)
4275             {
4276                 fprintf(pp, " 20=%d", output_shape[1]);
4277                 fprintf(pp, " 21=%d", output_shape[0]);
4278             }
4279 
4280             fprintf(pp, " 5=%d", has_bias);
4281 
4282             fprintf(pp, " 6=%d", get_tensor_proto_data_size(W));
4283 
4284             if (group > 1)
4285             {
4286                 fprintf(pp, " 7=%d", group);
4287             }
4288 
4289             int quantize_tag = 0;
4290             fwrite(&quantize_tag, sizeof(int), 1, bp);
4291 
4292             int maxk = 0;
4293             if (kernel_shape.size() == 2)
4294             {
4295                 maxk = kernel_shape[1] * kernel_shape[0];
4296             }
4297             else
4298             {
4299                 maxk = kernel_shape[0] * kernel_shape[0];
4300             }
4301             int weight_data_size = get_tensor_proto_data_size(W);
4302             const float* weight_data = 0;
4303             if (W.has_raw_data())
4304             {
4305                 weight_data = (const float*)W.raw_data().data();
4306             }
4307             else if (W.data_type() == 1)
4308             {
4309                 weight_data = W.float_data().data();
4310             }
4311             for (int g = 0; g < group; g++)
4312             {
4313                 // reorder weight from inch-outch to outch-inch
4314                 int num_filter_g = num_filter / group;
4315                 int num_input = weight_data_size / maxk / num_filter_g / group;
4316                 const float* weight_data_ptr = weight_data + g * maxk * num_filter_g * num_input;
4317                 for (int k = 0; k < num_filter_g; k++)
4318                 {
4319                     for (int j = 0; j < num_input; j++)
4320                     {
4321                         fwrite(weight_data_ptr + (j * num_filter_g + k) * maxk, sizeof(float), maxk, bp);
4322                     }
4323                 }
4324             }
4325 
4326             if (has_bias)
4327             {
4328                 const onnx::TensorProto& B = weights[node.input(2)];
4329                 fwrite_tensor_proto_data(B, bp);
4330             }
4331         }
4332         else if (op == "Cos")
4333         {
4334             int op_type = 10;
4335             fprintf(pp, " 0=%d", op_type);
4336         }
4337         else if (op == "DepthToSpace")
4338         {
4339             // pixelshuffle
4340             int scale_factor = get_node_attr_i(node, "blocksize", 1);
4341             std::string mode = get_node_attr_s(node, "mode");
4342             fprintf(pp, " 0=%d", scale_factor);
4343             if (mode == "CRD")
4344             {
4345                 fprintf(pp, " 1=0");
4346             }
4347             else if (mode == "DCR")
4348             {
4349                 fprintf(pp, " 1=1");
4350             }
4351         }
4352         else if (op == "Div")
4353         {
4354             int op_type = 3;
4355             fprintf(pp, " 0=%d", op_type);
4356 
4357             int with_scalar = get_node_attr_i(node, "with_scalar", 0);
4358             float b = get_node_attr_f(node, "b", 0.f);
4359             if (with_scalar)
4360             {
4361                 fprintf(pp, " 1=%d", with_scalar);
4362                 fprintf(pp, " 2=%e", b);
4363             }
4364         }
4365         else if (op == "Dropout")
4366         {
4367             // no-op
4368         }
4369         else if (op == "Elu")
4370         {
4371             float alpha = get_node_attr_f(node, "alpha", 1.f);
4372             fprintf(pp, " 0=%e", alpha);
4373         }
4374         else if (op == "EmbedLayerNormalization")
4375         {
4376             const onnx::TensorProto& words = weights[node.input(2)];
4377             const onnx::TensorProto& positions = weights[node.input(3)];
4378             const onnx::TensorProto& W = weights[node.input(5)];
4379             const onnx::TensorProto& B = weights[node.input(6)];
4380 
4381             fprintf(pp, " 0=%d", get_tensor_proto_data_size(B));
4382             fprintf(pp, " 1=%d", get_tensor_proto_data_size(words));
4383             fprintf(pp, " 2=%d", get_tensor_proto_data_size(positions));
4384 
4385             int quantize_tag = 0;
4386             fwrite(&quantize_tag, sizeof(int), 1, bp);
4387 
4388             fwrite_tensor_proto_data(words, bp);
4389 
4390             fwrite(&quantize_tag, sizeof(int), 1, bp);
4391 
4392             fwrite_tensor_proto_data(positions, bp);
4393 
4394             fwrite(&quantize_tag, sizeof(int), 1, bp);
4395 
4396             fwrite_tensor_proto_data(W, bp);
4397 
4398             fwrite(&quantize_tag, sizeof(int), 1, bp);
4399 
4400             fwrite_tensor_proto_data(B, bp);
4401         }
4402         else if (op == "Exp")
4403         {
4404             int op_type = 7;
4405             fprintf(pp, " 0=%d", op_type);
4406         }
4407         else if (op == "Flatten")
4408         {
4409             int axis = get_node_attr_i(node, "axis", 1);
4410             if (axis != 1)
4411             {
4412                 fprintf(stderr, "Unsupported Flatten axis %d!\n", axis);
4413             }
4414         }
4415         else if (op == "Floor")
4416         {
4417             int op_type = 2;
4418             fprintf(pp, " 0=%d", op_type);
4419         }
4420         else if (op == "Gemm")
4421         {
4422             float alpha = get_node_attr_f(node, "alpha", 1.f);
4423             float beta = get_node_attr_f(node, "beta", 1.f);
4424             int transA = get_node_attr_i(node, "transA", 0);
4425             int transB = get_node_attr_i(node, "transB", 0);
4426 
4427             if (alpha == 1.f && beta == 1.f && transA == 0 && transB == 1)
4428             {
4429                 // InnerProduct-like A * B + C
4430                 const onnx::TensorProto& B = weights[node.input(1)];
4431                 const onnx::TensorProto& C = weights[node.input(2)];
4432 
4433                 fprintf(pp, " 0=%d", get_tensor_proto_data_size(C));
4434                 fprintf(pp, " 1=1");
4435                 fprintf(pp, " 2=%d", get_tensor_proto_data_size(B));
4436 
4437                 int quantize_tag = 0;
4438                 fwrite(&quantize_tag, sizeof(int), 1, bp);
4439 
4440                 fwrite_tensor_proto_data(B, bp);
4441                 fwrite_tensor_proto_data(C, bp);
4442             }
4443             else
4444             {
4445                 // gemm
4446                 fprintf(pp, " 0=%e", alpha);
4447                 fprintf(pp, " 1=%e", beta);
4448                 fprintf(pp, " 2=%d", transA);
4449                 fprintf(pp, " 3=%d", transB);
4450             }
4451         }
4452         else if (op == "GlobalAveragePool")
4453         {
4454             int pool = 1;
4455             int global_pool = 1;
4456 
4457             fprintf(pp, " 0=%d", pool);
4458             fprintf(pp, " 4=%d", global_pool);
4459         }
4460         else if (op == "GlobalMaxPool")
4461         {
4462             int pool = 0;
4463             int global_pool = 1;
4464 
4465             fprintf(pp, " 0=%d", pool);
4466             fprintf(pp, " 4=%d", global_pool);
4467         }
4468         else if (op == "adaptive_avg_pool2d" || op == "adaptive_max_pool2d")
4469         {
4470             int pool = 0;
4471             if (op == "adaptive_avg_pool2d")
4472             {
4473                 pool = 1;
4474             }
4475             int adaptive_pooling = 1;
4476             const onnx::TensorProto& out_shape_tp = weights[node.input(1)];
4477             std::vector<int> out_shape = get_node_attr_from_input_ai(out_shape_tp);
4478 
4479             fprintf(pp, " 0=%d", pool);
4480             fprintf(pp, " 7=%d", adaptive_pooling);
4481             if (out_shape.size() == 1)
4482             {
4483                 fprintf(pp, " 8=%d", out_shape[0]);
4484             }
4485             else if (out_shape.size() == 2)
4486             {
4487                 // out_w
4488                 fprintf(pp, " 8=%d", out_shape[1]);
4489                 // out_h
4490                 fprintf(pp, " 18=%d", out_shape[0]);
4491             }
4492         }
4493         else if (op == "GroupNorm")
4494         {
4495             int groups = get_node_attr_i(node, "groups", 1);
4496             int channels = get_node_attr_i(node, "channels", 1);
4497             float eps = get_node_attr_f(node, "epsilon", 1e-5f);
4498             int affine = get_node_attr_i(node, "affine", 1);
4499 
4500             if (affine)
4501             {
4502                 // discard affine-less S=1 B=0
4503                 std::vector<float> affine_S = get_node_attr_from_input_af(weights[node.input(1)]);
4504                 std::vector<float> affine_B = get_node_attr_from_input_af(weights[node.input(2)]);
4505                 if (affine_S.size() == 1 && affine_S[0] == 1.f && affine_B.size() == 1 && affine_B[0] == 0.f)
4506                 {
4507                     affine = 0;
4508                 }
4509                 else
4510                 {
4511                     affine = 0;
4512                     {
4513                         for (int j = 0; j < channels; j++)
4514                         {
4515                             if (affine_S[j] != 1.f || affine_B[j] != 0.f)
4516                             {
4517                                 affine = 1;
4518                                 break;
4519                             }
4520                         }
4521                     }
4522                 }
4523             }
4524 
4525             fprintf(pp, " 0=%d", groups);
4526             fprintf(pp, " 1=%d", channels);
4527             fprintf(pp, " 2=%e", eps);
4528             fprintf(pp, " 3=%d", affine);
4529             if (affine)
4530             {
4531                 const onnx::TensorProto& scale = weights[node.input(1)];
4532                 const onnx::TensorProto& B = weights[node.input(2)];
4533 
4534                 fwrite_tensor_proto_data(scale, bp);
4535                 fwrite_tensor_proto_data(B, bp);
4536             }
4537         }
4538         else if (op == "GRU")
4539         {
4540             const onnx::TensorProto& W = weights[node.input(1)];
4541             const onnx::TensorProto& R = weights[node.input(2)];
4542             const onnx::TensorProto& B = weights[node.input(3)];
4543 
4544             int hidden_size = get_node_attr_i(node, "hidden_size", 0);
4545             std::string direction = get_node_attr_s(node, "direction");
4546 
4547             int direction_type = 0;
4548             if (direction == "forward")
4549             {
4550                 direction_type = 0;
4551             }
4552             else if (direction == "reverse")
4553             {
4554                 direction_type = 1;
4555             }
4556             else if (direction == "bidirectional")
4557             {
4558                 direction_type = 2;
4559             }
4560 
4561             int weight_data_size = get_tensor_proto_data_size(W);
4562 
4563             fprintf(pp, " 0=%d", hidden_size);
4564             fprintf(pp, " 1=%d", weight_data_size);
4565             fprintf(pp, " 2=%d", direction_type);
4566 
4567             int num_directions = direction_type == 2 ? 2 : 1;
4568 
4569             int quantize_tag = 0;
4570 
4571             // reorder num_directions-URN-hidden-size to num_directions-RUN-hidden-size
4572             {
4573                 fwrite(&quantize_tag, sizeof(int), 1, bp);
4574 
4575                 int weight_data_size_g = get_tensor_proto_data_size(W) / 3 / num_directions;
4576                 const float* wptr = W.has_raw_data() ? (const float*)W.raw_data().data() : W.float_data().data();
4577 
4578                 const float* uptr = wptr;
4579                 const float* rptr = wptr + weight_data_size_g;
4580                 const float* nptr = wptr + weight_data_size_g * 2;
4581                 fwrite(rptr, sizeof(float), weight_data_size_g, bp);
4582                 fwrite(uptr, sizeof(float), weight_data_size_g, bp);
4583                 fwrite(nptr, sizeof(float), weight_data_size_g, bp);
4584 
4585                 if (direction_type == 2)
4586                 {
4587                     uptr += weight_data_size_g * 3;
4588                     rptr += weight_data_size_g * 3;
4589                     nptr += weight_data_size_g * 3;
4590                     fwrite(rptr, sizeof(float), weight_data_size_g, bp);
4591                     fwrite(uptr, sizeof(float), weight_data_size_g, bp);
4592                     fwrite(nptr, sizeof(float), weight_data_size_g, bp);
4593                 }
4594             }
4595 
4596             // reduce U and R bias except N
4597             // reorder num_directions-URN-hidden to num_directions-RUN-hidden
4598             {
4599                 fwrite(&quantize_tag, sizeof(int), 1, bp);
4600 
4601                 int bias_data_size_g = get_tensor_proto_data_size(B) / 2 / 3 / num_directions;
4602                 const float* bptr = B.has_raw_data() ? (const float*)B.raw_data().data() : B.float_data().data();
4603                 const float* wuptr = bptr;
4604                 const float* wrptr = bptr + bias_data_size_g;
4605                 const float* wnptr = bptr + bias_data_size_g * 2;
4606                 const float* buptr = bptr + bias_data_size_g * 3;
4607                 const float* brptr = bptr + bias_data_size_g * 4;
4608                 const float* bnptr = bptr + bias_data_size_g * 5;
4609 
4610                 for (int j = 0; j < bias_data_size_g; j++)
4611                 {
4612                     float vb = wrptr[j] + brptr[j];
4613                     fwrite(&vb, sizeof(float), 1, bp);
4614                 }
4615                 for (int j = 0; j < bias_data_size_g; j++)
4616                 {
4617                     float vb = wuptr[j] + buptr[j];
4618                     fwrite(&vb, sizeof(float), 1, bp);
4619                 }
4620                 fwrite(wnptr, sizeof(float), bias_data_size_g, bp);
4621                 fwrite(bnptr, sizeof(float), bias_data_size_g, bp);
4622 
4623                 if (direction_type == 2)
4624                 {
4625                     wuptr += bias_data_size_g * 6;
4626                     wrptr += bias_data_size_g * 6;
4627                     wnptr += bias_data_size_g * 6;
4628                     buptr += bias_data_size_g * 6;
4629                     brptr += bias_data_size_g * 6;
4630                     bnptr += bias_data_size_g * 6;
4631 
4632                     for (int j = 0; j < bias_data_size_g; j++)
4633                     {
4634                         float vb = wrptr[j] + brptr[j];
4635                         fwrite(&vb, sizeof(float), 1, bp);
4636                     }
4637                     for (int j = 0; j < bias_data_size_g; j++)
4638                     {
4639                         float vb = wuptr[j] + buptr[j];
4640                         fwrite(&vb, sizeof(float), 1, bp);
4641                     }
4642                     fwrite(wnptr, sizeof(float), bias_data_size_g, bp);
4643                     fwrite(bnptr, sizeof(float), bias_data_size_g, bp);
4644                 }
4645             }
4646 
4647             // reorder num_directions-URN-hidden-hidden to num_directions-RUN-hidden-hidden
4648             {
4649                 fwrite(&quantize_tag, sizeof(int), 1, bp);
4650 
4651                 int weight_data_size_g = get_tensor_proto_data_size(R) / 3 / num_directions;
4652                 const float* Rptr = R.has_raw_data() ? (const float*)R.raw_data().data() : R.float_data().data();
4653 
4654                 const float* uptr = Rptr;
4655                 const float* rptr = Rptr + weight_data_size_g;
4656                 const float* nptr = Rptr + weight_data_size_g * 2;
4657                 fwrite(rptr, sizeof(float), weight_data_size_g, bp);
4658                 fwrite(uptr, sizeof(float), weight_data_size_g, bp);
4659                 fwrite(nptr, sizeof(float), weight_data_size_g, bp);
4660 
4661                 if (direction_type == 2)
4662                 {
4663                     uptr += weight_data_size_g * 3;
4664                     rptr += weight_data_size_g * 3;
4665                     nptr += weight_data_size_g * 3;
4666                     fwrite(rptr, sizeof(float), weight_data_size_g, bp);
4667                     fwrite(uptr, sizeof(float), weight_data_size_g, bp);
4668                     fwrite(nptr, sizeof(float), weight_data_size_g, bp);
4669                 }
4670             }
4671         }
4672         else if (op == "HardSigmoid")
4673         {
4674             float alpha = get_node_attr_f(node, "alpha", 0.2f);
4675             float beta = get_node_attr_f(node, "beta", 0.5f);
4676 
4677             fprintf(pp, " 0=%e", alpha);
4678             fprintf(pp, " 1=%e", beta);
4679         }
4680         else if (op == "HardSwish")
4681         {
4682             float alpha = get_node_attr_f(node, "alpha", 0.2f);
4683             float beta = get_node_attr_f(node, "beta", 0.5f);
4684 
4685             fprintf(pp, " 0=%e", alpha);
4686             fprintf(pp, " 1=%e", beta);
4687         }
4688         else if (op == "ImageScaler")
4689         {
4690             std::vector<float> bias = get_node_attr_af(node, "bias");
4691             float scale = get_node_attr_f(node, "scale", 1.f);
4692 
4693             int channels = (int)bias.size();
4694 
4695             fprintf(pp, " 0=%d", channels);
4696             fprintf(pp, " 1=1");
4697 
4698             for (int j = 0; j < channels; j++)
4699             {
4700                 fwrite(&scale, sizeof(float), 1, bp);
4701             }
4702             fwrite(&bias[0], sizeof(float), channels, bp);
4703         }
4704         else if (op == "InstanceNormalization")
4705         {
4706             float eps = get_node_attr_f(node, "epsilon", 1e-5f);
4707 
4708             // discard affine-less S=1 B=0
4709             std::vector<float> affine_S = get_node_attr_from_input_af(weights[node.input(1)]);
4710             std::vector<float> affine_B = get_node_attr_from_input_af(weights[node.input(2)]);
4711             int channels = (int)affine_S.size();
4712             int affine = 0;
4713             {
4714                 for (int j = 0; j < channels; j++)
4715                 {
4716                     if (affine_S[j] != 1.f || affine_B[j] != 0.f)
4717                     {
4718                         affine = 1;
4719                         break;
4720                     }
4721                 }
4722             }
4723 
4724             fprintf(pp, " 0=%d", channels);
4725             fprintf(pp, " 1=%e", eps);
4726             fprintf(pp, " 2=%d", affine);
4727             if (affine)
4728             {
4729                 const onnx::TensorProto& scale = weights[node.input(1)];
4730                 const onnx::TensorProto& B = weights[node.input(2)];
4731 
4732                 fwrite_tensor_proto_data(scale, bp);
4733                 fwrite_tensor_proto_data(B, bp);
4734             }
4735         }
4736         else if (op == "LayerNorm")
4737         {
4738             float eps = get_node_attr_f(node, "epsilon", 1e-5f);
4739             int affine = get_node_attr_i(node, "affine", 1);
4740 
4741             if (affine)
4742             {
4743                 // discard affine-less S=1 B=0
4744                 std::vector<float> affine_S = get_node_attr_from_input_af(weights[node.input(1)]);
4745                 std::vector<float> affine_B = get_node_attr_from_input_af(weights[node.input(2)]);
4746                 int affine_size = (int)affine_S.size();
4747                 affine = 0;
4748                 {
4749                     for (int j = 0; j < affine_size; j++)
4750                     {
4751                         if (affine_S[j] != 1.f || affine_B[j] != 0.f)
4752                         {
4753                             affine = 1;
4754                             break;
4755                         }
4756                     }
4757                 }
4758 
4759                 if (affine)
4760                 {
4761                     fprintf(pp, " 0=%d", affine_size);
4762                 }
4763             }
4764 
4765             fprintf(pp, " 1=%e", eps);
4766             fprintf(pp, " 2=%d", affine);
4767 
4768             if (affine)
4769             {
4770                 const onnx::TensorProto& scale = weights[node.input(1)];
4771                 const onnx::TensorProto& B = weights[node.input(2)];
4772 
4773                 fwrite_tensor_proto_data(scale, bp);
4774                 fwrite_tensor_proto_data(B, bp);
4775             }
4776         }
4777         else if (op == "LeakyRelu")
4778         {
4779             float alpha = get_node_attr_f(node, "alpha", 0.01f);
4780 
4781             fprintf(pp, " 0=%e", alpha);
4782         }
4783         else if (op == "Log")
4784         {
4785             int op_type = 8;
4786             fprintf(pp, " 0=%d", op_type);
4787         }
4788         else if (op == "LRN")
4789         {
4790             float alpha = get_node_attr_f(node, "alpha", 1.f);
4791             float beta = get_node_attr_f(node, "beta", 0.5f);
4792             float bias = get_node_attr_f(node, "bias", 1.f);
4793             int size = get_node_attr_i(node, "size", 1);
4794 
4795             int norm_region = 0;
4796 
4797             fprintf(pp, " 0=%d", norm_region);
4798             fprintf(pp, " 1=%d", size);
4799             fprintf(pp, " 2=%e", alpha);
4800             fprintf(pp, " 3=%e", beta);
4801             fprintf(pp, " 4=%e", bias);
4802         }
4803         else if (op == "LSTM")
4804         {
4805             const onnx::TensorProto& W = weights[node.input(1)];
4806             const onnx::TensorProto& R = weights[node.input(2)];
4807             const onnx::TensorProto& B = weights[node.input(3)];
4808 
4809             int hidden_size = get_node_attr_i(node, "hidden_size", 0);
4810             std::string direction = get_node_attr_s(node, "direction");
4811 
4812             int direction_type = 0;
4813             if (direction == "forward")
4814             {
4815                 direction_type = 0;
4816             }
4817             else if (direction == "reverse")
4818             {
4819                 direction_type = 1;
4820             }
4821             else if (direction == "bidirectional")
4822             {
4823                 direction_type = 2;
4824             }
4825 
4826             int weight_data_size = get_tensor_proto_data_size(W);
4827 
4828             fprintf(pp, " 0=%d", hidden_size);
4829             fprintf(pp, " 1=%d", weight_data_size);
4830             fprintf(pp, " 2=%d", direction_type);
4831 
4832             int num_directions = direction_type == 2 ? 2 : 1;
4833 
4834             int quantize_tag = 0;
4835 
4836             // reorder num_directions-IOFG-hidden-size to num_directions-IFOG-hidden-size
4837             {
4838                 fwrite(&quantize_tag, sizeof(int), 1, bp);
4839 
4840                 int weight_data_size_g = get_tensor_proto_data_size(W) / 4 / num_directions;
4841                 const float* wptr = W.has_raw_data() ? (const float*)W.raw_data().data() : W.float_data().data();
4842 
4843                 const float* iptr = wptr;
4844                 const float* optr = wptr + weight_data_size_g;
4845                 const float* fptr = wptr + weight_data_size_g * 2;
4846                 const float* gptr = wptr + weight_data_size_g * 3;
4847                 fwrite(iptr, sizeof(float), weight_data_size_g, bp);
4848                 fwrite(fptr, sizeof(float), weight_data_size_g, bp);
4849                 fwrite(optr, sizeof(float), weight_data_size_g, bp);
4850                 fwrite(gptr, sizeof(float), weight_data_size_g, bp);
4851 
4852                 if (direction_type == 2)
4853                 {
4854                     iptr += weight_data_size_g * 4;
4855                     optr += weight_data_size_g * 4;
4856                     fptr += weight_data_size_g * 4;
4857                     gptr += weight_data_size_g * 4;
4858                     fwrite(iptr, sizeof(float), weight_data_size_g, bp);
4859                     fwrite(fptr, sizeof(float), weight_data_size_g, bp);
4860                     fwrite(optr, sizeof(float), weight_data_size_g, bp);
4861                     fwrite(gptr, sizeof(float), weight_data_size_g, bp);
4862                 }
4863             }
4864 
4865             // reduce xc and hc bias
4866             // reorder num_directions-IOFG-hidden to num_directions-IFOG-hidden
4867             {
4868                 fwrite(&quantize_tag, sizeof(int), 1, bp);
4869 
4870                 int bias_data_size_g = get_tensor_proto_data_size(B) / 2 / 4 / num_directions;
4871                 const float* xcbptr = B.has_raw_data() ? (const float*)B.raw_data().data() : B.float_data().data();
4872                 const float* xiptr = xcbptr;
4873                 const float* xoptr = xcbptr + bias_data_size_g;
4874                 const float* xfptr = xcbptr + bias_data_size_g * 2;
4875                 const float* xgptr = xcbptr + bias_data_size_g * 3;
4876                 const float* hiptr = xcbptr + bias_data_size_g * 4;
4877                 const float* hoptr = xcbptr + bias_data_size_g * 5;
4878                 const float* hfptr = xcbptr + bias_data_size_g * 6;
4879                 const float* hgptr = xcbptr + bias_data_size_g * 7;
4880 
4881                 for (int j = 0; j < bias_data_size_g; j++)
4882                 {
4883                     float vb = xiptr[j] + hiptr[j];
4884                     fwrite(&vb, sizeof(float), 1, bp);
4885                 }
4886                 for (int j = 0; j < bias_data_size_g; j++)
4887                 {
4888                     float vb = xfptr[j] + hfptr[j];
4889                     fwrite(&vb, sizeof(float), 1, bp);
4890                 }
4891                 for (int j = 0; j < bias_data_size_g; j++)
4892                 {
4893                     float vb = xoptr[j] + hoptr[j];
4894                     fwrite(&vb, sizeof(float), 1, bp);
4895                 }
4896                 for (int j = 0; j < bias_data_size_g; j++)
4897                 {
4898                     float vb = xgptr[j] + hgptr[j];
4899                     fwrite(&vb, sizeof(float), 1, bp);
4900                 }
4901 
4902                 if (direction_type == 2)
4903                 {
4904                     xiptr += bias_data_size_g * 8;
4905                     xoptr += bias_data_size_g * 8;
4906                     xfptr += bias_data_size_g * 8;
4907                     xgptr += bias_data_size_g * 8;
4908                     hiptr += bias_data_size_g * 8;
4909                     hoptr += bias_data_size_g * 8;
4910                     hfptr += bias_data_size_g * 8;
4911                     hgptr += bias_data_size_g * 8;
4912 
4913                     for (int j = 0; j < bias_data_size_g; j++)
4914                     {
4915                         float vb = xiptr[j] + hiptr[j];
4916                         fwrite(&vb, sizeof(float), 1, bp);
4917                     }
4918                     for (int j = 0; j < bias_data_size_g; j++)
4919                     {
4920                         float vb = xfptr[j] + hfptr[j];
4921                         fwrite(&vb, sizeof(float), 1, bp);
4922                     }
4923                     for (int j = 0; j < bias_data_size_g; j++)
4924                     {
4925                         float vb = xoptr[j] + hoptr[j];
4926                         fwrite(&vb, sizeof(float), 1, bp);
4927                     }
4928                     for (int j = 0; j < bias_data_size_g; j++)
4929                     {
4930                         float vb = xgptr[j] + hgptr[j];
4931                         fwrite(&vb, sizeof(float), 1, bp);
4932                     }
4933                 }
4934             }
4935 
4936             // reorder num_directions-IOFG-hidden-hidden to num_directions-IFOG-hidden-hidden
4937             {
4938                 fwrite(&quantize_tag, sizeof(int), 1, bp);
4939 
4940                 int weight_data_size_g = get_tensor_proto_data_size(R) / 4 / num_directions;
4941                 const float* rptr = R.has_raw_data() ? (const float*)R.raw_data().data() : R.float_data().data();
4942 
4943                 const float* iptr = rptr;
4944                 const float* optr = rptr + weight_data_size_g;
4945                 const float* fptr = rptr + weight_data_size_g * 2;
4946                 const float* gptr = rptr + weight_data_size_g * 3;
4947                 fwrite(iptr, sizeof(float), weight_data_size_g, bp);
4948                 fwrite(fptr, sizeof(float), weight_data_size_g, bp);
4949                 fwrite(optr, sizeof(float), weight_data_size_g, bp);
4950                 fwrite(gptr, sizeof(float), weight_data_size_g, bp);
4951 
4952                 if (direction_type == 2)
4953                 {
4954                     iptr += weight_data_size_g * 4;
4955                     optr += weight_data_size_g * 4;
4956                     fptr += weight_data_size_g * 4;
4957                     gptr += weight_data_size_g * 4;
4958                     fwrite(iptr, sizeof(float), weight_data_size_g, bp);
4959                     fwrite(fptr, sizeof(float), weight_data_size_g, bp);
4960                     fwrite(optr, sizeof(float), weight_data_size_g, bp);
4961                     fwrite(gptr, sizeof(float), weight_data_size_g, bp);
4962                 }
4963             }
4964         }
4965         else if (op == "MatMul")
4966         {
4967             if (weights.find(node.input(1)) != weights.end() && weights[node.input(1)].dims_size() == 2)
4968             {
4969                 // InnerProduct
4970                 const onnx::TensorProto& B = weights[node.input(1)];
4971 
4972                 int weight_data_size = get_tensor_proto_data_size(B);
4973 
4974                 int num_output = B.dims(B.dims_size() - 1);
4975                 int num_input = weight_data_size / num_output;
4976 
4977                 fprintf(pp, " 0=%d", num_output);
4978                 fprintf(pp, " 1=0");
4979                 fprintf(pp, " 2=%d", weight_data_size);
4980 
4981                 int quantize_tag = 0;
4982                 fwrite(&quantize_tag, sizeof(int), 1, bp);
4983 
4984                 // reorder num_input-num_output to num_output-num_input
4985                 {
4986                     const float* bptr = B.has_raw_data() ? (const float*)B.raw_data().data() : B.float_data().data();
4987 
4988                     for (int j = 0; j < num_output; j++)
4989                     {
4990                         for (int k = 0; k < num_input; k++)
4991                         {
4992                             float vb = bptr[k * num_output + j];
4993                             fwrite(&vb, sizeof(float), 1, bp);
4994                         }
4995                     }
4996                 }
4997 
4998                 // fwrite_tensor_proto_data(B, bp)
4999             }
5000             else
5001             {
5002                 // default matrix multiplication
5003             }
5004         }
5005         else if (op == "Max")
5006         {
5007             int op_type = 4;
5008             fprintf(pp, " 0=%d", op_type);
5009 
5010             int with_scalar = get_node_attr_i(node, "with_scalar", 0);
5011             float b = get_node_attr_f(node, "b", 0.f);
5012             if (with_scalar)
5013             {
5014                 fprintf(pp, " 1=%d", with_scalar);
5015                 fprintf(pp, " 2=%e", b);
5016             }
5017         }
5018         else if (op == "Min")
5019         {
5020             int op_type = 5;
5021             fprintf(pp, " 0=%d", op_type);
5022 
5023             int with_scalar = get_node_attr_i(node, "with_scalar", 0);
5024             float b = get_node_attr_f(node, "b", 0.f);
5025             if (with_scalar)
5026             {
5027                 fprintf(pp, " 1=%d", with_scalar);
5028                 fprintf(pp, " 2=%e", b);
5029             }
5030         }
5031         else if (op == "Mul")
5032         {
5033             int op_type = 2;
5034             fprintf(pp, " 0=%d", op_type);
5035 
5036             int with_scalar = get_node_attr_i(node, "with_scalar", 0);
5037             float b = get_node_attr_f(node, "b", 0.f);
5038             if (with_scalar)
5039             {
5040                 fprintf(pp, " 1=%d", with_scalar);
5041                 fprintf(pp, " 2=%e", b);
5042             }
5043         }
5044         else if (op == "MultiHeadAttention")
5045         {
5046             int embed_dim = get_node_attr_i(node, "embed_dim", 0);
5047             int num_heads = get_node_attr_i(node, "num_heads", 0);
5048 
5049             fprintf(pp, " 0=%d", embed_dim);
5050             fprintf(pp, " 1=%d", num_heads);
5051 
5052             if (node.input_size() == 5)
5053             {
5054                 const onnx::TensorProto& qkvw = weights[node.input(1)];
5055                 const onnx::TensorProto& qkvb = weights[node.input(2)];
5056                 const onnx::TensorProto& ow = weights[node.input(3)];
5057                 const onnx::TensorProto& ob = weights[node.input(4)];
5058 
5059                 int weight_data_size = get_tensor_proto_data_size(ow);
5060 
5061                 fprintf(pp, " 2=%d", weight_data_size);
5062 
5063                 int quantize_tag = 0;
5064 
5065                 fwrite(&quantize_tag, sizeof(int), 1, bp);
5066                 // transpose qw
5067                 {
5068                     const float* wptr = qkvw.has_raw_data() ? (const float*)qkvw.raw_data().data() : qkvw.float_data().data();
5069                     const float* bptr = qkvb.has_raw_data() ? (const float*)qkvb.raw_data().data() : qkvb.float_data().data();
5070 
5071                     for (int j = 0; j < embed_dim; j++)
5072                     {
5073                         for (int k = 0; k < embed_dim; k++)
5074                         {
5075                             float vb = wptr[k * embed_dim * 3 + j];
5076                             fwrite(&vb, sizeof(float), 1, bp);
5077                         }
5078                     }
5079 
5080                     fwrite(bptr, sizeof(float), embed_dim, bp);
5081                 }
5082 
5083                 fwrite(&quantize_tag, sizeof(int), 1, bp);
5084                 // transpose kw
5085                 {
5086                     const float* wptr = qkvw.has_raw_data() ? (const float*)qkvw.raw_data().data() : qkvw.float_data().data();
5087                     const float* bptr = qkvb.has_raw_data() ? (const float*)qkvb.raw_data().data() : qkvb.float_data().data();
5088                     bptr += embed_dim;
5089 
5090                     for (int j = 0; j < embed_dim; j++)
5091                     {
5092                         for (int k = 0; k < embed_dim; k++)
5093                         {
5094                             float vb = wptr[k * embed_dim * 3 + j + embed_dim];
5095                             fwrite(&vb, sizeof(float), 1, bp);
5096                         }
5097                     }
5098 
5099                     fwrite(bptr, sizeof(float), embed_dim, bp);
5100                 }
5101 
5102                 fwrite(&quantize_tag, sizeof(int), 1, bp);
5103                 // transpose vw
5104                 {
5105                     const float* wptr = qkvw.has_raw_data() ? (const float*)qkvw.raw_data().data() : qkvw.float_data().data();
5106                     const float* bptr = qkvb.has_raw_data() ? (const float*)qkvb.raw_data().data() : qkvb.float_data().data();
5107                     bptr += embed_dim * 2;
5108 
5109                     for (int j = 0; j < embed_dim; j++)
5110                     {
5111                         for (int k = 0; k < embed_dim; k++)
5112                         {
5113                             float vb = wptr[k * embed_dim * 3 + j + embed_dim * 2];
5114                             fwrite(&vb, sizeof(float), 1, bp);
5115                         }
5116                     }
5117 
5118                     fwrite(bptr, sizeof(float), embed_dim, bp);
5119                 }
5120 
5121                 fwrite(&quantize_tag, sizeof(int), 1, bp);
5122                 // transpose ow
5123                 {
5124                     const float* wptr = ow.has_raw_data() ? (const float*)ow.raw_data().data() : ow.float_data().data();
5125 
5126                     for (int j = 0; j < embed_dim; j++)
5127                     {
5128                         for (int k = 0; k < embed_dim; k++)
5129                         {
5130                             float vb = wptr[k * embed_dim + j];
5131                             fwrite(&vb, sizeof(float), 1, bp);
5132                         }
5133                     }
5134                 }
5135                 fwrite_tensor_proto_data(ob, bp);
5136             }
5137             else
5138             {
5139                 const onnx::TensorProto& qw = weights[node.input(3)];
5140                 const onnx::TensorProto& qb = weights[node.input(4)];
5141                 const onnx::TensorProto& kw = weights[node.input(5)];
5142                 const onnx::TensorProto& kb = weights[node.input(6)];
5143                 const onnx::TensorProto& vw = weights[node.input(7)];
5144                 const onnx::TensorProto& vb = weights[node.input(8)];
5145                 const onnx::TensorProto& ow = weights[node.input(9)];
5146                 const onnx::TensorProto& ob = weights[node.input(10)];
5147 
5148                 int weight_data_size = get_tensor_proto_data_size(qw);
5149 
5150                 fprintf(pp, " 2=%d", weight_data_size);
5151 
5152                 int quantize_tag = 0;
5153 
5154                 fwrite(&quantize_tag, sizeof(int), 1, bp);
5155                 // transpose qw
5156                 {
5157                     const float* wptr = qw.has_raw_data() ? (const float*)qw.raw_data().data() : qw.float_data().data();
5158 
5159                     for (int j = 0; j < embed_dim; j++)
5160                     {
5161                         for (int k = 0; k < embed_dim; k++)
5162                         {
5163                             float vb = wptr[k * embed_dim + j];
5164                             fwrite(&vb, sizeof(float), 1, bp);
5165                         }
5166                     }
5167                 }
5168                 fwrite_tensor_proto_data(qb, bp);
5169 
5170                 fwrite(&quantize_tag, sizeof(int), 1, bp);
5171                 // transpose kw
5172                 {
5173                     const float* wptr = kw.has_raw_data() ? (const float*)kw.raw_data().data() : kw.float_data().data();
5174 
5175                     for (int j = 0; j < embed_dim; j++)
5176                     {
5177                         for (int k = 0; k < embed_dim; k++)
5178                         {
5179                             float vb = wptr[k * embed_dim + j];
5180                             fwrite(&vb, sizeof(float), 1, bp);
5181                         }
5182                     }
5183                 }
5184                 fwrite_tensor_proto_data(kb, bp);
5185 
5186                 fwrite(&quantize_tag, sizeof(int), 1, bp);
5187                 // transpose vw
5188                 {
5189                     const float* wptr = vw.has_raw_data() ? (const float*)vw.raw_data().data() : vw.float_data().data();
5190 
5191                     for (int j = 0; j < embed_dim; j++)
5192                     {
5193                         for (int k = 0; k < embed_dim; k++)
5194                         {
5195                             float vb = wptr[k * embed_dim + j];
5196                             fwrite(&vb, sizeof(float), 1, bp);
5197                         }
5198                     }
5199                 }
5200                 fwrite_tensor_proto_data(vb, bp);
5201 
5202                 fwrite(&quantize_tag, sizeof(int), 1, bp);
5203                 // transpose ow
5204                 {
5205                     const float* wptr = ow.has_raw_data() ? (const float*)ow.raw_data().data() : ow.float_data().data();
5206 
5207                     for (int j = 0; j < embed_dim; j++)
5208                     {
5209                         for (int k = 0; k < embed_dim; k++)
5210                         {
5211                             float vb = wptr[k * embed_dim + j];
5212                             fwrite(&vb, sizeof(float), 1, bp);
5213                         }
5214                     }
5215                 }
5216                 fwrite_tensor_proto_data(ob, bp);
5217             }
5218         }
5219         else if (op == "Neg")
5220         {
5221             int op_type = 1;
5222             fprintf(pp, " 0=%d", op_type);
5223         }
5224         else if (op == "Normalize")
5225         {
5226             float eps = get_node_attr_f(node, "eps", 0.f);
5227             int scale_data_size = 1;
5228 
5229             fprintf(pp, " 1=1"); // channel_shared
5230             fprintf(pp, " 2=%e", eps);
5231             fprintf(pp, " 3=%d", scale_data_size);
5232             fprintf(pp, " 9=1"); // TODO hardcode pytorch style
5233 
5234             const float scale_data[1] = {1.f};
5235             fwrite(scale_data, sizeof(float), 1, bp);
5236         }
5237         else if (op == "Pad")
5238         {
5239             std::string mode = get_node_attr_s(node, "mode");
5240             float value = get_node_attr_f(node, "value", 0.f);
5241 
5242             std::vector<int> pads;
5243             if (node.input_size() == 1)
5244             {
5245                 pads = get_node_attr_ai(node, "pads");
5246             }
5247             else
5248             {
5249                 pads = get_node_attr_from_input_ai(weights[node.input(1)]);
5250             }
5251 
5252             int type = 0;
5253             if (mode == "constant")
5254             {
5255                 type = 0;
5256             }
5257             else if (mode == "edge")
5258             {
5259                 type = 1;
5260             }
5261             else if (mode == "reflect")
5262             {
5263                 type = 2;
5264             }
5265 
5266             int pad_size = (int)pads.size();
5267             int top = 0;
5268             int bottom = 0;
5269             int left = 0;
5270             int right = 0;
5271             int front = 0;
5272             int behind = 0;
5273             if (pad_size == 8)
5274             {
5275                 //NCHW
5276                 top = pads[2];
5277                 bottom = pads[6];
5278                 left = pads[3];
5279                 right = pads[7];
5280                 front = pads[1];
5281                 behind = pads[5];
5282             }
5283             else if (pad_size == 6)
5284             {
5285                 //NHW
5286                 top = pads[1];
5287                 bottom = pads[4];
5288                 left = pads[2];
5289                 right = pads[5];
5290             }
5291             else
5292             {
5293                 //NW
5294                 left = pads[1];
5295                 right = pads[3];
5296             }
5297 
5298             fprintf(pp, " 0=%d", top);
5299             fprintf(pp, " 1=%d", bottom);
5300             fprintf(pp, " 2=%d", left);
5301             fprintf(pp, " 3=%d", right);
5302             fprintf(pp, " 4=%d", type);
5303             fprintf(pp, " 5=%e", value);
5304             fprintf(pp, " 7=%d", front);
5305             fprintf(pp, " 8=%d", behind);
5306         }
5307         else if (op == "Pow")
5308         {
5309             int op_type = 6;
5310             fprintf(pp, " 0=%d", op_type);
5311 
5312             int with_scalar = get_node_attr_i(node, "with_scalar", 0);
5313             float b = get_node_attr_f(node, "b", 0.f);
5314             if (with_scalar)
5315             {
5316                 fprintf(pp, " 1=%d", with_scalar);
5317                 fprintf(pp, " 2=%e", b);
5318             }
5319         }
5320         else if (op == "PixelShuffle")
5321         {
5322             int scale_factor = get_node_attr_i(node, "scale_factor", 1);
5323             fprintf(pp, " 0=%d", scale_factor);
5324         }
5325         else if (op == "PRelu")
5326         {
5327             const onnx::TensorProto& slope = weights[node.input(1)];
5328 
5329             int num_slope = get_tensor_proto_data_size(slope);
5330 
5331             fprintf(pp, " 0=%d", num_slope);
5332 
5333             fwrite_tensor_proto_data(slope, bp);
5334         }
5335         else if (op == "Reciprocal")
5336         {
5337             int op_type = 15;
5338             fprintf(pp, " 0=%d", op_type);
5339         }
5340         else if (op == "ReduceMax" || op == "ReduceMin" || op == "ReduceMean" || op == "ReduceProd" || op == "ReduceSum" || op == "ReduceSumSquare" || op == "ReduceL1" || op == "ReduceL2" || op == "ReduceLogSum" || op == "ReduceLogSumExp")
5341         {
5342             int op_type = -233;
5343             if (op == "ReduceSum")
5344                 op_type = 0;
5345             else if (op == "ReduceSumSquare")
5346                 op_type = 2;
5347             else if (op == "ReduceMean")
5348                 op_type = 3;
5349             else if (op == "ReduceMax")
5350                 op_type = 4;
5351             else if (op == "ReduceMin")
5352                 op_type = 5;
5353             else if (op == "ReduceProd")
5354                 op_type = 6;
5355             else if (op == "ReduceL1")
5356                 op_type = 7;
5357             else if (op == "ReduceL2")
5358                 op_type = 8;
5359             else if (op == "ReduceLogSum")
5360                 op_type = 9;
5361             else if (op == "ReduceLogSumExp")
5362                 op_type = 10;
5363             fprintf(pp, " 0=%d", op_type);
5364 
5365             std::vector<int> axes = get_node_attr_ai(node, "axes");
5366             int keepdims = get_node_attr_i(node, "keepdims", 1);
5367 
5368             if (axes.size() > 0)
5369             {
5370                 // if axes set, reduce according to axes
5371                 fprintf(pp, " 1=%d", 0);
5372                 fprintf(pp, " -23303=%zu", axes.size());
5373                 for (size_t j = 0; j < axes.size(); j++)
5374                 {
5375                     if (axes[j] == 0 || axes[j] > 3 || axes[j] < -3)
5376                         fprintf(stderr, "Unsupported reduction axes !\n");
5377                     fprintf(pp, ",%d", axes[j]);
5378                 }
5379             }
5380             else
5381             {
5382                 // if axes not set, reduce all axes by default
5383                 fprintf(pp, " 1=%d", 1);
5384             }
5385             fprintf(pp, " 4=%d", keepdims);
5386         }
5387         else if (op == "Reorg")
5388         {
5389             int stride = get_node_attr_i(node, "stride", 1);
5390             fprintf(pp, " 0=%d", stride);
5391         }
5392         else if (op == "Reshape")
5393         {
5394             std::vector<int> shape;
5395 
5396             if (node.input_size() == 1)
5397             {
5398                 shape = get_node_attr_ai(node, "shape");
5399             }
5400             else
5401             {
5402                 shape = get_node_attr_from_input_ai(weights[node.input(1)]);
5403             }
5404 
5405             if (shape.size() == 1)
5406             {
5407                 fprintf(pp, " 0=%d", shape[0]); // should never reach here
5408             }
5409             else if (shape.size() == 2)
5410             {
5411                 fprintf(pp, " 0=%d", shape[1]);
5412             }
5413             else if (shape.size() == 3)
5414             {
5415                 fprintf(pp, " 0=%d", shape[2]);
5416                 fprintf(pp, " 1=%d", shape[1]);
5417             }
5418             else if (shape.size() == 4)
5419             {
5420                 fprintf(pp, " 0=%d", shape[3]);
5421                 fprintf(pp, " 1=%d", shape[2]);
5422                 fprintf(pp, " 2=%d", shape[1]);
5423             }
5424             else if (shape.size() == 5)
5425             {
5426                 fprintf(pp, " 0=%d", shape[4] * shape[3]);
5427                 fprintf(pp, " 1=%d", shape[2]);
5428                 fprintf(pp, " 2=%d", shape[1]);
5429             }
5430         }
5431         else if (op == "Resize")
5432         {
5433             std::string mode = get_node_attr_s(node, "mode");
5434             std::string align = get_node_attr_s(node, "coordinate_transformation_mode");
5435 
5436             std::vector<float> scales;
5437             std::vector<int> sizes;
5438             if (node.input_size() == 2)
5439             {
5440                 // opset 10
5441                 scales = get_node_attr_from_input_af(weights[node.input(1)]);
5442             }
5443             else
5444             {
5445                 // opset 11+
5446                 scales = get_node_attr_from_input_af(weights[node.input(2)]);
5447                 if (node.input_size() >= 4)
5448                 {
5449                     sizes = get_node_attr_from_input_ai(weights[node.input(3)]);
5450                 }
5451             }
5452 
5453             int resize_type = 1;
5454             if (mode == "nearest")
5455             {
5456                 resize_type = 1;
5457             }
5458             else if (mode == "linear")
5459             {
5460                 resize_type = 2;
5461             }
5462             else if (mode == "cubic")
5463             {
5464                 resize_type = 3;
5465             }
5466 
5467             if (scales.empty() && sizes.empty())
5468             {
5469                 fprintf(stderr, "Unsupported Resize scales and sizes are all empty!\n");
5470             }
5471 
5472             float h_scale = 1.f;
5473             float w_scale = 1.f;
5474             if (scales.size() == 2)
5475             {
5476                 w_scale = scales[1];
5477             }
5478             else if (scales.size() == 3)
5479             {
5480                 h_scale = scales[1];
5481                 w_scale = scales[2];
5482             }
5483             else if (scales.size() == 4)
5484             {
5485                 h_scale = scales[2];
5486                 w_scale = scales[3];
5487 
5488                 if (scales[1] != 1.f)
5489                     fprintf(stderr, "Unsupported Resize scales !\n");
5490             }
5491 
5492             int output_height = 0;
5493             int output_width = 0;
5494             if (sizes.size() == 2)
5495             {
5496                 output_width = sizes[1];
5497             }
5498             else if (sizes.size() == 3)
5499             {
5500                 output_height = sizes[1];
5501                 output_width = sizes[2];
5502             }
5503             else if (sizes.size() == 4)
5504             {
5505                 output_height = sizes[2];
5506                 output_width = sizes[3];
5507             }
5508 
5509             int align_corner = 0;
5510             if (align == "align_corners")
5511             {
5512                 align_corner = 1;
5513             }
5514 
5515             fprintf(pp, " 0=%d", resize_type);
5516             fprintf(pp, " 1=%e", h_scale);
5517             fprintf(pp, " 2=%e", w_scale);
5518             fprintf(pp, " 3=%d", output_height);
5519             fprintf(pp, " 4=%d", output_width);
5520             fprintf(pp, " 6=%d", align_corner);
5521         }
5522         else if (op == "RNN")
5523         {
5524             const onnx::TensorProto& W = weights[node.input(1)];
5525             const onnx::TensorProto& R = weights[node.input(2)];
5526             const onnx::TensorProto& B = weights[node.input(3)];
5527 
5528             int hidden_size = get_node_attr_i(node, "hidden_size", 0);
5529             std::string direction = get_node_attr_s(node, "direction");
5530 
5531             int direction_type = 0;
5532             if (direction == "forward")
5533             {
5534                 direction_type = 0;
5535             }
5536             else if (direction == "reverse")
5537             {
5538                 direction_type = 1;
5539             }
5540             else if (direction == "bidirectional")
5541             {
5542                 direction_type = 2;
5543             }
5544 
5545             int weight_data_size = get_tensor_proto_data_size(W);
5546 
5547             fprintf(pp, " 0=%d", hidden_size);
5548             fprintf(pp, " 1=%d", weight_data_size);
5549             fprintf(pp, " 2=%d", direction_type);
5550 
5551             int num_directions = direction_type == 2 ? 2 : 1;
5552 
5553             int quantize_tag = 0;
5554 
5555             fwrite(&quantize_tag, sizeof(int), 1, bp);
5556             fwrite_tensor_proto_data(W, bp);
5557 
5558             // reduce xc and hc bias
5559             {
5560                 fwrite(&quantize_tag, sizeof(int), 1, bp);
5561 
5562                 int bias_data_size_g = get_tensor_proto_data_size(B) / 2 / num_directions;
5563                 const float* bptr = B.has_raw_data() ? (const float*)B.raw_data().data() : B.float_data().data();
5564                 const float* xiptr = bptr;
5565                 const float* hiptr = bptr + bias_data_size_g;
5566 
5567                 for (int j = 0; j < bias_data_size_g; j++)
5568                 {
5569                     float vb = xiptr[j] + hiptr[j];
5570                     fwrite(&vb, sizeof(float), 1, bp);
5571                 }
5572 
5573                 if (direction_type == 2)
5574                 {
5575                     xiptr += bias_data_size_g * 2;
5576                     hiptr += bias_data_size_g * 2;
5577 
5578                     for (int j = 0; j < bias_data_size_g; j++)
5579                     {
5580                         float vb = xiptr[j] + hiptr[j];
5581                         fwrite(&vb, sizeof(float), 1, bp);
5582                     }
5583                 }
5584             }
5585 
5586             fwrite(&quantize_tag, sizeof(int), 1, bp);
5587             fwrite_tensor_proto_data(R, bp);
5588         }
5589         else if (op == "RDiv")
5590         {
5591             int op_type = 8;
5592             fprintf(pp, " 0=%d", op_type);
5593 
5594             int with_scalar = get_node_attr_i(node, "with_scalar", 0);
5595             float b = get_node_attr_f(node, "b", 0.f);
5596             if (with_scalar)
5597             {
5598                 fprintf(pp, " 1=%d", with_scalar);
5599                 fprintf(pp, " 2=%e", b);
5600             }
5601         }
5602         else if (op == "RSub")
5603         {
5604             int op_type = 7;
5605             fprintf(pp, " 0=%d", op_type);
5606 
5607             int with_scalar = get_node_attr_i(node, "with_scalar", 0);
5608             float b = get_node_attr_f(node, "b", 0.f);
5609             if (with_scalar)
5610             {
5611                 fprintf(pp, " 1=%d", with_scalar);
5612                 fprintf(pp, " 2=%e", b);
5613             }
5614         }
5615         else if (op == "ShuffleChannel")
5616         {
5617             int group = get_node_attr_i(node, "group", 1);
5618             int reverse = get_node_attr_i(node, "reverse", 0);
5619             fprintf(pp, " 0=%d", group);
5620             fprintf(pp, " 1=%d", reverse);
5621         }
5622         else if (op == "Sigmoid")
5623         {
5624             // no param
5625         }
5626         else if (op == "Sin")
5627         {
5628             int op_type = 9;
5629             fprintf(pp, " 0=%d", op_type);
5630         }
5631         else if (op == "SkipLayerNormalization")
5632         {
5633             const onnx::TensorProto& W = weights[node.input(2)];
5634             const onnx::TensorProto& B = weights[node.input(3)];
5635             const onnx::TensorProto& B2 = weights[node.input(4)];
5636 
5637             fprintf(pp, " 0=%d", get_tensor_proto_data_size(B));
5638 
5639             int quantize_tag = 0;
5640             fwrite(&quantize_tag, sizeof(int), 1, bp);
5641 
5642             fwrite_tensor_proto_data(W, bp);
5643 
5644             fwrite(&quantize_tag, sizeof(int), 1, bp);
5645 
5646             fwrite_tensor_proto_data(B, bp);
5647 
5648             fwrite(&quantize_tag, sizeof(int), 1, bp);
5649 
5650             fwrite_tensor_proto_data(B2, bp);
5651         }
5652         else if (op == "Slice")
5653         {
5654             std::vector<int> starts;
5655             std::vector<int> ends;
5656             std::vector<int> axes;
5657             std::vector<int> steps;
5658             if (node.input_size() == 1)
5659             {
5660                 starts = get_node_attr_ai(node, "starts");
5661                 ends = get_node_attr_ai(node, "ends");
5662                 axes = get_node_attr_ai(node, "axes");
5663                 steps = get_node_attr_ai(node, "steps"); // TODO
5664             }
5665             else
5666             {
5667                 starts = get_node_attr_from_input_ai(weights[node.input(1)]);
5668                 ends = get_node_attr_from_input_ai(weights[node.input(2)]);
5669                 if (node.input_size() >= 4)
5670                     axes = get_node_attr_from_input_ai(weights[node.input(3)]);
5671                 if (node.input_size() >= 5)
5672                     steps = get_node_attr_from_input_ai(weights[node.input(4)]);
5673             }
5674 
5675             // assert step == 1
5676             for (int i = 0; i < (int)steps.size(); i++)
5677             {
5678                 if (steps[i] != 1)
5679                     fprintf(stderr, "Unsupported slice step !\n");
5680             }
5681 
5682             // filter out N-dim axis
5683             if (!axes.empty())
5684             {
5685                 for (int i = 0; i < (int)axes.size(); i++)
5686                 {
5687                     int axis = axes[i];
5688                     if (axis == 0)
5689                     {
5690                         starts.erase(starts.begin() + i);
5691                         ends.erase(ends.begin() + i);
5692                         axes.erase(axes.begin() + i);
5693                         break;
5694                     }
5695                 }
5696             }
5697 
5698             fprintf(pp, " -23309=%d", (int)starts.size());
5699             for (int i = 0; i < (int)starts.size(); i++)
5700             {
5701                 fprintf(pp, ",%d", starts[i]);
5702             }
5703             fprintf(pp, " -23310=%d", (int)ends.size());
5704             for (int i = 0; i < (int)ends.size(); i++)
5705             {
5706                 fprintf(pp, ",%d", ends[i]);
5707             }
5708             if (!axes.empty())
5709             {
5710                 fprintf(pp, " -23311=%d", (int)axes.size());
5711                 for (int i = 0; i < (int)axes.size(); i++)
5712                 {
5713                     int axis = axes[i];
5714                     if (axis == 0 || axis > 3 || axis < -3)
5715                         fprintf(stderr, "Unsupported slice axes !\n");
5716 
5717                     if (axis > 0)
5718                         axis = axis - 1; // -1 for skip N-dim
5719 
5720                     fprintf(pp, ",%d", axis);
5721                 }
5722             }
5723         }
5724         else if (op == "Softmax")
5725         {
5726             int axis = get_node_attr_i(node, "axis", 1);
5727             fprintf(pp, " 0=%d", axis - 1);
5728             fprintf(pp, " 1=1");
5729         }
5730         else if (op == "Split")
5731         {
5732             int axis = get_node_attr_i(node, "axis", 0);
5733             std::vector<int> split = get_node_attr_ai(node, "split");
5734             if (axis < 1)
5735                 fprintf(stderr, "Unsupported split axis !\n");
5736 
5737             fprintf(pp, " -23300=%d", output_size);
5738             if (split.empty())
5739             {
5740                 for (int i = 0; i < output_size; i++)
5741                 {
5742                     fprintf(pp, ",-233");
5743                 }
5744             }
5745             else
5746             {
5747                 for (size_t i = 0; i < split.size() - 1; i++)
5748                 {
5749                     fprintf(pp, ",%d", split[i]);
5750                 }
5751                 fprintf(pp, ",-233");
5752             }
5753             fprintf(pp, " 1=%d", axis - 1);
5754         }
5755         else if (op == "Sqrt")
5756         {
5757             int op_type = 5;
5758             fprintf(pp, " 0=%d", op_type);
5759         }
5760         else if (op == "Squeeze")
5761         {
5762             std::vector<int> axes = get_node_attr_ai(node, "axes");
5763 
5764             if (axes.empty())
5765             {
5766                 fprintf(pp, " 0=1");
5767                 fprintf(pp, " 1=1");
5768                 fprintf(pp, " 2=1");
5769             }
5770             else
5771             {
5772                 fprintf(pp, " -23303=%zu", axes.size());
5773                 for (int i = 0; i < (int)axes.size(); i++)
5774                 {
5775                     if (axes[i] == 0 || axes[i] > 3 || axes[i] < -3)
5776                         fprintf(stderr, "Unsupported squeeze axes !\n");
5777                     fprintf(pp, ",%d", axes[i]);
5778                 }
5779             }
5780         }
5781         else if (op == "Sub")
5782         {
5783             int op_type = 1;
5784             fprintf(pp, " 0=%d", op_type);
5785 
5786             int with_scalar = get_node_attr_i(node, "with_scalar", 0);
5787             float b = get_node_attr_f(node, "b", 0.f);
5788             if (with_scalar)
5789             {
5790                 fprintf(pp, " 1=%d", with_scalar);
5791                 fprintf(pp, " 2=%e", b);
5792             }
5793         }
5794         else if (op == "Sum")
5795         {
5796             int op_type = 1;
5797             fprintf(pp, " 0=%d", op_type);
5798         }
5799         else if (op == "Swish")
5800         {
5801             // no param
5802         }
5803         else if (op == "Tan")
5804         {
5805             int op_type = 11;
5806             fprintf(pp, " 0=%d", op_type);
5807         }
5808         else if (op == "Tanh")
5809         {
5810             int op_type = 16;
5811             fprintf(pp, " 0=%d", op_type);
5812         }
5813         else if (op == "Transpose")
5814         {
5815             std::vector<int> perm = get_node_attr_ai(node, "perm");
5816 
5817             if (perm.size() == 3)
5818             {
5819                 if (perm[1] == 1 && perm[2] == 2)
5820                     fprintf(pp, " 0=0"); // w h
5821                 else if (perm[1] == 2 && perm[2] == 1)
5822                     fprintf(pp, " 0=1"); // h w
5823                 else if (perm[0] == 1 && perm[1] == 0 && perm[2] == 2)
5824                     fprintf(pp, " 0=0"); // w h
5825                 else if (perm[0] == 2 && perm[1] == 0 && perm[2] == 1)
5826                     fprintf(pp, " 0=1"); // h w
5827             }
5828             else if (perm.size() == 4)
5829             {
5830                 if (perm[1] == 1 && perm[2] == 2 && perm[3] == 3)
5831                     fprintf(pp, " 0=0"); // w h c
5832                 else if (perm[1] == 1 && perm[2] == 3 && perm[3] == 2)
5833                     fprintf(pp, " 0=1"); // h w c
5834                 else if (perm[1] == 2 && perm[2] == 1 && perm[3] == 3)
5835                     fprintf(pp, " 0=2"); // w c h
5836                 else if (perm[1] == 2 && perm[2] == 3 && perm[3] == 1)
5837                     fprintf(pp, " 0=3"); // c w h
5838                 else if (perm[1] == 3 && perm[2] == 1 && perm[3] == 2)
5839                     fprintf(pp, " 0=4"); // h c w
5840                 else if (perm[1] == 3 && perm[2] == 2 && perm[3] == 1)
5841                     fprintf(pp, " 0=5"); // c h w
5842             }
5843             else if (perm.size() == 5)
5844             {
5845                 if (perm[1] == 1 && perm[2] == 2 && perm[3] == 3 && perm[4] == 4)
5846                     fprintf(pp, " 0=0"); // wx h c
5847                 else if (perm[1] == 1 && perm[2] == 3 && perm[3] == 4 && perm[4] == 2)
5848                     fprintf(pp, " 0=1"); // h wx c
5849                 else if (perm[1] == 2 && perm[2] == 1 && perm[3] == 3 && perm[4] == 4)
5850                     fprintf(pp, " 0=2"); // wx c h
5851                 else if (perm[1] == 2 && perm[2] == 3 && perm[3] == 4 && perm[4] == 1)
5852                     fprintf(pp, " 0=3"); // c wx h
5853                 else if (perm[1] == 3 && perm[2] == 4 && perm[3] == 1 && perm[4] == 2)
5854                     fprintf(pp, " 0=4"); // h c wx
5855                 else if (perm[1] == 3 && perm[2] == 4 && perm[3] == 2 && perm[4] == 1)
5856                     fprintf(pp, " 0=5"); // c h wx
5857                 else
5858                     fprintf(stderr, "Unsupported transpose type !\n");
5859             }
5860         }
5861         else if (op == "Upsample")
5862         {
5863             std::string mode = get_node_attr_s(node, "mode");
5864             std::string align = get_node_attr_s(node, "coordinate_transformation_mode");
5865 
5866             std::vector<float> scales;
5867 
5868             if (node.input_size() == 1)
5869             {
5870                 scales = get_node_attr_af(node, "scales");
5871             }
5872             else
5873             {
5874                 scales = get_node_attr_from_input_af(weights[node.input(1)]);
5875             }
5876 
5877             int resize_type = 1;
5878             if (mode == "nearest")
5879             {
5880                 resize_type = 1;
5881             }
5882             else if (mode == "bilinear" || mode == "linear")
5883             {
5884                 resize_type = 2;
5885             }
5886             else if (mode == "trilinear")
5887             {
5888                 fprintf(stderr, "Unsupported Upsample mode !\n");
5889             }
5890 
5891             float h_scale = 1.f;
5892             float w_scale = 1.f;
5893             if (scales.size() == 2)
5894             {
5895                 w_scale = scales[1];
5896             }
5897             else if (scales.size() == 3)
5898             {
5899                 h_scale = scales[1];
5900                 w_scale = scales[2];
5901             }
5902             else if (scales.size() == 4)
5903             {
5904                 h_scale = scales[2];
5905                 w_scale = scales[3];
5906 
5907                 if (scales[1] != 1.f)
5908                     fprintf(stderr, "Unsupported Upsample scales !\n");
5909             }
5910             else
5911             {
5912                 fprintf(stderr, "Unsupported Upsample scales !\n");
5913             }
5914 
5915             int align_corner = 0;
5916             if (align == "align_corners")
5917             {
5918                 align_corner = 1;
5919             }
5920 
5921             fprintf(pp, " 0=%d", resize_type);
5922             fprintf(pp, " 1=%e", h_scale);
5923             fprintf(pp, " 2=%e", w_scale);
5924             fprintf(pp, " 6=%d", align_corner);
5925         }
5926         else if (op == "Unsqueeze")
5927         {
5928             std::vector<int> axes = get_node_attr_ai(node, "axes");
5929 
5930             fprintf(pp, " -23303=%zu", axes.size());
5931             for (int i = 0; i < (int)axes.size(); i++)
5932             {
5933                 if (axes[i] == 0 || axes[i] > 4 || axes[i] < -4)
5934                     fprintf(stderr, "Unsupported unsqueeze axes !\n");
5935                 fprintf(pp, ",%d", axes[i]);
5936             }
5937         }
5938         else
5939         {
5940             // TODO op specific param
5941             for (int j = 0; j < node.attribute_size(); j++)
5942             {
5943                 const onnx::AttributeProto& attr = node.attribute(j);
5944                 if (attr.type() == 1)
5945                 {
5946                     fprintf(stderr, "  # %s=%g\n", attr.name().c_str(), attr.f());
5947                 }
5948                 else if (attr.type() == 2)
5949                 {
5950                     fprintf(stderr, "  # %s=%lld\n", attr.name().c_str(), (long long)attr.i());
5951                 }
5952                 else if (attr.type() == 3)
5953                 {
5954                     fprintf(stderr, "  # %s=%s\n", attr.name().c_str(), attr.s().c_str());
5955                 }
5956                 else
5957                 {
5958                     fprintf(stderr, "  # %s %d\n", attr.name().c_str(), attr.type());
5959                 }
5960             }
5961         }
5962 
5963         fprintf(pp, "\n");
5964 
5965         for (int j = 0; j < output_size; j++)
5966         {
5967             const std::string& output_name = node.output(j);
5968             if (node_reference.find(output_name) != node_reference.end())
5969             {
5970                 int refcount = node_reference[output_name];
5971                 if (refcount > 1)
5972                 {
5973                     char splitname[256];
5974                     sprintf(splitname, "splitncnn_%d", internal_split);
5975                     fprintf(pp, "%-16s %-24s %d %d", "Split", splitname, 1, refcount);
5976 
5977                     fprintf(pp, " %s", output_name.c_str());
5978 
5979                     for (int k = 0; k < refcount; k++)
5980                     {
5981                         fprintf(pp, " %s_splitncnn_%d", output_name.c_str(), k);
5982                     }
5983                     fprintf(pp, "\n");
5984 
5985                     internal_split++;
5986                 }
5987             }
5988         }
5989     }
5990 
5991     fclose(pp);
5992     fclose(bp);
5993 
5994     return 0;
5995 }
5996