1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14 
15 #include "onnx.pb.h"
16 
17 #include <algorithm>
18 #include <float.h>
19 #include <fstream>
20 #include <google/protobuf/io/coded_stream.h>
21 #include <google/protobuf/io/zero_copy_stream_impl.h>
22 #include <google/protobuf/message.h>
23 #include <google/protobuf/text_format.h>
24 #include <iostream>
25 #include <limits.h>
26 #include <limits>
27 #include <set>
28 #include <stdio.h>
29 
read_proto_from_binary(const char * filepath,onnx::ModelProto * message)30 static bool read_proto_from_binary(const char* filepath, onnx::ModelProto* message)
31 {
32     std::ifstream fs(filepath, std::ifstream::in | std::ifstream::binary);
33     if (!fs.is_open())
34     {
35         fprintf(stderr, "open failed %s\n", filepath);
36         return false;
37     }
38 
39     google::protobuf::io::IstreamInputStream input(&fs);
40     google::protobuf::io::CodedInputStream codedstr(&input);
41 
42 #if GOOGLE_PROTOBUF_VERSION >= 3011000
43     codedstr.SetTotalBytesLimit(INT_MAX);
44 #else
45     codedstr.SetTotalBytesLimit(INT_MAX, INT_MAX / 2);
46 #endif
47 
48     bool success = message->ParseFromCodedStream(&codedstr);
49 
50     fs.close();
51 
52     return success;
53 }
54 
get_node_attr_ai(const onnx::NodeProto & node,const char * key)55 static std::vector<int> get_node_attr_ai(const onnx::NodeProto& node, const char* key)
56 {
57     std::vector<int> v;
58 
59     for (int i = 0; i < node.attribute_size(); i++)
60     {
61         const onnx::AttributeProto& attr = node.attribute(i);
62         if (attr.name() == key)
63         {
64             v.resize(attr.ints_size());
65             for (int j = 0; j < attr.ints_size(); j++)
66             {
67                 v[j] = std::max(std::min(attr.ints(j), (::google::protobuf::int64)INT_MAX), (::google::protobuf::int64)INT_MIN);
68             }
69 
70             break;
71         }
72     }
73 
74     return v;
75 }
76 
get_node_attr_af(const onnx::NodeProto & node,const char * key)77 static std::vector<float> get_node_attr_af(const onnx::NodeProto& node, const char* key)
78 {
79     std::vector<float> v;
80 
81     for (int i = 0; i < node.attribute_size(); i++)
82     {
83         const onnx::AttributeProto& attr = node.attribute(i);
84         if (attr.name() == key)
85         {
86             v.resize(attr.floats_size());
87             for (int j = 0; j < attr.floats_size(); j++)
88             {
89                 v[j] = attr.floats(j);
90             }
91 
92             break;
93         }
94     }
95 
96     return v;
97 }
98 
get_node_attr_i(const onnx::NodeProto & node,const char * key,int def=0)99 static int get_node_attr_i(const onnx::NodeProto& node, const char* key, int def = 0)
100 {
101     for (int i = 0; i < node.attribute_size(); i++)
102     {
103         const onnx::AttributeProto& attr = node.attribute(i);
104         if (attr.name() == key)
105         {
106             return std::max(std::min(attr.i(), (::google::protobuf::int64)INT_MAX), (::google::protobuf::int64)INT_MIN);
107         }
108     }
109 
110     return def;
111 }
112 
get_node_attr_f(const onnx::NodeProto & node,const char * key,float def=0.f)113 static float get_node_attr_f(const onnx::NodeProto& node, const char* key, float def = 0.f)
114 {
115     for (int i = 0; i < node.attribute_size(); i++)
116     {
117         const onnx::AttributeProto& attr = node.attribute(i);
118         if (attr.name() == key)
119         {
120             return attr.f();
121         }
122     }
123 
124     return def;
125 }
126 
get_node_attr_s(const onnx::NodeProto & node,const char * key,const std::string & def=std::string ())127 static std::string get_node_attr_s(const onnx::NodeProto& node, const char* key, const std::string& def = std::string())
128 {
129     for (int i = 0; i < node.attribute_size(); i++)
130     {
131         const onnx::AttributeProto& attr = node.attribute(i);
132         if (attr.name() == key)
133         {
134             return attr.s();
135         }
136     }
137 
138     return def;
139 }
140 
get_node_attr_tensor(const onnx::NodeProto & node,const char * key)141 static onnx::TensorProto get_node_attr_tensor(const onnx::NodeProto& node, const char* key)
142 {
143     for (int i = 0; i < node.attribute_size(); i++)
144     {
145         const onnx::AttributeProto& attr = node.attribute(i);
146         if (attr.name() == key)
147         {
148             return attr.t();
149         }
150     }
151 
152     return onnx::TensorProto();
153 }
154 
get_node_attr_from_input_f(const onnx::TensorProto & tp)155 static float get_node_attr_from_input_f(const onnx::TensorProto& tp)
156 {
157     float v = 0.f;
158 
159     // float
160     if (tp.data_type() == 1)
161     {
162         const float* shape_data = 0;
163         if (tp.has_raw_data())
164         {
165             shape_data = (const float*)tp.raw_data().data();
166         }
167         else
168         {
169             shape_data = tp.float_data().data();
170         }
171         v = shape_data[0];
172     }
173     // double
174     else if (tp.data_type() == 11)
175     {
176         const double* shape_data = 0;
177         if (tp.has_raw_data())
178         {
179             shape_data = (const double*)tp.raw_data().data();
180         }
181         else
182         {
183             shape_data = tp.double_data().data();
184         }
185         v = shape_data[0];
186     }
187     // int64
188     else if (tp.data_type() == 7)
189     {
190         const int64_t* shape_data = 0;
191         if (tp.has_raw_data())
192         {
193             shape_data = (const int64_t*)tp.raw_data().data();
194         }
195         else
196         {
197             shape_data = tp.int64_data().data();
198         }
199         v = std::max(std::min(shape_data[0], (::google::protobuf::int64)INT_MAX), (::google::protobuf::int64)INT_MIN);
200     }
201     // int32
202     else if (tp.data_type() == 6)
203     {
204         const int32_t* shape_data = 0;
205         if (tp.has_raw_data())
206         {
207             shape_data = (const int32_t*)tp.raw_data().data();
208         }
209         else
210         {
211             shape_data = tp.int32_data().data();
212         }
213         v = shape_data[0];
214     }
215     else
216     {
217         fprintf(stderr, "Unknown data type %d\n", tp.data_type());
218         abort();
219     }
220 
221     return v;
222 }
223 
get_node_attr_from_input_ai(const onnx::TensorProto & tp)224 static std::vector<int> get_node_attr_from_input_ai(const onnx::TensorProto& tp)
225 {
226     int size = 0;
227 
228     std::vector<int> v;
229 
230     // int64
231     if (tp.data_type() == 7)
232     {
233         const int64_t* shape_data = 0;
234         if (tp.has_raw_data())
235         {
236             shape_data = (const int64_t*)tp.raw_data().data();
237             size = (int)(tp.raw_data().size() / 8);
238         }
239         else
240         {
241             shape_data = tp.int64_data().data();
242             size = tp.int64_data_size();
243         }
244         for (int j = 0; j < size; j++)
245         {
246             int vi = std::max(std::min(shape_data[j], (::google::protobuf::int64)INT_MAX), (::google::protobuf::int64)INT_MIN);
247             v.push_back(vi);
248         }
249     }
250     // int32
251     else if (tp.data_type() == 6)
252     {
253         const int32_t* shape_data = 0;
254         if (tp.has_raw_data())
255         {
256             shape_data = (const int32_t*)tp.raw_data().data();
257             size = (int)(tp.raw_data().size() / 4);
258         }
259         else
260         {
261             shape_data = tp.int32_data().data();
262             size = tp.int32_data_size();
263         }
264         for (int j = 0; j < size; j++)
265         {
266             v.push_back(shape_data[j]);
267         }
268     }
269     else
270     {
271         fprintf(stderr, "Unknown data type %d\n", tp.data_type());
272     }
273 
274     return v;
275 }
276 
get_node_attr_from_input_af(const onnx::TensorProto & tp)277 static std::vector<float> get_node_attr_from_input_af(const onnx::TensorProto& tp)
278 {
279     int size = 0;
280 
281     std::vector<float> v;
282 
283     // float
284     if (tp.data_type() == 1)
285     {
286         const float* shape_data = 0;
287         if (tp.has_raw_data())
288         {
289             shape_data = (const float*)tp.raw_data().data();
290             size = (int)(tp.raw_data().size() / 4);
291         }
292         else
293         {
294             shape_data = tp.float_data().data();
295             size = tp.float_data_size();
296         }
297         for (int j = 0; j < size; j++)
298         {
299             v.push_back(shape_data[j]);
300         }
301     }
302     // double
303     else if (tp.data_type() == 11)
304     {
305         const double* shape_data = 0;
306         if (tp.has_raw_data())
307         {
308             shape_data = (const double*)tp.raw_data().data();
309             size = (int)(tp.raw_data().size() / 8);
310         }
311         else
312         {
313             shape_data = tp.double_data().data();
314             size = tp.double_data_size();
315         }
316         for (int j = 0; j < size; j++)
317         {
318             v.push_back((float)shape_data[j]);
319         }
320     }
321     else
322     {
323         fprintf(stderr, "Unknown data type %d\n", tp.data_type());
324     }
325 
326     return v;
327 }
328 
get_tensor_proto_data_size(const onnx::TensorProto & tp)329 static int get_tensor_proto_data_size(const onnx::TensorProto& tp)
330 {
331     if (tp.has_raw_data())
332     {
333         const std::string& raw_data = tp.raw_data();
334         int size = (int)raw_data.size() / 4;
335         return size;
336     }
337     else if (tp.data_type() == 1)
338     {
339         return tp.float_data_size();
340     }
341 
342     return 0;
343 }
344 
fwrite_tensor_proto_data(const onnx::TensorProto & tp,FILE * bp)345 static void fwrite_tensor_proto_data(const onnx::TensorProto& tp, FILE* bp)
346 {
347     int size = get_tensor_proto_data_size(tp);
348 
349     if (tp.has_raw_data())
350     {
351         const std::string& raw_data = tp.raw_data();
352         fwrite(raw_data.data(), sizeof(float), size, bp);
353     }
354     else if (tp.data_type() == 1)
355     {
356         fwrite(tp.float_data().data(), sizeof(float), size, bp);
357     }
358 }
359 
fuse_weight_reshape(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)360 static void fuse_weight_reshape(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
361 {
362     int node_count = mutable_graph->node_size();
363     for (int i = 0; i < node_count; i++)
364     {
365         onnx::NodeProto* node = mutable_graph->mutable_node(i);
366 
367         // weight <= Reshape(weight)
368         if (node->op_type() == "Reshape")
369         {
370             // check weight
371             if (weights.find(node->input(0)) == weights.end())
372                 continue;
373 
374             weights[node->output(0)] = weights[node->input(0)];
375 
376             // set weight shape directly
377             std::vector<int> shape;
378             if (node->input_size() == 1)
379             {
380                 shape = get_node_attr_ai(*node, "shape");
381             }
382             else if (node->input_size() == 2)
383             {
384                 // opset 5
385                 shape = get_node_attr_from_input_ai(weights[node->input(1)]);
386             }
387 
388             weights[node->output(0)].clear_dims();
389             for (int j = 0; j < shape.size(); j++)
390             {
391                 weights[node->output(0)].add_dims(shape[j]);
392             }
393 
394             // reduce
395             node->set_op_type("noop_reducedncnn");
396 
397             node_reference[node->input(0)] -= 1;
398             if (node->input_size() == 2)
399             {
400                 node_reference[node->input(1)] -= 1;
401             }
402 
403             reduced_node_count += 1;
404             i += 1;
405         }
406     }
407 }
408 
fuse_weight_transpose(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)409 static void fuse_weight_transpose(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
410 {
411     int node_count = mutable_graph->node_size();
412     for (int i = 0; i < node_count; i++)
413     {
414         onnx::NodeProto* node = mutable_graph->mutable_node(i);
415 
416         // weight <= Transpose(weight)
417         if (node->op_type() == "Transpose")
418         {
419             // check weight
420             if (weights.find(node->input(0)) == weights.end())
421                 continue;
422 
423             if (weights[node->input(0)].dims_size() != 2)
424                 continue;
425 
426             // perm = (1, 0)
427             std::vector<int> perm = get_node_attr_ai(*node, "perm");
428             if (perm.size() != 2)
429                 continue;
430             if (perm[0] != 1 || perm[1] != 0)
431                 continue;
432 
433             weights[node->output(0)] = weights[node->input(0)];
434 
435             // permute weight
436             {
437                 onnx::TensorProto& B = weights[node->output(0)];
438 
439                 const int h = B.dims(0);
440                 const int w = B.dims(1);
441 
442                 std::vector<float> permuted_data;
443                 permuted_data.reserve((size_t)h * w);
444                 const float* bptr = B.has_raw_data() ? (const float*)B.raw_data().data() : B.float_data().data();
445 
446                 for (int j = 0; j < w; j++)
447                 {
448                     for (int k = 0; k < h; k++)
449                     {
450                         float vb = bptr[k * w + j];
451                         permuted_data.push_back(vb);
452                     }
453                 }
454 
455                 B.set_dims(0, w);
456                 B.set_dims(1, h);
457 
458                 if (B.has_raw_data())
459                 {
460                     B.set_raw_data(permuted_data.data(), permuted_data.size() * sizeof(float));
461                 }
462                 else
463                 {
464                     for (int j = 0; j < (int)permuted_data.size(); j++)
465                         B.set_float_data(j, permuted_data[j]);
466                 }
467             }
468 
469             // reduce
470             node->set_op_type("noop_reducedncnn");
471 
472             node_reference[node->input(0)] -= 1;
473 
474             reduced_node_count += 1;
475             i += 1;
476         }
477     }
478 }
479 
fuse_shufflechannel(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)480 static void fuse_shufflechannel(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
481 {
482     int node_count = mutable_graph->node_size();
483     for (int i = 0; i < node_count; i++)
484     {
485         onnx::NodeProto* node = mutable_graph->mutable_node(i);
486 
487         // ShuffleChannel <= Reshape - Transpose - Reshape
488         // ShuffleChannel <= Reshape - Transpose - Constant - Reshape
489         if (node->op_type() == "Reshape")
490         {
491             if (node_reference[node->output(0)] != 1)
492                 continue;
493 
494             std::vector<int> shape;
495             if (node->input_size() == 1)
496             {
497                 shape = get_node_attr_ai(*node, "shape");
498             }
499             else
500             {
501                 // skip weight reshape
502                 if (weights.find(node->input(1)) == weights.end())
503                     continue;
504 
505                 shape = get_node_attr_from_input_ai(weights[node->input(1)]);
506             }
507 
508             // 1 groups channels_per_group, height, width
509             // reverse style = channels_per_group, groups, height * width
510             if (shape.size() != 5 && shape.size() != 3)
511                 continue;
512 
513             if (shape.size() == 5 && shape[0] != 1)
514                 continue;
515 
516             if (i + 2 >= node_count)
517                 continue;
518 
519             onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
520             onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
521 
522             if (node3->op_type() == "Constant")
523             {
524                 if (i + 3 >= node_count)
525                     continue;
526 
527                 node3 = mutable_graph->mutable_node(i + 3);
528             }
529 
530             if (node2->op_type() != "Transpose" || node3->op_type() != "Reshape")
531                 continue;
532 
533             if (node_reference[node2->output(0)] != 1)
534                 continue;
535 
536             // 0 2 1 3 4
537             // reverse style = 1 0 2
538             std::vector<int> perm = get_node_attr_ai(*node2, "perm");
539             if (perm.size() != 5 && perm.size() != 3)
540                 continue;
541 
542             if (perm.size() == 5 && (perm[0] != 0 || perm[1] != 2 || perm[2] != 1 || perm[3] != 3 || perm[4] != 4))
543                 continue;
544 
545             if (perm.size() == 3 && (perm[0] != 1 || perm[1] != 0 || perm[2] != 2))
546                 continue;
547 
548             std::vector<int> shape3;
549             if (node3->input_size() == 1)
550             {
551                 shape3 = get_node_attr_ai(*node3, "shape");
552             }
553             else
554             {
555                 // skip weight reshape
556                 if (weights.find(node3->input(1)) == weights.end())
557                     continue;
558 
559                 shape3 = get_node_attr_from_input_ai(weights[node3->input(1)]);
560             }
561 
562             // 1, -1, height, width
563             // reverse style = group, -1, channels_per_group, height, width
564             if (shape3.size() != 4 && shape3.size() != 5)
565                 continue;
566 
567             if (shape3.size() == 4 && (shape3[0] != 1 || (shape3[1] != -1 && shape3[1] != shape[1] * shape[2])))
568                 continue;
569 
570             if (shape3.size() == 5 && (shape3[0] != shape[1] || shape3[2] != shape[0] || shape3[3] * shape3[4] != shape[2]))
571                 continue;
572 
573             // reduce
574             node->set_op_type("noop_reducedncnn");
575             node2->set_op_type("noop_reducedncnn");
576 
577             if (node->input_size() == 2)
578             {
579                 node_reference[node->input(1)] -= 1;
580             }
581             node_reference[node->output(0)] -= 1;
582             node_reference[node2->output(0)] -= 1;
583             if (node3->input_size() == 2)
584             {
585                 node_reference[node3->input(1)] -= 1;
586             }
587 
588             blob_names.erase(node->output(0));
589             blob_names.erase(node2->output(0));
590 
591             node3->set_op_type("ShuffleChannel");
592             node3->set_input(0, node->input(0));
593 
594             onnx::AttributeProto* attr_group = node3->add_attribute();
595             attr_group->set_name("group");
596             attr_group->set_i(shape[1]);
597 
598             onnx::AttributeProto* attr_reverse = node3->add_attribute();
599             attr_reverse->set_name("reverse");
600             attr_reverse->set_i(shape.size() == 3);
601 
602             reduced_node_count += 2;
603             i += 2;
604         }
605     }
606 }
607 
fuse_shufflechannel_split(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)608 static void fuse_shufflechannel_split(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
609 {
610     int node_count = mutable_graph->node_size();
611     for (int i = 0; i < node_count; i++)
612     {
613         onnx::NodeProto* node = mutable_graph->mutable_node(i);
614 
615         // Split <= ShuffleChannel(reverse type) - Gather(0) - Gather(1)
616         if (node->op_type() == "ShuffleChannel")
617         {
618             // reverse = 1
619             int reverse = get_node_attr_i(*node, "reverse");
620             if (reverse != 1)
621                 continue;
622 
623             if (i + 2 >= node_count)
624                 continue;
625 
626             onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
627             onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
628 
629             if (node2->op_type() != "Gather" || node3->op_type() != "Gather")
630                 continue;
631 
632             if (node2->input(0) != node->output(0) || node3->input(0) != node->output(0))
633                 continue;
634 
635             // axis = 0
636             int gather2_axis = get_node_attr_i(*node2, "axis");
637             if (gather2_axis != 0)
638                 continue;
639 
640             // indices = 0
641             if (weights.find(node2->input(1)) == weights.end())
642                 continue;
643 
644             std::vector<int> gather2_indices = get_node_attr_from_input_ai(weights[node2->input(1)]);
645             if (gather2_indices.size() != 1 || gather2_indices[0] != 0)
646                 continue;
647 
648             // axis = 0
649             int gather3_axis = get_node_attr_i(*node3, "axis");
650             if (gather3_axis != 0)
651                 continue;
652 
653             // indices = 1
654             if (weights.find(node3->input(1)) == weights.end())
655                 continue;
656 
657             std::vector<int> gather3_indices = get_node_attr_from_input_ai(weights[node3->input(1)]);
658             if (gather3_indices.size() != 1 || gather3_indices[0] != 1)
659                 continue;
660 
661             // reduce
662             node2->set_op_type("noop_reducedncnn");
663 
664             node_reference[node->output(0)] -= 2;
665             node_reference[node2->input(1)] -= 1;
666             node_reference[node3->input(1)] -= 1;
667 
668             node3->set_op_type("Split");
669             node3->clear_input();
670             node3->add_input(node->output(0));
671             node3->add_output(node3->output(0));
672             node3->set_output(0, node2->output(0));
673 
674             node3->clear_attribute();
675             onnx::AttributeProto* attr_axis = node3->add_attribute();
676             attr_axis->set_name("axis");
677             attr_axis->set_i(1);
678 
679             reduced_node_count += 1;
680             i += 1;
681         }
682     }
683 }
684 
fuse_hardswish(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)685 static void fuse_hardswish(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
686 {
687     int node_count = mutable_graph->node_size();
688     for (int i = 0; i < node_count; i++)
689     {
690         onnx::NodeProto* node = mutable_graph->mutable_node(i);
691 
692         // HardSwish <= Add(+3) - Clip(0,6) - Mul(X,) - Div(/6)
693         // HardSwish <= Add(+3) - Clip(0,6) - Mul(X,) - Mul(*(1/6))
694         // HardSwish <= Add(+3) - Clip(0,6) - Mul(X,) - Constant - Div(/6)
695         // HardSwish <= Add(+3) - Clip(0,6) - Mul(X,) - Constant - Mul(*(1/6))
696         //     out = x * F.relu6(x + 3, inplace=True) / 6
697         if (node->op_type() == "Add")
698         {
699             if (node_reference[node->output(0)] != 1)
700                 continue;
701 
702             if (i + 3 >= node_count)
703                 continue;
704 
705             if (weights.find(node->input(1)) == weights.end())
706                 continue;
707 
708             const onnx::TensorProto& add_three = weights[node->input(1)];
709             if (add_three.dims_size() != 0 || get_tensor_proto_data_size(add_three) != 1)
710                 continue;
711 
712             float constant_add_three = get_node_attr_from_input_f(add_three);
713             if (constant_add_three != 3.f)
714                 continue;
715 
716             onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
717             onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
718             onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3);
719 
720             if (node4->op_type() == "Constant")
721             {
722                 if (i + 4 >= node_count)
723                     continue;
724 
725                 node4 = mutable_graph->mutable_node(i + 4);
726             }
727 
728             if (node2->op_type() != "Clip" || node3->op_type() != "Mul" || (node4->op_type() != "Div" && node4->op_type() != "Mul"))
729                 continue;
730 
731             if (node_reference[node2->output(0)] != 1)
732                 continue;
733 
734             float relu6_min;
735             float relu6_max;
736             if (node2->input_size() == 1)
737             {
738                 relu6_min = get_node_attr_f(*node2, "min", -FLT_MAX);
739                 relu6_max = get_node_attr_f(*node2, "max", FLT_MAX);
740             }
741             else
742             {
743                 const onnx::TensorProto& min_tp = weights[node2->input(1)];
744                 const onnx::TensorProto& max_tp = weights[node2->input(2)];
745 
746                 relu6_min = get_node_attr_from_input_f(min_tp);
747                 relu6_max = get_node_attr_from_input_f(max_tp);
748             }
749             if (relu6_min != 0.f || relu6_max != 6.f)
750                 continue;
751 
752             if (node_reference[node3->output(0)] != 1)
753                 continue;
754 
755             if (node3->input(0) != node->input(0) || node3->input(1) != node2->output(0))
756                 continue;
757 
758             if (weights.find(node4->input(1)) == weights.end())
759                 continue;
760 
761             const onnx::TensorProto& div_six = weights[node4->input(1)];
762             if (div_six.dims_size() != 0 || get_tensor_proto_data_size(div_six) != 1)
763                 continue;
764 
765             float constant_div_six = get_node_attr_from_input_f(div_six);
766             if (node4->op_type() == "Div" && constant_div_six != 6.f)
767                 continue;
768             if (node4->op_type() == "Mul" && constant_div_six != 1 / 6.f)
769                 continue;
770 
771             // reduce
772             node->set_op_type("noop_reducedncnn");
773             node2->set_op_type("noop_reducedncnn");
774             node3->set_op_type("noop_reducedncnn");
775 
776             node_reference[node->input(0)] -= 1;
777             node_reference[node->input(1)] -= 1;
778             node_reference[node->output(0)] -= 1;
779             if (node2->input_size() == 3)
780             {
781                 node_reference[node2->input(1)] -= 1;
782                 node_reference[node2->input(2)] -= 1;
783             }
784             node_reference[node2->output(0)] -= 1;
785             node_reference[node3->output(0)] -= 1;
786             node_reference[node4->input(1)] -= 1;
787 
788             blob_names.erase(node->output(0));
789             blob_names.erase(node2->output(0));
790             blob_names.erase(node3->output(0));
791 
792             node4->set_op_type("HardSwish");
793             node4->clear_input();
794             node4->add_input(node->input(0));
795 
796             onnx::AttributeProto* attr_alpha = node4->add_attribute();
797             attr_alpha->set_name("alpha");
798             attr_alpha->set_f(1.f / 6.f);
799 
800             onnx::AttributeProto* attr_beta = node4->add_attribute();
801             attr_beta->set_name("beta");
802             attr_beta->set_f(3.f / 6.f);
803 
804             reduced_node_count += 3;
805             i += 3;
806         }
807     }
808 
809     for (int i = 0; i < node_count; i++)
810     {
811         onnx::NodeProto* node = mutable_graph->mutable_node(i);
812 
813         // HardSwish <= HardSigmoid - Mul
814         //     out = x * hsigmoid(x)
815         if (node->op_type() == "HardSigmoid")
816         {
817             if (node_reference[node->output(0)] != 1)
818                 continue;
819 
820             float alpha = get_node_attr_f(*node, "alpha", 0.2f);
821             float beta = get_node_attr_f(*node, "beta", 0.5f);
822 
823             if (i + 1 >= node_count)
824                 continue;
825 
826             onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
827 
828             if (node2->op_type() != "Mul")
829                 continue;
830 
831             if (node2->input(0) != node->input(0) || node2->input(1) != node->output(0))
832                 continue;
833 
834             // reduce
835             node->set_op_type("noop_reducedncnn");
836 
837             node_reference[node->input(0)] -= 1;
838             node_reference[node->output(0)] -= 1;
839 
840             blob_names.erase(node->output(0));
841 
842             node2->set_op_type("HardSwish");
843             node2->clear_input();
844             node2->add_input(node->input(0));
845 
846             onnx::AttributeProto* attr_alpha = node2->add_attribute();
847             attr_alpha->set_name("alpha");
848             attr_alpha->set_f(alpha);
849 
850             onnx::AttributeProto* attr_beta = node2->add_attribute();
851             attr_beta->set_name("beta");
852             attr_beta->set_f(beta);
853 
854             reduced_node_count += 1;
855             i += 1;
856         }
857     }
858 }
859 
fuse_hardsigmoid(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)860 static void fuse_hardsigmoid(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
861 {
862     int node_count = mutable_graph->node_size();
863     for (int i = 0; i < node_count; i++)
864     {
865         onnx::NodeProto* node = mutable_graph->mutable_node(i);
866 
867         // HardSigmoid <= Add(+3) - Clip(0,6) - Div(/6)
868         // HardSigmoid <= Add(+3) - Clip(0,6) - Mul(*(1/6))
869         // HardSigmoid <= Add(+3) - Clip(0,6) - Constant - Div(/6)
870         // HardSigmoid <= Add(+3) - Clip(0,6) - Constant - Mul(*(1/6))
871         //     out = F.relu6(x + 3, inplace=True) / 6
872         if (node->op_type() == "Add")
873         {
874             if (node_reference[node->output(0)] != 1)
875                 continue;
876 
877             if (i + 2 >= node_count)
878                 continue;
879 
880             if (weights.find(node->input(1)) == weights.end())
881                 continue;
882 
883             const onnx::TensorProto& add_three = weights[node->input(1)];
884             if (add_three.dims_size() != 0 || get_tensor_proto_data_size(add_three) != 1)
885                 continue;
886 
887             float constant_add_three = get_node_attr_from_input_f(add_three);
888             if (constant_add_three != 3.f)
889                 continue;
890 
891             onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
892             onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
893 
894             if (node3->op_type() == "Constant")
895             {
896                 if (i + 3 >= node_count)
897                     continue;
898 
899                 node3 = mutable_graph->mutable_node(i + 3);
900             }
901 
902             if (node2->op_type() != "Clip" || (node3->op_type() != "Div" && node3->op_type() != "Mul"))
903                 continue;
904 
905             if (node_reference[node2->output(0)] != 1)
906                 continue;
907 
908             float relu6_min;
909             float relu6_max;
910             if (node2->input_size() == 1)
911             {
912                 relu6_min = get_node_attr_f(*node2, "min", -FLT_MAX);
913                 relu6_max = get_node_attr_f(*node2, "max", FLT_MAX);
914             }
915             else
916             {
917                 const onnx::TensorProto& min_tp = weights[node2->input(1)];
918                 const onnx::TensorProto& max_tp = weights[node2->input(2)];
919 
920                 relu6_min = get_node_attr_from_input_f(min_tp);
921                 relu6_max = get_node_attr_from_input_f(max_tp);
922             }
923             if (relu6_min != 0.f || relu6_max != 6.f)
924                 continue;
925 
926             if (weights.find(node3->input(1)) == weights.end())
927                 continue;
928 
929             const onnx::TensorProto& div_six = weights[node3->input(1)];
930             if (div_six.dims_size() != 0 || get_tensor_proto_data_size(div_six) != 1)
931                 continue;
932 
933             float constant_div_six = get_node_attr_from_input_f(div_six);
934             if (node3->op_type() == "Div" && constant_div_six != 6.f)
935                 continue;
936             if (node3->op_type() == "Mul" && constant_div_six != 1 / 6.f)
937                 continue;
938 
939             // reduce
940             node->set_op_type("noop_reducedncnn");
941             node2->set_op_type("noop_reducedncnn");
942 
943             node_reference[node->input(1)] -= 1;
944             node_reference[node->output(0)] -= 1;
945             if (node2->input_size() == 3)
946             {
947                 node_reference[node2->input(1)] -= 1;
948                 node_reference[node2->input(2)] -= 1;
949             }
950             node_reference[node2->output(0)] -= 1;
951             node_reference[node3->input(1)] -= 1;
952 
953             blob_names.erase(node->output(0));
954             blob_names.erase(node2->output(0));
955 
956             node3->set_op_type("HardSigmoid");
957             node3->clear_input();
958             node3->add_input(node->input(0));
959 
960             onnx::AttributeProto* attr_alpha = node3->add_attribute();
961             attr_alpha->set_name("alpha");
962             attr_alpha->set_f(1.f / 6.f);
963 
964             onnx::AttributeProto* attr_beta = node3->add_attribute();
965             attr_beta->set_name("beta");
966             attr_beta->set_f(3.f / 6.f);
967 
968             reduced_node_count += 2;
969             i += 2;
970         }
971     }
972 }
973 
fuse_swish(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)974 static void fuse_swish(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
975 {
976     int node_count = mutable_graph->node_size();
977     for (int i = 0; i < node_count; i++)
978     {
979         onnx::NodeProto* node = mutable_graph->mutable_node(i);
980 
981         // Swish <= Sigmoid - Mul
982         //     x * torch.sigmoid(x)
983         if (node->op_type() == "Sigmoid")
984         {
985             if (node_reference[node->output(0)] != 1)
986                 continue;
987 
988             if (i + 1 >= node_count)
989                 continue;
990 
991             onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
992 
993             if (node2->op_type() != "Mul")
994                 continue;
995 
996             if (node2->input(0) != node->input(0) || node2->input(1) != node->output(0))
997                 continue;
998 
999             // reduce
1000             node->set_op_type("noop_reducedncnn");
1001 
1002             node_reference[node->input(0)] -= 1;
1003             node_reference[node->output(0)] -= 1;
1004 
1005             blob_names.erase(node->output(0));
1006 
1007             node2->set_op_type("Swish");
1008             node2->clear_input();
1009             node2->add_input(node->input(0));
1010 
1011             reduced_node_count += 1;
1012             i += 1;
1013         }
1014     }
1015 }
1016 
fuse_batchnorm1d_squeeze_unsqueeze(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)1017 static void fuse_batchnorm1d_squeeze_unsqueeze(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
1018 {
1019     int node_count = mutable_graph->node_size();
1020     for (int i = 0; i < node_count; i++)
1021     {
1022         onnx::NodeProto* node = mutable_graph->mutable_node(i);
1023 
1024         // BatchNormalization <= Unsqueeze - BatchNormalization - Squeeze
1025         if (node->op_type() == "Unsqueeze")
1026         {
1027             if (node_reference[node->output(0)] != 1)
1028                 continue;
1029 
1030             if (i + 2 >= node_count)
1031                 continue;
1032 
1033             onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
1034             onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
1035 
1036             if (node2->op_type() != "BatchNormalization" || node3->op_type() != "Squeeze")
1037                 continue;
1038 
1039             if (node_reference[node2->output(0)] != 1)
1040                 continue;
1041 
1042             if (node2->input(0) != node->output(0) || node3->input(0) != node2->output(0))
1043                 continue;
1044 
1045             // reduce
1046             node->set_op_type("noop_reducedncnn");
1047             node3->set_op_type("noop_reducedncnn");
1048 
1049             node_reference[node->output(0)] -= 1;
1050             node_reference[node2->output(0)] -= 1;
1051 
1052             blob_names.erase(node->output(0));
1053             blob_names.erase(node2->output(0));
1054 
1055             node2->set_input(0, node->input(0));
1056             node2->set_output(0, node3->output(0));
1057 
1058             reduced_node_count += 2;
1059             i += 2;
1060         }
1061     }
1062 }
1063 
fuse_unsqueeze_prelu(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)1064 static void fuse_unsqueeze_prelu(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
1065 {
1066     int node_count = mutable_graph->node_size();
1067     for (int i = 0; i < node_count; i++)
1068     {
1069         onnx::NodeProto* node = mutable_graph->mutable_node(i);
1070 
1071         // PReLU <= Unsqueeze - PReLU
1072         if (node->op_type() == "Unsqueeze")
1073         {
1074             // check weight
1075             if (weights.find(node->input(0)) == weights.end())
1076                 continue;
1077 
1078             onnx::TensorProto& B = weights[node->input(0)];
1079             if (B.dims_size() != 1)
1080                 continue;
1081 
1082             if (node_reference[node->output(0)] != 1)
1083                 continue;
1084 
1085             // axes = (1, 2)
1086             std::vector<int> axes = get_node_attr_ai(*node, "axes");
1087             if (axes.size() != 2)
1088                 continue;
1089             if (axes[0] != 1 || axes[1] != 2)
1090                 continue;
1091 
1092             if (i + 1 >= node_count)
1093                 continue;
1094 
1095             onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
1096 
1097             if (node2->op_type() != "PRelu")
1098                 continue;
1099 
1100             if (node2->input(1) != node->output(0))
1101                 continue;
1102 
1103             // reduce
1104             node->set_op_type("noop_reducedncnn");
1105 
1106             node_reference[node->output(0)] -= 1;
1107 
1108             blob_names.erase(node->output(0));
1109 
1110             node2->set_input(1, node->input(0));
1111 
1112             reduced_node_count += 1;
1113             i += 1;
1114         }
1115     }
1116 }
1117 
fuse_normalize(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)1118 static void fuse_normalize(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
1119 {
1120     int node_count = mutable_graph->node_size();
1121     for (int i = 0; i < node_count; i++)
1122     {
1123         onnx::NodeProto* node = mutable_graph->mutable_node(i);
1124 
1125         // Normalize <= X - ReduceL2 - Clip - Expand - Div
1126         // Normalize <= X - ReduceL2 - Clip - Shape - Expand - Div
1127         if (node->op_type() == "ReduceL2")
1128         {
1129             if (node_reference[node->output(0)] != 1)
1130                 continue;
1131 
1132             // axes = (1)
1133             std::vector<int> axes = get_node_attr_ai(*node, "axes");
1134             if (axes.size() != 1)
1135                 continue;
1136             if (axes[0] != 1)
1137                 continue;
1138 
1139             if (i + 3 >= node_count)
1140                 continue;
1141 
1142             onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
1143             onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
1144             onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3);
1145 
1146             bool has_shape_node = node3->op_type() == "Shape";
1147             onnx::NodeProto* node_shape = 0;
1148             if (has_shape_node)
1149             {
1150                 if (i + 4 >= node_count)
1151                     continue;
1152 
1153                 node_shape = node3;
1154                 node3 = mutable_graph->mutable_node(i + 3);
1155                 node4 = mutable_graph->mutable_node(i + 4);
1156             }
1157 
1158             if (node2->op_type() != "Clip" || node3->op_type() != "Expand" || node4->op_type() != "Div")
1159                 continue;
1160 
1161             if (node_reference[node2->output(0)] != 1)
1162                 continue;
1163 
1164             if (node_reference[node3->output(0)] != 1)
1165                 continue;
1166 
1167             if (node2->input(0) != node->output(0) || node3->input(0) != node2->output(0)
1168                     || node4->input(0) != node->input(0) || node4->input(1) != node3->output(0))
1169                 continue;
1170 
1171             if (has_shape_node)
1172             {
1173                 if (node_shape->input(0) != node->input(0) || node3->input(1) != node_shape->output(0))
1174                     continue;
1175             }
1176 
1177             // +eps
1178             float clip_min;
1179             if (node2->input_size() == 1)
1180             {
1181                 clip_min = get_node_attr_f(*node2, "min", -FLT_MAX);
1182             }
1183             else
1184             {
1185                 const onnx::TensorProto& min_tp = weights[node2->input(1)];
1186 
1187                 clip_min = get_node_attr_from_input_f(min_tp);
1188             }
1189 
1190             // reduce
1191             node->set_op_type("noop_reducedncnn");
1192             node2->set_op_type("noop_reducedncnn");
1193             if (has_shape_node)
1194             {
1195                 node_shape->set_op_type("noop_reducedncnn");
1196             }
1197             node3->set_op_type("noop_reducedncnn");
1198 
1199             node_reference[node->input(0)] -= has_shape_node ? 2 : 1;
1200             node_reference[node->output(0)] -= 1;
1201             node_reference[node2->output(0)] -= 1;
1202             if (has_shape_node)
1203             {
1204                 node_reference[node_shape->output(0)] -= 1;
1205             }
1206             node_reference[node3->output(0)] -= 1;
1207             if (node3->input_size() == 2)
1208             {
1209                 node_reference[node3->input(1)] -= 1;
1210             }
1211 
1212             blob_names.erase(node->output(0));
1213             blob_names.erase(node2->output(0));
1214             if (has_shape_node)
1215             {
1216                 blob_names.erase(node_shape->output(0));
1217             }
1218             blob_names.erase(node3->output(0));
1219 
1220             node4->set_op_type("Normalize");
1221             node4->clear_input();
1222             node4->add_input(node->input(0));
1223 
1224             onnx::AttributeProto* attr_alpha = node4->add_attribute();
1225             attr_alpha->set_name("eps");
1226             attr_alpha->set_f(clip_min);
1227 
1228             reduced_node_count += has_shape_node ? 4 : 3;
1229             i += has_shape_node ? 4 : 3;
1230         }
1231     }
1232 }
1233 
fuse_groupnorm(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)1234 static void fuse_groupnorm(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
1235 {
1236     int node_count = mutable_graph->node_size();
1237     for (int i = 0; i < node_count; i++)
1238     {
1239         onnx::NodeProto* node = mutable_graph->mutable_node(i);
1240 
1241         // GroupNorm <= X - Reshape - InstanceNormalization - Reshape - Mul - Add
1242         if (node->op_type() == "Reshape")
1243         {
1244             if (node_reference[node->output(0)] != 1)
1245                 continue;
1246 
1247             std::vector<int> shape;
1248             if (node->input_size() == 1)
1249             {
1250                 shape = get_node_attr_ai(*node, "shape");
1251             }
1252             else
1253             {
1254                 // skip weight reshape
1255                 if (weights.find(node->input(1)) == weights.end())
1256                     continue;
1257 
1258                 shape = get_node_attr_from_input_ai(weights[node->input(1)]);
1259             }
1260 
1261             // 0, group, -1
1262             if (shape.size() != 3)
1263                 continue;
1264 
1265             if (shape[0] != 0 || shape[2] != -1)
1266                 continue;
1267 
1268             int groups = shape[1];
1269 
1270             if (i + 4 >= node_count)
1271                 continue;
1272 
1273             onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
1274             onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
1275             onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3);
1276             onnx::NodeProto* node5 = mutable_graph->mutable_node(i + 4);
1277 
1278             if (node2->op_type() != "InstanceNormalization" || node3->op_type() != "Reshape" || node4->op_type() != "Mul" || node5->op_type() != "Add")
1279                 continue;
1280 
1281             if (node_reference[node2->output(0)] != 1)
1282                 continue;
1283 
1284             if (node_reference[node3->output(0)] != 1)
1285                 continue;
1286 
1287             if (node_reference[node4->output(0)] != 1)
1288                 continue;
1289 
1290             if (node2->input(0) != node->output(0) || node3->input(0) != node2->output(0)
1291                     || node4->input(0) != node3->output(0) || node5->input(0) != node4->output(0))
1292                 continue;
1293 
1294             // +eps
1295             float eps = get_node_attr_f(*node2, "epsilon", 1e-05f);
1296 
1297             // InstanceNormalization S=1 B=0
1298             std::vector<float> S = get_node_attr_from_input_af(weights[node2->input(1)]);
1299             std::vector<float> B = get_node_attr_from_input_af(weights[node2->input(2)]);
1300             if ((int)S.size() != groups || (int)B.size() != groups)
1301                 continue;
1302 
1303             bool instancenorm_affine = false;
1304             for (int j = 0; j < groups; j++)
1305             {
1306                 if (S[j] != 1.f || B[j] != 0.f)
1307                 {
1308                     instancenorm_affine = true;
1309                     break;
1310                 }
1311             }
1312 
1313             if (instancenorm_affine)
1314                 continue;
1315 
1316             std::vector<int> shape2;
1317             if (node3->input_size() == 1)
1318             {
1319                 shape2 = get_node_attr_ai(*node3, "shape");
1320             }
1321             else
1322             {
1323                 // skip weight reshape
1324                 if (weights.find(node3->input(1)) == weights.end())
1325                     continue;
1326 
1327                 shape2 = get_node_attr_from_input_ai(weights[node3->input(1)]);
1328             }
1329 
1330             // 1, channels, w, h
1331             if (shape2.size() != 4)
1332                 continue;
1333 
1334             if (shape2[0] != 1)
1335                 continue;
1336 
1337             int channels = shape2[1];
1338 
1339             // affine
1340             std::vector<float> affine_S = get_node_attr_from_input_af(weights[node4->input(1)]);
1341             std::vector<float> affine_B = get_node_attr_from_input_af(weights[node5->input(1)]);
1342             if (affine_S.size() == 1 && affine_S[0] == 1.f && affine_B.size() == 1 && affine_B[0] == 0.f)
1343             {
1344                 // no affine
1345             }
1346             else if ((int)affine_S.size() != channels && (int)affine_B.size() != channels)
1347             {
1348                 // we only allow per-channel affine
1349                 continue;
1350             }
1351 
1352             // reduce
1353             node->set_op_type("noop_reducedncnn");
1354             node2->set_op_type("noop_reducedncnn");
1355             node3->set_op_type("noop_reducedncnn");
1356             node4->set_op_type("noop_reducedncnn");
1357 
1358             if (node->input_size() == 2)
1359             {
1360                 node_reference[node->input(1)] -= 1;
1361             }
1362             node_reference[node->output(0)] -= 1;
1363             node_reference[node2->input(1)] -= 1;
1364             node_reference[node2->input(2)] -= 1;
1365             node_reference[node2->output(0)] -= 1;
1366             if (node3->input_size() == 2)
1367             {
1368                 node_reference[node3->input(1)] -= 1;
1369             }
1370             node_reference[node3->output(0)] -= 1;
1371             node_reference[node4->output(0)] -= 1;
1372 
1373             blob_names.erase(node->output(0));
1374             blob_names.erase(node2->output(0));
1375             blob_names.erase(node3->output(0));
1376             blob_names.erase(node4->output(0));
1377 
1378             std::string affine_scale = node4->input(1);
1379             std::string affine_bias = node5->input(1);
1380 
1381             node5->set_op_type("GroupNorm");
1382             node5->clear_input();
1383             node5->add_input(node->input(0));
1384             node5->add_input(affine_scale);
1385             node5->add_input(affine_bias);
1386 
1387             onnx::AttributeProto* attr_groups = node5->add_attribute();
1388             attr_groups->set_name("groups");
1389             attr_groups->set_i(groups);
1390 
1391             onnx::AttributeProto* attr_channels = node5->add_attribute();
1392             attr_channels->set_name("channels");
1393             attr_channels->set_i(channels);
1394 
1395             onnx::AttributeProto* attr_eps = node5->add_attribute();
1396             attr_eps->set_name("epsilon");
1397             attr_eps->set_f(eps);
1398 
1399             onnx::AttributeProto* attr_affine = node5->add_attribute();
1400             attr_affine->set_name("affine");
1401             attr_affine->set_i(1);
1402 
1403             reduced_node_count += 4;
1404             i += 4;
1405         }
1406     }
1407 }
1408 
fuse_layernorm(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)1409 static void fuse_layernorm(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
1410 {
1411     int node_count = mutable_graph->node_size();
1412     for (int i = 0; i < node_count; i++)
1413     {
1414         onnx::NodeProto* node = mutable_graph->mutable_node(i);
1415 
1416         // LayerNorm <= X - ReduceMean - Sub - Pow - ReduceMean - Add - Sqrt - Div
1417         // LayerNorm <= X - ReduceMean - Sub - Pow - ReduceMean - Add - Sqrt - Div - Mul - Add
1418         if (node->op_type() == "ReduceMean")
1419         {
1420             if (node_reference[node->output(0)] != 1)
1421                 continue;
1422 
1423             std::vector<int> axes = get_node_attr_ai(*node, "axes");
1424 
1425             // -1
1426             // -2 -1
1427             if (axes.size() != 1 && axes.size() != 2)
1428                 continue;
1429 
1430             int normed_axes = (int)axes.size();
1431             if (normed_axes == 1 && axes[0] != -1)
1432                 continue;
1433             if (normed_axes == 2 && (axes[0] != -2 || axes[1] != -1))
1434                 continue;
1435 
1436             if (i + 6 >= node_count)
1437                 continue;
1438 
1439             onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
1440             onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
1441             onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3);
1442             onnx::NodeProto* node5 = mutable_graph->mutable_node(i + 4);
1443             onnx::NodeProto* node6 = mutable_graph->mutable_node(i + 5);
1444             onnx::NodeProto* node7 = mutable_graph->mutable_node(i + 6);
1445 
1446             if (node2->op_type() != "Sub" || node3->op_type() != "Pow" || node4->op_type() != "ReduceMean" || node5->op_type() != "Add" || node6->op_type() != "Sqrt" || node7->op_type() != "Div")
1447                 continue;
1448 
1449             if (node_reference[node2->output(0)] != 2)
1450                 continue;
1451 
1452             if (node_reference[node3->output(0)] != 1)
1453                 continue;
1454 
1455             if (node_reference[node4->output(0)] != 1)
1456                 continue;
1457 
1458             if (node_reference[node5->output(0)] != 1)
1459                 continue;
1460 
1461             if (node_reference[node6->output(0)] != 1)
1462                 continue;
1463 
1464             if (node2->input(0) != node->input(0) || node2->input(1) != node->output(0)
1465                     || node3->input(0) != node2->output(0) || node4->input(0) != node3->output(0)
1466                     || node5->input(0) != node4->output(0) || node6->input(0) != node5->output(0)
1467                     || node7->input(0) != node2->output(0) || node7->input(1) != node6->output(0))
1468                 continue;
1469 
1470             if (weights.find(node3->input(1)) == weights.end())
1471                 continue;
1472 
1473             const onnx::TensorProto& pow_two = weights[node3->input(1)];
1474             if (pow_two.dims_size() != 0 || get_tensor_proto_data_size(pow_two) != 1)
1475                 continue;
1476 
1477             float constant_pow_two = get_node_attr_from_input_f(pow_two);
1478             if (constant_pow_two != 2.f)
1479                 continue;
1480 
1481             std::vector<int> axes4 = get_node_attr_ai(*node4, "axes");
1482 
1483             // -1
1484             // -2 -1
1485             if ((int)axes4.size() != normed_axes)
1486                 continue;
1487 
1488             if (normed_axes == 1 && axes4[0] != -1)
1489                 continue;
1490             if (normed_axes == 2 && (axes4[0] != -2 || axes4[1] != -1))
1491                 continue;
1492 
1493             if (weights.find(node5->input(1)) == weights.end())
1494                 continue;
1495 
1496             const onnx::TensorProto& add_eps = weights[node5->input(1)];
1497             if (add_eps.dims_size() != 0 || get_tensor_proto_data_size(add_eps) != 1)
1498                 continue;
1499 
1500             float eps = get_node_attr_from_input_f(add_eps);
1501 
1502             int affine = 0;
1503             while (i + 8 < node_count)
1504             {
1505                 onnx::NodeProto* node8 = mutable_graph->mutable_node(i + 7);
1506                 onnx::NodeProto* node9 = mutable_graph->mutable_node(i + 8);
1507 
1508                 if (node8->op_type() != "Mul" || node9->op_type() != "Add")
1509                     break;
1510 
1511                 if (node_reference[node7->output(0)] != 1)
1512                     break;
1513 
1514                 if (node_reference[node8->output(0)] != 1)
1515                     break;
1516 
1517                 if (node8->input(0) != node7->output(0) || node9->input(0) != node8->output(0))
1518                     break;
1519 
1520                 // affine
1521                 std::vector<float> affine_S = get_node_attr_from_input_af(weights[node8->input(1)]);
1522                 std::vector<float> affine_B = get_node_attr_from_input_af(weights[node9->input(1)]);
1523                 if (affine_S.size() != affine_B.size())
1524                     break;
1525 
1526                 affine = 1;
1527                 break;
1528             }
1529 
1530             // reduce
1531             node->set_op_type("noop_reducedncnn");
1532             node2->set_op_type("noop_reducedncnn");
1533             node3->set_op_type("noop_reducedncnn");
1534             node4->set_op_type("noop_reducedncnn");
1535             node5->set_op_type("noop_reducedncnn");
1536             node6->set_op_type("noop_reducedncnn");
1537 
1538             node_reference[node->input(0)] -= 1;
1539             node_reference[node2->input(0)] -= 1;
1540             node_reference[node2->input(1)] -= 1;
1541             node_reference[node3->input(0)] -= 1;
1542             node_reference[node3->input(1)] -= 1;
1543             node_reference[node4->input(0)] -= 1;
1544             node_reference[node5->input(0)] -= 1;
1545             node_reference[node5->input(1)] -= 1;
1546             node_reference[node6->input(0)] -= 1;
1547             node_reference[node7->input(0)] -= 1;
1548             node_reference[node7->input(1)] -= 1;
1549 
1550             blob_names.erase(node->output(0));
1551             blob_names.erase(node2->output(0));
1552             blob_names.erase(node3->output(0));
1553             blob_names.erase(node4->output(0));
1554             blob_names.erase(node5->output(0));
1555             blob_names.erase(node6->output(0));
1556 
1557             node_reference[node->input(0)] += 1;
1558 
1559             if (affine == 0)
1560             {
1561                 node7->set_op_type("LayerNorm");
1562                 node7->clear_input();
1563                 node7->add_input(node->input(0));
1564 
1565                 onnx::AttributeProto* attr_eps = node7->add_attribute();
1566                 attr_eps->set_name("epsilon");
1567                 attr_eps->set_f(eps);
1568 
1569                 onnx::AttributeProto* attr_affine = node7->add_attribute();
1570                 attr_affine->set_name("affine");
1571                 attr_affine->set_i(affine);
1572 
1573                 reduced_node_count += 6;
1574                 i += 6;
1575             }
1576             else // if (affine == 1)
1577             {
1578                 onnx::NodeProto* node8 = mutable_graph->mutable_node(i + 7);
1579                 onnx::NodeProto* node9 = mutable_graph->mutable_node(i + 8);
1580 
1581                 node7->set_op_type("noop_reducedncnn");
1582                 node8->set_op_type("noop_reducedncnn");
1583 
1584                 node_reference[node8->input(0)] -= 1;
1585                 node_reference[node9->input(0)] -= 1;
1586 
1587                 blob_names.erase(node7->output(0));
1588                 blob_names.erase(node8->output(0));
1589 
1590                 std::string affine_scale = node8->input(1);
1591                 std::string affine_bias = node9->input(1);
1592 
1593                 node9->set_op_type("LayerNorm");
1594                 node9->clear_input();
1595                 node9->add_input(node->input(0));
1596                 node9->add_input(affine_scale);
1597                 node9->add_input(affine_bias);
1598 
1599                 onnx::AttributeProto* attr_eps = node9->add_attribute();
1600                 attr_eps->set_name("epsilon");
1601                 attr_eps->set_f(eps);
1602 
1603                 onnx::AttributeProto* attr_affine = node9->add_attribute();
1604                 attr_affine->set_name("affine");
1605                 attr_affine->set_i(affine);
1606 
1607                 reduced_node_count += 8;
1608                 i += 8;
1609             }
1610         }
1611     }
1612 }
1613 
fuse_flatten(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)1614 static void fuse_flatten(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
1615 {
1616     int node_count = mutable_graph->node_size();
1617     for (int i = 0; i < node_count; i++)
1618     {
1619         onnx::NodeProto* node = mutable_graph->mutable_node(i);
1620 
1621         // Flatten <= X - Shape - Gather - Constant - Unsqueeze - Unsqueeze - Concat - Reshape
1622         if (node->op_type() == "Shape")
1623         {
1624             if (node_reference[node->output(0)] != 1)
1625                 continue;
1626 
1627             if (i + 6 >= node_count)
1628                 continue;
1629 
1630             onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
1631             onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
1632             onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3);
1633             onnx::NodeProto* node5 = mutable_graph->mutable_node(i + 4);
1634             onnx::NodeProto* node6 = mutable_graph->mutable_node(i + 5);
1635             onnx::NodeProto* node7 = mutable_graph->mutable_node(i + 6);
1636 
1637             if (node2->op_type() != "Gather" || node3->op_type() != "Constant" || node4->op_type() != "Unsqueeze" || node5->op_type() != "Unsqueeze"
1638                     || node6->op_type() != "Concat" || node7->op_type() != "Reshape")
1639                 continue;
1640 
1641             if (node_reference[node2->output(0)] != 1)
1642                 continue;
1643 
1644             //             if (node_reference[node3->output(0)] != 1)
1645             //                 continue;
1646 
1647             if (node_reference[node4->output(0)] != 1)
1648                 continue;
1649 
1650             if (node_reference[node5->output(0)] != 1)
1651                 continue;
1652 
1653             if (node_reference[node6->output(0)] != 1)
1654                 continue;
1655 
1656             if (node2->input(0) != node->output(0) || node4->input(0) != node2->output(0) || node5->input(0) != node3->output(0)
1657                     || node6->input(0) != node4->output(0) || node6->input(1) != node5->output(0)
1658                     || node7->input(0) != node->input(0) || node7->input(1) != node6->output(0))
1659                 continue;
1660 
1661             // axis = 0
1662             int gather_axis = get_node_attr_i(*node2, "axis");
1663             if (gather_axis != 0)
1664                 continue;
1665 
1666             // indices = 0
1667             if (weights.find(node2->input(1)) == weights.end())
1668                 continue;
1669 
1670             std::vector<int> gather_indices = get_node_attr_from_input_ai(weights[node2->input(1)]);
1671             if (gather_indices.size() != 1 || gather_indices[0] != 0)
1672                 continue;
1673 
1674             // axes = (0)
1675             std::vector<int> unsqueeze_axes = get_node_attr_ai(*node4, "axes");
1676             if (unsqueeze_axes.size() != 1)
1677                 continue;
1678             if (unsqueeze_axes[0] != 0)
1679                 continue;
1680 
1681             // axes = (0)
1682             std::vector<int> unsqueeze2_axes = get_node_attr_ai(*node5, "axes");
1683             if (unsqueeze2_axes.size() != 1)
1684                 continue;
1685             if (unsqueeze2_axes[0] != 0)
1686                 continue;
1687 
1688             // data = -1
1689             if (weights.find(node5->input(0)) == weights.end())
1690                 continue;
1691 
1692             std::vector<int> unsqueeze2_data = get_node_attr_from_input_ai(weights[node5->input(0)]);
1693             if (unsqueeze2_data.size() != 1 || unsqueeze2_data[0] != -1)
1694                 continue;
1695 
1696             // axis = 0
1697             int concat_axis = get_node_attr_i(*node6, "axis");
1698             if (concat_axis != 0)
1699                 continue;
1700 
1701             // reduce
1702             node->set_op_type("noop_reducedncnn");
1703             node2->set_op_type("noop_reducedncnn");
1704             //             node3->set_op_type("noop_reducedncnn");
1705             node4->set_op_type("noop_reducedncnn");
1706             node5->set_op_type("noop_reducedncnn");
1707             node6->set_op_type("noop_reducedncnn");
1708 
1709             node_reference[node->input(0)] -= 1;
1710             node_reference[node->output(0)] -= 1;
1711             node_reference[node2->input(1)] -= 1;
1712             node_reference[node2->output(0)] -= 1;
1713             //             node_reference[node3->output(0)] -= 1;
1714             node_reference[node4->output(0)] -= 1;
1715             node_reference[node5->input(0)] -= 1;
1716             node_reference[node5->output(0)] -= 1;
1717             node_reference[node6->output(0)] -= 1;
1718 
1719             blob_names.erase(node->output(0));
1720             blob_names.erase(node2->output(0));
1721             //             blob_names.erase(node3->output(0));
1722             blob_names.erase(node4->output(0));
1723             blob_names.erase(node5->output(0));
1724             blob_names.erase(node6->output(0));
1725 
1726             node7->set_op_type("Flatten");
1727             node7->clear_input();
1728             node7->add_input(node->input(0));
1729 
1730             reduced_node_count += 5;
1731             i += 5;
1732         }
1733     }
1734 }
1735 
fuse_pixelshuffle(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)1736 static void fuse_pixelshuffle(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
1737 {
1738     int node_count = mutable_graph->node_size();
1739     for (int i = 0; i < node_count; i++)
1740     {
1741         onnx::NodeProto* node = mutable_graph->mutable_node(i);
1742 
1743         // PixelShuffle <= Reshape - Transpose - Reshape
1744         // PixelShuffle <= Reshape - Transpose - Constant - Reshape
1745         if (node->op_type() == "Reshape")
1746         {
1747             if (node_reference[node->output(0)] != 1)
1748                 continue;
1749 
1750             std::vector<int> shape;
1751             if (node->input_size() == 1)
1752             {
1753                 shape = get_node_attr_ai(*node, "shape");
1754             }
1755             else
1756             {
1757                 // skip weight reshape
1758                 if (weights.find(node->input(1)) == weights.end())
1759                     continue;
1760 
1761                 shape = get_node_attr_from_input_ai(weights[node->input(1)]);
1762             }
1763 
1764             // -1, 3, upscale_factor, upscale_factor, height, width
1765             if (shape.size() != 6)
1766                 continue;
1767 
1768             if (shape[0] != 1 && shape[0] != -1)
1769                 continue;
1770 
1771             if (shape[2] != shape[3])
1772                 continue;
1773 
1774             if (i + 2 >= node_count)
1775                 continue;
1776 
1777             onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
1778             onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
1779 
1780             if (node3->op_type() == "Constant")
1781             {
1782                 if (i + 3 >= node_count)
1783                     continue;
1784 
1785                 node3 = mutable_graph->mutable_node(i + 3);
1786             }
1787 
1788             if (node2->op_type() != "Transpose" || node3->op_type() != "Reshape")
1789                 continue;
1790 
1791             if (node_reference[node2->output(0)] != 1)
1792                 continue;
1793 
1794             // 0 1 4 2 5 3
1795             std::vector<int> perm = get_node_attr_ai(*node2, "perm");
1796             if (perm.size() != 6)
1797                 continue;
1798 
1799             if (perm[0] != 0 || perm[1] != 1 || perm[2] != 4 || perm[3] != 2 || perm[4] != 5 || perm[5] != 3)
1800                 continue;
1801 
1802             std::vector<int> shape3;
1803             if (node3->input_size() == 1)
1804             {
1805                 shape3 = get_node_attr_ai(*node3, "shape");
1806             }
1807             else
1808             {
1809                 // skip weight reshape
1810                 if (weights.find(node3->input(1)) == weights.end())
1811                     continue;
1812 
1813                 shape3 = get_node_attr_from_input_ai(weights[node3->input(1)]);
1814             }
1815 
1816             // -1, 3, height, width
1817             if (shape3.size() != 4)
1818                 continue;
1819 
1820             if (shape3[0] != 1 && shape3[0] != -1)
1821                 continue;
1822 
1823             if (shape3[1] != shape[1] || shape3[2] != shape[2] * shape[4] || shape3[3] != shape[3] * shape[5])
1824                 continue;
1825 
1826             // reduce
1827             node->set_op_type("noop_reducedncnn");
1828             node2->set_op_type("noop_reducedncnn");
1829 
1830             if (node->input_size() == 2)
1831             {
1832                 node_reference[node->input(1)] -= 1;
1833             }
1834             node_reference[node->output(0)] -= 1;
1835             node_reference[node2->output(0)] -= 1;
1836             if (node3->input_size() == 2)
1837             {
1838                 node_reference[node3->input(1)] -= 1;
1839             }
1840 
1841             blob_names.erase(node->output(0));
1842             blob_names.erase(node2->output(0));
1843 
1844             node3->set_op_type("PixelShuffle");
1845             node3->set_input(0, node->input(0));
1846 
1847             onnx::AttributeProto* attr_group = node3->add_attribute();
1848             attr_group->set_name("scale_factor");
1849             attr_group->set_i(shape[2]);
1850 
1851             reduced_node_count += 2;
1852             i += 2;
1853         }
1854     }
1855 }
1856 
fuse_reorg(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)1857 static void fuse_reorg(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
1858 {
1859     int node_count = mutable_graph->node_size();
1860     for (int i = 0; i < node_count; i++)
1861     {
1862         onnx::NodeProto* node = mutable_graph->mutable_node(i);
1863 
1864         // PixelShuffle <= Reshape - Transpose - Reshape
1865         // PixelShuffle <= Reshape - Transpose - Constant - Reshape
1866         if (node->op_type() == "Reshape")
1867         {
1868             if (node_reference[node->output(0)] != 1)
1869                 continue;
1870 
1871             std::vector<int> shape;
1872             if (node->input_size() == 1)
1873             {
1874                 shape = get_node_attr_ai(*node, "shape");
1875             }
1876             else
1877             {
1878                 // skip weight reshape
1879                 if (weights.find(node->input(1)) == weights.end())
1880                     continue;
1881 
1882                 shape = get_node_attr_from_input_ai(weights[node->input(1)]);
1883             }
1884 
1885             // -1, 3, out_height, block_size, out_width, block_size
1886             if (shape.size() != 6)
1887                 continue;
1888 
1889             if (shape[0] != 1 && shape[0] != -1)
1890                 continue;
1891 
1892             if (shape[3] != shape[5])
1893                 continue;
1894 
1895             if (i + 2 >= node_count)
1896                 continue;
1897 
1898             onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
1899             onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
1900 
1901             if (node3->op_type() == "Constant")
1902             {
1903                 if (i + 3 >= node_count)
1904                     continue;
1905 
1906                 node3 = mutable_graph->mutable_node(i + 3);
1907             }
1908 
1909             if (node2->op_type() != "Transpose" || node3->op_type() != "Reshape")
1910                 continue;
1911 
1912             if (node_reference[node2->output(0)] != 1)
1913                 continue;
1914 
1915             // 0 1 3 5 2 4
1916             std::vector<int> perm = get_node_attr_ai(*node2, "perm");
1917             if (perm.size() != 6)
1918                 continue;
1919 
1920             if (perm[0] != 0 || perm[1] != 1 || perm[2] != 3 || perm[3] != 5 || perm[4] != 2 || perm[5] != 4)
1921                 continue;
1922 
1923             std::vector<int> shape3;
1924             if (node3->input_size() == 1)
1925             {
1926                 shape3 = get_node_attr_ai(*node3, "shape");
1927             }
1928             else
1929             {
1930                 // skip weight reshape
1931                 if (weights.find(node3->input(1)) == weights.end())
1932                     continue;
1933 
1934                 shape3 = get_node_attr_from_input_ai(weights[node3->input(1)]);
1935             }
1936 
1937             // -1, out_channels, out_height, out_width
1938             if (shape3.size() != 4)
1939                 continue;
1940 
1941             if (shape3[0] != 1 && shape3[0] != -1)
1942                 continue;
1943 
1944             if (shape3[1] != shape[1] * shape[3] * shape[5] || shape3[2] != shape[2] || shape3[3] != shape[4])
1945                 continue;
1946 
1947             // reduce
1948             node->set_op_type("noop_reducedncnn");
1949             node2->set_op_type("noop_reducedncnn");
1950 
1951             if (node->input_size() == 2)
1952             {
1953                 node_reference[node->input(1)] -= 1;
1954             }
1955             node_reference[node->output(0)] -= 1;
1956             node_reference[node2->output(0)] -= 1;
1957             if (node3->input_size() == 2)
1958             {
1959                 node_reference[node3->input(1)] -= 1;
1960             }
1961 
1962             blob_names.erase(node->output(0));
1963             blob_names.erase(node2->output(0));
1964 
1965             node3->set_op_type("Reorg");
1966             node3->set_input(0, node->input(0));
1967 
1968             onnx::AttributeProto* attr_group = node3->add_attribute();
1969             attr_group->set_name("stride");
1970             attr_group->set_i(shape[3]);
1971 
1972             reduced_node_count += 2;
1973             i += 2;
1974         }
1975     }
1976 }
1977 
fuse_expand_broadcast(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)1978 static void fuse_expand_broadcast(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
1979 {
1980     int node_count = mutable_graph->node_size();
1981     for (int i = 0; i < node_count; i++)
1982     {
1983         onnx::NodeProto* node = mutable_graph->mutable_node(i);
1984 
1985         // Add/Sub/Mul/Div/Min/Max <= Expand - Add/Sub/Mul/Div/Min/Max
1986         if (node->op_type() == "Expand")
1987         {
1988             if (node_reference[node->output(0)] != 1)
1989                 continue;
1990 
1991             if (i + 1 >= node_count)
1992                 continue;
1993 
1994             onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
1995 
1996             if (node2->op_type() != "Add" && node2->op_type() != "Sub" && node2->op_type() != "Mul" && node2->op_type() != "Div" && node2->op_type() != "Min" && node2->op_type() != "Max")
1997                 continue;
1998 
1999             if (node2->input(1) != node->output(0) && node2->input(0) != node->output(0))
2000                 continue;
2001 
2002             // reduce
2003             node->set_op_type("noop_reducedncnn");
2004 
2005             node_reference[node->output(0)] -= 1;
2006             if (node->input_size() == 2)
2007             {
2008                 node_reference[node->input(1)] -= 1;
2009             }
2010 
2011             blob_names.erase(node->output(0));
2012 
2013             node2->set_input(1, node->input(0));
2014 
2015             reduced_node_count += 1;
2016             i += 1;
2017         }
2018     }
2019 }
2020 
fuse_lstm_gru_rnn(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)2021 static void fuse_lstm_gru_rnn(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
2022 {
2023     int node_count = mutable_graph->node_size();
2024     for (int i = 0; i < node_count; i++)
2025     {
2026         onnx::NodeProto* node = mutable_graph->mutable_node(i);
2027 
2028         // LSTM(bi) <= LSTM(bi) - Transpose - Reshape - Transpose
2029         if (node->op_type() == "LSTM" || node->op_type() == "GRU" || node->op_type() == "RNN")
2030         {
2031             if (node_reference[node->output(0)] != 1)
2032                 continue;
2033 
2034             if (i + 2 >= node_count)
2035                 continue;
2036 
2037             onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
2038             onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
2039 
2040             if (node2->op_type() != "Transpose" || node3->op_type() != "Reshape")
2041                 continue;
2042 
2043             if (node_reference[node2->output(0)] != 1)
2044                 continue;
2045 
2046             if (node2->input(0) != node->output(0) || node3->input(0) != node2->output(0))
2047                 continue;
2048 
2049             std::string direction = get_node_attr_s(*node, "direction");
2050             if (direction != "bidirectional")
2051                 continue;
2052 
2053             // 0 2 1 3
2054             std::vector<int> perm = get_node_attr_ai(*node2, "perm");
2055             if (perm.size() != 4)
2056                 continue;
2057 
2058             if (perm[0] != 0 || perm[1] != 2 || perm[2] != 1 || perm[3] != 3)
2059                 continue;
2060 
2061             std::vector<int> shape;
2062             if (node3->input_size() == 1)
2063             {
2064                 shape = get_node_attr_ai(*node3, "shape");
2065             }
2066             else
2067             {
2068                 // skip weight reshape
2069                 if (weights.find(node3->input(1)) == weights.end())
2070                     continue;
2071 
2072                 shape = get_node_attr_from_input_ai(weights[node3->input(1)]);
2073             }
2074 
2075             // 0 0 -1
2076             if (shape.size() != 3)
2077                 continue;
2078 
2079             if (shape[0] != 0 || shape[1] != 0 || shape[2] != -1)
2080                 continue;
2081 
2082             // reduce
2083             node2->set_op_type("noop_reducedncnn");
2084             node3->set_op_type("noop_reducedncnn");
2085 
2086             node_reference[node->output(0)] -= 1;
2087             node_reference[node2->output(0)] -= 1;
2088             if (node3->input_size() == 2)
2089             {
2090                 node_reference[node3->input(1)] -= 1;
2091             }
2092 
2093             blob_names.erase(node->output(0));
2094             blob_names.erase(node2->output(0));
2095 
2096             node->set_output(0, node3->output(0));
2097 
2098             reduced_node_count += 2;
2099             i += 2;
2100 
2101             if (i + 1 < node_count)
2102             {
2103                 if (node_reference[node3->output(0)] != 1)
2104                     continue;
2105 
2106                 onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 1);
2107 
2108                 if (node4->op_type() != "Transpose")
2109                     continue;
2110 
2111                 if (node4->input(0) != node->output(0))
2112                     continue;
2113 
2114                 // 1 0 2
2115                 std::vector<int> perm4 = get_node_attr_ai(*node4, "perm");
2116                 if (perm4.size() != 3)
2117                     continue;
2118 
2119                 if (perm4[0] != 1 || perm4[1] != 0 || perm4[2] != 2)
2120                     continue;
2121 
2122                 // reduce
2123                 node4->set_op_type("noop_reducedncnn");
2124 
2125                 node_reference[node->output(0)] -= 1;
2126 
2127                 blob_names.erase(node->output(0));
2128 
2129                 node->clear_output();
2130                 node->add_output(node4->output(0));
2131 
2132                 reduced_node_count += 1;
2133                 i += 1;
2134             }
2135         }
2136     }
2137 
2138     for (int i = 0; i < node_count; i++)
2139     {
2140         onnx::NodeProto* node = mutable_graph->mutable_node(i);
2141 
2142         // LSTM(uni) <= LSTM(uni) - Squeeze - Transpose
2143         if (node->op_type() == "LSTM" || node->op_type() == "GRU" || node->op_type() == "RNN")
2144         {
2145             if (node_reference[node->output(0)] != 1)
2146                 continue;
2147 
2148             if (i + 1 >= node_count)
2149                 continue;
2150 
2151             onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
2152 
2153             if (node2->op_type() != "Squeeze")
2154                 continue;
2155 
2156             if (node2->input(0) != node->output(0))
2157                 continue;
2158 
2159             std::string direction = get_node_attr_s(*node, "direction");
2160             if (direction == "bidirectional")
2161                 continue;
2162 
2163             // 1
2164             std::vector<int> axes = get_node_attr_ai(*node2, "axes");
2165             if (axes.size() != 1)
2166                 continue;
2167 
2168             if (axes[0] != 1)
2169                 continue;
2170 
2171             // reduce
2172             node2->set_op_type("noop_reducedncnn");
2173 
2174             node_reference[node->output(0)] -= 1;
2175 
2176             blob_names.erase(node->output(0));
2177 
2178             node->set_output(0, node2->output(0));
2179 
2180             reduced_node_count += 1;
2181             i += 1;
2182 
2183             if (i + 1 < node_count)
2184             {
2185                 if (node_reference[node2->output(0)] != 1)
2186                     continue;
2187 
2188                 onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 1);
2189 
2190                 if (node3->op_type() != "Transpose")
2191                     continue;
2192 
2193                 if (node3->input(0) != node->output(0))
2194                     continue;
2195 
2196                 // 1 0 2
2197                 std::vector<int> perm4 = get_node_attr_ai(*node3, "perm");
2198                 if (perm4.size() != 3)
2199                     continue;
2200 
2201                 if (perm4[0] != 1 || perm4[1] != 0 || perm4[2] != 2)
2202                     continue;
2203 
2204                 // reduce
2205                 node3->set_op_type("noop_reducedncnn");
2206 
2207                 node_reference[node->output(0)] -= 1;
2208 
2209                 blob_names.erase(node->output(0));
2210 
2211                 node->clear_output();
2212                 node->add_output(node3->output(0));
2213 
2214                 reduced_node_count += 1;
2215                 i += 1;
2216             }
2217         }
2218     }
2219 
2220     for (int i = 0; i < node_count; i++)
2221     {
2222         onnx::NodeProto* node = mutable_graph->mutable_node(i);
2223 
2224         // LSTM <= Transpose - LSTM
2225         if (node->op_type() == "Transpose")
2226         {
2227             if (node_reference[node->output(0)] != 1)
2228                 continue;
2229 
2230             // 1 0 2
2231             std::vector<int> perm = get_node_attr_ai(*node, "perm");
2232             if (perm.size() != 3)
2233                 continue;
2234 
2235             if (perm[0] != 1 || perm[1] != 0 || perm[2] != 2)
2236                 continue;
2237 
2238             if (i + 1 >= node_count)
2239                 continue;
2240 
2241             onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
2242 
2243             if (node2->op_type() != "LSTM" && node->op_type() != "GRU" && node->op_type() != "RNN")
2244                 continue;
2245 
2246             if (node2->input(0) != node->output(0))
2247                 continue;
2248 
2249             // reduce
2250             node->set_op_type("noop_reducedncnn");
2251 
2252             node_reference[node->output(0)] -= 1;
2253 
2254             blob_names.erase(node->output(0));
2255 
2256             node2->set_input(0, node->input(0));
2257 
2258             reduced_node_count += 1;
2259             i += 1;
2260         }
2261     }
2262 }
2263 
fuse_multiheadattention(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)2264 static void fuse_multiheadattention(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
2265 {
2266     int node_count = mutable_graph->node_size();
2267     for (int i = 0; i < node_count; i++)
2268     {
2269         onnx::NodeProto* node = mutable_graph->mutable_node(i);
2270 
2271         // MultiHeadAttention <= MatMul(q) - Add
2272         //                      - MatMul(k) - Add
2273         //                      - MatMul(v) - Add
2274         //                      - Mul
2275         //                      - Reshape - Transpose
2276         //                      - Reshape - Reshape - Transpose - Transpose
2277         //                      - Gemm - Softmax - Gemm - Transpose - Reshape - MatMul - Add
2278         if (node->op_type() == "MatMul")
2279         {
2280             if (i + 19 >= node_count)
2281                 continue;
2282 
2283             if (node_reference[node->output(0)] != 1)
2284                 continue;
2285 
2286             onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
2287             onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
2288             onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3);
2289             onnx::NodeProto* node5 = mutable_graph->mutable_node(i + 4);
2290             onnx::NodeProto* node6 = mutable_graph->mutable_node(i + 5);
2291             onnx::NodeProto* node7 = mutable_graph->mutable_node(i + 6);
2292             onnx::NodeProto* node8 = mutable_graph->mutable_node(i + 7);
2293             onnx::NodeProto* node9 = mutable_graph->mutable_node(i + 8);
2294             onnx::NodeProto* node10 = mutable_graph->mutable_node(i + 9);
2295             onnx::NodeProto* node11 = mutable_graph->mutable_node(i + 10);
2296             onnx::NodeProto* node12 = mutable_graph->mutable_node(i + 11);
2297             onnx::NodeProto* node13 = mutable_graph->mutable_node(i + 12);
2298             onnx::NodeProto* node14 = mutable_graph->mutable_node(i + 13);
2299             onnx::NodeProto* node15 = mutable_graph->mutable_node(i + 14);
2300             onnx::NodeProto* node16 = mutable_graph->mutable_node(i + 15);
2301             onnx::NodeProto* node17 = mutable_graph->mutable_node(i + 16);
2302             onnx::NodeProto* node18 = mutable_graph->mutable_node(i + 17);
2303             onnx::NodeProto* node19 = mutable_graph->mutable_node(i + 18);
2304             onnx::NodeProto* node20 = mutable_graph->mutable_node(i + 19);
2305 
2306             if (node2->op_type() != "Add" || node3->op_type() != "MatMul" || node4->op_type() != "Add" || node5->op_type() != "MatMul" || node6->op_type() != "Add" || node7->op_type() != "Mul" || node8->op_type() != "Reshape" || node9->op_type() != "Transpose" || node10->op_type() != "Reshape" || node11->op_type() != "Reshape" || node12->op_type() != "Transpose" || node13->op_type() != "Transpose" || node14->op_type() != "MatMul" || node15->op_type() != "Softmax" || node16->op_type() != "MatMul" || node17->op_type() != "Transpose" || node18->op_type() != "Reshape" || node19->op_type() != "MatMul" || node20->op_type() != "Add")
2307                 continue;
2308 
2309             if (node_reference[node2->output(0)] != 1 || node_reference[node3->output(0)] != 1 || node_reference[node4->output(0)] != 1 || node_reference[node5->output(0)] != 1 || node_reference[node6->output(0)] != 1 || node_reference[node7->output(0)] != 1 || node_reference[node8->output(0)] != 1 || node_reference[node9->output(0)] != 1 || node_reference[node10->output(0)] != 1 || node_reference[node11->output(0)] != 1 || node_reference[node12->output(0)] != 1 || node_reference[node13->output(0)] != 1 || node_reference[node14->output(0)] != 1 || node_reference[node15->output(0)] != 1 || node_reference[node16->output(0)] != 1 || node_reference[node17->output(0)] != 1 || node_reference[node18->output(0)] != 1 || node_reference[node19->output(0)] != 1)
2310                 continue;
2311 
2312             if (node2->input(0) != node->output(0) || node4->input(0) != node3->output(0) || node6->input(0) != node5->output(0) || node7->input(0) != node2->output(0) || node8->input(0) != node7->output(0) || node9->input(0) != node8->output(0) || node10->input(0) != node4->output(0) || node11->input(0) != node6->output(0) || node12->input(0) != node11->output(0) || node13->input(0) != node10->output(0) || node14->input(0) != node9->output(0) || node14->input(1) != node13->output(0) || node15->input(0) != node14->output(0) || node16->input(0) != node15->output(0) || node16->input(1) != node12->output(0) || node17->input(0) != node16->output(0) || node18->input(0) != node17->output(0) || node19->input(0) != node18->output(0) || node20->input(0) != node19->output(0))
2313                 continue;
2314 
2315             std::vector<float> q_B = get_node_attr_from_input_af(weights[node2->input(1)]);
2316             std::vector<float> k_B = get_node_attr_from_input_af(weights[node4->input(1)]);
2317             std::vector<float> v_B = get_node_attr_from_input_af(weights[node6->input(1)]);
2318             std::vector<float> o_B = get_node_attr_from_input_af(weights[node20->input(1)]);
2319 
2320             if (q_B.size() != k_B.size() || q_B.size() != v_B.size() || q_B.size() != o_B.size())
2321                 continue;
2322 
2323             int embed_dim = q_B.size();
2324 
2325             // 1 0 2
2326             std::vector<int> perm9 = get_node_attr_ai(*node9, "perm");
2327             std::vector<int> perm12 = get_node_attr_ai(*node12, "perm");
2328             if (perm9.size() != 3 || perm12.size() != 3)
2329                 continue;
2330 
2331             if (perm9[0] != 1 || perm9[1] != 0 || perm9[2] != 2 || perm12[0] != 1 || perm12[1] != 0 || perm12[2] != 2)
2332                 continue;
2333 
2334             // 1 2 0
2335             std::vector<int> perm13 = get_node_attr_ai(*node13, "perm");
2336             if (perm13.size() != 3)
2337                 continue;
2338 
2339             if (perm13[0] != 1 || perm13[1] != 2 || perm13[2] != 0)
2340                 continue;
2341 
2342             // 1 0 2
2343             std::vector<int> perm17 = get_node_attr_ai(*node17, "perm");
2344             if (perm17.size() != 3)
2345                 continue;
2346 
2347             if (perm17[0] != 1 || perm17[1] != 0 || perm17[2] != 2)
2348                 continue;
2349 
2350             int softmax_axis = get_node_attr_i(*node15, "axis");
2351             if (softmax_axis != 2)
2352                 continue;
2353 
2354             // 1/-1, seqlen * num_heads, embed_dim / num_heads
2355             std::vector<int> shape8;
2356             std::vector<int> shape10;
2357             std::vector<int> shape11;
2358             if (node8->input_size() == 1)
2359             {
2360                 shape8 = get_node_attr_ai(*node8, "shape");
2361             }
2362             else
2363             {
2364                 // skip weight reshape
2365                 if (weights.find(node8->input(1)) == weights.end())
2366                     continue;
2367 
2368                 shape8 = get_node_attr_from_input_ai(weights[node8->input(1)]);
2369             }
2370             if (node10->input_size() == 1)
2371             {
2372                 shape10 = get_node_attr_ai(*node10, "shape");
2373             }
2374             else
2375             {
2376                 // skip weight reshape
2377                 if (weights.find(node10->input(1)) == weights.end())
2378                     continue;
2379 
2380                 shape10 = get_node_attr_from_input_ai(weights[node10->input(1)]);
2381             }
2382             if (node11->input_size() == 1)
2383             {
2384                 shape11 = get_node_attr_ai(*node11, "shape");
2385             }
2386             else
2387             {
2388                 // skip weight reshape
2389                 if (weights.find(node11->input(1)) == weights.end())
2390                     continue;
2391 
2392                 shape11 = get_node_attr_from_input_ai(weights[node11->input(1)]);
2393             }
2394 
2395             if (shape8.size() != 3 || shape10.size() != 3 || shape11.size() != 3)
2396                 continue;
2397 
2398             if (shape8[1] != shape10[1] || shape8[1] != shape11[1] || shape8[2] != shape10[2] || shape8[2] != shape11[2])
2399                 continue;
2400 
2401             int num_heads = embed_dim / shape8[2];
2402 
2403             // 1, seqlen, embed_dim
2404             std::vector<int> shape18;
2405             if (node18->input_size() == 1)
2406             {
2407                 shape18 = get_node_attr_ai(*node18, "shape");
2408             }
2409             else
2410             {
2411                 // skip weight reshape
2412                 if (weights.find(node18->input(1)) == weights.end())
2413                     continue;
2414 
2415                 shape18 = get_node_attr_from_input_ai(weights[node18->input(1)]);
2416             }
2417 
2418             if (shape18.size() != 3)
2419                 continue;
2420 
2421             if (shape18[2] != embed_dim || shape18[1] * num_heads != shape8[1])
2422                 continue;
2423 
2424             // reduce
2425             node->set_op_type("noop_reducedncnn");
2426             node2->set_op_type("noop_reducedncnn");
2427             node3->set_op_type("noop_reducedncnn");
2428             node4->set_op_type("noop_reducedncnn");
2429             node5->set_op_type("noop_reducedncnn");
2430             node6->set_op_type("noop_reducedncnn");
2431             node7->set_op_type("noop_reducedncnn");
2432             node8->set_op_type("noop_reducedncnn");
2433             node9->set_op_type("noop_reducedncnn");
2434             node10->set_op_type("noop_reducedncnn");
2435             node11->set_op_type("noop_reducedncnn");
2436             node12->set_op_type("noop_reducedncnn");
2437             node13->set_op_type("noop_reducedncnn");
2438             node14->set_op_type("noop_reducedncnn");
2439             node15->set_op_type("noop_reducedncnn");
2440             node16->set_op_type("noop_reducedncnn");
2441             node17->set_op_type("noop_reducedncnn");
2442             node18->set_op_type("noop_reducedncnn");
2443             node19->set_op_type("noop_reducedncnn");
2444 
2445             node_reference[node2->input(0)] -= 1;
2446             node_reference[node4->input(0)] -= 1;
2447             node_reference[node6->input(0)] -= 1;
2448             node_reference[node7->input(0)] -= 1;
2449             node_reference[node7->input(1)] -= 1;
2450             node_reference[node8->input(0)] -= 1;
2451             if (node8->input_size() == 2)
2452             {
2453                 node_reference[node8->input(1)] -= 1;
2454             }
2455             node_reference[node9->input(0)] -= 1;
2456             node_reference[node10->input(0)] -= 1;
2457             if (node10->input_size() == 2)
2458             {
2459                 node_reference[node10->input(1)] -= 1;
2460             }
2461             node_reference[node11->input(0)] -= 1;
2462             if (node11->input_size() == 2)
2463             {
2464                 node_reference[node11->input(1)] -= 1;
2465             }
2466             node_reference[node12->input(0)] -= 1;
2467             node_reference[node13->input(0)] -= 1;
2468             node_reference[node14->input(0)] -= 1;
2469             node_reference[node14->input(1)] -= 1;
2470             node_reference[node15->input(0)] -= 1;
2471             node_reference[node16->input(0)] -= 1;
2472             node_reference[node16->input(1)] -= 1;
2473             node_reference[node17->input(0)] -= 1;
2474             node_reference[node18->input(0)] -= 1;
2475             if (node18->input_size() == 2)
2476             {
2477                 node_reference[node18->input(1)] -= 1;
2478             }
2479             node_reference[node19->input(0)] -= 1;
2480             node_reference[node20->input(0)] -= 1;
2481 
2482             blob_names.erase(node->output(0));
2483             blob_names.erase(node2->output(0));
2484             blob_names.erase(node3->output(0));
2485             blob_names.erase(node4->output(0));
2486             blob_names.erase(node5->output(0));
2487             blob_names.erase(node6->output(0));
2488             blob_names.erase(node7->output(0));
2489             blob_names.erase(node8->output(0));
2490             blob_names.erase(node9->output(0));
2491             blob_names.erase(node10->output(0));
2492             blob_names.erase(node11->output(0));
2493             blob_names.erase(node12->output(0));
2494             blob_names.erase(node13->output(0));
2495             blob_names.erase(node14->output(0));
2496             blob_names.erase(node15->output(0));
2497             blob_names.erase(node16->output(0));
2498             blob_names.erase(node17->output(0));
2499             blob_names.erase(node18->output(0));
2500             blob_names.erase(node19->output(0));
2501 
2502             std::string qw = node->input(1);
2503             std::string qb = node2->input(1);
2504             std::string kw = node3->input(1);
2505             std::string kb = node4->input(1);
2506             std::string vw = node5->input(1);
2507             std::string vb = node6->input(1);
2508             std::string ow = node19->input(1);
2509             std::string ob = node20->input(1);
2510 
2511             node20->set_op_type("MultiHeadAttention");
2512             node20->clear_input();
2513             node20->add_input(node->input(0));
2514             node20->add_input(node3->input(0));
2515             node20->add_input(node5->input(0));
2516             // q
2517             node20->add_input(qw);
2518             node20->add_input(qb);
2519             // k
2520             node20->add_input(kw);
2521             node20->add_input(kb);
2522             // v
2523             node20->add_input(vw);
2524             node20->add_input(vb);
2525             // out linear
2526             node20->add_input(ow);
2527             node20->add_input(ob);
2528 
2529             onnx::AttributeProto* attr_embed_dim = node20->add_attribute();
2530             attr_embed_dim->set_name("embed_dim");
2531             attr_embed_dim->set_i(embed_dim);
2532 
2533             onnx::AttributeProto* attr_num_heads = node20->add_attribute();
2534             attr_num_heads->set_name("num_heads");
2535             attr_num_heads->set_i(num_heads);
2536 
2537             reduced_node_count += 19;
2538             i += 19;
2539         }
2540     }
2541 
2542     for (int i = 0; i < node_count; i++)
2543     {
2544         onnx::NodeProto* node = mutable_graph->mutable_node(i);
2545 
2546         // MultiHeadAttention <= MatMul(qkv) - Add - Split
2547         //                      - Mul
2548         //                      - Reshape - Transpose
2549         //                      - Reshape - Reshape - Transpose - Transpose
2550         //                      - Gemm - Softmax - Gemm - Transpose - Reshape - MatMul - Add
2551         if (node->op_type() == "MatMul")
2552         {
2553             if (i + 16 >= node_count)
2554                 continue;
2555 
2556             if (node_reference[node->output(0)] != 1)
2557                 continue;
2558 
2559             onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
2560             onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
2561             onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3);
2562             onnx::NodeProto* node5 = mutable_graph->mutable_node(i + 4);
2563             onnx::NodeProto* node6 = mutable_graph->mutable_node(i + 5);
2564             onnx::NodeProto* node7 = mutable_graph->mutable_node(i + 6);
2565             onnx::NodeProto* node8 = mutable_graph->mutable_node(i + 7);
2566             onnx::NodeProto* node9 = mutable_graph->mutable_node(i + 8);
2567             onnx::NodeProto* node10 = mutable_graph->mutable_node(i + 9);
2568             onnx::NodeProto* node11 = mutable_graph->mutable_node(i + 10);
2569             onnx::NodeProto* node12 = mutable_graph->mutable_node(i + 11);
2570             onnx::NodeProto* node13 = mutable_graph->mutable_node(i + 12);
2571             onnx::NodeProto* node14 = mutable_graph->mutable_node(i + 13);
2572             onnx::NodeProto* node15 = mutable_graph->mutable_node(i + 14);
2573             onnx::NodeProto* node16 = mutable_graph->mutable_node(i + 15);
2574             onnx::NodeProto* node17 = mutable_graph->mutable_node(i + 16);
2575 
2576             if (node2->op_type() != "Add" || node3->op_type() != "Split" || node4->op_type() != "Mul" || node5->op_type() != "Reshape" || node6->op_type() != "Transpose" || node7->op_type() != "Reshape" || node8->op_type() != "Reshape" || node9->op_type() != "Transpose" || node10->op_type() != "Transpose" || node11->op_type() != "MatMul" || node12->op_type() != "Softmax" || node13->op_type() != "MatMul" || node14->op_type() != "Transpose" || node15->op_type() != "Reshape" || node16->op_type() != "MatMul" || node17->op_type() != "Add")
2577                 continue;
2578 
2579             if (node_reference[node2->output(0)] != 1 || node_reference[node3->output(0)] != 1 || node_reference[node3->output(1)] != 1 || node_reference[node3->output(2)] != 1 || node_reference[node4->output(0)] != 1 || node_reference[node5->output(0)] != 1 || node_reference[node6->output(0)] != 1 || node_reference[node7->output(0)] != 1 || node_reference[node8->output(0)] != 1 || node_reference[node9->output(0)] != 1 || node_reference[node10->output(0)] != 1 || node_reference[node11->output(0)] != 1 || node_reference[node12->output(0)] != 1 || node_reference[node13->output(0)] != 1 || node_reference[node14->output(0)] != 1 || node_reference[node15->output(0)] != 1 || node_reference[node16->output(0)] != 1)
2580                 continue;
2581 
2582             if (node2->input(0) != node->output(0) || node3->input(0) != node2->output(0) || node4->input(0) != node3->output(0) || node5->input(0) != node4->output(0) || node6->input(0) != node5->output(0) || node7->input(0) != node3->output(1) || node8->input(0) != node3->output(2) || node9->input(0) != node8->output(0) || node10->input(0) != node7->output(0) || node11->input(0) != node6->output(0) || node11->input(1) != node10->output(0) || node12->input(0) != node11->output(0) || node13->input(0) != node12->output(0) || node13->input(1) != node9->output(0) || node14->input(0) != node13->output(0) || node15->input(0) != node14->output(0) || node16->input(0) != node15->output(0) || node17->input(0) != node16->output(0))
2583                 continue;
2584 
2585             std::vector<float> qkv_B = get_node_attr_from_input_af(weights[node2->input(1)]);
2586             std::vector<float> o_B = get_node_attr_from_input_af(weights[node17->input(1)]);
2587 
2588             if (qkv_B.size() != o_B.size() * 3)
2589                 continue;
2590 
2591             int embed_dim = o_B.size();
2592 
2593             // 1 0 2
2594             std::vector<int> perm6 = get_node_attr_ai(*node6, "perm");
2595             std::vector<int> perm9 = get_node_attr_ai(*node9, "perm");
2596             if (perm6.size() != 3 || perm9.size() != 3)
2597                 continue;
2598 
2599             if (perm6[0] != 1 || perm6[1] != 0 || perm6[2] != 2 || perm9[0] != 1 || perm9[1] != 0 || perm9[2] != 2)
2600                 continue;
2601 
2602             // 1 2 0
2603             std::vector<int> perm10 = get_node_attr_ai(*node10, "perm");
2604             if (perm10.size() != 3)
2605                 continue;
2606 
2607             if (perm10[0] != 1 || perm10[1] != 2 || perm10[2] != 0)
2608                 continue;
2609 
2610             // 1 0 2
2611             std::vector<int> perm14 = get_node_attr_ai(*node14, "perm");
2612             if (perm14.size() != 3)
2613                 continue;
2614 
2615             if (perm14[0] != 1 || perm14[1] != 0 || perm14[2] != 2)
2616                 continue;
2617 
2618             int softmax_axis = get_node_attr_i(*node12, "axis");
2619             if (softmax_axis != 2)
2620                 continue;
2621 
2622             // 1/-1, seqlen * num_heads, embed_dim / num_heads
2623             std::vector<int> shape5;
2624             std::vector<int> shape7;
2625             std::vector<int> shape8;
2626             if (node5->input_size() == 1)
2627             {
2628                 shape5 = get_node_attr_ai(*node5, "shape");
2629             }
2630             else
2631             {
2632                 // skip weight reshape
2633                 if (weights.find(node5->input(1)) == weights.end())
2634                     continue;
2635 
2636                 shape5 = get_node_attr_from_input_ai(weights[node5->input(1)]);
2637             }
2638             if (node7->input_size() == 1)
2639             {
2640                 shape7 = get_node_attr_ai(*node7, "shape");
2641             }
2642             else
2643             {
2644                 // skip weight reshape
2645                 if (weights.find(node7->input(1)) == weights.end())
2646                     continue;
2647 
2648                 shape7 = get_node_attr_from_input_ai(weights[node7->input(1)]);
2649             }
2650             if (node8->input_size() == 1)
2651             {
2652                 shape8 = get_node_attr_ai(*node8, "shape");
2653             }
2654             else
2655             {
2656                 // skip weight reshape
2657                 if (weights.find(node8->input(1)) == weights.end())
2658                     continue;
2659 
2660                 shape8 = get_node_attr_from_input_ai(weights[node8->input(1)]);
2661             }
2662 
2663             if (shape5.size() != 3 || shape7.size() != 3 || shape8.size() != 3)
2664                 continue;
2665 
2666             if (shape5[1] != shape7[1] || shape5[1] != shape8[1] || shape5[2] != shape7[2] || shape5[2] != shape8[2])
2667                 continue;
2668 
2669             int num_heads = embed_dim / shape5[2];
2670 
2671             // 1, seqlen, embed_dim
2672             std::vector<int> shape15;
2673             if (node15->input_size() == 1)
2674             {
2675                 shape15 = get_node_attr_ai(*node15, "shape");
2676             }
2677             else
2678             {
2679                 // skip weight reshape
2680                 if (weights.find(node15->input(1)) == weights.end())
2681                     continue;
2682 
2683                 shape15 = get_node_attr_from_input_ai(weights[node15->input(1)]);
2684             }
2685 
2686             if (shape15.size() != 3)
2687                 continue;
2688 
2689             if (shape15[2] != embed_dim || shape15[1] * num_heads != shape8[1])
2690                 continue;
2691 
2692             // reduce
2693             node->set_op_type("noop_reducedncnn");
2694             node2->set_op_type("noop_reducedncnn");
2695             node3->set_op_type("noop_reducedncnn");
2696             node4->set_op_type("noop_reducedncnn");
2697             node5->set_op_type("noop_reducedncnn");
2698             node6->set_op_type("noop_reducedncnn");
2699             node7->set_op_type("noop_reducedncnn");
2700             node8->set_op_type("noop_reducedncnn");
2701             node9->set_op_type("noop_reducedncnn");
2702             node10->set_op_type("noop_reducedncnn");
2703             node11->set_op_type("noop_reducedncnn");
2704             node12->set_op_type("noop_reducedncnn");
2705             node13->set_op_type("noop_reducedncnn");
2706             node14->set_op_type("noop_reducedncnn");
2707             node15->set_op_type("noop_reducedncnn");
2708             node16->set_op_type("noop_reducedncnn");
2709 
2710             node_reference[node2->input(0)] -= 1;
2711             node_reference[node3->input(0)] -= 1;
2712             node_reference[node4->input(0)] -= 1;
2713             node_reference[node4->input(1)] -= 1;
2714             node_reference[node5->input(0)] -= 1;
2715             if (node5->input_size() == 2)
2716             {
2717                 node_reference[node5->input(1)] -= 1;
2718             }
2719             node_reference[node6->input(0)] -= 1;
2720             node_reference[node7->input(0)] -= 1;
2721             if (node7->input_size() == 2)
2722             {
2723                 node_reference[node7->input(1)] -= 1;
2724             }
2725             node_reference[node8->input(0)] -= 1;
2726             if (node8->input_size() == 2)
2727             {
2728                 node_reference[node8->input(1)] -= 1;
2729             }
2730             node_reference[node9->input(0)] -= 1;
2731             node_reference[node10->input(0)] -= 1;
2732             node_reference[node11->input(0)] -= 1;
2733             node_reference[node11->input(1)] -= 1;
2734             node_reference[node12->input(0)] -= 1;
2735             node_reference[node13->input(0)] -= 1;
2736             node_reference[node13->input(1)] -= 1;
2737             node_reference[node14->input(0)] -= 1;
2738             node_reference[node15->input(0)] -= 1;
2739             if (node15->input_size() == 2)
2740             {
2741                 node_reference[node15->input(1)] -= 1;
2742             }
2743             node_reference[node16->input(0)] -= 1;
2744             node_reference[node17->input(0)] -= 1;
2745 
2746             blob_names.erase(node->output(0));
2747             blob_names.erase(node2->output(0));
2748             blob_names.erase(node3->output(0));
2749             blob_names.erase(node3->output(1));
2750             blob_names.erase(node3->output(2));
2751             blob_names.erase(node4->output(0));
2752             blob_names.erase(node5->output(0));
2753             blob_names.erase(node6->output(0));
2754             blob_names.erase(node7->output(0));
2755             blob_names.erase(node8->output(0));
2756             blob_names.erase(node9->output(0));
2757             blob_names.erase(node10->output(0));
2758             blob_names.erase(node11->output(0));
2759             blob_names.erase(node12->output(0));
2760             blob_names.erase(node13->output(0));
2761             blob_names.erase(node14->output(0));
2762             blob_names.erase(node15->output(0));
2763             blob_names.erase(node16->output(0));
2764 
2765             std::string qkvw = node->input(1);
2766             std::string qkvb = node2->input(1);
2767             std::string ow = node16->input(1);
2768             std::string ob = node17->input(1);
2769 
2770             node17->set_op_type("MultiHeadAttention");
2771             node17->clear_input();
2772             node17->add_input(node->input(0));
2773             // qkv
2774             node17->add_input(qkvw);
2775             node17->add_input(qkvb);
2776             // out linear
2777             node17->add_input(ow);
2778             node17->add_input(ob);
2779 
2780             onnx::AttributeProto* attr_embed_dim = node17->add_attribute();
2781             attr_embed_dim->set_name("embed_dim");
2782             attr_embed_dim->set_i(embed_dim);
2783 
2784             onnx::AttributeProto* attr_num_heads = node17->add_attribute();
2785             attr_num_heads->set_name("num_heads");
2786             attr_num_heads->set_i(num_heads);
2787 
2788             reduced_node_count += 16;
2789             i += 16;
2790         }
2791     }
2792 }
2793 
fuse_binaryop_with_scalar(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)2794 static void fuse_binaryop_with_scalar(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
2795 {
2796     int node_count = mutable_graph->node_size();
2797     for (int i = 0; i < node_count; i++)
2798     {
2799         onnx::NodeProto* node = mutable_graph->mutable_node(i);
2800 
2801         // Add/Sub/Mul/Div/Min/Max/Pow
2802         if (node->op_type() == "Add" || node->op_type() == "Sub" || node->op_type() == "Mul" || node->op_type() == "Div" || node->op_type() == "Max" || node->op_type() == "Min" || node->op_type() == "Pow")
2803         {
2804             if (weights.find(node->input(1)) == weights.end())
2805                 continue;
2806 
2807             const onnx::TensorProto& scalar_b = weights[node->input(1)];
2808             if (scalar_b.dims_size() != 0 || get_tensor_proto_data_size(scalar_b) != 1)
2809                 continue;
2810 
2811             float b = get_node_attr_from_input_f(scalar_b);
2812 
2813             node_reference[node->input(1)] -= 1;
2814 
2815             std::string input = node->input(0);
2816 
2817             node->clear_input();
2818             node->add_input(input);
2819 
2820             onnx::AttributeProto* attr_with_scalar = node->add_attribute();
2821             attr_with_scalar->set_name("with_scalar");
2822             attr_with_scalar->set_i(1);
2823 
2824             onnx::AttributeProto* attr_b = node->add_attribute();
2825             attr_b->set_name("b");
2826             attr_b->set_f(b);
2827         }
2828     }
2829 }
2830 
main(int argc,char ** argv)2831 int main(int argc, char** argv)
2832 {
2833     const char* onnxpb = argv[1];
2834     const char* ncnn_prototxt = argc >= 4 ? argv[2] : "ncnn.param";
2835     const char* ncnn_modelbin = argc >= 4 ? argv[3] : "ncnn.bin";
2836 
2837     onnx::ModelProto model;
2838 
2839     // load
2840     bool s1 = read_proto_from_binary(onnxpb, &model);
2841     if (!s1)
2842     {
2843         fprintf(stderr, "read_proto_from_binary failed\n");
2844         return -1;
2845     }
2846 
2847     FILE* pp = fopen(ncnn_prototxt, "wb");
2848     FILE* bp = fopen(ncnn_modelbin, "wb");
2849 
2850     // magic
2851     fprintf(pp, "7767517\n");
2852 
2853     const onnx::GraphProto& graph = model.graph();
2854     onnx::GraphProto* mutable_graph = model.mutable_graph();
2855 
2856     int node_count = graph.node_size();
2857 
2858     // node reference
2859     std::map<std::string, int> node_reference;
2860 
2861     // weight node and weight reshape node
2862     std::map<std::string, onnx::TensorProto> weights;
2863 
2864     for (int j = 0; j < graph.initializer_size(); j++)
2865     {
2866         const onnx::TensorProto& initializer = graph.initializer(j);
2867 
2868         //         fprintf(stderr, "weight = %s %d\n", initializer.name().c_str(), initializer.data_type());
2869 
2870         weights[initializer.name()] = initializer;
2871     }
2872 
2873     // topological sort
2874     {
2875         // name -> producer node index
2876         std::set<std::string> producers;
2877         for (int j = 0; j < graph.input_size(); j++)
2878         {
2879             const std::string& input_name = graph.input(j).name();
2880             producers.insert(input_name);
2881         }
2882 
2883         for (int i = 0; i < node_count;)
2884         {
2885             onnx::NodeProto* node = mutable_graph->mutable_node(i);
2886 
2887             bool swapnode = false;
2888             std::string missing_input_name;
2889             for (int j = 0; j < (int)node->input_size(); j++)
2890             {
2891                 const std::string& input_name = node->input(j);
2892                 if (input_name.empty())
2893                     continue;
2894 
2895                 if (producers.find(input_name) == producers.end() && weights.find(input_name) == weights.end())
2896                 {
2897                     swapnode = true;
2898                     missing_input_name = input_name;
2899                     break;
2900                 }
2901             }
2902 
2903             if (!swapnode)
2904             {
2905                 for (int j = 0; j < (int)node->output_size(); j++)
2906                 {
2907                     const std::string& output_name = node->output(j);
2908                     if (output_name.empty())
2909                         continue;
2910 
2911                     producers.insert(output_name);
2912                 }
2913 
2914                 i++;
2915                 continue;
2916             }
2917 
2918             // find node that produce missing_input_name
2919             int q = i + 1;
2920             for (; q < node_count; q++)
2921             {
2922                 onnx::NodeProto* nodeq = mutable_graph->mutable_node(q);
2923                 bool found = false;
2924                 for (int j = 0; j < (int)nodeq->output_size(); j++)
2925                 {
2926                     const std::string& output_name = nodeq->output(j);
2927                     if (output_name == missing_input_name)
2928                     {
2929                         found = true;
2930                         break;
2931                     }
2932                 }
2933 
2934                 if (found)
2935                     break;
2936             }
2937 
2938             if (q == node_count)
2939             {
2940                 fprintf(stderr, "cannot find node produces %s but node %d requires it\n", missing_input_name.c_str(), i);
2941                 return -1;
2942             }
2943 
2944             // fprintf(stderr, "swap %d %d\n", i, q);
2945             // swap this node with q
2946             onnx::NodeProto* nodeq = mutable_graph->mutable_node(q);
2947             onnx::NodeProto tmp = *node;
2948             *node = *nodeq;
2949             *nodeq = tmp;
2950         }
2951     }
2952 
2953     // global definition line
2954     // [layer count] [blob count]
2955     std::set<std::string> blob_names;
2956     for (int i = 0; i < node_count; i++)
2957     {
2958         const onnx::NodeProto& node = graph.node(i);
2959 
2960         const std::string& op = node.op_type();
2961 
2962         std::string name = node.name();
2963         if (name.empty())
2964         {
2965             name = node.output(0);
2966         }
2967 
2968         if (op == "Constant")
2969         {
2970             onnx::TensorProto tensor = get_node_attr_tensor(node, "value");
2971             weights[node.output(0)] = tensor;
2972         }
2973 
2974         for (int j = 0; j < (int)node.input_size(); j++)
2975         {
2976             const std::string& input_name = node.input(j);
2977 
2978             blob_names.insert(input_name);
2979 
2980             if (node_reference.find(input_name) == node_reference.end())
2981             {
2982                 node_reference[input_name] = 1;
2983             }
2984             else
2985             {
2986                 node_reference[input_name] = node_reference[input_name] + 1;
2987             }
2988         }
2989 
2990         if (op == "Dropout")
2991         {
2992             const std::string& output_name = node.output(0);
2993             blob_names.insert(output_name);
2994             node_reference[output_name] = 0;
2995             continue;
2996         }
2997 
2998         for (int j = 0; j < (int)node.output_size(); j++)
2999         {
3000             const std::string& output_name = node.output(j);
3001 
3002             blob_names.insert(output_name);
3003 
3004             node_reference[output_name] = 0;
3005         }
3006     }
3007 
3008     // include Input node
3009     int input_node_count = 0;
3010     for (int j = 0; j < graph.input_size(); j++)
3011     {
3012         const std::string& input_name = graph.input(j).name();
3013 
3014         // check weight
3015         if (weights.find(input_name) != weights.end())
3016             continue;
3017 
3018         blob_names.insert(input_name);
3019 
3020         input_node_count++;
3021     }
3022 
3023     //     for (auto a: node_reference)
3024     //     {
3025     //         fprintf(stderr, "a = %s %d\n", a.first.c_str(), a.second);
3026     //     }
3027 
3028     // op chain fusion
3029     int reduced_node_count = 0;
3030     fuse_weight_reshape(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3031     fuse_weight_transpose(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3032     fuse_shufflechannel(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3033     fuse_shufflechannel_split(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3034     fuse_hardsigmoid(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3035     fuse_hardswish(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3036     fuse_swish(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3037     fuse_batchnorm1d_squeeze_unsqueeze(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3038     fuse_unsqueeze_prelu(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3039     fuse_normalize(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3040     fuse_groupnorm(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3041     fuse_layernorm(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3042     fuse_flatten(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3043     fuse_pixelshuffle(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3044     fuse_reorg(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3045     fuse_expand_broadcast(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3046     fuse_lstm_gru_rnn(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3047     fuse_multiheadattention(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3048     fuse_binaryop_with_scalar(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3049 
3050     // reduce common const weight node_reference
3051     for (int i = 0; i < node_count; i++)
3052     {
3053         const onnx::NodeProto& node = graph.node(i);
3054 
3055         const std::string& op = node.op_type();
3056 
3057         if (op == "BatchNormalization")
3058         {
3059             node_reference[node.input(1)] -= 1;
3060             node_reference[node.input(2)] -= 1;
3061             node_reference[node.input(3)] -= 1;
3062             node_reference[node.input(4)] -= 1;
3063         }
3064         else if (op == "BiasGelu")
3065         {
3066             node_reference[node.input(1)] -= 1;
3067         }
3068         else if (op == "Clip")
3069         {
3070             if (node.input_size() == 3)
3071             {
3072                 node_reference[node.input(1)] -= 1;
3073                 node_reference[node.input(2)] -= 1;
3074             }
3075         }
3076         else if (op == "Conv")
3077         {
3078             node_reference[node.input(1)] -= 1;
3079             if (node.input_size() == 3)
3080             {
3081                 node_reference[node.input(2)] -= 1;
3082             }
3083         }
3084         else if (op == "ConvTranspose")
3085         {
3086             node_reference[node.input(1)] -= 1;
3087             if (node.input_size() == 3)
3088             {
3089                 node_reference[node.input(2)] -= 1;
3090             }
3091         }
3092         else if (op == "EmbedLayerNormalization")
3093         {
3094             node_reference[node.input(1)] -= 1;
3095             node_reference[node.input(2)] -= 1;
3096             node_reference[node.input(3)] -= 1;
3097             node_reference[node.input(4)] -= 1;
3098             node_reference[node.input(5)] -= 1;
3099             node_reference[node.input(6)] -= 1;
3100         }
3101         else if (op == "Gemm")
3102         {
3103             float alpha = get_node_attr_f(node, "alpha", 1.f);
3104             float beta = get_node_attr_f(node, "beta", 1.f);
3105             int transA = get_node_attr_i(node, "transA", 0);
3106             int transB = get_node_attr_i(node, "transB", 0);
3107 
3108             if (alpha == 1.f && beta == 1.f && transA == 0 && transB == 1)
3109             {
3110                 // InnerProduct-like A * B + C
3111                 node_reference[node.input(1)] -= 1;
3112                 node_reference[node.input(2)] -= 1;
3113             }
3114         }
3115         else if (op == "GroupNorm")
3116         {
3117             int affine = get_node_attr_i(node, "affine", 1);
3118             if (affine)
3119             {
3120                 node_reference[node.input(1)] -= 1;
3121                 node_reference[node.input(2)] -= 1;
3122             }
3123         }
3124         else if (op == "GRU")
3125         {
3126             for (int j = 1; j < node.input_size(); j++)
3127             {
3128                 node_reference[node.input(j)] -= 1;
3129             }
3130         }
3131         else if (op == "InstanceNormalization")
3132         {
3133             node_reference[node.input(1)] -= 1;
3134             node_reference[node.input(2)] -= 1;
3135         }
3136         else if (op == "LayerNorm")
3137         {
3138             int affine = get_node_attr_i(node, "affine", 1);
3139             if (affine)
3140             {
3141                 node_reference[node.input(1)] -= 1;
3142                 node_reference[node.input(2)] -= 1;
3143             }
3144         }
3145         else if (op == "LSTM")
3146         {
3147             for (int j = 1; j < node.input_size(); j++)
3148             {
3149                 node_reference[node.input(j)] -= 1;
3150             }
3151         }
3152         else if (op == "MatMul")
3153         {
3154             if (weights.find(node.input(1)) != weights.end() && weights[node.input(1)].dims_size() == 2)
3155             {
3156                 // InnerProduct
3157                 node_reference[node.input(1)] -= 1;
3158             }
3159         }
3160         else if (op == "MultiHeadAttention")
3161         {
3162             if (node.input_size() == 5)
3163             {
3164                 node_reference[node.input(1)] -= 1;
3165                 node_reference[node.input(2)] -= 1;
3166                 node_reference[node.input(3)] -= 1;
3167                 node_reference[node.input(4)] -= 1;
3168             }
3169             else
3170             {
3171                 node_reference[node.input(3)] -= 1;
3172                 node_reference[node.input(4)] -= 1;
3173                 node_reference[node.input(5)] -= 1;
3174                 node_reference[node.input(6)] -= 1;
3175                 node_reference[node.input(7)] -= 1;
3176                 node_reference[node.input(8)] -= 1;
3177                 node_reference[node.input(9)] -= 1;
3178                 node_reference[node.input(10)] -= 1;
3179             }
3180         }
3181         else if (op == "Pad")
3182         {
3183             if (node.input_size() >= 2)
3184             {
3185                 node_reference[node.input(1)] -= 1;
3186             }
3187         }
3188         else if (op == "PRelu")
3189         {
3190             node_reference[node.input(1)] -= 1;
3191         }
3192         else if (op == "Reshape")
3193         {
3194             if (node.input_size() >= 2)
3195             {
3196                 node_reference[node.input(1)] -= 1;
3197             }
3198         }
3199         else if (op == "Resize")
3200         {
3201             if (node.input_size() == 2)
3202             {
3203                 // opset 10
3204                 node_reference[node.input(1)] -= 1;
3205             }
3206             else
3207             {
3208                 // opset 11+
3209                 node_reference[node.input(1)] -= 1;
3210                 node_reference[node.input(2)] -= 1;
3211                 if (node.input_size() >= 4)
3212                 {
3213                     node_reference[node.input(3)] -= 1;
3214                 }
3215             }
3216         }
3217         else if (op == "RNN")
3218         {
3219             for (int j = 1; j < node.input_size(); j++)
3220             {
3221                 node_reference[node.input(j)] -= 1;
3222             }
3223         }
3224         else if (op == "SkipLayerNormalization")
3225         {
3226             node_reference[node.input(2)] -= 1;
3227             node_reference[node.input(3)] -= 1;
3228             node_reference[node.input(4)] -= 1;
3229         }
3230         else if (op == "Slice")
3231         {
3232             if (node.input_size() >= 2)
3233             {
3234                 node_reference[node.input(1)] -= 1;
3235                 node_reference[node.input(2)] -= 1;
3236                 if (node.input_size() >= 4)
3237                     node_reference[node.input(3)] -= 1;
3238                 if (node.input_size() >= 5)
3239                     node_reference[node.input(4)] -= 1;
3240             }
3241         }
3242         else if (op == "Upsample")
3243         {
3244             if (node.input_size() >= 2)
3245             {
3246                 node_reference[node.input(1)] -= 1;
3247             }
3248         }
3249         else if (op == "adaptive_avg_pool2d" || op == "adaptive_max_pool2d")
3250         {
3251             if (node.input_size() >= 2)
3252             {
3253                 node_reference[node.input(1)] -= 1;
3254             }
3255         }
3256     }
3257 
3258     //         for (auto a: node_reference)
3259     //         {
3260     //             fprintf(stderr, "b = %s %d\n", a.first.c_str(), a.second);
3261     //         }
3262 
3263     // count all weight node with zero reference
3264     int zero_reference_weight_node_count = 0;
3265     for (std::map<std::string, onnx::TensorProto>::iterator it = weights.begin(); it != weights.end(); it++)
3266     {
3267         const std::string& input_name = it->first;
3268 
3269         int refcount = node_reference[input_name];
3270         if (refcount == 0)
3271             zero_reference_weight_node_count++;
3272     }
3273 
3274     // we always treat constant node as weight or binaryop_weights
3275     // do not count it twice for layer_count
3276     int constant_node_count_moved_to_weight = 0;
3277     for (int i = 0; i < node_count; i++)
3278     {
3279         const onnx::NodeProto& node = graph.node(i);
3280 
3281         const std::string& op = node.op_type();
3282 
3283         if (op == "Constant")
3284         {
3285             constant_node_count_moved_to_weight++;
3286         }
3287     }
3288 
3289     // some op may have anonymous input
3290     // LSTM sequence_lens
3291     blob_names.erase("");
3292     node_reference.erase("");
3293 
3294     // remove node_reference entry with reference equals to one
3295     int split_layer_count = 0;
3296     int splitncnn_blob_count = 0;
3297     // split node reference
3298     std::map<std::string, int> split_node_reference;
3299     for (std::map<std::string, int>::iterator it = node_reference.begin(); it != node_reference.end(); it++)
3300     {
3301         if (it->second > 1)
3302         {
3303             split_layer_count++;
3304             splitncnn_blob_count += it->second;
3305 
3306             split_node_reference[it->first] = it->second;
3307         }
3308     }
3309 
3310     fprintf(pp, "%zu %zu\n", node_count - constant_node_count_moved_to_weight + weights.size() - zero_reference_weight_node_count - reduced_node_count + input_node_count + split_layer_count, blob_names.size() - zero_reference_weight_node_count + splitncnn_blob_count);
3311 
3312     int internal_split = 0;
3313 
3314     // place Input at the beginning
3315     for (int j = 0; j < graph.input_size(); j++)
3316     {
3317         const std::string& input_name = graph.input(j).name();
3318 
3319         // check weight
3320         if (weights.find(input_name) != weights.end())
3321             continue;
3322 
3323         fprintf(pp, "%-16s %-24s 0 1 %s\n", "Input", input_name.c_str(), input_name.c_str());
3324 
3325         int refcount = node_reference[input_name];
3326         if (refcount <= 1)
3327         {
3328             continue;
3329         }
3330 
3331         char splitname[256];
3332         sprintf(splitname, "splitncnn_input%d", j);
3333         fprintf(pp, "%-16s %-24s %d %d", "Split", splitname, 1, refcount);
3334         fprintf(pp, " %s", input_name.c_str());
3335 
3336         for (int k = 0; k < refcount; k++)
3337         {
3338             fprintf(pp, " %s_splitncnn_%d", input_name.c_str(), k);
3339         }
3340         fprintf(pp, "\n");
3341     }
3342 
3343     // place MemoryData next
3344     for (std::map<std::string, onnx::TensorProto>::iterator weight_it = weights.begin(); weight_it != weights.end(); weight_it++)
3345     {
3346         const std::string& input_name = weight_it->first;
3347 
3348         int refcount = node_reference[input_name];
3349         if (refcount == 0)
3350         {
3351             continue;
3352         }
3353 
3354         fprintf(pp, "%-16s %-24s 0 1 %s", "MemoryData", input_name.c_str(), input_name.c_str());
3355 
3356         const onnx::TensorProto& M = weights[input_name];
3357 
3358         if (M.dims_size() == 0)
3359         {
3360             fprintf(pp, " 0=%d", get_tensor_proto_data_size(M));
3361         }
3362         else if (M.dims_size() == 1)
3363         {
3364             fprintf(pp, " 0=%d", (int)M.dims(0));
3365         }
3366         else if (M.dims_size() == 2)
3367         {
3368             fprintf(pp, " 0=%d", (int)M.dims(1));
3369             if (M.dims(0) != 1)
3370             {
3371                 fprintf(pp, " 1=%d", (int)M.dims(0));
3372             }
3373         }
3374         else if (M.dims_size() == 3)
3375         {
3376             fprintf(pp, " 0=%d", (int)M.dims(2));
3377             fprintf(pp, " 1=%d", (int)M.dims(1));
3378             if (M.dims(0) != 1)
3379             {
3380                 fprintf(pp, " 2=%d", (int)M.dims(0));
3381             }
3382         }
3383         else if (M.dims_size() == 4)
3384         {
3385             fprintf(pp, " 0=%d", (int)M.dims(3));
3386             fprintf(pp, " 1=%d", (int)M.dims(2));
3387             fprintf(pp, " 2=%d", (int)M.dims(1));
3388         }
3389 
3390         fprintf(pp, "\n");
3391 
3392         fwrite_tensor_proto_data(M, bp);
3393 
3394         if (refcount <= 1)
3395         {
3396             continue;
3397         }
3398 
3399         char splitname[256];
3400         sprintf(splitname, "splitncnn_%d", internal_split);
3401         fprintf(pp, "%-16s %-24s %d %d", "Split", splitname, 1, refcount);
3402 
3403         fprintf(pp, " %s", input_name.c_str());
3404 
3405         for (int k = 0; k < refcount; k++)
3406         {
3407             fprintf(pp, " %s_splitncnn_%d", input_name.c_str(), k);
3408         }
3409         fprintf(pp, "\n");
3410 
3411         internal_split++;
3412     }
3413 
3414     for (int i = 0; i < node_count; i++)
3415     {
3416         const onnx::NodeProto& node = graph.node(i);
3417 
3418         const std::string& op = node.op_type();
3419 
3420         //         fprintf(stderr, "op = %s\n", op.c_str());
3421 
3422         if (op == "noop_reducedncnn")
3423         {
3424             continue;
3425         }
3426 
3427         std::string name = node.name();
3428         if (name.empty())
3429         {
3430             name = node.output(0);
3431         }
3432 
3433         int input_size = node.input_size();
3434         int output_size = node.output_size();
3435 
3436         for (int j = 0; j < (int)node.input_size(); j++)
3437         {
3438             const std::string& input_name = node.input(j);
3439 
3440             // check weight
3441             if (weights.find(input_name) != weights.end() && node_reference[input_name] == 0)
3442             {
3443                 input_size--;
3444             }
3445 
3446             if (input_name.empty())
3447             {
3448                 input_size--;
3449             }
3450 
3451             //             fprintf(stderr, "  input = %s\n", input_name.c_str());
3452         }
3453         /*
3454         for (int j=0; j<(int)node.output_size(); j++)
3455         {
3456             const std::string& output_name = node.output(j);
3457             fprintf(stderr, "  output = %s\n", output_name.c_str());
3458         }
3459         */
3460 
3461         if (op == "Abs")
3462         {
3463             fprintf(pp, "%-16s", "UnaryOp");
3464         }
3465         else if (op == "Acos")
3466         {
3467             fprintf(pp, "%-16s", "UnaryOp");
3468         }
3469         else if (op == "Add")
3470         {
3471             fprintf(pp, "%-16s", "BinaryOp");
3472         }
3473         else if (op == "Asin")
3474         {
3475             fprintf(pp, "%-16s", "UnaryOp");
3476         }
3477         else if (op == "Atan")
3478         {
3479             fprintf(pp, "%-16s", "UnaryOp");
3480         }
3481         else if (op == "AveragePool" || op == "MaxPool")
3482         {
3483             fprintf(pp, "%-16s", "Pooling");
3484         }
3485         else if (op == "BatchNormalization")
3486         {
3487             fprintf(pp, "%-16s", "BatchNorm");
3488         }
3489         else if (op == "BiasGelu")
3490         {
3491             fprintf(pp, "%-16s", "BiasGelu");
3492         }
3493         else if (op == "Ceil")
3494         {
3495             fprintf(pp, "%-16s", "UnaryOp");
3496         }
3497         else if (op == "Clip")
3498         {
3499             fprintf(pp, "%-16s", "Clip");
3500         }
3501         else if (op == "Concat")
3502         {
3503             fprintf(pp, "%-16s", "Concat");
3504         }
3505         else if (op == "Constant")
3506         {
3507             continue;
3508         }
3509         else if (op == "Conv")
3510         {
3511             int group = get_node_attr_i(node, "group", 1);
3512             if (group > 1)
3513             {
3514                 fprintf(pp, "%-16s", "ConvolutionDepthWise");
3515             }
3516             else
3517             {
3518                 fprintf(pp, "%-16s", "Convolution");
3519             }
3520         }
3521         else if (op == "ConvTranspose")
3522         {
3523             int group = get_node_attr_i(node, "group", 1);
3524             if (group > 1)
3525             {
3526                 fprintf(pp, "%-16s", "DeconvolutionDepthWise");
3527             }
3528             else
3529             {
3530                 fprintf(pp, "%-16s", "Deconvolution");
3531             }
3532         }
3533         else if (op == "Cos")
3534         {
3535             fprintf(pp, "%-16s", "UnaryOp");
3536         }
3537         else if (op == "DepthToSpace")
3538         {
3539             fprintf(pp, "%-16s", "PixelShuffle");
3540         }
3541         else if (op == "Div")
3542         {
3543             fprintf(pp, "%-16s", "BinaryOp");
3544         }
3545         else if (op == "Dropout")
3546         {
3547             fprintf(pp, "%-16s", "Dropout");
3548             output_size = 1;
3549         }
3550         else if (op == "Elu")
3551         {
3552             fprintf(pp, "%-16s", "ELU");
3553         }
3554         else if (op == "EmbedLayerNormalization")
3555         {
3556             fprintf(pp, "%-16s", "EmbedLayerNormalization");
3557         }
3558         else if (op == "Exp")
3559         {
3560             fprintf(pp, "%-16s", "UnaryOp");
3561         }
3562         else if (op == "Flatten")
3563         {
3564             fprintf(pp, "%-16s", "Flatten");
3565         }
3566         else if (op == "Floor")
3567         {
3568             fprintf(pp, "%-16s", "UnaryOp");
3569         }
3570         else if (op == "Gemm")
3571         {
3572             float alpha = get_node_attr_f(node, "alpha", 1.f);
3573             float beta = get_node_attr_f(node, "beta", 1.f);
3574             int transA = get_node_attr_i(node, "transA", 0);
3575             int transB = get_node_attr_i(node, "transB", 0);
3576 
3577             if (alpha == 1.f && beta == 1.f && transA == 0 && transB == 1)
3578             {
3579                 // InnerProduct-like A * B + C
3580                 fprintf(pp, "%-16s", "InnerProduct");
3581             }
3582             else
3583             {
3584                 fprintf(pp, "%-16s", "Gemm");
3585             }
3586         }
3587         else if (op == "GlobalAveragePool")
3588         {
3589             fprintf(pp, "%-16s", "Pooling");
3590         }
3591         else if (op == "GlobalMaxPool")
3592         {
3593             fprintf(pp, "%-16s", "Pooling");
3594         }
3595         else if (op == "adaptive_avg_pool2d" || op == "adaptive_max_pool2d")
3596         {
3597             fprintf(pp, "%-16s", "Pooling");
3598         }
3599         else if (op == "GroupNorm")
3600         {
3601             fprintf(pp, "%-16s", "GroupNorm");
3602         }
3603         else if (op == "GRU")
3604         {
3605             fprintf(pp, "%-16s", "GRU");
3606         }
3607         else if (op == "HardSigmoid")
3608         {
3609             fprintf(pp, "%-16s", "HardSigmoid");
3610         }
3611         else if (op == "HardSwish")
3612         {
3613             fprintf(pp, "%-16s", "HardSwish");
3614         }
3615         else if (op == "ImageScaler")
3616         {
3617             fprintf(pp, "%-16s", "Scale");
3618         }
3619         else if (op == "InstanceNormalization")
3620         {
3621             fprintf(pp, "%-16s", "InstanceNorm");
3622         }
3623         else if (op == "LayerNorm")
3624         {
3625             fprintf(pp, "%-16s", "LayerNorm");
3626         }
3627         else if (op == "LeakyRelu")
3628         {
3629             fprintf(pp, "%-16s", "ReLU");
3630         }
3631         else if (op == "Log")
3632         {
3633             fprintf(pp, "%-16s", "UnaryOp");
3634         }
3635         else if (op == "LRN")
3636         {
3637             fprintf(pp, "%-16s", "LRN");
3638         }
3639         else if (op == "LSTM")
3640         {
3641             fprintf(pp, "%-16s", "LSTM");
3642         }
3643         else if (op == "MatMul")
3644         {
3645             if (weights.find(node.input(1)) != weights.end() && weights[node.input(1)].dims_size() == 2)
3646             {
3647                 fprintf(pp, "%-16s", "InnerProduct");
3648             }
3649             else
3650             {
3651                 fprintf(pp, "%-16s", "Gemm");
3652             }
3653         }
3654         else if (op == "Max")
3655         {
3656             fprintf(pp, "%-16s", "BinaryOp");
3657         }
3658         else if (op == "Min")
3659         {
3660             fprintf(pp, "%-16s", "BinaryOp");
3661         }
3662         else if (op == "Mul")
3663         {
3664             fprintf(pp, "%-16s", "BinaryOp");
3665         }
3666         else if (op == "MultiHeadAttention")
3667         {
3668             fprintf(pp, "%-16s", "MultiHeadAttention");
3669         }
3670         else if (op == "Neg")
3671         {
3672             fprintf(pp, "%-16s", "UnaryOp");
3673         }
3674         else if (op == "Normalize")
3675         {
3676             fprintf(pp, "%-16s", "Normalize");
3677         }
3678         else if (op == "Pad")
3679         {
3680             fprintf(pp, "%-16s", "Padding");
3681         }
3682         else if (op == "PixelShuffle")
3683         {
3684             fprintf(pp, "%-16s", "PixelShuffle");
3685         }
3686         else if (op == "Pow")
3687         {
3688             fprintf(pp, "%-16s", "BinaryOp");
3689         }
3690         else if (op == "PRelu")
3691         {
3692             fprintf(pp, "%-16s", "PReLU");
3693         }
3694         else if (op == "Reciprocal")
3695         {
3696             fprintf(pp, "%-16s", "UnaryOp");
3697         }
3698         else if (op == "ReduceMax" || op == "ReduceMin" || op == "ReduceMean" || op == "ReduceProd" || op == "ReduceSum" || op == "ReduceSumSquare" || op == "ReduceL1" || op == "ReduceL2" || op == "ReduceLogSum" || op == "ReduceLogSumExp")
3699         {
3700             fprintf(pp, "%-16s", "Reduction");
3701         }
3702         else if (op == "Relu")
3703         {
3704             fprintf(pp, "%-16s", "ReLU");
3705         }
3706         else if (op == "Reorg")
3707         {
3708             fprintf(pp, "%-16s", "Reorg");
3709         }
3710         else if (op == "Reshape")
3711         {
3712             fprintf(pp, "%-16s", "Reshape");
3713         }
3714         else if (op == "RNN")
3715         {
3716             fprintf(pp, "%-16s", "RNN");
3717         }
3718         else if (op == "ShuffleChannel")
3719         {
3720             fprintf(pp, "%-16s", "ShuffleChannel");
3721         }
3722         else if (op == "Sigmoid")
3723         {
3724             fprintf(pp, "%-16s", "Sigmoid");
3725         }
3726         else if (op == "Sin")
3727         {
3728             fprintf(pp, "%-16s", "UnaryOp");
3729         }
3730         else if (op == "SkipLayerNormalization")
3731         {
3732             fprintf(pp, "%-16s", "SkipLayerNormalization");
3733         }
3734         else if (op == "Slice")
3735         {
3736             fprintf(pp, "%-16s", "Crop");
3737         }
3738         else if (op == "Softmax")
3739         {
3740             fprintf(pp, "%-16s", "Softmax");
3741         }
3742         else if (op == "Softplus")
3743         {
3744             fprintf(pp, "%-16s", "Softplus");
3745         }
3746         else if (op == "Split")
3747         {
3748             fprintf(pp, "%-16s", "Slice");
3749         }
3750         else if (op == "Sqrt")
3751         {
3752             fprintf(pp, "%-16s", "UnaryOp");
3753         }
3754         else if (op == "Squeeze")
3755         {
3756             fprintf(pp, "%-16s", "Squeeze");
3757         }
3758         else if (op == "Sub")
3759         {
3760             fprintf(pp, "%-16s", "BinaryOp");
3761         }
3762         else if (op == "Sum")
3763         {
3764             fprintf(pp, "%-16s", "Eltwise");
3765         }
3766         else if (op == "Swish")
3767         {
3768             fprintf(pp, "%-16s", "Swish");
3769         }
3770         else if (op == "Tan")
3771         {
3772             fprintf(pp, "%-16s", "UnaryOp");
3773         }
3774         else if (op == "Tanh")
3775         {
3776             fprintf(pp, "%-16s", "UnaryOp");
3777         }
3778         else if (op == "Transpose")
3779         {
3780             fprintf(pp, "%-16s", "Permute");
3781         }
3782         else if (op == "Upsample" || op == "Resize")
3783         {
3784             fprintf(pp, "%-16s", "Interp");
3785         }
3786         else if (op == "Unsqueeze")
3787         {
3788             fprintf(pp, "%-16s", "ExpandDims");
3789         }
3790         else
3791         {
3792             // TODO
3793             fprintf(stderr, "%s not supported yet!\n", op.c_str());
3794             fprintf(pp, "%-16s", op.c_str());
3795         }
3796 
3797         fprintf(pp, " %-24s %d %d", name.c_str(), input_size, output_size);
3798 
3799         for (int j = 0; j < (int)node.input_size(); j++)
3800         {
3801             std::string input_name = node.input(j);
3802 
3803             // check weight
3804             if (weights.find(input_name) != weights.end() && node_reference[input_name] == 0)
3805             {
3806                 continue;
3807             }
3808 
3809             if (input_name.empty())
3810             {
3811                 continue;
3812             }
3813 
3814             if (split_node_reference.find(input_name) != split_node_reference.end())
3815             {
3816                 int refidx = split_node_reference[input_name] - 1;
3817                 split_node_reference[input_name] = refidx;
3818 
3819                 char splitsuffix[256];
3820                 sprintf(splitsuffix, "_splitncnn_%d", refidx);
3821                 input_name = input_name + splitsuffix;
3822             }
3823 
3824             fprintf(pp, " %s", input_name.c_str());
3825         }
3826 
3827         for (int j = 0; j < output_size; j++)
3828         {
3829             const std::string& output_name = node.output(j);
3830 
3831             fprintf(pp, " %s", output_name.c_str());
3832         }
3833 
3834         if (op == "Abs")
3835         {
3836             int op_type = 0;
3837             fprintf(pp, " 0=%d", op_type);
3838         }
3839         else if (op == "Acos")
3840         {
3841             int op_type = 13;
3842             fprintf(pp, " 0=%d", op_type);
3843         }
3844         else if (op == "Add")
3845         {
3846             int op_type = 0;
3847             fprintf(pp, " 0=%d", op_type);
3848 
3849             int with_scalar = get_node_attr_i(node, "with_scalar", 0);
3850             float b = get_node_attr_f(node, "b", 0.f);
3851             if (with_scalar)
3852             {
3853                 fprintf(pp, " 1=%d", with_scalar);
3854                 fprintf(pp, " 2=%e", b);
3855             }
3856         }
3857         else if (op == "Asin")
3858         {
3859             int op_type = 12;
3860             fprintf(pp, " 0=%d", op_type);
3861         }
3862         else if (op == "Atan")
3863         {
3864             int op_type = 14;
3865             fprintf(pp, " 0=%d", op_type);
3866         }
3867         else if (op == "AveragePool" || op == "MaxPool")
3868         {
3869             std::string auto_pad = get_node_attr_s(node, "auto_pad");
3870             int ceil_mode = get_node_attr_i(node, "ceil_mode", 0);
3871             std::vector<int> kernel_shape = get_node_attr_ai(node, "kernel_shape");
3872             std::vector<int> strides = get_node_attr_ai(node, "strides");
3873             std::vector<int> pads = get_node_attr_ai(node, "pads");
3874 
3875             int pool = op == "AveragePool" ? 1 : 0;
3876             int pad_mode = 1;
3877 
3878             if (auto_pad == "SAME_UPPER")
3879             {
3880                 pad_mode = 2;
3881             }
3882             else if (auto_pad == "SAME_LOWER")
3883             {
3884                 pad_mode = 3;
3885             }
3886 
3887             if (ceil_mode == 1)
3888             {
3889                 pad_mode = 0;
3890             }
3891 
3892             fprintf(pp, " 0=%d", pool);
3893 
3894             if (kernel_shape.size() == 1)
3895             {
3896                 fprintf(pp, " 1=%d", kernel_shape[0]);
3897             }
3898             else if (kernel_shape.size() == 2)
3899             {
3900                 fprintf(pp, " 1=%d", kernel_shape[1]);
3901                 fprintf(pp, " 11=%d", kernel_shape[0]);
3902             }
3903 
3904             if (strides.size() == 1)
3905             {
3906                 fprintf(pp, " 2=%d", strides[0]);
3907             }
3908             else if (strides.size() == 2)
3909             {
3910                 fprintf(pp, " 2=%d", strides[1]);
3911                 fprintf(pp, " 12=%d", strides[0]);
3912             }
3913 
3914             if (pads.size() == 1)
3915             {
3916                 fprintf(pp, " 3=%d", pads[0]);
3917             }
3918             else if (pads.size() == 2)
3919             {
3920                 fprintf(pp, " 3=%d", pads[1]);
3921                 fprintf(pp, " 13=%d", pads[0]);
3922             }
3923             else if (pads.size() == 4)
3924             {
3925                 fprintf(pp, " 3=%d", pads[1]);
3926                 fprintf(pp, " 13=%d", pads[0]);
3927                 fprintf(pp, " 14=%d", pads[3]);
3928                 fprintf(pp, " 15=%d", pads[2]);
3929             }
3930 
3931             fprintf(pp, " 5=%d", pad_mode);
3932 
3933             if (op == "AveragePool")
3934             {
3935                 int avgpool_count_include_pad = get_node_attr_i(node, "count_include_pad", 0);
3936                 fprintf(pp, " 6=%d", avgpool_count_include_pad);
3937             }
3938         }
3939         else if (op == "BatchNormalization")
3940         {
3941             float epsilon = get_node_attr_f(node, "epsilon", 1e-5f);
3942 
3943             const onnx::TensorProto& scale = weights[node.input(1)];
3944             const onnx::TensorProto& B = weights[node.input(2)];
3945             const onnx::TensorProto& mean = weights[node.input(3)];
3946             const onnx::TensorProto& var = weights[node.input(4)];
3947 
3948             int channels = get_tensor_proto_data_size(scale);
3949 
3950             fprintf(pp, " 0=%d", channels);
3951 
3952             fwrite_tensor_proto_data(scale, bp);
3953             fwrite_tensor_proto_data(mean, bp);
3954             // apply epsilon to var
3955             {
3956                 const float* v = var.has_raw_data() ? (const float*)var.raw_data().data() : var.float_data().data();
3957 
3958                 for (int j = 0; j < channels; j++)
3959                 {
3960                     float ve = v[j] + epsilon;
3961                     fwrite(&ve, sizeof(float), 1, bp);
3962                 }
3963             }
3964             fwrite_tensor_proto_data(B, bp);
3965         }
3966         else if (op == "BiasGelu")
3967         {
3968             const onnx::TensorProto& B = weights[node.input(1)];
3969 
3970             fprintf(pp, " 0=%d", get_tensor_proto_data_size(B));
3971 
3972             int quantize_tag = 0;
3973             fwrite(&quantize_tag, sizeof(int), 1, bp);
3974 
3975             fwrite_tensor_proto_data(B, bp);
3976         }
3977         else if (op == "Ceil")
3978         {
3979             int op_type = 3;
3980             fprintf(pp, " 0=%d", op_type);
3981         }
3982         else if (op == "Clip")
3983         {
3984             float min;
3985             float max;
3986             if (node.input_size() == 1)
3987             {
3988                 min = get_node_attr_f(node, "min", -FLT_MAX);
3989                 max = get_node_attr_f(node, "max", FLT_MAX);
3990             }
3991             else
3992             {
3993                 min = weights.find(node.input(1)) != weights.end() ? get_node_attr_from_input_f(weights[node.input(1)]) : -FLT_MAX;
3994                 max = weights.find(node.input(2)) != weights.end() ? get_node_attr_from_input_f(weights[node.input(2)]) : FLT_MAX;
3995             }
3996 
3997             fprintf(pp, " 0=%e", min);
3998             fprintf(pp, " 1=%e", max);
3999         }
4000         else if (op == "Concat")
4001         {
4002             int axis = get_node_attr_i(node, "axis", 1);
4003             fprintf(pp, " 0=%d", axis - 1);
4004         }
4005         else if (op == "Constant")
4006         {
4007             // never reach here
4008         }
4009         else if (op == "Conv")
4010         {
4011             const onnx::TensorProto& W = weights[node.input(1)];
4012 
4013             int num_filter = W.dims(0);
4014             int has_bias = node.input_size() == 3 ? 1 : 0;
4015 
4016             std::string auto_pad = get_node_attr_s(node, "auto_pad");
4017             std::vector<int> kernel_shape = get_node_attr_ai(node, "kernel_shape");
4018             std::vector<int> dilations = get_node_attr_ai(node, "dilations");
4019             std::vector<int> strides = get_node_attr_ai(node, "strides");
4020             std::vector<int> pads = get_node_attr_ai(node, "pads");
4021             int group = get_node_attr_i(node, "group", 1);
4022 
4023             fprintf(pp, " 0=%d", num_filter);
4024 
4025             if (kernel_shape.size() == 1)
4026             {
4027                 fprintf(pp, " 1=%d", kernel_shape[0]);
4028             }
4029             else if (kernel_shape.size() == 2)
4030             {
4031                 fprintf(pp, " 1=%d", kernel_shape[1]);
4032                 fprintf(pp, " 11=%d", kernel_shape[0]);
4033             }
4034 
4035             if (dilations.size() == 1)
4036             {
4037                 fprintf(pp, " 2=%d", dilations[0]);
4038             }
4039             else if (dilations.size() == 2)
4040             {
4041                 fprintf(pp, " 2=%d", dilations[1]);
4042                 fprintf(pp, " 12=%d", dilations[0]);
4043             }
4044 
4045             if (strides.size() == 1)
4046             {
4047                 fprintf(pp, " 3=%d", strides[0]);
4048             }
4049             else if (strides.size() == 2)
4050             {
4051                 fprintf(pp, " 3=%d", strides[1]);
4052                 fprintf(pp, " 13=%d", strides[0]);
4053             }
4054 
4055             if (auto_pad == "SAME_UPPER")
4056             {
4057                 fprintf(pp, " 4=-233");
4058             }
4059             else if (auto_pad == "SAME_LOWER")
4060             {
4061                 fprintf(pp, " 4=-234");
4062             }
4063             else
4064             {
4065                 if (pads.size() == 1)
4066                 {
4067                     fprintf(pp, " 4=%d", pads[0]);
4068                 }
4069                 else if (pads.size() == 2)
4070                 {
4071                     fprintf(pp, " 4=%d", pads[1]);
4072                     fprintf(pp, " 14=%d", pads[0]);
4073                 }
4074                 else if (pads.size() == 4)
4075                 {
4076                     fprintf(pp, " 4=%d", pads[1]);
4077                     fprintf(pp, " 14=%d", pads[0]);
4078                     fprintf(pp, " 15=%d", pads[3]);
4079                     fprintf(pp, " 16=%d", pads[2]);
4080                 }
4081             }
4082 
4083             fprintf(pp, " 5=%d", has_bias);
4084 
4085             fprintf(pp, " 6=%d", get_tensor_proto_data_size(W));
4086 
4087             if (group > 1)
4088             {
4089                 fprintf(pp, " 7=%d", group);
4090             }
4091 
4092             int quantize_tag = 0;
4093             fwrite(&quantize_tag, sizeof(int), 1, bp);
4094 
4095             fwrite_tensor_proto_data(W, bp);
4096 
4097             if (has_bias)
4098             {
4099                 const onnx::TensorProto& B = weights[node.input(2)];
4100                 fwrite_tensor_proto_data(B, bp);
4101             }
4102         }
4103         else if (op == "ConvTranspose")
4104         {
4105             const onnx::TensorProto& W = weights[node.input(1)];
4106 
4107             int has_bias = node.input_size() == 3 ? 1 : 0;
4108 
4109             std::string auto_pad = get_node_attr_s(node, "auto_pad");
4110             std::vector<int> kernel_shape = get_node_attr_ai(node, "kernel_shape");
4111             std::vector<int> dilations = get_node_attr_ai(node, "dilations");
4112             std::vector<int> strides = get_node_attr_ai(node, "strides");
4113             std::vector<int> output_padding = get_node_attr_ai(node, "output_padding");
4114             std::vector<int> output_shape = get_node_attr_ai(node, "output_shape");
4115             std::vector<int> pads = get_node_attr_ai(node, "pads");
4116             int group = get_node_attr_i(node, "group", 1);
4117             int num_filter = W.dims(1) * group;
4118 
4119             fprintf(pp, " 0=%d", num_filter);
4120 
4121             if (kernel_shape.size() == 1)
4122             {
4123                 fprintf(pp, " 1=%d", kernel_shape[0]);
4124             }
4125             else if (kernel_shape.size() == 2)
4126             {
4127                 fprintf(pp, " 1=%d", kernel_shape[1]);
4128                 fprintf(pp, " 11=%d", kernel_shape[0]);
4129             }
4130 
4131             if (dilations.size() == 1)
4132             {
4133                 fprintf(pp, " 2=%d", dilations[0]);
4134             }
4135             else if (dilations.size() == 2)
4136             {
4137                 fprintf(pp, " 2=%d", dilations[1]);
4138                 fprintf(pp, " 12=%d", dilations[0]);
4139             }
4140 
4141             if (strides.size() == 1)
4142             {
4143                 fprintf(pp, " 3=%d", strides[0]);
4144             }
4145             else if (strides.size() == 2)
4146             {
4147                 fprintf(pp, " 3=%d", strides[1]);
4148                 fprintf(pp, " 13=%d", strides[0]);
4149             }
4150 
4151             if (auto_pad == "SAME_UPPER")
4152             {
4153                 fprintf(pp, " 4=-233");
4154             }
4155             else if (auto_pad == "SAME_LOWER")
4156             {
4157                 fprintf(pp, " 4=-234");
4158             }
4159             else
4160             {
4161                 if (pads.size() == 1)
4162                 {
4163                     fprintf(pp, " 4=%d", pads[0]);
4164                 }
4165                 else if (pads.size() == 2)
4166                 {
4167                     fprintf(pp, " 4=%d", pads[1]);
4168                     fprintf(pp, " 14=%d", pads[0]);
4169                 }
4170                 else if (pads.size() == 4)
4171                 {
4172                     fprintf(pp, " 4=%d", pads[1]);
4173                     fprintf(pp, " 14=%d", pads[0]);
4174                     fprintf(pp, " 15=%d", pads[3]);
4175                     fprintf(pp, " 16=%d", pads[2]);
4176                 }
4177             }
4178 
4179             if (output_padding.size() == 1)
4180             {
4181                 fprintf(pp, " 18=%d", output_padding[0]);
4182             }
4183             else if (output_padding.size() == 2)
4184             {
4185                 fprintf(pp, " 18=%d", output_padding[1]);
4186                 fprintf(pp, " 19=%d", output_padding[0]);
4187             }
4188 
4189             if (output_shape.size() == 1)
4190             {
4191                 fprintf(pp, " 20=%d", output_shape[0]);
4192             }
4193             else if (output_shape.size() == 2)
4194             {
4195                 fprintf(pp, " 20=%d", output_shape[1]);
4196                 fprintf(pp, " 21=%d", output_shape[0]);
4197             }
4198 
4199             fprintf(pp, " 5=%d", has_bias);
4200 
4201             fprintf(pp, " 6=%d", get_tensor_proto_data_size(W));
4202 
4203             if (group > 1)
4204             {
4205                 fprintf(pp, " 7=%d", group);
4206             }
4207 
4208             int quantize_tag = 0;
4209             fwrite(&quantize_tag, sizeof(int), 1, bp);
4210 
4211             int maxk = 0;
4212             if (kernel_shape.size() == 2)
4213             {
4214                 maxk = kernel_shape[1] * kernel_shape[0];
4215             }
4216             else
4217             {
4218                 maxk = kernel_shape[0] * kernel_shape[0];
4219             }
4220             int weight_data_size = get_tensor_proto_data_size(W);
4221             const float* weight_data = 0;
4222             if (W.has_raw_data())
4223             {
4224                 weight_data = (const float*)W.raw_data().data();
4225             }
4226             else if (W.data_type() == 1)
4227             {
4228                 weight_data = W.float_data().data();
4229             }
4230             for (int g = 0; g < group; g++)
4231             {
4232                 // reorder weight from inch-outch to outch-inch
4233                 int num_filter_g = num_filter / group;
4234                 int num_input = weight_data_size / maxk / num_filter_g / group;
4235                 const float* weight_data_ptr = weight_data + g * maxk * num_filter_g * num_input;
4236                 for (int k = 0; k < num_filter_g; k++)
4237                 {
4238                     for (int j = 0; j < num_input; j++)
4239                     {
4240                         fwrite(weight_data_ptr + (j * num_filter_g + k) * maxk, sizeof(float), maxk, bp);
4241                     }
4242                 }
4243             }
4244 
4245             if (has_bias)
4246             {
4247                 const onnx::TensorProto& B = weights[node.input(2)];
4248                 fwrite_tensor_proto_data(B, bp);
4249             }
4250         }
4251         else if (op == "Cos")
4252         {
4253             int op_type = 10;
4254             fprintf(pp, " 0=%d", op_type);
4255         }
4256         else if (op == "DepthToSpace")
4257         {
4258             // pixelshuffle
4259             int scale_factor = get_node_attr_i(node, "blocksize", 1);
4260             std::string mode = get_node_attr_s(node, "mode");
4261             fprintf(pp, " 0=%d", scale_factor);
4262             if (mode == "CRD")
4263             {
4264                 fprintf(pp, " 1=0");
4265             }
4266             else if (mode == "DCR")
4267             {
4268                 fprintf(pp, " 1=1");
4269             }
4270         }
4271         else if (op == "Div")
4272         {
4273             int op_type = 3;
4274             fprintf(pp, " 0=%d", op_type);
4275 
4276             int with_scalar = get_node_attr_i(node, "with_scalar", 0);
4277             float b = get_node_attr_f(node, "b", 0.f);
4278             if (with_scalar)
4279             {
4280                 fprintf(pp, " 1=%d", with_scalar);
4281                 fprintf(pp, " 2=%e", b);
4282             }
4283         }
4284         else if (op == "Dropout")
4285         {
4286             // no-op
4287         }
4288         else if (op == "Elu")
4289         {
4290             float alpha = get_node_attr_f(node, "alpha", 1.f);
4291             fprintf(pp, " 0=%e", alpha);
4292         }
4293         else if (op == "EmbedLayerNormalization")
4294         {
4295             const onnx::TensorProto& words = weights[node.input(2)];
4296             const onnx::TensorProto& positions = weights[node.input(3)];
4297             const onnx::TensorProto& W = weights[node.input(5)];
4298             const onnx::TensorProto& B = weights[node.input(6)];
4299 
4300             fprintf(pp, " 0=%d", get_tensor_proto_data_size(B));
4301             fprintf(pp, " 1=%d", get_tensor_proto_data_size(words));
4302             fprintf(pp, " 2=%d", get_tensor_proto_data_size(positions));
4303 
4304             int quantize_tag = 0;
4305             fwrite(&quantize_tag, sizeof(int), 1, bp);
4306 
4307             fwrite_tensor_proto_data(words, bp);
4308 
4309             fwrite(&quantize_tag, sizeof(int), 1, bp);
4310 
4311             fwrite_tensor_proto_data(positions, bp);
4312 
4313             fwrite(&quantize_tag, sizeof(int), 1, bp);
4314 
4315             fwrite_tensor_proto_data(W, bp);
4316 
4317             fwrite(&quantize_tag, sizeof(int), 1, bp);
4318 
4319             fwrite_tensor_proto_data(B, bp);
4320         }
4321         else if (op == "Exp")
4322         {
4323             int op_type = 7;
4324             fprintf(pp, " 0=%d", op_type);
4325         }
4326         else if (op == "Flatten")
4327         {
4328             int axis = get_node_attr_i(node, "axis", 1);
4329             if (axis != 1)
4330             {
4331                 fprintf(stderr, "Unsupported Flatten axis %d!\n", axis);
4332             }
4333         }
4334         else if (op == "Floor")
4335         {
4336             int op_type = 2;
4337             fprintf(pp, " 0=%d", op_type);
4338         }
4339         else if (op == "Gemm")
4340         {
4341             float alpha = get_node_attr_f(node, "alpha", 1.f);
4342             float beta = get_node_attr_f(node, "beta", 1.f);
4343             int transA = get_node_attr_i(node, "transA", 0);
4344             int transB = get_node_attr_i(node, "transB", 0);
4345 
4346             if (alpha == 1.f && beta == 1.f && transA == 0 && transB == 1)
4347             {
4348                 // InnerProduct-like A * B + C
4349                 const onnx::TensorProto& B = weights[node.input(1)];
4350                 const onnx::TensorProto& C = weights[node.input(2)];
4351 
4352                 fprintf(pp, " 0=%d", get_tensor_proto_data_size(C));
4353                 fprintf(pp, " 1=1");
4354                 fprintf(pp, " 2=%d", get_tensor_proto_data_size(B));
4355 
4356                 int quantize_tag = 0;
4357                 fwrite(&quantize_tag, sizeof(int), 1, bp);
4358 
4359                 fwrite_tensor_proto_data(B, bp);
4360                 fwrite_tensor_proto_data(C, bp);
4361             }
4362             else
4363             {
4364                 // gemm
4365                 fprintf(pp, " 0=%e", alpha);
4366                 fprintf(pp, " 1=%e", beta);
4367                 fprintf(pp, " 2=%d", transA);
4368                 fprintf(pp, " 3=%d", transB);
4369             }
4370         }
4371         else if (op == "GlobalAveragePool")
4372         {
4373             int pool = 1;
4374             int global_pool = 1;
4375 
4376             fprintf(pp, " 0=%d", pool);
4377             fprintf(pp, " 4=%d", global_pool);
4378         }
4379         else if (op == "GlobalMaxPool")
4380         {
4381             int pool = 0;
4382             int global_pool = 1;
4383 
4384             fprintf(pp, " 0=%d", pool);
4385             fprintf(pp, " 4=%d", global_pool);
4386         }
4387         else if (op == "adaptive_avg_pool2d" || op == "adaptive_max_pool2d")
4388         {
4389             int pool = 0;
4390             if (op == "adaptive_avg_pool2d")
4391             {
4392                 pool = 1;
4393             }
4394             int adaptive_pooling = 1;
4395             const onnx::TensorProto& out_shape_tp = weights[node.input(1)];
4396             std::vector<int> out_shape = get_node_attr_from_input_ai(out_shape_tp);
4397 
4398             fprintf(pp, " 0=%d", pool);
4399             fprintf(pp, " 7=%d", adaptive_pooling);
4400             if (out_shape.size() == 1)
4401             {
4402                 fprintf(pp, " 8=%d", out_shape[0]);
4403             }
4404             else if (out_shape.size() == 2)
4405             {
4406                 // out_w
4407                 fprintf(pp, " 8=%d", out_shape[1]);
4408                 // out_h
4409                 fprintf(pp, " 18=%d", out_shape[0]);
4410             }
4411         }
4412         else if (op == "GroupNorm")
4413         {
4414             int groups = get_node_attr_i(node, "groups", 1);
4415             int channels = get_node_attr_i(node, "channels", 1);
4416             float eps = get_node_attr_f(node, "epsilon", 1e-5f);
4417             int affine = get_node_attr_i(node, "affine", 1);
4418 
4419             if (affine)
4420             {
4421                 // discard affine-less S=1 B=0
4422                 std::vector<float> affine_S = get_node_attr_from_input_af(weights[node.input(1)]);
4423                 std::vector<float> affine_B = get_node_attr_from_input_af(weights[node.input(2)]);
4424                 if (affine_S.size() == 1 && affine_S[0] == 1.f && affine_B.size() == 1 && affine_B[0] == 0.f)
4425                 {
4426                     affine = 0;
4427                 }
4428                 else
4429                 {
4430                     affine = 0;
4431                     {
4432                         for (int j = 0; j < channels; j++)
4433                         {
4434                             if (affine_S[j] != 1.f || affine_B[j] != 0.f)
4435                             {
4436                                 affine = 1;
4437                                 break;
4438                             }
4439                         }
4440                     }
4441                 }
4442             }
4443 
4444             fprintf(pp, " 0=%d", groups);
4445             fprintf(pp, " 1=%d", channels);
4446             fprintf(pp, " 2=%e", eps);
4447             fprintf(pp, " 3=%d", affine);
4448             if (affine)
4449             {
4450                 const onnx::TensorProto& scale = weights[node.input(1)];
4451                 const onnx::TensorProto& B = weights[node.input(2)];
4452 
4453                 fwrite_tensor_proto_data(scale, bp);
4454                 fwrite_tensor_proto_data(B, bp);
4455             }
4456         }
4457         else if (op == "GRU")
4458         {
4459             const onnx::TensorProto& W = weights[node.input(1)];
4460             const onnx::TensorProto& R = weights[node.input(2)];
4461             const onnx::TensorProto& B = weights[node.input(3)];
4462 
4463             int hidden_size = get_node_attr_i(node, "hidden_size", 0);
4464             std::string direction = get_node_attr_s(node, "direction");
4465 
4466             int direction_type = 0;
4467             if (direction == "forward")
4468             {
4469                 direction_type = 0;
4470             }
4471             else if (direction == "reverse")
4472             {
4473                 direction_type = 1;
4474             }
4475             else if (direction == "bidirectional")
4476             {
4477                 direction_type = 2;
4478             }
4479 
4480             int weight_data_size = get_tensor_proto_data_size(W);
4481 
4482             fprintf(pp, " 0=%d", hidden_size);
4483             fprintf(pp, " 1=%d", weight_data_size);
4484             fprintf(pp, " 2=%d", direction_type);
4485 
4486             int num_directions = direction_type == 2 ? 2 : 1;
4487 
4488             int quantize_tag = 0;
4489 
4490             // reorder num_directions-URN-hidden-size to num_directions-RUN-hidden-size
4491             {
4492                 fwrite(&quantize_tag, sizeof(int), 1, bp);
4493 
4494                 int weight_data_size_g = get_tensor_proto_data_size(W) / 3 / num_directions;
4495                 const float* wptr = W.has_raw_data() ? (const float*)W.raw_data().data() : W.float_data().data();
4496 
4497                 const float* uptr = wptr;
4498                 const float* rptr = wptr + weight_data_size_g;
4499                 const float* nptr = wptr + weight_data_size_g * 2;
4500                 fwrite(rptr, sizeof(float), weight_data_size_g, bp);
4501                 fwrite(uptr, sizeof(float), weight_data_size_g, bp);
4502                 fwrite(nptr, sizeof(float), weight_data_size_g, bp);
4503 
4504                 if (direction_type == 2)
4505                 {
4506                     uptr += weight_data_size_g * 3;
4507                     rptr += weight_data_size_g * 3;
4508                     nptr += weight_data_size_g * 3;
4509                     fwrite(rptr, sizeof(float), weight_data_size_g, bp);
4510                     fwrite(uptr, sizeof(float), weight_data_size_g, bp);
4511                     fwrite(nptr, sizeof(float), weight_data_size_g, bp);
4512                 }
4513             }
4514 
4515             // reduce U and R bias except N
4516             // reorder num_directions-URN-hidden to num_directions-RUN-hidden
4517             {
4518                 fwrite(&quantize_tag, sizeof(int), 1, bp);
4519 
4520                 int bias_data_size_g = get_tensor_proto_data_size(B) / 2 / 3 / num_directions;
4521                 const float* bptr = B.has_raw_data() ? (const float*)B.raw_data().data() : B.float_data().data();
4522                 const float* wuptr = bptr;
4523                 const float* wrptr = bptr + bias_data_size_g;
4524                 const float* wnptr = bptr + bias_data_size_g * 2;
4525                 const float* buptr = bptr + bias_data_size_g * 3;
4526                 const float* brptr = bptr + bias_data_size_g * 4;
4527                 const float* bnptr = bptr + bias_data_size_g * 5;
4528 
4529                 for (int j = 0; j < bias_data_size_g; j++)
4530                 {
4531                     float vb = wrptr[j] + brptr[j];
4532                     fwrite(&vb, sizeof(float), 1, bp);
4533                 }
4534                 for (int j = 0; j < bias_data_size_g; j++)
4535                 {
4536                     float vb = wuptr[j] + buptr[j];
4537                     fwrite(&vb, sizeof(float), 1, bp);
4538                 }
4539                 fwrite(wnptr, sizeof(float), bias_data_size_g, bp);
4540                 fwrite(bnptr, sizeof(float), bias_data_size_g, bp);
4541 
4542                 if (direction_type == 2)
4543                 {
4544                     wuptr += bias_data_size_g * 6;
4545                     wrptr += bias_data_size_g * 6;
4546                     wnptr += bias_data_size_g * 6;
4547                     buptr += bias_data_size_g * 6;
4548                     brptr += bias_data_size_g * 6;
4549                     bnptr += bias_data_size_g * 6;
4550 
4551                     for (int j = 0; j < bias_data_size_g; j++)
4552                     {
4553                         float vb = wrptr[j] + brptr[j];
4554                         fwrite(&vb, sizeof(float), 1, bp);
4555                     }
4556                     for (int j = 0; j < bias_data_size_g; j++)
4557                     {
4558                         float vb = wuptr[j] + buptr[j];
4559                         fwrite(&vb, sizeof(float), 1, bp);
4560                     }
4561                     fwrite(wnptr, sizeof(float), bias_data_size_g, bp);
4562                     fwrite(bnptr, sizeof(float), bias_data_size_g, bp);
4563                 }
4564             }
4565 
4566             // reorder num_directions-URN-hidden-hidden to num_directions-RUN-hidden-hidden
4567             {
4568                 fwrite(&quantize_tag, sizeof(int), 1, bp);
4569 
4570                 int weight_data_size_g = get_tensor_proto_data_size(R) / 3 / num_directions;
4571                 const float* Rptr = R.has_raw_data() ? (const float*)R.raw_data().data() : R.float_data().data();
4572 
4573                 const float* uptr = Rptr;
4574                 const float* rptr = Rptr + weight_data_size_g;
4575                 const float* nptr = Rptr + weight_data_size_g * 2;
4576                 fwrite(rptr, sizeof(float), weight_data_size_g, bp);
4577                 fwrite(uptr, sizeof(float), weight_data_size_g, bp);
4578                 fwrite(nptr, sizeof(float), weight_data_size_g, bp);
4579 
4580                 if (direction_type == 2)
4581                 {
4582                     uptr += weight_data_size_g * 3;
4583                     rptr += weight_data_size_g * 3;
4584                     nptr += weight_data_size_g * 3;
4585                     fwrite(rptr, sizeof(float), weight_data_size_g, bp);
4586                     fwrite(uptr, sizeof(float), weight_data_size_g, bp);
4587                     fwrite(nptr, sizeof(float), weight_data_size_g, bp);
4588                 }
4589             }
4590         }
4591         else if (op == "HardSigmoid")
4592         {
4593             float alpha = get_node_attr_f(node, "alpha", 0.2f);
4594             float beta = get_node_attr_f(node, "beta", 0.5f);
4595 
4596             fprintf(pp, " 0=%e", alpha);
4597             fprintf(pp, " 1=%e", beta);
4598         }
4599         else if (op == "HardSwish")
4600         {
4601             float alpha = get_node_attr_f(node, "alpha", 0.2f);
4602             float beta = get_node_attr_f(node, "beta", 0.5f);
4603 
4604             fprintf(pp, " 0=%e", alpha);
4605             fprintf(pp, " 1=%e", beta);
4606         }
4607         else if (op == "ImageScaler")
4608         {
4609             std::vector<float> bias = get_node_attr_af(node, "bias");
4610             float scale = get_node_attr_f(node, "scale", 1.f);
4611 
4612             int channels = (int)bias.size();
4613 
4614             fprintf(pp, " 0=%d", channels);
4615             fprintf(pp, " 1=1");
4616 
4617             for (int j = 0; j < channels; j++)
4618             {
4619                 fwrite(&scale, sizeof(float), 1, bp);
4620             }
4621             fwrite(&bias[0], sizeof(float), channels, bp);
4622         }
4623         else if (op == "InstanceNormalization")
4624         {
4625             float eps = get_node_attr_f(node, "epsilon", 1e-5f);
4626 
4627             // discard affine-less S=1 B=0
4628             std::vector<float> affine_S = get_node_attr_from_input_af(weights[node.input(1)]);
4629             std::vector<float> affine_B = get_node_attr_from_input_af(weights[node.input(2)]);
4630             int channels = (int)affine_S.size();
4631             int affine = 0;
4632             {
4633                 for (int j = 0; j < channels; j++)
4634                 {
4635                     if (affine_S[j] != 1.f || affine_B[j] != 0.f)
4636                     {
4637                         affine = 1;
4638                         break;
4639                     }
4640                 }
4641             }
4642 
4643             fprintf(pp, " 0=%d", channels);
4644             fprintf(pp, " 1=%e", eps);
4645             fprintf(pp, " 2=%d", affine);
4646             if (affine)
4647             {
4648                 const onnx::TensorProto& scale = weights[node.input(1)];
4649                 const onnx::TensorProto& B = weights[node.input(2)];
4650 
4651                 fwrite_tensor_proto_data(scale, bp);
4652                 fwrite_tensor_proto_data(B, bp);
4653             }
4654         }
4655         else if (op == "LayerNorm")
4656         {
4657             float eps = get_node_attr_f(node, "epsilon", 1e-5f);
4658             int affine = get_node_attr_i(node, "affine", 1);
4659 
4660             if (affine)
4661             {
4662                 // discard affine-less S=1 B=0
4663                 std::vector<float> affine_S = get_node_attr_from_input_af(weights[node.input(1)]);
4664                 std::vector<float> affine_B = get_node_attr_from_input_af(weights[node.input(2)]);
4665                 int affine_size = (int)affine_S.size();
4666                 affine = 0;
4667                 {
4668                     for (int j = 0; j < affine_size; j++)
4669                     {
4670                         if (affine_S[j] != 1.f || affine_B[j] != 0.f)
4671                         {
4672                             affine = 1;
4673                             break;
4674                         }
4675                     }
4676                 }
4677 
4678                 if (affine)
4679                 {
4680                     fprintf(pp, " 0=%d", affine_size);
4681                 }
4682             }
4683 
4684             fprintf(pp, " 1=%e", eps);
4685             fprintf(pp, " 2=%d", affine);
4686 
4687             if (affine)
4688             {
4689                 const onnx::TensorProto& scale = weights[node.input(1)];
4690                 const onnx::TensorProto& B = weights[node.input(2)];
4691 
4692                 fwrite_tensor_proto_data(scale, bp);
4693                 fwrite_tensor_proto_data(B, bp);
4694             }
4695         }
4696         else if (op == "LeakyRelu")
4697         {
4698             float alpha = get_node_attr_f(node, "alpha", 0.01f);
4699 
4700             fprintf(pp, " 0=%e", alpha);
4701         }
4702         else if (op == "Log")
4703         {
4704             int op_type = 8;
4705             fprintf(pp, " 0=%d", op_type);
4706         }
4707         else if (op == "LRN")
4708         {
4709             float alpha = get_node_attr_f(node, "alpha", 1.f);
4710             float beta = get_node_attr_f(node, "beta", 0.5f);
4711             float bias = get_node_attr_f(node, "bias", 1.f);
4712             int size = get_node_attr_i(node, "size", 1);
4713 
4714             int norm_region = 0;
4715 
4716             fprintf(pp, " 0=%d", norm_region);
4717             fprintf(pp, " 1=%d", size);
4718             fprintf(pp, " 2=%e", alpha);
4719             fprintf(pp, " 3=%e", beta);
4720             fprintf(pp, " 4=%e", bias);
4721         }
4722         else if (op == "LSTM")
4723         {
4724             const onnx::TensorProto& W = weights[node.input(1)];
4725             const onnx::TensorProto& R = weights[node.input(2)];
4726             const onnx::TensorProto& B = weights[node.input(3)];
4727 
4728             int hidden_size = get_node_attr_i(node, "hidden_size", 0);
4729             std::string direction = get_node_attr_s(node, "direction");
4730 
4731             int direction_type = 0;
4732             if (direction == "forward")
4733             {
4734                 direction_type = 0;
4735             }
4736             else if (direction == "reverse")
4737             {
4738                 direction_type = 1;
4739             }
4740             else if (direction == "bidirectional")
4741             {
4742                 direction_type = 2;
4743             }
4744 
4745             int weight_data_size = get_tensor_proto_data_size(W);
4746 
4747             fprintf(pp, " 0=%d", hidden_size);
4748             fprintf(pp, " 1=%d", weight_data_size);
4749             fprintf(pp, " 2=%d", direction_type);
4750 
4751             int num_directions = direction_type == 2 ? 2 : 1;
4752 
4753             int quantize_tag = 0;
4754 
4755             // reorder num_directions-IOFG-hidden-size to num_directions-IFOG-hidden-size
4756             {
4757                 fwrite(&quantize_tag, sizeof(int), 1, bp);
4758 
4759                 int weight_data_size_g = get_tensor_proto_data_size(W) / 4 / num_directions;
4760                 const float* wptr = W.has_raw_data() ? (const float*)W.raw_data().data() : W.float_data().data();
4761 
4762                 const float* iptr = wptr;
4763                 const float* optr = wptr + weight_data_size_g;
4764                 const float* fptr = wptr + weight_data_size_g * 2;
4765                 const float* gptr = wptr + weight_data_size_g * 3;
4766                 fwrite(iptr, sizeof(float), weight_data_size_g, bp);
4767                 fwrite(fptr, sizeof(float), weight_data_size_g, bp);
4768                 fwrite(optr, sizeof(float), weight_data_size_g, bp);
4769                 fwrite(gptr, sizeof(float), weight_data_size_g, bp);
4770 
4771                 if (direction_type == 2)
4772                 {
4773                     iptr += weight_data_size_g * 4;
4774                     optr += weight_data_size_g * 4;
4775                     fptr += weight_data_size_g * 4;
4776                     gptr += weight_data_size_g * 4;
4777                     fwrite(iptr, sizeof(float), weight_data_size_g, bp);
4778                     fwrite(fptr, sizeof(float), weight_data_size_g, bp);
4779                     fwrite(optr, sizeof(float), weight_data_size_g, bp);
4780                     fwrite(gptr, sizeof(float), weight_data_size_g, bp);
4781                 }
4782             }
4783 
4784             // reduce xc and hc bias
4785             // reorder num_directions-IOFG-hidden to num_directions-IFOG-hidden
4786             {
4787                 fwrite(&quantize_tag, sizeof(int), 1, bp);
4788 
4789                 int bias_data_size_g = get_tensor_proto_data_size(B) / 2 / 4 / num_directions;
4790                 const float* xcbptr = B.has_raw_data() ? (const float*)B.raw_data().data() : B.float_data().data();
4791                 const float* xiptr = xcbptr;
4792                 const float* xoptr = xcbptr + bias_data_size_g;
4793                 const float* xfptr = xcbptr + bias_data_size_g * 2;
4794                 const float* xgptr = xcbptr + bias_data_size_g * 3;
4795                 const float* hiptr = xcbptr + bias_data_size_g * 4;
4796                 const float* hoptr = xcbptr + bias_data_size_g * 5;
4797                 const float* hfptr = xcbptr + bias_data_size_g * 6;
4798                 const float* hgptr = xcbptr + bias_data_size_g * 7;
4799 
4800                 for (int j = 0; j < bias_data_size_g; j++)
4801                 {
4802                     float vb = xiptr[j] + hiptr[j];
4803                     fwrite(&vb, sizeof(float), 1, bp);
4804                 }
4805                 for (int j = 0; j < bias_data_size_g; j++)
4806                 {
4807                     float vb = xfptr[j] + hfptr[j];
4808                     fwrite(&vb, sizeof(float), 1, bp);
4809                 }
4810                 for (int j = 0; j < bias_data_size_g; j++)
4811                 {
4812                     float vb = xoptr[j] + hoptr[j];
4813                     fwrite(&vb, sizeof(float), 1, bp);
4814                 }
4815                 for (int j = 0; j < bias_data_size_g; j++)
4816                 {
4817                     float vb = xgptr[j] + hgptr[j];
4818                     fwrite(&vb, sizeof(float), 1, bp);
4819                 }
4820 
4821                 if (direction_type == 2)
4822                 {
4823                     xiptr += bias_data_size_g * 8;
4824                     xoptr += bias_data_size_g * 8;
4825                     xfptr += bias_data_size_g * 8;
4826                     xgptr += bias_data_size_g * 8;
4827                     hiptr += bias_data_size_g * 8;
4828                     hoptr += bias_data_size_g * 8;
4829                     hfptr += bias_data_size_g * 8;
4830                     hgptr += bias_data_size_g * 8;
4831 
4832                     for (int j = 0; j < bias_data_size_g; j++)
4833                     {
4834                         float vb = xiptr[j] + hiptr[j];
4835                         fwrite(&vb, sizeof(float), 1, bp);
4836                     }
4837                     for (int j = 0; j < bias_data_size_g; j++)
4838                     {
4839                         float vb = xfptr[j] + hfptr[j];
4840                         fwrite(&vb, sizeof(float), 1, bp);
4841                     }
4842                     for (int j = 0; j < bias_data_size_g; j++)
4843                     {
4844                         float vb = xoptr[j] + hoptr[j];
4845                         fwrite(&vb, sizeof(float), 1, bp);
4846                     }
4847                     for (int j = 0; j < bias_data_size_g; j++)
4848                     {
4849                         float vb = xgptr[j] + hgptr[j];
4850                         fwrite(&vb, sizeof(float), 1, bp);
4851                     }
4852                 }
4853             }
4854 
4855             // reorder num_directions-IOFG-hidden-hidden to num_directions-IFOG-hidden-hidden
4856             {
4857                 fwrite(&quantize_tag, sizeof(int), 1, bp);
4858 
4859                 int weight_data_size_g = get_tensor_proto_data_size(R) / 4 / num_directions;
4860                 const float* rptr = R.has_raw_data() ? (const float*)R.raw_data().data() : R.float_data().data();
4861 
4862                 const float* iptr = rptr;
4863                 const float* optr = rptr + weight_data_size_g;
4864                 const float* fptr = rptr + weight_data_size_g * 2;
4865                 const float* gptr = rptr + weight_data_size_g * 3;
4866                 fwrite(iptr, sizeof(float), weight_data_size_g, bp);
4867                 fwrite(fptr, sizeof(float), weight_data_size_g, bp);
4868                 fwrite(optr, sizeof(float), weight_data_size_g, bp);
4869                 fwrite(gptr, sizeof(float), weight_data_size_g, bp);
4870 
4871                 if (direction_type == 2)
4872                 {
4873                     iptr += weight_data_size_g * 4;
4874                     optr += weight_data_size_g * 4;
4875                     fptr += weight_data_size_g * 4;
4876                     gptr += weight_data_size_g * 4;
4877                     fwrite(iptr, sizeof(float), weight_data_size_g, bp);
4878                     fwrite(fptr, sizeof(float), weight_data_size_g, bp);
4879                     fwrite(optr, sizeof(float), weight_data_size_g, bp);
4880                     fwrite(gptr, sizeof(float), weight_data_size_g, bp);
4881                 }
4882             }
4883         }
4884         else if (op == "MatMul")
4885         {
4886             if (weights.find(node.input(1)) != weights.end() && weights[node.input(1)].dims_size() == 2)
4887             {
4888                 // InnerProduct
4889                 const onnx::TensorProto& B = weights[node.input(1)];
4890 
4891                 int weight_data_size = get_tensor_proto_data_size(B);
4892 
4893                 int num_output = B.dims(B.dims_size() - 1);
4894                 int num_input = weight_data_size / num_output;
4895 
4896                 fprintf(pp, " 0=%d", num_output);
4897                 fprintf(pp, " 1=0");
4898                 fprintf(pp, " 2=%d", weight_data_size);
4899 
4900                 int quantize_tag = 0;
4901                 fwrite(&quantize_tag, sizeof(int), 1, bp);
4902 
4903                 // reorder num_input-num_output to num_output-num_input
4904                 {
4905                     const float* bptr = B.has_raw_data() ? (const float*)B.raw_data().data() : B.float_data().data();
4906 
4907                     for (int j = 0; j < num_output; j++)
4908                     {
4909                         for (int k = 0; k < num_input; k++)
4910                         {
4911                             float vb = bptr[k * num_output + j];
4912                             fwrite(&vb, sizeof(float), 1, bp);
4913                         }
4914                     }
4915                 }
4916 
4917                 // fwrite_tensor_proto_data(B, bp)
4918             }
4919             else
4920             {
4921                 // default matrix multiplication
4922             }
4923         }
4924         else if (op == "Max")
4925         {
4926             int op_type = 4;
4927             fprintf(pp, " 0=%d", op_type);
4928 
4929             int with_scalar = get_node_attr_i(node, "with_scalar", 0);
4930             float b = get_node_attr_f(node, "b", 0.f);
4931             if (with_scalar)
4932             {
4933                 fprintf(pp, " 1=%d", with_scalar);
4934                 fprintf(pp, " 2=%e", b);
4935             }
4936         }
4937         else if (op == "Min")
4938         {
4939             int op_type = 5;
4940             fprintf(pp, " 0=%d", op_type);
4941 
4942             int with_scalar = get_node_attr_i(node, "with_scalar", 0);
4943             float b = get_node_attr_f(node, "b", 0.f);
4944             if (with_scalar)
4945             {
4946                 fprintf(pp, " 1=%d", with_scalar);
4947                 fprintf(pp, " 2=%e", b);
4948             }
4949         }
4950         else if (op == "Mul")
4951         {
4952             int op_type = 2;
4953             fprintf(pp, " 0=%d", op_type);
4954 
4955             int with_scalar = get_node_attr_i(node, "with_scalar", 0);
4956             float b = get_node_attr_f(node, "b", 0.f);
4957             if (with_scalar)
4958             {
4959                 fprintf(pp, " 1=%d", with_scalar);
4960                 fprintf(pp, " 2=%e", b);
4961             }
4962         }
4963         else if (op == "MultiHeadAttention")
4964         {
4965             int embed_dim = get_node_attr_i(node, "embed_dim", 0);
4966             int num_heads = get_node_attr_i(node, "num_heads", 0);
4967 
4968             fprintf(pp, " 0=%d", embed_dim);
4969             fprintf(pp, " 1=%d", num_heads);
4970 
4971             if (node.input_size() == 5)
4972             {
4973                 const onnx::TensorProto& qkvw = weights[node.input(1)];
4974                 const onnx::TensorProto& qkvb = weights[node.input(2)];
4975                 const onnx::TensorProto& ow = weights[node.input(3)];
4976                 const onnx::TensorProto& ob = weights[node.input(4)];
4977 
4978                 int weight_data_size = get_tensor_proto_data_size(ow);
4979 
4980                 fprintf(pp, " 2=%d", weight_data_size);
4981 
4982                 int quantize_tag = 0;
4983 
4984                 fwrite(&quantize_tag, sizeof(int), 1, bp);
4985                 // transpose qw
4986                 {
4987                     const float* wptr = qkvw.has_raw_data() ? (const float*)qkvw.raw_data().data() : qkvw.float_data().data();
4988                     const float* bptr = qkvb.has_raw_data() ? (const float*)qkvb.raw_data().data() : qkvb.float_data().data();
4989 
4990                     for (int j = 0; j < embed_dim; j++)
4991                     {
4992                         for (int k = 0; k < embed_dim; k++)
4993                         {
4994                             float vb = wptr[k * embed_dim * 3 + j];
4995                             fwrite(&vb, sizeof(float), 1, bp);
4996                         }
4997                     }
4998 
4999                     fwrite(bptr, sizeof(float), embed_dim, bp);
5000                 }
5001 
5002                 fwrite(&quantize_tag, sizeof(int), 1, bp);
5003                 // transpose kw
5004                 {
5005                     const float* wptr = qkvw.has_raw_data() ? (const float*)qkvw.raw_data().data() : qkvw.float_data().data();
5006                     const float* bptr = qkvb.has_raw_data() ? (const float*)qkvb.raw_data().data() : qkvb.float_data().data();
5007                     bptr += embed_dim;
5008 
5009                     for (int j = 0; j < embed_dim; j++)
5010                     {
5011                         for (int k = 0; k < embed_dim; k++)
5012                         {
5013                             float vb = wptr[k * embed_dim * 3 + j + embed_dim];
5014                             fwrite(&vb, sizeof(float), 1, bp);
5015                         }
5016                     }
5017 
5018                     fwrite(bptr, sizeof(float), embed_dim, bp);
5019                 }
5020 
5021                 fwrite(&quantize_tag, sizeof(int), 1, bp);
5022                 // transpose vw
5023                 {
5024                     const float* wptr = qkvw.has_raw_data() ? (const float*)qkvw.raw_data().data() : qkvw.float_data().data();
5025                     const float* bptr = qkvb.has_raw_data() ? (const float*)qkvb.raw_data().data() : qkvb.float_data().data();
5026                     bptr += embed_dim * 2;
5027 
5028                     for (int j = 0; j < embed_dim; j++)
5029                     {
5030                         for (int k = 0; k < embed_dim; k++)
5031                         {
5032                             float vb = wptr[k * embed_dim * 3 + j + embed_dim * 2];
5033                             fwrite(&vb, sizeof(float), 1, bp);
5034                         }
5035                     }
5036 
5037                     fwrite(bptr, sizeof(float), embed_dim, bp);
5038                 }
5039 
5040                 fwrite(&quantize_tag, sizeof(int), 1, bp);
5041                 // transpose ow
5042                 {
5043                     const float* wptr = ow.has_raw_data() ? (const float*)ow.raw_data().data() : ow.float_data().data();
5044 
5045                     for (int j = 0; j < embed_dim; j++)
5046                     {
5047                         for (int k = 0; k < embed_dim; k++)
5048                         {
5049                             float vb = wptr[k * embed_dim + j];
5050                             fwrite(&vb, sizeof(float), 1, bp);
5051                         }
5052                     }
5053                 }
5054                 fwrite_tensor_proto_data(ob, bp);
5055             }
5056             else
5057             {
5058                 const onnx::TensorProto& qw = weights[node.input(3)];
5059                 const onnx::TensorProto& qb = weights[node.input(4)];
5060                 const onnx::TensorProto& kw = weights[node.input(5)];
5061                 const onnx::TensorProto& kb = weights[node.input(6)];
5062                 const onnx::TensorProto& vw = weights[node.input(7)];
5063                 const onnx::TensorProto& vb = weights[node.input(8)];
5064                 const onnx::TensorProto& ow = weights[node.input(9)];
5065                 const onnx::TensorProto& ob = weights[node.input(10)];
5066 
5067                 int weight_data_size = get_tensor_proto_data_size(qw);
5068 
5069                 fprintf(pp, " 2=%d", weight_data_size);
5070 
5071                 int quantize_tag = 0;
5072 
5073                 fwrite(&quantize_tag, sizeof(int), 1, bp);
5074                 // transpose qw
5075                 {
5076                     const float* wptr = qw.has_raw_data() ? (const float*)qw.raw_data().data() : qw.float_data().data();
5077 
5078                     for (int j = 0; j < embed_dim; j++)
5079                     {
5080                         for (int k = 0; k < embed_dim; k++)
5081                         {
5082                             float vb = wptr[k * embed_dim + j];
5083                             fwrite(&vb, sizeof(float), 1, bp);
5084                         }
5085                     }
5086                 }
5087                 fwrite_tensor_proto_data(qb, bp);
5088 
5089                 fwrite(&quantize_tag, sizeof(int), 1, bp);
5090                 // transpose kw
5091                 {
5092                     const float* wptr = kw.has_raw_data() ? (const float*)kw.raw_data().data() : kw.float_data().data();
5093 
5094                     for (int j = 0; j < embed_dim; j++)
5095                     {
5096                         for (int k = 0; k < embed_dim; k++)
5097                         {
5098                             float vb = wptr[k * embed_dim + j];
5099                             fwrite(&vb, sizeof(float), 1, bp);
5100                         }
5101                     }
5102                 }
5103                 fwrite_tensor_proto_data(kb, bp);
5104 
5105                 fwrite(&quantize_tag, sizeof(int), 1, bp);
5106                 // transpose vw
5107                 {
5108                     const float* wptr = vw.has_raw_data() ? (const float*)vw.raw_data().data() : vw.float_data().data();
5109 
5110                     for (int j = 0; j < embed_dim; j++)
5111                     {
5112                         for (int k = 0; k < embed_dim; k++)
5113                         {
5114                             float vb = wptr[k * embed_dim + j];
5115                             fwrite(&vb, sizeof(float), 1, bp);
5116                         }
5117                     }
5118                 }
5119                 fwrite_tensor_proto_data(vb, bp);
5120 
5121                 fwrite(&quantize_tag, sizeof(int), 1, bp);
5122                 // transpose ow
5123                 {
5124                     const float* wptr = ow.has_raw_data() ? (const float*)ow.raw_data().data() : ow.float_data().data();
5125 
5126                     for (int j = 0; j < embed_dim; j++)
5127                     {
5128                         for (int k = 0; k < embed_dim; k++)
5129                         {
5130                             float vb = wptr[k * embed_dim + j];
5131                             fwrite(&vb, sizeof(float), 1, bp);
5132                         }
5133                     }
5134                 }
5135                 fwrite_tensor_proto_data(ob, bp);
5136             }
5137         }
5138         else if (op == "Neg")
5139         {
5140             int op_type = 1;
5141             fprintf(pp, " 0=%d", op_type);
5142         }
5143         else if (op == "Normalize")
5144         {
5145             float eps = get_node_attr_f(node, "eps", 0.f);
5146             int scale_data_size = 1;
5147 
5148             fprintf(pp, " 1=1"); // channel_shared
5149             fprintf(pp, " 2=%e", eps);
5150             fprintf(pp, " 3=%d", scale_data_size);
5151             fprintf(pp, " 9=1"); // TODO hardcode pytorch style
5152 
5153             const float scale_data[1] = {1.f};
5154             fwrite(scale_data, sizeof(float), 1, bp);
5155         }
5156         else if (op == "Pad")
5157         {
5158             std::string mode = get_node_attr_s(node, "mode");
5159             float value = get_node_attr_f(node, "value", 0.f);
5160 
5161             std::vector<int> pads;
5162             if (node.input_size() == 1)
5163             {
5164                 pads = get_node_attr_ai(node, "pads");
5165             }
5166             else
5167             {
5168                 pads = get_node_attr_from_input_ai(weights[node.input(1)]);
5169             }
5170 
5171             int type = 0;
5172             if (mode == "constant")
5173             {
5174                 type = 0;
5175             }
5176             else if (mode == "edge")
5177             {
5178                 type = 1;
5179             }
5180             else if (mode == "reflect")
5181             {
5182                 type = 2;
5183             }
5184 
5185             int pad_size = (int)pads.size();
5186             int top = 0;
5187             int bottom = 0;
5188             int left = 0;
5189             int right = 0;
5190             int front = 0;
5191             int behind = 0;
5192             if (pad_size == 8)
5193             {
5194                 //NCHW
5195                 top = pads[2];
5196                 bottom = pads[6];
5197                 left = pads[3];
5198                 right = pads[7];
5199                 front = pads[1];
5200                 behind = pads[5];
5201             }
5202             else if (pad_size == 6)
5203             {
5204                 //NHW
5205                 top = pads[1];
5206                 bottom = pads[4];
5207                 left = pads[2];
5208                 right = pads[5];
5209             }
5210             else
5211             {
5212                 //NW
5213                 left = pads[1];
5214                 right = pads[3];
5215             }
5216 
5217             fprintf(pp, " 0=%d", top);
5218             fprintf(pp, " 1=%d", bottom);
5219             fprintf(pp, " 2=%d", left);
5220             fprintf(pp, " 3=%d", right);
5221             fprintf(pp, " 4=%d", type);
5222             fprintf(pp, " 5=%e", value);
5223             fprintf(pp, " 7=%d", front);
5224             fprintf(pp, " 8=%d", behind);
5225         }
5226         else if (op == "Pow")
5227         {
5228             int op_type = 6;
5229             fprintf(pp, " 0=%d", op_type);
5230 
5231             int with_scalar = get_node_attr_i(node, "with_scalar", 0);
5232             float b = get_node_attr_f(node, "b", 0.f);
5233             if (with_scalar)
5234             {
5235                 fprintf(pp, " 1=%d", with_scalar);
5236                 fprintf(pp, " 2=%e", b);
5237             }
5238         }
5239         else if (op == "PixelShuffle")
5240         {
5241             int scale_factor = get_node_attr_i(node, "scale_factor", 1);
5242             fprintf(pp, " 0=%d", scale_factor);
5243         }
5244         else if (op == "PRelu")
5245         {
5246             const onnx::TensorProto& slope = weights[node.input(1)];
5247 
5248             int num_slope = get_tensor_proto_data_size(slope);
5249 
5250             fprintf(pp, " 0=%d", num_slope);
5251 
5252             fwrite_tensor_proto_data(slope, bp);
5253         }
5254         else if (op == "Reciprocal")
5255         {
5256             int op_type = 15;
5257             fprintf(pp, " 0=%d", op_type);
5258         }
5259         else if (op == "ReduceMax" || op == "ReduceMin" || op == "ReduceMean" || op == "ReduceProd" || op == "ReduceSum" || op == "ReduceSumSquare" || op == "ReduceL1" || op == "ReduceL2" || op == "ReduceLogSum" || op == "ReduceLogSumExp")
5260         {
5261             int op_type = -233;
5262             if (op == "ReduceSum")
5263                 op_type = 0;
5264             else if (op == "ReduceSumSquare")
5265                 op_type = 2;
5266             else if (op == "ReduceMean")
5267                 op_type = 3;
5268             else if (op == "ReduceMax")
5269                 op_type = 4;
5270             else if (op == "ReduceMin")
5271                 op_type = 5;
5272             else if (op == "ReduceProd")
5273                 op_type = 6;
5274             else if (op == "ReduceL1")
5275                 op_type = 7;
5276             else if (op == "ReduceL2")
5277                 op_type = 8;
5278             else if (op == "ReduceLogSum")
5279                 op_type = 9;
5280             else if (op == "ReduceLogSumExp")
5281                 op_type = 10;
5282             fprintf(pp, " 0=%d", op_type);
5283 
5284             std::vector<int> axes = get_node_attr_ai(node, "axes");
5285             int keepdims = get_node_attr_i(node, "keepdims", 1);
5286 
5287             if (axes.size() > 0)
5288             {
5289                 // if axes set, reduce according to axes
5290                 fprintf(pp, " 1=%d", 0);
5291                 fprintf(pp, " -23303=%zu", axes.size());
5292                 for (size_t j = 0; j < axes.size(); j++)
5293                 {
5294                     if (axes[j] == 0 || axes[j] > 3 || axes[j] < -3)
5295                         fprintf(stderr, "Unsupported reduction axes !\n");
5296                     fprintf(pp, ",%d", axes[j]);
5297                 }
5298             }
5299             else
5300             {
5301                 // if axes not set, reduce all axes by default
5302                 fprintf(pp, " 1=%d", 1);
5303             }
5304             fprintf(pp, " 4=%d", keepdims);
5305         }
5306         else if (op == "Reorg")
5307         {
5308             int stride = get_node_attr_i(node, "stride", 1);
5309             fprintf(pp, " 0=%d", stride);
5310         }
5311         else if (op == "Reshape")
5312         {
5313             std::vector<int> shape;
5314 
5315             if (node.input_size() == 1)
5316             {
5317                 shape = get_node_attr_ai(node, "shape");
5318             }
5319             else
5320             {
5321                 shape = get_node_attr_from_input_ai(weights[node.input(1)]);
5322             }
5323 
5324             if (shape.size() == 1)
5325             {
5326                 fprintf(pp, " 0=%d", shape[0]); // should never reach here
5327             }
5328             else if (shape.size() == 2)
5329             {
5330                 fprintf(pp, " 0=%d", shape[1]);
5331             }
5332             else if (shape.size() == 3)
5333             {
5334                 fprintf(pp, " 0=%d", shape[2]);
5335                 fprintf(pp, " 1=%d", shape[1]);
5336             }
5337             else if (shape.size() == 4)
5338             {
5339                 fprintf(pp, " 0=%d", shape[3]);
5340                 fprintf(pp, " 1=%d", shape[2]);
5341                 fprintf(pp, " 2=%d", shape[1]);
5342             }
5343             else if (shape.size() == 5)
5344             {
5345                 fprintf(pp, " 0=%d", shape[4] * shape[3]);
5346                 fprintf(pp, " 1=%d", shape[2]);
5347                 fprintf(pp, " 2=%d", shape[1]);
5348             }
5349         }
5350         else if (op == "Resize")
5351         {
5352             std::string mode = get_node_attr_s(node, "mode");
5353             std::string align = get_node_attr_s(node, "coordinate_transformation_mode");
5354 
5355             std::vector<float> scales;
5356             std::vector<int> sizes;
5357             if (node.input_size() == 2)
5358             {
5359                 // opset 10
5360                 scales = get_node_attr_from_input_af(weights[node.input(1)]);
5361             }
5362             else
5363             {
5364                 // opset 11+
5365                 scales = get_node_attr_from_input_af(weights[node.input(2)]);
5366                 if (node.input_size() >= 4)
5367                 {
5368                     sizes = get_node_attr_from_input_ai(weights[node.input(3)]);
5369                 }
5370             }
5371 
5372             int resize_type = 1;
5373             if (mode == "nearest")
5374             {
5375                 resize_type = 1;
5376             }
5377             else if (mode == "linear")
5378             {
5379                 resize_type = 2;
5380             }
5381             else if (mode == "cubic")
5382             {
5383                 resize_type = 3;
5384             }
5385 
5386             if (scales.empty() && sizes.empty())
5387             {
5388                 fprintf(stderr, "Unsupported Resize scales and sizes are all empty!\n");
5389             }
5390 
5391             float h_scale = 1.f;
5392             float w_scale = 1.f;
5393             if (scales.size() == 2)
5394             {
5395                 w_scale = scales[1];
5396             }
5397             else if (scales.size() == 3)
5398             {
5399                 h_scale = scales[1];
5400                 w_scale = scales[2];
5401             }
5402             else if (scales.size() == 4)
5403             {
5404                 h_scale = scales[2];
5405                 w_scale = scales[3];
5406 
5407                 if (scales[1] != 1.f)
5408                     fprintf(stderr, "Unsupported Resize scales !\n");
5409             }
5410 
5411             int output_height = 0;
5412             int output_width = 0;
5413             if (sizes.size() == 2)
5414             {
5415                 output_width = sizes[1];
5416             }
5417             else if (sizes.size() == 3)
5418             {
5419                 output_height = sizes[1];
5420                 output_width = sizes[2];
5421             }
5422             else if (sizes.size() == 4)
5423             {
5424                 output_height = sizes[2];
5425                 output_width = sizes[3];
5426             }
5427 
5428             int align_corner = 0;
5429             if (align == "align_corners")
5430             {
5431                 align_corner = 1;
5432             }
5433 
5434             fprintf(pp, " 0=%d", resize_type);
5435             fprintf(pp, " 1=%e", h_scale);
5436             fprintf(pp, " 2=%e", w_scale);
5437             fprintf(pp, " 3=%d", output_height);
5438             fprintf(pp, " 4=%d", output_width);
5439             fprintf(pp, " 6=%d", align_corner);
5440         }
5441         else if (op == "RNN")
5442         {
5443             const onnx::TensorProto& W = weights[node.input(1)];
5444             const onnx::TensorProto& R = weights[node.input(2)];
5445             const onnx::TensorProto& B = weights[node.input(3)];
5446 
5447             int hidden_size = get_node_attr_i(node, "hidden_size", 0);
5448             std::string direction = get_node_attr_s(node, "direction");
5449 
5450             int direction_type = 0;
5451             if (direction == "forward")
5452             {
5453                 direction_type = 0;
5454             }
5455             else if (direction == "reverse")
5456             {
5457                 direction_type = 1;
5458             }
5459             else if (direction == "bidirectional")
5460             {
5461                 direction_type = 2;
5462             }
5463 
5464             int weight_data_size = get_tensor_proto_data_size(W);
5465 
5466             fprintf(pp, " 0=%d", hidden_size);
5467             fprintf(pp, " 1=%d", weight_data_size);
5468             fprintf(pp, " 2=%d", direction_type);
5469 
5470             int num_directions = direction_type == 2 ? 2 : 1;
5471 
5472             int quantize_tag = 0;
5473 
5474             fwrite(&quantize_tag, sizeof(int), 1, bp);
5475             fwrite_tensor_proto_data(W, bp);
5476 
5477             // reduce xc and hc bias
5478             {
5479                 fwrite(&quantize_tag, sizeof(int), 1, bp);
5480 
5481                 int bias_data_size_g = get_tensor_proto_data_size(B) / 2 / num_directions;
5482                 const float* bptr = B.has_raw_data() ? (const float*)B.raw_data().data() : B.float_data().data();
5483                 const float* xiptr = bptr;
5484                 const float* hiptr = bptr + bias_data_size_g;
5485 
5486                 for (int j = 0; j < bias_data_size_g; j++)
5487                 {
5488                     float vb = xiptr[j] + hiptr[j];
5489                     fwrite(&vb, sizeof(float), 1, bp);
5490                 }
5491 
5492                 if (direction_type == 2)
5493                 {
5494                     xiptr += bias_data_size_g * 2;
5495                     hiptr += bias_data_size_g * 2;
5496 
5497                     for (int j = 0; j < bias_data_size_g; j++)
5498                     {
5499                         float vb = xiptr[j] + hiptr[j];
5500                         fwrite(&vb, sizeof(float), 1, bp);
5501                     }
5502                 }
5503             }
5504 
5505             fwrite(&quantize_tag, sizeof(int), 1, bp);
5506             fwrite_tensor_proto_data(R, bp);
5507         }
5508         else if (op == "ShuffleChannel")
5509         {
5510             int group = get_node_attr_i(node, "group", 1);
5511             int reverse = get_node_attr_i(node, "reverse", 0);
5512             fprintf(pp, " 0=%d", group);
5513             fprintf(pp, " 1=%d", reverse);
5514         }
5515         else if (op == "Sigmoid")
5516         {
5517             // no param
5518         }
5519         else if (op == "Sin")
5520         {
5521             int op_type = 9;
5522             fprintf(pp, " 0=%d", op_type);
5523         }
5524         else if (op == "SkipLayerNormalization")
5525         {
5526             const onnx::TensorProto& W = weights[node.input(2)];
5527             const onnx::TensorProto& B = weights[node.input(3)];
5528             const onnx::TensorProto& B2 = weights[node.input(4)];
5529 
5530             fprintf(pp, " 0=%d", get_tensor_proto_data_size(B));
5531 
5532             int quantize_tag = 0;
5533             fwrite(&quantize_tag, sizeof(int), 1, bp);
5534 
5535             fwrite_tensor_proto_data(W, bp);
5536 
5537             fwrite(&quantize_tag, sizeof(int), 1, bp);
5538 
5539             fwrite_tensor_proto_data(B, bp);
5540 
5541             fwrite(&quantize_tag, sizeof(int), 1, bp);
5542 
5543             fwrite_tensor_proto_data(B2, bp);
5544         }
5545         else if (op == "Slice")
5546         {
5547             std::vector<int> starts;
5548             std::vector<int> ends;
5549             std::vector<int> axes;
5550             std::vector<int> steps;
5551             if (node.input_size() == 1)
5552             {
5553                 starts = get_node_attr_ai(node, "starts");
5554                 ends = get_node_attr_ai(node, "ends");
5555                 axes = get_node_attr_ai(node, "axes");
5556                 steps = get_node_attr_ai(node, "steps"); // TODO
5557             }
5558             else
5559             {
5560                 starts = get_node_attr_from_input_ai(weights[node.input(1)]);
5561                 ends = get_node_attr_from_input_ai(weights[node.input(2)]);
5562                 if (node.input_size() >= 4)
5563                     axes = get_node_attr_from_input_ai(weights[node.input(3)]);
5564                 if (node.input_size() >= 5)
5565                     steps = get_node_attr_from_input_ai(weights[node.input(4)]);
5566             }
5567 
5568             // assert step == 1
5569             for (int i = 0; i < (int)steps.size(); i++)
5570             {
5571                 if (steps[i] != 1)
5572                     fprintf(stderr, "Unsupported slice step !\n");
5573             }
5574 
5575             // filter out N-dim axis
5576             if (!axes.empty())
5577             {
5578                 for (int i = 0; i < (int)axes.size(); i++)
5579                 {
5580                     int axis = axes[i];
5581                     if (axis == 0)
5582                     {
5583                         starts.erase(starts.begin() + i);
5584                         ends.erase(ends.begin() + i);
5585                         axes.erase(axes.begin() + i);
5586                         break;
5587                     }
5588                 }
5589             }
5590 
5591             fprintf(pp, " -23309=%d", (int)starts.size());
5592             for (int i = 0; i < (int)starts.size(); i++)
5593             {
5594                 fprintf(pp, ",%d", starts[i]);
5595             }
5596             fprintf(pp, " -23310=%d", (int)ends.size());
5597             for (int i = 0; i < (int)ends.size(); i++)
5598             {
5599                 fprintf(pp, ",%d", ends[i]);
5600             }
5601             if (!axes.empty())
5602             {
5603                 fprintf(pp, " -23311=%d", (int)axes.size());
5604                 for (int i = 0; i < (int)axes.size(); i++)
5605                 {
5606                     int axis = axes[i];
5607                     if (axis == 0 || axis > 3 || axis < -3)
5608                         fprintf(stderr, "Unsupported slice axes !\n");
5609 
5610                     if (axis > 0)
5611                         axis = axis - 1; // -1 for skip N-dim
5612 
5613                     fprintf(pp, ",%d", axis);
5614                 }
5615             }
5616         }
5617         else if (op == "Softmax")
5618         {
5619             int axis = get_node_attr_i(node, "axis", 1);
5620             fprintf(pp, " 0=%d", axis - 1);
5621             fprintf(pp, " 1=1");
5622         }
5623         else if (op == "Split")
5624         {
5625             int axis = get_node_attr_i(node, "axis", 0);
5626             std::vector<int> split = get_node_attr_ai(node, "split");
5627             if (axis < 1)
5628                 fprintf(stderr, "Unsupported split axis !\n");
5629 
5630             fprintf(pp, " -23300=%d", output_size);
5631             if (split.empty())
5632             {
5633                 for (int i = 0; i < output_size; i++)
5634                 {
5635                     fprintf(pp, ",-233");
5636                 }
5637             }
5638             else
5639             {
5640                 for (size_t i = 0; i < split.size() - 1; i++)
5641                 {
5642                     fprintf(pp, ",%d", split[i]);
5643                 }
5644                 fprintf(pp, ",-233");
5645             }
5646             fprintf(pp, " 1=%d", axis - 1);
5647         }
5648         else if (op == "Sqrt")
5649         {
5650             int op_type = 5;
5651             fprintf(pp, " 0=%d", op_type);
5652         }
5653         else if (op == "Squeeze")
5654         {
5655             std::vector<int> axes = get_node_attr_ai(node, "axes");
5656 
5657             if (axes.empty())
5658             {
5659                 fprintf(pp, " 0=1");
5660                 fprintf(pp, " 1=1");
5661                 fprintf(pp, " 2=1");
5662             }
5663             else
5664             {
5665                 fprintf(pp, " -23303=%zu", axes.size());
5666                 for (int i = 0; i < (int)axes.size(); i++)
5667                 {
5668                     if (axes[i] == 0 || axes[i] > 3 || axes[i] < -3)
5669                         fprintf(stderr, "Unsupported squeeze axes !\n");
5670                     fprintf(pp, ",%d", axes[i]);
5671                 }
5672             }
5673         }
5674         else if (op == "Sub")
5675         {
5676             int op_type = 1;
5677             fprintf(pp, " 0=%d", op_type);
5678 
5679             int with_scalar = get_node_attr_i(node, "with_scalar", 0);
5680             float b = get_node_attr_f(node, "b", 0.f);
5681             if (with_scalar)
5682             {
5683                 fprintf(pp, " 1=%d", with_scalar);
5684                 fprintf(pp, " 2=%e", b);
5685             }
5686         }
5687         else if (op == "Sum")
5688         {
5689             int op_type = 1;
5690             fprintf(pp, " 0=%d", op_type);
5691         }
5692         else if (op == "Swish")
5693         {
5694             // no param
5695         }
5696         else if (op == "Tan")
5697         {
5698             int op_type = 11;
5699             fprintf(pp, " 0=%d", op_type);
5700         }
5701         else if (op == "Tanh")
5702         {
5703             int op_type = 16;
5704             fprintf(pp, " 0=%d", op_type);
5705         }
5706         else if (op == "Transpose")
5707         {
5708             std::vector<int> perm = get_node_attr_ai(node, "perm");
5709 
5710             if (perm.size() == 3)
5711             {
5712                 if (perm[1] == 1 && perm[2] == 2)
5713                     fprintf(pp, " 0=0"); // w h
5714                 else if (perm[1] == 2 && perm[2] == 1)
5715                     fprintf(pp, " 0=1"); // h w
5716                 else if (perm[0] == 1 && perm[1] == 0 && perm[2] == 2)
5717                     fprintf(pp, " 0=0"); // w h
5718                 else if (perm[0] == 2 && perm[1] == 0 && perm[2] == 1)
5719                     fprintf(pp, " 0=1"); // h w
5720             }
5721             else if (perm.size() == 4)
5722             {
5723                 if (perm[1] == 1 && perm[2] == 2 && perm[3] == 3)
5724                     fprintf(pp, " 0=0"); // w h c
5725                 else if (perm[1] == 1 && perm[2] == 3 && perm[3] == 2)
5726                     fprintf(pp, " 0=1"); // h w c
5727                 else if (perm[1] == 2 && perm[2] == 1 && perm[3] == 3)
5728                     fprintf(pp, " 0=2"); // w c h
5729                 else if (perm[1] == 2 && perm[2] == 3 && perm[3] == 1)
5730                     fprintf(pp, " 0=3"); // c w h
5731                 else if (perm[1] == 3 && perm[2] == 1 && perm[3] == 2)
5732                     fprintf(pp, " 0=4"); // h c w
5733                 else if (perm[1] == 3 && perm[2] == 2 && perm[3] == 1)
5734                     fprintf(pp, " 0=5"); // c h w
5735             }
5736             else if (perm.size() == 5)
5737             {
5738                 if (perm[1] == 1 && perm[2] == 2 && perm[3] == 3 && perm[4] == 4)
5739                     fprintf(pp, " 0=0"); // wx h c
5740                 else if (perm[1] == 1 && perm[2] == 3 && perm[3] == 4 && perm[4] == 2)
5741                     fprintf(pp, " 0=1"); // h wx c
5742                 else if (perm[1] == 2 && perm[2] == 1 && perm[3] == 3 && perm[4] == 4)
5743                     fprintf(pp, " 0=2"); // wx c h
5744                 else if (perm[1] == 2 && perm[2] == 3 && perm[3] == 4 && perm[4] == 1)
5745                     fprintf(pp, " 0=3"); // c wx h
5746                 else if (perm[1] == 3 && perm[2] == 4 && perm[3] == 1 && perm[4] == 2)
5747                     fprintf(pp, " 0=4"); // h c wx
5748                 else if (perm[1] == 3 && perm[2] == 4 && perm[3] == 2 && perm[4] == 1)
5749                     fprintf(pp, " 0=5"); // c h wx
5750                 else
5751                     fprintf(stderr, "Unsupported transpose type !\n");
5752             }
5753         }
5754         else if (op == "Upsample")
5755         {
5756             std::string mode = get_node_attr_s(node, "mode");
5757             std::string align = get_node_attr_s(node, "coordinate_transformation_mode");
5758 
5759             std::vector<float> scales;
5760 
5761             if (node.input_size() == 1)
5762             {
5763                 scales = get_node_attr_af(node, "scales");
5764             }
5765             else
5766             {
5767                 scales = get_node_attr_from_input_af(weights[node.input(1)]);
5768             }
5769 
5770             int resize_type = 1;
5771             if (mode == "nearest")
5772             {
5773                 resize_type = 1;
5774             }
5775             else if (mode == "bilinear" || mode == "linear")
5776             {
5777                 resize_type = 2;
5778             }
5779             else if (mode == "trilinear")
5780             {
5781                 fprintf(stderr, "Unsupported Upsample mode !\n");
5782             }
5783 
5784             float h_scale = 1.f;
5785             float w_scale = 1.f;
5786             if (scales.size() == 2)
5787             {
5788                 w_scale = scales[1];
5789             }
5790             else if (scales.size() == 3)
5791             {
5792                 h_scale = scales[1];
5793                 w_scale = scales[2];
5794             }
5795             else if (scales.size() == 4)
5796             {
5797                 h_scale = scales[2];
5798                 w_scale = scales[3];
5799 
5800                 if (scales[1] != 1.f)
5801                     fprintf(stderr, "Unsupported Upsample scales !\n");
5802             }
5803             else
5804             {
5805                 fprintf(stderr, "Unsupported Upsample scales !\n");
5806             }
5807 
5808             int align_corner = 0;
5809             if (align == "align_corners")
5810             {
5811                 align_corner = 1;
5812             }
5813 
5814             fprintf(pp, " 0=%d", resize_type);
5815             fprintf(pp, " 1=%e", h_scale);
5816             fprintf(pp, " 2=%e", w_scale);
5817             fprintf(pp, " 6=%d", align_corner);
5818         }
5819         else if (op == "Unsqueeze")
5820         {
5821             std::vector<int> axes = get_node_attr_ai(node, "axes");
5822 
5823             fprintf(pp, " -23303=%zu", axes.size());
5824             for (int i = 0; i < (int)axes.size(); i++)
5825             {
5826                 if (axes[i] == 0 || axes[i] > 4 || axes[i] < -4)
5827                     fprintf(stderr, "Unsupported unsqueeze axes !\n");
5828                 fprintf(pp, ",%d", axes[i]);
5829             }
5830         }
5831         else
5832         {
5833             // TODO op specific param
5834             for (int j = 0; j < node.attribute_size(); j++)
5835             {
5836                 const onnx::AttributeProto& attr = node.attribute(j);
5837                 if (attr.type() == 1)
5838                 {
5839                     fprintf(stderr, "  # %s=%g\n", attr.name().c_str(), attr.f());
5840                 }
5841                 else if (attr.type() == 2)
5842                 {
5843                     fprintf(stderr, "  # %s=%lld\n", attr.name().c_str(), (long long)attr.i());
5844                 }
5845                 else if (attr.type() == 3)
5846                 {
5847                     fprintf(stderr, "  # %s=%s\n", attr.name().c_str(), attr.s().c_str());
5848                 }
5849                 else
5850                 {
5851                     fprintf(stderr, "  # %s %d\n", attr.name().c_str(), attr.type());
5852                 }
5853             }
5854         }
5855 
5856         fprintf(pp, "\n");
5857 
5858         for (int j = 0; j < output_size; j++)
5859         {
5860             const std::string& output_name = node.output(j);
5861             if (node_reference.find(output_name) != node_reference.end())
5862             {
5863                 int refcount = node_reference[output_name];
5864                 if (refcount > 1)
5865                 {
5866                     char splitname[256];
5867                     sprintf(splitname, "splitncnn_%d", internal_split);
5868                     fprintf(pp, "%-16s %-24s %d %d", "Split", splitname, 1, refcount);
5869 
5870                     fprintf(pp, " %s", output_name.c_str());
5871 
5872                     for (int k = 0; k < refcount; k++)
5873                     {
5874                         fprintf(pp, " %s_splitncnn_%d", output_name.c_str(), k);
5875                     }
5876                     fprintf(pp, "\n");
5877 
5878                     internal_split++;
5879                 }
5880             }
5881         }
5882     }
5883 
5884     fclose(pp);
5885     fclose(bp);
5886 
5887     return 0;
5888 }
5889