1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14 
15 #include "onnx.pb.h"
16 
17 #include <algorithm>
18 #include <float.h>
19 #include <fstream>
20 #include <google/protobuf/io/coded_stream.h>
21 #include <google/protobuf/io/zero_copy_stream_impl.h>
22 #include <google/protobuf/message.h>
23 #include <google/protobuf/text_format.h>
24 #include <iostream>
25 #include <limits.h>
26 #include <limits>
27 #include <set>
28 #include <stdio.h>
29 
read_proto_from_binary(const char * filepath,onnx::ModelProto * message)30 static bool read_proto_from_binary(const char* filepath, onnx::ModelProto* message)
31 {
32     std::ifstream fs(filepath, std::ifstream::in | std::ifstream::binary);
33     if (!fs.is_open())
34     {
35         fprintf(stderr, "open failed %s\n", filepath);
36         return false;
37     }
38 
39     google::protobuf::io::IstreamInputStream input(&fs);
40     google::protobuf::io::CodedInputStream codedstr(&input);
41 
42 #if GOOGLE_PROTOBUF_VERSION >= 3011000
43     codedstr.SetTotalBytesLimit(INT_MAX);
44 #else
45     codedstr.SetTotalBytesLimit(INT_MAX, INT_MAX / 2);
46 #endif
47 
48     bool success = message->ParseFromCodedStream(&codedstr);
49 
50     fs.close();
51 
52     return success;
53 }
54 
get_node_attr_ai(const onnx::NodeProto & node,const char * key)55 static std::vector<int> get_node_attr_ai(const onnx::NodeProto& node, const char* key)
56 {
57     std::vector<int> v;
58 
59     for (int i = 0; i < node.attribute_size(); i++)
60     {
61         const onnx::AttributeProto& attr = node.attribute(i);
62         if (attr.name() == key)
63         {
64             v.resize(attr.ints_size());
65             for (int j = 0; j < attr.ints_size(); j++)
66             {
67                 v[j] = std::max(std::min(attr.ints(j), (::google::protobuf::int64)INT_MAX), (::google::protobuf::int64)INT_MIN);
68             }
69 
70             break;
71         }
72     }
73 
74     return v;
75 }
76 
get_node_attr_af(const onnx::NodeProto & node,const char * key)77 static std::vector<float> get_node_attr_af(const onnx::NodeProto& node, const char* key)
78 {
79     std::vector<float> v;
80 
81     for (int i = 0; i < node.attribute_size(); i++)
82     {
83         const onnx::AttributeProto& attr = node.attribute(i);
84         if (attr.name() == key)
85         {
86             v.resize(attr.floats_size());
87             for (int j = 0; j < attr.floats_size(); j++)
88             {
89                 v[j] = attr.floats(j);
90             }
91 
92             break;
93         }
94     }
95 
96     return v;
97 }
98 
get_node_attr_i(const onnx::NodeProto & node,const char * key,int def=0)99 static int get_node_attr_i(const onnx::NodeProto& node, const char* key, int def = 0)
100 {
101     for (int i = 0; i < node.attribute_size(); i++)
102     {
103         const onnx::AttributeProto& attr = node.attribute(i);
104         if (attr.name() == key)
105         {
106             return std::max(std::min(attr.i(), (::google::protobuf::int64)INT_MAX), (::google::protobuf::int64)INT_MIN);
107         }
108     }
109 
110     return def;
111 }
112 
get_node_attr_f(const onnx::NodeProto & node,const char * key,float def=0.f)113 static float get_node_attr_f(const onnx::NodeProto& node, const char* key, float def = 0.f)
114 {
115     for (int i = 0; i < node.attribute_size(); i++)
116     {
117         const onnx::AttributeProto& attr = node.attribute(i);
118         if (attr.name() == key)
119         {
120             return attr.f();
121         }
122     }
123 
124     return def;
125 }
126 
get_node_attr_s(const onnx::NodeProto & node,const char * key,const std::string & def=std::string ())127 static std::string get_node_attr_s(const onnx::NodeProto& node, const char* key, const std::string& def = std::string())
128 {
129     for (int i = 0; i < node.attribute_size(); i++)
130     {
131         const onnx::AttributeProto& attr = node.attribute(i);
132         if (attr.name() == key)
133         {
134             return attr.s();
135         }
136     }
137 
138     return def;
139 }
140 
get_node_attr_tensor(const onnx::NodeProto & node,const char * key)141 static onnx::TensorProto get_node_attr_tensor(const onnx::NodeProto& node, const char* key)
142 {
143     for (int i = 0; i < node.attribute_size(); i++)
144     {
145         const onnx::AttributeProto& attr = node.attribute(i);
146         if (attr.name() == key)
147         {
148             return attr.t();
149         }
150     }
151 
152     return onnx::TensorProto();
153 }
154 
get_node_attr_from_input_f(const onnx::TensorProto & tp)155 static float get_node_attr_from_input_f(const onnx::TensorProto& tp)
156 {
157     float v = 0.f;
158 
159     // float
160     if (tp.data_type() == 1)
161     {
162         const float* shape_data = 0;
163         if (tp.has_raw_data())
164         {
165             shape_data = (const float*)tp.raw_data().data();
166         }
167         else
168         {
169             shape_data = tp.float_data().data();
170         }
171         v = shape_data[0];
172     }
173     // double
174     else if (tp.data_type() == 11)
175     {
176         const double* shape_data = 0;
177         if (tp.has_raw_data())
178         {
179             shape_data = (const double*)tp.raw_data().data();
180         }
181         else
182         {
183             shape_data = tp.double_data().data();
184         }
185         v = shape_data[0];
186     }
187     // int64
188     else if (tp.data_type() == 7)
189     {
190         const int64_t* shape_data = 0;
191         if (tp.has_raw_data())
192         {
193             shape_data = (const int64_t*)tp.raw_data().data();
194         }
195         else
196         {
197             shape_data = tp.int64_data().data();
198         }
199         v = std::max(std::min(shape_data[0], (::google::protobuf::int64)INT_MAX), (::google::protobuf::int64)INT_MIN);
200     }
201     // int32
202     else if (tp.data_type() == 6)
203     {
204         const int32_t* shape_data = 0;
205         if (tp.has_raw_data())
206         {
207             shape_data = (const int32_t*)tp.raw_data().data();
208         }
209         else
210         {
211             shape_data = tp.int32_data().data();
212         }
213         v = shape_data[0];
214     }
215     else
216     {
217         fprintf(stderr, "Unknown data type %d\n", tp.data_type());
218         abort();
219     }
220 
221     return v;
222 }
223 
get_node_attr_from_input_ai(const onnx::TensorProto & tp)224 static std::vector<int> get_node_attr_from_input_ai(const onnx::TensorProto& tp)
225 {
226     int size = 0;
227 
228     std::vector<int> v;
229 
230     // int64
231     if (tp.data_type() == 7)
232     {
233         const int64_t* shape_data = 0;
234         if (tp.has_raw_data())
235         {
236             shape_data = (const int64_t*)tp.raw_data().data();
237             size = (int)(tp.raw_data().size() / 8);
238         }
239         else
240         {
241             shape_data = tp.int64_data().data();
242             size = tp.int64_data_size();
243         }
244         for (int j = 0; j < size; j++)
245         {
246             int vi = std::max(std::min(shape_data[j], (::google::protobuf::int64)INT_MAX), (::google::protobuf::int64)INT_MIN);
247             v.push_back(vi);
248         }
249     }
250     // int32
251     else if (tp.data_type() == 6)
252     {
253         const int32_t* shape_data = 0;
254         if (tp.has_raw_data())
255         {
256             shape_data = (const int32_t*)tp.raw_data().data();
257             size = (int)(tp.raw_data().size() / 4);
258         }
259         else
260         {
261             shape_data = tp.int32_data().data();
262             size = tp.int32_data_size();
263         }
264         for (int j = 0; j < size; j++)
265         {
266             v.push_back(shape_data[j]);
267         }
268     }
269     else
270     {
271         fprintf(stderr, "Unknown data type %d\n", tp.data_type());
272     }
273 
274     return v;
275 }
276 
get_node_attr_from_input_af(const onnx::TensorProto & tp)277 static std::vector<float> get_node_attr_from_input_af(const onnx::TensorProto& tp)
278 {
279     int size = 0;
280 
281     std::vector<float> v;
282 
283     // float
284     if (tp.data_type() == 1)
285     {
286         const float* shape_data = 0;
287         if (tp.has_raw_data())
288         {
289             shape_data = (const float*)tp.raw_data().data();
290             size = (int)(tp.raw_data().size() / 4);
291         }
292         else
293         {
294             shape_data = tp.float_data().data();
295             size = tp.float_data_size();
296         }
297         for (int j = 0; j < size; j++)
298         {
299             v.push_back(shape_data[j]);
300         }
301     }
302     // double
303     else if (tp.data_type() == 11)
304     {
305         const double* shape_data = 0;
306         if (tp.has_raw_data())
307         {
308             shape_data = (const double*)tp.raw_data().data();
309             size = (int)(tp.raw_data().size() / 8);
310         }
311         else
312         {
313             shape_data = tp.double_data().data();
314             size = tp.double_data_size();
315         }
316         for (int j = 0; j < size; j++)
317         {
318             v.push_back((float)shape_data[j]);
319         }
320     }
321     else
322     {
323         fprintf(stderr, "Unknown data type %d\n", tp.data_type());
324     }
325 
326     return v;
327 }
328 
get_tensor_proto_data_size(const onnx::TensorProto & tp)329 static int get_tensor_proto_data_size(const onnx::TensorProto& tp)
330 {
331     if (tp.has_raw_data())
332     {
333         const std::string& raw_data = tp.raw_data();
334         int size = (int)raw_data.size() / 4;
335         return size;
336     }
337     else if (tp.data_type() == 1)
338     {
339         return tp.float_data_size();
340     }
341 
342     return 0;
343 }
344 
fwrite_tensor_proto_data(const onnx::TensorProto & tp,FILE * bp)345 static void fwrite_tensor_proto_data(const onnx::TensorProto& tp, FILE* bp)
346 {
347     int size = get_tensor_proto_data_size(tp);
348 
349     if (tp.has_raw_data())
350     {
351         const std::string& raw_data = tp.raw_data();
352         fwrite(raw_data.data(), sizeof(float), size, bp);
353     }
354     else if (tp.data_type() == 1)
355     {
356         fwrite(tp.float_data().data(), sizeof(float), size, bp);
357     }
358 }
359 
fuse_weight_reshape(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)360 static void fuse_weight_reshape(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
361 {
362     int node_count = mutable_graph->node_size();
363     for (int i = 0; i < node_count; i++)
364     {
365         onnx::NodeProto* node = mutable_graph->mutable_node(i);
366 
367         // weight <= Reshape(weight)
368         if (node->op_type() == "Reshape")
369         {
370             // check weight
371             if (weights.find(node->input(0)) == weights.end())
372                 continue;
373 
374             weights[node->output(0)] = weights[node->input(0)];
375 
376             // set weight shape directly
377             std::vector<int> shape;
378             if (node->input_size() == 1)
379             {
380                 shape = get_node_attr_ai(*node, "shape");
381             }
382             else if (node->input_size() == 2)
383             {
384                 // opset 5
385                 shape = get_node_attr_from_input_ai(weights[node->input(1)]);
386             }
387 
388             weights[node->output(0)].clear_dims();
389             for (int j = 0; j < shape.size(); j++)
390             {
391                 weights[node->output(0)].add_dims(shape[j]);
392             }
393 
394             // reduce
395             node->set_op_type("noop_reducedncnn");
396 
397             node_reference[node->input(0)] -= 1;
398             if (node->input_size() == 2)
399             {
400                 node_reference[node->input(1)] -= 1;
401             }
402 
403             reduced_node_count += 1;
404             i += 1;
405         }
406     }
407 }
408 
fuse_weight_transpose(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)409 static void fuse_weight_transpose(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
410 {
411     int node_count = mutable_graph->node_size();
412     for (int i = 0; i < node_count; i++)
413     {
414         onnx::NodeProto* node = mutable_graph->mutable_node(i);
415 
416         // weight <= Transpose(weight)
417         if (node->op_type() == "Transpose")
418         {
419             // check weight
420             if (weights.find(node->input(0)) == weights.end())
421                 continue;
422 
423             if (weights[node->input(0)].dims_size() != 2)
424                 continue;
425 
426             // perm = (1, 0)
427             std::vector<int> perm = get_node_attr_ai(*node, "perm");
428             if (perm.size() != 2)
429                 continue;
430             if (perm[0] != 1 || perm[1] != 0)
431                 continue;
432 
433             weights[node->output(0)] = weights[node->input(0)];
434 
435             // permute weight
436             {
437                 onnx::TensorProto& B = weights[node->output(0)];
438 
439                 const int h = B.dims(0);
440                 const int w = B.dims(1);
441 
442                 std::vector<float> permuted_data;
443                 permuted_data.reserve((size_t)h * w);
444                 const float* bptr = B.has_raw_data() ? (const float*)B.raw_data().data() : B.float_data().data();
445 
446                 for (int j = 0; j < w; j++)
447                 {
448                     for (int k = 0; k < h; k++)
449                     {
450                         float vb = bptr[k * w + j];
451                         permuted_data.push_back(vb);
452                     }
453                 }
454 
455                 B.set_dims(0, w);
456                 B.set_dims(1, h);
457 
458                 if (B.has_raw_data())
459                 {
460                     B.set_raw_data(permuted_data.data(), permuted_data.size() * sizeof(float));
461                 }
462                 else
463                 {
464                     for (int j = 0; j < (int)permuted_data.size(); j++)
465                         B.set_float_data(j, permuted_data[j]);
466                 }
467             }
468 
469             // reduce
470             node->set_op_type("noop_reducedncnn");
471 
472             node_reference[node->input(0)] -= 1;
473 
474             reduced_node_count += 1;
475             i += 1;
476         }
477     }
478 }
479 
fuse_shufflechannel(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)480 static void fuse_shufflechannel(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
481 {
482     int node_count = mutable_graph->node_size();
483     for (int i = 0; i < node_count; i++)
484     {
485         onnx::NodeProto* node = mutable_graph->mutable_node(i);
486 
487         // ShuffleChannel <= Reshape - Transpose - Reshape
488         // ShuffleChannel <= Reshape - Transpose - Constant - Reshape
489         if (node->op_type() == "Reshape")
490         {
491             if (node_reference[node->output(0)] != 1)
492                 continue;
493 
494             std::vector<int> shape;
495             if (node->input_size() == 1)
496             {
497                 shape = get_node_attr_ai(*node, "shape");
498             }
499             else
500             {
501                 // skip weight reshape
502                 if (weights.find(node->input(1)) == weights.end())
503                     continue;
504 
505                 shape = get_node_attr_from_input_ai(weights[node->input(1)]);
506             }
507 
508             // 1 groups channels_per_group, height, width
509             // reverse style = channels_per_group, groups, height * width
510             if (shape.size() != 5 && shape.size() != 3)
511                 continue;
512 
513             if (shape.size() == 5 && shape[0] != 1)
514                 continue;
515 
516             if (i + 2 >= node_count)
517                 continue;
518 
519             onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
520             onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
521 
522             if (node3->op_type() == "Constant")
523             {
524                 if (i + 3 >= node_count)
525                     continue;
526 
527                 node3 = mutable_graph->mutable_node(i + 3);
528             }
529 
530             if (node2->op_type() != "Transpose" || node3->op_type() != "Reshape")
531                 continue;
532 
533             if (node_reference[node2->output(0)] != 1)
534                 continue;
535 
536             // 0 2 1 3 4
537             // reverse style = 1 0 2
538             std::vector<int> perm = get_node_attr_ai(*node2, "perm");
539             if (perm.size() != 5 && perm.size() != 3)
540                 continue;
541 
542             if (perm.size() == 5 && (perm[0] != 0 || perm[1] != 2 || perm[2] != 1 || perm[3] != 3 || perm[4] != 4))
543                 continue;
544 
545             if (perm.size() == 3 && (perm[0] != 1 || perm[1] != 0 || perm[2] != 2))
546                 continue;
547 
548             std::vector<int> shape3;
549             if (node3->input_size() == 1)
550             {
551                 shape3 = get_node_attr_ai(*node3, "shape");
552             }
553             else
554             {
555                 // skip weight reshape
556                 if (weights.find(node3->input(1)) == weights.end())
557                     continue;
558 
559                 shape3 = get_node_attr_from_input_ai(weights[node3->input(1)]);
560             }
561 
562             // 1, -1, height, width
563             // reverse style = group, -1, channels_per_group, height, width
564             if (shape3.size() != 4 && shape3.size() != 5)
565                 continue;
566 
567             if (shape3.size() == 4 && (shape3[0] != 1 || (shape3[1] != -1 && shape3[1] != shape[1] * shape[2])))
568                 continue;
569 
570             if (shape3.size() == 5 && (shape3[0] != shape[1] || shape3[2] != shape[0] || shape3[3] * shape3[4] != shape[2]))
571                 continue;
572 
573             // reduce
574             node->set_op_type("noop_reducedncnn");
575             node2->set_op_type("noop_reducedncnn");
576 
577             if (node->input_size() == 2)
578             {
579                 node_reference[node->input(1)] -= 1;
580             }
581             node_reference[node->output(0)] -= 1;
582             node_reference[node2->output(0)] -= 1;
583             if (node3->input_size() == 2)
584             {
585                 node_reference[node3->input(1)] -= 1;
586             }
587 
588             blob_names.erase(node->output(0));
589             blob_names.erase(node2->output(0));
590 
591             node3->set_op_type("ShuffleChannel");
592             node3->set_input(0, node->input(0));
593 
594             onnx::AttributeProto* attr_group = node3->add_attribute();
595             attr_group->set_name("group");
596             attr_group->set_i(shape[1]);
597 
598             onnx::AttributeProto* attr_reverse = node3->add_attribute();
599             attr_reverse->set_name("reverse");
600             attr_reverse->set_i(shape.size() == 3);
601 
602             reduced_node_count += 2;
603             i += 2;
604         }
605     }
606 }
607 
fuse_shufflechannel_split(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)608 static void fuse_shufflechannel_split(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
609 {
610     int node_count = mutable_graph->node_size();
611     for (int i = 0; i < node_count; i++)
612     {
613         onnx::NodeProto* node = mutable_graph->mutable_node(i);
614 
615         // Split <= ShuffleChannel(reverse type) - Gather(0) - Gather(1)
616         if (node->op_type() == "ShuffleChannel")
617         {
618             // reverse = 1
619             int reverse = get_node_attr_i(*node, "reverse");
620             if (reverse != 1)
621                 continue;
622 
623             if (i + 2 >= node_count)
624                 continue;
625 
626             onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
627             onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
628 
629             if (node2->op_type() != "Gather" || node3->op_type() != "Gather")
630                 continue;
631 
632             if (node2->input(0) != node->output(0) || node3->input(0) != node->output(0))
633                 continue;
634 
635             // axis = 0
636             int gather2_axis = get_node_attr_i(*node2, "axis");
637             if (gather2_axis != 0)
638                 continue;
639 
640             // indices = 0
641             if (weights.find(node2->input(1)) == weights.end())
642                 continue;
643 
644             std::vector<int> gather2_indices = get_node_attr_from_input_ai(weights[node2->input(1)]);
645             if (gather2_indices.size() != 1 || gather2_indices[0] != 0)
646                 continue;
647 
648             // axis = 0
649             int gather3_axis = get_node_attr_i(*node3, "axis");
650             if (gather3_axis != 0)
651                 continue;
652 
653             // indices = 1
654             if (weights.find(node3->input(1)) == weights.end())
655                 continue;
656 
657             std::vector<int> gather3_indices = get_node_attr_from_input_ai(weights[node3->input(1)]);
658             if (gather3_indices.size() != 1 || gather3_indices[0] != 1)
659                 continue;
660 
661             // reduce
662             node2->set_op_type("noop_reducedncnn");
663 
664             node_reference[node->output(0)] -= 2;
665             node_reference[node2->input(1)] -= 1;
666             node_reference[node3->input(1)] -= 1;
667 
668             node3->set_op_type("Split");
669             node3->clear_input();
670             node3->add_input(node->output(0));
671             node3->add_output(node3->output(0));
672             node3->set_output(0, node2->output(0));
673 
674             node3->clear_attribute();
675             onnx::AttributeProto* attr_axis = node3->add_attribute();
676             attr_axis->set_name("axis");
677             attr_axis->set_i(1);
678 
679             reduced_node_count += 1;
680             i += 1;
681         }
682     }
683 }
684 
fuse_hardswish(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)685 static void fuse_hardswish(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
686 {
687     int node_count = mutable_graph->node_size();
688     for (int i = 0; i < node_count; i++)
689     {
690         onnx::NodeProto* node = mutable_graph->mutable_node(i);
691 
692         // HardSwish <= Add(+3) - Clip(0,6) - Mul(X,) - Div(/6)
693         // HardSwish <= Add(+3) - Clip(0,6) - Mul(X,) - Mul(*(1/6))
694         // HardSwish <= Add(+3) - Clip(0,6) - Mul(X,) - Constant - Div(/6)
695         // HardSwish <= Add(+3) - Clip(0,6) - Mul(X,) - Constant - Mul(*(1/6))
696         //     out = x * F.relu6(x + 3, inplace=True) / 6
697         if (node->op_type() == "Add")
698         {
699             if (node_reference[node->output(0)] != 1)
700                 continue;
701 
702             if (i + 3 >= node_count)
703                 continue;
704 
705             if (weights.find(node->input(1)) == weights.end())
706                 continue;
707 
708             const onnx::TensorProto& add_three = weights[node->input(1)];
709             if (add_three.dims_size() != 0 || get_tensor_proto_data_size(add_three) != 1)
710                 continue;
711 
712             float constant_add_three = get_node_attr_from_input_f(add_three);
713             if (constant_add_three != 3.f)
714                 continue;
715 
716             onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
717             onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
718             onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3);
719 
720             if (node4->op_type() == "Constant")
721             {
722                 if (i + 4 >= node_count)
723                     continue;
724 
725                 node4 = mutable_graph->mutable_node(i + 4);
726             }
727 
728             if (node2->op_type() != "Clip" || node3->op_type() != "Mul" || (node4->op_type() != "Div" && node4->op_type() != "Mul"))
729                 continue;
730 
731             if (node_reference[node2->output(0)] != 1)
732                 continue;
733 
734             float relu6_min;
735             float relu6_max;
736             if (node2->input_size() == 1)
737             {
738                 relu6_min = get_node_attr_f(*node2, "min", -FLT_MAX);
739                 relu6_max = get_node_attr_f(*node2, "max", FLT_MAX);
740             }
741             else
742             {
743                 const onnx::TensorProto& min_tp = weights[node2->input(1)];
744                 const onnx::TensorProto& max_tp = weights[node2->input(2)];
745 
746                 relu6_min = get_node_attr_from_input_f(min_tp);
747                 relu6_max = get_node_attr_from_input_f(max_tp);
748             }
749             if (relu6_min != 0.f || relu6_max != 6.f)
750                 continue;
751 
752             if (node_reference[node3->output(0)] != 1)
753                 continue;
754 
755             if (node3->input(0) != node->input(0) || node3->input(1) != node2->output(0))
756                 continue;
757 
758             if (weights.find(node4->input(1)) == weights.end())
759                 continue;
760 
761             const onnx::TensorProto& div_six = weights[node4->input(1)];
762             if (div_six.dims_size() != 0 || get_tensor_proto_data_size(div_six) != 1)
763                 continue;
764 
765             float constant_div_six = get_node_attr_from_input_f(div_six);
766             if (node4->op_type() == "Div" && constant_div_six != 6.f)
767                 continue;
768             if (node4->op_type() == "Mul" && constant_div_six != 1 / 6.f)
769                 continue;
770 
771             // reduce
772             node->set_op_type("noop_reducedncnn");
773             node2->set_op_type("noop_reducedncnn");
774             node3->set_op_type("noop_reducedncnn");
775 
776             node_reference[node->input(0)] -= 1;
777             node_reference[node->input(1)] -= 1;
778             node_reference[node->output(0)] -= 1;
779             if (node2->input_size() == 3)
780             {
781                 node_reference[node2->input(1)] -= 1;
782                 node_reference[node2->input(2)] -= 1;
783             }
784             node_reference[node2->output(0)] -= 1;
785             node_reference[node3->output(0)] -= 1;
786             node_reference[node4->input(1)] -= 1;
787 
788             blob_names.erase(node->output(0));
789             blob_names.erase(node2->output(0));
790             blob_names.erase(node3->output(0));
791 
792             node4->set_op_type("HardSwish");
793             node4->clear_input();
794             node4->add_input(node->input(0));
795 
796             onnx::AttributeProto* attr_alpha = node4->add_attribute();
797             attr_alpha->set_name("alpha");
798             attr_alpha->set_f(1.f / 6.f);
799 
800             onnx::AttributeProto* attr_beta = node4->add_attribute();
801             attr_beta->set_name("beta");
802             attr_beta->set_f(3.f / 6.f);
803 
804             reduced_node_count += 3;
805             i += 3;
806         }
807     }
808 
809     for (int i = 0; i < node_count; i++)
810     {
811         onnx::NodeProto* node = mutable_graph->mutable_node(i);
812 
813         // HardSwish <= HardSigmoid - Mul
814         //     out = x * hsigmoid(x)
815         if (node->op_type() == "HardSigmoid")
816         {
817             if (node_reference[node->output(0)] != 1)
818                 continue;
819 
820             float alpha = get_node_attr_f(*node, "alpha", 0.2f);
821             float beta = get_node_attr_f(*node, "beta", 0.5f);
822 
823             if (i + 1 >= node_count)
824                 continue;
825 
826             onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
827 
828             if (node2->op_type() != "Mul")
829                 continue;
830 
831             if (node2->input(0) != node->input(0) || node2->input(1) != node->output(0))
832                 continue;
833 
834             // reduce
835             node->set_op_type("noop_reducedncnn");
836 
837             node_reference[node->input(0)] -= 1;
838             node_reference[node->output(0)] -= 1;
839 
840             blob_names.erase(node->output(0));
841 
842             node2->set_op_type("HardSwish");
843             node2->clear_input();
844             node2->add_input(node->input(0));
845 
846             onnx::AttributeProto* attr_alpha = node2->add_attribute();
847             attr_alpha->set_name("alpha");
848             attr_alpha->set_f(alpha);
849 
850             onnx::AttributeProto* attr_beta = node2->add_attribute();
851             attr_beta->set_name("beta");
852             attr_beta->set_f(beta);
853 
854             reduced_node_count += 1;
855             i += 1;
856         }
857     }
858 }
859 
fuse_hardsigmoid(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)860 static void fuse_hardsigmoid(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
861 {
862     int node_count = mutable_graph->node_size();
863     for (int i = 0; i < node_count; i++)
864     {
865         onnx::NodeProto* node = mutable_graph->mutable_node(i);
866 
867         // HardSigmoid <= Add(+3) - Clip(0,6) - Div(/6)
868         // HardSigmoid <= Add(+3) - Clip(0,6) - Mul(*(1/6))
869         // HardSigmoid <= Add(+3) - Clip(0,6) - Constant - Div(/6)
870         // HardSigmoid <= Add(+3) - Clip(0,6) - Constant - Mul(*(1/6))
871         //     out = F.relu6(x + 3, inplace=True) / 6
872         if (node->op_type() == "Add")
873         {
874             if (node_reference[node->output(0)] != 1)
875                 continue;
876 
877             if (i + 2 >= node_count)
878                 continue;
879 
880             if (weights.find(node->input(1)) == weights.end())
881                 continue;
882 
883             const onnx::TensorProto& add_three = weights[node->input(1)];
884             if (add_three.dims_size() != 0 || get_tensor_proto_data_size(add_three) != 1)
885                 continue;
886 
887             float constant_add_three = get_node_attr_from_input_f(add_three);
888             if (constant_add_three != 3.f)
889                 continue;
890 
891             onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
892             onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
893 
894             if (node3->op_type() == "Constant")
895             {
896                 if (i + 3 >= node_count)
897                     continue;
898 
899                 node3 = mutable_graph->mutable_node(i + 3);
900             }
901 
902             if (node2->op_type() != "Clip" || (node3->op_type() != "Div" && node3->op_type() != "Mul"))
903                 continue;
904 
905             if (node_reference[node2->output(0)] != 1)
906                 continue;
907 
908             float relu6_min;
909             float relu6_max;
910             if (node2->input_size() == 1)
911             {
912                 relu6_min = get_node_attr_f(*node2, "min", -FLT_MAX);
913                 relu6_max = get_node_attr_f(*node2, "max", FLT_MAX);
914             }
915             else
916             {
917                 const onnx::TensorProto& min_tp = weights[node2->input(1)];
918                 const onnx::TensorProto& max_tp = weights[node2->input(2)];
919 
920                 relu6_min = get_node_attr_from_input_f(min_tp);
921                 relu6_max = get_node_attr_from_input_f(max_tp);
922             }
923             if (relu6_min != 0.f || relu6_max != 6.f)
924                 continue;
925 
926             if (weights.find(node3->input(1)) == weights.end())
927                 continue;
928 
929             const onnx::TensorProto& div_six = weights[node3->input(1)];
930             if (div_six.dims_size() != 0 || get_tensor_proto_data_size(div_six) != 1)
931                 continue;
932 
933             float constant_div_six = get_node_attr_from_input_f(div_six);
934             if (node3->op_type() == "Div" && constant_div_six != 6.f)
935                 continue;
936             if (node3->op_type() == "Mul" && constant_div_six != 1 / 6.f)
937                 continue;
938 
939             // reduce
940             node->set_op_type("noop_reducedncnn");
941             node2->set_op_type("noop_reducedncnn");
942 
943             node_reference[node->input(1)] -= 1;
944             node_reference[node->output(0)] -= 1;
945             if (node2->input_size() == 3)
946             {
947                 node_reference[node2->input(1)] -= 1;
948                 node_reference[node2->input(2)] -= 1;
949             }
950             node_reference[node2->output(0)] -= 1;
951             node_reference[node3->input(1)] -= 1;
952 
953             blob_names.erase(node->output(0));
954             blob_names.erase(node2->output(0));
955 
956             node3->set_op_type("HardSigmoid");
957             node3->clear_input();
958             node3->add_input(node->input(0));
959 
960             onnx::AttributeProto* attr_alpha = node3->add_attribute();
961             attr_alpha->set_name("alpha");
962             attr_alpha->set_f(1.f / 6.f);
963 
964             onnx::AttributeProto* attr_beta = node3->add_attribute();
965             attr_beta->set_name("beta");
966             attr_beta->set_f(3.f / 6.f);
967 
968             reduced_node_count += 2;
969             i += 2;
970         }
971     }
972 }
973 
fuse_swish(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)974 static void fuse_swish(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
975 {
976     int node_count = mutable_graph->node_size();
977     for (int i = 0; i < node_count; i++)
978     {
979         onnx::NodeProto* node = mutable_graph->mutable_node(i);
980 
981         // Swish <= Sigmoid - Mul
982         //     x * torch.sigmoid(x)
983         if (node->op_type() == "Sigmoid")
984         {
985             if (node_reference[node->output(0)] != 1)
986                 continue;
987 
988             if (i + 1 >= node_count)
989                 continue;
990 
991             onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
992 
993             if (node2->op_type() != "Mul")
994                 continue;
995 
996             if (node2->input(0) != node->input(0) || node2->input(1) != node->output(0))
997                 continue;
998 
999             // reduce
1000             node->set_op_type("noop_reducedncnn");
1001 
1002             node_reference[node->input(0)] -= 1;
1003             node_reference[node->output(0)] -= 1;
1004 
1005             blob_names.erase(node->output(0));
1006 
1007             node2->set_op_type("Swish");
1008             node2->clear_input();
1009             node2->add_input(node->input(0));
1010 
1011             reduced_node_count += 1;
1012             i += 1;
1013         }
1014     }
1015 }
1016 
fuse_batchnorm1d_squeeze_unsqueeze(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)1017 static void fuse_batchnorm1d_squeeze_unsqueeze(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
1018 {
1019     int node_count = mutable_graph->node_size();
1020     for (int i = 0; i < node_count; i++)
1021     {
1022         onnx::NodeProto* node = mutable_graph->mutable_node(i);
1023 
1024         // BatchNormalization <= Unsqueeze - BatchNormalization - Squeeze
1025         if (node->op_type() == "Unsqueeze")
1026         {
1027             if (node_reference[node->output(0)] != 1)
1028                 continue;
1029 
1030             if (i + 2 >= node_count)
1031                 continue;
1032 
1033             onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
1034             onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
1035 
1036             if (node2->op_type() != "BatchNormalization" || node3->op_type() != "Squeeze")
1037                 continue;
1038 
1039             if (node_reference[node2->output(0)] != 1)
1040                 continue;
1041 
1042             if (node2->input(0) != node->output(0) || node3->input(0) != node2->output(0))
1043                 continue;
1044 
1045             // reduce
1046             node->set_op_type("noop_reducedncnn");
1047             node3->set_op_type("noop_reducedncnn");
1048 
1049             node_reference[node->output(0)] -= 1;
1050             node_reference[node2->output(0)] -= 1;
1051 
1052             blob_names.erase(node->output(0));
1053             blob_names.erase(node2->output(0));
1054 
1055             node2->set_input(0, node->input(0));
1056             node2->set_output(0, node3->output(0));
1057 
1058             reduced_node_count += 2;
1059             i += 2;
1060         }
1061     }
1062 }
1063 
fuse_unsqueeze_prelu(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)1064 static void fuse_unsqueeze_prelu(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
1065 {
1066     int node_count = mutable_graph->node_size();
1067     for (int i = 0; i < node_count; i++)
1068     {
1069         onnx::NodeProto* node = mutable_graph->mutable_node(i);
1070 
1071         // PReLU <= Unsqueeze - PReLU
1072         if (node->op_type() == "Unsqueeze")
1073         {
1074             // check weight
1075             if (weights.find(node->input(0)) == weights.end())
1076                 continue;
1077 
1078             onnx::TensorProto& B = weights[node->input(0)];
1079             if (B.dims_size() != 1)
1080                 continue;
1081 
1082             if (node_reference[node->output(0)] != 1)
1083                 continue;
1084 
1085             // axes = (1, 2)
1086             std::vector<int> axes = get_node_attr_ai(*node, "axes");
1087             if (axes.size() != 2)
1088                 continue;
1089             if (axes[0] != 1 || axes[1] != 2)
1090                 continue;
1091 
1092             if (i + 1 >= node_count)
1093                 continue;
1094 
1095             onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
1096 
1097             if (node2->op_type() != "PRelu")
1098                 continue;
1099 
1100             if (node2->input(1) != node->output(0))
1101                 continue;
1102 
1103             // reduce
1104             node->set_op_type("noop_reducedncnn");
1105 
1106             node_reference[node->output(0)] -= 1;
1107 
1108             blob_names.erase(node->output(0));
1109 
1110             node2->set_input(1, node->input(0));
1111 
1112             reduced_node_count += 1;
1113             i += 1;
1114         }
1115     }
1116 }
1117 
fuse_normalize(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)1118 static void fuse_normalize(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
1119 {
1120     int node_count = mutable_graph->node_size();
1121     for (int i = 0; i < node_count; i++)
1122     {
1123         onnx::NodeProto* node = mutable_graph->mutable_node(i);
1124 
1125         // Normalize <= X - ReduceL2 - Clip - Expand - Div
1126         // Normalize <= X - ReduceL2 - Clip - Shape - Expand - Div
1127         if (node->op_type() == "ReduceL2")
1128         {
1129             if (node_reference[node->output(0)] != 1)
1130                 continue;
1131 
1132             // axes = (1)
1133             std::vector<int> axes = get_node_attr_ai(*node, "axes");
1134             if (axes.size() != 1)
1135                 continue;
1136             if (axes[0] != 1)
1137                 continue;
1138 
1139             if (i + 3 >= node_count)
1140                 continue;
1141 
1142             onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
1143             onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
1144             onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3);
1145 
1146             bool has_shape_node = node3->op_type() == "Shape";
1147             onnx::NodeProto* node_shape = 0;
1148             if (has_shape_node)
1149             {
1150                 if (i + 4 >= node_count)
1151                     continue;
1152 
1153                 node_shape = node3;
1154                 node3 = mutable_graph->mutable_node(i + 3);
1155                 node4 = mutable_graph->mutable_node(i + 4);
1156             }
1157 
1158             if (node2->op_type() != "Clip" || node3->op_type() != "Expand" || node4->op_type() != "Div")
1159                 continue;
1160 
1161             if (node_reference[node2->output(0)] != 1)
1162                 continue;
1163 
1164             if (node_reference[node3->output(0)] != 1)
1165                 continue;
1166 
1167             if (node2->input(0) != node->output(0) || node3->input(0) != node2->output(0)
1168                     || node4->input(0) != node->input(0) || node4->input(1) != node3->output(0))
1169                 continue;
1170 
1171             if (has_shape_node)
1172             {
1173                 if (node_shape->input(0) != node->input(0) || node3->input(1) != node_shape->output(0))
1174                     continue;
1175             }
1176 
1177             // +eps
1178             float clip_min;
1179             if (node2->input_size() == 1)
1180             {
1181                 clip_min = get_node_attr_f(*node2, "min", -FLT_MAX);
1182             }
1183             else
1184             {
1185                 const onnx::TensorProto& min_tp = weights[node2->input(1)];
1186 
1187                 clip_min = get_node_attr_from_input_f(min_tp);
1188             }
1189 
1190             // reduce
1191             node->set_op_type("noop_reducedncnn");
1192             node2->set_op_type("noop_reducedncnn");
1193             if (has_shape_node)
1194             {
1195                 node_shape->set_op_type("noop_reducedncnn");
1196             }
1197             node3->set_op_type("noop_reducedncnn");
1198 
1199             node_reference[node->input(0)] -= has_shape_node ? 2 : 1;
1200             node_reference[node->output(0)] -= 1;
1201             node_reference[node2->output(0)] -= 1;
1202             if (has_shape_node)
1203             {
1204                 node_reference[node_shape->output(0)] -= 1;
1205             }
1206             node_reference[node3->output(0)] -= 1;
1207             if (node3->input_size() == 2)
1208             {
1209                 node_reference[node3->input(1)] -= 1;
1210             }
1211 
1212             blob_names.erase(node->output(0));
1213             blob_names.erase(node2->output(0));
1214             if (has_shape_node)
1215             {
1216                 blob_names.erase(node_shape->output(0));
1217             }
1218             blob_names.erase(node3->output(0));
1219 
1220             node4->set_op_type("Normalize");
1221             node4->clear_input();
1222             node4->add_input(node->input(0));
1223 
1224             onnx::AttributeProto* attr_alpha = node4->add_attribute();
1225             attr_alpha->set_name("eps");
1226             attr_alpha->set_f(clip_min);
1227 
1228             reduced_node_count += has_shape_node ? 4 : 3;
1229             i += has_shape_node ? 4 : 3;
1230         }
1231     }
1232 }
1233 
fuse_groupnorm(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)1234 static void fuse_groupnorm(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
1235 {
1236     int node_count = mutable_graph->node_size();
1237     for (int i = 0; i < node_count; i++)
1238     {
1239         onnx::NodeProto* node = mutable_graph->mutable_node(i);
1240 
1241         // GroupNorm <= X - Reshape - InstanceNormalization - Reshape - Mul - Add
1242         if (node->op_type() == "Reshape")
1243         {
1244             if (node_reference[node->output(0)] != 1)
1245                 continue;
1246 
1247             std::vector<int> shape;
1248             if (node->input_size() == 1)
1249             {
1250                 shape = get_node_attr_ai(*node, "shape");
1251             }
1252             else
1253             {
1254                 // skip weight reshape
1255                 if (weights.find(node->input(1)) == weights.end())
1256                     continue;
1257 
1258                 shape = get_node_attr_from_input_ai(weights[node->input(1)]);
1259             }
1260 
1261             // 0, group, -1
1262             if (shape.size() != 3)
1263                 continue;
1264 
1265             if (shape[0] != 0 || shape[2] != -1)
1266                 continue;
1267 
1268             int groups = shape[1];
1269 
1270             if (i + 4 >= node_count)
1271                 continue;
1272 
1273             onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
1274             onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
1275             onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3);
1276             onnx::NodeProto* node5 = mutable_graph->mutable_node(i + 4);
1277 
1278             if (node2->op_type() != "InstanceNormalization" || node3->op_type() != "Reshape" || node4->op_type() != "Mul" || node5->op_type() != "Add")
1279                 continue;
1280 
1281             if (node_reference[node2->output(0)] != 1)
1282                 continue;
1283 
1284             if (node_reference[node3->output(0)] != 1)
1285                 continue;
1286 
1287             if (node_reference[node4->output(0)] != 1)
1288                 continue;
1289 
1290             if (node2->input(0) != node->output(0) || node3->input(0) != node2->output(0)
1291                     || node4->input(0) != node3->output(0) || node5->input(0) != node4->output(0))
1292                 continue;
1293 
1294             // +eps
1295             float eps = get_node_attr_f(*node2, "epsilon", 1e-05f);
1296 
1297             // InstanceNormalization S=1 B=0
1298             std::vector<float> S = get_node_attr_from_input_af(weights[node2->input(1)]);
1299             std::vector<float> B = get_node_attr_from_input_af(weights[node2->input(2)]);
1300             if ((int)S.size() != groups || (int)B.size() != groups)
1301                 continue;
1302 
1303             bool instancenorm_affine = false;
1304             for (int j = 0; j < groups; j++)
1305             {
1306                 if (S[j] != 1.f || B[j] != 0.f)
1307                 {
1308                     instancenorm_affine = true;
1309                     break;
1310                 }
1311             }
1312 
1313             if (instancenorm_affine)
1314                 continue;
1315 
1316             std::vector<int> shape2;
1317             if (node3->input_size() == 1)
1318             {
1319                 shape2 = get_node_attr_ai(*node3, "shape");
1320             }
1321             else
1322             {
1323                 // skip weight reshape
1324                 if (weights.find(node3->input(1)) == weights.end())
1325                     continue;
1326 
1327                 shape2 = get_node_attr_from_input_ai(weights[node3->input(1)]);
1328             }
1329 
1330             // 1, channels, w, h
1331             if (shape2.size() != 4)
1332                 continue;
1333 
1334             if (shape2[0] != 1)
1335                 continue;
1336 
1337             int channels = shape2[1];
1338 
1339             // affine
1340             std::vector<float> affine_S = get_node_attr_from_input_af(weights[node4->input(1)]);
1341             std::vector<float> affine_B = get_node_attr_from_input_af(weights[node5->input(1)]);
1342             if (affine_S.size() == 1 && affine_S[0] == 1.f && affine_B.size() == 1 && affine_B[0] == 0.f)
1343             {
1344                 // no affine
1345             }
1346             else if ((int)affine_S.size() != channels && (int)affine_B.size() != channels)
1347             {
1348                 // we only allow per-channel affine
1349                 continue;
1350             }
1351 
1352             // reduce
1353             node->set_op_type("noop_reducedncnn");
1354             node2->set_op_type("noop_reducedncnn");
1355             node3->set_op_type("noop_reducedncnn");
1356             node4->set_op_type("noop_reducedncnn");
1357 
1358             if (node->input_size() == 2)
1359             {
1360                 node_reference[node->input(1)] -= 1;
1361             }
1362             node_reference[node->output(0)] -= 1;
1363             node_reference[node2->input(1)] -= 1;
1364             node_reference[node2->input(2)] -= 1;
1365             node_reference[node2->output(0)] -= 1;
1366             if (node3->input_size() == 2)
1367             {
1368                 node_reference[node3->input(1)] -= 1;
1369             }
1370             node_reference[node3->output(0)] -= 1;
1371             node_reference[node4->output(0)] -= 1;
1372 
1373             blob_names.erase(node->output(0));
1374             blob_names.erase(node2->output(0));
1375             blob_names.erase(node3->output(0));
1376             blob_names.erase(node4->output(0));
1377 
1378             std::string affine_scale = node4->input(1);
1379             std::string affine_bias = node5->input(1);
1380 
1381             node5->set_op_type("GroupNorm");
1382             node5->clear_input();
1383             node5->add_input(node->input(0));
1384             node5->add_input(affine_scale);
1385             node5->add_input(affine_bias);
1386 
1387             onnx::AttributeProto* attr_groups = node5->add_attribute();
1388             attr_groups->set_name("groups");
1389             attr_groups->set_i(groups);
1390 
1391             onnx::AttributeProto* attr_channels = node5->add_attribute();
1392             attr_channels->set_name("channels");
1393             attr_channels->set_i(channels);
1394 
1395             onnx::AttributeProto* attr_eps = node5->add_attribute();
1396             attr_eps->set_name("epsilon");
1397             attr_eps->set_f(eps);
1398 
1399             onnx::AttributeProto* attr_affine = node5->add_attribute();
1400             attr_affine->set_name("affine");
1401             attr_affine->set_i(1);
1402 
1403             reduced_node_count += 4;
1404             i += 4;
1405         }
1406     }
1407 }
1408 
fuse_layernorm(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)1409 static void fuse_layernorm(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
1410 {
1411     int node_count = mutable_graph->node_size();
1412     for (int i = 0; i < node_count; i++)
1413     {
1414         onnx::NodeProto* node = mutable_graph->mutable_node(i);
1415 
1416         // LayerNorm <= X - ReduceMean - Sub - Pow - ReduceMean - Add - Sqrt - Div
1417         // LayerNorm <= X - ReduceMean - Sub - Pow - ReduceMean - Add - Sqrt - Div - Mul - Add
1418         if (node->op_type() == "ReduceMean")
1419         {
1420             if (node_reference[node->output(0)] != 1)
1421                 continue;
1422 
1423             std::vector<int> axes = get_node_attr_ai(*node, "axes");
1424 
1425             // -1
1426             // -2 -1
1427             if (axes.size() != 1 && axes.size() != 2)
1428                 continue;
1429 
1430             int normed_axes = (int)axes.size();
1431             if (normed_axes == 1 && axes[0] != -1)
1432                 continue;
1433             if (normed_axes == 2 && (axes[0] != -2 || axes[1] != -1))
1434                 continue;
1435 
1436             if (i + 6 >= node_count)
1437                 continue;
1438 
1439             onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
1440             onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
1441             onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3);
1442             onnx::NodeProto* node5 = mutable_graph->mutable_node(i + 4);
1443             onnx::NodeProto* node6 = mutable_graph->mutable_node(i + 5);
1444             onnx::NodeProto* node7 = mutable_graph->mutable_node(i + 6);
1445 
1446             if (node2->op_type() != "Sub" || node3->op_type() != "Pow" || node4->op_type() != "ReduceMean" || node5->op_type() != "Add" || node6->op_type() != "Sqrt" || node7->op_type() != "Div")
1447                 continue;
1448 
1449             if (node_reference[node2->output(0)] != 2)
1450                 continue;
1451 
1452             if (node_reference[node3->output(0)] != 1)
1453                 continue;
1454 
1455             if (node_reference[node4->output(0)] != 1)
1456                 continue;
1457 
1458             if (node_reference[node5->output(0)] != 1)
1459                 continue;
1460 
1461             if (node_reference[node6->output(0)] != 1)
1462                 continue;
1463 
1464             if (node2->input(0) != node->input(0) || node2->input(1) != node->output(0)
1465                     || node3->input(0) != node2->output(0) || node4->input(0) != node3->output(0)
1466                     || node5->input(0) != node4->output(0) || node6->input(0) != node5->output(0)
1467                     || node7->input(0) != node2->output(0) || node7->input(1) != node6->output(0))
1468                 continue;
1469 
1470             if (weights.find(node3->input(1)) == weights.end())
1471                 continue;
1472 
1473             const onnx::TensorProto& pow_two = weights[node3->input(1)];
1474             if (pow_two.dims_size() != 0 || get_tensor_proto_data_size(pow_two) != 1)
1475                 continue;
1476 
1477             float constant_pow_two = get_node_attr_from_input_f(pow_two);
1478             if (constant_pow_two != 2.f)
1479                 continue;
1480 
1481             std::vector<int> axes4 = get_node_attr_ai(*node4, "axes");
1482 
1483             // -1
1484             // -2 -1
1485             if ((int)axes4.size() != normed_axes)
1486                 continue;
1487 
1488             if (normed_axes == 1 && axes4[0] != -1)
1489                 continue;
1490             if (normed_axes == 2 && (axes4[0] != -2 || axes4[1] != -1))
1491                 continue;
1492 
1493             if (weights.find(node5->input(1)) == weights.end())
1494                 continue;
1495 
1496             const onnx::TensorProto& add_eps = weights[node5->input(1)];
1497             if (add_eps.dims_size() != 0 || get_tensor_proto_data_size(add_eps) != 1)
1498                 continue;
1499 
1500             float eps = get_node_attr_from_input_f(add_eps);
1501 
1502             int affine = 0;
1503             while (i + 8 < node_count)
1504             {
1505                 onnx::NodeProto* node8 = mutable_graph->mutable_node(i + 7);
1506                 onnx::NodeProto* node9 = mutable_graph->mutable_node(i + 8);
1507 
1508                 if (node8->op_type() != "Mul" || node9->op_type() != "Add")
1509                     break;
1510 
1511                 if (node_reference[node7->output(0)] != 1)
1512                     break;
1513 
1514                 if (node_reference[node8->output(0)] != 1)
1515                     break;
1516 
1517                 if (node8->input(0) != node7->output(0) || node9->input(0) != node8->output(0))
1518                     break;
1519 
1520                 // affine
1521                 std::vector<float> affine_S = get_node_attr_from_input_af(weights[node8->input(1)]);
1522                 std::vector<float> affine_B = get_node_attr_from_input_af(weights[node9->input(1)]);
1523                 if (affine_S.size() != affine_B.size())
1524                     break;
1525 
1526                 affine = 1;
1527                 break;
1528             }
1529 
1530             // reduce
1531             node->set_op_type("noop_reducedncnn");
1532             node2->set_op_type("noop_reducedncnn");
1533             node3->set_op_type("noop_reducedncnn");
1534             node4->set_op_type("noop_reducedncnn");
1535             node5->set_op_type("noop_reducedncnn");
1536             node6->set_op_type("noop_reducedncnn");
1537 
1538             node_reference[node->input(0)] -= 1;
1539             node_reference[node2->input(0)] -= 1;
1540             node_reference[node2->input(1)] -= 1;
1541             node_reference[node3->input(0)] -= 1;
1542             node_reference[node3->input(1)] -= 1;
1543             node_reference[node4->input(0)] -= 1;
1544             node_reference[node5->input(0)] -= 1;
1545             node_reference[node5->input(1)] -= 1;
1546             node_reference[node6->input(0)] -= 1;
1547             node_reference[node7->input(0)] -= 1;
1548             node_reference[node7->input(1)] -= 1;
1549 
1550             blob_names.erase(node->output(0));
1551             blob_names.erase(node2->output(0));
1552             blob_names.erase(node3->output(0));
1553             blob_names.erase(node4->output(0));
1554             blob_names.erase(node5->output(0));
1555             blob_names.erase(node6->output(0));
1556 
1557             node_reference[node->input(0)] += 1;
1558 
1559             if (affine == 0)
1560             {
1561                 node7->set_op_type("LayerNorm");
1562                 node7->clear_input();
1563                 node7->add_input(node->input(0));
1564 
1565                 onnx::AttributeProto* attr_eps = node7->add_attribute();
1566                 attr_eps->set_name("epsilon");
1567                 attr_eps->set_f(eps);
1568 
1569                 onnx::AttributeProto* attr_affine = node7->add_attribute();
1570                 attr_affine->set_name("affine");
1571                 attr_affine->set_i(affine);
1572 
1573                 reduced_node_count += 6;
1574                 i += 6;
1575             }
1576             else // if (affine == 1)
1577             {
1578                 onnx::NodeProto* node8 = mutable_graph->mutable_node(i + 7);
1579                 onnx::NodeProto* node9 = mutable_graph->mutable_node(i + 8);
1580 
1581                 node7->set_op_type("noop_reducedncnn");
1582                 node8->set_op_type("noop_reducedncnn");
1583 
1584                 node_reference[node8->input(0)] -= 1;
1585                 node_reference[node9->input(0)] -= 1;
1586 
1587                 blob_names.erase(node7->output(0));
1588                 blob_names.erase(node8->output(0));
1589 
1590                 std::string affine_scale = node8->input(1);
1591                 std::string affine_bias = node9->input(1);
1592 
1593                 node9->set_op_type("LayerNorm");
1594                 node9->clear_input();
1595                 node9->add_input(node->input(0));
1596                 node9->add_input(affine_scale);
1597                 node9->add_input(affine_bias);
1598 
1599                 onnx::AttributeProto* attr_eps = node9->add_attribute();
1600                 attr_eps->set_name("epsilon");
1601                 attr_eps->set_f(eps);
1602 
1603                 onnx::AttributeProto* attr_affine = node9->add_attribute();
1604                 attr_affine->set_name("affine");
1605                 attr_affine->set_i(affine);
1606 
1607                 reduced_node_count += 8;
1608                 i += 8;
1609             }
1610         }
1611     }
1612 }
1613 
fuse_flatten(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)1614 static void fuse_flatten(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
1615 {
1616     int node_count = mutable_graph->node_size();
1617     for (int i = 0; i < node_count; i++)
1618     {
1619         onnx::NodeProto* node = mutable_graph->mutable_node(i);
1620 
1621         // Flatten <= X - Shape - Gather - Constant - Unsqueeze - Unsqueeze - Concat - Reshape
1622         if (node->op_type() == "Shape")
1623         {
1624             if (node_reference[node->output(0)] != 1)
1625                 continue;
1626 
1627             if (i + 6 >= node_count)
1628                 continue;
1629 
1630             onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
1631             onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
1632             onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3);
1633             onnx::NodeProto* node5 = mutable_graph->mutable_node(i + 4);
1634             onnx::NodeProto* node6 = mutable_graph->mutable_node(i + 5);
1635             onnx::NodeProto* node7 = mutable_graph->mutable_node(i + 6);
1636 
1637             if (node2->op_type() != "Gather" || node3->op_type() != "Constant" || node4->op_type() != "Unsqueeze" || node5->op_type() != "Unsqueeze"
1638                     || node6->op_type() != "Concat" || node7->op_type() != "Reshape")
1639                 continue;
1640 
1641             if (node_reference[node2->output(0)] != 1)
1642                 continue;
1643 
1644             //             if (node_reference[node3->output(0)] != 1)
1645             //                 continue;
1646 
1647             if (node_reference[node4->output(0)] != 1)
1648                 continue;
1649 
1650             if (node_reference[node5->output(0)] != 1)
1651                 continue;
1652 
1653             if (node_reference[node6->output(0)] != 1)
1654                 continue;
1655 
1656             if (node2->input(0) != node->output(0) || node4->input(0) != node2->output(0) || node5->input(0) != node3->output(0)
1657                     || node6->input(0) != node4->output(0) || node6->input(1) != node5->output(0)
1658                     || node7->input(0) != node->input(0) || node7->input(1) != node6->output(0))
1659                 continue;
1660 
1661             // axis = 0
1662             int gather_axis = get_node_attr_i(*node2, "axis");
1663             if (gather_axis != 0)
1664                 continue;
1665 
1666             // indices = 0
1667             if (weights.find(node2->input(1)) == weights.end())
1668                 continue;
1669 
1670             std::vector<int> gather_indices = get_node_attr_from_input_ai(weights[node2->input(1)]);
1671             if (gather_indices.size() != 1 || gather_indices[0] != 0)
1672                 continue;
1673 
1674             // axes = (0)
1675             std::vector<int> unsqueeze_axes = get_node_attr_ai(*node4, "axes");
1676             if (unsqueeze_axes.size() != 1)
1677                 continue;
1678             if (unsqueeze_axes[0] != 0)
1679                 continue;
1680 
1681             // axes = (0)
1682             std::vector<int> unsqueeze2_axes = get_node_attr_ai(*node5, "axes");
1683             if (unsqueeze2_axes.size() != 1)
1684                 continue;
1685             if (unsqueeze2_axes[0] != 0)
1686                 continue;
1687 
1688             // data = -1
1689             if (weights.find(node5->input(0)) == weights.end())
1690                 continue;
1691 
1692             std::vector<int> unsqueeze2_data = get_node_attr_from_input_ai(weights[node5->input(0)]);
1693             if (unsqueeze2_data.size() != 1 || unsqueeze2_data[0] != -1)
1694                 continue;
1695 
1696             // axis = 0
1697             int concat_axis = get_node_attr_i(*node6, "axis");
1698             if (concat_axis != 0)
1699                 continue;
1700 
1701             // reduce
1702             node->set_op_type("noop_reducedncnn");
1703             node2->set_op_type("noop_reducedncnn");
1704             //             node3->set_op_type("noop_reducedncnn");
1705             node4->set_op_type("noop_reducedncnn");
1706             node5->set_op_type("noop_reducedncnn");
1707             node6->set_op_type("noop_reducedncnn");
1708 
1709             node_reference[node->input(0)] -= 1;
1710             node_reference[node->output(0)] -= 1;
1711             node_reference[node2->input(1)] -= 1;
1712             node_reference[node2->output(0)] -= 1;
1713             //             node_reference[node3->output(0)] -= 1;
1714             node_reference[node4->output(0)] -= 1;
1715             node_reference[node5->input(0)] -= 1;
1716             node_reference[node5->output(0)] -= 1;
1717             node_reference[node6->output(0)] -= 1;
1718 
1719             blob_names.erase(node->output(0));
1720             blob_names.erase(node2->output(0));
1721             //             blob_names.erase(node3->output(0));
1722             blob_names.erase(node4->output(0));
1723             blob_names.erase(node5->output(0));
1724             blob_names.erase(node6->output(0));
1725 
1726             node7->set_op_type("Flatten");
1727             node7->clear_input();
1728             node7->add_input(node->input(0));
1729 
1730             reduced_node_count += 5;
1731             i += 5;
1732         }
1733     }
1734 }
1735 
fuse_pixelshuffle(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)1736 static void fuse_pixelshuffle(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
1737 {
1738     int node_count = mutable_graph->node_size();
1739     for (int i = 0; i < node_count; i++)
1740     {
1741         onnx::NodeProto* node = mutable_graph->mutable_node(i);
1742 
1743         // PixelShuffle <= Reshape - Transpose - Reshape
1744         // PixelShuffle <= Reshape - Transpose - Constant - Reshape
1745         if (node->op_type() == "Reshape")
1746         {
1747             if (node_reference[node->output(0)] != 1)
1748                 continue;
1749 
1750             std::vector<int> shape;
1751             if (node->input_size() == 1)
1752             {
1753                 shape = get_node_attr_ai(*node, "shape");
1754             }
1755             else
1756             {
1757                 // skip weight reshape
1758                 if (weights.find(node->input(1)) == weights.end())
1759                     continue;
1760 
1761                 shape = get_node_attr_from_input_ai(weights[node->input(1)]);
1762             }
1763 
1764             // -1, 3, upscale_factor, upscale_factor, height, width
1765             if (shape.size() != 6)
1766                 continue;
1767 
1768             if (shape[0] != 1 && shape[0] != -1)
1769                 continue;
1770 
1771             if (shape[2] != shape[3])
1772                 continue;
1773 
1774             if (i + 2 >= node_count)
1775                 continue;
1776 
1777             onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
1778             onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
1779 
1780             if (node3->op_type() == "Constant")
1781             {
1782                 if (i + 3 >= node_count)
1783                     continue;
1784 
1785                 node3 = mutable_graph->mutable_node(i + 3);
1786             }
1787 
1788             if (node2->op_type() != "Transpose" || node3->op_type() != "Reshape")
1789                 continue;
1790 
1791             if (node_reference[node2->output(0)] != 1)
1792                 continue;
1793 
1794             // 0 1 4 2 5 3
1795             std::vector<int> perm = get_node_attr_ai(*node2, "perm");
1796             if (perm.size() != 6)
1797                 continue;
1798 
1799             if (perm[0] != 0 || perm[1] != 1 || perm[2] != 4 || perm[3] != 2 || perm[4] != 5 || perm[5] != 3)
1800                 continue;
1801 
1802             std::vector<int> shape3;
1803             if (node3->input_size() == 1)
1804             {
1805                 shape3 = get_node_attr_ai(*node3, "shape");
1806             }
1807             else
1808             {
1809                 // skip weight reshape
1810                 if (weights.find(node3->input(1)) == weights.end())
1811                     continue;
1812 
1813                 shape3 = get_node_attr_from_input_ai(weights[node3->input(1)]);
1814             }
1815 
1816             // -1, 3, height, width
1817             if (shape3.size() != 4)
1818                 continue;
1819 
1820             if (shape3[0] != 1 && shape3[0] != -1)
1821                 continue;
1822 
1823             if (shape3[1] != shape[1] || shape3[2] != shape[2] * shape[4] || shape3[3] != shape[3] * shape[5])
1824                 continue;
1825 
1826             // reduce
1827             node->set_op_type("noop_reducedncnn");
1828             node2->set_op_type("noop_reducedncnn");
1829 
1830             if (node->input_size() == 2)
1831             {
1832                 node_reference[node->input(1)] -= 1;
1833             }
1834             node_reference[node->output(0)] -= 1;
1835             node_reference[node2->output(0)] -= 1;
1836             if (node3->input_size() == 2)
1837             {
1838                 node_reference[node3->input(1)] -= 1;
1839             }
1840 
1841             blob_names.erase(node->output(0));
1842             blob_names.erase(node2->output(0));
1843 
1844             node3->set_op_type("PixelShuffle");
1845             node3->set_input(0, node->input(0));
1846 
1847             onnx::AttributeProto* attr_group = node3->add_attribute();
1848             attr_group->set_name("scale_factor");
1849             attr_group->set_i(shape[2]);
1850 
1851             reduced_node_count += 2;
1852             i += 2;
1853         }
1854     }
1855 }
1856 
fuse_reorg(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)1857 static void fuse_reorg(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
1858 {
1859     int node_count = mutable_graph->node_size();
1860     for (int i = 0; i < node_count; i++)
1861     {
1862         onnx::NodeProto* node = mutable_graph->mutable_node(i);
1863 
1864         // PixelShuffle <= Reshape - Transpose - Reshape
1865         // PixelShuffle <= Reshape - Transpose - Constant - Reshape
1866         if (node->op_type() == "Reshape")
1867         {
1868             if (node_reference[node->output(0)] != 1)
1869                 continue;
1870 
1871             std::vector<int> shape;
1872             if (node->input_size() == 1)
1873             {
1874                 shape = get_node_attr_ai(*node, "shape");
1875             }
1876             else
1877             {
1878                 // skip weight reshape
1879                 if (weights.find(node->input(1)) == weights.end())
1880                     continue;
1881 
1882                 shape = get_node_attr_from_input_ai(weights[node->input(1)]);
1883             }
1884 
1885             // -1, 3, out_height, block_size, out_width, block_size
1886             if (shape.size() != 6)
1887                 continue;
1888 
1889             if (shape[0] != 1 && shape[0] != -1)
1890                 continue;
1891 
1892             if (shape[3] != shape[5])
1893                 continue;
1894 
1895             if (i + 2 >= node_count)
1896                 continue;
1897 
1898             onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
1899             onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
1900 
1901             if (node3->op_type() == "Constant")
1902             {
1903                 if (i + 3 >= node_count)
1904                     continue;
1905 
1906                 node3 = mutable_graph->mutable_node(i + 3);
1907             }
1908 
1909             if (node2->op_type() != "Transpose" || node3->op_type() != "Reshape")
1910                 continue;
1911 
1912             if (node_reference[node2->output(0)] != 1)
1913                 continue;
1914 
1915             // 0 1 3 5 2 4
1916             std::vector<int> perm = get_node_attr_ai(*node2, "perm");
1917             if (perm.size() != 6)
1918                 continue;
1919 
1920             if (perm[0] != 0 || perm[1] != 1 || perm[2] != 3 || perm[3] != 5 || perm[4] != 2 || perm[5] != 4)
1921                 continue;
1922 
1923             std::vector<int> shape3;
1924             if (node3->input_size() == 1)
1925             {
1926                 shape3 = get_node_attr_ai(*node3, "shape");
1927             }
1928             else
1929             {
1930                 // skip weight reshape
1931                 if (weights.find(node3->input(1)) == weights.end())
1932                     continue;
1933 
1934                 shape3 = get_node_attr_from_input_ai(weights[node3->input(1)]);
1935             }
1936 
1937             // -1, out_channels, out_height, out_width
1938             if (shape3.size() != 4)
1939                 continue;
1940 
1941             if (shape3[0] != 1 && shape3[0] != -1)
1942                 continue;
1943 
1944             if (shape3[1] != shape[1] * shape[3] * shape[5] || shape3[2] != shape[2] || shape3[3] != shape[4])
1945                 continue;
1946 
1947             // reduce
1948             node->set_op_type("noop_reducedncnn");
1949             node2->set_op_type("noop_reducedncnn");
1950 
1951             if (node->input_size() == 2)
1952             {
1953                 node_reference[node->input(1)] -= 1;
1954             }
1955             node_reference[node->output(0)] -= 1;
1956             node_reference[node2->output(0)] -= 1;
1957             if (node3->input_size() == 2)
1958             {
1959                 node_reference[node3->input(1)] -= 1;
1960             }
1961 
1962             blob_names.erase(node->output(0));
1963             blob_names.erase(node2->output(0));
1964 
1965             node3->set_op_type("Reorg");
1966             node3->set_input(0, node->input(0));
1967 
1968             onnx::AttributeProto* attr_group = node3->add_attribute();
1969             attr_group->set_name("stride");
1970             attr_group->set_i(shape[3]);
1971 
1972             reduced_node_count += 2;
1973             i += 2;
1974         }
1975     }
1976 }
1977 
fuse_expand_broadcast(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)1978 static void fuse_expand_broadcast(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
1979 {
1980     int node_count = mutable_graph->node_size();
1981     for (int i = 0; i < node_count; i++)
1982     {
1983         onnx::NodeProto* node = mutable_graph->mutable_node(i);
1984 
1985         // Add/Sub/Mul/Div/Min/Max <= Expand - Add/Sub/Mul/Div/Min/Max
1986         if (node->op_type() == "Expand")
1987         {
1988             if (node_reference[node->output(0)] != 1)
1989                 continue;
1990 
1991             if (i + 1 >= node_count)
1992                 continue;
1993 
1994             onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
1995 
1996             if (node2->op_type() != "Add" && node2->op_type() != "Sub" && node2->op_type() != "Mul" && node2->op_type() != "Div" && node2->op_type() != "Min" && node2->op_type() != "Max")
1997                 continue;
1998 
1999             if (node2->input(1) != node->output(0) && node2->input(0) != node->output(0))
2000                 continue;
2001 
2002             // reduce
2003             node->set_op_type("noop_reducedncnn");
2004 
2005             node_reference[node->output(0)] -= 1;
2006             if (node->input_size() == 2)
2007             {
2008                 node_reference[node->input(1)] -= 1;
2009             }
2010 
2011             blob_names.erase(node->output(0));
2012 
2013             node2->set_input(1, node->input(0));
2014 
2015             reduced_node_count += 1;
2016             i += 1;
2017         }
2018     }
2019 }
2020 
fuse_lstm_gru_rnn(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)2021 static void fuse_lstm_gru_rnn(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
2022 {
2023     int node_count = mutable_graph->node_size();
2024     for (int i = 0; i < node_count; i++)
2025     {
2026         onnx::NodeProto* node = mutable_graph->mutable_node(i);
2027 
2028         // LSTM(bi) <= LSTM(bi) - Transpose - Reshape - Transpose
2029         if (node->op_type() == "LSTM" || node->op_type() == "GRU" || node->op_type() == "RNN")
2030         {
2031             if (node_reference[node->output(0)] != 1)
2032                 continue;
2033 
2034             if (i + 2 >= node_count)
2035                 continue;
2036 
2037             onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
2038             onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
2039 
2040             if (node2->op_type() != "Transpose" || node3->op_type() != "Reshape")
2041                 continue;
2042 
2043             if (node_reference[node2->output(0)] != 1)
2044                 continue;
2045 
2046             if (node2->input(0) != node->output(0) || node3->input(0) != node2->output(0))
2047                 continue;
2048 
2049             std::string direction = get_node_attr_s(*node, "direction");
2050             if (direction != "bidirectional")
2051                 continue;
2052 
2053             // 0 2 1 3
2054             std::vector<int> perm = get_node_attr_ai(*node2, "perm");
2055             if (perm.size() != 4)
2056                 continue;
2057 
2058             if (perm[0] != 0 || perm[1] != 2 || perm[2] != 1 || perm[3] != 3)
2059                 continue;
2060 
2061             std::vector<int> shape;
2062             if (node3->input_size() == 1)
2063             {
2064                 shape = get_node_attr_ai(*node3, "shape");
2065             }
2066             else
2067             {
2068                 // skip weight reshape
2069                 if (weights.find(node3->input(1)) == weights.end())
2070                     continue;
2071 
2072                 shape = get_node_attr_from_input_ai(weights[node3->input(1)]);
2073             }
2074 
2075             // 0 0 -1
2076             if (shape.size() != 3)
2077                 continue;
2078 
2079             if (shape[0] != 0 || shape[1] != 0 || shape[2] != -1)
2080                 continue;
2081 
2082             // reduce
2083             node2->set_op_type("noop_reducedncnn");
2084             node3->set_op_type("noop_reducedncnn");
2085 
2086             node_reference[node->output(0)] -= 1;
2087             node_reference[node2->output(0)] -= 1;
2088             if (node3->input_size() == 2)
2089             {
2090                 node_reference[node3->input(1)] -= 1;
2091             }
2092 
2093             blob_names.erase(node->output(0));
2094             blob_names.erase(node2->output(0));
2095 
2096             node->set_output(0, node3->output(0));
2097 
2098             reduced_node_count += 2;
2099             i += 2;
2100 
2101             if (i + 1 < node_count)
2102             {
2103                 if (node_reference[node3->output(0)] != 1)
2104                     continue;
2105 
2106                 onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 1);
2107 
2108                 if (node4->op_type() != "Transpose")
2109                     continue;
2110 
2111                 if (node4->input(0) != node->output(0))
2112                     continue;
2113 
2114                 // 1 0 2
2115                 std::vector<int> perm4 = get_node_attr_ai(*node4, "perm");
2116                 if (perm4.size() != 3)
2117                     continue;
2118 
2119                 if (perm4[0] != 1 || perm4[1] != 0 || perm4[2] != 2)
2120                     continue;
2121 
2122                 // reduce
2123                 node4->set_op_type("noop_reducedncnn");
2124 
2125                 node_reference[node->output(0)] -= 1;
2126 
2127                 blob_names.erase(node->output(0));
2128 
2129                 node->clear_output();
2130                 node->add_output(node4->output(0));
2131 
2132                 reduced_node_count += 1;
2133                 i += 1;
2134             }
2135         }
2136     }
2137 
2138     for (int i = 0; i < node_count; i++)
2139     {
2140         onnx::NodeProto* node = mutable_graph->mutable_node(i);
2141 
2142         // LSTM(uni) <= LSTM(uni) - Squeeze - Transpose
2143         if (node->op_type() == "LSTM" || node->op_type() == "GRU" || node->op_type() == "RNN")
2144         {
2145             if (node_reference[node->output(0)] != 1)
2146                 continue;
2147 
2148             if (i + 1 >= node_count)
2149                 continue;
2150 
2151             onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
2152 
2153             if (node2->op_type() != "Squeeze")
2154                 continue;
2155 
2156             if (node2->input(0) != node->output(0))
2157                 continue;
2158 
2159             std::string direction = get_node_attr_s(*node, "direction");
2160             if (direction == "bidirectional")
2161                 continue;
2162 
2163             // 1
2164             std::vector<int> axes = get_node_attr_ai(*node2, "axes");
2165             if (axes.size() != 1)
2166                 continue;
2167 
2168             if (axes[0] != 1)
2169                 continue;
2170 
2171             // reduce
2172             node2->set_op_type("noop_reducedncnn");
2173 
2174             node_reference[node->output(0)] -= 1;
2175 
2176             blob_names.erase(node->output(0));
2177 
2178             node->set_output(0, node2->output(0));
2179 
2180             reduced_node_count += 1;
2181             i += 1;
2182 
2183             if (i + 1 < node_count)
2184             {
2185                 if (node_reference[node2->output(0)] != 1)
2186                     continue;
2187 
2188                 onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 1);
2189 
2190                 if (node3->op_type() != "Transpose")
2191                     continue;
2192 
2193                 if (node3->input(0) != node->output(0))
2194                     continue;
2195 
2196                 // 1 0 2
2197                 std::vector<int> perm4 = get_node_attr_ai(*node3, "perm");
2198                 if (perm4.size() != 3)
2199                     continue;
2200 
2201                 if (perm4[0] != 1 || perm4[1] != 0 || perm4[2] != 2)
2202                     continue;
2203 
2204                 // reduce
2205                 node3->set_op_type("noop_reducedncnn");
2206 
2207                 node_reference[node->output(0)] -= 1;
2208 
2209                 blob_names.erase(node->output(0));
2210 
2211                 node->clear_output();
2212                 node->add_output(node3->output(0));
2213 
2214                 reduced_node_count += 1;
2215                 i += 1;
2216             }
2217         }
2218     }
2219 
2220     for (int i = 0; i < node_count; i++)
2221     {
2222         onnx::NodeProto* node = mutable_graph->mutable_node(i);
2223 
2224         // LSTM <= Transpose - LSTM
2225         if (node->op_type() == "Transpose")
2226         {
2227             if (node_reference[node->output(0)] != 1)
2228                 continue;
2229 
2230             // 1 0 2
2231             std::vector<int> perm = get_node_attr_ai(*node, "perm");
2232             if (perm.size() != 3)
2233                 continue;
2234 
2235             if (perm[0] != 1 || perm[1] != 0 || perm[2] != 2)
2236                 continue;
2237 
2238             if (i + 1 >= node_count)
2239                 continue;
2240 
2241             onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
2242 
2243             if (node2->op_type() != "LSTM" && node->op_type() != "GRU" && node->op_type() != "RNN")
2244                 continue;
2245 
2246             if (node2->input(0) != node->output(0))
2247                 continue;
2248 
2249             // reduce
2250             node->set_op_type("noop_reducedncnn");
2251 
2252             node_reference[node->output(0)] -= 1;
2253 
2254             blob_names.erase(node->output(0));
2255 
2256             node2->set_input(0, node->input(0));
2257 
2258             reduced_node_count += 1;
2259             i += 1;
2260         }
2261     }
2262 }
2263 
fuse_multiheadattention(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)2264 static void fuse_multiheadattention(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
2265 {
2266     int node_count = mutable_graph->node_size();
2267     for (int i = 0; i < node_count; i++)
2268     {
2269         onnx::NodeProto* node = mutable_graph->mutable_node(i);
2270 
2271         // MultiHeadAttention <= MatMul(q) - Add
2272         //                      - MatMul(k) - Add
2273         //                      - MatMul(v) - Add
2274         //                      - Mul
2275         //                      - Reshape - Transpose
2276         //                      - Reshape - Reshape - Transpose - Transpose
2277         //                      - Gemm - Softmax - Gemm - Transpose - Reshape - MatMul - Add
2278         if (node->op_type() == "MatMul")
2279         {
2280             if (i + 19 >= node_count)
2281                 continue;
2282 
2283             if (node_reference[node->output(0)] != 1)
2284                 continue;
2285 
2286             onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
2287             onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
2288             onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3);
2289             onnx::NodeProto* node5 = mutable_graph->mutable_node(i + 4);
2290             onnx::NodeProto* node6 = mutable_graph->mutable_node(i + 5);
2291             onnx::NodeProto* node7 = mutable_graph->mutable_node(i + 6);
2292             onnx::NodeProto* node8 = mutable_graph->mutable_node(i + 7);
2293             onnx::NodeProto* node9 = mutable_graph->mutable_node(i + 8);
2294             onnx::NodeProto* node10 = mutable_graph->mutable_node(i + 9);
2295             onnx::NodeProto* node11 = mutable_graph->mutable_node(i + 10);
2296             onnx::NodeProto* node12 = mutable_graph->mutable_node(i + 11);
2297             onnx::NodeProto* node13 = mutable_graph->mutable_node(i + 12);
2298             onnx::NodeProto* node14 = mutable_graph->mutable_node(i + 13);
2299             onnx::NodeProto* node15 = mutable_graph->mutable_node(i + 14);
2300             onnx::NodeProto* node16 = mutable_graph->mutable_node(i + 15);
2301             onnx::NodeProto* node17 = mutable_graph->mutable_node(i + 16);
2302             onnx::NodeProto* node18 = mutable_graph->mutable_node(i + 17);
2303             onnx::NodeProto* node19 = mutable_graph->mutable_node(i + 18);
2304             onnx::NodeProto* node20 = mutable_graph->mutable_node(i + 19);
2305 
2306             if (node2->op_type() != "Add" || node3->op_type() != "MatMul" || node4->op_type() != "Add" || node5->op_type() != "MatMul" || node6->op_type() != "Add" || node7->op_type() != "Mul" || node8->op_type() != "Reshape" || node9->op_type() != "Transpose" || node10->op_type() != "Reshape" || node11->op_type() != "Reshape" || node12->op_type() != "Transpose" || node13->op_type() != "Transpose" || node14->op_type() != "MatMul" || node15->op_type() != "Softmax" || node16->op_type() != "MatMul" || node17->op_type() != "Transpose" || node18->op_type() != "Reshape" || node19->op_type() != "MatMul" || node20->op_type() != "Add")
2307                 continue;
2308 
2309             if (node_reference[node2->output(0)] != 1 || node_reference[node3->output(0)] != 1 || node_reference[node4->output(0)] != 1 || node_reference[node5->output(0)] != 1 || node_reference[node6->output(0)] != 1 || node_reference[node7->output(0)] != 1 || node_reference[node8->output(0)] != 1 || node_reference[node9->output(0)] != 1 || node_reference[node10->output(0)] != 1 || node_reference[node11->output(0)] != 1 || node_reference[node12->output(0)] != 1 || node_reference[node13->output(0)] != 1 || node_reference[node14->output(0)] != 1 || node_reference[node15->output(0)] != 1 || node_reference[node16->output(0)] != 1 || node_reference[node17->output(0)] != 1 || node_reference[node18->output(0)] != 1 || node_reference[node19->output(0)] != 1)
2310                 continue;
2311 
2312             if (node2->input(0) != node->output(0) || node4->input(0) != node3->output(0) || node6->input(0) != node5->output(0) || node7->input(0) != node2->output(0) || node8->input(0) != node7->output(0) || node9->input(0) != node8->output(0) || node10->input(0) != node4->output(0) || node11->input(0) != node6->output(0) || node12->input(0) != node11->output(0) || node13->input(0) != node10->output(0) || node14->input(0) != node9->output(0) || node14->input(1) != node13->output(0) || node15->input(0) != node14->output(0) || node16->input(0) != node15->output(0) || node16->input(1) != node12->output(0) || node17->input(0) != node16->output(0) || node18->input(0) != node17->output(0) || node19->input(0) != node18->output(0) || node20->input(0) != node19->output(0))
2313                 continue;
2314 
2315             std::vector<float> q_B = get_node_attr_from_input_af(weights[node2->input(1)]);
2316             std::vector<float> k_B = get_node_attr_from_input_af(weights[node4->input(1)]);
2317             std::vector<float> v_B = get_node_attr_from_input_af(weights[node6->input(1)]);
2318             std::vector<float> o_B = get_node_attr_from_input_af(weights[node20->input(1)]);
2319 
2320             if (q_B.size() != k_B.size() || q_B.size() != v_B.size() || q_B.size() != o_B.size())
2321                 continue;
2322 
2323             int embed_dim = q_B.size();
2324 
2325             // 1 0 2
2326             std::vector<int> perm9 = get_node_attr_ai(*node9, "perm");
2327             std::vector<int> perm12 = get_node_attr_ai(*node12, "perm");
2328             if (perm9.size() != 3 || perm12.size() != 3)
2329                 continue;
2330 
2331             if (perm9[0] != 1 || perm9[1] != 0 || perm9[2] != 2 || perm12[0] != 1 || perm12[1] != 0 || perm12[2] != 2)
2332                 continue;
2333 
2334             // 1 2 0
2335             std::vector<int> perm13 = get_node_attr_ai(*node13, "perm");
2336             if (perm13.size() != 3)
2337                 continue;
2338 
2339             if (perm13[0] != 1 || perm13[1] != 2 || perm13[2] != 0)
2340                 continue;
2341 
2342             // 1 0 2
2343             std::vector<int> perm17 = get_node_attr_ai(*node17, "perm");
2344             if (perm17.size() != 3)
2345                 continue;
2346 
2347             if (perm17[0] != 1 || perm17[1] != 0 || perm17[2] != 2)
2348                 continue;
2349 
2350             int softmax_axis = get_node_attr_i(*node15, "axis");
2351             if (softmax_axis != 2)
2352                 continue;
2353 
2354             // 1/-1, seqlen * num_heads, embed_dim / num_heads
2355             std::vector<int> shape8;
2356             std::vector<int> shape10;
2357             std::vector<int> shape11;
2358             if (node8->input_size() == 1)
2359             {
2360                 shape8 = get_node_attr_ai(*node8, "shape");
2361             }
2362             else
2363             {
2364                 // skip weight reshape
2365                 if (weights.find(node8->input(1)) == weights.end())
2366                     continue;
2367 
2368                 shape8 = get_node_attr_from_input_ai(weights[node8->input(1)]);
2369             }
2370             if (node10->input_size() == 1)
2371             {
2372                 shape10 = get_node_attr_ai(*node10, "shape");
2373             }
2374             else
2375             {
2376                 // skip weight reshape
2377                 if (weights.find(node10->input(1)) == weights.end())
2378                     continue;
2379 
2380                 shape10 = get_node_attr_from_input_ai(weights[node10->input(1)]);
2381             }
2382             if (node11->input_size() == 1)
2383             {
2384                 shape11 = get_node_attr_ai(*node11, "shape");
2385             }
2386             else
2387             {
2388                 // skip weight reshape
2389                 if (weights.find(node11->input(1)) == weights.end())
2390                     continue;
2391 
2392                 shape11 = get_node_attr_from_input_ai(weights[node11->input(1)]);
2393             }
2394 
2395             if (shape8.size() != 3 || shape10.size() != 3 || shape11.size() != 3)
2396                 continue;
2397 
2398             if (shape8[1] != shape10[1] || shape8[1] != shape11[1] || shape8[2] != shape10[2] || shape8[2] != shape11[2])
2399                 continue;
2400 
2401             int num_heads = embed_dim / shape8[2];
2402 
2403             // 1, seqlen, embed_dim
2404             std::vector<int> shape18;
2405             if (node18->input_size() == 1)
2406             {
2407                 shape18 = get_node_attr_ai(*node18, "shape");
2408             }
2409             else
2410             {
2411                 // skip weight reshape
2412                 if (weights.find(node18->input(1)) == weights.end())
2413                     continue;
2414 
2415                 shape18 = get_node_attr_from_input_ai(weights[node18->input(1)]);
2416             }
2417 
2418             if (shape18.size() != 3)
2419                 continue;
2420 
2421             if (shape18[2] != embed_dim || shape18[1] * num_heads != shape8[1])
2422                 continue;
2423 
2424             // reduce
2425             node->set_op_type("noop_reducedncnn");
2426             node2->set_op_type("noop_reducedncnn");
2427             node3->set_op_type("noop_reducedncnn");
2428             node4->set_op_type("noop_reducedncnn");
2429             node5->set_op_type("noop_reducedncnn");
2430             node6->set_op_type("noop_reducedncnn");
2431             node7->set_op_type("noop_reducedncnn");
2432             node8->set_op_type("noop_reducedncnn");
2433             node9->set_op_type("noop_reducedncnn");
2434             node10->set_op_type("noop_reducedncnn");
2435             node11->set_op_type("noop_reducedncnn");
2436             node12->set_op_type("noop_reducedncnn");
2437             node13->set_op_type("noop_reducedncnn");
2438             node14->set_op_type("noop_reducedncnn");
2439             node15->set_op_type("noop_reducedncnn");
2440             node16->set_op_type("noop_reducedncnn");
2441             node17->set_op_type("noop_reducedncnn");
2442             node18->set_op_type("noop_reducedncnn");
2443             node19->set_op_type("noop_reducedncnn");
2444 
2445             node_reference[node2->input(0)] -= 1;
2446             node_reference[node4->input(0)] -= 1;
2447             node_reference[node6->input(0)] -= 1;
2448             node_reference[node7->input(0)] -= 1;
2449             node_reference[node7->input(1)] -= 1;
2450             node_reference[node8->input(0)] -= 1;
2451             if (node8->input_size() == 2)
2452             {
2453                 node_reference[node8->input(1)] -= 1;
2454             }
2455             node_reference[node9->input(0)] -= 1;
2456             node_reference[node10->input(0)] -= 1;
2457             if (node10->input_size() == 2)
2458             {
2459                 node_reference[node10->input(1)] -= 1;
2460             }
2461             node_reference[node11->input(0)] -= 1;
2462             if (node11->input_size() == 2)
2463             {
2464                 node_reference[node11->input(1)] -= 1;
2465             }
2466             node_reference[node12->input(0)] -= 1;
2467             node_reference[node13->input(0)] -= 1;
2468             node_reference[node14->input(0)] -= 1;
2469             node_reference[node14->input(1)] -= 1;
2470             node_reference[node15->input(0)] -= 1;
2471             node_reference[node16->input(0)] -= 1;
2472             node_reference[node16->input(1)] -= 1;
2473             node_reference[node17->input(0)] -= 1;
2474             node_reference[node18->input(0)] -= 1;
2475             if (node18->input_size() == 2)
2476             {
2477                 node_reference[node18->input(1)] -= 1;
2478             }
2479             node_reference[node19->input(0)] -= 1;
2480             node_reference[node20->input(0)] -= 1;
2481 
2482             blob_names.erase(node->output(0));
2483             blob_names.erase(node2->output(0));
2484             blob_names.erase(node3->output(0));
2485             blob_names.erase(node4->output(0));
2486             blob_names.erase(node5->output(0));
2487             blob_names.erase(node6->output(0));
2488             blob_names.erase(node7->output(0));
2489             blob_names.erase(node8->output(0));
2490             blob_names.erase(node9->output(0));
2491             blob_names.erase(node10->output(0));
2492             blob_names.erase(node11->output(0));
2493             blob_names.erase(node12->output(0));
2494             blob_names.erase(node13->output(0));
2495             blob_names.erase(node14->output(0));
2496             blob_names.erase(node15->output(0));
2497             blob_names.erase(node16->output(0));
2498             blob_names.erase(node17->output(0));
2499             blob_names.erase(node18->output(0));
2500             blob_names.erase(node19->output(0));
2501 
2502             std::string qw = node->input(1);
2503             std::string qb = node2->input(1);
2504             std::string kw = node3->input(1);
2505             std::string kb = node4->input(1);
2506             std::string vw = node5->input(1);
2507             std::string vb = node6->input(1);
2508             std::string ow = node19->input(1);
2509             std::string ob = node20->input(1);
2510 
2511             node20->set_op_type("MultiHeadAttention");
2512             node20->clear_input();
2513             node20->add_input(node->input(0));
2514             node20->add_input(node3->input(0));
2515             node20->add_input(node5->input(0));
2516             // q
2517             node20->add_input(qw);
2518             node20->add_input(qb);
2519             // k
2520             node20->add_input(kw);
2521             node20->add_input(kb);
2522             // v
2523             node20->add_input(vw);
2524             node20->add_input(vb);
2525             // out linear
2526             node20->add_input(ow);
2527             node20->add_input(ob);
2528 
2529             onnx::AttributeProto* attr_embed_dim = node20->add_attribute();
2530             attr_embed_dim->set_name("embed_dim");
2531             attr_embed_dim->set_i(embed_dim);
2532 
2533             onnx::AttributeProto* attr_num_heads = node20->add_attribute();
2534             attr_num_heads->set_name("num_heads");
2535             attr_num_heads->set_i(num_heads);
2536 
2537             reduced_node_count += 19;
2538             i += 19;
2539         }
2540     }
2541 
2542     for (int i = 0; i < node_count; i++)
2543     {
2544         onnx::NodeProto* node = mutable_graph->mutable_node(i);
2545 
2546         // MultiHeadAttention <= MatMul(qkv) - Add - Split
2547         //                      - Mul
2548         //                      - Reshape - Transpose
2549         //                      - Reshape - Reshape - Transpose - Transpose
2550         //                      - Gemm - Softmax - Gemm - Transpose - Reshape - MatMul - Add
2551         if (node->op_type() == "MatMul")
2552         {
2553             if (i + 16 >= node_count)
2554                 continue;
2555 
2556             if (node_reference[node->output(0)] != 1)
2557                 continue;
2558 
2559             onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
2560             onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
2561             onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3);
2562             onnx::NodeProto* node5 = mutable_graph->mutable_node(i + 4);
2563             onnx::NodeProto* node6 = mutable_graph->mutable_node(i + 5);
2564             onnx::NodeProto* node7 = mutable_graph->mutable_node(i + 6);
2565             onnx::NodeProto* node8 = mutable_graph->mutable_node(i + 7);
2566             onnx::NodeProto* node9 = mutable_graph->mutable_node(i + 8);
2567             onnx::NodeProto* node10 = mutable_graph->mutable_node(i + 9);
2568             onnx::NodeProto* node11 = mutable_graph->mutable_node(i + 10);
2569             onnx::NodeProto* node12 = mutable_graph->mutable_node(i + 11);
2570             onnx::NodeProto* node13 = mutable_graph->mutable_node(i + 12);
2571             onnx::NodeProto* node14 = mutable_graph->mutable_node(i + 13);
2572             onnx::NodeProto* node15 = mutable_graph->mutable_node(i + 14);
2573             onnx::NodeProto* node16 = mutable_graph->mutable_node(i + 15);
2574             onnx::NodeProto* node17 = mutable_graph->mutable_node(i + 16);
2575 
2576             if (node2->op_type() != "Add" || node3->op_type() != "Split" || node4->op_type() != "Mul" || node5->op_type() != "Reshape" || node6->op_type() != "Transpose" || node7->op_type() != "Reshape" || node8->op_type() != "Reshape" || node9->op_type() != "Transpose" || node10->op_type() != "Transpose" || node11->op_type() != "MatMul" || node12->op_type() != "Softmax" || node13->op_type() != "MatMul" || node14->op_type() != "Transpose" || node15->op_type() != "Reshape" || node16->op_type() != "MatMul" || node17->op_type() != "Add")
2577                 continue;
2578 
2579             if (node_reference[node2->output(0)] != 1 || node_reference[node3->output(0)] != 1 || node_reference[node3->output(1)] != 1 || node_reference[node3->output(2)] != 1 || node_reference[node4->output(0)] != 1 || node_reference[node5->output(0)] != 1 || node_reference[node6->output(0)] != 1 || node_reference[node7->output(0)] != 1 || node_reference[node8->output(0)] != 1 || node_reference[node9->output(0)] != 1 || node_reference[node10->output(0)] != 1 || node_reference[node11->output(0)] != 1 || node_reference[node12->output(0)] != 1 || node_reference[node13->output(0)] != 1 || node_reference[node14->output(0)] != 1 || node_reference[node15->output(0)] != 1 || node_reference[node16->output(0)] != 1)
2580                 continue;
2581 
2582             if (node2->input(0) != node->output(0) || node3->input(0) != node2->output(0) || node4->input(0) != node3->output(0) || node5->input(0) != node4->output(0) || node6->input(0) != node5->output(0) || node7->input(0) != node3->output(1) || node8->input(0) != node3->output(2) || node9->input(0) != node8->output(0) || node10->input(0) != node7->output(0) || node11->input(0) != node6->output(0) || node11->input(1) != node10->output(0) || node12->input(0) != node11->output(0) || node13->input(0) != node12->output(0) || node13->input(1) != node9->output(0) || node14->input(0) != node13->output(0) || node15->input(0) != node14->output(0) || node16->input(0) != node15->output(0) || node17->input(0) != node16->output(0))
2583                 continue;
2584 
2585             std::vector<float> qkv_B = get_node_attr_from_input_af(weights[node2->input(1)]);
2586             std::vector<float> o_B = get_node_attr_from_input_af(weights[node17->input(1)]);
2587 
2588             if (qkv_B.size() != o_B.size() * 3)
2589                 continue;
2590 
2591             int embed_dim = o_B.size();
2592 
2593             // 1 0 2
2594             std::vector<int> perm6 = get_node_attr_ai(*node6, "perm");
2595             std::vector<int> perm9 = get_node_attr_ai(*node9, "perm");
2596             if (perm6.size() != 3 || perm9.size() != 3)
2597                 continue;
2598 
2599             if (perm6[0] != 1 || perm6[1] != 0 || perm6[2] != 2 || perm9[0] != 1 || perm9[1] != 0 || perm9[2] != 2)
2600                 continue;
2601 
2602             // 1 2 0
2603             std::vector<int> perm10 = get_node_attr_ai(*node10, "perm");
2604             if (perm10.size() != 3)
2605                 continue;
2606 
2607             if (perm10[0] != 1 || perm10[1] != 2 || perm10[2] != 0)
2608                 continue;
2609 
2610             // 1 0 2
2611             std::vector<int> perm14 = get_node_attr_ai(*node14, "perm");
2612             if (perm14.size() != 3)
2613                 continue;
2614 
2615             if (perm14[0] != 1 || perm14[1] != 0 || perm14[2] != 2)
2616                 continue;
2617 
2618             int softmax_axis = get_node_attr_i(*node12, "axis");
2619             if (softmax_axis != 2)
2620                 continue;
2621 
2622             // 1/-1, seqlen * num_heads, embed_dim / num_heads
2623             std::vector<int> shape5;
2624             std::vector<int> shape7;
2625             std::vector<int> shape8;
2626             if (node5->input_size() == 1)
2627             {
2628                 shape5 = get_node_attr_ai(*node5, "shape");
2629             }
2630             else
2631             {
2632                 // skip weight reshape
2633                 if (weights.find(node5->input(1)) == weights.end())
2634                     continue;
2635 
2636                 shape5 = get_node_attr_from_input_ai(weights[node5->input(1)]);
2637             }
2638             if (node7->input_size() == 1)
2639             {
2640                 shape7 = get_node_attr_ai(*node7, "shape");
2641             }
2642             else
2643             {
2644                 // skip weight reshape
2645                 if (weights.find(node7->input(1)) == weights.end())
2646                     continue;
2647 
2648                 shape7 = get_node_attr_from_input_ai(weights[node7->input(1)]);
2649             }
2650             if (node8->input_size() == 1)
2651             {
2652                 shape8 = get_node_attr_ai(*node8, "shape");
2653             }
2654             else
2655             {
2656                 // skip weight reshape
2657                 if (weights.find(node8->input(1)) == weights.end())
2658                     continue;
2659 
2660                 shape8 = get_node_attr_from_input_ai(weights[node8->input(1)]);
2661             }
2662 
2663             if (shape5.size() != 3 || shape7.size() != 3 || shape8.size() != 3)
2664                 continue;
2665 
2666             if (shape5[1] != shape7[1] || shape5[1] != shape8[1] || shape5[2] != shape7[2] || shape5[2] != shape8[2])
2667                 continue;
2668 
2669             int num_heads = embed_dim / shape5[2];
2670 
2671             // 1, seqlen, embed_dim
2672             std::vector<int> shape15;
2673             if (node15->input_size() == 1)
2674             {
2675                 shape15 = get_node_attr_ai(*node15, "shape");
2676             }
2677             else
2678             {
2679                 // skip weight reshape
2680                 if (weights.find(node15->input(1)) == weights.end())
2681                     continue;
2682 
2683                 shape15 = get_node_attr_from_input_ai(weights[node15->input(1)]);
2684             }
2685 
2686             if (shape15.size() != 3)
2687                 continue;
2688 
2689             if (shape15[2] != embed_dim || shape15[1] * num_heads != shape8[1])
2690                 continue;
2691 
2692             // reduce
2693             node->set_op_type("noop_reducedncnn");
2694             node2->set_op_type("noop_reducedncnn");
2695             node3->set_op_type("noop_reducedncnn");
2696             node4->set_op_type("noop_reducedncnn");
2697             node5->set_op_type("noop_reducedncnn");
2698             node6->set_op_type("noop_reducedncnn");
2699             node7->set_op_type("noop_reducedncnn");
2700             node8->set_op_type("noop_reducedncnn");
2701             node9->set_op_type("noop_reducedncnn");
2702             node10->set_op_type("noop_reducedncnn");
2703             node11->set_op_type("noop_reducedncnn");
2704             node12->set_op_type("noop_reducedncnn");
2705             node13->set_op_type("noop_reducedncnn");
2706             node14->set_op_type("noop_reducedncnn");
2707             node15->set_op_type("noop_reducedncnn");
2708             node16->set_op_type("noop_reducedncnn");
2709 
2710             node_reference[node2->input(0)] -= 1;
2711             node_reference[node3->input(0)] -= 1;
2712             node_reference[node4->input(0)] -= 1;
2713             node_reference[node4->input(1)] -= 1;
2714             node_reference[node5->input(0)] -= 1;
2715             if (node5->input_size() == 2)
2716             {
2717                 node_reference[node5->input(1)] -= 1;
2718             }
2719             node_reference[node6->input(0)] -= 1;
2720             node_reference[node7->input(0)] -= 1;
2721             if (node7->input_size() == 2)
2722             {
2723                 node_reference[node7->input(1)] -= 1;
2724             }
2725             node_reference[node8->input(0)] -= 1;
2726             if (node8->input_size() == 2)
2727             {
2728                 node_reference[node8->input(1)] -= 1;
2729             }
2730             node_reference[node9->input(0)] -= 1;
2731             node_reference[node10->input(0)] -= 1;
2732             node_reference[node11->input(0)] -= 1;
2733             node_reference[node11->input(1)] -= 1;
2734             node_reference[node12->input(0)] -= 1;
2735             node_reference[node13->input(0)] -= 1;
2736             node_reference[node13->input(1)] -= 1;
2737             node_reference[node14->input(0)] -= 1;
2738             node_reference[node15->input(0)] -= 1;
2739             if (node15->input_size() == 2)
2740             {
2741                 node_reference[node15->input(1)] -= 1;
2742             }
2743             node_reference[node16->input(0)] -= 1;
2744             node_reference[node17->input(0)] -= 1;
2745 
2746             blob_names.erase(node->output(0));
2747             blob_names.erase(node2->output(0));
2748             blob_names.erase(node3->output(0));
2749             blob_names.erase(node3->output(1));
2750             blob_names.erase(node3->output(2));
2751             blob_names.erase(node4->output(0));
2752             blob_names.erase(node5->output(0));
2753             blob_names.erase(node6->output(0));
2754             blob_names.erase(node7->output(0));
2755             blob_names.erase(node8->output(0));
2756             blob_names.erase(node9->output(0));
2757             blob_names.erase(node10->output(0));
2758             blob_names.erase(node11->output(0));
2759             blob_names.erase(node12->output(0));
2760             blob_names.erase(node13->output(0));
2761             blob_names.erase(node14->output(0));
2762             blob_names.erase(node15->output(0));
2763             blob_names.erase(node16->output(0));
2764 
2765             std::string qkvw = node->input(1);
2766             std::string qkvb = node2->input(1);
2767             std::string ow = node16->input(1);
2768             std::string ob = node17->input(1);
2769 
2770             node17->set_op_type("MultiHeadAttention");
2771             node17->clear_input();
2772             node17->add_input(node->input(0));
2773             // qkv
2774             node17->add_input(qkvw);
2775             node17->add_input(qkvb);
2776             // out linear
2777             node17->add_input(ow);
2778             node17->add_input(ob);
2779 
2780             onnx::AttributeProto* attr_embed_dim = node17->add_attribute();
2781             attr_embed_dim->set_name("embed_dim");
2782             attr_embed_dim->set_i(embed_dim);
2783 
2784             onnx::AttributeProto* attr_num_heads = node17->add_attribute();
2785             attr_num_heads->set_name("num_heads");
2786             attr_num_heads->set_i(num_heads);
2787 
2788             reduced_node_count += 16;
2789             i += 16;
2790         }
2791     }
2792 }
2793 
fuse_binaryop_with_scalar(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)2794 static void fuse_binaryop_with_scalar(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
2795 {
2796     int node_count = mutable_graph->node_size();
2797     for (int i = 0; i < node_count; i++)
2798     {
2799         onnx::NodeProto* node = mutable_graph->mutable_node(i);
2800 
2801         // Add/Sub/Mul/Div/Min/Max/Pow
2802         if (node->op_type() == "Add" || node->op_type() == "Sub" || node->op_type() == "Mul" || node->op_type() == "Div" || node->op_type() == "Max" || node->op_type() == "Min" || node->op_type() == "Pow")
2803         {
2804             if (weights.find(node->input(1)) == weights.end())
2805                 continue;
2806 
2807             const onnx::TensorProto& scalar_b = weights[node->input(1)];
2808             if (scalar_b.dims_size() != 0 || get_tensor_proto_data_size(scalar_b) != 1)
2809                 continue;
2810 
2811             float b = get_node_attr_from_input_f(scalar_b);
2812 
2813             node_reference[node->input(1)] -= 1;
2814 
2815             std::string input = node->input(0);
2816 
2817             node->clear_input();
2818             node->add_input(input);
2819 
2820             onnx::AttributeProto* attr_with_scalar = node->add_attribute();
2821             attr_with_scalar->set_name("with_scalar");
2822             attr_with_scalar->set_i(1);
2823 
2824             onnx::AttributeProto* attr_b = node->add_attribute();
2825             attr_b->set_name("b");
2826             attr_b->set_f(b);
2827         }
2828     }
2829 }
2830 
main(int argc,char ** argv)2831 int main(int argc, char** argv)
2832 {
2833     if (!(argc == 2 || argc == 4))
2834     {
2835         fprintf(stderr, "Usage: %s [onnxpb] [ncnnparam] [ncnnbin]\n", argv[0]);
2836         return -1;
2837     }
2838 
2839     const char* onnxpb = argv[1];
2840     const char* ncnn_prototxt = argc == 4 ? argv[2] : "ncnn.param";
2841     const char* ncnn_modelbin = argc == 4 ? argv[3] : "ncnn.bin";
2842 
2843     onnx::ModelProto model;
2844 
2845     // load
2846     bool s1 = read_proto_from_binary(onnxpb, &model);
2847     if (!s1)
2848     {
2849         fprintf(stderr, "read_proto_from_binary failed\n");
2850         return -1;
2851     }
2852 
2853     FILE* pp = fopen(ncnn_prototxt, "wb");
2854     FILE* bp = fopen(ncnn_modelbin, "wb");
2855 
2856     // magic
2857     fprintf(pp, "7767517\n");
2858 
2859     const onnx::GraphProto& graph = model.graph();
2860     onnx::GraphProto* mutable_graph = model.mutable_graph();
2861 
2862     int node_count = graph.node_size();
2863 
2864     // node reference
2865     std::map<std::string, int> node_reference;
2866 
2867     // weight node and weight reshape node
2868     std::map<std::string, onnx::TensorProto> weights;
2869 
2870     for (int j = 0; j < graph.initializer_size(); j++)
2871     {
2872         const onnx::TensorProto& initializer = graph.initializer(j);
2873 
2874         //         fprintf(stderr, "weight = %s %d\n", initializer.name().c_str(), initializer.data_type());
2875 
2876         weights[initializer.name()] = initializer;
2877     }
2878 
2879     // topological sort
2880     {
2881         // name -> producer node index
2882         std::set<std::string> producers;
2883         for (int j = 0; j < graph.input_size(); j++)
2884         {
2885             const std::string& input_name = graph.input(j).name();
2886             producers.insert(input_name);
2887         }
2888 
2889         for (int i = 0; i < node_count;)
2890         {
2891             onnx::NodeProto* node = mutable_graph->mutable_node(i);
2892 
2893             bool swapnode = false;
2894             std::string missing_input_name;
2895             for (int j = 0; j < (int)node->input_size(); j++)
2896             {
2897                 const std::string& input_name = node->input(j);
2898                 if (input_name.empty())
2899                     continue;
2900 
2901                 if (producers.find(input_name) == producers.end() && weights.find(input_name) == weights.end())
2902                 {
2903                     swapnode = true;
2904                     missing_input_name = input_name;
2905                     break;
2906                 }
2907             }
2908 
2909             if (!swapnode)
2910             {
2911                 for (int j = 0; j < (int)node->output_size(); j++)
2912                 {
2913                     const std::string& output_name = node->output(j);
2914                     if (output_name.empty())
2915                         continue;
2916 
2917                     producers.insert(output_name);
2918                 }
2919 
2920                 i++;
2921                 continue;
2922             }
2923 
2924             // find node that produce missing_input_name
2925             int q = i + 1;
2926             for (; q < node_count; q++)
2927             {
2928                 onnx::NodeProto* nodeq = mutable_graph->mutable_node(q);
2929                 bool found = false;
2930                 for (int j = 0; j < (int)nodeq->output_size(); j++)
2931                 {
2932                     const std::string& output_name = nodeq->output(j);
2933                     if (output_name == missing_input_name)
2934                     {
2935                         found = true;
2936                         break;
2937                     }
2938                 }
2939 
2940                 if (found)
2941                     break;
2942             }
2943 
2944             if (q == node_count)
2945             {
2946                 fprintf(stderr, "cannot find node produces %s but node %d requires it\n", missing_input_name.c_str(), i);
2947                 return -1;
2948             }
2949 
2950             // fprintf(stderr, "swap %d %d\n", i, q);
2951             // swap this node with q
2952             onnx::NodeProto* nodeq = mutable_graph->mutable_node(q);
2953             onnx::NodeProto tmp = *node;
2954             *node = *nodeq;
2955             *nodeq = tmp;
2956         }
2957     }
2958 
2959     // global definition line
2960     // [layer count] [blob count]
2961     std::set<std::string> blob_names;
2962     for (int i = 0; i < node_count; i++)
2963     {
2964         const onnx::NodeProto& node = graph.node(i);
2965 
2966         const std::string& op = node.op_type();
2967 
2968         std::string name = node.name();
2969         if (name.empty())
2970         {
2971             name = node.output(0);
2972         }
2973 
2974         if (op == "Constant")
2975         {
2976             onnx::TensorProto tensor = get_node_attr_tensor(node, "value");
2977             weights[node.output(0)] = tensor;
2978         }
2979 
2980         for (int j = 0; j < (int)node.input_size(); j++)
2981         {
2982             const std::string& input_name = node.input(j);
2983 
2984             blob_names.insert(input_name);
2985 
2986             if (node_reference.find(input_name) == node_reference.end())
2987             {
2988                 node_reference[input_name] = 1;
2989             }
2990             else
2991             {
2992                 node_reference[input_name] = node_reference[input_name] + 1;
2993             }
2994         }
2995 
2996         if (op == "Dropout")
2997         {
2998             const std::string& output_name = node.output(0);
2999             blob_names.insert(output_name);
3000             node_reference[output_name] = 0;
3001             continue;
3002         }
3003 
3004         for (int j = 0; j < (int)node.output_size(); j++)
3005         {
3006             const std::string& output_name = node.output(j);
3007 
3008             blob_names.insert(output_name);
3009 
3010             node_reference[output_name] = 0;
3011         }
3012     }
3013 
3014     // include Input node
3015     int input_node_count = 0;
3016     for (int j = 0; j < graph.input_size(); j++)
3017     {
3018         const std::string& input_name = graph.input(j).name();
3019 
3020         // check weight
3021         if (weights.find(input_name) != weights.end())
3022             continue;
3023 
3024         blob_names.insert(input_name);
3025 
3026         input_node_count++;
3027     }
3028 
3029     //     for (auto a: node_reference)
3030     //     {
3031     //         fprintf(stderr, "a = %s %d\n", a.first.c_str(), a.second);
3032     //     }
3033 
3034     // op chain fusion
3035     int reduced_node_count = 0;
3036     fuse_weight_reshape(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3037     fuse_weight_transpose(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3038     fuse_shufflechannel(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3039     fuse_shufflechannel_split(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3040     fuse_hardsigmoid(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3041     fuse_hardswish(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3042     fuse_swish(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3043     fuse_batchnorm1d_squeeze_unsqueeze(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3044     fuse_unsqueeze_prelu(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3045     fuse_normalize(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3046     fuse_groupnorm(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3047     fuse_layernorm(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3048     fuse_flatten(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3049     fuse_pixelshuffle(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3050     fuse_reorg(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3051     fuse_expand_broadcast(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3052     fuse_lstm_gru_rnn(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3053     fuse_multiheadattention(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3054     fuse_binaryop_with_scalar(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3055 
3056     // reduce common const weight node_reference
3057     for (int i = 0; i < node_count; i++)
3058     {
3059         const onnx::NodeProto& node = graph.node(i);
3060 
3061         const std::string& op = node.op_type();
3062 
3063         if (op == "BatchNormalization")
3064         {
3065             node_reference[node.input(1)] -= 1;
3066             node_reference[node.input(2)] -= 1;
3067             node_reference[node.input(3)] -= 1;
3068             node_reference[node.input(4)] -= 1;
3069         }
3070         else if (op == "BiasGelu")
3071         {
3072             node_reference[node.input(1)] -= 1;
3073         }
3074         else if (op == "Clip")
3075         {
3076             if (node.input_size() == 3)
3077             {
3078                 node_reference[node.input(1)] -= 1;
3079                 node_reference[node.input(2)] -= 1;
3080             }
3081         }
3082         else if (op == "Conv")
3083         {
3084             node_reference[node.input(1)] -= 1;
3085             if (node.input_size() == 3)
3086             {
3087                 node_reference[node.input(2)] -= 1;
3088             }
3089         }
3090         else if (op == "ConvTranspose")
3091         {
3092             node_reference[node.input(1)] -= 1;
3093             if (node.input_size() == 3)
3094             {
3095                 node_reference[node.input(2)] -= 1;
3096             }
3097         }
3098         else if (op == "EmbedLayerNormalization")
3099         {
3100             node_reference[node.input(1)] -= 1;
3101             node_reference[node.input(2)] -= 1;
3102             node_reference[node.input(3)] -= 1;
3103             node_reference[node.input(4)] -= 1;
3104             node_reference[node.input(5)] -= 1;
3105             node_reference[node.input(6)] -= 1;
3106         }
3107         else if (op == "Gemm")
3108         {
3109             float alpha = get_node_attr_f(node, "alpha", 1.f);
3110             float beta = get_node_attr_f(node, "beta", 1.f);
3111             int transA = get_node_attr_i(node, "transA", 0);
3112             int transB = get_node_attr_i(node, "transB", 0);
3113 
3114             if (alpha == 1.f && beta == 1.f && transA == 0 && transB == 1)
3115             {
3116                 // InnerProduct-like A * B + C
3117                 node_reference[node.input(1)] -= 1;
3118                 node_reference[node.input(2)] -= 1;
3119             }
3120         }
3121         else if (op == "GroupNorm")
3122         {
3123             int affine = get_node_attr_i(node, "affine", 1);
3124             if (affine)
3125             {
3126                 node_reference[node.input(1)] -= 1;
3127                 node_reference[node.input(2)] -= 1;
3128             }
3129         }
3130         else if (op == "GRU")
3131         {
3132             for (int j = 1; j < node.input_size(); j++)
3133             {
3134                 node_reference[node.input(j)] -= 1;
3135             }
3136         }
3137         else if (op == "InstanceNormalization")
3138         {
3139             node_reference[node.input(1)] -= 1;
3140             node_reference[node.input(2)] -= 1;
3141         }
3142         else if (op == "LayerNorm")
3143         {
3144             int affine = get_node_attr_i(node, "affine", 1);
3145             if (affine)
3146             {
3147                 node_reference[node.input(1)] -= 1;
3148                 node_reference[node.input(2)] -= 1;
3149             }
3150         }
3151         else if (op == "LSTM")
3152         {
3153             for (int j = 1; j < node.input_size(); j++)
3154             {
3155                 node_reference[node.input(j)] -= 1;
3156             }
3157         }
3158         else if (op == "MatMul")
3159         {
3160             if (weights.find(node.input(1)) != weights.end() && weights[node.input(1)].dims_size() == 2)
3161             {
3162                 // InnerProduct
3163                 node_reference[node.input(1)] -= 1;
3164             }
3165         }
3166         else if (op == "MultiHeadAttention")
3167         {
3168             if (node.input_size() == 5)
3169             {
3170                 node_reference[node.input(1)] -= 1;
3171                 node_reference[node.input(2)] -= 1;
3172                 node_reference[node.input(3)] -= 1;
3173                 node_reference[node.input(4)] -= 1;
3174             }
3175             else
3176             {
3177                 node_reference[node.input(3)] -= 1;
3178                 node_reference[node.input(4)] -= 1;
3179                 node_reference[node.input(5)] -= 1;
3180                 node_reference[node.input(6)] -= 1;
3181                 node_reference[node.input(7)] -= 1;
3182                 node_reference[node.input(8)] -= 1;
3183                 node_reference[node.input(9)] -= 1;
3184                 node_reference[node.input(10)] -= 1;
3185             }
3186         }
3187         else if (op == "Pad")
3188         {
3189             if (node.input_size() >= 2)
3190             {
3191                 node_reference[node.input(1)] -= 1;
3192             }
3193         }
3194         else if (op == "PRelu")
3195         {
3196             node_reference[node.input(1)] -= 1;
3197         }
3198         else if (op == "Reshape")
3199         {
3200             if (node.input_size() >= 2)
3201             {
3202                 node_reference[node.input(1)] -= 1;
3203             }
3204         }
3205         else if (op == "Resize")
3206         {
3207             if (node.input_size() == 2)
3208             {
3209                 // opset 10
3210                 node_reference[node.input(1)] -= 1;
3211             }
3212             else
3213             {
3214                 // opset 11+
3215                 node_reference[node.input(1)] -= 1;
3216                 node_reference[node.input(2)] -= 1;
3217                 if (node.input_size() >= 4)
3218                 {
3219                     node_reference[node.input(3)] -= 1;
3220                 }
3221             }
3222         }
3223         else if (op == "RNN")
3224         {
3225             for (int j = 1; j < node.input_size(); j++)
3226             {
3227                 node_reference[node.input(j)] -= 1;
3228             }
3229         }
3230         else if (op == "SkipLayerNormalization")
3231         {
3232             node_reference[node.input(2)] -= 1;
3233             node_reference[node.input(3)] -= 1;
3234             node_reference[node.input(4)] -= 1;
3235         }
3236         else if (op == "Slice")
3237         {
3238             if (node.input_size() >= 2)
3239             {
3240                 node_reference[node.input(1)] -= 1;
3241                 node_reference[node.input(2)] -= 1;
3242                 if (node.input_size() >= 4)
3243                     node_reference[node.input(3)] -= 1;
3244                 if (node.input_size() >= 5)
3245                     node_reference[node.input(4)] -= 1;
3246             }
3247         }
3248         else if (op == "Upsample")
3249         {
3250             if (node.input_size() >= 2)
3251             {
3252                 node_reference[node.input(1)] -= 1;
3253             }
3254         }
3255         else if (op == "adaptive_avg_pool2d" || op == "adaptive_max_pool2d")
3256         {
3257             if (node.input_size() >= 2)
3258             {
3259                 node_reference[node.input(1)] -= 1;
3260             }
3261         }
3262     }
3263 
3264     //         for (auto a: node_reference)
3265     //         {
3266     //             fprintf(stderr, "b = %s %d\n", a.first.c_str(), a.second);
3267     //         }
3268 
3269     // count all weight node with zero reference
3270     int zero_reference_weight_node_count = 0;
3271     for (std::map<std::string, onnx::TensorProto>::iterator it = weights.begin(); it != weights.end(); it++)
3272     {
3273         const std::string& input_name = it->first;
3274 
3275         int refcount = node_reference[input_name];
3276         if (refcount == 0)
3277             zero_reference_weight_node_count++;
3278     }
3279 
3280     // we always treat constant node as weight or binaryop_weights
3281     // do not count it twice for layer_count
3282     int constant_node_count_moved_to_weight = 0;
3283     for (int i = 0; i < node_count; i++)
3284     {
3285         const onnx::NodeProto& node = graph.node(i);
3286 
3287         const std::string& op = node.op_type();
3288 
3289         if (op == "Constant")
3290         {
3291             constant_node_count_moved_to_weight++;
3292         }
3293     }
3294 
3295     // some op may have anonymous input
3296     // LSTM sequence_lens
3297     blob_names.erase("");
3298     node_reference.erase("");
3299 
3300     // remove node_reference entry with reference equals to one
3301     int split_layer_count = 0;
3302     int splitncnn_blob_count = 0;
3303     // split node reference
3304     std::map<std::string, int> split_node_reference;
3305     for (std::map<std::string, int>::iterator it = node_reference.begin(); it != node_reference.end(); it++)
3306     {
3307         if (it->second > 1)
3308         {
3309             split_layer_count++;
3310             splitncnn_blob_count += it->second;
3311 
3312             split_node_reference[it->first] = it->second;
3313         }
3314     }
3315 
3316     fprintf(pp, "%zu %zu\n", node_count - constant_node_count_moved_to_weight + weights.size() - zero_reference_weight_node_count - reduced_node_count + input_node_count + split_layer_count, blob_names.size() - zero_reference_weight_node_count + splitncnn_blob_count);
3317 
3318     int internal_split = 0;
3319 
3320     // place Input at the beginning
3321     for (int j = 0; j < graph.input_size(); j++)
3322     {
3323         const std::string& input_name = graph.input(j).name();
3324 
3325         // check weight
3326         if (weights.find(input_name) != weights.end())
3327             continue;
3328 
3329         fprintf(pp, "%-16s %-24s 0 1 %s\n", "Input", input_name.c_str(), input_name.c_str());
3330 
3331         int refcount = node_reference[input_name];
3332         if (refcount <= 1)
3333         {
3334             continue;
3335         }
3336 
3337         char splitname[256];
3338         sprintf(splitname, "splitncnn_input%d", j);
3339         fprintf(pp, "%-16s %-24s %d %d", "Split", splitname, 1, refcount);
3340         fprintf(pp, " %s", input_name.c_str());
3341 
3342         for (int k = 0; k < refcount; k++)
3343         {
3344             fprintf(pp, " %s_splitncnn_%d", input_name.c_str(), k);
3345         }
3346         fprintf(pp, "\n");
3347     }
3348 
3349     // place MemoryData next
3350     for (std::map<std::string, onnx::TensorProto>::iterator weight_it = weights.begin(); weight_it != weights.end(); weight_it++)
3351     {
3352         const std::string& input_name = weight_it->first;
3353 
3354         int refcount = node_reference[input_name];
3355         if (refcount == 0)
3356         {
3357             continue;
3358         }
3359 
3360         fprintf(pp, "%-16s %-24s 0 1 %s", "MemoryData", input_name.c_str(), input_name.c_str());
3361 
3362         const onnx::TensorProto& M = weights[input_name];
3363 
3364         if (M.dims_size() == 0)
3365         {
3366             fprintf(pp, " 0=%d", get_tensor_proto_data_size(M));
3367         }
3368         else if (M.dims_size() == 1)
3369         {
3370             fprintf(pp, " 0=%d", (int)M.dims(0));
3371         }
3372         else if (M.dims_size() == 2)
3373         {
3374             fprintf(pp, " 0=%d", (int)M.dims(1));
3375             if (M.dims(0) != 1)
3376             {
3377                 fprintf(pp, " 1=%d", (int)M.dims(0));
3378             }
3379         }
3380         else if (M.dims_size() == 3)
3381         {
3382             fprintf(pp, " 0=%d", (int)M.dims(2));
3383             fprintf(pp, " 1=%d", (int)M.dims(1));
3384             if (M.dims(0) != 1)
3385             {
3386                 fprintf(pp, " 2=%d", (int)M.dims(0));
3387             }
3388         }
3389         else if (M.dims_size() == 4)
3390         {
3391             fprintf(pp, " 0=%d", (int)M.dims(3));
3392             fprintf(pp, " 1=%d", (int)M.dims(2));
3393             fprintf(pp, " 2=%d", (int)M.dims(1));
3394         }
3395 
3396         fprintf(pp, "\n");
3397 
3398         fwrite_tensor_proto_data(M, bp);
3399 
3400         if (refcount <= 1)
3401         {
3402             continue;
3403         }
3404 
3405         char splitname[256];
3406         sprintf(splitname, "splitncnn_%d", internal_split);
3407         fprintf(pp, "%-16s %-24s %d %d", "Split", splitname, 1, refcount);
3408 
3409         fprintf(pp, " %s", input_name.c_str());
3410 
3411         for (int k = 0; k < refcount; k++)
3412         {
3413             fprintf(pp, " %s_splitncnn_%d", input_name.c_str(), k);
3414         }
3415         fprintf(pp, "\n");
3416 
3417         internal_split++;
3418     }
3419 
3420     for (int i = 0; i < node_count; i++)
3421     {
3422         const onnx::NodeProto& node = graph.node(i);
3423 
3424         const std::string& op = node.op_type();
3425 
3426         //         fprintf(stderr, "op = %s\n", op.c_str());
3427 
3428         if (op == "noop_reducedncnn")
3429         {
3430             continue;
3431         }
3432 
3433         std::string name = node.name();
3434         if (name.empty())
3435         {
3436             name = node.output(0);
3437         }
3438 
3439         int input_size = node.input_size();
3440         int output_size = node.output_size();
3441 
3442         for (int j = 0; j < (int)node.input_size(); j++)
3443         {
3444             const std::string& input_name = node.input(j);
3445 
3446             // check weight
3447             if (weights.find(input_name) != weights.end() && node_reference[input_name] == 0)
3448             {
3449                 input_size--;
3450             }
3451 
3452             if (input_name.empty())
3453             {
3454                 input_size--;
3455             }
3456 
3457             //             fprintf(stderr, "  input = %s\n", input_name.c_str());
3458         }
3459         /*
3460         for (int j=0; j<(int)node.output_size(); j++)
3461         {
3462             const std::string& output_name = node.output(j);
3463             fprintf(stderr, "  output = %s\n", output_name.c_str());
3464         }
3465         */
3466 
3467         if (op == "Abs")
3468         {
3469             fprintf(pp, "%-16s", "UnaryOp");
3470         }
3471         else if (op == "Acos")
3472         {
3473             fprintf(pp, "%-16s", "UnaryOp");
3474         }
3475         else if (op == "Add")
3476         {
3477             fprintf(pp, "%-16s", "BinaryOp");
3478         }
3479         else if (op == "Asin")
3480         {
3481             fprintf(pp, "%-16s", "UnaryOp");
3482         }
3483         else if (op == "Atan")
3484         {
3485             fprintf(pp, "%-16s", "UnaryOp");
3486         }
3487         else if (op == "AveragePool" || op == "MaxPool")
3488         {
3489             fprintf(pp, "%-16s", "Pooling");
3490         }
3491         else if (op == "BatchNormalization")
3492         {
3493             fprintf(pp, "%-16s", "BatchNorm");
3494         }
3495         else if (op == "BiasGelu")
3496         {
3497             fprintf(pp, "%-16s", "BiasGelu");
3498         }
3499         else if (op == "Ceil")
3500         {
3501             fprintf(pp, "%-16s", "UnaryOp");
3502         }
3503         else if (op == "Clip")
3504         {
3505             fprintf(pp, "%-16s", "Clip");
3506         }
3507         else if (op == "Concat")
3508         {
3509             fprintf(pp, "%-16s", "Concat");
3510         }
3511         else if (op == "Constant")
3512         {
3513             continue;
3514         }
3515         else if (op == "Conv")
3516         {
3517             int group = get_node_attr_i(node, "group", 1);
3518             if (group > 1)
3519             {
3520                 fprintf(pp, "%-16s", "ConvolutionDepthWise");
3521             }
3522             else
3523             {
3524                 fprintf(pp, "%-16s", "Convolution");
3525             }
3526         }
3527         else if (op == "ConvTranspose")
3528         {
3529             int group = get_node_attr_i(node, "group", 1);
3530             if (group > 1)
3531             {
3532                 fprintf(pp, "%-16s", "DeconvolutionDepthWise");
3533             }
3534             else
3535             {
3536                 fprintf(pp, "%-16s", "Deconvolution");
3537             }
3538         }
3539         else if (op == "Cos")
3540         {
3541             fprintf(pp, "%-16s", "UnaryOp");
3542         }
3543         else if (op == "DepthToSpace")
3544         {
3545             fprintf(pp, "%-16s", "PixelShuffle");
3546         }
3547         else if (op == "Div")
3548         {
3549             fprintf(pp, "%-16s", "BinaryOp");
3550         }
3551         else if (op == "Dropout")
3552         {
3553             fprintf(pp, "%-16s", "Dropout");
3554             output_size = 1;
3555         }
3556         else if (op == "Elu")
3557         {
3558             fprintf(pp, "%-16s", "ELU");
3559         }
3560         else if (op == "EmbedLayerNormalization")
3561         {
3562             fprintf(pp, "%-16s", "EmbedLayerNormalization");
3563         }
3564         else if (op == "Exp")
3565         {
3566             fprintf(pp, "%-16s", "UnaryOp");
3567         }
3568         else if (op == "Flatten")
3569         {
3570             fprintf(pp, "%-16s", "Flatten");
3571         }
3572         else if (op == "Floor")
3573         {
3574             fprintf(pp, "%-16s", "UnaryOp");
3575         }
3576         else if (op == "Gemm")
3577         {
3578             float alpha = get_node_attr_f(node, "alpha", 1.f);
3579             float beta = get_node_attr_f(node, "beta", 1.f);
3580             int transA = get_node_attr_i(node, "transA", 0);
3581             int transB = get_node_attr_i(node, "transB", 0);
3582 
3583             if (alpha == 1.f && beta == 1.f && transA == 0 && transB == 1)
3584             {
3585                 // InnerProduct-like A * B + C
3586                 fprintf(pp, "%-16s", "InnerProduct");
3587             }
3588             else
3589             {
3590                 fprintf(pp, "%-16s", "Gemm");
3591             }
3592         }
3593         else if (op == "GlobalAveragePool")
3594         {
3595             fprintf(pp, "%-16s", "Pooling");
3596         }
3597         else if (op == "GlobalMaxPool")
3598         {
3599             fprintf(pp, "%-16s", "Pooling");
3600         }
3601         else if (op == "adaptive_avg_pool2d" || op == "adaptive_max_pool2d")
3602         {
3603             fprintf(pp, "%-16s", "Pooling");
3604         }
3605         else if (op == "GroupNorm")
3606         {
3607             fprintf(pp, "%-16s", "GroupNorm");
3608         }
3609         else if (op == "GRU")
3610         {
3611             fprintf(pp, "%-16s", "GRU");
3612         }
3613         else if (op == "HardSigmoid")
3614         {
3615             fprintf(pp, "%-16s", "HardSigmoid");
3616         }
3617         else if (op == "HardSwish")
3618         {
3619             fprintf(pp, "%-16s", "HardSwish");
3620         }
3621         else if (op == "ImageScaler")
3622         {
3623             fprintf(pp, "%-16s", "Scale");
3624         }
3625         else if (op == "InstanceNormalization")
3626         {
3627             fprintf(pp, "%-16s", "InstanceNorm");
3628         }
3629         else if (op == "LayerNorm")
3630         {
3631             fprintf(pp, "%-16s", "LayerNorm");
3632         }
3633         else if (op == "LeakyRelu")
3634         {
3635             fprintf(pp, "%-16s", "ReLU");
3636         }
3637         else if (op == "Log")
3638         {
3639             fprintf(pp, "%-16s", "UnaryOp");
3640         }
3641         else if (op == "LRN")
3642         {
3643             fprintf(pp, "%-16s", "LRN");
3644         }
3645         else if (op == "LSTM")
3646         {
3647             fprintf(pp, "%-16s", "LSTM");
3648         }
3649         else if (op == "MatMul")
3650         {
3651             if (weights.find(node.input(1)) != weights.end() && weights[node.input(1)].dims_size() == 2)
3652             {
3653                 fprintf(pp, "%-16s", "InnerProduct");
3654             }
3655             else
3656             {
3657                 fprintf(pp, "%-16s", "Gemm");
3658             }
3659         }
3660         else if (op == "Max")
3661         {
3662             fprintf(pp, "%-16s", "BinaryOp");
3663         }
3664         else if (op == "Min")
3665         {
3666             fprintf(pp, "%-16s", "BinaryOp");
3667         }
3668         else if (op == "Mul")
3669         {
3670             fprintf(pp, "%-16s", "BinaryOp");
3671         }
3672         else if (op == "MultiHeadAttention")
3673         {
3674             fprintf(pp, "%-16s", "MultiHeadAttention");
3675         }
3676         else if (op == "Neg")
3677         {
3678             fprintf(pp, "%-16s", "UnaryOp");
3679         }
3680         else if (op == "Normalize")
3681         {
3682             fprintf(pp, "%-16s", "Normalize");
3683         }
3684         else if (op == "Pad")
3685         {
3686             fprintf(pp, "%-16s", "Padding");
3687         }
3688         else if (op == "PixelShuffle")
3689         {
3690             fprintf(pp, "%-16s", "PixelShuffle");
3691         }
3692         else if (op == "Pow")
3693         {
3694             fprintf(pp, "%-16s", "BinaryOp");
3695         }
3696         else if (op == "PRelu")
3697         {
3698             fprintf(pp, "%-16s", "PReLU");
3699         }
3700         else if (op == "Reciprocal")
3701         {
3702             fprintf(pp, "%-16s", "UnaryOp");
3703         }
3704         else if (op == "ReduceMax" || op == "ReduceMin" || op == "ReduceMean" || op == "ReduceProd" || op == "ReduceSum" || op == "ReduceSumSquare" || op == "ReduceL1" || op == "ReduceL2" || op == "ReduceLogSum" || op == "ReduceLogSumExp")
3705         {
3706             fprintf(pp, "%-16s", "Reduction");
3707         }
3708         else if (op == "Relu")
3709         {
3710             fprintf(pp, "%-16s", "ReLU");
3711         }
3712         else if (op == "Reorg")
3713         {
3714             fprintf(pp, "%-16s", "Reorg");
3715         }
3716         else if (op == "Reshape")
3717         {
3718             fprintf(pp, "%-16s", "Reshape");
3719         }
3720         else if (op == "RNN")
3721         {
3722             fprintf(pp, "%-16s", "RNN");
3723         }
3724         else if (op == "ShuffleChannel")
3725         {
3726             fprintf(pp, "%-16s", "ShuffleChannel");
3727         }
3728         else if (op == "Sigmoid")
3729         {
3730             fprintf(pp, "%-16s", "Sigmoid");
3731         }
3732         else if (op == "Sin")
3733         {
3734             fprintf(pp, "%-16s", "UnaryOp");
3735         }
3736         else if (op == "SkipLayerNormalization")
3737         {
3738             fprintf(pp, "%-16s", "SkipLayerNormalization");
3739         }
3740         else if (op == "Slice")
3741         {
3742             fprintf(pp, "%-16s", "Crop");
3743         }
3744         else if (op == "Softmax")
3745         {
3746             fprintf(pp, "%-16s", "Softmax");
3747         }
3748         else if (op == "Softplus")
3749         {
3750             fprintf(pp, "%-16s", "Softplus");
3751         }
3752         else if (op == "Split")
3753         {
3754             fprintf(pp, "%-16s", "Slice");
3755         }
3756         else if (op == "Sqrt")
3757         {
3758             fprintf(pp, "%-16s", "UnaryOp");
3759         }
3760         else if (op == "Squeeze")
3761         {
3762             fprintf(pp, "%-16s", "Squeeze");
3763         }
3764         else if (op == "Sub")
3765         {
3766             fprintf(pp, "%-16s", "BinaryOp");
3767         }
3768         else if (op == "Sum")
3769         {
3770             fprintf(pp, "%-16s", "Eltwise");
3771         }
3772         else if (op == "Swish")
3773         {
3774             fprintf(pp, "%-16s", "Swish");
3775         }
3776         else if (op == "Tan")
3777         {
3778             fprintf(pp, "%-16s", "UnaryOp");
3779         }
3780         else if (op == "Tanh")
3781         {
3782             fprintf(pp, "%-16s", "UnaryOp");
3783         }
3784         else if (op == "Transpose")
3785         {
3786             fprintf(pp, "%-16s", "Permute");
3787         }
3788         else if (op == "Upsample" || op == "Resize")
3789         {
3790             fprintf(pp, "%-16s", "Interp");
3791         }
3792         else if (op == "Unsqueeze")
3793         {
3794             fprintf(pp, "%-16s", "ExpandDims");
3795         }
3796         else
3797         {
3798             // TODO
3799             fprintf(stderr, "%s not supported yet!\n", op.c_str());
3800             fprintf(pp, "%-16s", op.c_str());
3801         }
3802 
3803         fprintf(pp, " %-24s %d %d", name.c_str(), input_size, output_size);
3804 
3805         for (int j = 0; j < (int)node.input_size(); j++)
3806         {
3807             std::string input_name = node.input(j);
3808 
3809             // check weight
3810             if (weights.find(input_name) != weights.end() && node_reference[input_name] == 0)
3811             {
3812                 continue;
3813             }
3814 
3815             if (input_name.empty())
3816             {
3817                 continue;
3818             }
3819 
3820             if (split_node_reference.find(input_name) != split_node_reference.end())
3821             {
3822                 int refidx = split_node_reference[input_name] - 1;
3823                 split_node_reference[input_name] = refidx;
3824 
3825                 char splitsuffix[256];
3826                 sprintf(splitsuffix, "_splitncnn_%d", refidx);
3827                 input_name = input_name + splitsuffix;
3828             }
3829 
3830             fprintf(pp, " %s", input_name.c_str());
3831         }
3832 
3833         for (int j = 0; j < output_size; j++)
3834         {
3835             const std::string& output_name = node.output(j);
3836 
3837             fprintf(pp, " %s", output_name.c_str());
3838         }
3839 
3840         if (op == "Abs")
3841         {
3842             int op_type = 0;
3843             fprintf(pp, " 0=%d", op_type);
3844         }
3845         else if (op == "Acos")
3846         {
3847             int op_type = 13;
3848             fprintf(pp, " 0=%d", op_type);
3849         }
3850         else if (op == "Add")
3851         {
3852             int op_type = 0;
3853             fprintf(pp, " 0=%d", op_type);
3854 
3855             int with_scalar = get_node_attr_i(node, "with_scalar", 0);
3856             float b = get_node_attr_f(node, "b", 0.f);
3857             if (with_scalar)
3858             {
3859                 fprintf(pp, " 1=%d", with_scalar);
3860                 fprintf(pp, " 2=%e", b);
3861             }
3862         }
3863         else if (op == "Asin")
3864         {
3865             int op_type = 12;
3866             fprintf(pp, " 0=%d", op_type);
3867         }
3868         else if (op == "Atan")
3869         {
3870             int op_type = 14;
3871             fprintf(pp, " 0=%d", op_type);
3872         }
3873         else if (op == "AveragePool" || op == "MaxPool")
3874         {
3875             std::string auto_pad = get_node_attr_s(node, "auto_pad");
3876             int ceil_mode = get_node_attr_i(node, "ceil_mode", 0);
3877             std::vector<int> kernel_shape = get_node_attr_ai(node, "kernel_shape");
3878             std::vector<int> strides = get_node_attr_ai(node, "strides");
3879             std::vector<int> pads = get_node_attr_ai(node, "pads");
3880 
3881             int pool = op == "AveragePool" ? 1 : 0;
3882             int pad_mode = 1;
3883 
3884             if (auto_pad == "SAME_UPPER")
3885             {
3886                 pad_mode = 2;
3887             }
3888             else if (auto_pad == "SAME_LOWER")
3889             {
3890                 pad_mode = 3;
3891             }
3892 
3893             if (ceil_mode == 1)
3894             {
3895                 pad_mode = 0;
3896             }
3897 
3898             fprintf(pp, " 0=%d", pool);
3899 
3900             if (kernel_shape.size() == 1)
3901             {
3902                 fprintf(pp, " 1=%d", kernel_shape[0]);
3903             }
3904             else if (kernel_shape.size() == 2)
3905             {
3906                 fprintf(pp, " 1=%d", kernel_shape[1]);
3907                 fprintf(pp, " 11=%d", kernel_shape[0]);
3908             }
3909 
3910             if (strides.size() == 1)
3911             {
3912                 fprintf(pp, " 2=%d", strides[0]);
3913             }
3914             else if (strides.size() == 2)
3915             {
3916                 fprintf(pp, " 2=%d", strides[1]);
3917                 fprintf(pp, " 12=%d", strides[0]);
3918             }
3919 
3920             if (pads.size() == 1)
3921             {
3922                 fprintf(pp, " 3=%d", pads[0]);
3923             }
3924             else if (pads.size() == 2)
3925             {
3926                 fprintf(pp, " 3=%d", pads[1]);
3927                 fprintf(pp, " 13=%d", pads[0]);
3928             }
3929             else if (pads.size() == 4)
3930             {
3931                 fprintf(pp, " 3=%d", pads[1]);
3932                 fprintf(pp, " 13=%d", pads[0]);
3933                 fprintf(pp, " 14=%d", pads[3]);
3934                 fprintf(pp, " 15=%d", pads[2]);
3935             }
3936 
3937             fprintf(pp, " 5=%d", pad_mode);
3938 
3939             if (op == "AveragePool")
3940             {
3941                 int avgpool_count_include_pad = get_node_attr_i(node, "count_include_pad", 0);
3942                 fprintf(pp, " 6=%d", avgpool_count_include_pad);
3943             }
3944         }
3945         else if (op == "BatchNormalization")
3946         {
3947             float epsilon = get_node_attr_f(node, "epsilon", 1e-5f);
3948 
3949             const onnx::TensorProto& scale = weights[node.input(1)];
3950             const onnx::TensorProto& B = weights[node.input(2)];
3951             const onnx::TensorProto& mean = weights[node.input(3)];
3952             const onnx::TensorProto& var = weights[node.input(4)];
3953 
3954             int channels = get_tensor_proto_data_size(scale);
3955 
3956             fprintf(pp, " 0=%d", channels);
3957 
3958             fwrite_tensor_proto_data(scale, bp);
3959             fwrite_tensor_proto_data(mean, bp);
3960             // apply epsilon to var
3961             {
3962                 const float* v = var.has_raw_data() ? (const float*)var.raw_data().data() : var.float_data().data();
3963 
3964                 for (int j = 0; j < channels; j++)
3965                 {
3966                     float ve = v[j] + epsilon;
3967                     fwrite(&ve, sizeof(float), 1, bp);
3968                 }
3969             }
3970             fwrite_tensor_proto_data(B, bp);
3971         }
3972         else if (op == "BiasGelu")
3973         {
3974             const onnx::TensorProto& B = weights[node.input(1)];
3975 
3976             fprintf(pp, " 0=%d", get_tensor_proto_data_size(B));
3977 
3978             int quantize_tag = 0;
3979             fwrite(&quantize_tag, sizeof(int), 1, bp);
3980 
3981             fwrite_tensor_proto_data(B, bp);
3982         }
3983         else if (op == "Ceil")
3984         {
3985             int op_type = 3;
3986             fprintf(pp, " 0=%d", op_type);
3987         }
3988         else if (op == "Clip")
3989         {
3990             float min;
3991             float max;
3992             if (node.input_size() == 1)
3993             {
3994                 min = get_node_attr_f(node, "min", -FLT_MAX);
3995                 max = get_node_attr_f(node, "max", FLT_MAX);
3996             }
3997             else
3998             {
3999                 min = weights.find(node.input(1)) != weights.end() ? get_node_attr_from_input_f(weights[node.input(1)]) : -FLT_MAX;
4000                 max = weights.find(node.input(2)) != weights.end() ? get_node_attr_from_input_f(weights[node.input(2)]) : FLT_MAX;
4001             }
4002 
4003             fprintf(pp, " 0=%e", min);
4004             fprintf(pp, " 1=%e", max);
4005         }
4006         else if (op == "Concat")
4007         {
4008             int axis = get_node_attr_i(node, "axis", 1);
4009             fprintf(pp, " 0=%d", axis - 1);
4010         }
4011         else if (op == "Constant")
4012         {
4013             // never reach here
4014         }
4015         else if (op == "Conv")
4016         {
4017             const onnx::TensorProto& W = weights[node.input(1)];
4018 
4019             int num_filter = W.dims(0);
4020             int has_bias = node.input_size() == 3 ? 1 : 0;
4021 
4022             std::string auto_pad = get_node_attr_s(node, "auto_pad");
4023             std::vector<int> kernel_shape = get_node_attr_ai(node, "kernel_shape");
4024             std::vector<int> dilations = get_node_attr_ai(node, "dilations");
4025             std::vector<int> strides = get_node_attr_ai(node, "strides");
4026             std::vector<int> pads = get_node_attr_ai(node, "pads");
4027             int group = get_node_attr_i(node, "group", 1);
4028 
4029             fprintf(pp, " 0=%d", num_filter);
4030 
4031             if (kernel_shape.size() == 1)
4032             {
4033                 fprintf(pp, " 1=%d", kernel_shape[0]);
4034             }
4035             else if (kernel_shape.size() == 2)
4036             {
4037                 fprintf(pp, " 1=%d", kernel_shape[1]);
4038                 fprintf(pp, " 11=%d", kernel_shape[0]);
4039             }
4040 
4041             if (dilations.size() == 1)
4042             {
4043                 fprintf(pp, " 2=%d", dilations[0]);
4044             }
4045             else if (dilations.size() == 2)
4046             {
4047                 fprintf(pp, " 2=%d", dilations[1]);
4048                 fprintf(pp, " 12=%d", dilations[0]);
4049             }
4050 
4051             if (strides.size() == 1)
4052             {
4053                 fprintf(pp, " 3=%d", strides[0]);
4054             }
4055             else if (strides.size() == 2)
4056             {
4057                 fprintf(pp, " 3=%d", strides[1]);
4058                 fprintf(pp, " 13=%d", strides[0]);
4059             }
4060 
4061             if (auto_pad == "SAME_UPPER")
4062             {
4063                 fprintf(pp, " 4=-233");
4064             }
4065             else if (auto_pad == "SAME_LOWER")
4066             {
4067                 fprintf(pp, " 4=-234");
4068             }
4069             else
4070             {
4071                 if (pads.size() == 1)
4072                 {
4073                     fprintf(pp, " 4=%d", pads[0]);
4074                 }
4075                 else if (pads.size() == 2)
4076                 {
4077                     fprintf(pp, " 4=%d", pads[1]);
4078                     fprintf(pp, " 14=%d", pads[0]);
4079                 }
4080                 else if (pads.size() == 4)
4081                 {
4082                     fprintf(pp, " 4=%d", pads[1]);
4083                     fprintf(pp, " 14=%d", pads[0]);
4084                     fprintf(pp, " 15=%d", pads[3]);
4085                     fprintf(pp, " 16=%d", pads[2]);
4086                 }
4087             }
4088 
4089             fprintf(pp, " 5=%d", has_bias);
4090 
4091             fprintf(pp, " 6=%d", get_tensor_proto_data_size(W));
4092 
4093             if (group > 1)
4094             {
4095                 fprintf(pp, " 7=%d", group);
4096             }
4097 
4098             int quantize_tag = 0;
4099             fwrite(&quantize_tag, sizeof(int), 1, bp);
4100 
4101             fwrite_tensor_proto_data(W, bp);
4102 
4103             if (has_bias)
4104             {
4105                 const onnx::TensorProto& B = weights[node.input(2)];
4106                 fwrite_tensor_proto_data(B, bp);
4107             }
4108         }
4109         else if (op == "ConvTranspose")
4110         {
4111             const onnx::TensorProto& W = weights[node.input(1)];
4112 
4113             int has_bias = node.input_size() == 3 ? 1 : 0;
4114 
4115             std::string auto_pad = get_node_attr_s(node, "auto_pad");
4116             std::vector<int> kernel_shape = get_node_attr_ai(node, "kernel_shape");
4117             std::vector<int> dilations = get_node_attr_ai(node, "dilations");
4118             std::vector<int> strides = get_node_attr_ai(node, "strides");
4119             std::vector<int> output_padding = get_node_attr_ai(node, "output_padding");
4120             std::vector<int> output_shape = get_node_attr_ai(node, "output_shape");
4121             std::vector<int> pads = get_node_attr_ai(node, "pads");
4122             int group = get_node_attr_i(node, "group", 1);
4123             int num_filter = W.dims(1) * group;
4124 
4125             fprintf(pp, " 0=%d", num_filter);
4126 
4127             if (kernel_shape.size() == 1)
4128             {
4129                 fprintf(pp, " 1=%d", kernel_shape[0]);
4130             }
4131             else if (kernel_shape.size() == 2)
4132             {
4133                 fprintf(pp, " 1=%d", kernel_shape[1]);
4134                 fprintf(pp, " 11=%d", kernel_shape[0]);
4135             }
4136 
4137             if (dilations.size() == 1)
4138             {
4139                 fprintf(pp, " 2=%d", dilations[0]);
4140             }
4141             else if (dilations.size() == 2)
4142             {
4143                 fprintf(pp, " 2=%d", dilations[1]);
4144                 fprintf(pp, " 12=%d", dilations[0]);
4145             }
4146 
4147             if (strides.size() == 1)
4148             {
4149                 fprintf(pp, " 3=%d", strides[0]);
4150             }
4151             else if (strides.size() == 2)
4152             {
4153                 fprintf(pp, " 3=%d", strides[1]);
4154                 fprintf(pp, " 13=%d", strides[0]);
4155             }
4156 
4157             if (auto_pad == "SAME_UPPER")
4158             {
4159                 fprintf(pp, " 4=-233");
4160             }
4161             else if (auto_pad == "SAME_LOWER")
4162             {
4163                 fprintf(pp, " 4=-234");
4164             }
4165             else
4166             {
4167                 if (pads.size() == 1)
4168                 {
4169                     fprintf(pp, " 4=%d", pads[0]);
4170                 }
4171                 else if (pads.size() == 2)
4172                 {
4173                     fprintf(pp, " 4=%d", pads[1]);
4174                     fprintf(pp, " 14=%d", pads[0]);
4175                 }
4176                 else if (pads.size() == 4)
4177                 {
4178                     fprintf(pp, " 4=%d", pads[1]);
4179                     fprintf(pp, " 14=%d", pads[0]);
4180                     fprintf(pp, " 15=%d", pads[3]);
4181                     fprintf(pp, " 16=%d", pads[2]);
4182                 }
4183             }
4184 
4185             if (output_padding.size() == 1)
4186             {
4187                 fprintf(pp, " 18=%d", output_padding[0]);
4188             }
4189             else if (output_padding.size() == 2)
4190             {
4191                 fprintf(pp, " 18=%d", output_padding[1]);
4192                 fprintf(pp, " 19=%d", output_padding[0]);
4193             }
4194 
4195             if (output_shape.size() == 1)
4196             {
4197                 fprintf(pp, " 20=%d", output_shape[0]);
4198             }
4199             else if (output_shape.size() == 2)
4200             {
4201                 fprintf(pp, " 20=%d", output_shape[1]);
4202                 fprintf(pp, " 21=%d", output_shape[0]);
4203             }
4204 
4205             fprintf(pp, " 5=%d", has_bias);
4206 
4207             fprintf(pp, " 6=%d", get_tensor_proto_data_size(W));
4208 
4209             if (group > 1)
4210             {
4211                 fprintf(pp, " 7=%d", group);
4212             }
4213 
4214             int quantize_tag = 0;
4215             fwrite(&quantize_tag, sizeof(int), 1, bp);
4216 
4217             int maxk = 0;
4218             if (kernel_shape.size() == 2)
4219             {
4220                 maxk = kernel_shape[1] * kernel_shape[0];
4221             }
4222             else
4223             {
4224                 maxk = kernel_shape[0] * kernel_shape[0];
4225             }
4226             int weight_data_size = get_tensor_proto_data_size(W);
4227             const float* weight_data = 0;
4228             if (W.has_raw_data())
4229             {
4230                 weight_data = (const float*)W.raw_data().data();
4231             }
4232             else if (W.data_type() == 1)
4233             {
4234                 weight_data = W.float_data().data();
4235             }
4236             for (int g = 0; g < group; g++)
4237             {
4238                 // reorder weight from inch-outch to outch-inch
4239                 int num_filter_g = num_filter / group;
4240                 int num_input = weight_data_size / maxk / num_filter_g / group;
4241                 const float* weight_data_ptr = weight_data + g * maxk * num_filter_g * num_input;
4242                 for (int k = 0; k < num_filter_g; k++)
4243                 {
4244                     for (int j = 0; j < num_input; j++)
4245                     {
4246                         fwrite(weight_data_ptr + (j * num_filter_g + k) * maxk, sizeof(float), maxk, bp);
4247                     }
4248                 }
4249             }
4250 
4251             if (has_bias)
4252             {
4253                 const onnx::TensorProto& B = weights[node.input(2)];
4254                 fwrite_tensor_proto_data(B, bp);
4255             }
4256         }
4257         else if (op == "Cos")
4258         {
4259             int op_type = 10;
4260             fprintf(pp, " 0=%d", op_type);
4261         }
4262         else if (op == "DepthToSpace")
4263         {
4264             // pixelshuffle
4265             int scale_factor = get_node_attr_i(node, "blocksize", 1);
4266             std::string mode = get_node_attr_s(node, "mode");
4267             fprintf(pp, " 0=%d", scale_factor);
4268             if (mode == "CRD")
4269             {
4270                 fprintf(pp, " 1=0");
4271             }
4272             else if (mode == "DCR")
4273             {
4274                 fprintf(pp, " 1=1");
4275             }
4276         }
4277         else if (op == "Div")
4278         {
4279             int op_type = 3;
4280             fprintf(pp, " 0=%d", op_type);
4281 
4282             int with_scalar = get_node_attr_i(node, "with_scalar", 0);
4283             float b = get_node_attr_f(node, "b", 0.f);
4284             if (with_scalar)
4285             {
4286                 fprintf(pp, " 1=%d", with_scalar);
4287                 fprintf(pp, " 2=%e", b);
4288             }
4289         }
4290         else if (op == "Dropout")
4291         {
4292             // no-op
4293         }
4294         else if (op == "Elu")
4295         {
4296             float alpha = get_node_attr_f(node, "alpha", 1.f);
4297             fprintf(pp, " 0=%e", alpha);
4298         }
4299         else if (op == "EmbedLayerNormalization")
4300         {
4301             const onnx::TensorProto& words = weights[node.input(2)];
4302             const onnx::TensorProto& positions = weights[node.input(3)];
4303             const onnx::TensorProto& W = weights[node.input(5)];
4304             const onnx::TensorProto& B = weights[node.input(6)];
4305 
4306             fprintf(pp, " 0=%d", get_tensor_proto_data_size(B));
4307             fprintf(pp, " 1=%d", get_tensor_proto_data_size(words));
4308             fprintf(pp, " 2=%d", get_tensor_proto_data_size(positions));
4309 
4310             int quantize_tag = 0;
4311             fwrite(&quantize_tag, sizeof(int), 1, bp);
4312 
4313             fwrite_tensor_proto_data(words, bp);
4314 
4315             fwrite(&quantize_tag, sizeof(int), 1, bp);
4316 
4317             fwrite_tensor_proto_data(positions, bp);
4318 
4319             fwrite(&quantize_tag, sizeof(int), 1, bp);
4320 
4321             fwrite_tensor_proto_data(W, bp);
4322 
4323             fwrite(&quantize_tag, sizeof(int), 1, bp);
4324 
4325             fwrite_tensor_proto_data(B, bp);
4326         }
4327         else if (op == "Exp")
4328         {
4329             int op_type = 7;
4330             fprintf(pp, " 0=%d", op_type);
4331         }
4332         else if (op == "Flatten")
4333         {
4334             int axis = get_node_attr_i(node, "axis", 1);
4335             if (axis != 1)
4336             {
4337                 fprintf(stderr, "Unsupported Flatten axis %d!\n", axis);
4338             }
4339         }
4340         else if (op == "Floor")
4341         {
4342             int op_type = 2;
4343             fprintf(pp, " 0=%d", op_type);
4344         }
4345         else if (op == "Gemm")
4346         {
4347             float alpha = get_node_attr_f(node, "alpha", 1.f);
4348             float beta = get_node_attr_f(node, "beta", 1.f);
4349             int transA = get_node_attr_i(node, "transA", 0);
4350             int transB = get_node_attr_i(node, "transB", 0);
4351 
4352             if (alpha == 1.f && beta == 1.f && transA == 0 && transB == 1)
4353             {
4354                 // InnerProduct-like A * B + C
4355                 const onnx::TensorProto& B = weights[node.input(1)];
4356                 const onnx::TensorProto& C = weights[node.input(2)];
4357 
4358                 fprintf(pp, " 0=%d", get_tensor_proto_data_size(C));
4359                 fprintf(pp, " 1=1");
4360                 fprintf(pp, " 2=%d", get_tensor_proto_data_size(B));
4361 
4362                 int quantize_tag = 0;
4363                 fwrite(&quantize_tag, sizeof(int), 1, bp);
4364 
4365                 fwrite_tensor_proto_data(B, bp);
4366                 fwrite_tensor_proto_data(C, bp);
4367             }
4368             else
4369             {
4370                 // gemm
4371                 fprintf(pp, " 0=%e", alpha);
4372                 fprintf(pp, " 1=%e", beta);
4373                 fprintf(pp, " 2=%d", transA);
4374                 fprintf(pp, " 3=%d", transB);
4375             }
4376         }
4377         else if (op == "GlobalAveragePool")
4378         {
4379             int pool = 1;
4380             int global_pool = 1;
4381 
4382             fprintf(pp, " 0=%d", pool);
4383             fprintf(pp, " 4=%d", global_pool);
4384         }
4385         else if (op == "GlobalMaxPool")
4386         {
4387             int pool = 0;
4388             int global_pool = 1;
4389 
4390             fprintf(pp, " 0=%d", pool);
4391             fprintf(pp, " 4=%d", global_pool);
4392         }
4393         else if (op == "adaptive_avg_pool2d" || op == "adaptive_max_pool2d")
4394         {
4395             int pool = 0;
4396             if (op == "adaptive_avg_pool2d")
4397             {
4398                 pool = 1;
4399             }
4400             int adaptive_pooling = 1;
4401             const onnx::TensorProto& out_shape_tp = weights[node.input(1)];
4402             std::vector<int> out_shape = get_node_attr_from_input_ai(out_shape_tp);
4403 
4404             fprintf(pp, " 0=%d", pool);
4405             fprintf(pp, " 7=%d", adaptive_pooling);
4406             if (out_shape.size() == 1)
4407             {
4408                 fprintf(pp, " 8=%d", out_shape[0]);
4409             }
4410             else if (out_shape.size() == 2)
4411             {
4412                 // out_w
4413                 fprintf(pp, " 8=%d", out_shape[1]);
4414                 // out_h
4415                 fprintf(pp, " 18=%d", out_shape[0]);
4416             }
4417         }
4418         else if (op == "GroupNorm")
4419         {
4420             int groups = get_node_attr_i(node, "groups", 1);
4421             int channels = get_node_attr_i(node, "channels", 1);
4422             float eps = get_node_attr_f(node, "epsilon", 1e-5f);
4423             int affine = get_node_attr_i(node, "affine", 1);
4424 
4425             if (affine)
4426             {
4427                 // discard affine-less S=1 B=0
4428                 std::vector<float> affine_S = get_node_attr_from_input_af(weights[node.input(1)]);
4429                 std::vector<float> affine_B = get_node_attr_from_input_af(weights[node.input(2)]);
4430                 if (affine_S.size() == 1 && affine_S[0] == 1.f && affine_B.size() == 1 && affine_B[0] == 0.f)
4431                 {
4432                     affine = 0;
4433                 }
4434                 else
4435                 {
4436                     affine = 0;
4437                     {
4438                         for (int j = 0; j < channels; j++)
4439                         {
4440                             if (affine_S[j] != 1.f || affine_B[j] != 0.f)
4441                             {
4442                                 affine = 1;
4443                                 break;
4444                             }
4445                         }
4446                     }
4447                 }
4448             }
4449 
4450             fprintf(pp, " 0=%d", groups);
4451             fprintf(pp, " 1=%d", channels);
4452             fprintf(pp, " 2=%e", eps);
4453             fprintf(pp, " 3=%d", affine);
4454             if (affine)
4455             {
4456                 const onnx::TensorProto& scale = weights[node.input(1)];
4457                 const onnx::TensorProto& B = weights[node.input(2)];
4458 
4459                 fwrite_tensor_proto_data(scale, bp);
4460                 fwrite_tensor_proto_data(B, bp);
4461             }
4462         }
4463         else if (op == "GRU")
4464         {
4465             const onnx::TensorProto& W = weights[node.input(1)];
4466             const onnx::TensorProto& R = weights[node.input(2)];
4467             const onnx::TensorProto& B = weights[node.input(3)];
4468 
4469             int hidden_size = get_node_attr_i(node, "hidden_size", 0);
4470             std::string direction = get_node_attr_s(node, "direction");
4471 
4472             int direction_type = 0;
4473             if (direction == "forward")
4474             {
4475                 direction_type = 0;
4476             }
4477             else if (direction == "reverse")
4478             {
4479                 direction_type = 1;
4480             }
4481             else if (direction == "bidirectional")
4482             {
4483                 direction_type = 2;
4484             }
4485 
4486             int weight_data_size = get_tensor_proto_data_size(W);
4487 
4488             fprintf(pp, " 0=%d", hidden_size);
4489             fprintf(pp, " 1=%d", weight_data_size);
4490             fprintf(pp, " 2=%d", direction_type);
4491 
4492             int num_directions = direction_type == 2 ? 2 : 1;
4493 
4494             int quantize_tag = 0;
4495 
4496             // reorder num_directions-URN-hidden-size to num_directions-RUN-hidden-size
4497             {
4498                 fwrite(&quantize_tag, sizeof(int), 1, bp);
4499 
4500                 int weight_data_size_g = get_tensor_proto_data_size(W) / 3 / num_directions;
4501                 const float* wptr = W.has_raw_data() ? (const float*)W.raw_data().data() : W.float_data().data();
4502 
4503                 const float* uptr = wptr;
4504                 const float* rptr = wptr + weight_data_size_g;
4505                 const float* nptr = wptr + weight_data_size_g * 2;
4506                 fwrite(rptr, sizeof(float), weight_data_size_g, bp);
4507                 fwrite(uptr, sizeof(float), weight_data_size_g, bp);
4508                 fwrite(nptr, sizeof(float), weight_data_size_g, bp);
4509 
4510                 if (direction_type == 2)
4511                 {
4512                     uptr += weight_data_size_g * 3;
4513                     rptr += weight_data_size_g * 3;
4514                     nptr += weight_data_size_g * 3;
4515                     fwrite(rptr, sizeof(float), weight_data_size_g, bp);
4516                     fwrite(uptr, sizeof(float), weight_data_size_g, bp);
4517                     fwrite(nptr, sizeof(float), weight_data_size_g, bp);
4518                 }
4519             }
4520 
4521             // reduce U and R bias except N
4522             // reorder num_directions-URN-hidden to num_directions-RUN-hidden
4523             {
4524                 fwrite(&quantize_tag, sizeof(int), 1, bp);
4525 
4526                 int bias_data_size_g = get_tensor_proto_data_size(B) / 2 / 3 / num_directions;
4527                 const float* bptr = B.has_raw_data() ? (const float*)B.raw_data().data() : B.float_data().data();
4528                 const float* wuptr = bptr;
4529                 const float* wrptr = bptr + bias_data_size_g;
4530                 const float* wnptr = bptr + bias_data_size_g * 2;
4531                 const float* buptr = bptr + bias_data_size_g * 3;
4532                 const float* brptr = bptr + bias_data_size_g * 4;
4533                 const float* bnptr = bptr + bias_data_size_g * 5;
4534 
4535                 for (int j = 0; j < bias_data_size_g; j++)
4536                 {
4537                     float vb = wrptr[j] + brptr[j];
4538                     fwrite(&vb, sizeof(float), 1, bp);
4539                 }
4540                 for (int j = 0; j < bias_data_size_g; j++)
4541                 {
4542                     float vb = wuptr[j] + buptr[j];
4543                     fwrite(&vb, sizeof(float), 1, bp);
4544                 }
4545                 fwrite(wnptr, sizeof(float), bias_data_size_g, bp);
4546                 fwrite(bnptr, sizeof(float), bias_data_size_g, bp);
4547 
4548                 if (direction_type == 2)
4549                 {
4550                     wuptr += bias_data_size_g * 6;
4551                     wrptr += bias_data_size_g * 6;
4552                     wnptr += bias_data_size_g * 6;
4553                     buptr += bias_data_size_g * 6;
4554                     brptr += bias_data_size_g * 6;
4555                     bnptr += bias_data_size_g * 6;
4556 
4557                     for (int j = 0; j < bias_data_size_g; j++)
4558                     {
4559                         float vb = wrptr[j] + brptr[j];
4560                         fwrite(&vb, sizeof(float), 1, bp);
4561                     }
4562                     for (int j = 0; j < bias_data_size_g; j++)
4563                     {
4564                         float vb = wuptr[j] + buptr[j];
4565                         fwrite(&vb, sizeof(float), 1, bp);
4566                     }
4567                     fwrite(wnptr, sizeof(float), bias_data_size_g, bp);
4568                     fwrite(bnptr, sizeof(float), bias_data_size_g, bp);
4569                 }
4570             }
4571 
4572             // reorder num_directions-URN-hidden-hidden to num_directions-RUN-hidden-hidden
4573             {
4574                 fwrite(&quantize_tag, sizeof(int), 1, bp);
4575 
4576                 int weight_data_size_g = get_tensor_proto_data_size(R) / 3 / num_directions;
4577                 const float* Rptr = R.has_raw_data() ? (const float*)R.raw_data().data() : R.float_data().data();
4578 
4579                 const float* uptr = Rptr;
4580                 const float* rptr = Rptr + weight_data_size_g;
4581                 const float* nptr = Rptr + weight_data_size_g * 2;
4582                 fwrite(rptr, sizeof(float), weight_data_size_g, bp);
4583                 fwrite(uptr, sizeof(float), weight_data_size_g, bp);
4584                 fwrite(nptr, sizeof(float), weight_data_size_g, bp);
4585 
4586                 if (direction_type == 2)
4587                 {
4588                     uptr += weight_data_size_g * 3;
4589                     rptr += weight_data_size_g * 3;
4590                     nptr += weight_data_size_g * 3;
4591                     fwrite(rptr, sizeof(float), weight_data_size_g, bp);
4592                     fwrite(uptr, sizeof(float), weight_data_size_g, bp);
4593                     fwrite(nptr, sizeof(float), weight_data_size_g, bp);
4594                 }
4595             }
4596         }
4597         else if (op == "HardSigmoid")
4598         {
4599             float alpha = get_node_attr_f(node, "alpha", 0.2f);
4600             float beta = get_node_attr_f(node, "beta", 0.5f);
4601 
4602             fprintf(pp, " 0=%e", alpha);
4603             fprintf(pp, " 1=%e", beta);
4604         }
4605         else if (op == "HardSwish")
4606         {
4607             float alpha = get_node_attr_f(node, "alpha", 0.2f);
4608             float beta = get_node_attr_f(node, "beta", 0.5f);
4609 
4610             fprintf(pp, " 0=%e", alpha);
4611             fprintf(pp, " 1=%e", beta);
4612         }
4613         else if (op == "ImageScaler")
4614         {
4615             std::vector<float> bias = get_node_attr_af(node, "bias");
4616             float scale = get_node_attr_f(node, "scale", 1.f);
4617 
4618             int channels = (int)bias.size();
4619 
4620             fprintf(pp, " 0=%d", channels);
4621             fprintf(pp, " 1=1");
4622 
4623             for (int j = 0; j < channels; j++)
4624             {
4625                 fwrite(&scale, sizeof(float), 1, bp);
4626             }
4627             fwrite(&bias[0], sizeof(float), channels, bp);
4628         }
4629         else if (op == "InstanceNormalization")
4630         {
4631             float eps = get_node_attr_f(node, "epsilon", 1e-5f);
4632 
4633             // discard affine-less S=1 B=0
4634             std::vector<float> affine_S = get_node_attr_from_input_af(weights[node.input(1)]);
4635             std::vector<float> affine_B = get_node_attr_from_input_af(weights[node.input(2)]);
4636             int channels = (int)affine_S.size();
4637             int affine = 0;
4638             {
4639                 for (int j = 0; j < channels; j++)
4640                 {
4641                     if (affine_S[j] != 1.f || affine_B[j] != 0.f)
4642                     {
4643                         affine = 1;
4644                         break;
4645                     }
4646                 }
4647             }
4648 
4649             fprintf(pp, " 0=%d", channels);
4650             fprintf(pp, " 1=%e", eps);
4651             fprintf(pp, " 2=%d", affine);
4652             if (affine)
4653             {
4654                 const onnx::TensorProto& scale = weights[node.input(1)];
4655                 const onnx::TensorProto& B = weights[node.input(2)];
4656 
4657                 fwrite_tensor_proto_data(scale, bp);
4658                 fwrite_tensor_proto_data(B, bp);
4659             }
4660         }
4661         else if (op == "LayerNorm")
4662         {
4663             float eps = get_node_attr_f(node, "epsilon", 1e-5f);
4664             int affine = get_node_attr_i(node, "affine", 1);
4665 
4666             if (affine)
4667             {
4668                 // discard affine-less S=1 B=0
4669                 std::vector<float> affine_S = get_node_attr_from_input_af(weights[node.input(1)]);
4670                 std::vector<float> affine_B = get_node_attr_from_input_af(weights[node.input(2)]);
4671                 int affine_size = (int)affine_S.size();
4672                 affine = 0;
4673                 {
4674                     for (int j = 0; j < affine_size; j++)
4675                     {
4676                         if (affine_S[j] != 1.f || affine_B[j] != 0.f)
4677                         {
4678                             affine = 1;
4679                             break;
4680                         }
4681                     }
4682                 }
4683 
4684                 if (affine)
4685                 {
4686                     fprintf(pp, " 0=%d", affine_size);
4687                 }
4688             }
4689 
4690             fprintf(pp, " 1=%e", eps);
4691             fprintf(pp, " 2=%d", affine);
4692 
4693             if (affine)
4694             {
4695                 const onnx::TensorProto& scale = weights[node.input(1)];
4696                 const onnx::TensorProto& B = weights[node.input(2)];
4697 
4698                 fwrite_tensor_proto_data(scale, bp);
4699                 fwrite_tensor_proto_data(B, bp);
4700             }
4701         }
4702         else if (op == "LeakyRelu")
4703         {
4704             float alpha = get_node_attr_f(node, "alpha", 0.01f);
4705 
4706             fprintf(pp, " 0=%e", alpha);
4707         }
4708         else if (op == "Log")
4709         {
4710             int op_type = 8;
4711             fprintf(pp, " 0=%d", op_type);
4712         }
4713         else if (op == "LRN")
4714         {
4715             float alpha = get_node_attr_f(node, "alpha", 1.f);
4716             float beta = get_node_attr_f(node, "beta", 0.5f);
4717             float bias = get_node_attr_f(node, "bias", 1.f);
4718             int size = get_node_attr_i(node, "size", 1);
4719 
4720             int norm_region = 0;
4721 
4722             fprintf(pp, " 0=%d", norm_region);
4723             fprintf(pp, " 1=%d", size);
4724             fprintf(pp, " 2=%e", alpha);
4725             fprintf(pp, " 3=%e", beta);
4726             fprintf(pp, " 4=%e", bias);
4727         }
4728         else if (op == "LSTM")
4729         {
4730             const onnx::TensorProto& W = weights[node.input(1)];
4731             const onnx::TensorProto& R = weights[node.input(2)];
4732             const onnx::TensorProto& B = weights[node.input(3)];
4733 
4734             int hidden_size = get_node_attr_i(node, "hidden_size", 0);
4735             std::string direction = get_node_attr_s(node, "direction");
4736 
4737             int direction_type = 0;
4738             if (direction == "forward")
4739             {
4740                 direction_type = 0;
4741             }
4742             else if (direction == "reverse")
4743             {
4744                 direction_type = 1;
4745             }
4746             else if (direction == "bidirectional")
4747             {
4748                 direction_type = 2;
4749             }
4750 
4751             int weight_data_size = get_tensor_proto_data_size(W);
4752 
4753             fprintf(pp, " 0=%d", hidden_size);
4754             fprintf(pp, " 1=%d", weight_data_size);
4755             fprintf(pp, " 2=%d", direction_type);
4756 
4757             int num_directions = direction_type == 2 ? 2 : 1;
4758 
4759             int quantize_tag = 0;
4760 
4761             // reorder num_directions-IOFG-hidden-size to num_directions-IFOG-hidden-size
4762             {
4763                 fwrite(&quantize_tag, sizeof(int), 1, bp);
4764 
4765                 int weight_data_size_g = get_tensor_proto_data_size(W) / 4 / num_directions;
4766                 const float* wptr = W.has_raw_data() ? (const float*)W.raw_data().data() : W.float_data().data();
4767 
4768                 const float* iptr = wptr;
4769                 const float* optr = wptr + weight_data_size_g;
4770                 const float* fptr = wptr + weight_data_size_g * 2;
4771                 const float* gptr = wptr + weight_data_size_g * 3;
4772                 fwrite(iptr, sizeof(float), weight_data_size_g, bp);
4773                 fwrite(fptr, sizeof(float), weight_data_size_g, bp);
4774                 fwrite(optr, sizeof(float), weight_data_size_g, bp);
4775                 fwrite(gptr, sizeof(float), weight_data_size_g, bp);
4776 
4777                 if (direction_type == 2)
4778                 {
4779                     iptr += weight_data_size_g * 4;
4780                     optr += weight_data_size_g * 4;
4781                     fptr += weight_data_size_g * 4;
4782                     gptr += weight_data_size_g * 4;
4783                     fwrite(iptr, sizeof(float), weight_data_size_g, bp);
4784                     fwrite(fptr, sizeof(float), weight_data_size_g, bp);
4785                     fwrite(optr, sizeof(float), weight_data_size_g, bp);
4786                     fwrite(gptr, sizeof(float), weight_data_size_g, bp);
4787                 }
4788             }
4789 
4790             // reduce xc and hc bias
4791             // reorder num_directions-IOFG-hidden to num_directions-IFOG-hidden
4792             {
4793                 fwrite(&quantize_tag, sizeof(int), 1, bp);
4794 
4795                 int bias_data_size_g = get_tensor_proto_data_size(B) / 2 / 4 / num_directions;
4796                 const float* xcbptr = B.has_raw_data() ? (const float*)B.raw_data().data() : B.float_data().data();
4797                 const float* xiptr = xcbptr;
4798                 const float* xoptr = xcbptr + bias_data_size_g;
4799                 const float* xfptr = xcbptr + bias_data_size_g * 2;
4800                 const float* xgptr = xcbptr + bias_data_size_g * 3;
4801                 const float* hiptr = xcbptr + bias_data_size_g * 4;
4802                 const float* hoptr = xcbptr + bias_data_size_g * 5;
4803                 const float* hfptr = xcbptr + bias_data_size_g * 6;
4804                 const float* hgptr = xcbptr + bias_data_size_g * 7;
4805 
4806                 for (int j = 0; j < bias_data_size_g; j++)
4807                 {
4808                     float vb = xiptr[j] + hiptr[j];
4809                     fwrite(&vb, sizeof(float), 1, bp);
4810                 }
4811                 for (int j = 0; j < bias_data_size_g; j++)
4812                 {
4813                     float vb = xfptr[j] + hfptr[j];
4814                     fwrite(&vb, sizeof(float), 1, bp);
4815                 }
4816                 for (int j = 0; j < bias_data_size_g; j++)
4817                 {
4818                     float vb = xoptr[j] + hoptr[j];
4819                     fwrite(&vb, sizeof(float), 1, bp);
4820                 }
4821                 for (int j = 0; j < bias_data_size_g; j++)
4822                 {
4823                     float vb = xgptr[j] + hgptr[j];
4824                     fwrite(&vb, sizeof(float), 1, bp);
4825                 }
4826 
4827                 if (direction_type == 2)
4828                 {
4829                     xiptr += bias_data_size_g * 8;
4830                     xoptr += bias_data_size_g * 8;
4831                     xfptr += bias_data_size_g * 8;
4832                     xgptr += bias_data_size_g * 8;
4833                     hiptr += bias_data_size_g * 8;
4834                     hoptr += bias_data_size_g * 8;
4835                     hfptr += bias_data_size_g * 8;
4836                     hgptr += bias_data_size_g * 8;
4837 
4838                     for (int j = 0; j < bias_data_size_g; j++)
4839                     {
4840                         float vb = xiptr[j] + hiptr[j];
4841                         fwrite(&vb, sizeof(float), 1, bp);
4842                     }
4843                     for (int j = 0; j < bias_data_size_g; j++)
4844                     {
4845                         float vb = xfptr[j] + hfptr[j];
4846                         fwrite(&vb, sizeof(float), 1, bp);
4847                     }
4848                     for (int j = 0; j < bias_data_size_g; j++)
4849                     {
4850                         float vb = xoptr[j] + hoptr[j];
4851                         fwrite(&vb, sizeof(float), 1, bp);
4852                     }
4853                     for (int j = 0; j < bias_data_size_g; j++)
4854                     {
4855                         float vb = xgptr[j] + hgptr[j];
4856                         fwrite(&vb, sizeof(float), 1, bp);
4857                     }
4858                 }
4859             }
4860 
4861             // reorder num_directions-IOFG-hidden-hidden to num_directions-IFOG-hidden-hidden
4862             {
4863                 fwrite(&quantize_tag, sizeof(int), 1, bp);
4864 
4865                 int weight_data_size_g = get_tensor_proto_data_size(R) / 4 / num_directions;
4866                 const float* rptr = R.has_raw_data() ? (const float*)R.raw_data().data() : R.float_data().data();
4867 
4868                 const float* iptr = rptr;
4869                 const float* optr = rptr + weight_data_size_g;
4870                 const float* fptr = rptr + weight_data_size_g * 2;
4871                 const float* gptr = rptr + weight_data_size_g * 3;
4872                 fwrite(iptr, sizeof(float), weight_data_size_g, bp);
4873                 fwrite(fptr, sizeof(float), weight_data_size_g, bp);
4874                 fwrite(optr, sizeof(float), weight_data_size_g, bp);
4875                 fwrite(gptr, sizeof(float), weight_data_size_g, bp);
4876 
4877                 if (direction_type == 2)
4878                 {
4879                     iptr += weight_data_size_g * 4;
4880                     optr += weight_data_size_g * 4;
4881                     fptr += weight_data_size_g * 4;
4882                     gptr += weight_data_size_g * 4;
4883                     fwrite(iptr, sizeof(float), weight_data_size_g, bp);
4884                     fwrite(fptr, sizeof(float), weight_data_size_g, bp);
4885                     fwrite(optr, sizeof(float), weight_data_size_g, bp);
4886                     fwrite(gptr, sizeof(float), weight_data_size_g, bp);
4887                 }
4888             }
4889         }
4890         else if (op == "MatMul")
4891         {
4892             if (weights.find(node.input(1)) != weights.end() && weights[node.input(1)].dims_size() == 2)
4893             {
4894                 // InnerProduct
4895                 const onnx::TensorProto& B = weights[node.input(1)];
4896 
4897                 int weight_data_size = get_tensor_proto_data_size(B);
4898 
4899                 int num_output = B.dims(B.dims_size() - 1);
4900                 int num_input = weight_data_size / num_output;
4901 
4902                 fprintf(pp, " 0=%d", num_output);
4903                 fprintf(pp, " 1=0");
4904                 fprintf(pp, " 2=%d", weight_data_size);
4905 
4906                 int quantize_tag = 0;
4907                 fwrite(&quantize_tag, sizeof(int), 1, bp);
4908 
4909                 // reorder num_input-num_output to num_output-num_input
4910                 {
4911                     const float* bptr = B.has_raw_data() ? (const float*)B.raw_data().data() : B.float_data().data();
4912 
4913                     for (int j = 0; j < num_output; j++)
4914                     {
4915                         for (int k = 0; k < num_input; k++)
4916                         {
4917                             float vb = bptr[k * num_output + j];
4918                             fwrite(&vb, sizeof(float), 1, bp);
4919                         }
4920                     }
4921                 }
4922 
4923                 // fwrite_tensor_proto_data(B, bp)
4924             }
4925             else
4926             {
4927                 // default matrix multiplication
4928             }
4929         }
4930         else if (op == "Max")
4931         {
4932             int op_type = 4;
4933             fprintf(pp, " 0=%d", op_type);
4934 
4935             int with_scalar = get_node_attr_i(node, "with_scalar", 0);
4936             float b = get_node_attr_f(node, "b", 0.f);
4937             if (with_scalar)
4938             {
4939                 fprintf(pp, " 1=%d", with_scalar);
4940                 fprintf(pp, " 2=%e", b);
4941             }
4942         }
4943         else if (op == "Min")
4944         {
4945             int op_type = 5;
4946             fprintf(pp, " 0=%d", op_type);
4947 
4948             int with_scalar = get_node_attr_i(node, "with_scalar", 0);
4949             float b = get_node_attr_f(node, "b", 0.f);
4950             if (with_scalar)
4951             {
4952                 fprintf(pp, " 1=%d", with_scalar);
4953                 fprintf(pp, " 2=%e", b);
4954             }
4955         }
4956         else if (op == "Mul")
4957         {
4958             int op_type = 2;
4959             fprintf(pp, " 0=%d", op_type);
4960 
4961             int with_scalar = get_node_attr_i(node, "with_scalar", 0);
4962             float b = get_node_attr_f(node, "b", 0.f);
4963             if (with_scalar)
4964             {
4965                 fprintf(pp, " 1=%d", with_scalar);
4966                 fprintf(pp, " 2=%e", b);
4967             }
4968         }
4969         else if (op == "MultiHeadAttention")
4970         {
4971             int embed_dim = get_node_attr_i(node, "embed_dim", 0);
4972             int num_heads = get_node_attr_i(node, "num_heads", 0);
4973 
4974             fprintf(pp, " 0=%d", embed_dim);
4975             fprintf(pp, " 1=%d", num_heads);
4976 
4977             if (node.input_size() == 5)
4978             {
4979                 const onnx::TensorProto& qkvw = weights[node.input(1)];
4980                 const onnx::TensorProto& qkvb = weights[node.input(2)];
4981                 const onnx::TensorProto& ow = weights[node.input(3)];
4982                 const onnx::TensorProto& ob = weights[node.input(4)];
4983 
4984                 int weight_data_size = get_tensor_proto_data_size(ow);
4985 
4986                 fprintf(pp, " 2=%d", weight_data_size);
4987 
4988                 int quantize_tag = 0;
4989 
4990                 fwrite(&quantize_tag, sizeof(int), 1, bp);
4991                 // transpose qw
4992                 {
4993                     const float* wptr = qkvw.has_raw_data() ? (const float*)qkvw.raw_data().data() : qkvw.float_data().data();
4994                     const float* bptr = qkvb.has_raw_data() ? (const float*)qkvb.raw_data().data() : qkvb.float_data().data();
4995 
4996                     for (int j = 0; j < embed_dim; j++)
4997                     {
4998                         for (int k = 0; k < embed_dim; k++)
4999                         {
5000                             float vb = wptr[k * embed_dim * 3 + j];
5001                             fwrite(&vb, sizeof(float), 1, bp);
5002                         }
5003                     }
5004 
5005                     fwrite(bptr, sizeof(float), embed_dim, bp);
5006                 }
5007 
5008                 fwrite(&quantize_tag, sizeof(int), 1, bp);
5009                 // transpose kw
5010                 {
5011                     const float* wptr = qkvw.has_raw_data() ? (const float*)qkvw.raw_data().data() : qkvw.float_data().data();
5012                     const float* bptr = qkvb.has_raw_data() ? (const float*)qkvb.raw_data().data() : qkvb.float_data().data();
5013                     bptr += embed_dim;
5014 
5015                     for (int j = 0; j < embed_dim; j++)
5016                     {
5017                         for (int k = 0; k < embed_dim; k++)
5018                         {
5019                             float vb = wptr[k * embed_dim * 3 + j + embed_dim];
5020                             fwrite(&vb, sizeof(float), 1, bp);
5021                         }
5022                     }
5023 
5024                     fwrite(bptr, sizeof(float), embed_dim, bp);
5025                 }
5026 
5027                 fwrite(&quantize_tag, sizeof(int), 1, bp);
5028                 // transpose vw
5029                 {
5030                     const float* wptr = qkvw.has_raw_data() ? (const float*)qkvw.raw_data().data() : qkvw.float_data().data();
5031                     const float* bptr = qkvb.has_raw_data() ? (const float*)qkvb.raw_data().data() : qkvb.float_data().data();
5032                     bptr += embed_dim * 2;
5033 
5034                     for (int j = 0; j < embed_dim; j++)
5035                     {
5036                         for (int k = 0; k < embed_dim; k++)
5037                         {
5038                             float vb = wptr[k * embed_dim * 3 + j + embed_dim * 2];
5039                             fwrite(&vb, sizeof(float), 1, bp);
5040                         }
5041                     }
5042 
5043                     fwrite(bptr, sizeof(float), embed_dim, bp);
5044                 }
5045 
5046                 fwrite(&quantize_tag, sizeof(int), 1, bp);
5047                 // transpose ow
5048                 {
5049                     const float* wptr = ow.has_raw_data() ? (const float*)ow.raw_data().data() : ow.float_data().data();
5050 
5051                     for (int j = 0; j < embed_dim; j++)
5052                     {
5053                         for (int k = 0; k < embed_dim; k++)
5054                         {
5055                             float vb = wptr[k * embed_dim + j];
5056                             fwrite(&vb, sizeof(float), 1, bp);
5057                         }
5058                     }
5059                 }
5060                 fwrite_tensor_proto_data(ob, bp);
5061             }
5062             else
5063             {
5064                 const onnx::TensorProto& qw = weights[node.input(3)];
5065                 const onnx::TensorProto& qb = weights[node.input(4)];
5066                 const onnx::TensorProto& kw = weights[node.input(5)];
5067                 const onnx::TensorProto& kb = weights[node.input(6)];
5068                 const onnx::TensorProto& vw = weights[node.input(7)];
5069                 const onnx::TensorProto& vb = weights[node.input(8)];
5070                 const onnx::TensorProto& ow = weights[node.input(9)];
5071                 const onnx::TensorProto& ob = weights[node.input(10)];
5072 
5073                 int weight_data_size = get_tensor_proto_data_size(qw);
5074 
5075                 fprintf(pp, " 2=%d", weight_data_size);
5076 
5077                 int quantize_tag = 0;
5078 
5079                 fwrite(&quantize_tag, sizeof(int), 1, bp);
5080                 // transpose qw
5081                 {
5082                     const float* wptr = qw.has_raw_data() ? (const float*)qw.raw_data().data() : qw.float_data().data();
5083 
5084                     for (int j = 0; j < embed_dim; j++)
5085                     {
5086                         for (int k = 0; k < embed_dim; k++)
5087                         {
5088                             float vb = wptr[k * embed_dim + j];
5089                             fwrite(&vb, sizeof(float), 1, bp);
5090                         }
5091                     }
5092                 }
5093                 fwrite_tensor_proto_data(qb, bp);
5094 
5095                 fwrite(&quantize_tag, sizeof(int), 1, bp);
5096                 // transpose kw
5097                 {
5098                     const float* wptr = kw.has_raw_data() ? (const float*)kw.raw_data().data() : kw.float_data().data();
5099 
5100                     for (int j = 0; j < embed_dim; j++)
5101                     {
5102                         for (int k = 0; k < embed_dim; k++)
5103                         {
5104                             float vb = wptr[k * embed_dim + j];
5105                             fwrite(&vb, sizeof(float), 1, bp);
5106                         }
5107                     }
5108                 }
5109                 fwrite_tensor_proto_data(kb, bp);
5110 
5111                 fwrite(&quantize_tag, sizeof(int), 1, bp);
5112                 // transpose vw
5113                 {
5114                     const float* wptr = vw.has_raw_data() ? (const float*)vw.raw_data().data() : vw.float_data().data();
5115 
5116                     for (int j = 0; j < embed_dim; j++)
5117                     {
5118                         for (int k = 0; k < embed_dim; k++)
5119                         {
5120                             float vb = wptr[k * embed_dim + j];
5121                             fwrite(&vb, sizeof(float), 1, bp);
5122                         }
5123                     }
5124                 }
5125                 fwrite_tensor_proto_data(vb, bp);
5126 
5127                 fwrite(&quantize_tag, sizeof(int), 1, bp);
5128                 // transpose ow
5129                 {
5130                     const float* wptr = ow.has_raw_data() ? (const float*)ow.raw_data().data() : ow.float_data().data();
5131 
5132                     for (int j = 0; j < embed_dim; j++)
5133                     {
5134                         for (int k = 0; k < embed_dim; k++)
5135                         {
5136                             float vb = wptr[k * embed_dim + j];
5137                             fwrite(&vb, sizeof(float), 1, bp);
5138                         }
5139                     }
5140                 }
5141                 fwrite_tensor_proto_data(ob, bp);
5142             }
5143         }
5144         else if (op == "Neg")
5145         {
5146             int op_type = 1;
5147             fprintf(pp, " 0=%d", op_type);
5148         }
5149         else if (op == "Normalize")
5150         {
5151             float eps = get_node_attr_f(node, "eps", 0.f);
5152             int scale_data_size = 1;
5153 
5154             fprintf(pp, " 1=1"); // channel_shared
5155             fprintf(pp, " 2=%e", eps);
5156             fprintf(pp, " 3=%d", scale_data_size);
5157             fprintf(pp, " 9=1"); // TODO hardcode pytorch style
5158 
5159             const float scale_data[1] = {1.f};
5160             fwrite(scale_data, sizeof(float), 1, bp);
5161         }
5162         else if (op == "Pad")
5163         {
5164             std::string mode = get_node_attr_s(node, "mode");
5165             float value = get_node_attr_f(node, "value", 0.f);
5166 
5167             std::vector<int> pads;
5168             if (node.input_size() == 1)
5169             {
5170                 pads = get_node_attr_ai(node, "pads");
5171             }
5172             else
5173             {
5174                 pads = get_node_attr_from_input_ai(weights[node.input(1)]);
5175             }
5176 
5177             int type = 0;
5178             if (mode == "constant")
5179             {
5180                 type = 0;
5181             }
5182             else if (mode == "edge")
5183             {
5184                 type = 1;
5185             }
5186             else if (mode == "reflect")
5187             {
5188                 type = 2;
5189             }
5190 
5191             int pad_size = (int)pads.size();
5192             int top = 0;
5193             int bottom = 0;
5194             int left = 0;
5195             int right = 0;
5196             int front = 0;
5197             int behind = 0;
5198             if (pad_size == 8)
5199             {
5200                 //NCHW
5201                 top = pads[2];
5202                 bottom = pads[6];
5203                 left = pads[3];
5204                 right = pads[7];
5205                 front = pads[1];
5206                 behind = pads[5];
5207             }
5208             else if (pad_size == 6)
5209             {
5210                 //NHW
5211                 top = pads[1];
5212                 bottom = pads[4];
5213                 left = pads[2];
5214                 right = pads[5];
5215             }
5216             else
5217             {
5218                 //NW
5219                 left = pads[1];
5220                 right = pads[3];
5221             }
5222 
5223             fprintf(pp, " 0=%d", top);
5224             fprintf(pp, " 1=%d", bottom);
5225             fprintf(pp, " 2=%d", left);
5226             fprintf(pp, " 3=%d", right);
5227             fprintf(pp, " 4=%d", type);
5228             fprintf(pp, " 5=%e", value);
5229             fprintf(pp, " 7=%d", front);
5230             fprintf(pp, " 8=%d", behind);
5231         }
5232         else if (op == "Pow")
5233         {
5234             int op_type = 6;
5235             fprintf(pp, " 0=%d", op_type);
5236 
5237             int with_scalar = get_node_attr_i(node, "with_scalar", 0);
5238             float b = get_node_attr_f(node, "b", 0.f);
5239             if (with_scalar)
5240             {
5241                 fprintf(pp, " 1=%d", with_scalar);
5242                 fprintf(pp, " 2=%e", b);
5243             }
5244         }
5245         else if (op == "PixelShuffle")
5246         {
5247             int scale_factor = get_node_attr_i(node, "scale_factor", 1);
5248             fprintf(pp, " 0=%d", scale_factor);
5249         }
5250         else if (op == "PRelu")
5251         {
5252             const onnx::TensorProto& slope = weights[node.input(1)];
5253 
5254             int num_slope = get_tensor_proto_data_size(slope);
5255 
5256             fprintf(pp, " 0=%d", num_slope);
5257 
5258             fwrite_tensor_proto_data(slope, bp);
5259         }
5260         else if (op == "Reciprocal")
5261         {
5262             int op_type = 15;
5263             fprintf(pp, " 0=%d", op_type);
5264         }
5265         else if (op == "ReduceMax" || op == "ReduceMin" || op == "ReduceMean" || op == "ReduceProd" || op == "ReduceSum" || op == "ReduceSumSquare" || op == "ReduceL1" || op == "ReduceL2" || op == "ReduceLogSum" || op == "ReduceLogSumExp")
5266         {
5267             int op_type = -233;
5268             if (op == "ReduceSum")
5269                 op_type = 0;
5270             else if (op == "ReduceSumSquare")
5271                 op_type = 2;
5272             else if (op == "ReduceMean")
5273                 op_type = 3;
5274             else if (op == "ReduceMax")
5275                 op_type = 4;
5276             else if (op == "ReduceMin")
5277                 op_type = 5;
5278             else if (op == "ReduceProd")
5279                 op_type = 6;
5280             else if (op == "ReduceL1")
5281                 op_type = 7;
5282             else if (op == "ReduceL2")
5283                 op_type = 8;
5284             else if (op == "ReduceLogSum")
5285                 op_type = 9;
5286             else if (op == "ReduceLogSumExp")
5287                 op_type = 10;
5288             fprintf(pp, " 0=%d", op_type);
5289 
5290             std::vector<int> axes = get_node_attr_ai(node, "axes");
5291             int keepdims = get_node_attr_i(node, "keepdims", 1);
5292 
5293             if (axes.size() > 0)
5294             {
5295                 // if axes set, reduce according to axes
5296                 fprintf(pp, " 1=%d", 0);
5297                 fprintf(pp, " -23303=%zu", axes.size());
5298                 for (size_t j = 0; j < axes.size(); j++)
5299                 {
5300                     if (axes[j] == 0 || axes[j] > 3 || axes[j] < -3)
5301                         fprintf(stderr, "Unsupported reduction axes !\n");
5302                     fprintf(pp, ",%d", axes[j]);
5303                 }
5304             }
5305             else
5306             {
5307                 // if axes not set, reduce all axes by default
5308                 fprintf(pp, " 1=%d", 1);
5309             }
5310             fprintf(pp, " 4=%d", keepdims);
5311         }
5312         else if (op == "Reorg")
5313         {
5314             int stride = get_node_attr_i(node, "stride", 1);
5315             fprintf(pp, " 0=%d", stride);
5316         }
5317         else if (op == "Reshape")
5318         {
5319             std::vector<int> shape;
5320 
5321             if (node.input_size() == 1)
5322             {
5323                 shape = get_node_attr_ai(node, "shape");
5324             }
5325             else
5326             {
5327                 shape = get_node_attr_from_input_ai(weights[node.input(1)]);
5328             }
5329 
5330             if (shape.size() == 1)
5331             {
5332                 fprintf(pp, " 0=%d", shape[0]); // should never reach here
5333             }
5334             else if (shape.size() == 2)
5335             {
5336                 fprintf(pp, " 0=%d", shape[1]);
5337             }
5338             else if (shape.size() == 3)
5339             {
5340                 fprintf(pp, " 0=%d", shape[2]);
5341                 fprintf(pp, " 1=%d", shape[1]);
5342             }
5343             else if (shape.size() == 4)
5344             {
5345                 fprintf(pp, " 0=%d", shape[3]);
5346                 fprintf(pp, " 1=%d", shape[2]);
5347                 fprintf(pp, " 2=%d", shape[1]);
5348             }
5349             else if (shape.size() == 5)
5350             {
5351                 fprintf(pp, " 0=%d", shape[4] * shape[3]);
5352                 fprintf(pp, " 1=%d", shape[2]);
5353                 fprintf(pp, " 2=%d", shape[1]);
5354             }
5355         }
5356         else if (op == "Resize")
5357         {
5358             std::string mode = get_node_attr_s(node, "mode");
5359             std::string align = get_node_attr_s(node, "coordinate_transformation_mode");
5360 
5361             std::vector<float> scales;
5362             std::vector<int> sizes;
5363             if (node.input_size() == 2)
5364             {
5365                 // opset 10
5366                 scales = get_node_attr_from_input_af(weights[node.input(1)]);
5367             }
5368             else
5369             {
5370                 // opset 11+
5371                 scales = get_node_attr_from_input_af(weights[node.input(2)]);
5372                 if (node.input_size() >= 4)
5373                 {
5374                     sizes = get_node_attr_from_input_ai(weights[node.input(3)]);
5375                 }
5376             }
5377 
5378             int resize_type = 1;
5379             if (mode == "nearest")
5380             {
5381                 resize_type = 1;
5382             }
5383             else if (mode == "linear")
5384             {
5385                 resize_type = 2;
5386             }
5387             else if (mode == "cubic")
5388             {
5389                 resize_type = 3;
5390             }
5391 
5392             if (scales.empty() && sizes.empty())
5393             {
5394                 fprintf(stderr, "Unsupported Resize scales and sizes are all empty!\n");
5395             }
5396 
5397             float h_scale = 1.f;
5398             float w_scale = 1.f;
5399             if (scales.size() == 2)
5400             {
5401                 w_scale = scales[1];
5402             }
5403             else if (scales.size() == 3)
5404             {
5405                 h_scale = scales[1];
5406                 w_scale = scales[2];
5407             }
5408             else if (scales.size() == 4)
5409             {
5410                 h_scale = scales[2];
5411                 w_scale = scales[3];
5412 
5413                 if (scales[1] != 1.f)
5414                     fprintf(stderr, "Unsupported Resize scales !\n");
5415             }
5416 
5417             int output_height = 0;
5418             int output_width = 0;
5419             if (sizes.size() == 2)
5420             {
5421                 output_width = sizes[1];
5422             }
5423             else if (sizes.size() == 3)
5424             {
5425                 output_height = sizes[1];
5426                 output_width = sizes[2];
5427             }
5428             else if (sizes.size() == 4)
5429             {
5430                 output_height = sizes[2];
5431                 output_width = sizes[3];
5432             }
5433 
5434             int align_corner = 0;
5435             if (align == "align_corners")
5436             {
5437                 align_corner = 1;
5438             }
5439 
5440             fprintf(pp, " 0=%d", resize_type);
5441             fprintf(pp, " 1=%e", h_scale);
5442             fprintf(pp, " 2=%e", w_scale);
5443             fprintf(pp, " 3=%d", output_height);
5444             fprintf(pp, " 4=%d", output_width);
5445             fprintf(pp, " 6=%d", align_corner);
5446         }
5447         else if (op == "RNN")
5448         {
5449             const onnx::TensorProto& W = weights[node.input(1)];
5450             const onnx::TensorProto& R = weights[node.input(2)];
5451             const onnx::TensorProto& B = weights[node.input(3)];
5452 
5453             int hidden_size = get_node_attr_i(node, "hidden_size", 0);
5454             std::string direction = get_node_attr_s(node, "direction");
5455 
5456             int direction_type = 0;
5457             if (direction == "forward")
5458             {
5459                 direction_type = 0;
5460             }
5461             else if (direction == "reverse")
5462             {
5463                 direction_type = 1;
5464             }
5465             else if (direction == "bidirectional")
5466             {
5467                 direction_type = 2;
5468             }
5469 
5470             int weight_data_size = get_tensor_proto_data_size(W);
5471 
5472             fprintf(pp, " 0=%d", hidden_size);
5473             fprintf(pp, " 1=%d", weight_data_size);
5474             fprintf(pp, " 2=%d", direction_type);
5475 
5476             int num_directions = direction_type == 2 ? 2 : 1;
5477 
5478             int quantize_tag = 0;
5479 
5480             fwrite(&quantize_tag, sizeof(int), 1, bp);
5481             fwrite_tensor_proto_data(W, bp);
5482 
5483             // reduce xc and hc bias
5484             {
5485                 fwrite(&quantize_tag, sizeof(int), 1, bp);
5486 
5487                 int bias_data_size_g = get_tensor_proto_data_size(B) / 2 / num_directions;
5488                 const float* bptr = B.has_raw_data() ? (const float*)B.raw_data().data() : B.float_data().data();
5489                 const float* xiptr = bptr;
5490                 const float* hiptr = bptr + bias_data_size_g;
5491 
5492                 for (int j = 0; j < bias_data_size_g; j++)
5493                 {
5494                     float vb = xiptr[j] + hiptr[j];
5495                     fwrite(&vb, sizeof(float), 1, bp);
5496                 }
5497 
5498                 if (direction_type == 2)
5499                 {
5500                     xiptr += bias_data_size_g * 2;
5501                     hiptr += bias_data_size_g * 2;
5502 
5503                     for (int j = 0; j < bias_data_size_g; j++)
5504                     {
5505                         float vb = xiptr[j] + hiptr[j];
5506                         fwrite(&vb, sizeof(float), 1, bp);
5507                     }
5508                 }
5509             }
5510 
5511             fwrite(&quantize_tag, sizeof(int), 1, bp);
5512             fwrite_tensor_proto_data(R, bp);
5513         }
5514         else if (op == "ShuffleChannel")
5515         {
5516             int group = get_node_attr_i(node, "group", 1);
5517             int reverse = get_node_attr_i(node, "reverse", 0);
5518             fprintf(pp, " 0=%d", group);
5519             fprintf(pp, " 1=%d", reverse);
5520         }
5521         else if (op == "Sigmoid")
5522         {
5523             // no param
5524         }
5525         else if (op == "Sin")
5526         {
5527             int op_type = 9;
5528             fprintf(pp, " 0=%d", op_type);
5529         }
5530         else if (op == "SkipLayerNormalization")
5531         {
5532             const onnx::TensorProto& W = weights[node.input(2)];
5533             const onnx::TensorProto& B = weights[node.input(3)];
5534             const onnx::TensorProto& B2 = weights[node.input(4)];
5535 
5536             fprintf(pp, " 0=%d", get_tensor_proto_data_size(B));
5537 
5538             int quantize_tag = 0;
5539             fwrite(&quantize_tag, sizeof(int), 1, bp);
5540 
5541             fwrite_tensor_proto_data(W, bp);
5542 
5543             fwrite(&quantize_tag, sizeof(int), 1, bp);
5544 
5545             fwrite_tensor_proto_data(B, bp);
5546 
5547             fwrite(&quantize_tag, sizeof(int), 1, bp);
5548 
5549             fwrite_tensor_proto_data(B2, bp);
5550         }
5551         else if (op == "Slice")
5552         {
5553             std::vector<int> starts;
5554             std::vector<int> ends;
5555             std::vector<int> axes;
5556             std::vector<int> steps;
5557             if (node.input_size() == 1)
5558             {
5559                 starts = get_node_attr_ai(node, "starts");
5560                 ends = get_node_attr_ai(node, "ends");
5561                 axes = get_node_attr_ai(node, "axes");
5562                 steps = get_node_attr_ai(node, "steps"); // TODO
5563             }
5564             else
5565             {
5566                 starts = get_node_attr_from_input_ai(weights[node.input(1)]);
5567                 ends = get_node_attr_from_input_ai(weights[node.input(2)]);
5568                 if (node.input_size() >= 4)
5569                     axes = get_node_attr_from_input_ai(weights[node.input(3)]);
5570                 if (node.input_size() >= 5)
5571                     steps = get_node_attr_from_input_ai(weights[node.input(4)]);
5572             }
5573 
5574             // assert step == 1
5575             for (int i = 0; i < (int)steps.size(); i++)
5576             {
5577                 if (steps[i] != 1)
5578                     fprintf(stderr, "Unsupported slice step !\n");
5579             }
5580 
5581             // filter out N-dim axis
5582             if (!axes.empty())
5583             {
5584                 for (int i = 0; i < (int)axes.size(); i++)
5585                 {
5586                     int axis = axes[i];
5587                     if (axis == 0)
5588                     {
5589                         starts.erase(starts.begin() + i);
5590                         ends.erase(ends.begin() + i);
5591                         axes.erase(axes.begin() + i);
5592                         break;
5593                     }
5594                 }
5595             }
5596 
5597             fprintf(pp, " -23309=%d", (int)starts.size());
5598             for (int i = 0; i < (int)starts.size(); i++)
5599             {
5600                 fprintf(pp, ",%d", starts[i]);
5601             }
5602             fprintf(pp, " -23310=%d", (int)ends.size());
5603             for (int i = 0; i < (int)ends.size(); i++)
5604             {
5605                 fprintf(pp, ",%d", ends[i]);
5606             }
5607             if (!axes.empty())
5608             {
5609                 fprintf(pp, " -23311=%d", (int)axes.size());
5610                 for (int i = 0; i < (int)axes.size(); i++)
5611                 {
5612                     int axis = axes[i];
5613                     if (axis == 0 || axis > 3 || axis < -3)
5614                         fprintf(stderr, "Unsupported slice axes !\n");
5615 
5616                     if (axis > 0)
5617                         axis = axis - 1; // -1 for skip N-dim
5618 
5619                     fprintf(pp, ",%d", axis);
5620                 }
5621             }
5622         }
5623         else if (op == "Softmax")
5624         {
5625             int axis = get_node_attr_i(node, "axis", 1);
5626             fprintf(pp, " 0=%d", axis - 1);
5627             fprintf(pp, " 1=1");
5628         }
5629         else if (op == "Split")
5630         {
5631             int axis = get_node_attr_i(node, "axis", 0);
5632             std::vector<int> split = get_node_attr_ai(node, "split");
5633             if (axis < 1)
5634                 fprintf(stderr, "Unsupported split axis !\n");
5635 
5636             fprintf(pp, " -23300=%d", output_size);
5637             if (split.empty())
5638             {
5639                 for (int i = 0; i < output_size; i++)
5640                 {
5641                     fprintf(pp, ",-233");
5642                 }
5643             }
5644             else
5645             {
5646                 for (size_t i = 0; i < split.size() - 1; i++)
5647                 {
5648                     fprintf(pp, ",%d", split[i]);
5649                 }
5650                 fprintf(pp, ",-233");
5651             }
5652             fprintf(pp, " 1=%d", axis - 1);
5653         }
5654         else if (op == "Sqrt")
5655         {
5656             int op_type = 5;
5657             fprintf(pp, " 0=%d", op_type);
5658         }
5659         else if (op == "Squeeze")
5660         {
5661             std::vector<int> axes = get_node_attr_ai(node, "axes");
5662 
5663             if (axes.empty())
5664             {
5665                 fprintf(pp, " 0=1");
5666                 fprintf(pp, " 1=1");
5667                 fprintf(pp, " 2=1");
5668             }
5669             else
5670             {
5671                 fprintf(pp, " -23303=%zu", axes.size());
5672                 for (int i = 0; i < (int)axes.size(); i++)
5673                 {
5674                     if (axes[i] == 0 || axes[i] > 3 || axes[i] < -3)
5675                         fprintf(stderr, "Unsupported squeeze axes !\n");
5676                     fprintf(pp, ",%d", axes[i]);
5677                 }
5678             }
5679         }
5680         else if (op == "Sub")
5681         {
5682             int op_type = 1;
5683             fprintf(pp, " 0=%d", op_type);
5684 
5685             int with_scalar = get_node_attr_i(node, "with_scalar", 0);
5686             float b = get_node_attr_f(node, "b", 0.f);
5687             if (with_scalar)
5688             {
5689                 fprintf(pp, " 1=%d", with_scalar);
5690                 fprintf(pp, " 2=%e", b);
5691             }
5692         }
5693         else if (op == "Sum")
5694         {
5695             int op_type = 1;
5696             fprintf(pp, " 0=%d", op_type);
5697         }
5698         else if (op == "Swish")
5699         {
5700             // no param
5701         }
5702         else if (op == "Tan")
5703         {
5704             int op_type = 11;
5705             fprintf(pp, " 0=%d", op_type);
5706         }
5707         else if (op == "Tanh")
5708         {
5709             int op_type = 16;
5710             fprintf(pp, " 0=%d", op_type);
5711         }
5712         else if (op == "Transpose")
5713         {
5714             std::vector<int> perm = get_node_attr_ai(node, "perm");
5715 
5716             if (perm.size() == 3)
5717             {
5718                 if (perm[1] == 1 && perm[2] == 2)
5719                     fprintf(pp, " 0=0"); // w h
5720                 else if (perm[1] == 2 && perm[2] == 1)
5721                     fprintf(pp, " 0=1"); // h w
5722                 else if (perm[0] == 1 && perm[1] == 0 && perm[2] == 2)
5723                     fprintf(pp, " 0=0"); // w h
5724                 else if (perm[0] == 2 && perm[1] == 0 && perm[2] == 1)
5725                     fprintf(pp, " 0=1"); // h w
5726             }
5727             else if (perm.size() == 4)
5728             {
5729                 if (perm[1] == 1 && perm[2] == 2 && perm[3] == 3)
5730                     fprintf(pp, " 0=0"); // w h c
5731                 else if (perm[1] == 1 && perm[2] == 3 && perm[3] == 2)
5732                     fprintf(pp, " 0=1"); // h w c
5733                 else if (perm[1] == 2 && perm[2] == 1 && perm[3] == 3)
5734                     fprintf(pp, " 0=2"); // w c h
5735                 else if (perm[1] == 2 && perm[2] == 3 && perm[3] == 1)
5736                     fprintf(pp, " 0=3"); // c w h
5737                 else if (perm[1] == 3 && perm[2] == 1 && perm[3] == 2)
5738                     fprintf(pp, " 0=4"); // h c w
5739                 else if (perm[1] == 3 && perm[2] == 2 && perm[3] == 1)
5740                     fprintf(pp, " 0=5"); // c h w
5741             }
5742             else if (perm.size() == 5)
5743             {
5744                 if (perm[1] == 1 && perm[2] == 2 && perm[3] == 3 && perm[4] == 4)
5745                     fprintf(pp, " 0=0"); // wx h c
5746                 else if (perm[1] == 1 && perm[2] == 3 && perm[3] == 4 && perm[4] == 2)
5747                     fprintf(pp, " 0=1"); // h wx c
5748                 else if (perm[1] == 2 && perm[2] == 1 && perm[3] == 3 && perm[4] == 4)
5749                     fprintf(pp, " 0=2"); // wx c h
5750                 else if (perm[1] == 2 && perm[2] == 3 && perm[3] == 4 && perm[4] == 1)
5751                     fprintf(pp, " 0=3"); // c wx h
5752                 else if (perm[1] == 3 && perm[2] == 4 && perm[3] == 1 && perm[4] == 2)
5753                     fprintf(pp, " 0=4"); // h c wx
5754                 else if (perm[1] == 3 && perm[2] == 4 && perm[3] == 2 && perm[4] == 1)
5755                     fprintf(pp, " 0=5"); // c h wx
5756                 else
5757                     fprintf(stderr, "Unsupported transpose type !\n");
5758             }
5759         }
5760         else if (op == "Upsample")
5761         {
5762             std::string mode = get_node_attr_s(node, "mode");
5763             std::string align = get_node_attr_s(node, "coordinate_transformation_mode");
5764 
5765             std::vector<float> scales;
5766 
5767             if (node.input_size() == 1)
5768             {
5769                 scales = get_node_attr_af(node, "scales");
5770             }
5771             else
5772             {
5773                 scales = get_node_attr_from_input_af(weights[node.input(1)]);
5774             }
5775 
5776             int resize_type = 1;
5777             if (mode == "nearest")
5778             {
5779                 resize_type = 1;
5780             }
5781             else if (mode == "bilinear" || mode == "linear")
5782             {
5783                 resize_type = 2;
5784             }
5785             else if (mode == "trilinear")
5786             {
5787                 fprintf(stderr, "Unsupported Upsample mode !\n");
5788             }
5789 
5790             float h_scale = 1.f;
5791             float w_scale = 1.f;
5792             if (scales.size() == 2)
5793             {
5794                 w_scale = scales[1];
5795             }
5796             else if (scales.size() == 3)
5797             {
5798                 h_scale = scales[1];
5799                 w_scale = scales[2];
5800             }
5801             else if (scales.size() == 4)
5802             {
5803                 h_scale = scales[2];
5804                 w_scale = scales[3];
5805 
5806                 if (scales[1] != 1.f)
5807                     fprintf(stderr, "Unsupported Upsample scales !\n");
5808             }
5809             else
5810             {
5811                 fprintf(stderr, "Unsupported Upsample scales !\n");
5812             }
5813 
5814             int align_corner = 0;
5815             if (align == "align_corners")
5816             {
5817                 align_corner = 1;
5818             }
5819 
5820             fprintf(pp, " 0=%d", resize_type);
5821             fprintf(pp, " 1=%e", h_scale);
5822             fprintf(pp, " 2=%e", w_scale);
5823             fprintf(pp, " 6=%d", align_corner);
5824         }
5825         else if (op == "Unsqueeze")
5826         {
5827             std::vector<int> axes = get_node_attr_ai(node, "axes");
5828 
5829             fprintf(pp, " -23303=%zu", axes.size());
5830             for (int i = 0; i < (int)axes.size(); i++)
5831             {
5832                 if (axes[i] == 0 || axes[i] > 4 || axes[i] < -4)
5833                     fprintf(stderr, "Unsupported unsqueeze axes !\n");
5834                 fprintf(pp, ",%d", axes[i]);
5835             }
5836         }
5837         else
5838         {
5839             // TODO op specific param
5840             for (int j = 0; j < node.attribute_size(); j++)
5841             {
5842                 const onnx::AttributeProto& attr = node.attribute(j);
5843                 if (attr.type() == 1)
5844                 {
5845                     fprintf(stderr, "  # %s=%g\n", attr.name().c_str(), attr.f());
5846                 }
5847                 else if (attr.type() == 2)
5848                 {
5849                     fprintf(stderr, "  # %s=%lld\n", attr.name().c_str(), (long long)attr.i());
5850                 }
5851                 else if (attr.type() == 3)
5852                 {
5853                     fprintf(stderr, "  # %s=%s\n", attr.name().c_str(), attr.s().c_str());
5854                 }
5855                 else
5856                 {
5857                     fprintf(stderr, "  # %s %d\n", attr.name().c_str(), attr.type());
5858                 }
5859             }
5860         }
5861 
5862         fprintf(pp, "\n");
5863 
5864         for (int j = 0; j < output_size; j++)
5865         {
5866             const std::string& output_name = node.output(j);
5867             if (node_reference.find(output_name) != node_reference.end())
5868             {
5869                 int refcount = node_reference[output_name];
5870                 if (refcount > 1)
5871                 {
5872                     char splitname[256];
5873                     sprintf(splitname, "splitncnn_%d", internal_split);
5874                     fprintf(pp, "%-16s %-24s %d %d", "Split", splitname, 1, refcount);
5875 
5876                     fprintf(pp, " %s", output_name.c_str());
5877 
5878                     for (int k = 0; k < refcount; k++)
5879                     {
5880                         fprintf(pp, " %s_splitncnn_%d", output_name.c_str(), k);
5881                     }
5882                     fprintf(pp, "\n");
5883 
5884                     internal_split++;
5885                 }
5886             }
5887         }
5888     }
5889 
5890     fclose(pp);
5891     fclose(bp);
5892 
5893     return 0;
5894 }
5895