1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14
15 #include "onnx.pb.h"
16
17 #include <algorithm>
18 #include <float.h>
19 #include <fstream>
20 #include <google/protobuf/io/coded_stream.h>
21 #include <google/protobuf/io/zero_copy_stream_impl.h>
22 #include <google/protobuf/message.h>
23 #include <google/protobuf/text_format.h>
24 #include <iostream>
25 #include <limits.h>
26 #include <limits>
27 #include <set>
28 #include <stdio.h>
29
read_proto_from_binary(const char * filepath,onnx::ModelProto * message)30 static bool read_proto_from_binary(const char* filepath, onnx::ModelProto* message)
31 {
32 std::ifstream fs(filepath, std::ifstream::in | std::ifstream::binary);
33 if (!fs.is_open())
34 {
35 fprintf(stderr, "open failed %s\n", filepath);
36 return false;
37 }
38
39 google::protobuf::io::IstreamInputStream input(&fs);
40 google::protobuf::io::CodedInputStream codedstr(&input);
41
42 #if GOOGLE_PROTOBUF_VERSION >= 3011000
43 codedstr.SetTotalBytesLimit(INT_MAX);
44 #else
45 codedstr.SetTotalBytesLimit(INT_MAX, INT_MAX / 2);
46 #endif
47
48 bool success = message->ParseFromCodedStream(&codedstr);
49
50 fs.close();
51
52 return success;
53 }
54
get_node_attr_ai(const onnx::NodeProto & node,const char * key)55 static std::vector<int> get_node_attr_ai(const onnx::NodeProto& node, const char* key)
56 {
57 std::vector<int> v;
58
59 for (int i = 0; i < node.attribute_size(); i++)
60 {
61 const onnx::AttributeProto& attr = node.attribute(i);
62 if (attr.name() == key)
63 {
64 v.resize(attr.ints_size());
65 for (int j = 0; j < attr.ints_size(); j++)
66 {
67 v[j] = std::max(std::min(attr.ints(j), (::google::protobuf::int64)INT_MAX), (::google::protobuf::int64)INT_MIN);
68 }
69
70 break;
71 }
72 }
73
74 return v;
75 }
76
get_node_attr_af(const onnx::NodeProto & node,const char * key)77 static std::vector<float> get_node_attr_af(const onnx::NodeProto& node, const char* key)
78 {
79 std::vector<float> v;
80
81 for (int i = 0; i < node.attribute_size(); i++)
82 {
83 const onnx::AttributeProto& attr = node.attribute(i);
84 if (attr.name() == key)
85 {
86 v.resize(attr.floats_size());
87 for (int j = 0; j < attr.floats_size(); j++)
88 {
89 v[j] = attr.floats(j);
90 }
91
92 break;
93 }
94 }
95
96 return v;
97 }
98
get_node_attr_i(const onnx::NodeProto & node,const char * key,int def=0)99 static int get_node_attr_i(const onnx::NodeProto& node, const char* key, int def = 0)
100 {
101 for (int i = 0; i < node.attribute_size(); i++)
102 {
103 const onnx::AttributeProto& attr = node.attribute(i);
104 if (attr.name() == key)
105 {
106 return std::max(std::min(attr.i(), (::google::protobuf::int64)INT_MAX), (::google::protobuf::int64)INT_MIN);
107 }
108 }
109
110 return def;
111 }
112
get_node_attr_f(const onnx::NodeProto & node,const char * key,float def=0.f)113 static float get_node_attr_f(const onnx::NodeProto& node, const char* key, float def = 0.f)
114 {
115 for (int i = 0; i < node.attribute_size(); i++)
116 {
117 const onnx::AttributeProto& attr = node.attribute(i);
118 if (attr.name() == key)
119 {
120 return attr.f();
121 }
122 }
123
124 return def;
125 }
126
get_node_attr_s(const onnx::NodeProto & node,const char * key,const std::string & def=std::string ())127 static std::string get_node_attr_s(const onnx::NodeProto& node, const char* key, const std::string& def = std::string())
128 {
129 for (int i = 0; i < node.attribute_size(); i++)
130 {
131 const onnx::AttributeProto& attr = node.attribute(i);
132 if (attr.name() == key)
133 {
134 return attr.s();
135 }
136 }
137
138 return def;
139 }
140
get_node_attr_tensor(const onnx::NodeProto & node,const char * key)141 static onnx::TensorProto get_node_attr_tensor(const onnx::NodeProto& node, const char* key)
142 {
143 for (int i = 0; i < node.attribute_size(); i++)
144 {
145 const onnx::AttributeProto& attr = node.attribute(i);
146 if (attr.name() == key)
147 {
148 return attr.t();
149 }
150 }
151
152 return onnx::TensorProto();
153 }
154
get_node_attr_from_input_f(const onnx::TensorProto & tp)155 static float get_node_attr_from_input_f(const onnx::TensorProto& tp)
156 {
157 float v = 0.f;
158
159 // float
160 if (tp.data_type() == 1)
161 {
162 const float* shape_data = 0;
163 if (tp.has_raw_data())
164 {
165 shape_data = (const float*)tp.raw_data().data();
166 }
167 else
168 {
169 shape_data = tp.float_data().data();
170 }
171 v = shape_data[0];
172 }
173 // double
174 else if (tp.data_type() == 11)
175 {
176 const double* shape_data = 0;
177 if (tp.has_raw_data())
178 {
179 shape_data = (const double*)tp.raw_data().data();
180 }
181 else
182 {
183 shape_data = tp.double_data().data();
184 }
185 v = shape_data[0];
186 }
187 // int64
188 else if (tp.data_type() == 7)
189 {
190 const int64_t* shape_data = 0;
191 if (tp.has_raw_data())
192 {
193 shape_data = (const int64_t*)tp.raw_data().data();
194 }
195 else
196 {
197 shape_data = tp.int64_data().data();
198 }
199 v = std::max(std::min(shape_data[0], (::google::protobuf::int64)INT_MAX), (::google::protobuf::int64)INT_MIN);
200 }
201 // int32
202 else if (tp.data_type() == 6)
203 {
204 const int32_t* shape_data = 0;
205 if (tp.has_raw_data())
206 {
207 shape_data = (const int32_t*)tp.raw_data().data();
208 }
209 else
210 {
211 shape_data = tp.int32_data().data();
212 }
213 v = shape_data[0];
214 }
215 else
216 {
217 fprintf(stderr, "Unknown data type %d\n", tp.data_type());
218 abort();
219 }
220
221 return v;
222 }
223
get_node_attr_from_input_ai(const onnx::TensorProto & tp)224 static std::vector<int> get_node_attr_from_input_ai(const onnx::TensorProto& tp)
225 {
226 int size = 0;
227
228 std::vector<int> v;
229
230 // int64
231 if (tp.data_type() == 7)
232 {
233 const int64_t* shape_data = 0;
234 if (tp.has_raw_data())
235 {
236 shape_data = (const int64_t*)tp.raw_data().data();
237 size = (int)(tp.raw_data().size() / 8);
238 }
239 else
240 {
241 shape_data = tp.int64_data().data();
242 size = tp.int64_data_size();
243 }
244 for (int j = 0; j < size; j++)
245 {
246 int vi = std::max(std::min(shape_data[j], (::google::protobuf::int64)INT_MAX), (::google::protobuf::int64)INT_MIN);
247 v.push_back(vi);
248 }
249 }
250 // int32
251 else if (tp.data_type() == 6)
252 {
253 const int32_t* shape_data = 0;
254 if (tp.has_raw_data())
255 {
256 shape_data = (const int32_t*)tp.raw_data().data();
257 size = (int)(tp.raw_data().size() / 4);
258 }
259 else
260 {
261 shape_data = tp.int32_data().data();
262 size = tp.int32_data_size();
263 }
264 for (int j = 0; j < size; j++)
265 {
266 v.push_back(shape_data[j]);
267 }
268 }
269 else
270 {
271 fprintf(stderr, "Unknown data type %d\n", tp.data_type());
272 }
273
274 return v;
275 }
276
get_node_attr_from_input_af(const onnx::TensorProto & tp)277 static std::vector<float> get_node_attr_from_input_af(const onnx::TensorProto& tp)
278 {
279 int size = 0;
280
281 std::vector<float> v;
282
283 // float
284 if (tp.data_type() == 1)
285 {
286 const float* shape_data = 0;
287 if (tp.has_raw_data())
288 {
289 shape_data = (const float*)tp.raw_data().data();
290 size = (int)(tp.raw_data().size() / 4);
291 }
292 else
293 {
294 shape_data = tp.float_data().data();
295 size = tp.float_data_size();
296 }
297 for (int j = 0; j < size; j++)
298 {
299 v.push_back(shape_data[j]);
300 }
301 }
302 // double
303 else if (tp.data_type() == 11)
304 {
305 const double* shape_data = 0;
306 if (tp.has_raw_data())
307 {
308 shape_data = (const double*)tp.raw_data().data();
309 size = (int)(tp.raw_data().size() / 8);
310 }
311 else
312 {
313 shape_data = tp.double_data().data();
314 size = tp.double_data_size();
315 }
316 for (int j = 0; j < size; j++)
317 {
318 v.push_back((float)shape_data[j]);
319 }
320 }
321 else
322 {
323 fprintf(stderr, "Unknown data type %d\n", tp.data_type());
324 }
325
326 return v;
327 }
328
get_tensor_proto_data_size(const onnx::TensorProto & tp)329 static int get_tensor_proto_data_size(const onnx::TensorProto& tp)
330 {
331 if (tp.has_raw_data())
332 {
333 const std::string& raw_data = tp.raw_data();
334 int size = (int)raw_data.size() / 4;
335 return size;
336 }
337 else if (tp.data_type() == 1)
338 {
339 return tp.float_data_size();
340 }
341
342 return 0;
343 }
344
fwrite_tensor_proto_data(const onnx::TensorProto & tp,FILE * bp)345 static void fwrite_tensor_proto_data(const onnx::TensorProto& tp, FILE* bp)
346 {
347 int size = get_tensor_proto_data_size(tp);
348
349 if (tp.has_raw_data())
350 {
351 const std::string& raw_data = tp.raw_data();
352 fwrite(raw_data.data(), sizeof(float), size, bp);
353 }
354 else if (tp.data_type() == 1)
355 {
356 fwrite(tp.float_data().data(), sizeof(float), size, bp);
357 }
358 }
359
fuse_weight_reshape(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)360 static void fuse_weight_reshape(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
361 {
362 int node_count = mutable_graph->node_size();
363 for (int i = 0; i < node_count; i++)
364 {
365 onnx::NodeProto* node = mutable_graph->mutable_node(i);
366
367 // weight <= Reshape(weight)
368 if (node->op_type() == "Reshape")
369 {
370 // check weight
371 if (weights.find(node->input(0)) == weights.end())
372 continue;
373
374 weights[node->output(0)] = weights[node->input(0)];
375
376 // set weight shape directly
377 std::vector<int> shape;
378 if (node->input_size() == 1)
379 {
380 shape = get_node_attr_ai(*node, "shape");
381 }
382 else if (node->input_size() == 2)
383 {
384 // opset 5
385 shape = get_node_attr_from_input_ai(weights[node->input(1)]);
386 }
387
388 weights[node->output(0)].clear_dims();
389 for (int j = 0; j < shape.size(); j++)
390 {
391 weights[node->output(0)].add_dims(shape[j]);
392 }
393
394 // reduce
395 node->set_op_type("noop_reducedncnn");
396
397 node_reference[node->input(0)] -= 1;
398 if (node->input_size() == 2)
399 {
400 node_reference[node->input(1)] -= 1;
401 }
402
403 reduced_node_count += 1;
404 i += 1;
405 }
406 }
407 }
408
fuse_weight_transpose(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)409 static void fuse_weight_transpose(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
410 {
411 int node_count = mutable_graph->node_size();
412 for (int i = 0; i < node_count; i++)
413 {
414 onnx::NodeProto* node = mutable_graph->mutable_node(i);
415
416 // weight <= Transpose(weight)
417 if (node->op_type() == "Transpose")
418 {
419 // check weight
420 if (weights.find(node->input(0)) == weights.end())
421 continue;
422
423 if (weights[node->input(0)].dims_size() != 2)
424 continue;
425
426 // perm = (1, 0)
427 std::vector<int> perm = get_node_attr_ai(*node, "perm");
428 if (perm.size() != 2)
429 continue;
430 if (perm[0] != 1 || perm[1] != 0)
431 continue;
432
433 weights[node->output(0)] = weights[node->input(0)];
434
435 // permute weight
436 {
437 onnx::TensorProto& B = weights[node->output(0)];
438
439 const int h = B.dims(0);
440 const int w = B.dims(1);
441
442 std::vector<float> permuted_data;
443 permuted_data.reserve((size_t)h * w);
444 const float* bptr = B.has_raw_data() ? (const float*)B.raw_data().data() : B.float_data().data();
445
446 for (int j = 0; j < w; j++)
447 {
448 for (int k = 0; k < h; k++)
449 {
450 float vb = bptr[k * w + j];
451 permuted_data.push_back(vb);
452 }
453 }
454
455 B.set_dims(0, w);
456 B.set_dims(1, h);
457
458 if (B.has_raw_data())
459 {
460 B.set_raw_data(permuted_data.data(), permuted_data.size() * sizeof(float));
461 }
462 else
463 {
464 for (int j = 0; j < (int)permuted_data.size(); j++)
465 B.set_float_data(j, permuted_data[j]);
466 }
467 }
468
469 // reduce
470 node->set_op_type("noop_reducedncnn");
471
472 node_reference[node->input(0)] -= 1;
473
474 reduced_node_count += 1;
475 i += 1;
476 }
477 }
478 }
479
fuse_shufflechannel(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)480 static void fuse_shufflechannel(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
481 {
482 int node_count = mutable_graph->node_size();
483 for (int i = 0; i < node_count; i++)
484 {
485 onnx::NodeProto* node = mutable_graph->mutable_node(i);
486
487 // ShuffleChannel <= Reshape - Transpose - Reshape
488 // ShuffleChannel <= Reshape - Transpose - Constant - Reshape
489 if (node->op_type() == "Reshape")
490 {
491 if (node_reference[node->output(0)] != 1)
492 continue;
493
494 std::vector<int> shape;
495 if (node->input_size() == 1)
496 {
497 shape = get_node_attr_ai(*node, "shape");
498 }
499 else
500 {
501 // skip weight reshape
502 if (weights.find(node->input(1)) == weights.end())
503 continue;
504
505 shape = get_node_attr_from_input_ai(weights[node->input(1)]);
506 }
507
508 // 1 groups channels_per_group, height, width
509 // reverse style = channels_per_group, groups, height * width
510 if (shape.size() != 5 && shape.size() != 3)
511 continue;
512
513 if (shape.size() == 5 && shape[0] != 1)
514 continue;
515
516 if (i + 2 >= node_count)
517 continue;
518
519 onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
520 onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
521
522 if (node3->op_type() == "Constant")
523 {
524 if (i + 3 >= node_count)
525 continue;
526
527 node3 = mutable_graph->mutable_node(i + 3);
528 }
529
530 if (node2->op_type() != "Transpose" || node3->op_type() != "Reshape")
531 continue;
532
533 if (node_reference[node2->output(0)] != 1)
534 continue;
535
536 // 0 2 1 3 4
537 // reverse style = 1 0 2
538 std::vector<int> perm = get_node_attr_ai(*node2, "perm");
539 if (perm.size() != 5 && perm.size() != 3)
540 continue;
541
542 if (perm.size() == 5 && (perm[0] != 0 || perm[1] != 2 || perm[2] != 1 || perm[3] != 3 || perm[4] != 4))
543 continue;
544
545 if (perm.size() == 3 && (perm[0] != 1 || perm[1] != 0 || perm[2] != 2))
546 continue;
547
548 std::vector<int> shape3;
549 if (node3->input_size() == 1)
550 {
551 shape3 = get_node_attr_ai(*node3, "shape");
552 }
553 else
554 {
555 // skip weight reshape
556 if (weights.find(node3->input(1)) == weights.end())
557 continue;
558
559 shape3 = get_node_attr_from_input_ai(weights[node3->input(1)]);
560 }
561
562 // 1, -1, height, width
563 // reverse style = group, -1, channels_per_group, height, width
564 if (shape3.size() != 4 && shape3.size() != 5)
565 continue;
566
567 if (shape3.size() == 4 && (shape3[0] != 1 || (shape3[1] != -1 && shape3[1] != shape[1] * shape[2])))
568 continue;
569
570 if (shape3.size() == 5 && (shape3[0] != shape[1] || shape3[2] != shape[0] || shape3[3] * shape3[4] != shape[2]))
571 continue;
572
573 // reduce
574 node->set_op_type("noop_reducedncnn");
575 node2->set_op_type("noop_reducedncnn");
576
577 if (node->input_size() == 2)
578 {
579 node_reference[node->input(1)] -= 1;
580 }
581 node_reference[node->output(0)] -= 1;
582 node_reference[node2->output(0)] -= 1;
583 if (node3->input_size() == 2)
584 {
585 node_reference[node3->input(1)] -= 1;
586 }
587
588 blob_names.erase(node->output(0));
589 blob_names.erase(node2->output(0));
590
591 node3->set_op_type("ShuffleChannel");
592 node3->set_input(0, node->input(0));
593
594 onnx::AttributeProto* attr_group = node3->add_attribute();
595 attr_group->set_name("group");
596 attr_group->set_i(shape[1]);
597
598 onnx::AttributeProto* attr_reverse = node3->add_attribute();
599 attr_reverse->set_name("reverse");
600 attr_reverse->set_i(shape.size() == 3);
601
602 reduced_node_count += 2;
603 i += 2;
604 }
605 }
606 }
607
fuse_shufflechannel_split(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)608 static void fuse_shufflechannel_split(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
609 {
610 int node_count = mutable_graph->node_size();
611 for (int i = 0; i < node_count; i++)
612 {
613 onnx::NodeProto* node = mutable_graph->mutable_node(i);
614
615 // Split <= ShuffleChannel(reverse type) - Gather(0) - Gather(1)
616 if (node->op_type() == "ShuffleChannel")
617 {
618 // reverse = 1
619 int reverse = get_node_attr_i(*node, "reverse");
620 if (reverse != 1)
621 continue;
622
623 if (i + 2 >= node_count)
624 continue;
625
626 onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
627 onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
628
629 if (node2->op_type() != "Gather" || node3->op_type() != "Gather")
630 continue;
631
632 if (node2->input(0) != node->output(0) || node3->input(0) != node->output(0))
633 continue;
634
635 // axis = 0
636 int gather2_axis = get_node_attr_i(*node2, "axis");
637 if (gather2_axis != 0)
638 continue;
639
640 // indices = 0
641 if (weights.find(node2->input(1)) == weights.end())
642 continue;
643
644 std::vector<int> gather2_indices = get_node_attr_from_input_ai(weights[node2->input(1)]);
645 if (gather2_indices.size() != 1 || gather2_indices[0] != 0)
646 continue;
647
648 // axis = 0
649 int gather3_axis = get_node_attr_i(*node3, "axis");
650 if (gather3_axis != 0)
651 continue;
652
653 // indices = 1
654 if (weights.find(node3->input(1)) == weights.end())
655 continue;
656
657 std::vector<int> gather3_indices = get_node_attr_from_input_ai(weights[node3->input(1)]);
658 if (gather3_indices.size() != 1 || gather3_indices[0] != 1)
659 continue;
660
661 // reduce
662 node2->set_op_type("noop_reducedncnn");
663
664 node_reference[node->output(0)] -= 2;
665 node_reference[node2->input(1)] -= 1;
666 node_reference[node3->input(1)] -= 1;
667
668 node3->set_op_type("Split");
669 node3->clear_input();
670 node3->add_input(node->output(0));
671 node3->add_output(node3->output(0));
672 node3->set_output(0, node2->output(0));
673
674 node3->clear_attribute();
675 onnx::AttributeProto* attr_axis = node3->add_attribute();
676 attr_axis->set_name("axis");
677 attr_axis->set_i(1);
678
679 reduced_node_count += 1;
680 i += 1;
681 }
682 }
683 }
684
fuse_hardswish(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)685 static void fuse_hardswish(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
686 {
687 int node_count = mutable_graph->node_size();
688 for (int i = 0; i < node_count; i++)
689 {
690 onnx::NodeProto* node = mutable_graph->mutable_node(i);
691
692 // HardSwish <= Add(+3) - Clip(0,6) - Mul(X,) - Div(/6)
693 // HardSwish <= Add(+3) - Clip(0,6) - Mul(X,) - Mul(*(1/6))
694 // HardSwish <= Add(+3) - Clip(0,6) - Mul(X,) - Constant - Div(/6)
695 // HardSwish <= Add(+3) - Clip(0,6) - Mul(X,) - Constant - Mul(*(1/6))
696 // out = x * F.relu6(x + 3, inplace=True) / 6
697 if (node->op_type() == "Add")
698 {
699 if (node_reference[node->output(0)] != 1)
700 continue;
701
702 if (i + 3 >= node_count)
703 continue;
704
705 if (weights.find(node->input(1)) == weights.end())
706 continue;
707
708 const onnx::TensorProto& add_three = weights[node->input(1)];
709 if (add_three.dims_size() != 0 || get_tensor_proto_data_size(add_three) != 1)
710 continue;
711
712 float constant_add_three = get_node_attr_from_input_f(add_three);
713 if (constant_add_three != 3.f)
714 continue;
715
716 onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
717 onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
718 onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3);
719
720 if (node4->op_type() == "Constant")
721 {
722 if (i + 4 >= node_count)
723 continue;
724
725 node4 = mutable_graph->mutable_node(i + 4);
726 }
727
728 if (node2->op_type() != "Clip" || node3->op_type() != "Mul" || (node4->op_type() != "Div" && node4->op_type() != "Mul"))
729 continue;
730
731 if (node_reference[node2->output(0)] != 1)
732 continue;
733
734 float relu6_min;
735 float relu6_max;
736 if (node2->input_size() == 1)
737 {
738 relu6_min = get_node_attr_f(*node2, "min", -FLT_MAX);
739 relu6_max = get_node_attr_f(*node2, "max", FLT_MAX);
740 }
741 else
742 {
743 const onnx::TensorProto& min_tp = weights[node2->input(1)];
744 const onnx::TensorProto& max_tp = weights[node2->input(2)];
745
746 relu6_min = get_node_attr_from_input_f(min_tp);
747 relu6_max = get_node_attr_from_input_f(max_tp);
748 }
749 if (relu6_min != 0.f || relu6_max != 6.f)
750 continue;
751
752 if (node_reference[node3->output(0)] != 1)
753 continue;
754
755 if (node3->input(0) != node->input(0) || node3->input(1) != node2->output(0))
756 continue;
757
758 if (weights.find(node4->input(1)) == weights.end())
759 continue;
760
761 const onnx::TensorProto& div_six = weights[node4->input(1)];
762 if (div_six.dims_size() != 0 || get_tensor_proto_data_size(div_six) != 1)
763 continue;
764
765 float constant_div_six = get_node_attr_from_input_f(div_six);
766 if (node4->op_type() == "Div" && constant_div_six != 6.f)
767 continue;
768 if (node4->op_type() == "Mul" && constant_div_six != 1 / 6.f)
769 continue;
770
771 // reduce
772 node->set_op_type("noop_reducedncnn");
773 node2->set_op_type("noop_reducedncnn");
774 node3->set_op_type("noop_reducedncnn");
775
776 node_reference[node->input(0)] -= 1;
777 node_reference[node->input(1)] -= 1;
778 node_reference[node->output(0)] -= 1;
779 if (node2->input_size() == 3)
780 {
781 node_reference[node2->input(1)] -= 1;
782 node_reference[node2->input(2)] -= 1;
783 }
784 node_reference[node2->output(0)] -= 1;
785 node_reference[node3->output(0)] -= 1;
786 node_reference[node4->input(1)] -= 1;
787
788 blob_names.erase(node->output(0));
789 blob_names.erase(node2->output(0));
790 blob_names.erase(node3->output(0));
791
792 node4->set_op_type("HardSwish");
793 node4->clear_input();
794 node4->add_input(node->input(0));
795
796 onnx::AttributeProto* attr_alpha = node4->add_attribute();
797 attr_alpha->set_name("alpha");
798 attr_alpha->set_f(1.f / 6.f);
799
800 onnx::AttributeProto* attr_beta = node4->add_attribute();
801 attr_beta->set_name("beta");
802 attr_beta->set_f(3.f / 6.f);
803
804 reduced_node_count += 3;
805 i += 3;
806 }
807 }
808
809 for (int i = 0; i < node_count; i++)
810 {
811 onnx::NodeProto* node = mutable_graph->mutable_node(i);
812
813 // HardSwish <= HardSigmoid - Mul
814 // out = x * hsigmoid(x)
815 if (node->op_type() == "HardSigmoid")
816 {
817 if (node_reference[node->output(0)] != 1)
818 continue;
819
820 float alpha = get_node_attr_f(*node, "alpha", 0.2f);
821 float beta = get_node_attr_f(*node, "beta", 0.5f);
822
823 if (i + 1 >= node_count)
824 continue;
825
826 onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
827
828 if (node2->op_type() != "Mul")
829 continue;
830
831 if (node2->input(0) != node->input(0) || node2->input(1) != node->output(0))
832 continue;
833
834 // reduce
835 node->set_op_type("noop_reducedncnn");
836
837 node_reference[node->input(0)] -= 1;
838 node_reference[node->output(0)] -= 1;
839
840 blob_names.erase(node->output(0));
841
842 node2->set_op_type("HardSwish");
843 node2->clear_input();
844 node2->add_input(node->input(0));
845
846 onnx::AttributeProto* attr_alpha = node2->add_attribute();
847 attr_alpha->set_name("alpha");
848 attr_alpha->set_f(alpha);
849
850 onnx::AttributeProto* attr_beta = node2->add_attribute();
851 attr_beta->set_name("beta");
852 attr_beta->set_f(beta);
853
854 reduced_node_count += 1;
855 i += 1;
856 }
857 }
858 }
859
fuse_hardsigmoid(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)860 static void fuse_hardsigmoid(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
861 {
862 int node_count = mutable_graph->node_size();
863 for (int i = 0; i < node_count; i++)
864 {
865 onnx::NodeProto* node = mutable_graph->mutable_node(i);
866
867 // HardSigmoid <= Add(+3) - Clip(0,6) - Div(/6)
868 // HardSigmoid <= Add(+3) - Clip(0,6) - Mul(*(1/6))
869 // HardSigmoid <= Add(+3) - Clip(0,6) - Constant - Div(/6)
870 // HardSigmoid <= Add(+3) - Clip(0,6) - Constant - Mul(*(1/6))
871 // out = F.relu6(x + 3, inplace=True) / 6
872 if (node->op_type() == "Add")
873 {
874 if (node_reference[node->output(0)] != 1)
875 continue;
876
877 if (i + 2 >= node_count)
878 continue;
879
880 if (weights.find(node->input(1)) == weights.end())
881 continue;
882
883 const onnx::TensorProto& add_three = weights[node->input(1)];
884 if (add_three.dims_size() != 0 || get_tensor_proto_data_size(add_three) != 1)
885 continue;
886
887 float constant_add_three = get_node_attr_from_input_f(add_three);
888 if (constant_add_three != 3.f)
889 continue;
890
891 onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
892 onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
893
894 if (node3->op_type() == "Constant")
895 {
896 if (i + 3 >= node_count)
897 continue;
898
899 node3 = mutable_graph->mutable_node(i + 3);
900 }
901
902 if (node2->op_type() != "Clip" || (node3->op_type() != "Div" && node3->op_type() != "Mul"))
903 continue;
904
905 if (node_reference[node2->output(0)] != 1)
906 continue;
907
908 float relu6_min;
909 float relu6_max;
910 if (node2->input_size() == 1)
911 {
912 relu6_min = get_node_attr_f(*node2, "min", -FLT_MAX);
913 relu6_max = get_node_attr_f(*node2, "max", FLT_MAX);
914 }
915 else
916 {
917 const onnx::TensorProto& min_tp = weights[node2->input(1)];
918 const onnx::TensorProto& max_tp = weights[node2->input(2)];
919
920 relu6_min = get_node_attr_from_input_f(min_tp);
921 relu6_max = get_node_attr_from_input_f(max_tp);
922 }
923 if (relu6_min != 0.f || relu6_max != 6.f)
924 continue;
925
926 if (weights.find(node3->input(1)) == weights.end())
927 continue;
928
929 const onnx::TensorProto& div_six = weights[node3->input(1)];
930 if (div_six.dims_size() != 0 || get_tensor_proto_data_size(div_six) != 1)
931 continue;
932
933 float constant_div_six = get_node_attr_from_input_f(div_six);
934 if (node3->op_type() == "Div" && constant_div_six != 6.f)
935 continue;
936 if (node3->op_type() == "Mul" && constant_div_six != 1 / 6.f)
937 continue;
938
939 // reduce
940 node->set_op_type("noop_reducedncnn");
941 node2->set_op_type("noop_reducedncnn");
942
943 node_reference[node->input(1)] -= 1;
944 node_reference[node->output(0)] -= 1;
945 if (node2->input_size() == 3)
946 {
947 node_reference[node2->input(1)] -= 1;
948 node_reference[node2->input(2)] -= 1;
949 }
950 node_reference[node2->output(0)] -= 1;
951 node_reference[node3->input(1)] -= 1;
952
953 blob_names.erase(node->output(0));
954 blob_names.erase(node2->output(0));
955
956 node3->set_op_type("HardSigmoid");
957 node3->clear_input();
958 node3->add_input(node->input(0));
959
960 onnx::AttributeProto* attr_alpha = node3->add_attribute();
961 attr_alpha->set_name("alpha");
962 attr_alpha->set_f(1.f / 6.f);
963
964 onnx::AttributeProto* attr_beta = node3->add_attribute();
965 attr_beta->set_name("beta");
966 attr_beta->set_f(3.f / 6.f);
967
968 reduced_node_count += 2;
969 i += 2;
970 }
971 }
972 }
973
fuse_swish(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)974 static void fuse_swish(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
975 {
976 int node_count = mutable_graph->node_size();
977 for (int i = 0; i < node_count; i++)
978 {
979 onnx::NodeProto* node = mutable_graph->mutable_node(i);
980
981 // Swish <= Sigmoid - Mul
982 // x * torch.sigmoid(x)
983 if (node->op_type() == "Sigmoid")
984 {
985 if (node_reference[node->output(0)] != 1)
986 continue;
987
988 if (i + 1 >= node_count)
989 continue;
990
991 onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
992
993 if (node2->op_type() != "Mul")
994 continue;
995
996 if (node2->input(0) != node->input(0) || node2->input(1) != node->output(0))
997 continue;
998
999 // reduce
1000 node->set_op_type("noop_reducedncnn");
1001
1002 node_reference[node->input(0)] -= 1;
1003 node_reference[node->output(0)] -= 1;
1004
1005 blob_names.erase(node->output(0));
1006
1007 node2->set_op_type("Swish");
1008 node2->clear_input();
1009 node2->add_input(node->input(0));
1010
1011 reduced_node_count += 1;
1012 i += 1;
1013 }
1014 }
1015 }
1016
fuse_batchnorm1d_squeeze_unsqueeze(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)1017 static void fuse_batchnorm1d_squeeze_unsqueeze(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
1018 {
1019 int node_count = mutable_graph->node_size();
1020 for (int i = 0; i < node_count; i++)
1021 {
1022 onnx::NodeProto* node = mutable_graph->mutable_node(i);
1023
1024 // BatchNormalization <= Unsqueeze - BatchNormalization - Squeeze
1025 if (node->op_type() == "Unsqueeze")
1026 {
1027 if (node_reference[node->output(0)] != 1)
1028 continue;
1029
1030 if (i + 2 >= node_count)
1031 continue;
1032
1033 onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
1034 onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
1035
1036 if (node2->op_type() != "BatchNormalization" || node3->op_type() != "Squeeze")
1037 continue;
1038
1039 if (node_reference[node2->output(0)] != 1)
1040 continue;
1041
1042 if (node2->input(0) != node->output(0) || node3->input(0) != node2->output(0))
1043 continue;
1044
1045 // reduce
1046 node->set_op_type("noop_reducedncnn");
1047 node3->set_op_type("noop_reducedncnn");
1048
1049 node_reference[node->output(0)] -= 1;
1050 node_reference[node2->output(0)] -= 1;
1051
1052 blob_names.erase(node->output(0));
1053 blob_names.erase(node2->output(0));
1054
1055 node2->set_input(0, node->input(0));
1056 node2->set_output(0, node3->output(0));
1057
1058 reduced_node_count += 2;
1059 i += 2;
1060 }
1061 }
1062 }
1063
fuse_unsqueeze_prelu(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)1064 static void fuse_unsqueeze_prelu(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
1065 {
1066 int node_count = mutable_graph->node_size();
1067 for (int i = 0; i < node_count; i++)
1068 {
1069 onnx::NodeProto* node = mutable_graph->mutable_node(i);
1070
1071 // PReLU <= Unsqueeze - PReLU
1072 if (node->op_type() == "Unsqueeze")
1073 {
1074 // check weight
1075 if (weights.find(node->input(0)) == weights.end())
1076 continue;
1077
1078 onnx::TensorProto& B = weights[node->input(0)];
1079 if (B.dims_size() != 1)
1080 continue;
1081
1082 if (node_reference[node->output(0)] != 1)
1083 continue;
1084
1085 // axes = (1, 2)
1086 std::vector<int> axes = get_node_attr_ai(*node, "axes");
1087 if (axes.size() != 2)
1088 continue;
1089 if (axes[0] != 1 || axes[1] != 2)
1090 continue;
1091
1092 if (i + 1 >= node_count)
1093 continue;
1094
1095 onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
1096
1097 if (node2->op_type() != "PRelu")
1098 continue;
1099
1100 if (node2->input(1) != node->output(0))
1101 continue;
1102
1103 // reduce
1104 node->set_op_type("noop_reducedncnn");
1105
1106 node_reference[node->output(0)] -= 1;
1107
1108 blob_names.erase(node->output(0));
1109
1110 node2->set_input(1, node->input(0));
1111
1112 reduced_node_count += 1;
1113 i += 1;
1114 }
1115 }
1116 }
1117
fuse_normalize(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)1118 static void fuse_normalize(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
1119 {
1120 int node_count = mutable_graph->node_size();
1121 for (int i = 0; i < node_count; i++)
1122 {
1123 onnx::NodeProto* node = mutable_graph->mutable_node(i);
1124
1125 // Normalize <= X - ReduceL2 - Clip - Expand - Div
1126 // Normalize <= X - ReduceL2 - Clip - Shape - Expand - Div
1127 if (node->op_type() == "ReduceL2")
1128 {
1129 if (node_reference[node->output(0)] != 1)
1130 continue;
1131
1132 // axes = (1)
1133 std::vector<int> axes = get_node_attr_ai(*node, "axes");
1134 if (axes.size() != 1)
1135 continue;
1136 if (axes[0] != 1)
1137 continue;
1138
1139 if (i + 3 >= node_count)
1140 continue;
1141
1142 onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
1143 onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
1144 onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3);
1145
1146 bool has_shape_node = node3->op_type() == "Shape";
1147 onnx::NodeProto* node_shape = 0;
1148 if (has_shape_node)
1149 {
1150 if (i + 4 >= node_count)
1151 continue;
1152
1153 node_shape = node3;
1154 node3 = mutable_graph->mutable_node(i + 3);
1155 node4 = mutable_graph->mutable_node(i + 4);
1156 }
1157
1158 if (node2->op_type() != "Clip" || node3->op_type() != "Expand" || node4->op_type() != "Div")
1159 continue;
1160
1161 if (node_reference[node2->output(0)] != 1)
1162 continue;
1163
1164 if (node_reference[node3->output(0)] != 1)
1165 continue;
1166
1167 if (node2->input(0) != node->output(0) || node3->input(0) != node2->output(0)
1168 || node4->input(0) != node->input(0) || node4->input(1) != node3->output(0))
1169 continue;
1170
1171 if (has_shape_node)
1172 {
1173 if (node_shape->input(0) != node->input(0) || node3->input(1) != node_shape->output(0))
1174 continue;
1175 }
1176
1177 // +eps
1178 float clip_min;
1179 if (node2->input_size() == 1)
1180 {
1181 clip_min = get_node_attr_f(*node2, "min", -FLT_MAX);
1182 }
1183 else
1184 {
1185 const onnx::TensorProto& min_tp = weights[node2->input(1)];
1186
1187 clip_min = get_node_attr_from_input_f(min_tp);
1188 }
1189
1190 // reduce
1191 node->set_op_type("noop_reducedncnn");
1192 node2->set_op_type("noop_reducedncnn");
1193 if (has_shape_node)
1194 {
1195 node_shape->set_op_type("noop_reducedncnn");
1196 }
1197 node3->set_op_type("noop_reducedncnn");
1198
1199 node_reference[node->input(0)] -= has_shape_node ? 2 : 1;
1200 node_reference[node->output(0)] -= 1;
1201 node_reference[node2->output(0)] -= 1;
1202 if (has_shape_node)
1203 {
1204 node_reference[node_shape->output(0)] -= 1;
1205 }
1206 node_reference[node3->output(0)] -= 1;
1207
1208 blob_names.erase(node->output(0));
1209 blob_names.erase(node2->output(0));
1210 if (has_shape_node)
1211 {
1212 blob_names.erase(node_shape->output(0));
1213 }
1214 blob_names.erase(node3->output(0));
1215
1216 node4->set_op_type("Normalize");
1217 node4->clear_input();
1218 node4->add_input(node->input(0));
1219
1220 onnx::AttributeProto* attr_alpha = node4->add_attribute();
1221 attr_alpha->set_name("eps");
1222 attr_alpha->set_f(clip_min);
1223
1224 reduced_node_count += has_shape_node ? 4 : 3;
1225 i += has_shape_node ? 4 : 3;
1226 }
1227 }
1228 }
1229
fuse_groupnorm(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)1230 static void fuse_groupnorm(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
1231 {
1232 int node_count = mutable_graph->node_size();
1233 for (int i = 0; i < node_count; i++)
1234 {
1235 onnx::NodeProto* node = mutable_graph->mutable_node(i);
1236
1237 // GroupNorm <= X - Reshape - InstanceNormalization - Reshape - Mul - Add
1238 if (node->op_type() == "Reshape")
1239 {
1240 if (node_reference[node->output(0)] != 1)
1241 continue;
1242
1243 std::vector<int> shape;
1244 if (node->input_size() == 1)
1245 {
1246 shape = get_node_attr_ai(*node, "shape");
1247 }
1248 else
1249 {
1250 // skip weight reshape
1251 if (weights.find(node->input(1)) == weights.end())
1252 continue;
1253
1254 shape = get_node_attr_from_input_ai(weights[node->input(1)]);
1255 }
1256
1257 // 0, group, -1
1258 if (shape.size() != 3)
1259 continue;
1260
1261 if (shape[0] != 0 || shape[2] != -1)
1262 continue;
1263
1264 int groups = shape[1];
1265
1266 if (i + 4 >= node_count)
1267 continue;
1268
1269 onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
1270 onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
1271 onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3);
1272 onnx::NodeProto* node5 = mutable_graph->mutable_node(i + 4);
1273
1274 if (node2->op_type() != "InstanceNormalization" || node3->op_type() != "Reshape" || node4->op_type() != "Mul" || node5->op_type() != "Add")
1275 continue;
1276
1277 if (node_reference[node2->output(0)] != 1)
1278 continue;
1279
1280 if (node_reference[node3->output(0)] != 1)
1281 continue;
1282
1283 if (node_reference[node4->output(0)] != 1)
1284 continue;
1285
1286 if (node2->input(0) != node->output(0) || node3->input(0) != node2->output(0)
1287 || node4->input(0) != node3->output(0) || node5->input(0) != node4->output(0))
1288 continue;
1289
1290 // +eps
1291 float eps = get_node_attr_f(*node2, "epsilon", 1e-05f);
1292
1293 // InstanceNormalization S=1 B=0
1294 std::vector<float> S = get_node_attr_from_input_af(weights[node2->input(1)]);
1295 std::vector<float> B = get_node_attr_from_input_af(weights[node2->input(2)]);
1296 if ((int)S.size() != groups || (int)B.size() != groups)
1297 continue;
1298
1299 bool instancenorm_affine = false;
1300 for (int j = 0; j < groups; j++)
1301 {
1302 if (S[j] != 1.f || B[j] != 0.f)
1303 {
1304 instancenorm_affine = true;
1305 break;
1306 }
1307 }
1308
1309 if (instancenorm_affine)
1310 continue;
1311
1312 std::vector<int> shape2;
1313 if (node3->input_size() == 1)
1314 {
1315 shape2 = get_node_attr_ai(*node3, "shape");
1316 }
1317 else
1318 {
1319 // skip weight reshape
1320 if (weights.find(node3->input(1)) == weights.end())
1321 continue;
1322
1323 shape2 = get_node_attr_from_input_ai(weights[node3->input(1)]);
1324 }
1325
1326 // 1, channels, w, h
1327 if (shape2.size() != 4)
1328 continue;
1329
1330 if (shape2[0] != 1)
1331 continue;
1332
1333 int channels = shape2[1];
1334
1335 // affine
1336 int affine = 0;
1337 std::vector<float> affine_S = get_node_attr_from_input_af(weights[node4->input(1)]);
1338 std::vector<float> affine_B = get_node_attr_from_input_af(weights[node5->input(1)]);
1339 if (affine_S.size() == 1 && affine_S[0] == 1.f && affine_B.size() == 1 && affine_B[0] == 0.f)
1340 {
1341 affine = 0;
1342 }
1343 else if ((int)affine_S.size() != channels && (int)affine_B.size() != channels)
1344 {
1345 // we only allow per-channel affine
1346 continue;
1347 }
1348
1349 for (int j = 0; j < channels; j++)
1350 {
1351 if (affine_S[j] != 1.f || affine_B[j] != 0.f)
1352 {
1353 affine = 1;
1354 break;
1355 }
1356 }
1357
1358 // reduce
1359 node->set_op_type("noop_reducedncnn");
1360 node2->set_op_type("noop_reducedncnn");
1361 node3->set_op_type("noop_reducedncnn");
1362 node4->set_op_type("noop_reducedncnn");
1363
1364 if (node->input_size() == 2)
1365 {
1366 node_reference[node->input(1)] -= 1;
1367 }
1368 node_reference[node->output(0)] -= 1;
1369 node_reference[node2->input(1)] -= 1;
1370 node_reference[node2->input(2)] -= 1;
1371 node_reference[node2->output(0)] -= 1;
1372 if (node3->input_size() == 2)
1373 {
1374 node_reference[node3->input(1)] -= 1;
1375 }
1376 node_reference[node3->output(0)] -= 1;
1377 node_reference[node4->output(0)] -= 1;
1378
1379 std::string affine_scale = node4->input(1);
1380 std::string affine_bias = node5->input(1);
1381
1382 node_reference[affine_scale] -= 1;
1383 node_reference[affine_bias] -= 1;
1384
1385 blob_names.erase(node->output(0));
1386 blob_names.erase(node2->output(0));
1387 blob_names.erase(node3->output(0));
1388 blob_names.erase(node4->output(0));
1389
1390 node5->set_op_type("GroupNorm");
1391 node5->clear_input();
1392 node5->add_input(node->input(0));
1393 if (affine)
1394 {
1395 node5->add_input(affine_scale);
1396 node5->add_input(affine_bias);
1397 }
1398
1399 onnx::AttributeProto* attr_groups = node5->add_attribute();
1400 attr_groups->set_name("groups");
1401 attr_groups->set_i(groups);
1402
1403 onnx::AttributeProto* attr_channels = node5->add_attribute();
1404 attr_channels->set_name("channels");
1405 attr_channels->set_i(channels);
1406
1407 onnx::AttributeProto* attr_eps = node5->add_attribute();
1408 attr_eps->set_name("epsilon");
1409 attr_eps->set_f(eps);
1410
1411 onnx::AttributeProto* attr_affine = node5->add_attribute();
1412 attr_affine->set_name("affine");
1413 attr_affine->set_i(affine);
1414
1415 reduced_node_count += 4;
1416 i += 4;
1417 }
1418 }
1419 }
1420
fuse_flatten(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)1421 static void fuse_flatten(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
1422 {
1423 int node_count = mutable_graph->node_size();
1424 for (int i = 0; i < node_count; i++)
1425 {
1426 onnx::NodeProto* node = mutable_graph->mutable_node(i);
1427
1428 // Flatten <= X - Shape - Gather - Constant - Unsqueeze - Unsqueeze - Concat - Reshape
1429 if (node->op_type() == "Shape")
1430 {
1431 if (node_reference[node->output(0)] != 1)
1432 continue;
1433
1434 if (i + 6 >= node_count)
1435 continue;
1436
1437 onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
1438 onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
1439 onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3);
1440 onnx::NodeProto* node5 = mutable_graph->mutable_node(i + 4);
1441 onnx::NodeProto* node6 = mutable_graph->mutable_node(i + 5);
1442 onnx::NodeProto* node7 = mutable_graph->mutable_node(i + 6);
1443
1444 if (node2->op_type() != "Gather" || node3->op_type() != "Constant" || node4->op_type() != "Unsqueeze" || node5->op_type() != "Unsqueeze"
1445 || node6->op_type() != "Concat" || node7->op_type() != "Reshape")
1446 continue;
1447
1448 if (node_reference[node2->output(0)] != 1)
1449 continue;
1450
1451 // if (node_reference[node3->output(0)] != 1)
1452 // continue;
1453
1454 if (node_reference[node4->output(0)] != 1)
1455 continue;
1456
1457 if (node_reference[node5->output(0)] != 1)
1458 continue;
1459
1460 if (node_reference[node6->output(0)] != 1)
1461 continue;
1462
1463 if (node2->input(0) != node->output(0) || node4->input(0) != node2->output(0) || node5->input(0) != node3->output(0)
1464 || node6->input(0) != node4->output(0) || node6->input(1) != node5->output(0)
1465 || node7->input(0) != node->input(0) || node7->input(1) != node6->output(0))
1466 continue;
1467
1468 // axis = 0
1469 int gather_axis = get_node_attr_i(*node2, "axis");
1470 if (gather_axis != 0)
1471 continue;
1472
1473 // indices = 0
1474 if (weights.find(node2->input(1)) == weights.end())
1475 continue;
1476
1477 std::vector<int> gather_indices = get_node_attr_from_input_ai(weights[node2->input(1)]);
1478 if (gather_indices.size() != 1 || gather_indices[0] != 0)
1479 continue;
1480
1481 // axes = (0)
1482 std::vector<int> unsqueeze_axes = get_node_attr_ai(*node4, "axes");
1483 if (unsqueeze_axes.size() != 1)
1484 continue;
1485 if (unsqueeze_axes[0] != 0)
1486 continue;
1487
1488 // axes = (0)
1489 std::vector<int> unsqueeze2_axes = get_node_attr_ai(*node5, "axes");
1490 if (unsqueeze2_axes.size() != 1)
1491 continue;
1492 if (unsqueeze2_axes[0] != 0)
1493 continue;
1494
1495 // data = -1
1496 if (weights.find(node5->input(0)) == weights.end())
1497 continue;
1498
1499 std::vector<int> unsqueeze2_data = get_node_attr_from_input_ai(weights[node5->input(0)]);
1500 if (unsqueeze2_data.size() != 1 || unsqueeze2_data[0] != -1)
1501 continue;
1502
1503 // axis = 0
1504 int concat_axis = get_node_attr_i(*node6, "axis");
1505 if (concat_axis != 0)
1506 continue;
1507
1508 // reduce
1509 node->set_op_type("noop_reducedncnn");
1510 node2->set_op_type("noop_reducedncnn");
1511 // node3->set_op_type("noop_reducedncnn");
1512 node4->set_op_type("noop_reducedncnn");
1513 node5->set_op_type("noop_reducedncnn");
1514 node6->set_op_type("noop_reducedncnn");
1515
1516 node_reference[node->input(0)] -= 1;
1517 node_reference[node->output(0)] -= 1;
1518 node_reference[node2->input(1)] -= 1;
1519 node_reference[node2->output(0)] -= 1;
1520 // node_reference[node3->output(0)] -= 1;
1521 node_reference[node4->output(0)] -= 1;
1522 node_reference[node5->input(0)] -= 1;
1523 node_reference[node5->output(0)] -= 1;
1524 node_reference[node6->output(0)] -= 1;
1525
1526 blob_names.erase(node->output(0));
1527 blob_names.erase(node2->output(0));
1528 // blob_names.erase(node3->output(0));
1529 blob_names.erase(node4->output(0));
1530 blob_names.erase(node5->output(0));
1531 blob_names.erase(node6->output(0));
1532
1533 node7->set_op_type("Flatten");
1534 node7->clear_input();
1535 node7->add_input(node->input(0));
1536
1537 reduced_node_count += 5;
1538 i += 5;
1539 }
1540 }
1541 }
1542
fuse_pixelshuffle(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)1543 static void fuse_pixelshuffle(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
1544 {
1545 int node_count = mutable_graph->node_size();
1546 for (int i = 0; i < node_count; i++)
1547 {
1548 onnx::NodeProto* node = mutable_graph->mutable_node(i);
1549
1550 // PixelShuffle <= Reshape - Transpose - Reshape
1551 // PixelShuffle <= Reshape - Transpose - Constant - Reshape
1552 if (node->op_type() == "Reshape")
1553 {
1554 if (node_reference[node->output(0)] != 1)
1555 continue;
1556
1557 std::vector<int> shape;
1558 if (node->input_size() == 1)
1559 {
1560 shape = get_node_attr_ai(*node, "shape");
1561 }
1562 else
1563 {
1564 // skip weight reshape
1565 if (weights.find(node->input(1)) == weights.end())
1566 continue;
1567
1568 shape = get_node_attr_from_input_ai(weights[node->input(1)]);
1569 }
1570
1571 // -1, 3, upscale_factor, upscale_factor, height, width
1572 if (shape.size() != 6)
1573 continue;
1574
1575 if (shape[0] != 1 && shape[0] != -1)
1576 continue;
1577
1578 if (shape[2] != shape[3])
1579 continue;
1580
1581 if (i + 2 >= node_count)
1582 continue;
1583
1584 onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
1585 onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
1586
1587 if (node3->op_type() == "Constant")
1588 {
1589 if (i + 3 >= node_count)
1590 continue;
1591
1592 node3 = mutable_graph->mutable_node(i + 3);
1593 }
1594
1595 if (node2->op_type() != "Transpose" || node3->op_type() != "Reshape")
1596 continue;
1597
1598 if (node_reference[node2->output(0)] != 1)
1599 continue;
1600
1601 // 0 1 4 2 5 3
1602 std::vector<int> perm = get_node_attr_ai(*node2, "perm");
1603 if (perm.size() != 6)
1604 continue;
1605
1606 if (perm[0] != 0 || perm[1] != 1 || perm[2] != 4 || perm[3] != 2 || perm[4] != 5 || perm[5] != 3)
1607 continue;
1608
1609 std::vector<int> shape3;
1610 if (node3->input_size() == 1)
1611 {
1612 shape3 = get_node_attr_ai(*node3, "shape");
1613 }
1614 else
1615 {
1616 // skip weight reshape
1617 if (weights.find(node3->input(1)) == weights.end())
1618 continue;
1619
1620 shape3 = get_node_attr_from_input_ai(weights[node3->input(1)]);
1621 }
1622
1623 // -1, 3, height, width
1624 if (shape3.size() != 4)
1625 continue;
1626
1627 if (shape3[0] != 1 && shape3[0] != -1)
1628 continue;
1629
1630 if (shape3[1] != shape[1] || shape3[2] != shape[2] * shape[4] || shape3[3] != shape[3] * shape[5])
1631 continue;
1632
1633 // reduce
1634 node->set_op_type("noop_reducedncnn");
1635 node2->set_op_type("noop_reducedncnn");
1636
1637 if (node->input_size() == 2)
1638 {
1639 node_reference[node->input(1)] -= 1;
1640 }
1641 node_reference[node->output(0)] -= 1;
1642 node_reference[node2->output(0)] -= 1;
1643 if (node3->input_size() == 2)
1644 {
1645 node_reference[node3->input(1)] -= 1;
1646 }
1647
1648 blob_names.erase(node->output(0));
1649 blob_names.erase(node2->output(0));
1650
1651 node3->set_op_type("PixelShuffle");
1652 node3->set_input(0, node->input(0));
1653
1654 onnx::AttributeProto* attr_group = node3->add_attribute();
1655 attr_group->set_name("scale_factor");
1656 attr_group->set_i(shape[2]);
1657
1658 reduced_node_count += 2;
1659 i += 2;
1660 }
1661 }
1662 }
1663
fuse_reorg(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)1664 static void fuse_reorg(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
1665 {
1666 int node_count = mutable_graph->node_size();
1667 for (int i = 0; i < node_count; i++)
1668 {
1669 onnx::NodeProto* node = mutable_graph->mutable_node(i);
1670
1671 // PixelShuffle <= Reshape - Transpose - Reshape
1672 // PixelShuffle <= Reshape - Transpose - Constant - Reshape
1673 if (node->op_type() == "Reshape")
1674 {
1675 if (node_reference[node->output(0)] != 1)
1676 continue;
1677
1678 std::vector<int> shape;
1679 if (node->input_size() == 1)
1680 {
1681 shape = get_node_attr_ai(*node, "shape");
1682 }
1683 else
1684 {
1685 // skip weight reshape
1686 if (weights.find(node->input(1)) == weights.end())
1687 continue;
1688
1689 shape = get_node_attr_from_input_ai(weights[node->input(1)]);
1690 }
1691
1692 // -1, 3, out_height, block_size, out_width, block_size
1693 if (shape.size() != 6)
1694 continue;
1695
1696 if (shape[0] != 1 && shape[0] != -1)
1697 continue;
1698
1699 if (shape[3] != shape[5])
1700 continue;
1701
1702 if (i + 2 >= node_count)
1703 continue;
1704
1705 onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
1706 onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
1707
1708 if (node3->op_type() == "Constant")
1709 {
1710 if (i + 3 >= node_count)
1711 continue;
1712
1713 node3 = mutable_graph->mutable_node(i + 3);
1714 }
1715
1716 if (node2->op_type() != "Transpose" || node3->op_type() != "Reshape")
1717 continue;
1718
1719 if (node_reference[node2->output(0)] != 1)
1720 continue;
1721
1722 // 0 1 3 5 2 4
1723 std::vector<int> perm = get_node_attr_ai(*node2, "perm");
1724 if (perm.size() != 6)
1725 continue;
1726
1727 if (perm[0] != 0 || perm[1] != 1 || perm[2] != 3 || perm[3] != 5 || perm[4] != 2 || perm[5] != 4)
1728 continue;
1729
1730 std::vector<int> shape3;
1731 if (node3->input_size() == 1)
1732 {
1733 shape3 = get_node_attr_ai(*node3, "shape");
1734 }
1735 else
1736 {
1737 // skip weight reshape
1738 if (weights.find(node3->input(1)) == weights.end())
1739 continue;
1740
1741 shape3 = get_node_attr_from_input_ai(weights[node3->input(1)]);
1742 }
1743
1744 // -1, out_channels, out_height, out_width
1745 if (shape3.size() != 4)
1746 continue;
1747
1748 if (shape3[0] != 1 && shape3[0] != -1)
1749 continue;
1750
1751 if (shape3[1] != shape[1] * shape[3] * shape[5] || shape3[2] != shape[2] || shape3[3] != shape[4])
1752 continue;
1753
1754 // reduce
1755 node->set_op_type("noop_reducedncnn");
1756 node2->set_op_type("noop_reducedncnn");
1757
1758 if (node->input_size() == 2)
1759 {
1760 node_reference[node->input(1)] -= 1;
1761 }
1762 node_reference[node->output(0)] -= 1;
1763 node_reference[node2->output(0)] -= 1;
1764 if (node3->input_size() == 2)
1765 {
1766 node_reference[node3->input(1)] -= 1;
1767 }
1768
1769 blob_names.erase(node->output(0));
1770 blob_names.erase(node2->output(0));
1771
1772 node3->set_op_type("Reorg");
1773 node3->set_input(0, node->input(0));
1774
1775 onnx::AttributeProto* attr_group = node3->add_attribute();
1776 attr_group->set_name("stride");
1777 attr_group->set_i(shape[3]);
1778
1779 reduced_node_count += 2;
1780 i += 2;
1781 }
1782 }
1783 }
1784
fuse_expand_broadcast(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)1785 static void fuse_expand_broadcast(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
1786 {
1787 int node_count = mutable_graph->node_size();
1788 for (int i = 0; i < node_count; i++)
1789 {
1790 onnx::NodeProto* node = mutable_graph->mutable_node(i);
1791
1792 // Add/Sub/Mul/Div/Min/Max <= Expand - Add/Sub/Mul/Div/Min/Max
1793 if (node->op_type() == "Expand")
1794 {
1795 if (node_reference[node->output(0)] != 1)
1796 continue;
1797
1798 if (i + 1 >= node_count)
1799 continue;
1800
1801 onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
1802
1803 if (node2->op_type() != "Add" && node2->op_type() != "Sub" && node2->op_type() != "Mul" && node2->op_type() != "Div" && node2->op_type() != "Min" && node2->op_type() != "Max")
1804 continue;
1805
1806 if (node2->input(1) != node->output(0) && node2->input(0) != node->output(0))
1807 continue;
1808
1809 // reduce
1810 node->set_op_type("noop_reducedncnn");
1811
1812 node_reference[node->output(0)] -= 1;
1813 if (node->input_size() == 2)
1814 {
1815 node_reference[node->input(1)] -= 1;
1816 }
1817
1818 blob_names.erase(node->output(0));
1819
1820 node2->set_input(1, node->input(0));
1821
1822 reduced_node_count += 1;
1823 i += 1;
1824 }
1825 }
1826 }
1827
fuse_lstm_gru_rnn(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)1828 static void fuse_lstm_gru_rnn(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
1829 {
1830 int node_count = mutable_graph->node_size();
1831 for (int i = 0; i < node_count; i++)
1832 {
1833 onnx::NodeProto* node = mutable_graph->mutable_node(i);
1834
1835 // LSTM(bi) <= LSTM(bi) - Transpose - Reshape - Transpose
1836 if (node->op_type() == "LSTM" || node->op_type() == "GRU" || node->op_type() == "RNN")
1837 {
1838 if (node_reference[node->output(0)] != 1)
1839 continue;
1840
1841 if (i + 2 >= node_count)
1842 continue;
1843
1844 onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
1845 onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
1846
1847 if (node2->op_type() != "Transpose" || node3->op_type() != "Reshape")
1848 continue;
1849
1850 if (node_reference[node2->output(0)] != 1)
1851 continue;
1852
1853 if (node2->input(0) != node->output(0) || node3->input(0) != node2->output(0))
1854 continue;
1855
1856 std::string direction = get_node_attr_s(*node, "direction");
1857 if (direction != "bidirectional")
1858 continue;
1859
1860 // 0 2 1 3
1861 std::vector<int> perm = get_node_attr_ai(*node2, "perm");
1862 if (perm.size() != 4)
1863 continue;
1864
1865 if (perm[0] != 0 || perm[1] != 2 || perm[2] != 1 || perm[3] != 3)
1866 continue;
1867
1868 std::vector<int> shape;
1869 if (node3->input_size() == 1)
1870 {
1871 shape = get_node_attr_ai(*node3, "shape");
1872 }
1873 else
1874 {
1875 // skip weight reshape
1876 if (weights.find(node3->input(1)) == weights.end())
1877 continue;
1878
1879 shape = get_node_attr_from_input_ai(weights[node3->input(1)]);
1880 }
1881
1882 // 0 0 -1
1883 if (shape.size() != 3)
1884 continue;
1885
1886 if (shape[0] != 0 || shape[1] != 0 || shape[2] != -1)
1887 continue;
1888
1889 // reduce
1890 node2->set_op_type("noop_reducedncnn");
1891 node3->set_op_type("noop_reducedncnn");
1892
1893 node_reference[node->output(0)] -= 1;
1894 node_reference[node2->output(0)] -= 1;
1895 if (node3->input_size() == 2)
1896 {
1897 node_reference[node3->input(1)] -= 1;
1898 }
1899
1900 blob_names.erase(node->output(0));
1901 if (node->output_size() > 1)
1902 {
1903 for (int j = 1; j < node->output_size(); j++)
1904 {
1905 blob_names.erase(node->output(j));
1906 }
1907 }
1908 blob_names.erase(node2->output(0));
1909
1910 node->clear_output();
1911 node->add_output(node3->output(0));
1912
1913 reduced_node_count += 2;
1914 i += 2;
1915
1916 if (i + 1 < node_count)
1917 {
1918 if (node_reference[node3->output(0)] != 1)
1919 continue;
1920
1921 onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 1);
1922
1923 if (node4->op_type() != "Transpose")
1924 continue;
1925
1926 if (node4->input(0) != node->output(0))
1927 continue;
1928
1929 // 1 0 2
1930 std::vector<int> perm4 = get_node_attr_ai(*node4, "perm");
1931 if (perm4.size() != 3)
1932 continue;
1933
1934 if (perm4[0] != 1 || perm4[1] != 0 || perm4[2] != 2)
1935 continue;
1936
1937 // reduce
1938 node4->set_op_type("noop_reducedncnn");
1939
1940 node_reference[node->output(0)] -= 1;
1941
1942 blob_names.erase(node->output(0));
1943
1944 node->clear_output();
1945 node->add_output(node4->output(0));
1946
1947 reduced_node_count += 1;
1948 i += 1;
1949 }
1950 }
1951 }
1952
1953 for (int i = 0; i < node_count; i++)
1954 {
1955 onnx::NodeProto* node = mutable_graph->mutable_node(i);
1956
1957 // LSTM(uni) <= LSTM(uni) - Squeeze - Transpose
1958 if (node->op_type() == "LSTM" || node->op_type() == "GRU" || node->op_type() == "RNN")
1959 {
1960 if (node_reference[node->output(0)] != 1)
1961 continue;
1962
1963 if (i + 1 >= node_count)
1964 continue;
1965
1966 onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
1967
1968 if (node2->op_type() != "Squeeze")
1969 continue;
1970
1971 if (node2->input(0) != node->output(0))
1972 continue;
1973
1974 std::string direction = get_node_attr_s(*node, "direction");
1975 if (direction == "bidirectional")
1976 continue;
1977
1978 // 1
1979 std::vector<int> axes = get_node_attr_ai(*node2, "axes");
1980 if (axes.size() != 1)
1981 continue;
1982
1983 if (axes[0] != 1)
1984 continue;
1985
1986 // reduce
1987 node2->set_op_type("noop_reducedncnn");
1988
1989 node_reference[node->output(0)] -= 1;
1990
1991 blob_names.erase(node->output(0));
1992 if (node->output_size() > 1)
1993 {
1994 for (int j = 1; j < node->output_size(); j++)
1995 {
1996 blob_names.erase(node->output(j));
1997 }
1998 }
1999
2000 node->clear_output();
2001 node->add_output(node2->output(0));
2002
2003 reduced_node_count += 1;
2004 i += 1;
2005
2006 if (i + 1 < node_count)
2007 {
2008 if (node_reference[node2->output(0)] != 1)
2009 continue;
2010
2011 onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 1);
2012
2013 if (node3->op_type() != "Transpose")
2014 continue;
2015
2016 if (node3->input(0) != node->output(0))
2017 continue;
2018
2019 // 1 0 2
2020 std::vector<int> perm4 = get_node_attr_ai(*node3, "perm");
2021 if (perm4.size() != 3)
2022 continue;
2023
2024 if (perm4[0] != 1 || perm4[1] != 0 || perm4[2] != 2)
2025 continue;
2026
2027 // reduce
2028 node3->set_op_type("noop_reducedncnn");
2029
2030 node_reference[node->output(0)] -= 1;
2031
2032 blob_names.erase(node->output(0));
2033
2034 node->clear_output();
2035 node->add_output(node3->output(0));
2036
2037 reduced_node_count += 1;
2038 i += 1;
2039 }
2040 }
2041 }
2042
2043 for (int i = 0; i < node_count; i++)
2044 {
2045 onnx::NodeProto* node = mutable_graph->mutable_node(i);
2046
2047 // LSTM <= Transpose - LSTM
2048 if (node->op_type() == "Transpose")
2049 {
2050 if (node_reference[node->output(0)] != 1)
2051 continue;
2052
2053 // 1 0 2
2054 std::vector<int> perm = get_node_attr_ai(*node, "perm");
2055 if (perm.size() != 3)
2056 continue;
2057
2058 if (perm[0] != 1 || perm[1] != 0 || perm[2] != 2)
2059 continue;
2060
2061 if (i + 1 >= node_count)
2062 continue;
2063
2064 onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
2065
2066 if (node2->op_type() != "LSTM" && node->op_type() != "GRU" && node->op_type() != "RNN")
2067 continue;
2068
2069 if (node2->input(0) != node->output(0))
2070 continue;
2071
2072 // reduce
2073 node->set_op_type("noop_reducedncnn");
2074
2075 node_reference[node->output(0)] -= 1;
2076
2077 blob_names.erase(node->output(0));
2078
2079 node2->set_input(0, node->input(0));
2080
2081 reduced_node_count += 1;
2082 i += 1;
2083 }
2084 }
2085 }
2086
main(int argc,char ** argv)2087 int main(int argc, char** argv)
2088 {
2089 const char* onnxpb = argv[1];
2090 const char* ncnn_prototxt = argc >= 4 ? argv[2] : "ncnn.param";
2091 const char* ncnn_modelbin = argc >= 4 ? argv[3] : "ncnn.bin";
2092
2093 onnx::ModelProto model;
2094
2095 // load
2096 bool s1 = read_proto_from_binary(onnxpb, &model);
2097 if (!s1)
2098 {
2099 fprintf(stderr, "read_proto_from_binary failed\n");
2100 return -1;
2101 }
2102
2103 FILE* pp = fopen(ncnn_prototxt, "wb");
2104 FILE* bp = fopen(ncnn_modelbin, "wb");
2105
2106 // magic
2107 fprintf(pp, "7767517\n");
2108
2109 const onnx::GraphProto& graph = model.graph();
2110 onnx::GraphProto* mutable_graph = model.mutable_graph();
2111
2112 int node_count = graph.node_size();
2113
2114 // node reference
2115 std::map<std::string, int> node_reference;
2116
2117 // weight node and weight reshape node
2118 std::map<std::string, onnx::TensorProto> weights;
2119
2120 for (int j = 0; j < graph.initializer_size(); j++)
2121 {
2122 const onnx::TensorProto& initializer = graph.initializer(j);
2123
2124 // fprintf(stderr, "weight = %s %d\n", initializer.name().c_str(), initializer.data_type());
2125
2126 weights[initializer.name()] = initializer;
2127 }
2128
2129 // topological sort
2130 {
2131 // name -> producer node index
2132 std::set<std::string> producers;
2133 for (int j = 0; j < graph.input_size(); j++)
2134 {
2135 const std::string& input_name = graph.input(j).name();
2136 producers.insert(input_name);
2137 }
2138
2139 for (int i = 0; i < node_count;)
2140 {
2141 onnx::NodeProto* node = mutable_graph->mutable_node(i);
2142
2143 bool swapnode = false;
2144 std::string missing_input_name;
2145 for (int j = 0; j < (int)node->input_size(); j++)
2146 {
2147 const std::string& input_name = node->input(j);
2148 if (input_name.empty())
2149 continue;
2150
2151 if (producers.find(input_name) == producers.end() && weights.find(input_name) == weights.end())
2152 {
2153 swapnode = true;
2154 missing_input_name = input_name;
2155 break;
2156 }
2157 }
2158
2159 if (!swapnode)
2160 {
2161 for (int j = 0; j < (int)node->output_size(); j++)
2162 {
2163 const std::string& output_name = node->output(j);
2164 if (output_name.empty())
2165 continue;
2166
2167 producers.insert(output_name);
2168 }
2169
2170 i++;
2171 continue;
2172 }
2173
2174 // find node that produce missing_input_name
2175 int q = i + 1;
2176 for (; q < node_count; q++)
2177 {
2178 onnx::NodeProto* nodeq = mutable_graph->mutable_node(q);
2179 bool found = false;
2180 for (int j = 0; j < (int)nodeq->output_size(); j++)
2181 {
2182 const std::string& output_name = nodeq->output(j);
2183 if (output_name == missing_input_name)
2184 {
2185 found = true;
2186 break;
2187 }
2188 }
2189
2190 if (found)
2191 break;
2192 }
2193
2194 if (q == node_count)
2195 {
2196 fprintf(stderr, "cannot find node produces %s but node %d requires it\n", missing_input_name.c_str(), i);
2197 return -1;
2198 }
2199
2200 // fprintf(stderr, "swap %d %d\n", i, q);
2201 // swap this node with q
2202 onnx::NodeProto* nodeq = mutable_graph->mutable_node(q);
2203 onnx::NodeProto tmp = *node;
2204 *node = *nodeq;
2205 *nodeq = tmp;
2206 }
2207 }
2208
2209 // global definition line
2210 // [layer count] [blob count]
2211 std::set<std::string> blob_names;
2212 for (int i = 0; i < node_count; i++)
2213 {
2214 const onnx::NodeProto& node = graph.node(i);
2215
2216 const std::string& op = node.op_type();
2217
2218 std::string name = node.name();
2219 if (name.empty())
2220 {
2221 name = node.output(0);
2222 }
2223
2224 if (op == "Constant")
2225 {
2226 onnx::TensorProto tensor = get_node_attr_tensor(node, "value");
2227 weights[node.output(0)] = tensor;
2228 }
2229
2230 for (int j = 0; j < (int)node.input_size(); j++)
2231 {
2232 const std::string& input_name = node.input(j);
2233
2234 blob_names.insert(input_name);
2235
2236 if (node_reference.find(input_name) == node_reference.end())
2237 {
2238 node_reference[input_name] = 1;
2239 }
2240 else
2241 {
2242 node_reference[input_name] = node_reference[input_name] + 1;
2243 }
2244 }
2245
2246 if (op == "Dropout")
2247 {
2248 const std::string& output_name = node.output(0);
2249 blob_names.insert(output_name);
2250 node_reference[output_name] = 0;
2251 continue;
2252 }
2253
2254 for (int j = 0; j < (int)node.output_size(); j++)
2255 {
2256 const std::string& output_name = node.output(j);
2257
2258 blob_names.insert(output_name);
2259
2260 node_reference[output_name] = 0;
2261 }
2262 }
2263
2264 // include Input node
2265 int input_node_count = 0;
2266 for (int j = 0; j < graph.input_size(); j++)
2267 {
2268 const std::string& input_name = graph.input(j).name();
2269
2270 // check weight
2271 if (weights.find(input_name) != weights.end())
2272 continue;
2273
2274 blob_names.insert(input_name);
2275
2276 input_node_count++;
2277 }
2278
2279 // for (auto a: node_reference)
2280 // {
2281 // fprintf(stderr, "a = %s %d\n", a.first.c_str(), a.second);
2282 // }
2283
2284 // op chain fusion
2285 int reduced_node_count = 0;
2286 fuse_weight_reshape(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
2287 fuse_weight_transpose(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
2288 fuse_shufflechannel(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
2289 fuse_shufflechannel_split(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
2290 fuse_hardsigmoid(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
2291 fuse_hardswish(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
2292 fuse_swish(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
2293 fuse_batchnorm1d_squeeze_unsqueeze(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
2294 fuse_unsqueeze_prelu(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
2295 fuse_normalize(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
2296 fuse_groupnorm(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
2297 fuse_flatten(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
2298 fuse_pixelshuffle(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
2299 fuse_reorg(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
2300 fuse_expand_broadcast(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
2301 fuse_lstm_gru_rnn(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
2302
2303 // reduce common const weight node_reference
2304 for (int i = 0; i < node_count; i++)
2305 {
2306 const onnx::NodeProto& node = graph.node(i);
2307
2308 const std::string& op = node.op_type();
2309
2310 if (op == "Add" || op == "Sub" || op == "Mul" || op == "Div" || op == "Max" || op == "Min" || op == "Pow")
2311 {
2312 // binaryop with scalar
2313 if (weights.find(node.input(1)) != weights.end() && weights[node.input(1)].dims_size() == 0)
2314 {
2315 node_reference[node.input(1)] -= 1;
2316 }
2317 }
2318 else if (op == "Attention")
2319 {
2320 node_reference[node.input(1)] -= 1;
2321 node_reference[node.input(2)] -= 1;
2322 node_reference[node.input(3)] -= 1;
2323 }
2324 else if (op == "BatchNormalization")
2325 {
2326 node_reference[node.input(1)] -= 1;
2327 node_reference[node.input(2)] -= 1;
2328 node_reference[node.input(3)] -= 1;
2329 node_reference[node.input(4)] -= 1;
2330 }
2331 else if (op == "BiasGelu")
2332 {
2333 node_reference[node.input(1)] -= 1;
2334 }
2335 else if (op == "Clip")
2336 {
2337 if (node.input_size() == 3)
2338 {
2339 node_reference[node.input(1)] -= 1;
2340 node_reference[node.input(2)] -= 1;
2341 }
2342 }
2343 else if (op == "Conv")
2344 {
2345 node_reference[node.input(1)] -= 1;
2346 if (node.input_size() == 3)
2347 {
2348 node_reference[node.input(2)] -= 1;
2349 }
2350 }
2351 else if (op == "ConvTranspose")
2352 {
2353 node_reference[node.input(1)] -= 1;
2354 if (node.input_size() == 3)
2355 {
2356 node_reference[node.input(2)] -= 1;
2357 }
2358 }
2359 else if (op == "EmbedLayerNormalization")
2360 {
2361 node_reference[node.input(1)] -= 1;
2362 node_reference[node.input(2)] -= 1;
2363 node_reference[node.input(3)] -= 1;
2364 node_reference[node.input(4)] -= 1;
2365 node_reference[node.input(5)] -= 1;
2366 node_reference[node.input(6)] -= 1;
2367 }
2368 else if (op == "Gemm")
2369 {
2370 float alpha = get_node_attr_f(node, "alpha", 1.f);
2371 float beta = get_node_attr_f(node, "beta", 1.f);
2372 int transA = get_node_attr_i(node, "transA", 0);
2373 int transB = get_node_attr_i(node, "transB", 0);
2374
2375 if (alpha == 1.f && beta == 1.f && transA == 0 && transB == 1)
2376 {
2377 // InnerProduct-like A * B + C
2378 node_reference[node.input(1)] -= 1;
2379 node_reference[node.input(2)] -= 1;
2380 }
2381 }
2382 else if (op == "GroupNorm")
2383 {
2384 int affine = get_node_attr_i(node, "affine", 1);
2385 if (affine)
2386 {
2387 node_reference[node.input(1)] -= 1;
2388 node_reference[node.input(2)] -= 1;
2389 }
2390 }
2391 else if (op == "GRU")
2392 {
2393 for (int j = 1; j < node.input_size(); j++)
2394 {
2395 node_reference[node.input(j)] -= 1;
2396 }
2397 }
2398 else if (op == "InstanceNormalization")
2399 {
2400 node_reference[node.input(1)] -= 1;
2401 node_reference[node.input(2)] -= 1;
2402 }
2403 else if (op == "LSTM")
2404 {
2405 for (int j = 1; j < node.input_size(); j++)
2406 {
2407 node_reference[node.input(j)] -= 1;
2408 }
2409 }
2410 else if (op == "MatMul")
2411 {
2412 if (weights.find(node.input(1)) != weights.end() && weights[node.input(1)].dims_size() == 2)
2413 {
2414 // InnerProduct
2415 node_reference[node.input(1)] -= 1;
2416 }
2417 }
2418 else if (op == "Pad")
2419 {
2420 if (node.input_size() >= 2)
2421 {
2422 node_reference[node.input(1)] -= 1;
2423 }
2424 }
2425 else if (op == "PRelu")
2426 {
2427 node_reference[node.input(1)] -= 1;
2428 }
2429 else if (op == "Reshape")
2430 {
2431 if (node.input_size() >= 2)
2432 {
2433 node_reference[node.input(1)] -= 1;
2434 }
2435 }
2436 else if (op == "Resize")
2437 {
2438 if (node.input_size() == 2)
2439 {
2440 // opset 10
2441 node_reference[node.input(1)] -= 1;
2442 }
2443 else
2444 {
2445 // opset 11+
2446 node_reference[node.input(1)] -= 1;
2447 node_reference[node.input(2)] -= 1;
2448 if (node.input_size() >= 4)
2449 {
2450 node_reference[node.input(3)] -= 1;
2451 }
2452 }
2453 }
2454 else if (op == "RNN")
2455 {
2456 for (int j = 1; j < node.input_size(); j++)
2457 {
2458 node_reference[node.input(j)] -= 1;
2459 }
2460 }
2461 else if (op == "SkipLayerNormalization")
2462 {
2463 node_reference[node.input(2)] -= 1;
2464 node_reference[node.input(3)] -= 1;
2465 node_reference[node.input(4)] -= 1;
2466 }
2467 else if (op == "Slice")
2468 {
2469 if (node.input_size() >= 2)
2470 {
2471 node_reference[node.input(1)] -= 1;
2472 node_reference[node.input(2)] -= 1;
2473 if (node.input_size() >= 4)
2474 node_reference[node.input(3)] -= 1;
2475 if (node.input_size() >= 5)
2476 node_reference[node.input(4)] -= 1;
2477 }
2478 }
2479 else if (op == "Upsample")
2480 {
2481 if (node.input_size() >= 2)
2482 {
2483 node_reference[node.input(1)] -= 1;
2484 }
2485 }
2486 else if (op == "adaptive_avg_pool2d" || op == "adaptive_max_pool2d")
2487 {
2488 if (node.input_size() >= 2)
2489 {
2490 node_reference[node.input(1)] -= 1;
2491 }
2492 }
2493 }
2494
2495 // for (auto a: node_reference)
2496 // {
2497 // fprintf(stderr, "b = %s %d\n", a.first.c_str(), a.second);
2498 // }
2499
2500 // count all weight node with zero reference
2501 int zero_reference_weight_node_count = 0;
2502 for (std::map<std::string, onnx::TensorProto>::iterator it = weights.begin(); it != weights.end(); it++)
2503 {
2504 const std::string& input_name = it->first;
2505
2506 int refcount = node_reference[input_name];
2507 if (refcount == 0)
2508 zero_reference_weight_node_count++;
2509 }
2510
2511 // we always treat constant node as weight or binaryop_weights
2512 // do not count it twice for layer_count
2513 int constant_node_count_moved_to_weight = 0;
2514 for (int i = 0; i < node_count; i++)
2515 {
2516 const onnx::NodeProto& node = graph.node(i);
2517
2518 const std::string& op = node.op_type();
2519
2520 if (op == "Constant")
2521 {
2522 constant_node_count_moved_to_weight++;
2523 }
2524 }
2525
2526 // some op may have anonymous input
2527 // LSTM sequence_lens
2528 blob_names.erase("");
2529 node_reference.erase("");
2530
2531 // remove node_reference entry with reference equals to one
2532 int split_layer_count = 0;
2533 int splitncnn_blob_count = 0;
2534 // split node reference
2535 std::map<std::string, int> split_node_reference;
2536 for (std::map<std::string, int>::iterator it = node_reference.begin(); it != node_reference.end(); it++)
2537 {
2538 if (it->second > 1)
2539 {
2540 split_layer_count++;
2541 splitncnn_blob_count += it->second;
2542
2543 split_node_reference[it->first] = it->second;
2544 }
2545 }
2546
2547 fprintf(pp, "%zu %zu\n", node_count - constant_node_count_moved_to_weight + weights.size() - zero_reference_weight_node_count - reduced_node_count + input_node_count + split_layer_count, blob_names.size() - zero_reference_weight_node_count + splitncnn_blob_count);
2548
2549 int internal_split = 0;
2550
2551 // place Input at the beginning
2552 for (int j = 0; j < graph.input_size(); j++)
2553 {
2554 const std::string& input_name = graph.input(j).name();
2555
2556 // check weight
2557 if (weights.find(input_name) != weights.end())
2558 continue;
2559
2560 fprintf(pp, "%-16s %-24s 0 1 %s\n", "Input", input_name.c_str(), input_name.c_str());
2561
2562 int refcount = node_reference[input_name];
2563 if (refcount <= 1)
2564 {
2565 continue;
2566 }
2567
2568 char splitname[256];
2569 sprintf(splitname, "splitncnn_input%d", j);
2570 fprintf(pp, "%-16s %-24s %d %d", "Split", splitname, 1, refcount);
2571 fprintf(pp, " %s", input_name.c_str());
2572
2573 for (int k = 0; k < refcount; k++)
2574 {
2575 fprintf(pp, " %s_splitncnn_%d", input_name.c_str(), k);
2576 }
2577 fprintf(pp, "\n");
2578 }
2579
2580 // place MemoryData next
2581 for (std::map<std::string, onnx::TensorProto>::iterator weight_it = weights.begin(); weight_it != weights.end(); weight_it++)
2582 {
2583 const std::string& input_name = weight_it->first;
2584
2585 int refcount = node_reference[input_name];
2586 if (refcount == 0)
2587 {
2588 continue;
2589 }
2590
2591 fprintf(pp, "%-16s %-24s 0 1 %s", "MemoryData", input_name.c_str(), input_name.c_str());
2592
2593 const onnx::TensorProto& M = weights[input_name];
2594
2595 if (M.dims_size() == 0)
2596 {
2597 fprintf(pp, " 0=%d", get_tensor_proto_data_size(M));
2598 }
2599 else if (M.dims_size() == 1)
2600 {
2601 fprintf(pp, " 0=%d", (int)M.dims(0));
2602 }
2603 else if (M.dims_size() == 2)
2604 {
2605 fprintf(pp, " 0=%d", (int)M.dims(1));
2606 if (M.dims(0) != 1)
2607 {
2608 fprintf(pp, " 1=%d", (int)M.dims(0));
2609 }
2610 }
2611 else if (M.dims_size() == 3)
2612 {
2613 fprintf(pp, " 0=%d", (int)M.dims(2));
2614 fprintf(pp, " 1=%d", (int)M.dims(1));
2615 if (M.dims(0) != 1)
2616 {
2617 fprintf(pp, " 2=%d", (int)M.dims(0));
2618 }
2619 }
2620 else if (M.dims_size() == 4)
2621 {
2622 fprintf(pp, " 0=%d", (int)M.dims(3));
2623 fprintf(pp, " 1=%d", (int)M.dims(2));
2624 fprintf(pp, " 2=%d", (int)M.dims(1));
2625 }
2626
2627 fprintf(pp, "\n");
2628
2629 fwrite_tensor_proto_data(M, bp);
2630
2631 if (refcount <= 1)
2632 {
2633 continue;
2634 }
2635
2636 char splitname[256];
2637 sprintf(splitname, "splitncnn_%d", internal_split);
2638 fprintf(pp, "%-16s %-24s %d %d", "Split", splitname, 1, refcount);
2639
2640 fprintf(pp, " %s", input_name.c_str());
2641
2642 for (int k = 0; k < refcount; k++)
2643 {
2644 fprintf(pp, " %s_splitncnn_%d", input_name.c_str(), k);
2645 }
2646 fprintf(pp, "\n");
2647
2648 internal_split++;
2649 }
2650
2651 for (int i = 0; i < node_count; i++)
2652 {
2653 const onnx::NodeProto& node = graph.node(i);
2654
2655 const std::string& op = node.op_type();
2656
2657 // fprintf(stderr, "op = %s\n", op.c_str());
2658
2659 if (op == "noop_reducedncnn")
2660 {
2661 continue;
2662 }
2663
2664 std::string name = node.name();
2665 if (name.empty())
2666 {
2667 name = node.output(0);
2668 }
2669
2670 int input_size = node.input_size();
2671 int output_size = node.output_size();
2672
2673 for (int j = 0; j < (int)node.input_size(); j++)
2674 {
2675 const std::string& input_name = node.input(j);
2676
2677 // check weight
2678 if (weights.find(input_name) != weights.end() && node_reference[input_name] == 0)
2679 {
2680 input_size--;
2681 }
2682
2683 if (input_name.empty())
2684 {
2685 input_size--;
2686 }
2687
2688 // fprintf(stderr, " input = %s\n", input_name.c_str());
2689 }
2690 /*
2691 for (int j=0; j<(int)node.output_size(); j++)
2692 {
2693 const std::string& output_name = node.output(j);
2694 fprintf(stderr, " output = %s\n", output_name.c_str());
2695 }
2696 */
2697
2698 if (op == "Abs")
2699 {
2700 fprintf(pp, "%-16s", "UnaryOp");
2701 }
2702 else if (op == "Acos")
2703 {
2704 fprintf(pp, "%-16s", "UnaryOp");
2705 }
2706 else if (op == "Add")
2707 {
2708 fprintf(pp, "%-16s", "BinaryOp");
2709 }
2710 else if (op == "Asin")
2711 {
2712 fprintf(pp, "%-16s", "UnaryOp");
2713 }
2714 else if (op == "Atan")
2715 {
2716 fprintf(pp, "%-16s", "UnaryOp");
2717 }
2718 else if (op == "Attention")
2719 {
2720 fprintf(pp, "%-16s", "Attention");
2721 }
2722 else if (op == "AveragePool" || op == "MaxPool")
2723 {
2724 fprintf(pp, "%-16s", "Pooling");
2725 }
2726 else if (op == "BatchNormalization")
2727 {
2728 fprintf(pp, "%-16s", "BatchNorm");
2729 }
2730 else if (op == "BiasGelu")
2731 {
2732 fprintf(pp, "%-16s", "BiasGelu");
2733 }
2734 else if (op == "Ceil")
2735 {
2736 fprintf(pp, "%-16s", "UnaryOp");
2737 }
2738 else if (op == "Clip")
2739 {
2740 fprintf(pp, "%-16s", "Clip");
2741 }
2742 else if (op == "Concat")
2743 {
2744 fprintf(pp, "%-16s", "Concat");
2745 }
2746 else if (op == "Constant")
2747 {
2748 continue;
2749 }
2750 else if (op == "Conv")
2751 {
2752 int group = get_node_attr_i(node, "group", 1);
2753 if (group > 1)
2754 {
2755 fprintf(pp, "%-16s", "ConvolutionDepthWise");
2756 }
2757 else
2758 {
2759 fprintf(pp, "%-16s", "Convolution");
2760 }
2761 }
2762 else if (op == "ConvTranspose")
2763 {
2764 int group = get_node_attr_i(node, "group", 1);
2765 if (group > 1)
2766 {
2767 fprintf(pp, "%-16s", "DeconvolutionDepthWise");
2768 }
2769 else
2770 {
2771 fprintf(pp, "%-16s", "Deconvolution");
2772 }
2773 }
2774 else if (op == "Cos")
2775 {
2776 fprintf(pp, "%-16s", "UnaryOp");
2777 }
2778 else if (op == "DepthToSpace")
2779 {
2780 fprintf(pp, "%-16s", "PixelShuffle");
2781 }
2782 else if (op == "Div")
2783 {
2784 fprintf(pp, "%-16s", "BinaryOp");
2785 }
2786 else if (op == "Dropout")
2787 {
2788 fprintf(pp, "%-16s", "Dropout");
2789 output_size = 1;
2790 }
2791 else if (op == "Elu")
2792 {
2793 fprintf(pp, "%-16s", "ELU");
2794 }
2795 else if (op == "EmbedLayerNormalization")
2796 {
2797 fprintf(pp, "%-16s", "EmbedLayerNormalization");
2798 }
2799 else if (op == "Exp")
2800 {
2801 fprintf(pp, "%-16s", "UnaryOp");
2802 }
2803 else if (op == "Flatten")
2804 {
2805 fprintf(pp, "%-16s", "Flatten");
2806 }
2807 else if (op == "Floor")
2808 {
2809 fprintf(pp, "%-16s", "UnaryOp");
2810 }
2811 else if (op == "Gemm")
2812 {
2813 float alpha = get_node_attr_f(node, "alpha", 1.f);
2814 float beta = get_node_attr_f(node, "beta", 1.f);
2815 int transA = get_node_attr_i(node, "transA", 0);
2816 int transB = get_node_attr_i(node, "transB", 0);
2817
2818 if (alpha == 1.f && beta == 1.f && transA == 0 && transB == 1)
2819 {
2820 // InnerProduct-like A * B + C
2821 fprintf(pp, "%-16s", "InnerProduct");
2822 }
2823 else
2824 {
2825 fprintf(pp, "%-16s", "Gemm");
2826 }
2827 }
2828 else if (op == "GlobalAveragePool")
2829 {
2830 fprintf(pp, "%-16s", "Pooling");
2831 }
2832 else if (op == "GlobalMaxPool")
2833 {
2834 fprintf(pp, "%-16s", "Pooling");
2835 }
2836 else if (op == "adaptive_avg_pool2d" || op == "adaptive_max_pool2d")
2837 {
2838 fprintf(pp, "%-16s", "Pooling");
2839 }
2840 else if (op == "GroupNorm")
2841 {
2842 fprintf(pp, "%-16s", "GroupNorm");
2843 }
2844 else if (op == "GRU")
2845 {
2846 fprintf(pp, "%-16s", "GRU");
2847 }
2848 else if (op == "HardSigmoid")
2849 {
2850 fprintf(pp, "%-16s", "HardSigmoid");
2851 }
2852 else if (op == "HardSwish")
2853 {
2854 fprintf(pp, "%-16s", "HardSwish");
2855 }
2856 else if (op == "ImageScaler")
2857 {
2858 fprintf(pp, "%-16s", "Scale");
2859 }
2860 else if (op == "InstanceNormalization")
2861 {
2862 fprintf(pp, "%-16s", "InstanceNorm");
2863 }
2864 else if (op == "LeakyRelu")
2865 {
2866 fprintf(pp, "%-16s", "ReLU");
2867 }
2868 else if (op == "Log")
2869 {
2870 fprintf(pp, "%-16s", "UnaryOp");
2871 }
2872 else if (op == "LRN")
2873 {
2874 fprintf(pp, "%-16s", "LRN");
2875 }
2876 else if (op == "LSTM")
2877 {
2878 fprintf(pp, "%-16s", "LSTM");
2879 }
2880 else if (op == "MatMul")
2881 {
2882 if (weights.find(node.input(1)) != weights.end() && weights[node.input(1)].dims_size() == 2)
2883 {
2884 fprintf(pp, "%-16s", "InnerProduct");
2885 }
2886 else
2887 {
2888 fprintf(pp, "%-16s", "Gemm");
2889 }
2890 }
2891 else if (op == "Max")
2892 {
2893 fprintf(pp, "%-16s", "BinaryOp");
2894 }
2895 else if (op == "Min")
2896 {
2897 fprintf(pp, "%-16s", "BinaryOp");
2898 }
2899 else if (op == "Mul")
2900 {
2901 fprintf(pp, "%-16s", "BinaryOp");
2902 }
2903 else if (op == "Neg")
2904 {
2905 fprintf(pp, "%-16s", "UnaryOp");
2906 }
2907 else if (op == "Normalize")
2908 {
2909 fprintf(pp, "%-16s", "Normalize");
2910 }
2911 else if (op == "Pad")
2912 {
2913 fprintf(pp, "%-16s", "Padding");
2914 }
2915 else if (op == "PixelShuffle")
2916 {
2917 fprintf(pp, "%-16s", "PixelShuffle");
2918 }
2919 else if (op == "Pow")
2920 {
2921 fprintf(pp, "%-16s", "BinaryOp");
2922 }
2923 else if (op == "PRelu")
2924 {
2925 fprintf(pp, "%-16s", "PReLU");
2926 }
2927 else if (op == "Reciprocal")
2928 {
2929 fprintf(pp, "%-16s", "UnaryOp");
2930 }
2931 else if (op == "ReduceMax" || op == "ReduceMin" || op == "ReduceMean" || op == "ReduceProd" || op == "ReduceSum" || op == "ReduceSumSquare" || op == "ReduceL1" || op == "ReduceL2" || op == "ReduceLogSum" || op == "ReduceLogSumExp")
2932 {
2933 fprintf(pp, "%-16s", "Reduction");
2934 }
2935 else if (op == "Relu")
2936 {
2937 fprintf(pp, "%-16s", "ReLU");
2938 }
2939 else if (op == "Reorg")
2940 {
2941 fprintf(pp, "%-16s", "Reorg");
2942 }
2943 else if (op == "Reshape")
2944 {
2945 fprintf(pp, "%-16s", "Reshape");
2946 }
2947 else if (op == "RNN")
2948 {
2949 fprintf(pp, "%-16s", "RNN");
2950 }
2951 else if (op == "ShuffleChannel")
2952 {
2953 fprintf(pp, "%-16s", "ShuffleChannel");
2954 }
2955 else if (op == "Sigmoid")
2956 {
2957 fprintf(pp, "%-16s", "Sigmoid");
2958 }
2959 else if (op == "Sin")
2960 {
2961 fprintf(pp, "%-16s", "UnaryOp");
2962 }
2963 else if (op == "SkipLayerNormalization")
2964 {
2965 fprintf(pp, "%-16s", "SkipLayerNormalization");
2966 }
2967 else if (op == "Slice")
2968 {
2969 fprintf(pp, "%-16s", "Crop");
2970 }
2971 else if (op == "Softmax")
2972 {
2973 fprintf(pp, "%-16s", "Softmax");
2974 }
2975 else if (op == "Softplus")
2976 {
2977 fprintf(pp, "%-16s", "Softplus");
2978 }
2979 else if (op == "Split")
2980 {
2981 fprintf(pp, "%-16s", "Slice");
2982 }
2983 else if (op == "Sqrt")
2984 {
2985 fprintf(pp, "%-16s", "UnaryOp");
2986 }
2987 else if (op == "Squeeze")
2988 {
2989 fprintf(pp, "%-16s", "Squeeze");
2990 }
2991 else if (op == "Sub")
2992 {
2993 fprintf(pp, "%-16s", "BinaryOp");
2994 }
2995 else if (op == "Sum")
2996 {
2997 fprintf(pp, "%-16s", "Eltwise");
2998 }
2999 else if (op == "Swish")
3000 {
3001 fprintf(pp, "%-16s", "Swish");
3002 }
3003 else if (op == "Tan")
3004 {
3005 fprintf(pp, "%-16s", "UnaryOp");
3006 }
3007 else if (op == "Tanh")
3008 {
3009 fprintf(pp, "%-16s", "UnaryOp");
3010 }
3011 else if (op == "Transpose")
3012 {
3013 fprintf(pp, "%-16s", "Permute");
3014 }
3015 else if (op == "Upsample" || op == "Resize")
3016 {
3017 fprintf(pp, "%-16s", "Interp");
3018 }
3019 else if (op == "Unsqueeze")
3020 {
3021 fprintf(pp, "%-16s", "ExpandDims");
3022 }
3023 else
3024 {
3025 // TODO
3026 fprintf(stderr, "%s not supported yet!\n", op.c_str());
3027 fprintf(pp, "%-16s", op.c_str());
3028 }
3029
3030 fprintf(pp, " %-24s %d %d", name.c_str(), input_size, output_size);
3031
3032 for (int j = 0; j < (int)node.input_size(); j++)
3033 {
3034 std::string input_name = node.input(j);
3035
3036 // check weight
3037 if (weights.find(input_name) != weights.end() && node_reference[input_name] == 0)
3038 {
3039 continue;
3040 }
3041
3042 if (input_name.empty())
3043 {
3044 continue;
3045 }
3046
3047 if (split_node_reference.find(input_name) != split_node_reference.end())
3048 {
3049 int refidx = split_node_reference[input_name] - 1;
3050 split_node_reference[input_name] = refidx;
3051
3052 char splitsuffix[256];
3053 sprintf(splitsuffix, "_splitncnn_%d", refidx);
3054 input_name = input_name + splitsuffix;
3055 }
3056
3057 fprintf(pp, " %s", input_name.c_str());
3058 }
3059
3060 for (int j = 0; j < output_size; j++)
3061 {
3062 const std::string& output_name = node.output(j);
3063
3064 fprintf(pp, " %s", output_name.c_str());
3065 }
3066
3067 if (op == "Abs")
3068 {
3069 int op_type = 0;
3070 fprintf(pp, " 0=%d", op_type);
3071 }
3072 else if (op == "Acos")
3073 {
3074 int op_type = 13;
3075 fprintf(pp, " 0=%d", op_type);
3076 }
3077 else if (op == "Add")
3078 {
3079 int op_type = 0;
3080 fprintf(pp, " 0=%d", op_type);
3081
3082 if (weights.find(node.input(1)) != weights.end() && weights[node.input(1)].dims_size() == 0)
3083 {
3084 float b = get_node_attr_from_input_f(weights[node.input(1)]);
3085 fprintf(pp, " 1=1");
3086 fprintf(pp, " 2=%e", b);
3087 }
3088 }
3089 else if (op == "Asin")
3090 {
3091 int op_type = 12;
3092 fprintf(pp, " 0=%d", op_type);
3093 }
3094 else if (op == "Atan")
3095 {
3096 int op_type = 14;
3097 fprintf(pp, " 0=%d", op_type);
3098 }
3099 else if (op == "Attention")
3100 {
3101 int num_heads = get_node_attr_i(node, "num_heads", 1);
3102
3103 const onnx::TensorProto& W = weights[node.input(1)];
3104 const onnx::TensorProto& B = weights[node.input(2)];
3105
3106 fprintf(pp, " 0=%d", get_tensor_proto_data_size(B));
3107 fprintf(pp, " 1=%d", num_heads);
3108 fprintf(pp, " 2=%d", get_tensor_proto_data_size(W));
3109
3110 int quantize_tag = 0;
3111 fwrite(&quantize_tag, sizeof(int), 1, bp);
3112
3113 fwrite_tensor_proto_data(W, bp);
3114
3115 fwrite(&quantize_tag, sizeof(int), 1, bp);
3116
3117 fwrite_tensor_proto_data(B, bp);
3118 }
3119 else if (op == "AveragePool" || op == "MaxPool")
3120 {
3121 std::string auto_pad = get_node_attr_s(node, "auto_pad");
3122 int ceil_mode = get_node_attr_i(node, "ceil_mode", 0);
3123 std::vector<int> kernel_shape = get_node_attr_ai(node, "kernel_shape");
3124 std::vector<int> strides = get_node_attr_ai(node, "strides");
3125 std::vector<int> pads = get_node_attr_ai(node, "pads");
3126
3127 int pool = op == "AveragePool" ? 1 : 0;
3128 int pad_mode = 1;
3129
3130 if (auto_pad == "SAME_UPPER")
3131 {
3132 pad_mode = 2;
3133 }
3134 else if (auto_pad == "SAME_LOWER")
3135 {
3136 pad_mode = 3;
3137 }
3138
3139 if (ceil_mode == 1)
3140 {
3141 pad_mode = 0;
3142 }
3143
3144 fprintf(pp, " 0=%d", pool);
3145
3146 if (kernel_shape.size() == 1)
3147 {
3148 fprintf(pp, " 1=%d", kernel_shape[0]);
3149 }
3150 else if (kernel_shape.size() == 2)
3151 {
3152 fprintf(pp, " 1=%d", kernel_shape[1]);
3153 fprintf(pp, " 11=%d", kernel_shape[0]);
3154 }
3155
3156 if (strides.size() == 1)
3157 {
3158 fprintf(pp, " 2=%d", strides[0]);
3159 }
3160 else if (strides.size() == 2)
3161 {
3162 fprintf(pp, " 2=%d", strides[1]);
3163 fprintf(pp, " 12=%d", strides[0]);
3164 }
3165
3166 if (pads.size() == 1)
3167 {
3168 fprintf(pp, " 3=%d", pads[0]);
3169 }
3170 else if (pads.size() == 2)
3171 {
3172 fprintf(pp, " 3=%d", pads[1]);
3173 fprintf(pp, " 13=%d", pads[0]);
3174 }
3175 else if (pads.size() == 4)
3176 {
3177 fprintf(pp, " 3=%d", pads[1]);
3178 fprintf(pp, " 13=%d", pads[0]);
3179 fprintf(pp, " 14=%d", pads[3]);
3180 fprintf(pp, " 15=%d", pads[2]);
3181 }
3182
3183 fprintf(pp, " 5=%d", pad_mode);
3184
3185 if (op == "AveragePool")
3186 {
3187 int avgpool_count_include_pad = get_node_attr_i(node, "count_include_pad", 0);
3188 fprintf(pp, " 6=%d", avgpool_count_include_pad);
3189 }
3190 }
3191 else if (op == "BatchNormalization")
3192 {
3193 float epsilon = get_node_attr_f(node, "epsilon", 1e-5f);
3194
3195 const onnx::TensorProto& scale = weights[node.input(1)];
3196 const onnx::TensorProto& B = weights[node.input(2)];
3197 const onnx::TensorProto& mean = weights[node.input(3)];
3198 const onnx::TensorProto& var = weights[node.input(4)];
3199
3200 int channels = get_tensor_proto_data_size(scale);
3201
3202 fprintf(pp, " 0=%d", channels);
3203
3204 fwrite_tensor_proto_data(scale, bp);
3205 fwrite_tensor_proto_data(mean, bp);
3206 // apply epsilon to var
3207 {
3208 const float* v = var.has_raw_data() ? (const float*)var.raw_data().data() : var.float_data().data();
3209
3210 for (int j = 0; j < channels; j++)
3211 {
3212 float ve = v[j] + epsilon;
3213 fwrite(&ve, sizeof(float), 1, bp);
3214 }
3215 }
3216 fwrite_tensor_proto_data(B, bp);
3217 }
3218 else if (op == "BiasGelu")
3219 {
3220 const onnx::TensorProto& B = weights[node.input(1)];
3221
3222 fprintf(pp, " 0=%d", get_tensor_proto_data_size(B));
3223
3224 int quantize_tag = 0;
3225 fwrite(&quantize_tag, sizeof(int), 1, bp);
3226
3227 fwrite_tensor_proto_data(B, bp);
3228 }
3229 else if (op == "Ceil")
3230 {
3231 int op_type = 3;
3232 fprintf(pp, " 0=%d", op_type);
3233 }
3234 else if (op == "Clip")
3235 {
3236 float min;
3237 float max;
3238 if (node.input_size() == 1)
3239 {
3240 min = get_node_attr_f(node, "min", -FLT_MAX);
3241 max = get_node_attr_f(node, "max", FLT_MAX);
3242 }
3243 else
3244 {
3245 const onnx::TensorProto& min_tp = weights[node.input(1)];
3246 const onnx::TensorProto& max_tp = weights[node.input(2)];
3247
3248 min = get_node_attr_from_input_f(min_tp);
3249 max = get_node_attr_from_input_f(max_tp);
3250 }
3251
3252 fprintf(pp, " 0=%e", min);
3253 fprintf(pp, " 1=%e", max);
3254 }
3255 else if (op == "Concat")
3256 {
3257 int axis = get_node_attr_i(node, "axis", 1);
3258 fprintf(pp, " 0=%d", axis - 1);
3259 }
3260 else if (op == "Constant")
3261 {
3262 // never reach here
3263 }
3264 else if (op == "Conv")
3265 {
3266 const onnx::TensorProto& W = weights[node.input(1)];
3267
3268 int num_filter = W.dims(0);
3269 int has_bias = node.input_size() == 3 ? 1 : 0;
3270
3271 std::string auto_pad = get_node_attr_s(node, "auto_pad");
3272 std::vector<int> kernel_shape = get_node_attr_ai(node, "kernel_shape");
3273 std::vector<int> dilations = get_node_attr_ai(node, "dilations");
3274 std::vector<int> strides = get_node_attr_ai(node, "strides");
3275 std::vector<int> pads = get_node_attr_ai(node, "pads");
3276 int group = get_node_attr_i(node, "group", 1);
3277
3278 fprintf(pp, " 0=%d", num_filter);
3279
3280 if (kernel_shape.size() == 1)
3281 {
3282 fprintf(pp, " 1=%d", kernel_shape[0]);
3283 }
3284 else if (kernel_shape.size() == 2)
3285 {
3286 fprintf(pp, " 1=%d", kernel_shape[1]);
3287 fprintf(pp, " 11=%d", kernel_shape[0]);
3288 }
3289
3290 if (dilations.size() == 1)
3291 {
3292 fprintf(pp, " 2=%d", dilations[0]);
3293 }
3294 else if (dilations.size() == 2)
3295 {
3296 fprintf(pp, " 2=%d", dilations[1]);
3297 fprintf(pp, " 12=%d", dilations[0]);
3298 }
3299
3300 if (strides.size() == 1)
3301 {
3302 fprintf(pp, " 3=%d", strides[0]);
3303 }
3304 else if (strides.size() == 2)
3305 {
3306 fprintf(pp, " 3=%d", strides[1]);
3307 fprintf(pp, " 13=%d", strides[0]);
3308 }
3309
3310 if (auto_pad == "SAME_UPPER")
3311 {
3312 fprintf(pp, " 4=-233");
3313 }
3314 else if (auto_pad == "SAME_LOWER")
3315 {
3316 fprintf(pp, " 4=-234");
3317 }
3318 else
3319 {
3320 if (pads.size() == 1)
3321 {
3322 fprintf(pp, " 4=%d", pads[0]);
3323 }
3324 else if (pads.size() == 2)
3325 {
3326 fprintf(pp, " 4=%d", pads[1]);
3327 fprintf(pp, " 14=%d", pads[0]);
3328 }
3329 else if (pads.size() == 4)
3330 {
3331 fprintf(pp, " 4=%d", pads[1]);
3332 fprintf(pp, " 14=%d", pads[0]);
3333 fprintf(pp, " 15=%d", pads[3]);
3334 fprintf(pp, " 16=%d", pads[2]);
3335 }
3336 }
3337
3338 fprintf(pp, " 5=%d", has_bias);
3339
3340 fprintf(pp, " 6=%d", get_tensor_proto_data_size(W));
3341
3342 if (group > 1)
3343 {
3344 fprintf(pp, " 7=%d", group);
3345 }
3346
3347 int quantize_tag = 0;
3348 fwrite(&quantize_tag, sizeof(int), 1, bp);
3349
3350 fwrite_tensor_proto_data(W, bp);
3351
3352 if (has_bias)
3353 {
3354 const onnx::TensorProto& B = weights[node.input(2)];
3355 fwrite_tensor_proto_data(B, bp);
3356 }
3357 }
3358 else if (op == "ConvTranspose")
3359 {
3360 const onnx::TensorProto& W = weights[node.input(1)];
3361
3362 int has_bias = node.input_size() == 3 ? 1 : 0;
3363
3364 std::string auto_pad = get_node_attr_s(node, "auto_pad");
3365 std::vector<int> kernel_shape = get_node_attr_ai(node, "kernel_shape");
3366 std::vector<int> dilations = get_node_attr_ai(node, "dilations");
3367 std::vector<int> strides = get_node_attr_ai(node, "strides");
3368 std::vector<int> output_padding = get_node_attr_ai(node, "output_padding");
3369 std::vector<int> output_shape = get_node_attr_ai(node, "output_shape");
3370 std::vector<int> pads = get_node_attr_ai(node, "pads");
3371 int group = get_node_attr_i(node, "group", 1);
3372 int num_filter = W.dims(1) * group;
3373
3374 fprintf(pp, " 0=%d", num_filter);
3375
3376 if (kernel_shape.size() == 1)
3377 {
3378 fprintf(pp, " 1=%d", kernel_shape[0]);
3379 }
3380 else if (kernel_shape.size() == 2)
3381 {
3382 fprintf(pp, " 1=%d", kernel_shape[1]);
3383 fprintf(pp, " 11=%d", kernel_shape[0]);
3384 }
3385
3386 if (dilations.size() == 1)
3387 {
3388 fprintf(pp, " 2=%d", dilations[0]);
3389 }
3390 else if (dilations.size() == 2)
3391 {
3392 fprintf(pp, " 2=%d", dilations[1]);
3393 fprintf(pp, " 12=%d", dilations[0]);
3394 }
3395
3396 if (strides.size() == 1)
3397 {
3398 fprintf(pp, " 3=%d", strides[0]);
3399 }
3400 else if (strides.size() == 2)
3401 {
3402 fprintf(pp, " 3=%d", strides[1]);
3403 fprintf(pp, " 13=%d", strides[0]);
3404 }
3405
3406 if (auto_pad == "SAME_UPPER")
3407 {
3408 fprintf(pp, " 4=-233");
3409 }
3410 else if (auto_pad == "SAME_LOWER")
3411 {
3412 fprintf(pp, " 4=-234");
3413 }
3414 else
3415 {
3416 if (pads.size() == 1)
3417 {
3418 fprintf(pp, " 4=%d", pads[0]);
3419 }
3420 else if (pads.size() == 2)
3421 {
3422 fprintf(pp, " 4=%d", pads[1]);
3423 fprintf(pp, " 14=%d", pads[0]);
3424 }
3425 else if (pads.size() == 4)
3426 {
3427 fprintf(pp, " 4=%d", pads[1]);
3428 fprintf(pp, " 14=%d", pads[0]);
3429 fprintf(pp, " 15=%d", pads[3]);
3430 fprintf(pp, " 16=%d", pads[2]);
3431 }
3432 }
3433
3434 if (output_padding.size() == 1)
3435 {
3436 fprintf(pp, " 18=%d", output_padding[0]);
3437 }
3438 else if (output_padding.size() == 2)
3439 {
3440 fprintf(pp, " 18=%d", output_padding[1]);
3441 fprintf(pp, " 19=%d", output_padding[0]);
3442 }
3443
3444 if (output_shape.size() == 1)
3445 {
3446 fprintf(pp, " 20=%d", output_shape[0]);
3447 }
3448 else if (output_shape.size() == 2)
3449 {
3450 fprintf(pp, " 20=%d", output_shape[1]);
3451 fprintf(pp, " 21=%d", output_shape[0]);
3452 }
3453
3454 fprintf(pp, " 5=%d", has_bias);
3455
3456 fprintf(pp, " 6=%d", get_tensor_proto_data_size(W));
3457
3458 if (group > 1)
3459 {
3460 fprintf(pp, " 7=%d", group);
3461 }
3462
3463 int quantize_tag = 0;
3464 fwrite(&quantize_tag, sizeof(int), 1, bp);
3465
3466 int maxk = 0;
3467 if (kernel_shape.size() == 2)
3468 {
3469 maxk = kernel_shape[1] * kernel_shape[0];
3470 }
3471 else
3472 {
3473 maxk = kernel_shape[0] * kernel_shape[0];
3474 }
3475 int weight_data_size = get_tensor_proto_data_size(W);
3476 const float* weight_data = 0;
3477 if (W.has_raw_data())
3478 {
3479 weight_data = (const float*)W.raw_data().data();
3480 }
3481 else if (W.data_type() == 1)
3482 {
3483 weight_data = W.float_data().data();
3484 }
3485 for (int g = 0; g < group; g++)
3486 {
3487 // reorder weight from inch-outch to outch-inch
3488 int num_filter_g = num_filter / group;
3489 int num_input = weight_data_size / maxk / num_filter_g / group;
3490 const float* weight_data_ptr = weight_data + g * maxk * num_filter_g * num_input;
3491 for (int k = 0; k < num_filter_g; k++)
3492 {
3493 for (int j = 0; j < num_input; j++)
3494 {
3495 fwrite(weight_data_ptr + (j * num_filter_g + k) * maxk, sizeof(float), maxk, bp);
3496 }
3497 }
3498 }
3499
3500 if (has_bias)
3501 {
3502 const onnx::TensorProto& B = weights[node.input(2)];
3503 fwrite_tensor_proto_data(B, bp);
3504 }
3505 }
3506 else if (op == "Cos")
3507 {
3508 int op_type = 10;
3509 fprintf(pp, " 0=%d", op_type);
3510 }
3511 else if (op == "DepthToSpace")
3512 {
3513 // pixelshuffle
3514 int scale_factor = get_node_attr_i(node, "blocksize", 1);
3515 std::string mode = get_node_attr_s(node, "mode");
3516 fprintf(pp, " 0=%d", scale_factor);
3517 if (mode == "CRD")
3518 {
3519 fprintf(pp, " 1=0");
3520 }
3521 else if (mode == "DCR")
3522 {
3523 fprintf(pp, " 1=1");
3524 }
3525 }
3526 else if (op == "Div")
3527 {
3528 int op_type = 3;
3529 fprintf(pp, " 0=%d", op_type);
3530
3531 if (weights.find(node.input(1)) != weights.end() && weights[node.input(1)].dims_size() == 0)
3532 {
3533 float b = get_node_attr_from_input_f(weights[node.input(1)]);
3534 fprintf(pp, " 1=1");
3535 fprintf(pp, " 2=%e", b);
3536 }
3537 }
3538 else if (op == "Dropout")
3539 {
3540 // no-op
3541 }
3542 else if (op == "Elu")
3543 {
3544 float alpha = get_node_attr_f(node, "alpha", 1.f);
3545 fprintf(pp, " 0=%e", alpha);
3546 }
3547 else if (op == "EmbedLayerNormalization")
3548 {
3549 const onnx::TensorProto& words = weights[node.input(2)];
3550 const onnx::TensorProto& positions = weights[node.input(3)];
3551 const onnx::TensorProto& W = weights[node.input(5)];
3552 const onnx::TensorProto& B = weights[node.input(6)];
3553
3554 fprintf(pp, " 0=%d", get_tensor_proto_data_size(B));
3555 fprintf(pp, " 1=%d", get_tensor_proto_data_size(words));
3556 fprintf(pp, " 2=%d", get_tensor_proto_data_size(positions));
3557
3558 int quantize_tag = 0;
3559 fwrite(&quantize_tag, sizeof(int), 1, bp);
3560
3561 fwrite_tensor_proto_data(words, bp);
3562
3563 fwrite(&quantize_tag, sizeof(int), 1, bp);
3564
3565 fwrite_tensor_proto_data(positions, bp);
3566
3567 fwrite(&quantize_tag, sizeof(int), 1, bp);
3568
3569 fwrite_tensor_proto_data(W, bp);
3570
3571 fwrite(&quantize_tag, sizeof(int), 1, bp);
3572
3573 fwrite_tensor_proto_data(B, bp);
3574 }
3575 else if (op == "Exp")
3576 {
3577 int op_type = 7;
3578 fprintf(pp, " 0=%d", op_type);
3579 }
3580 else if (op == "Flatten")
3581 {
3582 int axis = get_node_attr_i(node, "axis", 1);
3583 if (axis != 1)
3584 {
3585 fprintf(stderr, "Unsupported Flatten axis %d!\n", axis);
3586 }
3587 }
3588 else if (op == "Floor")
3589 {
3590 int op_type = 2;
3591 fprintf(pp, " 0=%d", op_type);
3592 }
3593 else if (op == "Gemm")
3594 {
3595 float alpha = get_node_attr_f(node, "alpha", 1.f);
3596 float beta = get_node_attr_f(node, "beta", 1.f);
3597 int transA = get_node_attr_i(node, "transA", 0);
3598 int transB = get_node_attr_i(node, "transB", 0);
3599
3600 if (alpha == 1.f && beta == 1.f && transA == 0 && transB == 1)
3601 {
3602 // InnerProduct-like A * B + C
3603 const onnx::TensorProto& B = weights[node.input(1)];
3604 const onnx::TensorProto& C = weights[node.input(2)];
3605
3606 fprintf(pp, " 0=%d", get_tensor_proto_data_size(C));
3607 fprintf(pp, " 1=1");
3608 fprintf(pp, " 2=%d", get_tensor_proto_data_size(B));
3609
3610 int quantize_tag = 0;
3611 fwrite(&quantize_tag, sizeof(int), 1, bp);
3612
3613 fwrite_tensor_proto_data(B, bp);
3614 fwrite_tensor_proto_data(C, bp);
3615 }
3616 else
3617 {
3618 // gemm
3619 fprintf(pp, " 0=%e", alpha);
3620 fprintf(pp, " 1=%e", beta);
3621 fprintf(pp, " 2=%d", transA);
3622 fprintf(pp, " 3=%d", transB);
3623 }
3624 }
3625 else if (op == "GlobalAveragePool")
3626 {
3627 int pool = 1;
3628 int global_pool = 1;
3629
3630 fprintf(pp, " 0=%d", pool);
3631 fprintf(pp, " 4=%d", global_pool);
3632 }
3633 else if (op == "GlobalMaxPool")
3634 {
3635 int pool = 0;
3636 int global_pool = 1;
3637
3638 fprintf(pp, " 0=%d", pool);
3639 fprintf(pp, " 4=%d", global_pool);
3640 }
3641 else if (op == "adaptive_avg_pool2d" || op == "adaptive_max_pool2d")
3642 {
3643 int pool = 0;
3644 if (op == "adaptive_avg_pool2d")
3645 {
3646 pool = 1;
3647 }
3648 int adaptive_pooling = 1;
3649 const onnx::TensorProto& out_shape_tp = weights[node.input(1)];
3650 std::vector<int> out_shape = get_node_attr_from_input_ai(out_shape_tp);
3651
3652 fprintf(pp, " 0=%d", pool);
3653 fprintf(pp, " 7=%d", adaptive_pooling);
3654 if (out_shape.size() == 1)
3655 {
3656 fprintf(pp, " 8=%d", out_shape[0]);
3657 }
3658 else if (out_shape.size() == 2)
3659 {
3660 // out_w
3661 fprintf(pp, " 8=%d", out_shape[1]);
3662 // out_h
3663 fprintf(pp, " 18=%d", out_shape[0]);
3664 }
3665 }
3666 else if (op == "GroupNorm")
3667 {
3668 int groups = get_node_attr_i(node, "groups", 1);
3669 int channels = get_node_attr_i(node, "channels", 1);
3670 float eps = get_node_attr_f(node, "epsilon", 1e-5f);
3671 int affine = get_node_attr_i(node, "affine", 1);
3672
3673 fprintf(pp, " 0=%d", groups);
3674 fprintf(pp, " 1=%d", channels);
3675 fprintf(pp, " 2=%e", eps);
3676 fprintf(pp, " 3=%d", affine);
3677 if (affine)
3678 {
3679 const onnx::TensorProto& scale = weights[node.input(1)];
3680 const onnx::TensorProto& B = weights[node.input(2)];
3681
3682 fwrite_tensor_proto_data(scale, bp);
3683 fwrite_tensor_proto_data(B, bp);
3684 }
3685 }
3686 else if (op == "GRU")
3687 {
3688 const onnx::TensorProto& W = weights[node.input(1)];
3689 const onnx::TensorProto& R = weights[node.input(2)];
3690 const onnx::TensorProto& B = weights[node.input(3)];
3691
3692 int hidden_size = get_node_attr_i(node, "hidden_size", 0);
3693 std::string direction = get_node_attr_s(node, "direction");
3694
3695 int direction_type = 0;
3696 if (direction == "forward")
3697 {
3698 direction_type = 0;
3699 }
3700 else if (direction == "reverse")
3701 {
3702 direction_type = 1;
3703 }
3704 else if (direction == "bidirectional")
3705 {
3706 direction_type = 2;
3707 }
3708
3709 int weight_data_size = get_tensor_proto_data_size(W);
3710
3711 fprintf(pp, " 0=%d", hidden_size);
3712 fprintf(pp, " 1=%d", weight_data_size);
3713 fprintf(pp, " 2=%d", direction_type);
3714
3715 int num_directions = direction_type == 2 ? 2 : 1;
3716
3717 int quantize_tag = 0;
3718
3719 // reorder num_directions-URN-hidden-size to num_directions-RUN-hidden-size
3720 {
3721 fwrite(&quantize_tag, sizeof(int), 1, bp);
3722
3723 int weight_data_size_g = get_tensor_proto_data_size(W) / 3 / num_directions;
3724 const float* wptr = W.has_raw_data() ? (const float*)W.raw_data().data() : W.float_data().data();
3725
3726 const float* uptr = wptr;
3727 const float* rptr = wptr + weight_data_size_g;
3728 const float* nptr = wptr + weight_data_size_g * 2;
3729 fwrite(rptr, sizeof(float), weight_data_size_g, bp);
3730 fwrite(uptr, sizeof(float), weight_data_size_g, bp);
3731 fwrite(nptr, sizeof(float), weight_data_size_g, bp);
3732
3733 if (direction_type == 2)
3734 {
3735 uptr += weight_data_size_g * 3;
3736 rptr += weight_data_size_g * 3;
3737 nptr += weight_data_size_g * 3;
3738 fwrite(rptr, sizeof(float), weight_data_size_g, bp);
3739 fwrite(uptr, sizeof(float), weight_data_size_g, bp);
3740 fwrite(nptr, sizeof(float), weight_data_size_g, bp);
3741 }
3742 }
3743
3744 // reduce U and R bias except N
3745 // reorder num_directions-URN-hidden to num_directions-RUN-hidden
3746 {
3747 fwrite(&quantize_tag, sizeof(int), 1, bp);
3748
3749 int bias_data_size_g = get_tensor_proto_data_size(B) / 2 / 3 / num_directions;
3750 const float* bptr = B.has_raw_data() ? (const float*)B.raw_data().data() : B.float_data().data();
3751 const float* wuptr = bptr;
3752 const float* wrptr = bptr + bias_data_size_g;
3753 const float* wnptr = bptr + bias_data_size_g * 2;
3754 const float* buptr = bptr + bias_data_size_g * 3;
3755 const float* brptr = bptr + bias_data_size_g * 4;
3756 const float* bnptr = bptr + bias_data_size_g * 5;
3757
3758 for (int j = 0; j < bias_data_size_g; j++)
3759 {
3760 float vb = wrptr[j] + brptr[j];
3761 fwrite(&vb, sizeof(float), 1, bp);
3762 }
3763 for (int j = 0; j < bias_data_size_g; j++)
3764 {
3765 float vb = wuptr[j] + buptr[j];
3766 fwrite(&vb, sizeof(float), 1, bp);
3767 }
3768 fwrite(wnptr, sizeof(float), bias_data_size_g, bp);
3769 fwrite(bnptr, sizeof(float), bias_data_size_g, bp);
3770
3771 if (direction_type == 2)
3772 {
3773 wuptr += bias_data_size_g * 6;
3774 wrptr += bias_data_size_g * 6;
3775 wnptr += bias_data_size_g * 6;
3776 buptr += bias_data_size_g * 6;
3777 brptr += bias_data_size_g * 6;
3778 bnptr += bias_data_size_g * 6;
3779
3780 for (int j = 0; j < bias_data_size_g; j++)
3781 {
3782 float vb = wrptr[j] + brptr[j];
3783 fwrite(&vb, sizeof(float), 1, bp);
3784 }
3785 for (int j = 0; j < bias_data_size_g; j++)
3786 {
3787 float vb = wuptr[j] + buptr[j];
3788 fwrite(&vb, sizeof(float), 1, bp);
3789 }
3790 fwrite(wnptr, sizeof(float), bias_data_size_g, bp);
3791 fwrite(bnptr, sizeof(float), bias_data_size_g, bp);
3792 }
3793 }
3794
3795 // reorder num_directions-URN-hidden-hidden to num_directions-RUN-hidden-hidden
3796 {
3797 fwrite(&quantize_tag, sizeof(int), 1, bp);
3798
3799 int weight_data_size_g = get_tensor_proto_data_size(R) / 3 / num_directions;
3800 const float* Rptr = R.has_raw_data() ? (const float*)R.raw_data().data() : R.float_data().data();
3801
3802 const float* uptr = Rptr;
3803 const float* rptr = Rptr + weight_data_size_g;
3804 const float* nptr = Rptr + weight_data_size_g * 2;
3805 fwrite(rptr, sizeof(float), weight_data_size_g, bp);
3806 fwrite(uptr, sizeof(float), weight_data_size_g, bp);
3807 fwrite(nptr, sizeof(float), weight_data_size_g, bp);
3808
3809 if (direction_type == 2)
3810 {
3811 uptr += weight_data_size_g * 3;
3812 rptr += weight_data_size_g * 3;
3813 nptr += weight_data_size_g * 3;
3814 fwrite(rptr, sizeof(float), weight_data_size_g, bp);
3815 fwrite(uptr, sizeof(float), weight_data_size_g, bp);
3816 fwrite(nptr, sizeof(float), weight_data_size_g, bp);
3817 }
3818 }
3819 }
3820 else if (op == "HardSigmoid")
3821 {
3822 float alpha = get_node_attr_f(node, "alpha", 0.2f);
3823 float beta = get_node_attr_f(node, "beta", 0.5f);
3824
3825 fprintf(pp, " 0=%e", alpha);
3826 fprintf(pp, " 1=%e", beta);
3827 }
3828 else if (op == "HardSwish")
3829 {
3830 float alpha = get_node_attr_f(node, "alpha", 0.2f);
3831 float beta = get_node_attr_f(node, "beta", 0.5f);
3832
3833 fprintf(pp, " 0=%e", alpha);
3834 fprintf(pp, " 1=%e", beta);
3835 }
3836 else if (op == "ImageScaler")
3837 {
3838 std::vector<float> bias = get_node_attr_af(node, "bias");
3839 float scale = get_node_attr_f(node, "scale", 1.f);
3840
3841 int channels = (int)bias.size();
3842
3843 fprintf(pp, " 0=%d", channels);
3844 fprintf(pp, " 1=1");
3845
3846 for (int j = 0; j < channels; j++)
3847 {
3848 fwrite(&scale, sizeof(float), 1, bp);
3849 }
3850 fwrite(&bias[0], sizeof(float), channels, bp);
3851 }
3852 else if (op == "InstanceNormalization")
3853 {
3854 float eps = get_node_attr_f(node, "epsilon", 1e-5f);
3855
3856 // discard affine-less S=1 B=0
3857 std::vector<float> affine_S = get_node_attr_from_input_af(weights[node.input(1)]);
3858 std::vector<float> affine_B = get_node_attr_from_input_af(weights[node.input(2)]);
3859 int channels = (int)affine_S.size();
3860 int affine = 0;
3861 {
3862 for (int j = 0; j < channels; j++)
3863 {
3864 if (affine_S[j] != 1.f || affine_B[j] != 0.f)
3865 {
3866 affine = 1;
3867 break;
3868 }
3869 }
3870 }
3871
3872 fprintf(pp, " 0=%d", channels);
3873 fprintf(pp, " 1=%e", eps);
3874 fprintf(pp, " 2=%d", affine);
3875 if (affine)
3876 {
3877 const onnx::TensorProto& scale = weights[node.input(1)];
3878 const onnx::TensorProto& B = weights[node.input(2)];
3879
3880 fwrite_tensor_proto_data(scale, bp);
3881 fwrite_tensor_proto_data(B, bp);
3882 }
3883 }
3884 else if (op == "LeakyRelu")
3885 {
3886 float alpha = get_node_attr_f(node, "alpha", 0.01f);
3887
3888 fprintf(pp, " 0=%e", alpha);
3889 }
3890 else if (op == "Log")
3891 {
3892 int op_type = 8;
3893 fprintf(pp, " 0=%d", op_type);
3894 }
3895 else if (op == "LRN")
3896 {
3897 float alpha = get_node_attr_f(node, "alpha", 1.f);
3898 float beta = get_node_attr_f(node, "beta", 0.5f);
3899 float bias = get_node_attr_f(node, "bias", 1.f);
3900 int size = get_node_attr_i(node, "size", 1);
3901
3902 int norm_region = 0;
3903
3904 fprintf(pp, " 0=%d", norm_region);
3905 fprintf(pp, " 1=%d", size);
3906 fprintf(pp, " 2=%e", alpha);
3907 fprintf(pp, " 3=%e", beta);
3908 fprintf(pp, " 4=%e", bias);
3909 }
3910 else if (op == "LSTM")
3911 {
3912 const onnx::TensorProto& W = weights[node.input(1)];
3913 const onnx::TensorProto& R = weights[node.input(2)];
3914 const onnx::TensorProto& B = weights[node.input(3)];
3915
3916 int hidden_size = get_node_attr_i(node, "hidden_size", 0);
3917 std::string direction = get_node_attr_s(node, "direction");
3918
3919 int direction_type = 0;
3920 if (direction == "forward")
3921 {
3922 direction_type = 0;
3923 }
3924 else if (direction == "reverse")
3925 {
3926 direction_type = 1;
3927 }
3928 else if (direction == "bidirectional")
3929 {
3930 direction_type = 2;
3931 }
3932
3933 int weight_data_size = get_tensor_proto_data_size(W);
3934
3935 fprintf(pp, " 0=%d", hidden_size);
3936 fprintf(pp, " 1=%d", weight_data_size);
3937 fprintf(pp, " 2=%d", direction_type);
3938
3939 int num_directions = direction_type == 2 ? 2 : 1;
3940
3941 int quantize_tag = 0;
3942
3943 // reorder num_directions-IOFG-hidden-size to num_directions-IFOG-hidden-size
3944 {
3945 fwrite(&quantize_tag, sizeof(int), 1, bp);
3946
3947 int weight_data_size_g = get_tensor_proto_data_size(W) / 4 / num_directions;
3948 const float* wptr = W.has_raw_data() ? (const float*)W.raw_data().data() : W.float_data().data();
3949
3950 const float* iptr = wptr;
3951 const float* optr = wptr + weight_data_size_g;
3952 const float* fptr = wptr + weight_data_size_g * 2;
3953 const float* gptr = wptr + weight_data_size_g * 3;
3954 fwrite(iptr, sizeof(float), weight_data_size_g, bp);
3955 fwrite(fptr, sizeof(float), weight_data_size_g, bp);
3956 fwrite(optr, sizeof(float), weight_data_size_g, bp);
3957 fwrite(gptr, sizeof(float), weight_data_size_g, bp);
3958
3959 if (direction_type == 2)
3960 {
3961 iptr += weight_data_size_g * 4;
3962 optr += weight_data_size_g * 4;
3963 fptr += weight_data_size_g * 4;
3964 gptr += weight_data_size_g * 4;
3965 fwrite(iptr, sizeof(float), weight_data_size_g, bp);
3966 fwrite(fptr, sizeof(float), weight_data_size_g, bp);
3967 fwrite(optr, sizeof(float), weight_data_size_g, bp);
3968 fwrite(gptr, sizeof(float), weight_data_size_g, bp);
3969 }
3970 }
3971
3972 // reduce xc and hc bias
3973 // reorder num_directions-IOFG-hidden to num_directions-IFOG-hidden
3974 {
3975 fwrite(&quantize_tag, sizeof(int), 1, bp);
3976
3977 int bias_data_size_g = get_tensor_proto_data_size(B) / 2 / 4 / num_directions;
3978 const float* xcbptr = B.has_raw_data() ? (const float*)B.raw_data().data() : B.float_data().data();
3979 const float* xiptr = xcbptr;
3980 const float* xoptr = xcbptr + bias_data_size_g;
3981 const float* xfptr = xcbptr + bias_data_size_g * 2;
3982 const float* xgptr = xcbptr + bias_data_size_g * 3;
3983 const float* hiptr = xcbptr + bias_data_size_g * 4;
3984 const float* hoptr = xcbptr + bias_data_size_g * 5;
3985 const float* hfptr = xcbptr + bias_data_size_g * 6;
3986 const float* hgptr = xcbptr + bias_data_size_g * 7;
3987
3988 for (int j = 0; j < bias_data_size_g; j++)
3989 {
3990 float vb = xiptr[j] + hiptr[j];
3991 fwrite(&vb, sizeof(float), 1, bp);
3992 }
3993 for (int j = 0; j < bias_data_size_g; j++)
3994 {
3995 float vb = xfptr[j] + hfptr[j];
3996 fwrite(&vb, sizeof(float), 1, bp);
3997 }
3998 for (int j = 0; j < bias_data_size_g; j++)
3999 {
4000 float vb = xoptr[j] + hoptr[j];
4001 fwrite(&vb, sizeof(float), 1, bp);
4002 }
4003 for (int j = 0; j < bias_data_size_g; j++)
4004 {
4005 float vb = xgptr[j] + hgptr[j];
4006 fwrite(&vb, sizeof(float), 1, bp);
4007 }
4008
4009 if (direction_type == 2)
4010 {
4011 xiptr += bias_data_size_g * 8;
4012 xoptr += bias_data_size_g * 8;
4013 xfptr += bias_data_size_g * 8;
4014 xgptr += bias_data_size_g * 8;
4015 hiptr += bias_data_size_g * 8;
4016 hoptr += bias_data_size_g * 8;
4017 hfptr += bias_data_size_g * 8;
4018 hgptr += bias_data_size_g * 8;
4019
4020 for (int j = 0; j < bias_data_size_g; j++)
4021 {
4022 float vb = xiptr[j] + hiptr[j];
4023 fwrite(&vb, sizeof(float), 1, bp);
4024 }
4025 for (int j = 0; j < bias_data_size_g; j++)
4026 {
4027 float vb = xfptr[j] + hfptr[j];
4028 fwrite(&vb, sizeof(float), 1, bp);
4029 }
4030 for (int j = 0; j < bias_data_size_g; j++)
4031 {
4032 float vb = xoptr[j] + hoptr[j];
4033 fwrite(&vb, sizeof(float), 1, bp);
4034 }
4035 for (int j = 0; j < bias_data_size_g; j++)
4036 {
4037 float vb = xgptr[j] + hgptr[j];
4038 fwrite(&vb, sizeof(float), 1, bp);
4039 }
4040 }
4041 }
4042
4043 // reorder num_directions-IOFG-hidden-hidden to num_directions-IFOG-hidden-hidden
4044 {
4045 fwrite(&quantize_tag, sizeof(int), 1, bp);
4046
4047 int weight_data_size_g = get_tensor_proto_data_size(R) / 4 / num_directions;
4048 const float* rptr = R.has_raw_data() ? (const float*)R.raw_data().data() : R.float_data().data();
4049
4050 const float* iptr = rptr;
4051 const float* optr = rptr + weight_data_size_g;
4052 const float* fptr = rptr + weight_data_size_g * 2;
4053 const float* gptr = rptr + weight_data_size_g * 3;
4054 fwrite(iptr, sizeof(float), weight_data_size_g, bp);
4055 fwrite(fptr, sizeof(float), weight_data_size_g, bp);
4056 fwrite(optr, sizeof(float), weight_data_size_g, bp);
4057 fwrite(gptr, sizeof(float), weight_data_size_g, bp);
4058
4059 if (direction_type == 2)
4060 {
4061 iptr += weight_data_size_g * 4;
4062 optr += weight_data_size_g * 4;
4063 fptr += weight_data_size_g * 4;
4064 gptr += weight_data_size_g * 4;
4065 fwrite(iptr, sizeof(float), weight_data_size_g, bp);
4066 fwrite(fptr, sizeof(float), weight_data_size_g, bp);
4067 fwrite(optr, sizeof(float), weight_data_size_g, bp);
4068 fwrite(gptr, sizeof(float), weight_data_size_g, bp);
4069 }
4070 }
4071 }
4072 else if (op == "MatMul")
4073 {
4074 if (weights.find(node.input(1)) != weights.end() && weights[node.input(1)].dims_size() == 2)
4075 {
4076 // InnerProduct
4077 const onnx::TensorProto& B = weights[node.input(1)];
4078
4079 int weight_data_size = get_tensor_proto_data_size(B);
4080
4081 int num_output = B.dims(B.dims_size() - 1);
4082 int num_input = weight_data_size / num_output;
4083
4084 fprintf(pp, " 0=%d", num_output);
4085 fprintf(pp, " 1=0");
4086 fprintf(pp, " 2=%d", weight_data_size);
4087
4088 int quantize_tag = 0;
4089 fwrite(&quantize_tag, sizeof(int), 1, bp);
4090
4091 // reorder num_input-num_output to num_output-num_input
4092 {
4093 const float* bptr = B.has_raw_data() ? (const float*)B.raw_data().data() : B.float_data().data();
4094
4095 for (int j = 0; j < num_output; j++)
4096 {
4097 for (int k = 0; k < num_input; k++)
4098 {
4099 float vb = bptr[k * num_output + j];
4100 fwrite(&vb, sizeof(float), 1, bp);
4101 }
4102 }
4103 }
4104
4105 // fwrite_tensor_proto_data(B, bp)
4106 }
4107 else
4108 {
4109 // default matrix multiplication
4110 }
4111 }
4112 else if (op == "Max")
4113 {
4114 int op_type = 4;
4115 fprintf(pp, " 0=%d", op_type);
4116
4117 if (weights.find(node.input(1)) != weights.end() && weights[node.input(1)].dims_size() == 0)
4118 {
4119 float b = get_node_attr_from_input_f(weights[node.input(1)]);
4120 fprintf(pp, " 1=1");
4121 fprintf(pp, " 2=%e", b);
4122 }
4123 }
4124 else if (op == "Min")
4125 {
4126 int op_type = 5;
4127 fprintf(pp, " 0=%d", op_type);
4128
4129 if (weights.find(node.input(1)) != weights.end() && weights[node.input(1)].dims_size() == 0)
4130 {
4131 float b = get_node_attr_from_input_f(weights[node.input(1)]);
4132 fprintf(pp, " 1=1");
4133 fprintf(pp, " 2=%e", b);
4134 }
4135 }
4136 else if (op == "Mul")
4137 {
4138 int op_type = 2;
4139 fprintf(pp, " 0=%d", op_type);
4140
4141 if (weights.find(node.input(1)) != weights.end() && weights[node.input(1)].dims_size() == 0)
4142 {
4143 float b = get_node_attr_from_input_f(weights[node.input(1)]);
4144 fprintf(pp, " 1=1");
4145 fprintf(pp, " 2=%e", b);
4146 }
4147 }
4148 else if (op == "Neg")
4149 {
4150 int op_type = 1;
4151 fprintf(pp, " 0=%d", op_type);
4152 }
4153 else if (op == "Normalize")
4154 {
4155 float eps = get_node_attr_f(node, "eps", 0.f);
4156 int scale_data_size = 1;
4157
4158 fprintf(pp, " 1=1"); // channel_shared
4159 fprintf(pp, " 2=%e", eps);
4160 fprintf(pp, " 3=%d", scale_data_size);
4161 fprintf(pp, " 9=1"); // TODO hardcode pytorch style
4162
4163 const float scale_data[1] = {1.f};
4164 fwrite(scale_data, sizeof(float), 1, bp);
4165 }
4166 else if (op == "Pad")
4167 {
4168 std::string mode = get_node_attr_s(node, "mode");
4169 float value = get_node_attr_f(node, "value", 0.f);
4170
4171 std::vector<int> pads;
4172 if (node.input_size() == 1)
4173 {
4174 pads = get_node_attr_ai(node, "pads");
4175 }
4176 else
4177 {
4178 pads = get_node_attr_from_input_ai(weights[node.input(1)]);
4179 }
4180
4181 int type = 0;
4182 if (mode == "constant")
4183 {
4184 type = 0;
4185 }
4186 else if (mode == "edge")
4187 {
4188 type = 1;
4189 }
4190 else if (mode == "reflect")
4191 {
4192 type = 2;
4193 }
4194
4195 int pad_size = (int)pads.size();
4196 int top = 0;
4197 int bottom = 0;
4198 int left = 0;
4199 int right = 0;
4200 int front = 0;
4201 int behind = 0;
4202 if (pad_size == 8)
4203 {
4204 //NCHW
4205 top = pads[2];
4206 bottom = pads[6];
4207 left = pads[3];
4208 right = pads[7];
4209 front = pads[1];
4210 behind = pads[5];
4211 }
4212 else if (pad_size == 6)
4213 {
4214 //NHW
4215 top = pads[1];
4216 bottom = pads[4];
4217 left = pads[2];
4218 right = pads[5];
4219 }
4220 else
4221 {
4222 //NW
4223 left = pads[1];
4224 right = pads[3];
4225 }
4226
4227 fprintf(pp, " 0=%d", top);
4228 fprintf(pp, " 1=%d", bottom);
4229 fprintf(pp, " 2=%d", left);
4230 fprintf(pp, " 3=%d", right);
4231 fprintf(pp, " 4=%d", type);
4232 fprintf(pp, " 5=%e", value);
4233 fprintf(pp, " 7=%d", front);
4234 fprintf(pp, " 8=%d", behind);
4235 }
4236 else if (op == "Pow")
4237 {
4238 int op_type = 6;
4239 fprintf(pp, " 0=%d", op_type);
4240
4241 if (weights.find(node.input(1)) != weights.end() && weights[node.input(1)].dims_size() == 0)
4242 {
4243 float b = get_node_attr_from_input_f(weights[node.input(1)]);
4244 fprintf(pp, " 1=1");
4245 fprintf(pp, " 2=%e", b);
4246 }
4247 }
4248 else if (op == "PixelShuffle")
4249 {
4250 int scale_factor = get_node_attr_i(node, "scale_factor", 1);
4251 fprintf(pp, " 0=%d", scale_factor);
4252 }
4253 else if (op == "PRelu")
4254 {
4255 const onnx::TensorProto& slope = weights[node.input(1)];
4256
4257 int num_slope = get_tensor_proto_data_size(slope);
4258
4259 fprintf(pp, " 0=%d", num_slope);
4260
4261 fwrite_tensor_proto_data(slope, bp);
4262 }
4263 else if (op == "Reciprocal")
4264 {
4265 int op_type = 15;
4266 fprintf(pp, " 0=%d", op_type);
4267 }
4268 else if (op == "ReduceMax" || op == "ReduceMin" || op == "ReduceMean" || op == "ReduceProd" || op == "ReduceSum" || op == "ReduceSumSquare" || op == "ReduceL1" || op == "ReduceL2" || op == "ReduceLogSum" || op == "ReduceLogSumExp")
4269 {
4270 int op_type = -233;
4271 if (op == "ReduceSum")
4272 op_type = 0;
4273 else if (op == "ReduceSumSquare")
4274 op_type = 2;
4275 else if (op == "ReduceMean")
4276 op_type = 3;
4277 else if (op == "ReduceMax")
4278 op_type = 4;
4279 else if (op == "ReduceMin")
4280 op_type = 5;
4281 else if (op == "ReduceProd")
4282 op_type = 6;
4283 else if (op == "ReduceL1")
4284 op_type = 7;
4285 else if (op == "ReduceL2")
4286 op_type = 8;
4287 else if (op == "ReduceLogSum")
4288 op_type = 9;
4289 else if (op == "ReduceLogSumExp")
4290 op_type = 10;
4291 fprintf(pp, " 0=%d", op_type);
4292
4293 std::vector<int> axes = get_node_attr_ai(node, "axes");
4294 int keepdims = get_node_attr_i(node, "keepdims", 1);
4295
4296 if (axes.size() > 0)
4297 {
4298 // if axes set, reduce according to axes
4299 fprintf(pp, " 1=%d", 0);
4300 fprintf(pp, " -23303=%zu", axes.size());
4301 for (size_t j = 0; j < axes.size(); j++)
4302 {
4303 if (axes[j] == 0 || axes[j] > 3 || axes[j] < -3)
4304 fprintf(stderr, "Unsupported reduction axes !\n");
4305 fprintf(pp, ",%d", axes[j]);
4306 }
4307 }
4308 else
4309 {
4310 // if axes not set, reduce all axes by default
4311 fprintf(pp, " 1=%d", 1);
4312 }
4313 fprintf(pp, " 4=%d", keepdims);
4314 }
4315 else if (op == "Reorg")
4316 {
4317 int stride = get_node_attr_i(node, "stride", 1);
4318 fprintf(pp, " 0=%d", stride);
4319 }
4320 else if (op == "Reshape")
4321 {
4322 std::vector<int> shape;
4323
4324 if (node.input_size() == 1)
4325 {
4326 shape = get_node_attr_ai(node, "shape");
4327 }
4328 else
4329 {
4330 shape = get_node_attr_from_input_ai(weights[node.input(1)]);
4331 }
4332
4333 if (shape.size() == 1)
4334 {
4335 fprintf(pp, " 0=%d", shape[0]); // should never reach here
4336 }
4337 else if (shape.size() == 2)
4338 {
4339 fprintf(pp, " 0=%d", shape[1]);
4340 }
4341 else if (shape.size() == 3)
4342 {
4343 fprintf(pp, " 0=%d", shape[2]);
4344 fprintf(pp, " 1=%d", shape[1]);
4345 }
4346 else if (shape.size() == 4)
4347 {
4348 fprintf(pp, " 0=%d", shape[3]);
4349 fprintf(pp, " 1=%d", shape[2]);
4350 fprintf(pp, " 2=%d", shape[1]);
4351 }
4352 else if (shape.size() == 5)
4353 {
4354 fprintf(pp, " 0=%d", shape[4] * shape[3]);
4355 fprintf(pp, " 1=%d", shape[2]);
4356 fprintf(pp, " 2=%d", shape[1]);
4357 }
4358 }
4359 else if (op == "Resize")
4360 {
4361 std::string mode = get_node_attr_s(node, "mode");
4362 std::string align = get_node_attr_s(node, "coordinate_transformation_mode");
4363
4364 std::vector<float> scales;
4365 std::vector<int> sizes;
4366 if (node.input_size() == 2)
4367 {
4368 // opset 10
4369 scales = get_node_attr_from_input_af(weights[node.input(1)]);
4370 }
4371 else
4372 {
4373 // opset 11+
4374 scales = get_node_attr_from_input_af(weights[node.input(2)]);
4375 if (node.input_size() >= 4)
4376 {
4377 sizes = get_node_attr_from_input_ai(weights[node.input(3)]);
4378 }
4379 }
4380
4381 int resize_type = 1;
4382 if (mode == "nearest")
4383 {
4384 resize_type = 1;
4385 }
4386 else if (mode == "linear")
4387 {
4388 resize_type = 2;
4389 }
4390 else if (mode == "cubic")
4391 {
4392 resize_type = 3;
4393 }
4394
4395 if (scales.empty() && sizes.empty())
4396 {
4397 fprintf(stderr, "Unsupported Resize scales and sizes are all empty!\n");
4398 }
4399
4400 float h_scale = 1.f;
4401 float w_scale = 1.f;
4402 if (scales.size() == 2)
4403 {
4404 w_scale = scales[1];
4405 }
4406 else if (scales.size() == 3)
4407 {
4408 h_scale = scales[1];
4409 w_scale = scales[2];
4410 }
4411 else if (scales.size() == 4)
4412 {
4413 h_scale = scales[2];
4414 w_scale = scales[3];
4415
4416 if (scales[1] != 1.f)
4417 fprintf(stderr, "Unsupported Resize scales !\n");
4418 }
4419
4420 int output_height = 0;
4421 int output_width = 0;
4422 if (sizes.size() == 2)
4423 {
4424 output_width = sizes[1];
4425 }
4426 else if (sizes.size() == 3)
4427 {
4428 output_height = sizes[1];
4429 output_width = sizes[2];
4430 }
4431 else if (sizes.size() == 4)
4432 {
4433 output_height = sizes[2];
4434 output_width = sizes[3];
4435 }
4436
4437 int align_corner = 0;
4438 if (align == "align_corners")
4439 {
4440 align_corner = 1;
4441 }
4442
4443 fprintf(pp, " 0=%d", resize_type);
4444 fprintf(pp, " 1=%e", h_scale);
4445 fprintf(pp, " 2=%e", w_scale);
4446 fprintf(pp, " 3=%d", output_height);
4447 fprintf(pp, " 4=%d", output_width);
4448 fprintf(pp, " 6=%d", align_corner);
4449 }
4450 else if (op == "RNN")
4451 {
4452 const onnx::TensorProto& W = weights[node.input(1)];
4453 const onnx::TensorProto& R = weights[node.input(2)];
4454 const onnx::TensorProto& B = weights[node.input(3)];
4455
4456 int hidden_size = get_node_attr_i(node, "hidden_size", 0);
4457 std::string direction = get_node_attr_s(node, "direction");
4458
4459 int direction_type = 0;
4460 if (direction == "forward")
4461 {
4462 direction_type = 0;
4463 }
4464 else if (direction == "reverse")
4465 {
4466 direction_type = 1;
4467 }
4468 else if (direction == "bidirectional")
4469 {
4470 direction_type = 2;
4471 }
4472
4473 int weight_data_size = get_tensor_proto_data_size(W);
4474
4475 fprintf(pp, " 0=%d", hidden_size);
4476 fprintf(pp, " 1=%d", weight_data_size);
4477 fprintf(pp, " 2=%d", direction_type);
4478
4479 int num_directions = direction_type == 2 ? 2 : 1;
4480
4481 int quantize_tag = 0;
4482
4483 fwrite(&quantize_tag, sizeof(int), 1, bp);
4484 fwrite_tensor_proto_data(W, bp);
4485
4486 // reduce xc and hc bias
4487 {
4488 fwrite(&quantize_tag, sizeof(int), 1, bp);
4489
4490 int bias_data_size_g = get_tensor_proto_data_size(B) / 2 / num_directions;
4491 const float* bptr = B.has_raw_data() ? (const float*)B.raw_data().data() : B.float_data().data();
4492 const float* xiptr = bptr;
4493 const float* hiptr = bptr + bias_data_size_g;
4494
4495 for (int j = 0; j < bias_data_size_g; j++)
4496 {
4497 float vb = xiptr[j] + hiptr[j];
4498 fwrite(&vb, sizeof(float), 1, bp);
4499 }
4500
4501 if (direction_type == 2)
4502 {
4503 xiptr += bias_data_size_g * 2;
4504 hiptr += bias_data_size_g * 2;
4505
4506 for (int j = 0; j < bias_data_size_g; j++)
4507 {
4508 float vb = xiptr[j] + hiptr[j];
4509 fwrite(&vb, sizeof(float), 1, bp);
4510 }
4511 }
4512 }
4513
4514 fwrite(&quantize_tag, sizeof(int), 1, bp);
4515 fwrite_tensor_proto_data(R, bp);
4516 }
4517 else if (op == "ShuffleChannel")
4518 {
4519 int group = get_node_attr_i(node, "group", 1);
4520 int reverse = get_node_attr_i(node, "reverse", 0);
4521 fprintf(pp, " 0=%d", group);
4522 fprintf(pp, " 1=%d", reverse);
4523 }
4524 else if (op == "Sigmoid")
4525 {
4526 }
4527 else if (op == "Sin")
4528 {
4529 int op_type = 9;
4530 fprintf(pp, " 0=%d", op_type);
4531 }
4532 else if (op == "SkipLayerNormalization")
4533 {
4534 const onnx::TensorProto& W = weights[node.input(2)];
4535 const onnx::TensorProto& B = weights[node.input(3)];
4536 const onnx::TensorProto& B2 = weights[node.input(4)];
4537
4538 fprintf(pp, " 0=%d", get_tensor_proto_data_size(B));
4539
4540 int quantize_tag = 0;
4541 fwrite(&quantize_tag, sizeof(int), 1, bp);
4542
4543 fwrite_tensor_proto_data(W, bp);
4544
4545 fwrite(&quantize_tag, sizeof(int), 1, bp);
4546
4547 fwrite_tensor_proto_data(B, bp);
4548
4549 fwrite(&quantize_tag, sizeof(int), 1, bp);
4550
4551 fwrite_tensor_proto_data(B2, bp);
4552 }
4553 else if (op == "Slice")
4554 {
4555 std::vector<int> starts;
4556 std::vector<int> ends;
4557 std::vector<int> axes;
4558 std::vector<int> steps;
4559 if (node.input_size() == 1)
4560 {
4561 starts = get_node_attr_ai(node, "starts");
4562 ends = get_node_attr_ai(node, "ends");
4563 axes = get_node_attr_ai(node, "axes");
4564 steps = get_node_attr_ai(node, "steps"); // TODO
4565 }
4566 else
4567 {
4568 starts = get_node_attr_from_input_ai(weights[node.input(1)]);
4569 ends = get_node_attr_from_input_ai(weights[node.input(2)]);
4570 if (node.input_size() >= 4)
4571 axes = get_node_attr_from_input_ai(weights[node.input(3)]);
4572 if (node.input_size() >= 5)
4573 steps = get_node_attr_from_input_ai(weights[node.input(4)]);
4574 }
4575
4576 // assert step == 1
4577 for (int i = 0; i < (int)steps.size(); i++)
4578 {
4579 if (steps[i] != 1)
4580 fprintf(stderr, "Unsupported slice step !\n");
4581 }
4582
4583 // filter out N-dim axis
4584 if (!axes.empty())
4585 {
4586 for (int i = 0; i < (int)axes.size(); i++)
4587 {
4588 int axis = axes[i];
4589 if (axis == 0)
4590 {
4591 starts.erase(starts.begin() + i);
4592 ends.erase(ends.begin() + i);
4593 axes.erase(axes.begin() + i);
4594 break;
4595 }
4596 }
4597 }
4598
4599 fprintf(pp, " -23309=%d", (int)starts.size());
4600 for (int i = 0; i < (int)starts.size(); i++)
4601 {
4602 fprintf(pp, ",%d", starts[i]);
4603 }
4604 fprintf(pp, " -23310=%d", (int)ends.size());
4605 for (int i = 0; i < (int)ends.size(); i++)
4606 {
4607 fprintf(pp, ",%d", ends[i]);
4608 }
4609 if (!axes.empty())
4610 {
4611 fprintf(pp, " -23311=%d", (int)axes.size());
4612 for (int i = 0; i < (int)axes.size(); i++)
4613 {
4614 int axis = axes[i];
4615 if (axis == 0 || axis > 3 || axis < -3)
4616 fprintf(stderr, "Unsupported slice axes !\n");
4617
4618 if (axis > 0)
4619 axis = axis - 1; // -1 for skip N-dim
4620
4621 fprintf(pp, ",%d", axis);
4622 }
4623 }
4624 }
4625 else if (op == "Softmax")
4626 {
4627 int axis = get_node_attr_i(node, "axis", 1);
4628 fprintf(pp, " 0=%d", axis - 1);
4629 fprintf(pp, " 1=1");
4630 }
4631 else if (op == "Split")
4632 {
4633 int axis = get_node_attr_i(node, "axis", 0);
4634 std::vector<int> split = get_node_attr_ai(node, "split");
4635 if (axis < 1)
4636 fprintf(stderr, "Unsupported split axis !\n");
4637
4638 fprintf(pp, " -23300=%d", output_size);
4639 if (split.empty())
4640 {
4641 for (int i = 0; i < output_size; i++)
4642 {
4643 fprintf(pp, ",-233");
4644 }
4645 }
4646 else
4647 {
4648 for (size_t i = 0; i < split.size() - 1; i++)
4649 {
4650 fprintf(pp, ",%d", split[i]);
4651 }
4652 fprintf(pp, ",-233");
4653 }
4654 fprintf(pp, " 1=%d", axis - 1);
4655 }
4656 else if (op == "Sqrt")
4657 {
4658 int op_type = 5;
4659 fprintf(pp, " 0=%d", op_type);
4660 }
4661 else if (op == "Squeeze")
4662 {
4663 std::vector<int> axes = get_node_attr_ai(node, "axes");
4664
4665 if (axes.empty())
4666 {
4667 fprintf(pp, " 0=1");
4668 fprintf(pp, " 1=1");
4669 fprintf(pp, " 2=1");
4670 }
4671 else
4672 {
4673 fprintf(pp, " -23303=%zu", axes.size());
4674 for (int i = 0; i < (int)axes.size(); i++)
4675 {
4676 if (axes[i] == 0 || axes[i] > 3 || axes[i] < -3)
4677 fprintf(stderr, "Unsupported squeeze axes !\n");
4678 fprintf(pp, ",%d", axes[i]);
4679 }
4680 }
4681 }
4682 else if (op == "Sub")
4683 {
4684 int op_type = 1;
4685 fprintf(pp, " 0=%d", op_type);
4686
4687 if (weights.find(node.input(1)) != weights.end() && weights[node.input(1)].dims_size() == 0)
4688 {
4689 float b = get_node_attr_from_input_f(weights[node.input(1)]);
4690 fprintf(pp, " 1=1");
4691 fprintf(pp, " 2=%e", b);
4692 }
4693 }
4694 else if (op == "Sum")
4695 {
4696 int op_type = 1;
4697 fprintf(pp, " 0=%d", op_type);
4698 }
4699 else if (op == "Swish")
4700 {
4701 }
4702 else if (op == "Tan")
4703 {
4704 int op_type = 11;
4705 fprintf(pp, " 0=%d", op_type);
4706 }
4707 else if (op == "Tanh")
4708 {
4709 int op_type = 16;
4710 fprintf(pp, " 0=%d", op_type);
4711 }
4712 else if (op == "Transpose")
4713 {
4714 std::vector<int> perm = get_node_attr_ai(node, "perm");
4715
4716 if (perm.size() == 3)
4717 {
4718 if (perm[1] == 1 && perm[2] == 2)
4719 fprintf(pp, " 0=0"); // w h
4720 else if (perm[1] == 2 && perm[2] == 1)
4721 fprintf(pp, " 0=1"); // h w
4722 else if (perm[0] == 1 && perm[1] == 0 && perm[2] == 2)
4723 fprintf(pp, " 0=0"); // w h
4724 else if (perm[0] == 2 && perm[1] == 0 && perm[2] == 1)
4725 fprintf(pp, " 0=1"); // h w
4726 }
4727 else if (perm.size() == 4)
4728 {
4729 if (perm[1] == 1 && perm[2] == 2 && perm[3] == 3)
4730 fprintf(pp, " 0=0"); // w h c
4731 else if (perm[1] == 1 && perm[2] == 3 && perm[3] == 2)
4732 fprintf(pp, " 0=1"); // h w c
4733 else if (perm[1] == 2 && perm[2] == 1 && perm[3] == 3)
4734 fprintf(pp, " 0=2"); // w c h
4735 else if (perm[1] == 2 && perm[2] == 3 && perm[3] == 1)
4736 fprintf(pp, " 0=3"); // c w h
4737 else if (perm[1] == 3 && perm[2] == 1 && perm[3] == 2)
4738 fprintf(pp, " 0=4"); // h c w
4739 else if (perm[1] == 3 && perm[2] == 2 && perm[3] == 1)
4740 fprintf(pp, " 0=5"); // c h w
4741 }
4742 else if (perm.size() == 5)
4743 {
4744 if (perm[1] == 1 && perm[2] == 2 && perm[3] == 3 && perm[4] == 4)
4745 fprintf(pp, " 0=0"); // wx h c
4746 else if (perm[1] == 1 && perm[2] == 3 && perm[3] == 4 && perm[4] == 2)
4747 fprintf(pp, " 0=1"); // h wx c
4748 else if (perm[1] == 2 && perm[2] == 1 && perm[3] == 3 && perm[4] == 4)
4749 fprintf(pp, " 0=2"); // wx c h
4750 else if (perm[1] == 2 && perm[2] == 3 && perm[3] == 4 && perm[4] == 1)
4751 fprintf(pp, " 0=3"); // c wx h
4752 else if (perm[1] == 3 && perm[2] == 4 && perm[3] == 1 && perm[4] == 2)
4753 fprintf(pp, " 0=4"); // h c wx
4754 else if (perm[1] == 3 && perm[2] == 4 && perm[3] == 2 && perm[4] == 1)
4755 fprintf(pp, " 0=5"); // c h wx
4756 else
4757 fprintf(stderr, "Unsupported transpose type !\n");
4758 }
4759 }
4760 else if (op == "Upsample")
4761 {
4762 std::string mode = get_node_attr_s(node, "mode");
4763 std::string align = get_node_attr_s(node, "coordinate_transformation_mode");
4764
4765 std::vector<float> scales;
4766
4767 if (node.input_size() == 1)
4768 {
4769 scales = get_node_attr_af(node, "scales");
4770 }
4771 else
4772 {
4773 scales = get_node_attr_from_input_af(weights[node.input(1)]);
4774 }
4775
4776 int resize_type = 1;
4777 if (mode == "nearest")
4778 {
4779 resize_type = 1;
4780 }
4781 else if (mode == "bilinear" || mode == "linear")
4782 {
4783 resize_type = 2;
4784 }
4785 else if (mode == "trilinear")
4786 {
4787 fprintf(stderr, "Unsupported Upsample mode !\n");
4788 }
4789
4790 float h_scale = 1.f;
4791 float w_scale = 1.f;
4792 if (scales.size() == 2)
4793 {
4794 w_scale = scales[1];
4795 }
4796 else if (scales.size() == 3)
4797 {
4798 h_scale = scales[1];
4799 w_scale = scales[2];
4800 }
4801 else if (scales.size() == 4)
4802 {
4803 h_scale = scales[2];
4804 w_scale = scales[3];
4805
4806 if (scales[1] != 1.f)
4807 fprintf(stderr, "Unsupported Upsample scales !\n");
4808 }
4809 else
4810 {
4811 fprintf(stderr, "Unsupported Upsample scales !\n");
4812 }
4813
4814 int align_corner = 0;
4815 if (align == "align_corners")
4816 {
4817 align_corner = 1;
4818 }
4819
4820 fprintf(pp, " 0=%d", resize_type);
4821 fprintf(pp, " 1=%e", h_scale);
4822 fprintf(pp, " 2=%e", w_scale);
4823 fprintf(pp, " 6=%d", align_corner);
4824 }
4825 else if (op == "Unsqueeze")
4826 {
4827 std::vector<int> axes = get_node_attr_ai(node, "axes");
4828
4829 fprintf(pp, " -23303=%zu", axes.size());
4830 for (int i = 0; i < (int)axes.size(); i++)
4831 {
4832 if (axes[i] == 0 || axes[i] > 4 || axes[i] < -4)
4833 fprintf(stderr, "Unsupported unsqueeze axes !\n");
4834 fprintf(pp, ",%d", axes[i]);
4835 }
4836 }
4837 else
4838 {
4839 // TODO op specific param
4840 for (int j = 0; j < node.attribute_size(); j++)
4841 {
4842 const onnx::AttributeProto& attr = node.attribute(j);
4843 if (attr.type() == 1)
4844 {
4845 fprintf(stderr, " # %s=%g\n", attr.name().c_str(), attr.f());
4846 }
4847 else if (attr.type() == 2)
4848 {
4849 fprintf(stderr, " # %s=%lld\n", attr.name().c_str(), (long long)attr.i());
4850 }
4851 else if (attr.type() == 3)
4852 {
4853 fprintf(stderr, " # %s=%s\n", attr.name().c_str(), attr.s().c_str());
4854 }
4855 else
4856 {
4857 fprintf(stderr, " # %s %d\n", attr.name().c_str(), attr.type());
4858 }
4859 }
4860 }
4861
4862 fprintf(pp, "\n");
4863
4864 for (int j = 0; j < output_size; j++)
4865 {
4866 const std::string& output_name = node.output(j);
4867 if (node_reference.find(output_name) != node_reference.end())
4868 {
4869 int refcount = node_reference[output_name];
4870 if (refcount > 1)
4871 {
4872 char splitname[256];
4873 sprintf(splitname, "splitncnn_%d", internal_split);
4874 fprintf(pp, "%-16s %-24s %d %d", "Split", splitname, 1, refcount);
4875
4876 fprintf(pp, " %s", output_name.c_str());
4877
4878 for (int k = 0; k < refcount; k++)
4879 {
4880 fprintf(pp, " %s_splitncnn_%d", output_name.c_str(), k);
4881 }
4882 fprintf(pp, "\n");
4883
4884 internal_split++;
4885 }
4886 }
4887 }
4888 }
4889
4890 fclose(pp);
4891 fclose(bp);
4892
4893 return 0;
4894 }
4895