1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14
15 #include "onnx.pb.h"
16
17 #include <algorithm>
18 #include <float.h>
19 #include <fstream>
20 #include <google/protobuf/io/coded_stream.h>
21 #include <google/protobuf/io/zero_copy_stream_impl.h>
22 #include <google/protobuf/message.h>
23 #include <google/protobuf/text_format.h>
24 #include <iostream>
25 #include <limits.h>
26 #include <limits>
27 #include <set>
28 #include <stdio.h>
29
read_proto_from_binary(const char * filepath,onnx::ModelProto * message)30 static bool read_proto_from_binary(const char* filepath, onnx::ModelProto* message)
31 {
32 std::ifstream fs(filepath, std::ifstream::in | std::ifstream::binary);
33 if (!fs.is_open())
34 {
35 fprintf(stderr, "open failed %s\n", filepath);
36 return false;
37 }
38
39 google::protobuf::io::IstreamInputStream input(&fs);
40 google::protobuf::io::CodedInputStream codedstr(&input);
41
42 #if GOOGLE_PROTOBUF_VERSION >= 3011000
43 codedstr.SetTotalBytesLimit(INT_MAX);
44 #else
45 codedstr.SetTotalBytesLimit(INT_MAX, INT_MAX / 2);
46 #endif
47
48 bool success = message->ParseFromCodedStream(&codedstr);
49
50 fs.close();
51
52 return success;
53 }
54
get_node_attr_ai(const onnx::NodeProto & node,const char * key)55 static std::vector<int> get_node_attr_ai(const onnx::NodeProto& node, const char* key)
56 {
57 std::vector<int> v;
58
59 for (int i = 0; i < node.attribute_size(); i++)
60 {
61 const onnx::AttributeProto& attr = node.attribute(i);
62 if (attr.name() == key)
63 {
64 v.resize(attr.ints_size());
65 for (int j = 0; j < attr.ints_size(); j++)
66 {
67 v[j] = std::max(std::min(attr.ints(j), (::google::protobuf::int64)INT_MAX), (::google::protobuf::int64)INT_MIN);
68 }
69
70 break;
71 }
72 }
73
74 return v;
75 }
76
get_node_attr_af(const onnx::NodeProto & node,const char * key)77 static std::vector<float> get_node_attr_af(const onnx::NodeProto& node, const char* key)
78 {
79 std::vector<float> v;
80
81 for (int i = 0; i < node.attribute_size(); i++)
82 {
83 const onnx::AttributeProto& attr = node.attribute(i);
84 if (attr.name() == key)
85 {
86 v.resize(attr.floats_size());
87 for (int j = 0; j < attr.floats_size(); j++)
88 {
89 v[j] = attr.floats(j);
90 }
91
92 break;
93 }
94 }
95
96 return v;
97 }
98
get_node_attr_i(const onnx::NodeProto & node,const char * key,int def=0)99 static int get_node_attr_i(const onnx::NodeProto& node, const char* key, int def = 0)
100 {
101 for (int i = 0; i < node.attribute_size(); i++)
102 {
103 const onnx::AttributeProto& attr = node.attribute(i);
104 if (attr.name() == key)
105 {
106 return std::max(std::min(attr.i(), (::google::protobuf::int64)INT_MAX), (::google::protobuf::int64)INT_MIN);
107 }
108 }
109
110 return def;
111 }
112
get_node_attr_f(const onnx::NodeProto & node,const char * key,float def=0.f)113 static float get_node_attr_f(const onnx::NodeProto& node, const char* key, float def = 0.f)
114 {
115 for (int i = 0; i < node.attribute_size(); i++)
116 {
117 const onnx::AttributeProto& attr = node.attribute(i);
118 if (attr.name() == key)
119 {
120 return attr.f();
121 }
122 }
123
124 return def;
125 }
126
get_node_attr_s(const onnx::NodeProto & node,const char * key,const std::string & def=std::string ())127 static std::string get_node_attr_s(const onnx::NodeProto& node, const char* key, const std::string& def = std::string())
128 {
129 for (int i = 0; i < node.attribute_size(); i++)
130 {
131 const onnx::AttributeProto& attr = node.attribute(i);
132 if (attr.name() == key)
133 {
134 return attr.s();
135 }
136 }
137
138 return def;
139 }
140
get_node_attr_tensor(const onnx::NodeProto & node,const char * key)141 static onnx::TensorProto get_node_attr_tensor(const onnx::NodeProto& node, const char* key)
142 {
143 for (int i = 0; i < node.attribute_size(); i++)
144 {
145 const onnx::AttributeProto& attr = node.attribute(i);
146 if (attr.name() == key)
147 {
148 return attr.t();
149 }
150 }
151
152 return onnx::TensorProto();
153 }
154
get_node_attr_from_input_f(const onnx::TensorProto & tp)155 static float get_node_attr_from_input_f(const onnx::TensorProto& tp)
156 {
157 float v = 0.f;
158
159 // float
160 if (tp.data_type() == 1)
161 {
162 const float* shape_data = 0;
163 if (tp.has_raw_data())
164 {
165 shape_data = (const float*)tp.raw_data().data();
166 }
167 else
168 {
169 shape_data = tp.float_data().data();
170 }
171 v = shape_data[0];
172 }
173 // double
174 else if (tp.data_type() == 11)
175 {
176 const double* shape_data = 0;
177 if (tp.has_raw_data())
178 {
179 shape_data = (const double*)tp.raw_data().data();
180 }
181 else
182 {
183 shape_data = tp.double_data().data();
184 }
185 v = shape_data[0];
186 }
187 // int64
188 else if (tp.data_type() == 7)
189 {
190 const int64_t* shape_data = 0;
191 if (tp.has_raw_data())
192 {
193 shape_data = (const int64_t*)tp.raw_data().data();
194 }
195 else
196 {
197 shape_data = tp.int64_data().data();
198 }
199 v = std::max(std::min(shape_data[0], (::google::protobuf::int64)INT_MAX), (::google::protobuf::int64)INT_MIN);
200 }
201 // int32
202 else if (tp.data_type() == 6)
203 {
204 const int32_t* shape_data = 0;
205 if (tp.has_raw_data())
206 {
207 shape_data = (const int32_t*)tp.raw_data().data();
208 }
209 else
210 {
211 shape_data = tp.int32_data().data();
212 }
213 v = shape_data[0];
214 }
215 else
216 {
217 fprintf(stderr, "Unknown data type %d\n", tp.data_type());
218 abort();
219 }
220
221 return v;
222 }
223
get_node_attr_from_input_ai(const onnx::TensorProto & tp)224 static std::vector<int> get_node_attr_from_input_ai(const onnx::TensorProto& tp)
225 {
226 int size = 0;
227
228 std::vector<int> v;
229
230 // int64
231 if (tp.data_type() == 7)
232 {
233 const int64_t* shape_data = 0;
234 if (tp.has_raw_data())
235 {
236 shape_data = (const int64_t*)tp.raw_data().data();
237 size = (int)(tp.raw_data().size() / 8);
238 }
239 else
240 {
241 shape_data = tp.int64_data().data();
242 size = tp.int64_data_size();
243 }
244 for (int j = 0; j < size; j++)
245 {
246 int vi = std::max(std::min(shape_data[j], (::google::protobuf::int64)INT_MAX), (::google::protobuf::int64)INT_MIN);
247 v.push_back(vi);
248 }
249 }
250 // int32
251 else if (tp.data_type() == 6)
252 {
253 const int32_t* shape_data = 0;
254 if (tp.has_raw_data())
255 {
256 shape_data = (const int32_t*)tp.raw_data().data();
257 size = (int)(tp.raw_data().size() / 4);
258 }
259 else
260 {
261 shape_data = tp.int32_data().data();
262 size = tp.int32_data_size();
263 }
264 for (int j = 0; j < size; j++)
265 {
266 v.push_back(shape_data[j]);
267 }
268 }
269 else
270 {
271 fprintf(stderr, "Unknown data type %d\n", tp.data_type());
272 }
273
274 return v;
275 }
276
get_node_attr_from_input_af(const onnx::TensorProto & tp)277 static std::vector<float> get_node_attr_from_input_af(const onnx::TensorProto& tp)
278 {
279 int size = 0;
280
281 std::vector<float> v;
282
283 // float
284 if (tp.data_type() == 1)
285 {
286 const float* shape_data = 0;
287 if (tp.has_raw_data())
288 {
289 shape_data = (const float*)tp.raw_data().data();
290 size = (int)(tp.raw_data().size() / 4);
291 }
292 else
293 {
294 shape_data = tp.float_data().data();
295 size = tp.float_data_size();
296 }
297 for (int j = 0; j < size; j++)
298 {
299 v.push_back(shape_data[j]);
300 }
301 }
302 // double
303 else if (tp.data_type() == 11)
304 {
305 const double* shape_data = 0;
306 if (tp.has_raw_data())
307 {
308 shape_data = (const double*)tp.raw_data().data();
309 size = (int)(tp.raw_data().size() / 8);
310 }
311 else
312 {
313 shape_data = tp.double_data().data();
314 size = tp.double_data_size();
315 }
316 for (int j = 0; j < size; j++)
317 {
318 v.push_back((float)shape_data[j]);
319 }
320 }
321 else
322 {
323 fprintf(stderr, "Unknown data type %d\n", tp.data_type());
324 }
325
326 return v;
327 }
328
get_tensor_proto_data_size(const onnx::TensorProto & tp)329 static int get_tensor_proto_data_size(const onnx::TensorProto& tp)
330 {
331 if (tp.has_raw_data())
332 {
333 const std::string& raw_data = tp.raw_data();
334 int size = (int)raw_data.size() / 4;
335 return size;
336 }
337 else if (tp.data_type() == 1)
338 {
339 return tp.float_data_size();
340 }
341
342 return 0;
343 }
344
fwrite_tensor_proto_data(const onnx::TensorProto & tp,FILE * bp)345 static void fwrite_tensor_proto_data(const onnx::TensorProto& tp, FILE* bp)
346 {
347 int size = get_tensor_proto_data_size(tp);
348
349 if (tp.has_raw_data())
350 {
351 const std::string& raw_data = tp.raw_data();
352 fwrite(raw_data.data(), sizeof(float), size, bp);
353 }
354 else if (tp.data_type() == 1)
355 {
356 fwrite(tp.float_data().data(), sizeof(float), size, bp);
357 }
358 }
359
fuse_weight_reshape(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)360 static void fuse_weight_reshape(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
361 {
362 int node_count = mutable_graph->node_size();
363 for (int i = 0; i < node_count; i++)
364 {
365 onnx::NodeProto* node = mutable_graph->mutable_node(i);
366
367 // weight <= Reshape(weight)
368 if (node->op_type() == "Reshape")
369 {
370 // check weight
371 if (weights.find(node->input(0)) == weights.end())
372 continue;
373
374 weights[node->output(0)] = weights[node->input(0)];
375
376 // set weight shape directly
377 std::vector<int> shape;
378 if (node->input_size() == 1)
379 {
380 shape = get_node_attr_ai(*node, "shape");
381 }
382 else if (node->input_size() == 2)
383 {
384 // opset 5
385 shape = get_node_attr_from_input_ai(weights[node->input(1)]);
386 }
387
388 weights[node->output(0)].clear_dims();
389 for (int j = 0; j < shape.size(); j++)
390 {
391 weights[node->output(0)].add_dims(shape[j]);
392 }
393
394 // reduce
395 node->set_op_type("noop_reducedncnn");
396
397 node_reference[node->input(0)] -= 1;
398 if (node->input_size() == 2)
399 {
400 node_reference[node->input(1)] -= 1;
401 }
402
403 reduced_node_count += 1;
404 i += 1;
405 }
406 }
407 }
408
fuse_weight_transpose(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)409 static void fuse_weight_transpose(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
410 {
411 int node_count = mutable_graph->node_size();
412 for (int i = 0; i < node_count; i++)
413 {
414 onnx::NodeProto* node = mutable_graph->mutable_node(i);
415
416 // weight <= Transpose(weight)
417 if (node->op_type() == "Transpose")
418 {
419 // check weight
420 if (weights.find(node->input(0)) == weights.end())
421 continue;
422
423 if (weights[node->input(0)].dims_size() != 2)
424 continue;
425
426 // perm = (1, 0)
427 std::vector<int> perm = get_node_attr_ai(*node, "perm");
428 if (perm.size() != 2)
429 continue;
430 if (perm[0] != 1 || perm[1] != 0)
431 continue;
432
433 weights[node->output(0)] = weights[node->input(0)];
434
435 // permute weight
436 {
437 onnx::TensorProto& B = weights[node->output(0)];
438
439 const int h = B.dims(0);
440 const int w = B.dims(1);
441
442 std::vector<float> permuted_data;
443 permuted_data.reserve((size_t)h * w);
444 const float* bptr = B.has_raw_data() ? (const float*)B.raw_data().data() : B.float_data().data();
445
446 for (int j = 0; j < w; j++)
447 {
448 for (int k = 0; k < h; k++)
449 {
450 float vb = bptr[k * w + j];
451 permuted_data.push_back(vb);
452 }
453 }
454
455 B.set_dims(0, w);
456 B.set_dims(1, h);
457
458 if (B.has_raw_data())
459 {
460 B.set_raw_data(permuted_data.data(), permuted_data.size() * sizeof(float));
461 }
462 else
463 {
464 for (int j = 0; j < (int)permuted_data.size(); j++)
465 B.set_float_data(j, permuted_data[j]);
466 }
467 }
468
469 // reduce
470 node->set_op_type("noop_reducedncnn");
471
472 node_reference[node->input(0)] -= 1;
473
474 reduced_node_count += 1;
475 i += 1;
476 }
477 }
478 }
479
fuse_shufflechannel(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)480 static void fuse_shufflechannel(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
481 {
482 int node_count = mutable_graph->node_size();
483 for (int i = 0; i < node_count; i++)
484 {
485 onnx::NodeProto* node = mutable_graph->mutable_node(i);
486
487 // ShuffleChannel <= Reshape - Transpose - Reshape
488 // ShuffleChannel <= Reshape - Transpose - Constant - Reshape
489 if (node->op_type() == "Reshape")
490 {
491 if (node_reference[node->output(0)] != 1)
492 continue;
493
494 std::vector<int> shape;
495 if (node->input_size() == 1)
496 {
497 shape = get_node_attr_ai(*node, "shape");
498 }
499 else
500 {
501 // skip weight reshape
502 if (weights.find(node->input(1)) == weights.end())
503 continue;
504
505 shape = get_node_attr_from_input_ai(weights[node->input(1)]);
506 }
507
508 // 1 groups channels_per_group, height, width
509 // reverse style = channels_per_group, groups, height * width
510 if (shape.size() != 5 && shape.size() != 3)
511 continue;
512
513 if (shape.size() == 5 && shape[0] != 1)
514 continue;
515
516 if (i + 2 >= node_count)
517 continue;
518
519 onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
520 onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
521
522 if (node3->op_type() == "Constant")
523 {
524 if (i + 3 >= node_count)
525 continue;
526
527 node3 = mutable_graph->mutable_node(i + 3);
528 }
529
530 if (node2->op_type() != "Transpose" || node3->op_type() != "Reshape")
531 continue;
532
533 if (node_reference[node2->output(0)] != 1)
534 continue;
535
536 // 0 2 1 3 4
537 // reverse style = 1 0 2
538 std::vector<int> perm = get_node_attr_ai(*node2, "perm");
539 if (perm.size() != 5 && perm.size() != 3)
540 continue;
541
542 if (perm.size() == 5 && (perm[0] != 0 || perm[1] != 2 || perm[2] != 1 || perm[3] != 3 || perm[4] != 4))
543 continue;
544
545 if (perm.size() == 3 && (perm[0] != 1 || perm[1] != 0 || perm[2] != 2))
546 continue;
547
548 std::vector<int> shape3;
549 if (node3->input_size() == 1)
550 {
551 shape3 = get_node_attr_ai(*node3, "shape");
552 }
553 else
554 {
555 // skip weight reshape
556 if (weights.find(node3->input(1)) == weights.end())
557 continue;
558
559 shape3 = get_node_attr_from_input_ai(weights[node3->input(1)]);
560 }
561
562 // 1, -1, height, width
563 // reverse style = group, -1, channels_per_group, height, width
564 if (shape3.size() != 4 && shape3.size() != 5)
565 continue;
566
567 if (shape3.size() == 4 && (shape3[0] != 1 || (shape3[1] != -1 && shape3[1] != shape[1] * shape[2])))
568 continue;
569
570 if (shape3.size() == 5 && (shape3[0] != shape[1] || shape3[2] != shape[0] || shape3[3] * shape3[4] != shape[2]))
571 continue;
572
573 // reduce
574 node->set_op_type("noop_reducedncnn");
575 node2->set_op_type("noop_reducedncnn");
576
577 if (node->input_size() == 2)
578 {
579 node_reference[node->input(1)] -= 1;
580 }
581 node_reference[node->output(0)] -= 1;
582 node_reference[node2->output(0)] -= 1;
583 if (node3->input_size() == 2)
584 {
585 node_reference[node3->input(1)] -= 1;
586 }
587
588 blob_names.erase(node->output(0));
589 blob_names.erase(node2->output(0));
590
591 node3->set_op_type("ShuffleChannel");
592 node3->set_input(0, node->input(0));
593
594 onnx::AttributeProto* attr_group = node3->add_attribute();
595 attr_group->set_name("group");
596 attr_group->set_i(shape[1]);
597
598 onnx::AttributeProto* attr_reverse = node3->add_attribute();
599 attr_reverse->set_name("reverse");
600 attr_reverse->set_i(shape.size() == 3);
601
602 reduced_node_count += 2;
603 i += 2;
604 }
605 }
606 }
607
fuse_shufflechannel_split(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)608 static void fuse_shufflechannel_split(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
609 {
610 int node_count = mutable_graph->node_size();
611 for (int i = 0; i < node_count; i++)
612 {
613 onnx::NodeProto* node = mutable_graph->mutable_node(i);
614
615 // Split <= ShuffleChannel(reverse type) - Gather(0) - Gather(1)
616 if (node->op_type() == "ShuffleChannel")
617 {
618 // reverse = 1
619 int reverse = get_node_attr_i(*node, "reverse");
620 if (reverse != 1)
621 continue;
622
623 if (i + 2 >= node_count)
624 continue;
625
626 onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
627 onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
628
629 if (node2->op_type() != "Gather" || node3->op_type() != "Gather")
630 continue;
631
632 if (node2->input(0) != node->output(0) || node3->input(0) != node->output(0))
633 continue;
634
635 // axis = 0
636 int gather2_axis = get_node_attr_i(*node2, "axis");
637 if (gather2_axis != 0)
638 continue;
639
640 // indices = 0
641 if (weights.find(node2->input(1)) == weights.end())
642 continue;
643
644 std::vector<int> gather2_indices = get_node_attr_from_input_ai(weights[node2->input(1)]);
645 if (gather2_indices.size() != 1 || gather2_indices[0] != 0)
646 continue;
647
648 // axis = 0
649 int gather3_axis = get_node_attr_i(*node3, "axis");
650 if (gather3_axis != 0)
651 continue;
652
653 // indices = 1
654 if (weights.find(node3->input(1)) == weights.end())
655 continue;
656
657 std::vector<int> gather3_indices = get_node_attr_from_input_ai(weights[node3->input(1)]);
658 if (gather3_indices.size() != 1 || gather3_indices[0] != 1)
659 continue;
660
661 // reduce
662 node2->set_op_type("noop_reducedncnn");
663
664 node_reference[node->output(0)] -= 2;
665 node_reference[node2->input(1)] -= 1;
666 node_reference[node3->input(1)] -= 1;
667
668 node3->set_op_type("Split");
669 node3->clear_input();
670 node3->add_input(node->output(0));
671 node3->add_output(node3->output(0));
672 node3->set_output(0, node2->output(0));
673
674 node3->clear_attribute();
675 onnx::AttributeProto* attr_axis = node3->add_attribute();
676 attr_axis->set_name("axis");
677 attr_axis->set_i(1);
678
679 reduced_node_count += 1;
680 i += 1;
681 }
682 }
683 }
684
fuse_hardswish(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)685 static void fuse_hardswish(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
686 {
687 int node_count = mutable_graph->node_size();
688 for (int i = 0; i < node_count; i++)
689 {
690 onnx::NodeProto* node = mutable_graph->mutable_node(i);
691
692 // HardSwish <= Add(+3) - Clip(0,6) - Mul(X,) - Div(/6)
693 // HardSwish <= Add(+3) - Clip(0,6) - Mul(X,) - Mul(*(1/6))
694 // HardSwish <= Add(+3) - Clip(0,6) - Mul(X,) - Constant - Div(/6)
695 // HardSwish <= Add(+3) - Clip(0,6) - Mul(X,) - Constant - Mul(*(1/6))
696 // out = x * F.relu6(x + 3, inplace=True) / 6
697 if (node->op_type() == "Add")
698 {
699 if (node_reference[node->output(0)] != 1)
700 continue;
701
702 if (i + 3 >= node_count)
703 continue;
704
705 if (weights.find(node->input(1)) == weights.end())
706 continue;
707
708 const onnx::TensorProto& add_three = weights[node->input(1)];
709 if (add_three.dims_size() != 0 || get_tensor_proto_data_size(add_three) != 1)
710 continue;
711
712 float constant_add_three = get_node_attr_from_input_f(add_three);
713 if (constant_add_three != 3.f)
714 continue;
715
716 onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
717 onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
718 onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3);
719
720 if (node4->op_type() == "Constant")
721 {
722 if (i + 4 >= node_count)
723 continue;
724
725 node4 = mutable_graph->mutable_node(i + 4);
726 }
727
728 if (node2->op_type() != "Clip" || node3->op_type() != "Mul" || (node4->op_type() != "Div" && node4->op_type() != "Mul"))
729 continue;
730
731 if (node_reference[node2->output(0)] != 1)
732 continue;
733
734 float relu6_min;
735 float relu6_max;
736 if (node2->input_size() == 1)
737 {
738 relu6_min = get_node_attr_f(*node2, "min", -FLT_MAX);
739 relu6_max = get_node_attr_f(*node2, "max", FLT_MAX);
740 }
741 else
742 {
743 const onnx::TensorProto& min_tp = weights[node2->input(1)];
744 const onnx::TensorProto& max_tp = weights[node2->input(2)];
745
746 relu6_min = get_node_attr_from_input_f(min_tp);
747 relu6_max = get_node_attr_from_input_f(max_tp);
748 }
749 if (relu6_min != 0.f || relu6_max != 6.f)
750 continue;
751
752 if (node_reference[node3->output(0)] != 1)
753 continue;
754
755 if (node3->input(0) != node->input(0) || node3->input(1) != node2->output(0))
756 continue;
757
758 if (weights.find(node4->input(1)) == weights.end())
759 continue;
760
761 const onnx::TensorProto& div_six = weights[node4->input(1)];
762 if (div_six.dims_size() != 0 || get_tensor_proto_data_size(div_six) != 1)
763 continue;
764
765 float constant_div_six = get_node_attr_from_input_f(div_six);
766 if (node4->op_type() == "Div" && constant_div_six != 6.f)
767 continue;
768 if (node4->op_type() == "Mul" && constant_div_six != 1 / 6.f)
769 continue;
770
771 // reduce
772 node->set_op_type("noop_reducedncnn");
773 node2->set_op_type("noop_reducedncnn");
774 node3->set_op_type("noop_reducedncnn");
775
776 node_reference[node->input(0)] -= 1;
777 node_reference[node->input(1)] -= 1;
778 node_reference[node->output(0)] -= 1;
779 if (node2->input_size() == 3)
780 {
781 node_reference[node2->input(1)] -= 1;
782 node_reference[node2->input(2)] -= 1;
783 }
784 node_reference[node2->output(0)] -= 1;
785 node_reference[node3->output(0)] -= 1;
786 node_reference[node4->input(1)] -= 1;
787
788 blob_names.erase(node->output(0));
789 blob_names.erase(node2->output(0));
790 blob_names.erase(node3->output(0));
791
792 node4->set_op_type("HardSwish");
793 node4->clear_input();
794 node4->add_input(node->input(0));
795
796 onnx::AttributeProto* attr_alpha = node4->add_attribute();
797 attr_alpha->set_name("alpha");
798 attr_alpha->set_f(1.f / 6.f);
799
800 onnx::AttributeProto* attr_beta = node4->add_attribute();
801 attr_beta->set_name("beta");
802 attr_beta->set_f(3.f / 6.f);
803
804 reduced_node_count += 3;
805 i += 3;
806 }
807 }
808
809 for (int i = 0; i < node_count; i++)
810 {
811 onnx::NodeProto* node = mutable_graph->mutable_node(i);
812
813 // HardSwish <= HardSigmoid - Mul
814 // out = x * hsigmoid(x)
815 if (node->op_type() == "HardSigmoid")
816 {
817 if (node_reference[node->output(0)] != 1)
818 continue;
819
820 float alpha = get_node_attr_f(*node, "alpha", 0.2f);
821 float beta = get_node_attr_f(*node, "beta", 0.5f);
822
823 if (i + 1 >= node_count)
824 continue;
825
826 onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
827
828 if (node2->op_type() != "Mul")
829 continue;
830
831 if (node2->input(0) != node->input(0) || node2->input(1) != node->output(0))
832 continue;
833
834 // reduce
835 node->set_op_type("noop_reducedncnn");
836
837 node_reference[node->input(0)] -= 1;
838 node_reference[node->output(0)] -= 1;
839
840 blob_names.erase(node->output(0));
841
842 node2->set_op_type("HardSwish");
843 node2->clear_input();
844 node2->add_input(node->input(0));
845
846 onnx::AttributeProto* attr_alpha = node2->add_attribute();
847 attr_alpha->set_name("alpha");
848 attr_alpha->set_f(alpha);
849
850 onnx::AttributeProto* attr_beta = node2->add_attribute();
851 attr_beta->set_name("beta");
852 attr_beta->set_f(beta);
853
854 reduced_node_count += 1;
855 i += 1;
856 }
857 }
858 }
859
fuse_hardsigmoid(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)860 static void fuse_hardsigmoid(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
861 {
862 int node_count = mutable_graph->node_size();
863 for (int i = 0; i < node_count; i++)
864 {
865 onnx::NodeProto* node = mutable_graph->mutable_node(i);
866
867 // HardSigmoid <= Add(+3) - Clip(0,6) - Div(/6)
868 // HardSigmoid <= Add(+3) - Clip(0,6) - Mul(*(1/6))
869 // HardSigmoid <= Add(+3) - Clip(0,6) - Constant - Div(/6)
870 // HardSigmoid <= Add(+3) - Clip(0,6) - Constant - Mul(*(1/6))
871 // out = F.relu6(x + 3, inplace=True) / 6
872 if (node->op_type() == "Add")
873 {
874 if (node_reference[node->output(0)] != 1)
875 continue;
876
877 if (i + 2 >= node_count)
878 continue;
879
880 if (weights.find(node->input(1)) == weights.end())
881 continue;
882
883 const onnx::TensorProto& add_three = weights[node->input(1)];
884 if (add_three.dims_size() != 0 || get_tensor_proto_data_size(add_three) != 1)
885 continue;
886
887 float constant_add_three = get_node_attr_from_input_f(add_three);
888 if (constant_add_three != 3.f)
889 continue;
890
891 onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
892 onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
893
894 if (node3->op_type() == "Constant")
895 {
896 if (i + 3 >= node_count)
897 continue;
898
899 node3 = mutable_graph->mutable_node(i + 3);
900 }
901
902 if (node2->op_type() != "Clip" || (node3->op_type() != "Div" && node3->op_type() != "Mul"))
903 continue;
904
905 if (node_reference[node2->output(0)] != 1)
906 continue;
907
908 float relu6_min;
909 float relu6_max;
910 if (node2->input_size() == 1)
911 {
912 relu6_min = get_node_attr_f(*node2, "min", -FLT_MAX);
913 relu6_max = get_node_attr_f(*node2, "max", FLT_MAX);
914 }
915 else
916 {
917 const onnx::TensorProto& min_tp = weights[node2->input(1)];
918 const onnx::TensorProto& max_tp = weights[node2->input(2)];
919
920 relu6_min = get_node_attr_from_input_f(min_tp);
921 relu6_max = get_node_attr_from_input_f(max_tp);
922 }
923 if (relu6_min != 0.f || relu6_max != 6.f)
924 continue;
925
926 if (weights.find(node3->input(1)) == weights.end())
927 continue;
928
929 const onnx::TensorProto& div_six = weights[node3->input(1)];
930 if (div_six.dims_size() != 0 || get_tensor_proto_data_size(div_six) != 1)
931 continue;
932
933 float constant_div_six = get_node_attr_from_input_f(div_six);
934 if (node3->op_type() == "Div" && constant_div_six != 6.f)
935 continue;
936 if (node3->op_type() == "Mul" && constant_div_six != 1 / 6.f)
937 continue;
938
939 // reduce
940 node->set_op_type("noop_reducedncnn");
941 node2->set_op_type("noop_reducedncnn");
942
943 node_reference[node->input(1)] -= 1;
944 node_reference[node->output(0)] -= 1;
945 if (node2->input_size() == 3)
946 {
947 node_reference[node2->input(1)] -= 1;
948 node_reference[node2->input(2)] -= 1;
949 }
950 node_reference[node2->output(0)] -= 1;
951 node_reference[node3->input(1)] -= 1;
952
953 blob_names.erase(node->output(0));
954 blob_names.erase(node2->output(0));
955
956 node3->set_op_type("HardSigmoid");
957 node3->clear_input();
958 node3->add_input(node->input(0));
959
960 onnx::AttributeProto* attr_alpha = node3->add_attribute();
961 attr_alpha->set_name("alpha");
962 attr_alpha->set_f(1.f / 6.f);
963
964 onnx::AttributeProto* attr_beta = node3->add_attribute();
965 attr_beta->set_name("beta");
966 attr_beta->set_f(3.f / 6.f);
967
968 reduced_node_count += 2;
969 i += 2;
970 }
971 }
972 }
973
fuse_swish(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)974 static void fuse_swish(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
975 {
976 int node_count = mutable_graph->node_size();
977 for (int i = 0; i < node_count; i++)
978 {
979 onnx::NodeProto* node = mutable_graph->mutable_node(i);
980
981 // Swish <= Sigmoid - Mul
982 // x * torch.sigmoid(x)
983 if (node->op_type() == "Sigmoid")
984 {
985 if (node_reference[node->output(0)] != 1)
986 continue;
987
988 if (i + 1 >= node_count)
989 continue;
990
991 onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
992
993 if (node2->op_type() != "Mul")
994 continue;
995
996 if (node2->input(0) != node->input(0) || node2->input(1) != node->output(0))
997 continue;
998
999 // reduce
1000 node->set_op_type("noop_reducedncnn");
1001
1002 node_reference[node->input(0)] -= 1;
1003 node_reference[node->output(0)] -= 1;
1004
1005 blob_names.erase(node->output(0));
1006
1007 node2->set_op_type("Swish");
1008 node2->clear_input();
1009 node2->add_input(node->input(0));
1010
1011 reduced_node_count += 1;
1012 i += 1;
1013 }
1014 }
1015 }
1016
fuse_batchnorm1d_squeeze_unsqueeze(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)1017 static void fuse_batchnorm1d_squeeze_unsqueeze(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
1018 {
1019 int node_count = mutable_graph->node_size();
1020 for (int i = 0; i < node_count; i++)
1021 {
1022 onnx::NodeProto* node = mutable_graph->mutable_node(i);
1023
1024 // BatchNormalization <= Unsqueeze - BatchNormalization - Squeeze
1025 if (node->op_type() == "Unsqueeze")
1026 {
1027 if (node_reference[node->output(0)] != 1)
1028 continue;
1029
1030 if (i + 2 >= node_count)
1031 continue;
1032
1033 onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
1034 onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
1035
1036 if (node2->op_type() != "BatchNormalization" || node3->op_type() != "Squeeze")
1037 continue;
1038
1039 if (node_reference[node2->output(0)] != 1)
1040 continue;
1041
1042 if (node2->input(0) != node->output(0) || node3->input(0) != node2->output(0))
1043 continue;
1044
1045 // reduce
1046 node->set_op_type("noop_reducedncnn");
1047 node3->set_op_type("noop_reducedncnn");
1048
1049 node_reference[node->output(0)] -= 1;
1050 node_reference[node2->output(0)] -= 1;
1051
1052 blob_names.erase(node->output(0));
1053 blob_names.erase(node2->output(0));
1054
1055 node2->set_input(0, node->input(0));
1056 node2->set_output(0, node3->output(0));
1057
1058 reduced_node_count += 2;
1059 i += 2;
1060 }
1061 }
1062 }
1063
fuse_unsqueeze_prelu(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)1064 static void fuse_unsqueeze_prelu(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
1065 {
1066 int node_count = mutable_graph->node_size();
1067 for (int i = 0; i < node_count; i++)
1068 {
1069 onnx::NodeProto* node = mutable_graph->mutable_node(i);
1070
1071 // PReLU <= Unsqueeze - PReLU
1072 if (node->op_type() == "Unsqueeze")
1073 {
1074 // check weight
1075 if (weights.find(node->input(0)) == weights.end())
1076 continue;
1077
1078 onnx::TensorProto& B = weights[node->input(0)];
1079 if (B.dims_size() != 1)
1080 continue;
1081
1082 if (node_reference[node->output(0)] != 1)
1083 continue;
1084
1085 // axes = (1, 2)
1086 std::vector<int> axes = get_node_attr_ai(*node, "axes");
1087 if (axes.size() != 2)
1088 continue;
1089 if (axes[0] != 1 || axes[1] != 2)
1090 continue;
1091
1092 if (i + 1 >= node_count)
1093 continue;
1094
1095 onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
1096
1097 if (node2->op_type() != "PRelu")
1098 continue;
1099
1100 if (node2->input(1) != node->output(0))
1101 continue;
1102
1103 // reduce
1104 node->set_op_type("noop_reducedncnn");
1105
1106 node_reference[node->output(0)] -= 1;
1107
1108 blob_names.erase(node->output(0));
1109
1110 node2->set_input(1, node->input(0));
1111
1112 reduced_node_count += 1;
1113 i += 1;
1114 }
1115 }
1116 }
1117
fuse_normalize(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)1118 static void fuse_normalize(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
1119 {
1120 int node_count = mutable_graph->node_size();
1121 for (int i = 0; i < node_count; i++)
1122 {
1123 onnx::NodeProto* node = mutable_graph->mutable_node(i);
1124
1125 // Normalize <= X - ReduceL2 - Clip - Expand - Div
1126 // Normalize <= X - ReduceL2 - Clip - Shape - Expand - Div
1127 if (node->op_type() == "ReduceL2")
1128 {
1129 if (node_reference[node->output(0)] != 1)
1130 continue;
1131
1132 // axes = (1)
1133 std::vector<int> axes = get_node_attr_ai(*node, "axes");
1134 if (axes.size() != 1)
1135 continue;
1136 if (axes[0] != 1)
1137 continue;
1138
1139 if (i + 3 >= node_count)
1140 continue;
1141
1142 onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
1143 onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
1144 onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3);
1145
1146 bool has_shape_node = node3->op_type() == "Shape";
1147 onnx::NodeProto* node_shape = 0;
1148 if (has_shape_node)
1149 {
1150 if (i + 4 >= node_count)
1151 continue;
1152
1153 node_shape = node3;
1154 node3 = mutable_graph->mutable_node(i + 3);
1155 node4 = mutable_graph->mutable_node(i + 4);
1156 }
1157
1158 if (node2->op_type() != "Clip" || node3->op_type() != "Expand" || node4->op_type() != "Div")
1159 continue;
1160
1161 if (node_reference[node2->output(0)] != 1)
1162 continue;
1163
1164 if (node_reference[node3->output(0)] != 1)
1165 continue;
1166
1167 if (node2->input(0) != node->output(0) || node3->input(0) != node2->output(0)
1168 || node4->input(0) != node->input(0) || node4->input(1) != node3->output(0))
1169 continue;
1170
1171 if (has_shape_node)
1172 {
1173 if (node_shape->input(0) != node->input(0) || node3->input(1) != node_shape->output(0))
1174 continue;
1175 }
1176
1177 // +eps
1178 float clip_min;
1179 if (node2->input_size() == 1)
1180 {
1181 clip_min = get_node_attr_f(*node2, "min", -FLT_MAX);
1182 }
1183 else
1184 {
1185 const onnx::TensorProto& min_tp = weights[node2->input(1)];
1186
1187 clip_min = get_node_attr_from_input_f(min_tp);
1188 }
1189
1190 // reduce
1191 node->set_op_type("noop_reducedncnn");
1192 node2->set_op_type("noop_reducedncnn");
1193 if (has_shape_node)
1194 {
1195 node_shape->set_op_type("noop_reducedncnn");
1196 }
1197 node3->set_op_type("noop_reducedncnn");
1198
1199 node_reference[node->input(0)] -= has_shape_node ? 2 : 1;
1200 node_reference[node->output(0)] -= 1;
1201 node_reference[node2->output(0)] -= 1;
1202 if (has_shape_node)
1203 {
1204 node_reference[node_shape->output(0)] -= 1;
1205 }
1206 node_reference[node3->output(0)] -= 1;
1207 if (node3->input_size() == 2)
1208 {
1209 node_reference[node3->input(1)] -= 1;
1210 }
1211
1212 blob_names.erase(node->output(0));
1213 blob_names.erase(node2->output(0));
1214 if (has_shape_node)
1215 {
1216 blob_names.erase(node_shape->output(0));
1217 }
1218 blob_names.erase(node3->output(0));
1219
1220 node4->set_op_type("Normalize");
1221 node4->clear_input();
1222 node4->add_input(node->input(0));
1223
1224 onnx::AttributeProto* attr_alpha = node4->add_attribute();
1225 attr_alpha->set_name("eps");
1226 attr_alpha->set_f(clip_min);
1227
1228 reduced_node_count += has_shape_node ? 4 : 3;
1229 i += has_shape_node ? 4 : 3;
1230 }
1231 }
1232 }
1233
fuse_groupnorm(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)1234 static void fuse_groupnorm(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
1235 {
1236 int node_count = mutable_graph->node_size();
1237 for (int i = 0; i < node_count; i++)
1238 {
1239 onnx::NodeProto* node = mutable_graph->mutable_node(i);
1240
1241 // GroupNorm <= X - Reshape - InstanceNormalization - Reshape - Mul - Add
1242 if (node->op_type() == "Reshape")
1243 {
1244 if (node_reference[node->output(0)] != 1)
1245 continue;
1246
1247 std::vector<int> shape;
1248 if (node->input_size() == 1)
1249 {
1250 shape = get_node_attr_ai(*node, "shape");
1251 }
1252 else
1253 {
1254 // skip weight reshape
1255 if (weights.find(node->input(1)) == weights.end())
1256 continue;
1257
1258 shape = get_node_attr_from_input_ai(weights[node->input(1)]);
1259 }
1260
1261 // 0, group, -1
1262 if (shape.size() != 3)
1263 continue;
1264
1265 if (shape[0] != 0 || shape[2] != -1)
1266 continue;
1267
1268 int groups = shape[1];
1269
1270 if (i + 4 >= node_count)
1271 continue;
1272
1273 onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
1274 onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
1275 onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3);
1276 onnx::NodeProto* node5 = mutable_graph->mutable_node(i + 4);
1277
1278 if (node2->op_type() != "InstanceNormalization" || node3->op_type() != "Reshape" || node4->op_type() != "Mul" || node5->op_type() != "Add")
1279 continue;
1280
1281 if (node_reference[node2->output(0)] != 1)
1282 continue;
1283
1284 if (node_reference[node3->output(0)] != 1)
1285 continue;
1286
1287 if (node_reference[node4->output(0)] != 1)
1288 continue;
1289
1290 if (node2->input(0) != node->output(0) || node3->input(0) != node2->output(0)
1291 || node4->input(0) != node3->output(0) || node5->input(0) != node4->output(0))
1292 continue;
1293
1294 // +eps
1295 float eps = get_node_attr_f(*node2, "epsilon", 1e-05f);
1296
1297 // InstanceNormalization S=1 B=0
1298 std::vector<float> S = get_node_attr_from_input_af(weights[node2->input(1)]);
1299 std::vector<float> B = get_node_attr_from_input_af(weights[node2->input(2)]);
1300 if ((int)S.size() != groups || (int)B.size() != groups)
1301 continue;
1302
1303 bool instancenorm_affine = false;
1304 for (int j = 0; j < groups; j++)
1305 {
1306 if (S[j] != 1.f || B[j] != 0.f)
1307 {
1308 instancenorm_affine = true;
1309 break;
1310 }
1311 }
1312
1313 if (instancenorm_affine)
1314 continue;
1315
1316 std::vector<int> shape2;
1317 if (node3->input_size() == 1)
1318 {
1319 shape2 = get_node_attr_ai(*node3, "shape");
1320 }
1321 else
1322 {
1323 // skip weight reshape
1324 if (weights.find(node3->input(1)) == weights.end())
1325 continue;
1326
1327 shape2 = get_node_attr_from_input_ai(weights[node3->input(1)]);
1328 }
1329
1330 // 1, channels, w, h
1331 if (shape2.size() != 4)
1332 continue;
1333
1334 if (shape2[0] != 1)
1335 continue;
1336
1337 int channels = shape2[1];
1338
1339 // affine
1340 std::vector<float> affine_S = get_node_attr_from_input_af(weights[node4->input(1)]);
1341 std::vector<float> affine_B = get_node_attr_from_input_af(weights[node5->input(1)]);
1342 if (affine_S.size() == 1 && affine_S[0] == 1.f && affine_B.size() == 1 && affine_B[0] == 0.f)
1343 {
1344 // no affine
1345 }
1346 else if ((int)affine_S.size() != channels && (int)affine_B.size() != channels)
1347 {
1348 // we only allow per-channel affine
1349 continue;
1350 }
1351
1352 // reduce
1353 node->set_op_type("noop_reducedncnn");
1354 node2->set_op_type("noop_reducedncnn");
1355 node3->set_op_type("noop_reducedncnn");
1356 node4->set_op_type("noop_reducedncnn");
1357
1358 if (node->input_size() == 2)
1359 {
1360 node_reference[node->input(1)] -= 1;
1361 }
1362 node_reference[node->output(0)] -= 1;
1363 node_reference[node2->input(1)] -= 1;
1364 node_reference[node2->input(2)] -= 1;
1365 node_reference[node2->output(0)] -= 1;
1366 if (node3->input_size() == 2)
1367 {
1368 node_reference[node3->input(1)] -= 1;
1369 }
1370 node_reference[node3->output(0)] -= 1;
1371 node_reference[node4->output(0)] -= 1;
1372
1373 blob_names.erase(node->output(0));
1374 blob_names.erase(node2->output(0));
1375 blob_names.erase(node3->output(0));
1376 blob_names.erase(node4->output(0));
1377
1378 std::string affine_scale = node4->input(1);
1379 std::string affine_bias = node5->input(1);
1380
1381 node5->set_op_type("GroupNorm");
1382 node5->clear_input();
1383 node5->add_input(node->input(0));
1384 node5->add_input(affine_scale);
1385 node5->add_input(affine_bias);
1386
1387 onnx::AttributeProto* attr_groups = node5->add_attribute();
1388 attr_groups->set_name("groups");
1389 attr_groups->set_i(groups);
1390
1391 onnx::AttributeProto* attr_channels = node5->add_attribute();
1392 attr_channels->set_name("channels");
1393 attr_channels->set_i(channels);
1394
1395 onnx::AttributeProto* attr_eps = node5->add_attribute();
1396 attr_eps->set_name("epsilon");
1397 attr_eps->set_f(eps);
1398
1399 onnx::AttributeProto* attr_affine = node5->add_attribute();
1400 attr_affine->set_name("affine");
1401 attr_affine->set_i(1);
1402
1403 reduced_node_count += 4;
1404 i += 4;
1405 }
1406 }
1407 }
1408
fuse_layernorm(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)1409 static void fuse_layernorm(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
1410 {
1411 int node_count = mutable_graph->node_size();
1412 for (int i = 0; i < node_count; i++)
1413 {
1414 onnx::NodeProto* node = mutable_graph->mutable_node(i);
1415
1416 // LayerNorm <= X - ReduceMean - Sub - Pow - ReduceMean - Add - Sqrt - Div
1417 // LayerNorm <= X - ReduceMean - Sub - Pow - ReduceMean - Add - Sqrt - Div - Mul - Add
1418 if (node->op_type() == "ReduceMean")
1419 {
1420 if (node_reference[node->output(0)] != 1)
1421 continue;
1422
1423 std::vector<int> axes = get_node_attr_ai(*node, "axes");
1424
1425 // -1
1426 // -2 -1
1427 if (axes.size() != 1 && axes.size() != 2)
1428 continue;
1429
1430 int normed_axes = (int)axes.size();
1431 if (normed_axes == 1 && axes[0] != -1)
1432 continue;
1433 if (normed_axes == 2 && (axes[0] != -2 || axes[1] != -1))
1434 continue;
1435
1436 if (i + 6 >= node_count)
1437 continue;
1438
1439 onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
1440 onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
1441 onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3);
1442 onnx::NodeProto* node5 = mutable_graph->mutable_node(i + 4);
1443 onnx::NodeProto* node6 = mutable_graph->mutable_node(i + 5);
1444 onnx::NodeProto* node7 = mutable_graph->mutable_node(i + 6);
1445
1446 if (node2->op_type() != "Sub" || node3->op_type() != "Pow" || node4->op_type() != "ReduceMean" || node5->op_type() != "Add" || node6->op_type() != "Sqrt" || node7->op_type() != "Div")
1447 continue;
1448
1449 if (node_reference[node2->output(0)] != 2)
1450 continue;
1451
1452 if (node_reference[node3->output(0)] != 1)
1453 continue;
1454
1455 if (node_reference[node4->output(0)] != 1)
1456 continue;
1457
1458 if (node_reference[node5->output(0)] != 1)
1459 continue;
1460
1461 if (node_reference[node6->output(0)] != 1)
1462 continue;
1463
1464 if (node2->input(0) != node->input(0) || node2->input(1) != node->output(0)
1465 || node3->input(0) != node2->output(0) || node4->input(0) != node3->output(0)
1466 || node5->input(0) != node4->output(0) || node6->input(0) != node5->output(0)
1467 || node7->input(0) != node2->output(0) || node7->input(1) != node6->output(0))
1468 continue;
1469
1470 if (weights.find(node3->input(1)) == weights.end())
1471 continue;
1472
1473 const onnx::TensorProto& pow_two = weights[node3->input(1)];
1474 if (pow_two.dims_size() != 0 || get_tensor_proto_data_size(pow_two) != 1)
1475 continue;
1476
1477 float constant_pow_two = get_node_attr_from_input_f(pow_two);
1478 if (constant_pow_two != 2.f)
1479 continue;
1480
1481 std::vector<int> axes4 = get_node_attr_ai(*node4, "axes");
1482
1483 // -1
1484 // -2 -1
1485 if ((int)axes4.size() != normed_axes)
1486 continue;
1487
1488 if (normed_axes == 1 && axes4[0] != -1)
1489 continue;
1490 if (normed_axes == 2 && (axes4[0] != -2 || axes4[1] != -1))
1491 continue;
1492
1493 if (weights.find(node5->input(1)) == weights.end())
1494 continue;
1495
1496 const onnx::TensorProto& add_eps = weights[node5->input(1)];
1497 if (add_eps.dims_size() != 0 || get_tensor_proto_data_size(add_eps) != 1)
1498 continue;
1499
1500 float eps = get_node_attr_from_input_f(add_eps);
1501
1502 int affine = 0;
1503 while (i + 8 < node_count)
1504 {
1505 onnx::NodeProto* node8 = mutable_graph->mutable_node(i + 7);
1506 onnx::NodeProto* node9 = mutable_graph->mutable_node(i + 8);
1507
1508 if (node8->op_type() != "Mul" || node9->op_type() != "Add")
1509 break;
1510
1511 if (node_reference[node7->output(0)] != 1)
1512 break;
1513
1514 if (node_reference[node8->output(0)] != 1)
1515 break;
1516
1517 if (node8->input(0) != node7->output(0) || node9->input(0) != node8->output(0))
1518 break;
1519
1520 // affine
1521 std::vector<float> affine_S = get_node_attr_from_input_af(weights[node8->input(1)]);
1522 std::vector<float> affine_B = get_node_attr_from_input_af(weights[node9->input(1)]);
1523 if (affine_S.size() != affine_B.size())
1524 break;
1525
1526 affine = 1;
1527 break;
1528 }
1529
1530 // reduce
1531 node->set_op_type("noop_reducedncnn");
1532 node2->set_op_type("noop_reducedncnn");
1533 node3->set_op_type("noop_reducedncnn");
1534 node4->set_op_type("noop_reducedncnn");
1535 node5->set_op_type("noop_reducedncnn");
1536 node6->set_op_type("noop_reducedncnn");
1537
1538 node_reference[node->input(0)] -= 1;
1539 node_reference[node2->input(0)] -= 1;
1540 node_reference[node2->input(1)] -= 1;
1541 node_reference[node3->input(0)] -= 1;
1542 node_reference[node3->input(1)] -= 1;
1543 node_reference[node4->input(0)] -= 1;
1544 node_reference[node5->input(0)] -= 1;
1545 node_reference[node5->input(1)] -= 1;
1546 node_reference[node6->input(0)] -= 1;
1547 node_reference[node7->input(0)] -= 1;
1548 node_reference[node7->input(1)] -= 1;
1549
1550 blob_names.erase(node->output(0));
1551 blob_names.erase(node2->output(0));
1552 blob_names.erase(node3->output(0));
1553 blob_names.erase(node4->output(0));
1554 blob_names.erase(node5->output(0));
1555 blob_names.erase(node6->output(0));
1556
1557 node_reference[node->input(0)] += 1;
1558
1559 if (affine == 0)
1560 {
1561 node7->set_op_type("LayerNorm");
1562 node7->clear_input();
1563 node7->add_input(node->input(0));
1564
1565 onnx::AttributeProto* attr_eps = node7->add_attribute();
1566 attr_eps->set_name("epsilon");
1567 attr_eps->set_f(eps);
1568
1569 onnx::AttributeProto* attr_affine = node7->add_attribute();
1570 attr_affine->set_name("affine");
1571 attr_affine->set_i(affine);
1572
1573 reduced_node_count += 6;
1574 i += 6;
1575 }
1576 else // if (affine == 1)
1577 {
1578 onnx::NodeProto* node8 = mutable_graph->mutable_node(i + 7);
1579 onnx::NodeProto* node9 = mutable_graph->mutable_node(i + 8);
1580
1581 node7->set_op_type("noop_reducedncnn");
1582 node8->set_op_type("noop_reducedncnn");
1583
1584 node_reference[node8->input(0)] -= 1;
1585 node_reference[node9->input(0)] -= 1;
1586
1587 blob_names.erase(node7->output(0));
1588 blob_names.erase(node8->output(0));
1589
1590 std::string affine_scale = node8->input(1);
1591 std::string affine_bias = node9->input(1);
1592
1593 node9->set_op_type("LayerNorm");
1594 node9->clear_input();
1595 node9->add_input(node->input(0));
1596 node9->add_input(affine_scale);
1597 node9->add_input(affine_bias);
1598
1599 onnx::AttributeProto* attr_eps = node9->add_attribute();
1600 attr_eps->set_name("epsilon");
1601 attr_eps->set_f(eps);
1602
1603 onnx::AttributeProto* attr_affine = node9->add_attribute();
1604 attr_affine->set_name("affine");
1605 attr_affine->set_i(affine);
1606
1607 reduced_node_count += 8;
1608 i += 8;
1609 }
1610 }
1611 }
1612 }
1613
fuse_flatten(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)1614 static void fuse_flatten(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
1615 {
1616 int node_count = mutable_graph->node_size();
1617 for (int i = 0; i < node_count; i++)
1618 {
1619 onnx::NodeProto* node = mutable_graph->mutable_node(i);
1620
1621 // Flatten <= X - Shape - Gather - Constant - Unsqueeze - Unsqueeze - Concat - Reshape
1622 if (node->op_type() == "Shape")
1623 {
1624 if (node_reference[node->output(0)] != 1)
1625 continue;
1626
1627 if (i + 6 >= node_count)
1628 continue;
1629
1630 onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
1631 onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
1632 onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3);
1633 onnx::NodeProto* node5 = mutable_graph->mutable_node(i + 4);
1634 onnx::NodeProto* node6 = mutable_graph->mutable_node(i + 5);
1635 onnx::NodeProto* node7 = mutable_graph->mutable_node(i + 6);
1636
1637 if (node2->op_type() != "Gather" || node3->op_type() != "Constant" || node4->op_type() != "Unsqueeze" || node5->op_type() != "Unsqueeze"
1638 || node6->op_type() != "Concat" || node7->op_type() != "Reshape")
1639 continue;
1640
1641 if (node_reference[node2->output(0)] != 1)
1642 continue;
1643
1644 // if (node_reference[node3->output(0)] != 1)
1645 // continue;
1646
1647 if (node_reference[node4->output(0)] != 1)
1648 continue;
1649
1650 if (node_reference[node5->output(0)] != 1)
1651 continue;
1652
1653 if (node_reference[node6->output(0)] != 1)
1654 continue;
1655
1656 if (node2->input(0) != node->output(0) || node4->input(0) != node2->output(0) || node5->input(0) != node3->output(0)
1657 || node6->input(0) != node4->output(0) || node6->input(1) != node5->output(0)
1658 || node7->input(0) != node->input(0) || node7->input(1) != node6->output(0))
1659 continue;
1660
1661 // axis = 0
1662 int gather_axis = get_node_attr_i(*node2, "axis");
1663 if (gather_axis != 0)
1664 continue;
1665
1666 // indices = 0
1667 if (weights.find(node2->input(1)) == weights.end())
1668 continue;
1669
1670 std::vector<int> gather_indices = get_node_attr_from_input_ai(weights[node2->input(1)]);
1671 if (gather_indices.size() != 1 || gather_indices[0] != 0)
1672 continue;
1673
1674 // axes = (0)
1675 std::vector<int> unsqueeze_axes = get_node_attr_ai(*node4, "axes");
1676 if (unsqueeze_axes.size() != 1)
1677 continue;
1678 if (unsqueeze_axes[0] != 0)
1679 continue;
1680
1681 // axes = (0)
1682 std::vector<int> unsqueeze2_axes = get_node_attr_ai(*node5, "axes");
1683 if (unsqueeze2_axes.size() != 1)
1684 continue;
1685 if (unsqueeze2_axes[0] != 0)
1686 continue;
1687
1688 // data = -1
1689 if (weights.find(node5->input(0)) == weights.end())
1690 continue;
1691
1692 std::vector<int> unsqueeze2_data = get_node_attr_from_input_ai(weights[node5->input(0)]);
1693 if (unsqueeze2_data.size() != 1 || unsqueeze2_data[0] != -1)
1694 continue;
1695
1696 // axis = 0
1697 int concat_axis = get_node_attr_i(*node6, "axis");
1698 if (concat_axis != 0)
1699 continue;
1700
1701 // reduce
1702 node->set_op_type("noop_reducedncnn");
1703 node2->set_op_type("noop_reducedncnn");
1704 // node3->set_op_type("noop_reducedncnn");
1705 node4->set_op_type("noop_reducedncnn");
1706 node5->set_op_type("noop_reducedncnn");
1707 node6->set_op_type("noop_reducedncnn");
1708
1709 node_reference[node->input(0)] -= 1;
1710 node_reference[node->output(0)] -= 1;
1711 node_reference[node2->input(1)] -= 1;
1712 node_reference[node2->output(0)] -= 1;
1713 // node_reference[node3->output(0)] -= 1;
1714 node_reference[node4->output(0)] -= 1;
1715 node_reference[node5->input(0)] -= 1;
1716 node_reference[node5->output(0)] -= 1;
1717 node_reference[node6->output(0)] -= 1;
1718
1719 blob_names.erase(node->output(0));
1720 blob_names.erase(node2->output(0));
1721 // blob_names.erase(node3->output(0));
1722 blob_names.erase(node4->output(0));
1723 blob_names.erase(node5->output(0));
1724 blob_names.erase(node6->output(0));
1725
1726 node7->set_op_type("Flatten");
1727 node7->clear_input();
1728 node7->add_input(node->input(0));
1729
1730 reduced_node_count += 5;
1731 i += 5;
1732 }
1733 }
1734 }
1735
fuse_pixelshuffle(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)1736 static void fuse_pixelshuffle(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
1737 {
1738 int node_count = mutable_graph->node_size();
1739 for (int i = 0; i < node_count; i++)
1740 {
1741 onnx::NodeProto* node = mutable_graph->mutable_node(i);
1742
1743 // PixelShuffle <= Reshape - Transpose - Reshape
1744 // PixelShuffle <= Reshape - Transpose - Constant - Reshape
1745 if (node->op_type() == "Reshape")
1746 {
1747 if (node_reference[node->output(0)] != 1)
1748 continue;
1749
1750 std::vector<int> shape;
1751 if (node->input_size() == 1)
1752 {
1753 shape = get_node_attr_ai(*node, "shape");
1754 }
1755 else
1756 {
1757 // skip weight reshape
1758 if (weights.find(node->input(1)) == weights.end())
1759 continue;
1760
1761 shape = get_node_attr_from_input_ai(weights[node->input(1)]);
1762 }
1763
1764 // -1, 3, upscale_factor, upscale_factor, height, width
1765 if (shape.size() != 6)
1766 continue;
1767
1768 if (shape[0] != 1 && shape[0] != -1)
1769 continue;
1770
1771 if (shape[2] != shape[3])
1772 continue;
1773
1774 if (i + 2 >= node_count)
1775 continue;
1776
1777 onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
1778 onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
1779
1780 if (node3->op_type() == "Constant")
1781 {
1782 if (i + 3 >= node_count)
1783 continue;
1784
1785 node3 = mutable_graph->mutable_node(i + 3);
1786 }
1787
1788 if (node2->op_type() != "Transpose" || node3->op_type() != "Reshape")
1789 continue;
1790
1791 if (node_reference[node2->output(0)] != 1)
1792 continue;
1793
1794 // 0 1 4 2 5 3
1795 std::vector<int> perm = get_node_attr_ai(*node2, "perm");
1796 if (perm.size() != 6)
1797 continue;
1798
1799 if (perm[0] != 0 || perm[1] != 1 || perm[2] != 4 || perm[3] != 2 || perm[4] != 5 || perm[5] != 3)
1800 continue;
1801
1802 std::vector<int> shape3;
1803 if (node3->input_size() == 1)
1804 {
1805 shape3 = get_node_attr_ai(*node3, "shape");
1806 }
1807 else
1808 {
1809 // skip weight reshape
1810 if (weights.find(node3->input(1)) == weights.end())
1811 continue;
1812
1813 shape3 = get_node_attr_from_input_ai(weights[node3->input(1)]);
1814 }
1815
1816 // -1, 3, height, width
1817 if (shape3.size() != 4)
1818 continue;
1819
1820 if (shape3[0] != 1 && shape3[0] != -1)
1821 continue;
1822
1823 if (shape3[1] != shape[1] || shape3[2] != shape[2] * shape[4] || shape3[3] != shape[3] * shape[5])
1824 continue;
1825
1826 // reduce
1827 node->set_op_type("noop_reducedncnn");
1828 node2->set_op_type("noop_reducedncnn");
1829
1830 if (node->input_size() == 2)
1831 {
1832 node_reference[node->input(1)] -= 1;
1833 }
1834 node_reference[node->output(0)] -= 1;
1835 node_reference[node2->output(0)] -= 1;
1836 if (node3->input_size() == 2)
1837 {
1838 node_reference[node3->input(1)] -= 1;
1839 }
1840
1841 blob_names.erase(node->output(0));
1842 blob_names.erase(node2->output(0));
1843
1844 node3->set_op_type("PixelShuffle");
1845 node3->set_input(0, node->input(0));
1846
1847 onnx::AttributeProto* attr_group = node3->add_attribute();
1848 attr_group->set_name("scale_factor");
1849 attr_group->set_i(shape[2]);
1850
1851 reduced_node_count += 2;
1852 i += 2;
1853 }
1854 }
1855 }
1856
fuse_reorg(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)1857 static void fuse_reorg(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
1858 {
1859 int node_count = mutable_graph->node_size();
1860 for (int i = 0; i < node_count; i++)
1861 {
1862 onnx::NodeProto* node = mutable_graph->mutable_node(i);
1863
1864 // PixelShuffle <= Reshape - Transpose - Reshape
1865 // PixelShuffle <= Reshape - Transpose - Constant - Reshape
1866 if (node->op_type() == "Reshape")
1867 {
1868 if (node_reference[node->output(0)] != 1)
1869 continue;
1870
1871 std::vector<int> shape;
1872 if (node->input_size() == 1)
1873 {
1874 shape = get_node_attr_ai(*node, "shape");
1875 }
1876 else
1877 {
1878 // skip weight reshape
1879 if (weights.find(node->input(1)) == weights.end())
1880 continue;
1881
1882 shape = get_node_attr_from_input_ai(weights[node->input(1)]);
1883 }
1884
1885 // -1, 3, out_height, block_size, out_width, block_size
1886 if (shape.size() != 6)
1887 continue;
1888
1889 if (shape[0] != 1 && shape[0] != -1)
1890 continue;
1891
1892 if (shape[3] != shape[5])
1893 continue;
1894
1895 if (i + 2 >= node_count)
1896 continue;
1897
1898 onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
1899 onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
1900
1901 if (node3->op_type() == "Constant")
1902 {
1903 if (i + 3 >= node_count)
1904 continue;
1905
1906 node3 = mutable_graph->mutable_node(i + 3);
1907 }
1908
1909 if (node2->op_type() != "Transpose" || node3->op_type() != "Reshape")
1910 continue;
1911
1912 if (node_reference[node2->output(0)] != 1)
1913 continue;
1914
1915 // 0 1 3 5 2 4
1916 std::vector<int> perm = get_node_attr_ai(*node2, "perm");
1917 if (perm.size() != 6)
1918 continue;
1919
1920 if (perm[0] != 0 || perm[1] != 1 || perm[2] != 3 || perm[3] != 5 || perm[4] != 2 || perm[5] != 4)
1921 continue;
1922
1923 std::vector<int> shape3;
1924 if (node3->input_size() == 1)
1925 {
1926 shape3 = get_node_attr_ai(*node3, "shape");
1927 }
1928 else
1929 {
1930 // skip weight reshape
1931 if (weights.find(node3->input(1)) == weights.end())
1932 continue;
1933
1934 shape3 = get_node_attr_from_input_ai(weights[node3->input(1)]);
1935 }
1936
1937 // -1, out_channels, out_height, out_width
1938 if (shape3.size() != 4)
1939 continue;
1940
1941 if (shape3[0] != 1 && shape3[0] != -1)
1942 continue;
1943
1944 if (shape3[1] != shape[1] * shape[3] * shape[5] || shape3[2] != shape[2] || shape3[3] != shape[4])
1945 continue;
1946
1947 // reduce
1948 node->set_op_type("noop_reducedncnn");
1949 node2->set_op_type("noop_reducedncnn");
1950
1951 if (node->input_size() == 2)
1952 {
1953 node_reference[node->input(1)] -= 1;
1954 }
1955 node_reference[node->output(0)] -= 1;
1956 node_reference[node2->output(0)] -= 1;
1957 if (node3->input_size() == 2)
1958 {
1959 node_reference[node3->input(1)] -= 1;
1960 }
1961
1962 blob_names.erase(node->output(0));
1963 blob_names.erase(node2->output(0));
1964
1965 node3->set_op_type("Reorg");
1966 node3->set_input(0, node->input(0));
1967
1968 onnx::AttributeProto* attr_group = node3->add_attribute();
1969 attr_group->set_name("stride");
1970 attr_group->set_i(shape[3]);
1971
1972 reduced_node_count += 2;
1973 i += 2;
1974 }
1975 }
1976 }
1977
fuse_expand_broadcast(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)1978 static void fuse_expand_broadcast(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
1979 {
1980 int node_count = mutable_graph->node_size();
1981 for (int i = 0; i < node_count; i++)
1982 {
1983 onnx::NodeProto* node = mutable_graph->mutable_node(i);
1984
1985 // Add/Sub/Mul/Div/Min/Max <= Expand - Add/Sub/Mul/Div/Min/Max
1986 if (node->op_type() == "Expand")
1987 {
1988 if (node_reference[node->output(0)] != 1)
1989 continue;
1990
1991 if (i + 1 >= node_count)
1992 continue;
1993
1994 onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
1995
1996 if (node2->op_type() != "Add" && node2->op_type() != "Sub" && node2->op_type() != "Mul" && node2->op_type() != "Div" && node2->op_type() != "Min" && node2->op_type() != "Max")
1997 continue;
1998
1999 if (node2->input(1) != node->output(0) && node2->input(0) != node->output(0))
2000 continue;
2001
2002 // reduce
2003 node->set_op_type("noop_reducedncnn");
2004
2005 node_reference[node->output(0)] -= 1;
2006 if (node->input_size() == 2)
2007 {
2008 node_reference[node->input(1)] -= 1;
2009 }
2010
2011 blob_names.erase(node->output(0));
2012
2013 node2->set_input(1, node->input(0));
2014
2015 reduced_node_count += 1;
2016 i += 1;
2017 }
2018 }
2019 }
2020
fuse_lstm_gru_rnn(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)2021 static void fuse_lstm_gru_rnn(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
2022 {
2023 int node_count = mutable_graph->node_size();
2024 for (int i = 0; i < node_count; i++)
2025 {
2026 onnx::NodeProto* node = mutable_graph->mutable_node(i);
2027
2028 // LSTM(bi) <= LSTM(bi) - Transpose - Reshape - Transpose
2029 if (node->op_type() == "LSTM" || node->op_type() == "GRU" || node->op_type() == "RNN")
2030 {
2031 if (node_reference[node->output(0)] != 1)
2032 continue;
2033
2034 if (i + 2 >= node_count)
2035 continue;
2036
2037 onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
2038 onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
2039
2040 if (node2->op_type() != "Transpose" || node3->op_type() != "Reshape")
2041 continue;
2042
2043 if (node_reference[node2->output(0)] != 1)
2044 continue;
2045
2046 if (node2->input(0) != node->output(0) || node3->input(0) != node2->output(0))
2047 continue;
2048
2049 std::string direction = get_node_attr_s(*node, "direction");
2050 if (direction != "bidirectional")
2051 continue;
2052
2053 // 0 2 1 3
2054 std::vector<int> perm = get_node_attr_ai(*node2, "perm");
2055 if (perm.size() != 4)
2056 continue;
2057
2058 if (perm[0] != 0 || perm[1] != 2 || perm[2] != 1 || perm[3] != 3)
2059 continue;
2060
2061 std::vector<int> shape;
2062 if (node3->input_size() == 1)
2063 {
2064 shape = get_node_attr_ai(*node3, "shape");
2065 }
2066 else
2067 {
2068 // skip weight reshape
2069 if (weights.find(node3->input(1)) == weights.end())
2070 continue;
2071
2072 shape = get_node_attr_from_input_ai(weights[node3->input(1)]);
2073 }
2074
2075 // 0 0 -1
2076 if (shape.size() != 3)
2077 continue;
2078
2079 if (shape[0] != 0 || shape[1] != 0 || shape[2] != -1)
2080 continue;
2081
2082 // reduce
2083 node2->set_op_type("noop_reducedncnn");
2084 node3->set_op_type("noop_reducedncnn");
2085
2086 node_reference[node->output(0)] -= 1;
2087 node_reference[node2->output(0)] -= 1;
2088 if (node3->input_size() == 2)
2089 {
2090 node_reference[node3->input(1)] -= 1;
2091 }
2092
2093 blob_names.erase(node->output(0));
2094 blob_names.erase(node2->output(0));
2095
2096 node->set_output(0, node3->output(0));
2097
2098 reduced_node_count += 2;
2099 i += 2;
2100
2101 if (i + 1 < node_count)
2102 {
2103 if (node_reference[node3->output(0)] != 1)
2104 continue;
2105
2106 onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 1);
2107
2108 if (node4->op_type() != "Transpose")
2109 continue;
2110
2111 if (node4->input(0) != node->output(0))
2112 continue;
2113
2114 // 1 0 2
2115 std::vector<int> perm4 = get_node_attr_ai(*node4, "perm");
2116 if (perm4.size() != 3)
2117 continue;
2118
2119 if (perm4[0] != 1 || perm4[1] != 0 || perm4[2] != 2)
2120 continue;
2121
2122 // reduce
2123 node4->set_op_type("noop_reducedncnn");
2124
2125 node_reference[node->output(0)] -= 1;
2126
2127 blob_names.erase(node->output(0));
2128
2129 node->clear_output();
2130 node->add_output(node4->output(0));
2131
2132 reduced_node_count += 1;
2133 i += 1;
2134 }
2135 }
2136 }
2137
2138 for (int i = 0; i < node_count; i++)
2139 {
2140 onnx::NodeProto* node = mutable_graph->mutable_node(i);
2141
2142 // LSTM(uni) <= LSTM(uni) - Squeeze - Transpose
2143 if (node->op_type() == "LSTM" || node->op_type() == "GRU" || node->op_type() == "RNN")
2144 {
2145 if (node_reference[node->output(0)] != 1)
2146 continue;
2147
2148 if (i + 1 >= node_count)
2149 continue;
2150
2151 onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
2152
2153 if (node2->op_type() != "Squeeze")
2154 continue;
2155
2156 if (node2->input(0) != node->output(0))
2157 continue;
2158
2159 std::string direction = get_node_attr_s(*node, "direction");
2160 if (direction == "bidirectional")
2161 continue;
2162
2163 // 1
2164 std::vector<int> axes = get_node_attr_ai(*node2, "axes");
2165 if (axes.size() != 1)
2166 continue;
2167
2168 if (axes[0] != 1)
2169 continue;
2170
2171 // reduce
2172 node2->set_op_type("noop_reducedncnn");
2173
2174 node_reference[node->output(0)] -= 1;
2175
2176 blob_names.erase(node->output(0));
2177
2178 node->set_output(0, node2->output(0));
2179
2180 reduced_node_count += 1;
2181 i += 1;
2182
2183 if (i + 1 < node_count)
2184 {
2185 if (node_reference[node2->output(0)] != 1)
2186 continue;
2187
2188 onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 1);
2189
2190 if (node3->op_type() != "Transpose")
2191 continue;
2192
2193 if (node3->input(0) != node->output(0))
2194 continue;
2195
2196 // 1 0 2
2197 std::vector<int> perm4 = get_node_attr_ai(*node3, "perm");
2198 if (perm4.size() != 3)
2199 continue;
2200
2201 if (perm4[0] != 1 || perm4[1] != 0 || perm4[2] != 2)
2202 continue;
2203
2204 // reduce
2205 node3->set_op_type("noop_reducedncnn");
2206
2207 node_reference[node->output(0)] -= 1;
2208
2209 blob_names.erase(node->output(0));
2210
2211 node->clear_output();
2212 node->add_output(node3->output(0));
2213
2214 reduced_node_count += 1;
2215 i += 1;
2216 }
2217 }
2218 }
2219
2220 for (int i = 0; i < node_count; i++)
2221 {
2222 onnx::NodeProto* node = mutable_graph->mutable_node(i);
2223
2224 // LSTM <= Transpose - LSTM
2225 if (node->op_type() == "Transpose")
2226 {
2227 if (node_reference[node->output(0)] != 1)
2228 continue;
2229
2230 // 1 0 2
2231 std::vector<int> perm = get_node_attr_ai(*node, "perm");
2232 if (perm.size() != 3)
2233 continue;
2234
2235 if (perm[0] != 1 || perm[1] != 0 || perm[2] != 2)
2236 continue;
2237
2238 if (i + 1 >= node_count)
2239 continue;
2240
2241 onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
2242
2243 if (node2->op_type() != "LSTM" && node->op_type() != "GRU" && node->op_type() != "RNN")
2244 continue;
2245
2246 if (node2->input(0) != node->output(0))
2247 continue;
2248
2249 // reduce
2250 node->set_op_type("noop_reducedncnn");
2251
2252 node_reference[node->output(0)] -= 1;
2253
2254 blob_names.erase(node->output(0));
2255
2256 node2->set_input(0, node->input(0));
2257
2258 reduced_node_count += 1;
2259 i += 1;
2260 }
2261 }
2262 }
2263
fuse_multiheadattention(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)2264 static void fuse_multiheadattention(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
2265 {
2266 int node_count = mutable_graph->node_size();
2267 for (int i = 0; i < node_count; i++)
2268 {
2269 onnx::NodeProto* node = mutable_graph->mutable_node(i);
2270
2271 // MultiHeadAttention <= MatMul(q) - Add
2272 // - MatMul(k) - Add
2273 // - MatMul(v) - Add
2274 // - Mul
2275 // - Reshape - Transpose
2276 // - Reshape - Reshape - Transpose - Transpose
2277 // - Gemm - Softmax - Gemm - Transpose - Reshape - MatMul - Add
2278 if (node->op_type() == "MatMul")
2279 {
2280 if (i + 19 >= node_count)
2281 continue;
2282
2283 if (node_reference[node->output(0)] != 1)
2284 continue;
2285
2286 onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
2287 onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
2288 onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3);
2289 onnx::NodeProto* node5 = mutable_graph->mutable_node(i + 4);
2290 onnx::NodeProto* node6 = mutable_graph->mutable_node(i + 5);
2291 onnx::NodeProto* node7 = mutable_graph->mutable_node(i + 6);
2292 onnx::NodeProto* node8 = mutable_graph->mutable_node(i + 7);
2293 onnx::NodeProto* node9 = mutable_graph->mutable_node(i + 8);
2294 onnx::NodeProto* node10 = mutable_graph->mutable_node(i + 9);
2295 onnx::NodeProto* node11 = mutable_graph->mutable_node(i + 10);
2296 onnx::NodeProto* node12 = mutable_graph->mutable_node(i + 11);
2297 onnx::NodeProto* node13 = mutable_graph->mutable_node(i + 12);
2298 onnx::NodeProto* node14 = mutable_graph->mutable_node(i + 13);
2299 onnx::NodeProto* node15 = mutable_graph->mutable_node(i + 14);
2300 onnx::NodeProto* node16 = mutable_graph->mutable_node(i + 15);
2301 onnx::NodeProto* node17 = mutable_graph->mutable_node(i + 16);
2302 onnx::NodeProto* node18 = mutable_graph->mutable_node(i + 17);
2303 onnx::NodeProto* node19 = mutable_graph->mutable_node(i + 18);
2304 onnx::NodeProto* node20 = mutable_graph->mutable_node(i + 19);
2305
2306 if (node2->op_type() != "Add" || node3->op_type() != "MatMul" || node4->op_type() != "Add" || node5->op_type() != "MatMul" || node6->op_type() != "Add" || node7->op_type() != "Mul" || node8->op_type() != "Reshape" || node9->op_type() != "Transpose" || node10->op_type() != "Reshape" || node11->op_type() != "Reshape" || node12->op_type() != "Transpose" || node13->op_type() != "Transpose" || node14->op_type() != "MatMul" || node15->op_type() != "Softmax" || node16->op_type() != "MatMul" || node17->op_type() != "Transpose" || node18->op_type() != "Reshape" || node19->op_type() != "MatMul" || node20->op_type() != "Add")
2307 continue;
2308
2309 if (node_reference[node2->output(0)] != 1 || node_reference[node3->output(0)] != 1 || node_reference[node4->output(0)] != 1 || node_reference[node5->output(0)] != 1 || node_reference[node6->output(0)] != 1 || node_reference[node7->output(0)] != 1 || node_reference[node8->output(0)] != 1 || node_reference[node9->output(0)] != 1 || node_reference[node10->output(0)] != 1 || node_reference[node11->output(0)] != 1 || node_reference[node12->output(0)] != 1 || node_reference[node13->output(0)] != 1 || node_reference[node14->output(0)] != 1 || node_reference[node15->output(0)] != 1 || node_reference[node16->output(0)] != 1 || node_reference[node17->output(0)] != 1 || node_reference[node18->output(0)] != 1 || node_reference[node19->output(0)] != 1)
2310 continue;
2311
2312 if (node2->input(0) != node->output(0) || node4->input(0) != node3->output(0) || node6->input(0) != node5->output(0) || node7->input(0) != node2->output(0) || node8->input(0) != node7->output(0) || node9->input(0) != node8->output(0) || node10->input(0) != node4->output(0) || node11->input(0) != node6->output(0) || node12->input(0) != node11->output(0) || node13->input(0) != node10->output(0) || node14->input(0) != node9->output(0) || node14->input(1) != node13->output(0) || node15->input(0) != node14->output(0) || node16->input(0) != node15->output(0) || node16->input(1) != node12->output(0) || node17->input(0) != node16->output(0) || node18->input(0) != node17->output(0) || node19->input(0) != node18->output(0) || node20->input(0) != node19->output(0))
2313 continue;
2314
2315 std::vector<float> q_B = get_node_attr_from_input_af(weights[node2->input(1)]);
2316 std::vector<float> k_B = get_node_attr_from_input_af(weights[node4->input(1)]);
2317 std::vector<float> v_B = get_node_attr_from_input_af(weights[node6->input(1)]);
2318 std::vector<float> o_B = get_node_attr_from_input_af(weights[node20->input(1)]);
2319
2320 if (q_B.size() != k_B.size() || q_B.size() != v_B.size() || q_B.size() != o_B.size())
2321 continue;
2322
2323 int embed_dim = q_B.size();
2324
2325 // 1 0 2
2326 std::vector<int> perm9 = get_node_attr_ai(*node9, "perm");
2327 std::vector<int> perm12 = get_node_attr_ai(*node12, "perm");
2328 if (perm9.size() != 3 || perm12.size() != 3)
2329 continue;
2330
2331 if (perm9[0] != 1 || perm9[1] != 0 || perm9[2] != 2 || perm12[0] != 1 || perm12[1] != 0 || perm12[2] != 2)
2332 continue;
2333
2334 // 1 2 0
2335 std::vector<int> perm13 = get_node_attr_ai(*node13, "perm");
2336 if (perm13.size() != 3)
2337 continue;
2338
2339 if (perm13[0] != 1 || perm13[1] != 2 || perm13[2] != 0)
2340 continue;
2341
2342 // 1 0 2
2343 std::vector<int> perm17 = get_node_attr_ai(*node17, "perm");
2344 if (perm17.size() != 3)
2345 continue;
2346
2347 if (perm17[0] != 1 || perm17[1] != 0 || perm17[2] != 2)
2348 continue;
2349
2350 int softmax_axis = get_node_attr_i(*node15, "axis");
2351 if (softmax_axis != 2)
2352 continue;
2353
2354 // 1/-1, seqlen * num_heads, embed_dim / num_heads
2355 std::vector<int> shape8;
2356 std::vector<int> shape10;
2357 std::vector<int> shape11;
2358 if (node8->input_size() == 1)
2359 {
2360 shape8 = get_node_attr_ai(*node8, "shape");
2361 }
2362 else
2363 {
2364 // skip weight reshape
2365 if (weights.find(node8->input(1)) == weights.end())
2366 continue;
2367
2368 shape8 = get_node_attr_from_input_ai(weights[node8->input(1)]);
2369 }
2370 if (node10->input_size() == 1)
2371 {
2372 shape10 = get_node_attr_ai(*node10, "shape");
2373 }
2374 else
2375 {
2376 // skip weight reshape
2377 if (weights.find(node10->input(1)) == weights.end())
2378 continue;
2379
2380 shape10 = get_node_attr_from_input_ai(weights[node10->input(1)]);
2381 }
2382 if (node11->input_size() == 1)
2383 {
2384 shape11 = get_node_attr_ai(*node11, "shape");
2385 }
2386 else
2387 {
2388 // skip weight reshape
2389 if (weights.find(node11->input(1)) == weights.end())
2390 continue;
2391
2392 shape11 = get_node_attr_from_input_ai(weights[node11->input(1)]);
2393 }
2394
2395 if (shape8.size() != 3 || shape10.size() != 3 || shape11.size() != 3)
2396 continue;
2397
2398 if (shape8[1] != shape10[1] || shape8[1] != shape11[1] || shape8[2] != shape10[2] || shape8[2] != shape11[2])
2399 continue;
2400
2401 int num_heads = embed_dim / shape8[2];
2402
2403 // 1, seqlen, embed_dim
2404 std::vector<int> shape18;
2405 if (node18->input_size() == 1)
2406 {
2407 shape18 = get_node_attr_ai(*node18, "shape");
2408 }
2409 else
2410 {
2411 // skip weight reshape
2412 if (weights.find(node18->input(1)) == weights.end())
2413 continue;
2414
2415 shape18 = get_node_attr_from_input_ai(weights[node18->input(1)]);
2416 }
2417
2418 if (shape18.size() != 3)
2419 continue;
2420
2421 if (shape18[2] != embed_dim || shape18[1] * num_heads != shape8[1])
2422 continue;
2423
2424 // reduce
2425 node->set_op_type("noop_reducedncnn");
2426 node2->set_op_type("noop_reducedncnn");
2427 node3->set_op_type("noop_reducedncnn");
2428 node4->set_op_type("noop_reducedncnn");
2429 node5->set_op_type("noop_reducedncnn");
2430 node6->set_op_type("noop_reducedncnn");
2431 node7->set_op_type("noop_reducedncnn");
2432 node8->set_op_type("noop_reducedncnn");
2433 node9->set_op_type("noop_reducedncnn");
2434 node10->set_op_type("noop_reducedncnn");
2435 node11->set_op_type("noop_reducedncnn");
2436 node12->set_op_type("noop_reducedncnn");
2437 node13->set_op_type("noop_reducedncnn");
2438 node14->set_op_type("noop_reducedncnn");
2439 node15->set_op_type("noop_reducedncnn");
2440 node16->set_op_type("noop_reducedncnn");
2441 node17->set_op_type("noop_reducedncnn");
2442 node18->set_op_type("noop_reducedncnn");
2443 node19->set_op_type("noop_reducedncnn");
2444
2445 node_reference[node2->input(0)] -= 1;
2446 node_reference[node4->input(0)] -= 1;
2447 node_reference[node6->input(0)] -= 1;
2448 node_reference[node7->input(0)] -= 1;
2449 node_reference[node7->input(1)] -= 1;
2450 node_reference[node8->input(0)] -= 1;
2451 if (node8->input_size() == 2)
2452 {
2453 node_reference[node8->input(1)] -= 1;
2454 }
2455 node_reference[node9->input(0)] -= 1;
2456 node_reference[node10->input(0)] -= 1;
2457 if (node10->input_size() == 2)
2458 {
2459 node_reference[node10->input(1)] -= 1;
2460 }
2461 node_reference[node11->input(0)] -= 1;
2462 if (node11->input_size() == 2)
2463 {
2464 node_reference[node11->input(1)] -= 1;
2465 }
2466 node_reference[node12->input(0)] -= 1;
2467 node_reference[node13->input(0)] -= 1;
2468 node_reference[node14->input(0)] -= 1;
2469 node_reference[node14->input(1)] -= 1;
2470 node_reference[node15->input(0)] -= 1;
2471 node_reference[node16->input(0)] -= 1;
2472 node_reference[node16->input(1)] -= 1;
2473 node_reference[node17->input(0)] -= 1;
2474 node_reference[node18->input(0)] -= 1;
2475 if (node18->input_size() == 2)
2476 {
2477 node_reference[node18->input(1)] -= 1;
2478 }
2479 node_reference[node19->input(0)] -= 1;
2480 node_reference[node20->input(0)] -= 1;
2481
2482 blob_names.erase(node->output(0));
2483 blob_names.erase(node2->output(0));
2484 blob_names.erase(node3->output(0));
2485 blob_names.erase(node4->output(0));
2486 blob_names.erase(node5->output(0));
2487 blob_names.erase(node6->output(0));
2488 blob_names.erase(node7->output(0));
2489 blob_names.erase(node8->output(0));
2490 blob_names.erase(node9->output(0));
2491 blob_names.erase(node10->output(0));
2492 blob_names.erase(node11->output(0));
2493 blob_names.erase(node12->output(0));
2494 blob_names.erase(node13->output(0));
2495 blob_names.erase(node14->output(0));
2496 blob_names.erase(node15->output(0));
2497 blob_names.erase(node16->output(0));
2498 blob_names.erase(node17->output(0));
2499 blob_names.erase(node18->output(0));
2500 blob_names.erase(node19->output(0));
2501
2502 std::string qw = node->input(1);
2503 std::string qb = node2->input(1);
2504 std::string kw = node3->input(1);
2505 std::string kb = node4->input(1);
2506 std::string vw = node5->input(1);
2507 std::string vb = node6->input(1);
2508 std::string ow = node19->input(1);
2509 std::string ob = node20->input(1);
2510
2511 node20->set_op_type("MultiHeadAttention");
2512 node20->clear_input();
2513 node20->add_input(node->input(0));
2514 node20->add_input(node3->input(0));
2515 node20->add_input(node5->input(0));
2516 // q
2517 node20->add_input(qw);
2518 node20->add_input(qb);
2519 // k
2520 node20->add_input(kw);
2521 node20->add_input(kb);
2522 // v
2523 node20->add_input(vw);
2524 node20->add_input(vb);
2525 // out linear
2526 node20->add_input(ow);
2527 node20->add_input(ob);
2528
2529 onnx::AttributeProto* attr_embed_dim = node20->add_attribute();
2530 attr_embed_dim->set_name("embed_dim");
2531 attr_embed_dim->set_i(embed_dim);
2532
2533 onnx::AttributeProto* attr_num_heads = node20->add_attribute();
2534 attr_num_heads->set_name("num_heads");
2535 attr_num_heads->set_i(num_heads);
2536
2537 reduced_node_count += 19;
2538 i += 19;
2539 }
2540 }
2541
2542 for (int i = 0; i < node_count; i++)
2543 {
2544 onnx::NodeProto* node = mutable_graph->mutable_node(i);
2545
2546 // MultiHeadAttention <= MatMul(qkv) - Add - Split
2547 // - Mul
2548 // - Reshape - Transpose
2549 // - Reshape - Reshape - Transpose - Transpose
2550 // - Gemm - Softmax - Gemm - Transpose - Reshape - MatMul - Add
2551 if (node->op_type() == "MatMul")
2552 {
2553 if (i + 16 >= node_count)
2554 continue;
2555
2556 if (node_reference[node->output(0)] != 1)
2557 continue;
2558
2559 onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
2560 onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
2561 onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3);
2562 onnx::NodeProto* node5 = mutable_graph->mutable_node(i + 4);
2563 onnx::NodeProto* node6 = mutable_graph->mutable_node(i + 5);
2564 onnx::NodeProto* node7 = mutable_graph->mutable_node(i + 6);
2565 onnx::NodeProto* node8 = mutable_graph->mutable_node(i + 7);
2566 onnx::NodeProto* node9 = mutable_graph->mutable_node(i + 8);
2567 onnx::NodeProto* node10 = mutable_graph->mutable_node(i + 9);
2568 onnx::NodeProto* node11 = mutable_graph->mutable_node(i + 10);
2569 onnx::NodeProto* node12 = mutable_graph->mutable_node(i + 11);
2570 onnx::NodeProto* node13 = mutable_graph->mutable_node(i + 12);
2571 onnx::NodeProto* node14 = mutable_graph->mutable_node(i + 13);
2572 onnx::NodeProto* node15 = mutable_graph->mutable_node(i + 14);
2573 onnx::NodeProto* node16 = mutable_graph->mutable_node(i + 15);
2574 onnx::NodeProto* node17 = mutable_graph->mutable_node(i + 16);
2575
2576 if (node2->op_type() != "Add" || node3->op_type() != "Split" || node4->op_type() != "Mul" || node5->op_type() != "Reshape" || node6->op_type() != "Transpose" || node7->op_type() != "Reshape" || node8->op_type() != "Reshape" || node9->op_type() != "Transpose" || node10->op_type() != "Transpose" || node11->op_type() != "MatMul" || node12->op_type() != "Softmax" || node13->op_type() != "MatMul" || node14->op_type() != "Transpose" || node15->op_type() != "Reshape" || node16->op_type() != "MatMul" || node17->op_type() != "Add")
2577 continue;
2578
2579 if (node_reference[node2->output(0)] != 1 || node_reference[node3->output(0)] != 1 || node_reference[node3->output(1)] != 1 || node_reference[node3->output(2)] != 1 || node_reference[node4->output(0)] != 1 || node_reference[node5->output(0)] != 1 || node_reference[node6->output(0)] != 1 || node_reference[node7->output(0)] != 1 || node_reference[node8->output(0)] != 1 || node_reference[node9->output(0)] != 1 || node_reference[node10->output(0)] != 1 || node_reference[node11->output(0)] != 1 || node_reference[node12->output(0)] != 1 || node_reference[node13->output(0)] != 1 || node_reference[node14->output(0)] != 1 || node_reference[node15->output(0)] != 1 || node_reference[node16->output(0)] != 1)
2580 continue;
2581
2582 if (node2->input(0) != node->output(0) || node3->input(0) != node2->output(0) || node4->input(0) != node3->output(0) || node5->input(0) != node4->output(0) || node6->input(0) != node5->output(0) || node7->input(0) != node3->output(1) || node8->input(0) != node3->output(2) || node9->input(0) != node8->output(0) || node10->input(0) != node7->output(0) || node11->input(0) != node6->output(0) || node11->input(1) != node10->output(0) || node12->input(0) != node11->output(0) || node13->input(0) != node12->output(0) || node13->input(1) != node9->output(0) || node14->input(0) != node13->output(0) || node15->input(0) != node14->output(0) || node16->input(0) != node15->output(0) || node17->input(0) != node16->output(0))
2583 continue;
2584
2585 std::vector<float> qkv_B = get_node_attr_from_input_af(weights[node2->input(1)]);
2586 std::vector<float> o_B = get_node_attr_from_input_af(weights[node17->input(1)]);
2587
2588 if (qkv_B.size() != o_B.size() * 3)
2589 continue;
2590
2591 int embed_dim = o_B.size();
2592
2593 // 1 0 2
2594 std::vector<int> perm6 = get_node_attr_ai(*node6, "perm");
2595 std::vector<int> perm9 = get_node_attr_ai(*node9, "perm");
2596 if (perm6.size() != 3 || perm9.size() != 3)
2597 continue;
2598
2599 if (perm6[0] != 1 || perm6[1] != 0 || perm6[2] != 2 || perm9[0] != 1 || perm9[1] != 0 || perm9[2] != 2)
2600 continue;
2601
2602 // 1 2 0
2603 std::vector<int> perm10 = get_node_attr_ai(*node10, "perm");
2604 if (perm10.size() != 3)
2605 continue;
2606
2607 if (perm10[0] != 1 || perm10[1] != 2 || perm10[2] != 0)
2608 continue;
2609
2610 // 1 0 2
2611 std::vector<int> perm14 = get_node_attr_ai(*node14, "perm");
2612 if (perm14.size() != 3)
2613 continue;
2614
2615 if (perm14[0] != 1 || perm14[1] != 0 || perm14[2] != 2)
2616 continue;
2617
2618 int softmax_axis = get_node_attr_i(*node12, "axis");
2619 if (softmax_axis != 2)
2620 continue;
2621
2622 // 1/-1, seqlen * num_heads, embed_dim / num_heads
2623 std::vector<int> shape5;
2624 std::vector<int> shape7;
2625 std::vector<int> shape8;
2626 if (node5->input_size() == 1)
2627 {
2628 shape5 = get_node_attr_ai(*node5, "shape");
2629 }
2630 else
2631 {
2632 // skip weight reshape
2633 if (weights.find(node5->input(1)) == weights.end())
2634 continue;
2635
2636 shape5 = get_node_attr_from_input_ai(weights[node5->input(1)]);
2637 }
2638 if (node7->input_size() == 1)
2639 {
2640 shape7 = get_node_attr_ai(*node7, "shape");
2641 }
2642 else
2643 {
2644 // skip weight reshape
2645 if (weights.find(node7->input(1)) == weights.end())
2646 continue;
2647
2648 shape7 = get_node_attr_from_input_ai(weights[node7->input(1)]);
2649 }
2650 if (node8->input_size() == 1)
2651 {
2652 shape8 = get_node_attr_ai(*node8, "shape");
2653 }
2654 else
2655 {
2656 // skip weight reshape
2657 if (weights.find(node8->input(1)) == weights.end())
2658 continue;
2659
2660 shape8 = get_node_attr_from_input_ai(weights[node8->input(1)]);
2661 }
2662
2663 if (shape5.size() != 3 || shape7.size() != 3 || shape8.size() != 3)
2664 continue;
2665
2666 if (shape5[1] != shape7[1] || shape5[1] != shape8[1] || shape5[2] != shape7[2] || shape5[2] != shape8[2])
2667 continue;
2668
2669 int num_heads = embed_dim / shape5[2];
2670
2671 // 1, seqlen, embed_dim
2672 std::vector<int> shape15;
2673 if (node15->input_size() == 1)
2674 {
2675 shape15 = get_node_attr_ai(*node15, "shape");
2676 }
2677 else
2678 {
2679 // skip weight reshape
2680 if (weights.find(node15->input(1)) == weights.end())
2681 continue;
2682
2683 shape15 = get_node_attr_from_input_ai(weights[node15->input(1)]);
2684 }
2685
2686 if (shape15.size() != 3)
2687 continue;
2688
2689 if (shape15[2] != embed_dim || shape15[1] * num_heads != shape8[1])
2690 continue;
2691
2692 // reduce
2693 node->set_op_type("noop_reducedncnn");
2694 node2->set_op_type("noop_reducedncnn");
2695 node3->set_op_type("noop_reducedncnn");
2696 node4->set_op_type("noop_reducedncnn");
2697 node5->set_op_type("noop_reducedncnn");
2698 node6->set_op_type("noop_reducedncnn");
2699 node7->set_op_type("noop_reducedncnn");
2700 node8->set_op_type("noop_reducedncnn");
2701 node9->set_op_type("noop_reducedncnn");
2702 node10->set_op_type("noop_reducedncnn");
2703 node11->set_op_type("noop_reducedncnn");
2704 node12->set_op_type("noop_reducedncnn");
2705 node13->set_op_type("noop_reducedncnn");
2706 node14->set_op_type("noop_reducedncnn");
2707 node15->set_op_type("noop_reducedncnn");
2708 node16->set_op_type("noop_reducedncnn");
2709
2710 node_reference[node2->input(0)] -= 1;
2711 node_reference[node3->input(0)] -= 1;
2712 node_reference[node4->input(0)] -= 1;
2713 node_reference[node4->input(1)] -= 1;
2714 node_reference[node5->input(0)] -= 1;
2715 if (node5->input_size() == 2)
2716 {
2717 node_reference[node5->input(1)] -= 1;
2718 }
2719 node_reference[node6->input(0)] -= 1;
2720 node_reference[node7->input(0)] -= 1;
2721 if (node7->input_size() == 2)
2722 {
2723 node_reference[node7->input(1)] -= 1;
2724 }
2725 node_reference[node8->input(0)] -= 1;
2726 if (node8->input_size() == 2)
2727 {
2728 node_reference[node8->input(1)] -= 1;
2729 }
2730 node_reference[node9->input(0)] -= 1;
2731 node_reference[node10->input(0)] -= 1;
2732 node_reference[node11->input(0)] -= 1;
2733 node_reference[node11->input(1)] -= 1;
2734 node_reference[node12->input(0)] -= 1;
2735 node_reference[node13->input(0)] -= 1;
2736 node_reference[node13->input(1)] -= 1;
2737 node_reference[node14->input(0)] -= 1;
2738 node_reference[node15->input(0)] -= 1;
2739 if (node15->input_size() == 2)
2740 {
2741 node_reference[node15->input(1)] -= 1;
2742 }
2743 node_reference[node16->input(0)] -= 1;
2744 node_reference[node17->input(0)] -= 1;
2745
2746 blob_names.erase(node->output(0));
2747 blob_names.erase(node2->output(0));
2748 blob_names.erase(node3->output(0));
2749 blob_names.erase(node3->output(1));
2750 blob_names.erase(node3->output(2));
2751 blob_names.erase(node4->output(0));
2752 blob_names.erase(node5->output(0));
2753 blob_names.erase(node6->output(0));
2754 blob_names.erase(node7->output(0));
2755 blob_names.erase(node8->output(0));
2756 blob_names.erase(node9->output(0));
2757 blob_names.erase(node10->output(0));
2758 blob_names.erase(node11->output(0));
2759 blob_names.erase(node12->output(0));
2760 blob_names.erase(node13->output(0));
2761 blob_names.erase(node14->output(0));
2762 blob_names.erase(node15->output(0));
2763 blob_names.erase(node16->output(0));
2764
2765 std::string qkvw = node->input(1);
2766 std::string qkvb = node2->input(1);
2767 std::string ow = node16->input(1);
2768 std::string ob = node17->input(1);
2769
2770 node17->set_op_type("MultiHeadAttention");
2771 node17->clear_input();
2772 node17->add_input(node->input(0));
2773 // qkv
2774 node17->add_input(qkvw);
2775 node17->add_input(qkvb);
2776 // out linear
2777 node17->add_input(ow);
2778 node17->add_input(ob);
2779
2780 onnx::AttributeProto* attr_embed_dim = node17->add_attribute();
2781 attr_embed_dim->set_name("embed_dim");
2782 attr_embed_dim->set_i(embed_dim);
2783
2784 onnx::AttributeProto* attr_num_heads = node17->add_attribute();
2785 attr_num_heads->set_name("num_heads");
2786 attr_num_heads->set_i(num_heads);
2787
2788 reduced_node_count += 16;
2789 i += 16;
2790 }
2791 }
2792 }
2793
fuse_binaryop_with_scalar(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)2794 static void fuse_binaryop_with_scalar(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
2795 {
2796 int node_count = mutable_graph->node_size();
2797 for (int i = 0; i < node_count; i++)
2798 {
2799 onnx::NodeProto* node = mutable_graph->mutable_node(i);
2800
2801 // Add/Sub/Mul/Div/Min/Max/Pow
2802 if (node->op_type() == "Add" || node->op_type() == "Sub" || node->op_type() == "Mul" || node->op_type() == "Div" || node->op_type() == "Max" || node->op_type() == "Min" || node->op_type() == "Pow")
2803 {
2804 if (weights.find(node->input(1)) == weights.end())
2805 continue;
2806
2807 const onnx::TensorProto& scalar_b = weights[node->input(1)];
2808 if (scalar_b.dims_size() != 0 || get_tensor_proto_data_size(scalar_b) != 1)
2809 continue;
2810
2811 float b = get_node_attr_from_input_f(scalar_b);
2812
2813 node_reference[node->input(1)] -= 1;
2814
2815 std::string input = node->input(0);
2816
2817 node->clear_input();
2818 node->add_input(input);
2819
2820 onnx::AttributeProto* attr_with_scalar = node->add_attribute();
2821 attr_with_scalar->set_name("with_scalar");
2822 attr_with_scalar->set_i(1);
2823
2824 onnx::AttributeProto* attr_b = node->add_attribute();
2825 attr_b->set_name("b");
2826 attr_b->set_f(b);
2827 }
2828 }
2829 }
2830
main(int argc,char ** argv)2831 int main(int argc, char** argv)
2832 {
2833 if (!(argc == 2 || argc == 4))
2834 {
2835 fprintf(stderr, "Usage: %s [onnxpb] [ncnnparam] [ncnnbin]\n", argv[0]);
2836 return -1;
2837 }
2838
2839 const char* onnxpb = argv[1];
2840 const char* ncnn_prototxt = argc == 4 ? argv[2] : "ncnn.param";
2841 const char* ncnn_modelbin = argc == 4 ? argv[3] : "ncnn.bin";
2842
2843 onnx::ModelProto model;
2844
2845 // load
2846 bool s1 = read_proto_from_binary(onnxpb, &model);
2847 if (!s1)
2848 {
2849 fprintf(stderr, "read_proto_from_binary failed\n");
2850 return -1;
2851 }
2852
2853 FILE* pp = fopen(ncnn_prototxt, "wb");
2854 FILE* bp = fopen(ncnn_modelbin, "wb");
2855
2856 // magic
2857 fprintf(pp, "7767517\n");
2858
2859 const onnx::GraphProto& graph = model.graph();
2860 onnx::GraphProto* mutable_graph = model.mutable_graph();
2861
2862 int node_count = graph.node_size();
2863
2864 // node reference
2865 std::map<std::string, int> node_reference;
2866
2867 // weight node and weight reshape node
2868 std::map<std::string, onnx::TensorProto> weights;
2869
2870 for (int j = 0; j < graph.initializer_size(); j++)
2871 {
2872 const onnx::TensorProto& initializer = graph.initializer(j);
2873
2874 // fprintf(stderr, "weight = %s %d\n", initializer.name().c_str(), initializer.data_type());
2875
2876 weights[initializer.name()] = initializer;
2877 }
2878
2879 // topological sort
2880 {
2881 // name -> producer node index
2882 std::set<std::string> producers;
2883 for (int j = 0; j < graph.input_size(); j++)
2884 {
2885 const std::string& input_name = graph.input(j).name();
2886 producers.insert(input_name);
2887 }
2888
2889 for (int i = 0; i < node_count;)
2890 {
2891 onnx::NodeProto* node = mutable_graph->mutable_node(i);
2892
2893 bool swapnode = false;
2894 std::string missing_input_name;
2895 for (int j = 0; j < (int)node->input_size(); j++)
2896 {
2897 const std::string& input_name = node->input(j);
2898 if (input_name.empty())
2899 continue;
2900
2901 if (producers.find(input_name) == producers.end() && weights.find(input_name) == weights.end())
2902 {
2903 swapnode = true;
2904 missing_input_name = input_name;
2905 break;
2906 }
2907 }
2908
2909 if (!swapnode)
2910 {
2911 for (int j = 0; j < (int)node->output_size(); j++)
2912 {
2913 const std::string& output_name = node->output(j);
2914 if (output_name.empty())
2915 continue;
2916
2917 producers.insert(output_name);
2918 }
2919
2920 i++;
2921 continue;
2922 }
2923
2924 // find node that produce missing_input_name
2925 int q = i + 1;
2926 for (; q < node_count; q++)
2927 {
2928 onnx::NodeProto* nodeq = mutable_graph->mutable_node(q);
2929 bool found = false;
2930 for (int j = 0; j < (int)nodeq->output_size(); j++)
2931 {
2932 const std::string& output_name = nodeq->output(j);
2933 if (output_name == missing_input_name)
2934 {
2935 found = true;
2936 break;
2937 }
2938 }
2939
2940 if (found)
2941 break;
2942 }
2943
2944 if (q == node_count)
2945 {
2946 fprintf(stderr, "cannot find node produces %s but node %d requires it\n", missing_input_name.c_str(), i);
2947 return -1;
2948 }
2949
2950 // fprintf(stderr, "swap %d %d\n", i, q);
2951 // swap this node with q
2952 onnx::NodeProto* nodeq = mutable_graph->mutable_node(q);
2953 onnx::NodeProto tmp = *node;
2954 *node = *nodeq;
2955 *nodeq = tmp;
2956 }
2957 }
2958
2959 // global definition line
2960 // [layer count] [blob count]
2961 std::set<std::string> blob_names;
2962 for (int i = 0; i < node_count; i++)
2963 {
2964 const onnx::NodeProto& node = graph.node(i);
2965
2966 const std::string& op = node.op_type();
2967
2968 std::string name = node.name();
2969 if (name.empty())
2970 {
2971 name = node.output(0);
2972 }
2973
2974 if (op == "Constant")
2975 {
2976 onnx::TensorProto tensor = get_node_attr_tensor(node, "value");
2977 weights[node.output(0)] = tensor;
2978 }
2979
2980 for (int j = 0; j < (int)node.input_size(); j++)
2981 {
2982 const std::string& input_name = node.input(j);
2983
2984 blob_names.insert(input_name);
2985
2986 if (node_reference.find(input_name) == node_reference.end())
2987 {
2988 node_reference[input_name] = 1;
2989 }
2990 else
2991 {
2992 node_reference[input_name] = node_reference[input_name] + 1;
2993 }
2994 }
2995
2996 if (op == "Dropout")
2997 {
2998 const std::string& output_name = node.output(0);
2999 blob_names.insert(output_name);
3000 node_reference[output_name] = 0;
3001 continue;
3002 }
3003
3004 for (int j = 0; j < (int)node.output_size(); j++)
3005 {
3006 const std::string& output_name = node.output(j);
3007
3008 blob_names.insert(output_name);
3009
3010 node_reference[output_name] = 0;
3011 }
3012 }
3013
3014 // include Input node
3015 int input_node_count = 0;
3016 for (int j = 0; j < graph.input_size(); j++)
3017 {
3018 const std::string& input_name = graph.input(j).name();
3019
3020 // check weight
3021 if (weights.find(input_name) != weights.end())
3022 continue;
3023
3024 blob_names.insert(input_name);
3025
3026 input_node_count++;
3027 }
3028
3029 // for (auto a: node_reference)
3030 // {
3031 // fprintf(stderr, "a = %s %d\n", a.first.c_str(), a.second);
3032 // }
3033
3034 // op chain fusion
3035 int reduced_node_count = 0;
3036 fuse_weight_reshape(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3037 fuse_weight_transpose(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3038 fuse_shufflechannel(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3039 fuse_shufflechannel_split(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3040 fuse_hardsigmoid(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3041 fuse_hardswish(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3042 fuse_swish(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3043 fuse_batchnorm1d_squeeze_unsqueeze(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3044 fuse_unsqueeze_prelu(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3045 fuse_normalize(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3046 fuse_groupnorm(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3047 fuse_layernorm(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3048 fuse_flatten(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3049 fuse_pixelshuffle(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3050 fuse_reorg(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3051 fuse_expand_broadcast(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3052 fuse_lstm_gru_rnn(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3053 fuse_multiheadattention(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3054 fuse_binaryop_with_scalar(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3055
3056 // reduce common const weight node_reference
3057 for (int i = 0; i < node_count; i++)
3058 {
3059 const onnx::NodeProto& node = graph.node(i);
3060
3061 const std::string& op = node.op_type();
3062
3063 if (op == "BatchNormalization")
3064 {
3065 node_reference[node.input(1)] -= 1;
3066 node_reference[node.input(2)] -= 1;
3067 node_reference[node.input(3)] -= 1;
3068 node_reference[node.input(4)] -= 1;
3069 }
3070 else if (op == "BiasGelu")
3071 {
3072 node_reference[node.input(1)] -= 1;
3073 }
3074 else if (op == "Clip")
3075 {
3076 if (node.input_size() == 3)
3077 {
3078 node_reference[node.input(1)] -= 1;
3079 node_reference[node.input(2)] -= 1;
3080 }
3081 }
3082 else if (op == "Conv")
3083 {
3084 node_reference[node.input(1)] -= 1;
3085 if (node.input_size() == 3)
3086 {
3087 node_reference[node.input(2)] -= 1;
3088 }
3089 }
3090 else if (op == "ConvTranspose")
3091 {
3092 node_reference[node.input(1)] -= 1;
3093 if (node.input_size() == 3)
3094 {
3095 node_reference[node.input(2)] -= 1;
3096 }
3097 }
3098 else if (op == "EmbedLayerNormalization")
3099 {
3100 node_reference[node.input(1)] -= 1;
3101 node_reference[node.input(2)] -= 1;
3102 node_reference[node.input(3)] -= 1;
3103 node_reference[node.input(4)] -= 1;
3104 node_reference[node.input(5)] -= 1;
3105 node_reference[node.input(6)] -= 1;
3106 }
3107 else if (op == "Gemm")
3108 {
3109 float alpha = get_node_attr_f(node, "alpha", 1.f);
3110 float beta = get_node_attr_f(node, "beta", 1.f);
3111 int transA = get_node_attr_i(node, "transA", 0);
3112 int transB = get_node_attr_i(node, "transB", 0);
3113
3114 if (alpha == 1.f && beta == 1.f && transA == 0 && transB == 1)
3115 {
3116 // InnerProduct-like A * B + C
3117 node_reference[node.input(1)] -= 1;
3118 node_reference[node.input(2)] -= 1;
3119 }
3120 }
3121 else if (op == "GroupNorm")
3122 {
3123 int affine = get_node_attr_i(node, "affine", 1);
3124 if (affine)
3125 {
3126 node_reference[node.input(1)] -= 1;
3127 node_reference[node.input(2)] -= 1;
3128 }
3129 }
3130 else if (op == "GRU")
3131 {
3132 for (int j = 1; j < node.input_size(); j++)
3133 {
3134 node_reference[node.input(j)] -= 1;
3135 }
3136 }
3137 else if (op == "InstanceNormalization")
3138 {
3139 node_reference[node.input(1)] -= 1;
3140 node_reference[node.input(2)] -= 1;
3141 }
3142 else if (op == "LayerNorm")
3143 {
3144 int affine = get_node_attr_i(node, "affine", 1);
3145 if (affine)
3146 {
3147 node_reference[node.input(1)] -= 1;
3148 node_reference[node.input(2)] -= 1;
3149 }
3150 }
3151 else if (op == "LSTM")
3152 {
3153 for (int j = 1; j < node.input_size(); j++)
3154 {
3155 node_reference[node.input(j)] -= 1;
3156 }
3157 }
3158 else if (op == "MatMul")
3159 {
3160 if (weights.find(node.input(1)) != weights.end() && weights[node.input(1)].dims_size() == 2)
3161 {
3162 // InnerProduct
3163 node_reference[node.input(1)] -= 1;
3164 }
3165 }
3166 else if (op == "MultiHeadAttention")
3167 {
3168 if (node.input_size() == 5)
3169 {
3170 node_reference[node.input(1)] -= 1;
3171 node_reference[node.input(2)] -= 1;
3172 node_reference[node.input(3)] -= 1;
3173 node_reference[node.input(4)] -= 1;
3174 }
3175 else
3176 {
3177 node_reference[node.input(3)] -= 1;
3178 node_reference[node.input(4)] -= 1;
3179 node_reference[node.input(5)] -= 1;
3180 node_reference[node.input(6)] -= 1;
3181 node_reference[node.input(7)] -= 1;
3182 node_reference[node.input(8)] -= 1;
3183 node_reference[node.input(9)] -= 1;
3184 node_reference[node.input(10)] -= 1;
3185 }
3186 }
3187 else if (op == "Pad")
3188 {
3189 if (node.input_size() >= 2)
3190 {
3191 node_reference[node.input(1)] -= 1;
3192 }
3193 }
3194 else if (op == "PRelu")
3195 {
3196 node_reference[node.input(1)] -= 1;
3197 }
3198 else if (op == "Reshape")
3199 {
3200 if (node.input_size() >= 2)
3201 {
3202 node_reference[node.input(1)] -= 1;
3203 }
3204 }
3205 else if (op == "Resize")
3206 {
3207 if (node.input_size() == 2)
3208 {
3209 // opset 10
3210 node_reference[node.input(1)] -= 1;
3211 }
3212 else
3213 {
3214 // opset 11+
3215 node_reference[node.input(1)] -= 1;
3216 node_reference[node.input(2)] -= 1;
3217 if (node.input_size() >= 4)
3218 {
3219 node_reference[node.input(3)] -= 1;
3220 }
3221 }
3222 }
3223 else if (op == "RNN")
3224 {
3225 for (int j = 1; j < node.input_size(); j++)
3226 {
3227 node_reference[node.input(j)] -= 1;
3228 }
3229 }
3230 else if (op == "SkipLayerNormalization")
3231 {
3232 node_reference[node.input(2)] -= 1;
3233 node_reference[node.input(3)] -= 1;
3234 node_reference[node.input(4)] -= 1;
3235 }
3236 else if (op == "Slice")
3237 {
3238 if (node.input_size() >= 2)
3239 {
3240 node_reference[node.input(1)] -= 1;
3241 node_reference[node.input(2)] -= 1;
3242 if (node.input_size() >= 4)
3243 node_reference[node.input(3)] -= 1;
3244 if (node.input_size() >= 5)
3245 node_reference[node.input(4)] -= 1;
3246 }
3247 }
3248 else if (op == "Upsample")
3249 {
3250 if (node.input_size() >= 2)
3251 {
3252 node_reference[node.input(1)] -= 1;
3253 }
3254 }
3255 else if (op == "adaptive_avg_pool2d" || op == "adaptive_max_pool2d")
3256 {
3257 if (node.input_size() >= 2)
3258 {
3259 node_reference[node.input(1)] -= 1;
3260 }
3261 }
3262 }
3263
3264 // for (auto a: node_reference)
3265 // {
3266 // fprintf(stderr, "b = %s %d\n", a.first.c_str(), a.second);
3267 // }
3268
3269 // count all weight node with zero reference
3270 int zero_reference_weight_node_count = 0;
3271 for (std::map<std::string, onnx::TensorProto>::iterator it = weights.begin(); it != weights.end(); it++)
3272 {
3273 const std::string& input_name = it->first;
3274
3275 int refcount = node_reference[input_name];
3276 if (refcount == 0)
3277 zero_reference_weight_node_count++;
3278 }
3279
3280 // we always treat constant node as weight or binaryop_weights
3281 // do not count it twice for layer_count
3282 int constant_node_count_moved_to_weight = 0;
3283 for (int i = 0; i < node_count; i++)
3284 {
3285 const onnx::NodeProto& node = graph.node(i);
3286
3287 const std::string& op = node.op_type();
3288
3289 if (op == "Constant")
3290 {
3291 constant_node_count_moved_to_weight++;
3292 }
3293 }
3294
3295 // some op may have anonymous input
3296 // LSTM sequence_lens
3297 blob_names.erase("");
3298 node_reference.erase("");
3299
3300 // remove node_reference entry with reference equals to one
3301 int split_layer_count = 0;
3302 int splitncnn_blob_count = 0;
3303 // split node reference
3304 std::map<std::string, int> split_node_reference;
3305 for (std::map<std::string, int>::iterator it = node_reference.begin(); it != node_reference.end(); it++)
3306 {
3307 if (it->second > 1)
3308 {
3309 split_layer_count++;
3310 splitncnn_blob_count += it->second;
3311
3312 split_node_reference[it->first] = it->second;
3313 }
3314 }
3315
3316 fprintf(pp, "%zu %zu\n", node_count - constant_node_count_moved_to_weight + weights.size() - zero_reference_weight_node_count - reduced_node_count + input_node_count + split_layer_count, blob_names.size() - zero_reference_weight_node_count + splitncnn_blob_count);
3317
3318 int internal_split = 0;
3319
3320 // place Input at the beginning
3321 for (int j = 0; j < graph.input_size(); j++)
3322 {
3323 const std::string& input_name = graph.input(j).name();
3324
3325 // check weight
3326 if (weights.find(input_name) != weights.end())
3327 continue;
3328
3329 fprintf(pp, "%-16s %-24s 0 1 %s\n", "Input", input_name.c_str(), input_name.c_str());
3330
3331 int refcount = node_reference[input_name];
3332 if (refcount <= 1)
3333 {
3334 continue;
3335 }
3336
3337 char splitname[256];
3338 sprintf(splitname, "splitncnn_input%d", j);
3339 fprintf(pp, "%-16s %-24s %d %d", "Split", splitname, 1, refcount);
3340 fprintf(pp, " %s", input_name.c_str());
3341
3342 for (int k = 0; k < refcount; k++)
3343 {
3344 fprintf(pp, " %s_splitncnn_%d", input_name.c_str(), k);
3345 }
3346 fprintf(pp, "\n");
3347 }
3348
3349 // place MemoryData next
3350 for (std::map<std::string, onnx::TensorProto>::iterator weight_it = weights.begin(); weight_it != weights.end(); weight_it++)
3351 {
3352 const std::string& input_name = weight_it->first;
3353
3354 int refcount = node_reference[input_name];
3355 if (refcount == 0)
3356 {
3357 continue;
3358 }
3359
3360 fprintf(pp, "%-16s %-24s 0 1 %s", "MemoryData", input_name.c_str(), input_name.c_str());
3361
3362 const onnx::TensorProto& M = weights[input_name];
3363
3364 if (M.dims_size() == 0)
3365 {
3366 fprintf(pp, " 0=%d", get_tensor_proto_data_size(M));
3367 }
3368 else if (M.dims_size() == 1)
3369 {
3370 fprintf(pp, " 0=%d", (int)M.dims(0));
3371 }
3372 else if (M.dims_size() == 2)
3373 {
3374 fprintf(pp, " 0=%d", (int)M.dims(1));
3375 if (M.dims(0) != 1)
3376 {
3377 fprintf(pp, " 1=%d", (int)M.dims(0));
3378 }
3379 }
3380 else if (M.dims_size() == 3)
3381 {
3382 fprintf(pp, " 0=%d", (int)M.dims(2));
3383 fprintf(pp, " 1=%d", (int)M.dims(1));
3384 if (M.dims(0) != 1)
3385 {
3386 fprintf(pp, " 2=%d", (int)M.dims(0));
3387 }
3388 }
3389 else if (M.dims_size() == 4)
3390 {
3391 fprintf(pp, " 0=%d", (int)M.dims(3));
3392 fprintf(pp, " 1=%d", (int)M.dims(2));
3393 fprintf(pp, " 2=%d", (int)M.dims(1));
3394 }
3395
3396 fprintf(pp, "\n");
3397
3398 fwrite_tensor_proto_data(M, bp);
3399
3400 if (refcount <= 1)
3401 {
3402 continue;
3403 }
3404
3405 char splitname[256];
3406 sprintf(splitname, "splitncnn_%d", internal_split);
3407 fprintf(pp, "%-16s %-24s %d %d", "Split", splitname, 1, refcount);
3408
3409 fprintf(pp, " %s", input_name.c_str());
3410
3411 for (int k = 0; k < refcount; k++)
3412 {
3413 fprintf(pp, " %s_splitncnn_%d", input_name.c_str(), k);
3414 }
3415 fprintf(pp, "\n");
3416
3417 internal_split++;
3418 }
3419
3420 for (int i = 0; i < node_count; i++)
3421 {
3422 const onnx::NodeProto& node = graph.node(i);
3423
3424 const std::string& op = node.op_type();
3425
3426 // fprintf(stderr, "op = %s\n", op.c_str());
3427
3428 if (op == "noop_reducedncnn")
3429 {
3430 continue;
3431 }
3432
3433 std::string name = node.name();
3434 if (name.empty())
3435 {
3436 name = node.output(0);
3437 }
3438
3439 int input_size = node.input_size();
3440 int output_size = node.output_size();
3441
3442 for (int j = 0; j < (int)node.input_size(); j++)
3443 {
3444 const std::string& input_name = node.input(j);
3445
3446 // check weight
3447 if (weights.find(input_name) != weights.end() && node_reference[input_name] == 0)
3448 {
3449 input_size--;
3450 }
3451
3452 if (input_name.empty())
3453 {
3454 input_size--;
3455 }
3456
3457 // fprintf(stderr, " input = %s\n", input_name.c_str());
3458 }
3459 /*
3460 for (int j=0; j<(int)node.output_size(); j++)
3461 {
3462 const std::string& output_name = node.output(j);
3463 fprintf(stderr, " output = %s\n", output_name.c_str());
3464 }
3465 */
3466
3467 if (op == "Abs")
3468 {
3469 fprintf(pp, "%-16s", "UnaryOp");
3470 }
3471 else if (op == "Acos")
3472 {
3473 fprintf(pp, "%-16s", "UnaryOp");
3474 }
3475 else if (op == "Add")
3476 {
3477 fprintf(pp, "%-16s", "BinaryOp");
3478 }
3479 else if (op == "Asin")
3480 {
3481 fprintf(pp, "%-16s", "UnaryOp");
3482 }
3483 else if (op == "Atan")
3484 {
3485 fprintf(pp, "%-16s", "UnaryOp");
3486 }
3487 else if (op == "AveragePool" || op == "MaxPool")
3488 {
3489 fprintf(pp, "%-16s", "Pooling");
3490 }
3491 else if (op == "BatchNormalization")
3492 {
3493 fprintf(pp, "%-16s", "BatchNorm");
3494 }
3495 else if (op == "BiasGelu")
3496 {
3497 fprintf(pp, "%-16s", "BiasGelu");
3498 }
3499 else if (op == "Ceil")
3500 {
3501 fprintf(pp, "%-16s", "UnaryOp");
3502 }
3503 else if (op == "Clip")
3504 {
3505 fprintf(pp, "%-16s", "Clip");
3506 }
3507 else if (op == "Concat")
3508 {
3509 fprintf(pp, "%-16s", "Concat");
3510 }
3511 else if (op == "Constant")
3512 {
3513 continue;
3514 }
3515 else if (op == "Conv")
3516 {
3517 int group = get_node_attr_i(node, "group", 1);
3518 if (group > 1)
3519 {
3520 fprintf(pp, "%-16s", "ConvolutionDepthWise");
3521 }
3522 else
3523 {
3524 fprintf(pp, "%-16s", "Convolution");
3525 }
3526 }
3527 else if (op == "ConvTranspose")
3528 {
3529 int group = get_node_attr_i(node, "group", 1);
3530 if (group > 1)
3531 {
3532 fprintf(pp, "%-16s", "DeconvolutionDepthWise");
3533 }
3534 else
3535 {
3536 fprintf(pp, "%-16s", "Deconvolution");
3537 }
3538 }
3539 else if (op == "Cos")
3540 {
3541 fprintf(pp, "%-16s", "UnaryOp");
3542 }
3543 else if (op == "DepthToSpace")
3544 {
3545 fprintf(pp, "%-16s", "PixelShuffle");
3546 }
3547 else if (op == "Div")
3548 {
3549 fprintf(pp, "%-16s", "BinaryOp");
3550 }
3551 else if (op == "Dropout")
3552 {
3553 fprintf(pp, "%-16s", "Dropout");
3554 output_size = 1;
3555 }
3556 else if (op == "Elu")
3557 {
3558 fprintf(pp, "%-16s", "ELU");
3559 }
3560 else if (op == "EmbedLayerNormalization")
3561 {
3562 fprintf(pp, "%-16s", "EmbedLayerNormalization");
3563 }
3564 else if (op == "Exp")
3565 {
3566 fprintf(pp, "%-16s", "UnaryOp");
3567 }
3568 else if (op == "Flatten")
3569 {
3570 fprintf(pp, "%-16s", "Flatten");
3571 }
3572 else if (op == "Floor")
3573 {
3574 fprintf(pp, "%-16s", "UnaryOp");
3575 }
3576 else if (op == "Gemm")
3577 {
3578 float alpha = get_node_attr_f(node, "alpha", 1.f);
3579 float beta = get_node_attr_f(node, "beta", 1.f);
3580 int transA = get_node_attr_i(node, "transA", 0);
3581 int transB = get_node_attr_i(node, "transB", 0);
3582
3583 if (alpha == 1.f && beta == 1.f && transA == 0 && transB == 1)
3584 {
3585 // InnerProduct-like A * B + C
3586 fprintf(pp, "%-16s", "InnerProduct");
3587 }
3588 else
3589 {
3590 fprintf(pp, "%-16s", "Gemm");
3591 }
3592 }
3593 else if (op == "GlobalAveragePool")
3594 {
3595 fprintf(pp, "%-16s", "Pooling");
3596 }
3597 else if (op == "GlobalMaxPool")
3598 {
3599 fprintf(pp, "%-16s", "Pooling");
3600 }
3601 else if (op == "adaptive_avg_pool2d" || op == "adaptive_max_pool2d")
3602 {
3603 fprintf(pp, "%-16s", "Pooling");
3604 }
3605 else if (op == "GroupNorm")
3606 {
3607 fprintf(pp, "%-16s", "GroupNorm");
3608 }
3609 else if (op == "GRU")
3610 {
3611 fprintf(pp, "%-16s", "GRU");
3612 }
3613 else if (op == "HardSigmoid")
3614 {
3615 fprintf(pp, "%-16s", "HardSigmoid");
3616 }
3617 else if (op == "HardSwish")
3618 {
3619 fprintf(pp, "%-16s", "HardSwish");
3620 }
3621 else if (op == "ImageScaler")
3622 {
3623 fprintf(pp, "%-16s", "Scale");
3624 }
3625 else if (op == "InstanceNormalization")
3626 {
3627 fprintf(pp, "%-16s", "InstanceNorm");
3628 }
3629 else if (op == "LayerNorm")
3630 {
3631 fprintf(pp, "%-16s", "LayerNorm");
3632 }
3633 else if (op == "LeakyRelu")
3634 {
3635 fprintf(pp, "%-16s", "ReLU");
3636 }
3637 else if (op == "Log")
3638 {
3639 fprintf(pp, "%-16s", "UnaryOp");
3640 }
3641 else if (op == "LRN")
3642 {
3643 fprintf(pp, "%-16s", "LRN");
3644 }
3645 else if (op == "LSTM")
3646 {
3647 fprintf(pp, "%-16s", "LSTM");
3648 }
3649 else if (op == "MatMul")
3650 {
3651 if (weights.find(node.input(1)) != weights.end() && weights[node.input(1)].dims_size() == 2)
3652 {
3653 fprintf(pp, "%-16s", "InnerProduct");
3654 }
3655 else
3656 {
3657 fprintf(pp, "%-16s", "Gemm");
3658 }
3659 }
3660 else if (op == "Max")
3661 {
3662 fprintf(pp, "%-16s", "BinaryOp");
3663 }
3664 else if (op == "Min")
3665 {
3666 fprintf(pp, "%-16s", "BinaryOp");
3667 }
3668 else if (op == "Mul")
3669 {
3670 fprintf(pp, "%-16s", "BinaryOp");
3671 }
3672 else if (op == "MultiHeadAttention")
3673 {
3674 fprintf(pp, "%-16s", "MultiHeadAttention");
3675 }
3676 else if (op == "Neg")
3677 {
3678 fprintf(pp, "%-16s", "UnaryOp");
3679 }
3680 else if (op == "Normalize")
3681 {
3682 fprintf(pp, "%-16s", "Normalize");
3683 }
3684 else if (op == "Pad")
3685 {
3686 fprintf(pp, "%-16s", "Padding");
3687 }
3688 else if (op == "PixelShuffle")
3689 {
3690 fprintf(pp, "%-16s", "PixelShuffle");
3691 }
3692 else if (op == "Pow")
3693 {
3694 fprintf(pp, "%-16s", "BinaryOp");
3695 }
3696 else if (op == "PRelu")
3697 {
3698 fprintf(pp, "%-16s", "PReLU");
3699 }
3700 else if (op == "Reciprocal")
3701 {
3702 fprintf(pp, "%-16s", "UnaryOp");
3703 }
3704 else if (op == "ReduceMax" || op == "ReduceMin" || op == "ReduceMean" || op == "ReduceProd" || op == "ReduceSum" || op == "ReduceSumSquare" || op == "ReduceL1" || op == "ReduceL2" || op == "ReduceLogSum" || op == "ReduceLogSumExp")
3705 {
3706 fprintf(pp, "%-16s", "Reduction");
3707 }
3708 else if (op == "Relu")
3709 {
3710 fprintf(pp, "%-16s", "ReLU");
3711 }
3712 else if (op == "Reorg")
3713 {
3714 fprintf(pp, "%-16s", "Reorg");
3715 }
3716 else if (op == "Reshape")
3717 {
3718 fprintf(pp, "%-16s", "Reshape");
3719 }
3720 else if (op == "RNN")
3721 {
3722 fprintf(pp, "%-16s", "RNN");
3723 }
3724 else if (op == "ShuffleChannel")
3725 {
3726 fprintf(pp, "%-16s", "ShuffleChannel");
3727 }
3728 else if (op == "Sigmoid")
3729 {
3730 fprintf(pp, "%-16s", "Sigmoid");
3731 }
3732 else if (op == "Sin")
3733 {
3734 fprintf(pp, "%-16s", "UnaryOp");
3735 }
3736 else if (op == "SkipLayerNormalization")
3737 {
3738 fprintf(pp, "%-16s", "SkipLayerNormalization");
3739 }
3740 else if (op == "Slice")
3741 {
3742 fprintf(pp, "%-16s", "Crop");
3743 }
3744 else if (op == "Softmax")
3745 {
3746 fprintf(pp, "%-16s", "Softmax");
3747 }
3748 else if (op == "Softplus")
3749 {
3750 fprintf(pp, "%-16s", "Softplus");
3751 }
3752 else if (op == "Split")
3753 {
3754 fprintf(pp, "%-16s", "Slice");
3755 }
3756 else if (op == "Sqrt")
3757 {
3758 fprintf(pp, "%-16s", "UnaryOp");
3759 }
3760 else if (op == "Squeeze")
3761 {
3762 fprintf(pp, "%-16s", "Squeeze");
3763 }
3764 else if (op == "Sub")
3765 {
3766 fprintf(pp, "%-16s", "BinaryOp");
3767 }
3768 else if (op == "Sum")
3769 {
3770 fprintf(pp, "%-16s", "Eltwise");
3771 }
3772 else if (op == "Swish")
3773 {
3774 fprintf(pp, "%-16s", "Swish");
3775 }
3776 else if (op == "Tan")
3777 {
3778 fprintf(pp, "%-16s", "UnaryOp");
3779 }
3780 else if (op == "Tanh")
3781 {
3782 fprintf(pp, "%-16s", "UnaryOp");
3783 }
3784 else if (op == "Transpose")
3785 {
3786 fprintf(pp, "%-16s", "Permute");
3787 }
3788 else if (op == "Upsample" || op == "Resize")
3789 {
3790 fprintf(pp, "%-16s", "Interp");
3791 }
3792 else if (op == "Unsqueeze")
3793 {
3794 fprintf(pp, "%-16s", "ExpandDims");
3795 }
3796 else
3797 {
3798 // TODO
3799 fprintf(stderr, "%s not supported yet!\n", op.c_str());
3800 fprintf(pp, "%-16s", op.c_str());
3801 }
3802
3803 fprintf(pp, " %-24s %d %d", name.c_str(), input_size, output_size);
3804
3805 for (int j = 0; j < (int)node.input_size(); j++)
3806 {
3807 std::string input_name = node.input(j);
3808
3809 // check weight
3810 if (weights.find(input_name) != weights.end() && node_reference[input_name] == 0)
3811 {
3812 continue;
3813 }
3814
3815 if (input_name.empty())
3816 {
3817 continue;
3818 }
3819
3820 if (split_node_reference.find(input_name) != split_node_reference.end())
3821 {
3822 int refidx = split_node_reference[input_name] - 1;
3823 split_node_reference[input_name] = refidx;
3824
3825 char splitsuffix[256];
3826 sprintf(splitsuffix, "_splitncnn_%d", refidx);
3827 input_name = input_name + splitsuffix;
3828 }
3829
3830 fprintf(pp, " %s", input_name.c_str());
3831 }
3832
3833 for (int j = 0; j < output_size; j++)
3834 {
3835 const std::string& output_name = node.output(j);
3836
3837 fprintf(pp, " %s", output_name.c_str());
3838 }
3839
3840 if (op == "Abs")
3841 {
3842 int op_type = 0;
3843 fprintf(pp, " 0=%d", op_type);
3844 }
3845 else if (op == "Acos")
3846 {
3847 int op_type = 13;
3848 fprintf(pp, " 0=%d", op_type);
3849 }
3850 else if (op == "Add")
3851 {
3852 int op_type = 0;
3853 fprintf(pp, " 0=%d", op_type);
3854
3855 int with_scalar = get_node_attr_i(node, "with_scalar", 0);
3856 float b = get_node_attr_f(node, "b", 0.f);
3857 if (with_scalar)
3858 {
3859 fprintf(pp, " 1=%d", with_scalar);
3860 fprintf(pp, " 2=%e", b);
3861 }
3862 }
3863 else if (op == "Asin")
3864 {
3865 int op_type = 12;
3866 fprintf(pp, " 0=%d", op_type);
3867 }
3868 else if (op == "Atan")
3869 {
3870 int op_type = 14;
3871 fprintf(pp, " 0=%d", op_type);
3872 }
3873 else if (op == "AveragePool" || op == "MaxPool")
3874 {
3875 std::string auto_pad = get_node_attr_s(node, "auto_pad");
3876 int ceil_mode = get_node_attr_i(node, "ceil_mode", 0);
3877 std::vector<int> kernel_shape = get_node_attr_ai(node, "kernel_shape");
3878 std::vector<int> strides = get_node_attr_ai(node, "strides");
3879 std::vector<int> pads = get_node_attr_ai(node, "pads");
3880
3881 int pool = op == "AveragePool" ? 1 : 0;
3882 int pad_mode = 1;
3883
3884 if (auto_pad == "SAME_UPPER")
3885 {
3886 pad_mode = 2;
3887 }
3888 else if (auto_pad == "SAME_LOWER")
3889 {
3890 pad_mode = 3;
3891 }
3892
3893 if (ceil_mode == 1)
3894 {
3895 pad_mode = 0;
3896 }
3897
3898 fprintf(pp, " 0=%d", pool);
3899
3900 if (kernel_shape.size() == 1)
3901 {
3902 fprintf(pp, " 1=%d", kernel_shape[0]);
3903 }
3904 else if (kernel_shape.size() == 2)
3905 {
3906 fprintf(pp, " 1=%d", kernel_shape[1]);
3907 fprintf(pp, " 11=%d", kernel_shape[0]);
3908 }
3909
3910 if (strides.size() == 1)
3911 {
3912 fprintf(pp, " 2=%d", strides[0]);
3913 }
3914 else if (strides.size() == 2)
3915 {
3916 fprintf(pp, " 2=%d", strides[1]);
3917 fprintf(pp, " 12=%d", strides[0]);
3918 }
3919
3920 if (pads.size() == 1)
3921 {
3922 fprintf(pp, " 3=%d", pads[0]);
3923 }
3924 else if (pads.size() == 2)
3925 {
3926 fprintf(pp, " 3=%d", pads[1]);
3927 fprintf(pp, " 13=%d", pads[0]);
3928 }
3929 else if (pads.size() == 4)
3930 {
3931 fprintf(pp, " 3=%d", pads[1]);
3932 fprintf(pp, " 13=%d", pads[0]);
3933 fprintf(pp, " 14=%d", pads[3]);
3934 fprintf(pp, " 15=%d", pads[2]);
3935 }
3936
3937 fprintf(pp, " 5=%d", pad_mode);
3938
3939 if (op == "AveragePool")
3940 {
3941 int avgpool_count_include_pad = get_node_attr_i(node, "count_include_pad", 0);
3942 fprintf(pp, " 6=%d", avgpool_count_include_pad);
3943 }
3944 }
3945 else if (op == "BatchNormalization")
3946 {
3947 float epsilon = get_node_attr_f(node, "epsilon", 1e-5f);
3948
3949 const onnx::TensorProto& scale = weights[node.input(1)];
3950 const onnx::TensorProto& B = weights[node.input(2)];
3951 const onnx::TensorProto& mean = weights[node.input(3)];
3952 const onnx::TensorProto& var = weights[node.input(4)];
3953
3954 int channels = get_tensor_proto_data_size(scale);
3955
3956 fprintf(pp, " 0=%d", channels);
3957
3958 fwrite_tensor_proto_data(scale, bp);
3959 fwrite_tensor_proto_data(mean, bp);
3960 // apply epsilon to var
3961 {
3962 const float* v = var.has_raw_data() ? (const float*)var.raw_data().data() : var.float_data().data();
3963
3964 for (int j = 0; j < channels; j++)
3965 {
3966 float ve = v[j] + epsilon;
3967 fwrite(&ve, sizeof(float), 1, bp);
3968 }
3969 }
3970 fwrite_tensor_proto_data(B, bp);
3971 }
3972 else if (op == "BiasGelu")
3973 {
3974 const onnx::TensorProto& B = weights[node.input(1)];
3975
3976 fprintf(pp, " 0=%d", get_tensor_proto_data_size(B));
3977
3978 int quantize_tag = 0;
3979 fwrite(&quantize_tag, sizeof(int), 1, bp);
3980
3981 fwrite_tensor_proto_data(B, bp);
3982 }
3983 else if (op == "Ceil")
3984 {
3985 int op_type = 3;
3986 fprintf(pp, " 0=%d", op_type);
3987 }
3988 else if (op == "Clip")
3989 {
3990 float min;
3991 float max;
3992 if (node.input_size() == 1)
3993 {
3994 min = get_node_attr_f(node, "min", -FLT_MAX);
3995 max = get_node_attr_f(node, "max", FLT_MAX);
3996 }
3997 else
3998 {
3999 min = weights.find(node.input(1)) != weights.end() ? get_node_attr_from_input_f(weights[node.input(1)]) : -FLT_MAX;
4000 max = weights.find(node.input(2)) != weights.end() ? get_node_attr_from_input_f(weights[node.input(2)]) : FLT_MAX;
4001 }
4002
4003 fprintf(pp, " 0=%e", min);
4004 fprintf(pp, " 1=%e", max);
4005 }
4006 else if (op == "Concat")
4007 {
4008 int axis = get_node_attr_i(node, "axis", 1);
4009 fprintf(pp, " 0=%d", axis - 1);
4010 }
4011 else if (op == "Constant")
4012 {
4013 // never reach here
4014 }
4015 else if (op == "Conv")
4016 {
4017 const onnx::TensorProto& W = weights[node.input(1)];
4018
4019 int num_filter = W.dims(0);
4020 int has_bias = node.input_size() == 3 ? 1 : 0;
4021
4022 std::string auto_pad = get_node_attr_s(node, "auto_pad");
4023 std::vector<int> kernel_shape = get_node_attr_ai(node, "kernel_shape");
4024 std::vector<int> dilations = get_node_attr_ai(node, "dilations");
4025 std::vector<int> strides = get_node_attr_ai(node, "strides");
4026 std::vector<int> pads = get_node_attr_ai(node, "pads");
4027 int group = get_node_attr_i(node, "group", 1);
4028
4029 fprintf(pp, " 0=%d", num_filter);
4030
4031 if (kernel_shape.size() == 1)
4032 {
4033 fprintf(pp, " 1=%d", kernel_shape[0]);
4034 }
4035 else if (kernel_shape.size() == 2)
4036 {
4037 fprintf(pp, " 1=%d", kernel_shape[1]);
4038 fprintf(pp, " 11=%d", kernel_shape[0]);
4039 }
4040
4041 if (dilations.size() == 1)
4042 {
4043 fprintf(pp, " 2=%d", dilations[0]);
4044 }
4045 else if (dilations.size() == 2)
4046 {
4047 fprintf(pp, " 2=%d", dilations[1]);
4048 fprintf(pp, " 12=%d", dilations[0]);
4049 }
4050
4051 if (strides.size() == 1)
4052 {
4053 fprintf(pp, " 3=%d", strides[0]);
4054 }
4055 else if (strides.size() == 2)
4056 {
4057 fprintf(pp, " 3=%d", strides[1]);
4058 fprintf(pp, " 13=%d", strides[0]);
4059 }
4060
4061 if (auto_pad == "SAME_UPPER")
4062 {
4063 fprintf(pp, " 4=-233");
4064 }
4065 else if (auto_pad == "SAME_LOWER")
4066 {
4067 fprintf(pp, " 4=-234");
4068 }
4069 else
4070 {
4071 if (pads.size() == 1)
4072 {
4073 fprintf(pp, " 4=%d", pads[0]);
4074 }
4075 else if (pads.size() == 2)
4076 {
4077 fprintf(pp, " 4=%d", pads[1]);
4078 fprintf(pp, " 14=%d", pads[0]);
4079 }
4080 else if (pads.size() == 4)
4081 {
4082 fprintf(pp, " 4=%d", pads[1]);
4083 fprintf(pp, " 14=%d", pads[0]);
4084 fprintf(pp, " 15=%d", pads[3]);
4085 fprintf(pp, " 16=%d", pads[2]);
4086 }
4087 }
4088
4089 fprintf(pp, " 5=%d", has_bias);
4090
4091 fprintf(pp, " 6=%d", get_tensor_proto_data_size(W));
4092
4093 if (group > 1)
4094 {
4095 fprintf(pp, " 7=%d", group);
4096 }
4097
4098 int quantize_tag = 0;
4099 fwrite(&quantize_tag, sizeof(int), 1, bp);
4100
4101 fwrite_tensor_proto_data(W, bp);
4102
4103 if (has_bias)
4104 {
4105 const onnx::TensorProto& B = weights[node.input(2)];
4106 fwrite_tensor_proto_data(B, bp);
4107 }
4108 }
4109 else if (op == "ConvTranspose")
4110 {
4111 const onnx::TensorProto& W = weights[node.input(1)];
4112
4113 int has_bias = node.input_size() == 3 ? 1 : 0;
4114
4115 std::string auto_pad = get_node_attr_s(node, "auto_pad");
4116 std::vector<int> kernel_shape = get_node_attr_ai(node, "kernel_shape");
4117 std::vector<int> dilations = get_node_attr_ai(node, "dilations");
4118 std::vector<int> strides = get_node_attr_ai(node, "strides");
4119 std::vector<int> output_padding = get_node_attr_ai(node, "output_padding");
4120 std::vector<int> output_shape = get_node_attr_ai(node, "output_shape");
4121 std::vector<int> pads = get_node_attr_ai(node, "pads");
4122 int group = get_node_attr_i(node, "group", 1);
4123 int num_filter = W.dims(1) * group;
4124
4125 fprintf(pp, " 0=%d", num_filter);
4126
4127 if (kernel_shape.size() == 1)
4128 {
4129 fprintf(pp, " 1=%d", kernel_shape[0]);
4130 }
4131 else if (kernel_shape.size() == 2)
4132 {
4133 fprintf(pp, " 1=%d", kernel_shape[1]);
4134 fprintf(pp, " 11=%d", kernel_shape[0]);
4135 }
4136
4137 if (dilations.size() == 1)
4138 {
4139 fprintf(pp, " 2=%d", dilations[0]);
4140 }
4141 else if (dilations.size() == 2)
4142 {
4143 fprintf(pp, " 2=%d", dilations[1]);
4144 fprintf(pp, " 12=%d", dilations[0]);
4145 }
4146
4147 if (strides.size() == 1)
4148 {
4149 fprintf(pp, " 3=%d", strides[0]);
4150 }
4151 else if (strides.size() == 2)
4152 {
4153 fprintf(pp, " 3=%d", strides[1]);
4154 fprintf(pp, " 13=%d", strides[0]);
4155 }
4156
4157 if (auto_pad == "SAME_UPPER")
4158 {
4159 fprintf(pp, " 4=-233");
4160 }
4161 else if (auto_pad == "SAME_LOWER")
4162 {
4163 fprintf(pp, " 4=-234");
4164 }
4165 else
4166 {
4167 if (pads.size() == 1)
4168 {
4169 fprintf(pp, " 4=%d", pads[0]);
4170 }
4171 else if (pads.size() == 2)
4172 {
4173 fprintf(pp, " 4=%d", pads[1]);
4174 fprintf(pp, " 14=%d", pads[0]);
4175 }
4176 else if (pads.size() == 4)
4177 {
4178 fprintf(pp, " 4=%d", pads[1]);
4179 fprintf(pp, " 14=%d", pads[0]);
4180 fprintf(pp, " 15=%d", pads[3]);
4181 fprintf(pp, " 16=%d", pads[2]);
4182 }
4183 }
4184
4185 if (output_padding.size() == 1)
4186 {
4187 fprintf(pp, " 18=%d", output_padding[0]);
4188 }
4189 else if (output_padding.size() == 2)
4190 {
4191 fprintf(pp, " 18=%d", output_padding[1]);
4192 fprintf(pp, " 19=%d", output_padding[0]);
4193 }
4194
4195 if (output_shape.size() == 1)
4196 {
4197 fprintf(pp, " 20=%d", output_shape[0]);
4198 }
4199 else if (output_shape.size() == 2)
4200 {
4201 fprintf(pp, " 20=%d", output_shape[1]);
4202 fprintf(pp, " 21=%d", output_shape[0]);
4203 }
4204
4205 fprintf(pp, " 5=%d", has_bias);
4206
4207 fprintf(pp, " 6=%d", get_tensor_proto_data_size(W));
4208
4209 if (group > 1)
4210 {
4211 fprintf(pp, " 7=%d", group);
4212 }
4213
4214 int quantize_tag = 0;
4215 fwrite(&quantize_tag, sizeof(int), 1, bp);
4216
4217 int maxk = 0;
4218 if (kernel_shape.size() == 2)
4219 {
4220 maxk = kernel_shape[1] * kernel_shape[0];
4221 }
4222 else
4223 {
4224 maxk = kernel_shape[0] * kernel_shape[0];
4225 }
4226 int weight_data_size = get_tensor_proto_data_size(W);
4227 const float* weight_data = 0;
4228 if (W.has_raw_data())
4229 {
4230 weight_data = (const float*)W.raw_data().data();
4231 }
4232 else if (W.data_type() == 1)
4233 {
4234 weight_data = W.float_data().data();
4235 }
4236 for (int g = 0; g < group; g++)
4237 {
4238 // reorder weight from inch-outch to outch-inch
4239 int num_filter_g = num_filter / group;
4240 int num_input = weight_data_size / maxk / num_filter_g / group;
4241 const float* weight_data_ptr = weight_data + g * maxk * num_filter_g * num_input;
4242 for (int k = 0; k < num_filter_g; k++)
4243 {
4244 for (int j = 0; j < num_input; j++)
4245 {
4246 fwrite(weight_data_ptr + (j * num_filter_g + k) * maxk, sizeof(float), maxk, bp);
4247 }
4248 }
4249 }
4250
4251 if (has_bias)
4252 {
4253 const onnx::TensorProto& B = weights[node.input(2)];
4254 fwrite_tensor_proto_data(B, bp);
4255 }
4256 }
4257 else if (op == "Cos")
4258 {
4259 int op_type = 10;
4260 fprintf(pp, " 0=%d", op_type);
4261 }
4262 else if (op == "DepthToSpace")
4263 {
4264 // pixelshuffle
4265 int scale_factor = get_node_attr_i(node, "blocksize", 1);
4266 std::string mode = get_node_attr_s(node, "mode");
4267 fprintf(pp, " 0=%d", scale_factor);
4268 if (mode == "CRD")
4269 {
4270 fprintf(pp, " 1=0");
4271 }
4272 else if (mode == "DCR")
4273 {
4274 fprintf(pp, " 1=1");
4275 }
4276 }
4277 else if (op == "Div")
4278 {
4279 int op_type = 3;
4280 fprintf(pp, " 0=%d", op_type);
4281
4282 int with_scalar = get_node_attr_i(node, "with_scalar", 0);
4283 float b = get_node_attr_f(node, "b", 0.f);
4284 if (with_scalar)
4285 {
4286 fprintf(pp, " 1=%d", with_scalar);
4287 fprintf(pp, " 2=%e", b);
4288 }
4289 }
4290 else if (op == "Dropout")
4291 {
4292 // no-op
4293 }
4294 else if (op == "Elu")
4295 {
4296 float alpha = get_node_attr_f(node, "alpha", 1.f);
4297 fprintf(pp, " 0=%e", alpha);
4298 }
4299 else if (op == "EmbedLayerNormalization")
4300 {
4301 const onnx::TensorProto& words = weights[node.input(2)];
4302 const onnx::TensorProto& positions = weights[node.input(3)];
4303 const onnx::TensorProto& W = weights[node.input(5)];
4304 const onnx::TensorProto& B = weights[node.input(6)];
4305
4306 fprintf(pp, " 0=%d", get_tensor_proto_data_size(B));
4307 fprintf(pp, " 1=%d", get_tensor_proto_data_size(words));
4308 fprintf(pp, " 2=%d", get_tensor_proto_data_size(positions));
4309
4310 int quantize_tag = 0;
4311 fwrite(&quantize_tag, sizeof(int), 1, bp);
4312
4313 fwrite_tensor_proto_data(words, bp);
4314
4315 fwrite(&quantize_tag, sizeof(int), 1, bp);
4316
4317 fwrite_tensor_proto_data(positions, bp);
4318
4319 fwrite(&quantize_tag, sizeof(int), 1, bp);
4320
4321 fwrite_tensor_proto_data(W, bp);
4322
4323 fwrite(&quantize_tag, sizeof(int), 1, bp);
4324
4325 fwrite_tensor_proto_data(B, bp);
4326 }
4327 else if (op == "Exp")
4328 {
4329 int op_type = 7;
4330 fprintf(pp, " 0=%d", op_type);
4331 }
4332 else if (op == "Flatten")
4333 {
4334 int axis = get_node_attr_i(node, "axis", 1);
4335 if (axis != 1)
4336 {
4337 fprintf(stderr, "Unsupported Flatten axis %d!\n", axis);
4338 }
4339 }
4340 else if (op == "Floor")
4341 {
4342 int op_type = 2;
4343 fprintf(pp, " 0=%d", op_type);
4344 }
4345 else if (op == "Gemm")
4346 {
4347 float alpha = get_node_attr_f(node, "alpha", 1.f);
4348 float beta = get_node_attr_f(node, "beta", 1.f);
4349 int transA = get_node_attr_i(node, "transA", 0);
4350 int transB = get_node_attr_i(node, "transB", 0);
4351
4352 if (alpha == 1.f && beta == 1.f && transA == 0 && transB == 1)
4353 {
4354 // InnerProduct-like A * B + C
4355 const onnx::TensorProto& B = weights[node.input(1)];
4356 const onnx::TensorProto& C = weights[node.input(2)];
4357
4358 fprintf(pp, " 0=%d", get_tensor_proto_data_size(C));
4359 fprintf(pp, " 1=1");
4360 fprintf(pp, " 2=%d", get_tensor_proto_data_size(B));
4361
4362 int quantize_tag = 0;
4363 fwrite(&quantize_tag, sizeof(int), 1, bp);
4364
4365 fwrite_tensor_proto_data(B, bp);
4366 fwrite_tensor_proto_data(C, bp);
4367 }
4368 else
4369 {
4370 // gemm
4371 fprintf(pp, " 0=%e", alpha);
4372 fprintf(pp, " 1=%e", beta);
4373 fprintf(pp, " 2=%d", transA);
4374 fprintf(pp, " 3=%d", transB);
4375 }
4376 }
4377 else if (op == "GlobalAveragePool")
4378 {
4379 int pool = 1;
4380 int global_pool = 1;
4381
4382 fprintf(pp, " 0=%d", pool);
4383 fprintf(pp, " 4=%d", global_pool);
4384 }
4385 else if (op == "GlobalMaxPool")
4386 {
4387 int pool = 0;
4388 int global_pool = 1;
4389
4390 fprintf(pp, " 0=%d", pool);
4391 fprintf(pp, " 4=%d", global_pool);
4392 }
4393 else if (op == "adaptive_avg_pool2d" || op == "adaptive_max_pool2d")
4394 {
4395 int pool = 0;
4396 if (op == "adaptive_avg_pool2d")
4397 {
4398 pool = 1;
4399 }
4400 int adaptive_pooling = 1;
4401 const onnx::TensorProto& out_shape_tp = weights[node.input(1)];
4402 std::vector<int> out_shape = get_node_attr_from_input_ai(out_shape_tp);
4403
4404 fprintf(pp, " 0=%d", pool);
4405 fprintf(pp, " 7=%d", adaptive_pooling);
4406 if (out_shape.size() == 1)
4407 {
4408 fprintf(pp, " 8=%d", out_shape[0]);
4409 }
4410 else if (out_shape.size() == 2)
4411 {
4412 // out_w
4413 fprintf(pp, " 8=%d", out_shape[1]);
4414 // out_h
4415 fprintf(pp, " 18=%d", out_shape[0]);
4416 }
4417 }
4418 else if (op == "GroupNorm")
4419 {
4420 int groups = get_node_attr_i(node, "groups", 1);
4421 int channels = get_node_attr_i(node, "channels", 1);
4422 float eps = get_node_attr_f(node, "epsilon", 1e-5f);
4423 int affine = get_node_attr_i(node, "affine", 1);
4424
4425 if (affine)
4426 {
4427 // discard affine-less S=1 B=0
4428 std::vector<float> affine_S = get_node_attr_from_input_af(weights[node.input(1)]);
4429 std::vector<float> affine_B = get_node_attr_from_input_af(weights[node.input(2)]);
4430 if (affine_S.size() == 1 && affine_S[0] == 1.f && affine_B.size() == 1 && affine_B[0] == 0.f)
4431 {
4432 affine = 0;
4433 }
4434 else
4435 {
4436 affine = 0;
4437 {
4438 for (int j = 0; j < channels; j++)
4439 {
4440 if (affine_S[j] != 1.f || affine_B[j] != 0.f)
4441 {
4442 affine = 1;
4443 break;
4444 }
4445 }
4446 }
4447 }
4448 }
4449
4450 fprintf(pp, " 0=%d", groups);
4451 fprintf(pp, " 1=%d", channels);
4452 fprintf(pp, " 2=%e", eps);
4453 fprintf(pp, " 3=%d", affine);
4454 if (affine)
4455 {
4456 const onnx::TensorProto& scale = weights[node.input(1)];
4457 const onnx::TensorProto& B = weights[node.input(2)];
4458
4459 fwrite_tensor_proto_data(scale, bp);
4460 fwrite_tensor_proto_data(B, bp);
4461 }
4462 }
4463 else if (op == "GRU")
4464 {
4465 const onnx::TensorProto& W = weights[node.input(1)];
4466 const onnx::TensorProto& R = weights[node.input(2)];
4467 const onnx::TensorProto& B = weights[node.input(3)];
4468
4469 int hidden_size = get_node_attr_i(node, "hidden_size", 0);
4470 std::string direction = get_node_attr_s(node, "direction");
4471
4472 int direction_type = 0;
4473 if (direction == "forward")
4474 {
4475 direction_type = 0;
4476 }
4477 else if (direction == "reverse")
4478 {
4479 direction_type = 1;
4480 }
4481 else if (direction == "bidirectional")
4482 {
4483 direction_type = 2;
4484 }
4485
4486 int weight_data_size = get_tensor_proto_data_size(W);
4487
4488 fprintf(pp, " 0=%d", hidden_size);
4489 fprintf(pp, " 1=%d", weight_data_size);
4490 fprintf(pp, " 2=%d", direction_type);
4491
4492 int num_directions = direction_type == 2 ? 2 : 1;
4493
4494 int quantize_tag = 0;
4495
4496 // reorder num_directions-URN-hidden-size to num_directions-RUN-hidden-size
4497 {
4498 fwrite(&quantize_tag, sizeof(int), 1, bp);
4499
4500 int weight_data_size_g = get_tensor_proto_data_size(W) / 3 / num_directions;
4501 const float* wptr = W.has_raw_data() ? (const float*)W.raw_data().data() : W.float_data().data();
4502
4503 const float* uptr = wptr;
4504 const float* rptr = wptr + weight_data_size_g;
4505 const float* nptr = wptr + weight_data_size_g * 2;
4506 fwrite(rptr, sizeof(float), weight_data_size_g, bp);
4507 fwrite(uptr, sizeof(float), weight_data_size_g, bp);
4508 fwrite(nptr, sizeof(float), weight_data_size_g, bp);
4509
4510 if (direction_type == 2)
4511 {
4512 uptr += weight_data_size_g * 3;
4513 rptr += weight_data_size_g * 3;
4514 nptr += weight_data_size_g * 3;
4515 fwrite(rptr, sizeof(float), weight_data_size_g, bp);
4516 fwrite(uptr, sizeof(float), weight_data_size_g, bp);
4517 fwrite(nptr, sizeof(float), weight_data_size_g, bp);
4518 }
4519 }
4520
4521 // reduce U and R bias except N
4522 // reorder num_directions-URN-hidden to num_directions-RUN-hidden
4523 {
4524 fwrite(&quantize_tag, sizeof(int), 1, bp);
4525
4526 int bias_data_size_g = get_tensor_proto_data_size(B) / 2 / 3 / num_directions;
4527 const float* bptr = B.has_raw_data() ? (const float*)B.raw_data().data() : B.float_data().data();
4528 const float* wuptr = bptr;
4529 const float* wrptr = bptr + bias_data_size_g;
4530 const float* wnptr = bptr + bias_data_size_g * 2;
4531 const float* buptr = bptr + bias_data_size_g * 3;
4532 const float* brptr = bptr + bias_data_size_g * 4;
4533 const float* bnptr = bptr + bias_data_size_g * 5;
4534
4535 for (int j = 0; j < bias_data_size_g; j++)
4536 {
4537 float vb = wrptr[j] + brptr[j];
4538 fwrite(&vb, sizeof(float), 1, bp);
4539 }
4540 for (int j = 0; j < bias_data_size_g; j++)
4541 {
4542 float vb = wuptr[j] + buptr[j];
4543 fwrite(&vb, sizeof(float), 1, bp);
4544 }
4545 fwrite(wnptr, sizeof(float), bias_data_size_g, bp);
4546 fwrite(bnptr, sizeof(float), bias_data_size_g, bp);
4547
4548 if (direction_type == 2)
4549 {
4550 wuptr += bias_data_size_g * 6;
4551 wrptr += bias_data_size_g * 6;
4552 wnptr += bias_data_size_g * 6;
4553 buptr += bias_data_size_g * 6;
4554 brptr += bias_data_size_g * 6;
4555 bnptr += bias_data_size_g * 6;
4556
4557 for (int j = 0; j < bias_data_size_g; j++)
4558 {
4559 float vb = wrptr[j] + brptr[j];
4560 fwrite(&vb, sizeof(float), 1, bp);
4561 }
4562 for (int j = 0; j < bias_data_size_g; j++)
4563 {
4564 float vb = wuptr[j] + buptr[j];
4565 fwrite(&vb, sizeof(float), 1, bp);
4566 }
4567 fwrite(wnptr, sizeof(float), bias_data_size_g, bp);
4568 fwrite(bnptr, sizeof(float), bias_data_size_g, bp);
4569 }
4570 }
4571
4572 // reorder num_directions-URN-hidden-hidden to num_directions-RUN-hidden-hidden
4573 {
4574 fwrite(&quantize_tag, sizeof(int), 1, bp);
4575
4576 int weight_data_size_g = get_tensor_proto_data_size(R) / 3 / num_directions;
4577 const float* Rptr = R.has_raw_data() ? (const float*)R.raw_data().data() : R.float_data().data();
4578
4579 const float* uptr = Rptr;
4580 const float* rptr = Rptr + weight_data_size_g;
4581 const float* nptr = Rptr + weight_data_size_g * 2;
4582 fwrite(rptr, sizeof(float), weight_data_size_g, bp);
4583 fwrite(uptr, sizeof(float), weight_data_size_g, bp);
4584 fwrite(nptr, sizeof(float), weight_data_size_g, bp);
4585
4586 if (direction_type == 2)
4587 {
4588 uptr += weight_data_size_g * 3;
4589 rptr += weight_data_size_g * 3;
4590 nptr += weight_data_size_g * 3;
4591 fwrite(rptr, sizeof(float), weight_data_size_g, bp);
4592 fwrite(uptr, sizeof(float), weight_data_size_g, bp);
4593 fwrite(nptr, sizeof(float), weight_data_size_g, bp);
4594 }
4595 }
4596 }
4597 else if (op == "HardSigmoid")
4598 {
4599 float alpha = get_node_attr_f(node, "alpha", 0.2f);
4600 float beta = get_node_attr_f(node, "beta", 0.5f);
4601
4602 fprintf(pp, " 0=%e", alpha);
4603 fprintf(pp, " 1=%e", beta);
4604 }
4605 else if (op == "HardSwish")
4606 {
4607 float alpha = get_node_attr_f(node, "alpha", 0.2f);
4608 float beta = get_node_attr_f(node, "beta", 0.5f);
4609
4610 fprintf(pp, " 0=%e", alpha);
4611 fprintf(pp, " 1=%e", beta);
4612 }
4613 else if (op == "ImageScaler")
4614 {
4615 std::vector<float> bias = get_node_attr_af(node, "bias");
4616 float scale = get_node_attr_f(node, "scale", 1.f);
4617
4618 int channels = (int)bias.size();
4619
4620 fprintf(pp, " 0=%d", channels);
4621 fprintf(pp, " 1=1");
4622
4623 for (int j = 0; j < channels; j++)
4624 {
4625 fwrite(&scale, sizeof(float), 1, bp);
4626 }
4627 fwrite(&bias[0], sizeof(float), channels, bp);
4628 }
4629 else if (op == "InstanceNormalization")
4630 {
4631 float eps = get_node_attr_f(node, "epsilon", 1e-5f);
4632
4633 // discard affine-less S=1 B=0
4634 std::vector<float> affine_S = get_node_attr_from_input_af(weights[node.input(1)]);
4635 std::vector<float> affine_B = get_node_attr_from_input_af(weights[node.input(2)]);
4636 int channels = (int)affine_S.size();
4637 int affine = 0;
4638 {
4639 for (int j = 0; j < channels; j++)
4640 {
4641 if (affine_S[j] != 1.f || affine_B[j] != 0.f)
4642 {
4643 affine = 1;
4644 break;
4645 }
4646 }
4647 }
4648
4649 fprintf(pp, " 0=%d", channels);
4650 fprintf(pp, " 1=%e", eps);
4651 fprintf(pp, " 2=%d", affine);
4652 if (affine)
4653 {
4654 const onnx::TensorProto& scale = weights[node.input(1)];
4655 const onnx::TensorProto& B = weights[node.input(2)];
4656
4657 fwrite_tensor_proto_data(scale, bp);
4658 fwrite_tensor_proto_data(B, bp);
4659 }
4660 }
4661 else if (op == "LayerNorm")
4662 {
4663 float eps = get_node_attr_f(node, "epsilon", 1e-5f);
4664 int affine = get_node_attr_i(node, "affine", 1);
4665
4666 if (affine)
4667 {
4668 // discard affine-less S=1 B=0
4669 std::vector<float> affine_S = get_node_attr_from_input_af(weights[node.input(1)]);
4670 std::vector<float> affine_B = get_node_attr_from_input_af(weights[node.input(2)]);
4671 int affine_size = (int)affine_S.size();
4672 affine = 0;
4673 {
4674 for (int j = 0; j < affine_size; j++)
4675 {
4676 if (affine_S[j] != 1.f || affine_B[j] != 0.f)
4677 {
4678 affine = 1;
4679 break;
4680 }
4681 }
4682 }
4683
4684 if (affine)
4685 {
4686 fprintf(pp, " 0=%d", affine_size);
4687 }
4688 }
4689
4690 fprintf(pp, " 1=%e", eps);
4691 fprintf(pp, " 2=%d", affine);
4692
4693 if (affine)
4694 {
4695 const onnx::TensorProto& scale = weights[node.input(1)];
4696 const onnx::TensorProto& B = weights[node.input(2)];
4697
4698 fwrite_tensor_proto_data(scale, bp);
4699 fwrite_tensor_proto_data(B, bp);
4700 }
4701 }
4702 else if (op == "LeakyRelu")
4703 {
4704 float alpha = get_node_attr_f(node, "alpha", 0.01f);
4705
4706 fprintf(pp, " 0=%e", alpha);
4707 }
4708 else if (op == "Log")
4709 {
4710 int op_type = 8;
4711 fprintf(pp, " 0=%d", op_type);
4712 }
4713 else if (op == "LRN")
4714 {
4715 float alpha = get_node_attr_f(node, "alpha", 1.f);
4716 float beta = get_node_attr_f(node, "beta", 0.5f);
4717 float bias = get_node_attr_f(node, "bias", 1.f);
4718 int size = get_node_attr_i(node, "size", 1);
4719
4720 int norm_region = 0;
4721
4722 fprintf(pp, " 0=%d", norm_region);
4723 fprintf(pp, " 1=%d", size);
4724 fprintf(pp, " 2=%e", alpha);
4725 fprintf(pp, " 3=%e", beta);
4726 fprintf(pp, " 4=%e", bias);
4727 }
4728 else if (op == "LSTM")
4729 {
4730 const onnx::TensorProto& W = weights[node.input(1)];
4731 const onnx::TensorProto& R = weights[node.input(2)];
4732 const onnx::TensorProto& B = weights[node.input(3)];
4733
4734 int hidden_size = get_node_attr_i(node, "hidden_size", 0);
4735 std::string direction = get_node_attr_s(node, "direction");
4736
4737 int direction_type = 0;
4738 if (direction == "forward")
4739 {
4740 direction_type = 0;
4741 }
4742 else if (direction == "reverse")
4743 {
4744 direction_type = 1;
4745 }
4746 else if (direction == "bidirectional")
4747 {
4748 direction_type = 2;
4749 }
4750
4751 int weight_data_size = get_tensor_proto_data_size(W);
4752
4753 fprintf(pp, " 0=%d", hidden_size);
4754 fprintf(pp, " 1=%d", weight_data_size);
4755 fprintf(pp, " 2=%d", direction_type);
4756
4757 int num_directions = direction_type == 2 ? 2 : 1;
4758
4759 int quantize_tag = 0;
4760
4761 // reorder num_directions-IOFG-hidden-size to num_directions-IFOG-hidden-size
4762 {
4763 fwrite(&quantize_tag, sizeof(int), 1, bp);
4764
4765 int weight_data_size_g = get_tensor_proto_data_size(W) / 4 / num_directions;
4766 const float* wptr = W.has_raw_data() ? (const float*)W.raw_data().data() : W.float_data().data();
4767
4768 const float* iptr = wptr;
4769 const float* optr = wptr + weight_data_size_g;
4770 const float* fptr = wptr + weight_data_size_g * 2;
4771 const float* gptr = wptr + weight_data_size_g * 3;
4772 fwrite(iptr, sizeof(float), weight_data_size_g, bp);
4773 fwrite(fptr, sizeof(float), weight_data_size_g, bp);
4774 fwrite(optr, sizeof(float), weight_data_size_g, bp);
4775 fwrite(gptr, sizeof(float), weight_data_size_g, bp);
4776
4777 if (direction_type == 2)
4778 {
4779 iptr += weight_data_size_g * 4;
4780 optr += weight_data_size_g * 4;
4781 fptr += weight_data_size_g * 4;
4782 gptr += weight_data_size_g * 4;
4783 fwrite(iptr, sizeof(float), weight_data_size_g, bp);
4784 fwrite(fptr, sizeof(float), weight_data_size_g, bp);
4785 fwrite(optr, sizeof(float), weight_data_size_g, bp);
4786 fwrite(gptr, sizeof(float), weight_data_size_g, bp);
4787 }
4788 }
4789
4790 // reduce xc and hc bias
4791 // reorder num_directions-IOFG-hidden to num_directions-IFOG-hidden
4792 {
4793 fwrite(&quantize_tag, sizeof(int), 1, bp);
4794
4795 int bias_data_size_g = get_tensor_proto_data_size(B) / 2 / 4 / num_directions;
4796 const float* xcbptr = B.has_raw_data() ? (const float*)B.raw_data().data() : B.float_data().data();
4797 const float* xiptr = xcbptr;
4798 const float* xoptr = xcbptr + bias_data_size_g;
4799 const float* xfptr = xcbptr + bias_data_size_g * 2;
4800 const float* xgptr = xcbptr + bias_data_size_g * 3;
4801 const float* hiptr = xcbptr + bias_data_size_g * 4;
4802 const float* hoptr = xcbptr + bias_data_size_g * 5;
4803 const float* hfptr = xcbptr + bias_data_size_g * 6;
4804 const float* hgptr = xcbptr + bias_data_size_g * 7;
4805
4806 for (int j = 0; j < bias_data_size_g; j++)
4807 {
4808 float vb = xiptr[j] + hiptr[j];
4809 fwrite(&vb, sizeof(float), 1, bp);
4810 }
4811 for (int j = 0; j < bias_data_size_g; j++)
4812 {
4813 float vb = xfptr[j] + hfptr[j];
4814 fwrite(&vb, sizeof(float), 1, bp);
4815 }
4816 for (int j = 0; j < bias_data_size_g; j++)
4817 {
4818 float vb = xoptr[j] + hoptr[j];
4819 fwrite(&vb, sizeof(float), 1, bp);
4820 }
4821 for (int j = 0; j < bias_data_size_g; j++)
4822 {
4823 float vb = xgptr[j] + hgptr[j];
4824 fwrite(&vb, sizeof(float), 1, bp);
4825 }
4826
4827 if (direction_type == 2)
4828 {
4829 xiptr += bias_data_size_g * 8;
4830 xoptr += bias_data_size_g * 8;
4831 xfptr += bias_data_size_g * 8;
4832 xgptr += bias_data_size_g * 8;
4833 hiptr += bias_data_size_g * 8;
4834 hoptr += bias_data_size_g * 8;
4835 hfptr += bias_data_size_g * 8;
4836 hgptr += bias_data_size_g * 8;
4837
4838 for (int j = 0; j < bias_data_size_g; j++)
4839 {
4840 float vb = xiptr[j] + hiptr[j];
4841 fwrite(&vb, sizeof(float), 1, bp);
4842 }
4843 for (int j = 0; j < bias_data_size_g; j++)
4844 {
4845 float vb = xfptr[j] + hfptr[j];
4846 fwrite(&vb, sizeof(float), 1, bp);
4847 }
4848 for (int j = 0; j < bias_data_size_g; j++)
4849 {
4850 float vb = xoptr[j] + hoptr[j];
4851 fwrite(&vb, sizeof(float), 1, bp);
4852 }
4853 for (int j = 0; j < bias_data_size_g; j++)
4854 {
4855 float vb = xgptr[j] + hgptr[j];
4856 fwrite(&vb, sizeof(float), 1, bp);
4857 }
4858 }
4859 }
4860
4861 // reorder num_directions-IOFG-hidden-hidden to num_directions-IFOG-hidden-hidden
4862 {
4863 fwrite(&quantize_tag, sizeof(int), 1, bp);
4864
4865 int weight_data_size_g = get_tensor_proto_data_size(R) / 4 / num_directions;
4866 const float* rptr = R.has_raw_data() ? (const float*)R.raw_data().data() : R.float_data().data();
4867
4868 const float* iptr = rptr;
4869 const float* optr = rptr + weight_data_size_g;
4870 const float* fptr = rptr + weight_data_size_g * 2;
4871 const float* gptr = rptr + weight_data_size_g * 3;
4872 fwrite(iptr, sizeof(float), weight_data_size_g, bp);
4873 fwrite(fptr, sizeof(float), weight_data_size_g, bp);
4874 fwrite(optr, sizeof(float), weight_data_size_g, bp);
4875 fwrite(gptr, sizeof(float), weight_data_size_g, bp);
4876
4877 if (direction_type == 2)
4878 {
4879 iptr += weight_data_size_g * 4;
4880 optr += weight_data_size_g * 4;
4881 fptr += weight_data_size_g * 4;
4882 gptr += weight_data_size_g * 4;
4883 fwrite(iptr, sizeof(float), weight_data_size_g, bp);
4884 fwrite(fptr, sizeof(float), weight_data_size_g, bp);
4885 fwrite(optr, sizeof(float), weight_data_size_g, bp);
4886 fwrite(gptr, sizeof(float), weight_data_size_g, bp);
4887 }
4888 }
4889 }
4890 else if (op == "MatMul")
4891 {
4892 if (weights.find(node.input(1)) != weights.end() && weights[node.input(1)].dims_size() == 2)
4893 {
4894 // InnerProduct
4895 const onnx::TensorProto& B = weights[node.input(1)];
4896
4897 int weight_data_size = get_tensor_proto_data_size(B);
4898
4899 int num_output = B.dims(B.dims_size() - 1);
4900 int num_input = weight_data_size / num_output;
4901
4902 fprintf(pp, " 0=%d", num_output);
4903 fprintf(pp, " 1=0");
4904 fprintf(pp, " 2=%d", weight_data_size);
4905
4906 int quantize_tag = 0;
4907 fwrite(&quantize_tag, sizeof(int), 1, bp);
4908
4909 // reorder num_input-num_output to num_output-num_input
4910 {
4911 const float* bptr = B.has_raw_data() ? (const float*)B.raw_data().data() : B.float_data().data();
4912
4913 for (int j = 0; j < num_output; j++)
4914 {
4915 for (int k = 0; k < num_input; k++)
4916 {
4917 float vb = bptr[k * num_output + j];
4918 fwrite(&vb, sizeof(float), 1, bp);
4919 }
4920 }
4921 }
4922
4923 // fwrite_tensor_proto_data(B, bp)
4924 }
4925 else
4926 {
4927 // default matrix multiplication
4928 }
4929 }
4930 else if (op == "Max")
4931 {
4932 int op_type = 4;
4933 fprintf(pp, " 0=%d", op_type);
4934
4935 int with_scalar = get_node_attr_i(node, "with_scalar", 0);
4936 float b = get_node_attr_f(node, "b", 0.f);
4937 if (with_scalar)
4938 {
4939 fprintf(pp, " 1=%d", with_scalar);
4940 fprintf(pp, " 2=%e", b);
4941 }
4942 }
4943 else if (op == "Min")
4944 {
4945 int op_type = 5;
4946 fprintf(pp, " 0=%d", op_type);
4947
4948 int with_scalar = get_node_attr_i(node, "with_scalar", 0);
4949 float b = get_node_attr_f(node, "b", 0.f);
4950 if (with_scalar)
4951 {
4952 fprintf(pp, " 1=%d", with_scalar);
4953 fprintf(pp, " 2=%e", b);
4954 }
4955 }
4956 else if (op == "Mul")
4957 {
4958 int op_type = 2;
4959 fprintf(pp, " 0=%d", op_type);
4960
4961 int with_scalar = get_node_attr_i(node, "with_scalar", 0);
4962 float b = get_node_attr_f(node, "b", 0.f);
4963 if (with_scalar)
4964 {
4965 fprintf(pp, " 1=%d", with_scalar);
4966 fprintf(pp, " 2=%e", b);
4967 }
4968 }
4969 else if (op == "MultiHeadAttention")
4970 {
4971 int embed_dim = get_node_attr_i(node, "embed_dim", 0);
4972 int num_heads = get_node_attr_i(node, "num_heads", 0);
4973
4974 fprintf(pp, " 0=%d", embed_dim);
4975 fprintf(pp, " 1=%d", num_heads);
4976
4977 if (node.input_size() == 5)
4978 {
4979 const onnx::TensorProto& qkvw = weights[node.input(1)];
4980 const onnx::TensorProto& qkvb = weights[node.input(2)];
4981 const onnx::TensorProto& ow = weights[node.input(3)];
4982 const onnx::TensorProto& ob = weights[node.input(4)];
4983
4984 int weight_data_size = get_tensor_proto_data_size(ow);
4985
4986 fprintf(pp, " 2=%d", weight_data_size);
4987
4988 int quantize_tag = 0;
4989
4990 fwrite(&quantize_tag, sizeof(int), 1, bp);
4991 // transpose qw
4992 {
4993 const float* wptr = qkvw.has_raw_data() ? (const float*)qkvw.raw_data().data() : qkvw.float_data().data();
4994 const float* bptr = qkvb.has_raw_data() ? (const float*)qkvb.raw_data().data() : qkvb.float_data().data();
4995
4996 for (int j = 0; j < embed_dim; j++)
4997 {
4998 for (int k = 0; k < embed_dim; k++)
4999 {
5000 float vb = wptr[k * embed_dim * 3 + j];
5001 fwrite(&vb, sizeof(float), 1, bp);
5002 }
5003 }
5004
5005 fwrite(bptr, sizeof(float), embed_dim, bp);
5006 }
5007
5008 fwrite(&quantize_tag, sizeof(int), 1, bp);
5009 // transpose kw
5010 {
5011 const float* wptr = qkvw.has_raw_data() ? (const float*)qkvw.raw_data().data() : qkvw.float_data().data();
5012 const float* bptr = qkvb.has_raw_data() ? (const float*)qkvb.raw_data().data() : qkvb.float_data().data();
5013 bptr += embed_dim;
5014
5015 for (int j = 0; j < embed_dim; j++)
5016 {
5017 for (int k = 0; k < embed_dim; k++)
5018 {
5019 float vb = wptr[k * embed_dim * 3 + j + embed_dim];
5020 fwrite(&vb, sizeof(float), 1, bp);
5021 }
5022 }
5023
5024 fwrite(bptr, sizeof(float), embed_dim, bp);
5025 }
5026
5027 fwrite(&quantize_tag, sizeof(int), 1, bp);
5028 // transpose vw
5029 {
5030 const float* wptr = qkvw.has_raw_data() ? (const float*)qkvw.raw_data().data() : qkvw.float_data().data();
5031 const float* bptr = qkvb.has_raw_data() ? (const float*)qkvb.raw_data().data() : qkvb.float_data().data();
5032 bptr += embed_dim * 2;
5033
5034 for (int j = 0; j < embed_dim; j++)
5035 {
5036 for (int k = 0; k < embed_dim; k++)
5037 {
5038 float vb = wptr[k * embed_dim * 3 + j + embed_dim * 2];
5039 fwrite(&vb, sizeof(float), 1, bp);
5040 }
5041 }
5042
5043 fwrite(bptr, sizeof(float), embed_dim, bp);
5044 }
5045
5046 fwrite(&quantize_tag, sizeof(int), 1, bp);
5047 // transpose ow
5048 {
5049 const float* wptr = ow.has_raw_data() ? (const float*)ow.raw_data().data() : ow.float_data().data();
5050
5051 for (int j = 0; j < embed_dim; j++)
5052 {
5053 for (int k = 0; k < embed_dim; k++)
5054 {
5055 float vb = wptr[k * embed_dim + j];
5056 fwrite(&vb, sizeof(float), 1, bp);
5057 }
5058 }
5059 }
5060 fwrite_tensor_proto_data(ob, bp);
5061 }
5062 else
5063 {
5064 const onnx::TensorProto& qw = weights[node.input(3)];
5065 const onnx::TensorProto& qb = weights[node.input(4)];
5066 const onnx::TensorProto& kw = weights[node.input(5)];
5067 const onnx::TensorProto& kb = weights[node.input(6)];
5068 const onnx::TensorProto& vw = weights[node.input(7)];
5069 const onnx::TensorProto& vb = weights[node.input(8)];
5070 const onnx::TensorProto& ow = weights[node.input(9)];
5071 const onnx::TensorProto& ob = weights[node.input(10)];
5072
5073 int weight_data_size = get_tensor_proto_data_size(qw);
5074
5075 fprintf(pp, " 2=%d", weight_data_size);
5076
5077 int quantize_tag = 0;
5078
5079 fwrite(&quantize_tag, sizeof(int), 1, bp);
5080 // transpose qw
5081 {
5082 const float* wptr = qw.has_raw_data() ? (const float*)qw.raw_data().data() : qw.float_data().data();
5083
5084 for (int j = 0; j < embed_dim; j++)
5085 {
5086 for (int k = 0; k < embed_dim; k++)
5087 {
5088 float vb = wptr[k * embed_dim + j];
5089 fwrite(&vb, sizeof(float), 1, bp);
5090 }
5091 }
5092 }
5093 fwrite_tensor_proto_data(qb, bp);
5094
5095 fwrite(&quantize_tag, sizeof(int), 1, bp);
5096 // transpose kw
5097 {
5098 const float* wptr = kw.has_raw_data() ? (const float*)kw.raw_data().data() : kw.float_data().data();
5099
5100 for (int j = 0; j < embed_dim; j++)
5101 {
5102 for (int k = 0; k < embed_dim; k++)
5103 {
5104 float vb = wptr[k * embed_dim + j];
5105 fwrite(&vb, sizeof(float), 1, bp);
5106 }
5107 }
5108 }
5109 fwrite_tensor_proto_data(kb, bp);
5110
5111 fwrite(&quantize_tag, sizeof(int), 1, bp);
5112 // transpose vw
5113 {
5114 const float* wptr = vw.has_raw_data() ? (const float*)vw.raw_data().data() : vw.float_data().data();
5115
5116 for (int j = 0; j < embed_dim; j++)
5117 {
5118 for (int k = 0; k < embed_dim; k++)
5119 {
5120 float vb = wptr[k * embed_dim + j];
5121 fwrite(&vb, sizeof(float), 1, bp);
5122 }
5123 }
5124 }
5125 fwrite_tensor_proto_data(vb, bp);
5126
5127 fwrite(&quantize_tag, sizeof(int), 1, bp);
5128 // transpose ow
5129 {
5130 const float* wptr = ow.has_raw_data() ? (const float*)ow.raw_data().data() : ow.float_data().data();
5131
5132 for (int j = 0; j < embed_dim; j++)
5133 {
5134 for (int k = 0; k < embed_dim; k++)
5135 {
5136 float vb = wptr[k * embed_dim + j];
5137 fwrite(&vb, sizeof(float), 1, bp);
5138 }
5139 }
5140 }
5141 fwrite_tensor_proto_data(ob, bp);
5142 }
5143 }
5144 else if (op == "Neg")
5145 {
5146 int op_type = 1;
5147 fprintf(pp, " 0=%d", op_type);
5148 }
5149 else if (op == "Normalize")
5150 {
5151 float eps = get_node_attr_f(node, "eps", 0.f);
5152 int scale_data_size = 1;
5153
5154 fprintf(pp, " 1=1"); // channel_shared
5155 fprintf(pp, " 2=%e", eps);
5156 fprintf(pp, " 3=%d", scale_data_size);
5157 fprintf(pp, " 9=1"); // TODO hardcode pytorch style
5158
5159 const float scale_data[1] = {1.f};
5160 fwrite(scale_data, sizeof(float), 1, bp);
5161 }
5162 else if (op == "Pad")
5163 {
5164 std::string mode = get_node_attr_s(node, "mode");
5165 float value = get_node_attr_f(node, "value", 0.f);
5166
5167 std::vector<int> pads;
5168 if (node.input_size() == 1)
5169 {
5170 pads = get_node_attr_ai(node, "pads");
5171 }
5172 else
5173 {
5174 pads = get_node_attr_from_input_ai(weights[node.input(1)]);
5175 }
5176
5177 int type = 0;
5178 if (mode == "constant")
5179 {
5180 type = 0;
5181 }
5182 else if (mode == "edge")
5183 {
5184 type = 1;
5185 }
5186 else if (mode == "reflect")
5187 {
5188 type = 2;
5189 }
5190
5191 int pad_size = (int)pads.size();
5192 int top = 0;
5193 int bottom = 0;
5194 int left = 0;
5195 int right = 0;
5196 int front = 0;
5197 int behind = 0;
5198 if (pad_size == 8)
5199 {
5200 //NCHW
5201 top = pads[2];
5202 bottom = pads[6];
5203 left = pads[3];
5204 right = pads[7];
5205 front = pads[1];
5206 behind = pads[5];
5207 }
5208 else if (pad_size == 6)
5209 {
5210 //NHW
5211 top = pads[1];
5212 bottom = pads[4];
5213 left = pads[2];
5214 right = pads[5];
5215 }
5216 else
5217 {
5218 //NW
5219 left = pads[1];
5220 right = pads[3];
5221 }
5222
5223 fprintf(pp, " 0=%d", top);
5224 fprintf(pp, " 1=%d", bottom);
5225 fprintf(pp, " 2=%d", left);
5226 fprintf(pp, " 3=%d", right);
5227 fprintf(pp, " 4=%d", type);
5228 fprintf(pp, " 5=%e", value);
5229 fprintf(pp, " 7=%d", front);
5230 fprintf(pp, " 8=%d", behind);
5231 }
5232 else if (op == "Pow")
5233 {
5234 int op_type = 6;
5235 fprintf(pp, " 0=%d", op_type);
5236
5237 int with_scalar = get_node_attr_i(node, "with_scalar", 0);
5238 float b = get_node_attr_f(node, "b", 0.f);
5239 if (with_scalar)
5240 {
5241 fprintf(pp, " 1=%d", with_scalar);
5242 fprintf(pp, " 2=%e", b);
5243 }
5244 }
5245 else if (op == "PixelShuffle")
5246 {
5247 int scale_factor = get_node_attr_i(node, "scale_factor", 1);
5248 fprintf(pp, " 0=%d", scale_factor);
5249 }
5250 else if (op == "PRelu")
5251 {
5252 const onnx::TensorProto& slope = weights[node.input(1)];
5253
5254 int num_slope = get_tensor_proto_data_size(slope);
5255
5256 fprintf(pp, " 0=%d", num_slope);
5257
5258 fwrite_tensor_proto_data(slope, bp);
5259 }
5260 else if (op == "Reciprocal")
5261 {
5262 int op_type = 15;
5263 fprintf(pp, " 0=%d", op_type);
5264 }
5265 else if (op == "ReduceMax" || op == "ReduceMin" || op == "ReduceMean" || op == "ReduceProd" || op == "ReduceSum" || op == "ReduceSumSquare" || op == "ReduceL1" || op == "ReduceL2" || op == "ReduceLogSum" || op == "ReduceLogSumExp")
5266 {
5267 int op_type = -233;
5268 if (op == "ReduceSum")
5269 op_type = 0;
5270 else if (op == "ReduceSumSquare")
5271 op_type = 2;
5272 else if (op == "ReduceMean")
5273 op_type = 3;
5274 else if (op == "ReduceMax")
5275 op_type = 4;
5276 else if (op == "ReduceMin")
5277 op_type = 5;
5278 else if (op == "ReduceProd")
5279 op_type = 6;
5280 else if (op == "ReduceL1")
5281 op_type = 7;
5282 else if (op == "ReduceL2")
5283 op_type = 8;
5284 else if (op == "ReduceLogSum")
5285 op_type = 9;
5286 else if (op == "ReduceLogSumExp")
5287 op_type = 10;
5288 fprintf(pp, " 0=%d", op_type);
5289
5290 std::vector<int> axes = get_node_attr_ai(node, "axes");
5291 int keepdims = get_node_attr_i(node, "keepdims", 1);
5292
5293 if (axes.size() > 0)
5294 {
5295 // if axes set, reduce according to axes
5296 fprintf(pp, " 1=%d", 0);
5297 fprintf(pp, " -23303=%zu", axes.size());
5298 for (size_t j = 0; j < axes.size(); j++)
5299 {
5300 if (axes[j] == 0 || axes[j] > 3 || axes[j] < -3)
5301 fprintf(stderr, "Unsupported reduction axes !\n");
5302 fprintf(pp, ",%d", axes[j]);
5303 }
5304 }
5305 else
5306 {
5307 // if axes not set, reduce all axes by default
5308 fprintf(pp, " 1=%d", 1);
5309 }
5310 fprintf(pp, " 4=%d", keepdims);
5311 }
5312 else if (op == "Reorg")
5313 {
5314 int stride = get_node_attr_i(node, "stride", 1);
5315 fprintf(pp, " 0=%d", stride);
5316 }
5317 else if (op == "Reshape")
5318 {
5319 std::vector<int> shape;
5320
5321 if (node.input_size() == 1)
5322 {
5323 shape = get_node_attr_ai(node, "shape");
5324 }
5325 else
5326 {
5327 shape = get_node_attr_from_input_ai(weights[node.input(1)]);
5328 }
5329
5330 if (shape.size() == 1)
5331 {
5332 fprintf(pp, " 0=%d", shape[0]); // should never reach here
5333 }
5334 else if (shape.size() == 2)
5335 {
5336 fprintf(pp, " 0=%d", shape[1]);
5337 }
5338 else if (shape.size() == 3)
5339 {
5340 fprintf(pp, " 0=%d", shape[2]);
5341 fprintf(pp, " 1=%d", shape[1]);
5342 }
5343 else if (shape.size() == 4)
5344 {
5345 fprintf(pp, " 0=%d", shape[3]);
5346 fprintf(pp, " 1=%d", shape[2]);
5347 fprintf(pp, " 2=%d", shape[1]);
5348 }
5349 else if (shape.size() == 5)
5350 {
5351 fprintf(pp, " 0=%d", shape[4] * shape[3]);
5352 fprintf(pp, " 1=%d", shape[2]);
5353 fprintf(pp, " 2=%d", shape[1]);
5354 }
5355 }
5356 else if (op == "Resize")
5357 {
5358 std::string mode = get_node_attr_s(node, "mode");
5359 std::string align = get_node_attr_s(node, "coordinate_transformation_mode");
5360
5361 std::vector<float> scales;
5362 std::vector<int> sizes;
5363 if (node.input_size() == 2)
5364 {
5365 // opset 10
5366 scales = get_node_attr_from_input_af(weights[node.input(1)]);
5367 }
5368 else
5369 {
5370 // opset 11+
5371 scales = get_node_attr_from_input_af(weights[node.input(2)]);
5372 if (node.input_size() >= 4)
5373 {
5374 sizes = get_node_attr_from_input_ai(weights[node.input(3)]);
5375 }
5376 }
5377
5378 int resize_type = 1;
5379 if (mode == "nearest")
5380 {
5381 resize_type = 1;
5382 }
5383 else if (mode == "linear")
5384 {
5385 resize_type = 2;
5386 }
5387 else if (mode == "cubic")
5388 {
5389 resize_type = 3;
5390 }
5391
5392 if (scales.empty() && sizes.empty())
5393 {
5394 fprintf(stderr, "Unsupported Resize scales and sizes are all empty!\n");
5395 }
5396
5397 float h_scale = 1.f;
5398 float w_scale = 1.f;
5399 if (scales.size() == 2)
5400 {
5401 w_scale = scales[1];
5402 }
5403 else if (scales.size() == 3)
5404 {
5405 h_scale = scales[1];
5406 w_scale = scales[2];
5407 }
5408 else if (scales.size() == 4)
5409 {
5410 h_scale = scales[2];
5411 w_scale = scales[3];
5412
5413 if (scales[1] != 1.f)
5414 fprintf(stderr, "Unsupported Resize scales !\n");
5415 }
5416
5417 int output_height = 0;
5418 int output_width = 0;
5419 if (sizes.size() == 2)
5420 {
5421 output_width = sizes[1];
5422 }
5423 else if (sizes.size() == 3)
5424 {
5425 output_height = sizes[1];
5426 output_width = sizes[2];
5427 }
5428 else if (sizes.size() == 4)
5429 {
5430 output_height = sizes[2];
5431 output_width = sizes[3];
5432 }
5433
5434 int align_corner = 0;
5435 if (align == "align_corners")
5436 {
5437 align_corner = 1;
5438 }
5439
5440 fprintf(pp, " 0=%d", resize_type);
5441 fprintf(pp, " 1=%e", h_scale);
5442 fprintf(pp, " 2=%e", w_scale);
5443 fprintf(pp, " 3=%d", output_height);
5444 fprintf(pp, " 4=%d", output_width);
5445 fprintf(pp, " 6=%d", align_corner);
5446 }
5447 else if (op == "RNN")
5448 {
5449 const onnx::TensorProto& W = weights[node.input(1)];
5450 const onnx::TensorProto& R = weights[node.input(2)];
5451 const onnx::TensorProto& B = weights[node.input(3)];
5452
5453 int hidden_size = get_node_attr_i(node, "hidden_size", 0);
5454 std::string direction = get_node_attr_s(node, "direction");
5455
5456 int direction_type = 0;
5457 if (direction == "forward")
5458 {
5459 direction_type = 0;
5460 }
5461 else if (direction == "reverse")
5462 {
5463 direction_type = 1;
5464 }
5465 else if (direction == "bidirectional")
5466 {
5467 direction_type = 2;
5468 }
5469
5470 int weight_data_size = get_tensor_proto_data_size(W);
5471
5472 fprintf(pp, " 0=%d", hidden_size);
5473 fprintf(pp, " 1=%d", weight_data_size);
5474 fprintf(pp, " 2=%d", direction_type);
5475
5476 int num_directions = direction_type == 2 ? 2 : 1;
5477
5478 int quantize_tag = 0;
5479
5480 fwrite(&quantize_tag, sizeof(int), 1, bp);
5481 fwrite_tensor_proto_data(W, bp);
5482
5483 // reduce xc and hc bias
5484 {
5485 fwrite(&quantize_tag, sizeof(int), 1, bp);
5486
5487 int bias_data_size_g = get_tensor_proto_data_size(B) / 2 / num_directions;
5488 const float* bptr = B.has_raw_data() ? (const float*)B.raw_data().data() : B.float_data().data();
5489 const float* xiptr = bptr;
5490 const float* hiptr = bptr + bias_data_size_g;
5491
5492 for (int j = 0; j < bias_data_size_g; j++)
5493 {
5494 float vb = xiptr[j] + hiptr[j];
5495 fwrite(&vb, sizeof(float), 1, bp);
5496 }
5497
5498 if (direction_type == 2)
5499 {
5500 xiptr += bias_data_size_g * 2;
5501 hiptr += bias_data_size_g * 2;
5502
5503 for (int j = 0; j < bias_data_size_g; j++)
5504 {
5505 float vb = xiptr[j] + hiptr[j];
5506 fwrite(&vb, sizeof(float), 1, bp);
5507 }
5508 }
5509 }
5510
5511 fwrite(&quantize_tag, sizeof(int), 1, bp);
5512 fwrite_tensor_proto_data(R, bp);
5513 }
5514 else if (op == "ShuffleChannel")
5515 {
5516 int group = get_node_attr_i(node, "group", 1);
5517 int reverse = get_node_attr_i(node, "reverse", 0);
5518 fprintf(pp, " 0=%d", group);
5519 fprintf(pp, " 1=%d", reverse);
5520 }
5521 else if (op == "Sigmoid")
5522 {
5523 // no param
5524 }
5525 else if (op == "Sin")
5526 {
5527 int op_type = 9;
5528 fprintf(pp, " 0=%d", op_type);
5529 }
5530 else if (op == "SkipLayerNormalization")
5531 {
5532 const onnx::TensorProto& W = weights[node.input(2)];
5533 const onnx::TensorProto& B = weights[node.input(3)];
5534 const onnx::TensorProto& B2 = weights[node.input(4)];
5535
5536 fprintf(pp, " 0=%d", get_tensor_proto_data_size(B));
5537
5538 int quantize_tag = 0;
5539 fwrite(&quantize_tag, sizeof(int), 1, bp);
5540
5541 fwrite_tensor_proto_data(W, bp);
5542
5543 fwrite(&quantize_tag, sizeof(int), 1, bp);
5544
5545 fwrite_tensor_proto_data(B, bp);
5546
5547 fwrite(&quantize_tag, sizeof(int), 1, bp);
5548
5549 fwrite_tensor_proto_data(B2, bp);
5550 }
5551 else if (op == "Slice")
5552 {
5553 std::vector<int> starts;
5554 std::vector<int> ends;
5555 std::vector<int> axes;
5556 std::vector<int> steps;
5557 if (node.input_size() == 1)
5558 {
5559 starts = get_node_attr_ai(node, "starts");
5560 ends = get_node_attr_ai(node, "ends");
5561 axes = get_node_attr_ai(node, "axes");
5562 steps = get_node_attr_ai(node, "steps"); // TODO
5563 }
5564 else
5565 {
5566 starts = get_node_attr_from_input_ai(weights[node.input(1)]);
5567 ends = get_node_attr_from_input_ai(weights[node.input(2)]);
5568 if (node.input_size() >= 4)
5569 axes = get_node_attr_from_input_ai(weights[node.input(3)]);
5570 if (node.input_size() >= 5)
5571 steps = get_node_attr_from_input_ai(weights[node.input(4)]);
5572 }
5573
5574 // assert step == 1
5575 for (int i = 0; i < (int)steps.size(); i++)
5576 {
5577 if (steps[i] != 1)
5578 fprintf(stderr, "Unsupported slice step !\n");
5579 }
5580
5581 // filter out N-dim axis
5582 if (!axes.empty())
5583 {
5584 for (int i = 0; i < (int)axes.size(); i++)
5585 {
5586 int axis = axes[i];
5587 if (axis == 0)
5588 {
5589 starts.erase(starts.begin() + i);
5590 ends.erase(ends.begin() + i);
5591 axes.erase(axes.begin() + i);
5592 break;
5593 }
5594 }
5595 }
5596
5597 fprintf(pp, " -23309=%d", (int)starts.size());
5598 for (int i = 0; i < (int)starts.size(); i++)
5599 {
5600 fprintf(pp, ",%d", starts[i]);
5601 }
5602 fprintf(pp, " -23310=%d", (int)ends.size());
5603 for (int i = 0; i < (int)ends.size(); i++)
5604 {
5605 fprintf(pp, ",%d", ends[i]);
5606 }
5607 if (!axes.empty())
5608 {
5609 fprintf(pp, " -23311=%d", (int)axes.size());
5610 for (int i = 0; i < (int)axes.size(); i++)
5611 {
5612 int axis = axes[i];
5613 if (axis == 0 || axis > 3 || axis < -3)
5614 fprintf(stderr, "Unsupported slice axes !\n");
5615
5616 if (axis > 0)
5617 axis = axis - 1; // -1 for skip N-dim
5618
5619 fprintf(pp, ",%d", axis);
5620 }
5621 }
5622 }
5623 else if (op == "Softmax")
5624 {
5625 int axis = get_node_attr_i(node, "axis", 1);
5626 fprintf(pp, " 0=%d", axis - 1);
5627 fprintf(pp, " 1=1");
5628 }
5629 else if (op == "Split")
5630 {
5631 int axis = get_node_attr_i(node, "axis", 0);
5632 std::vector<int> split = get_node_attr_ai(node, "split");
5633 if (axis < 1)
5634 fprintf(stderr, "Unsupported split axis !\n");
5635
5636 fprintf(pp, " -23300=%d", output_size);
5637 if (split.empty())
5638 {
5639 for (int i = 0; i < output_size; i++)
5640 {
5641 fprintf(pp, ",-233");
5642 }
5643 }
5644 else
5645 {
5646 for (size_t i = 0; i < split.size() - 1; i++)
5647 {
5648 fprintf(pp, ",%d", split[i]);
5649 }
5650 fprintf(pp, ",-233");
5651 }
5652 fprintf(pp, " 1=%d", axis - 1);
5653 }
5654 else if (op == "Sqrt")
5655 {
5656 int op_type = 5;
5657 fprintf(pp, " 0=%d", op_type);
5658 }
5659 else if (op == "Squeeze")
5660 {
5661 std::vector<int> axes = get_node_attr_ai(node, "axes");
5662
5663 if (axes.empty())
5664 {
5665 fprintf(pp, " 0=1");
5666 fprintf(pp, " 1=1");
5667 fprintf(pp, " 2=1");
5668 }
5669 else
5670 {
5671 fprintf(pp, " -23303=%zu", axes.size());
5672 for (int i = 0; i < (int)axes.size(); i++)
5673 {
5674 if (axes[i] == 0 || axes[i] > 3 || axes[i] < -3)
5675 fprintf(stderr, "Unsupported squeeze axes !\n");
5676 fprintf(pp, ",%d", axes[i]);
5677 }
5678 }
5679 }
5680 else if (op == "Sub")
5681 {
5682 int op_type = 1;
5683 fprintf(pp, " 0=%d", op_type);
5684
5685 int with_scalar = get_node_attr_i(node, "with_scalar", 0);
5686 float b = get_node_attr_f(node, "b", 0.f);
5687 if (with_scalar)
5688 {
5689 fprintf(pp, " 1=%d", with_scalar);
5690 fprintf(pp, " 2=%e", b);
5691 }
5692 }
5693 else if (op == "Sum")
5694 {
5695 int op_type = 1;
5696 fprintf(pp, " 0=%d", op_type);
5697 }
5698 else if (op == "Swish")
5699 {
5700 // no param
5701 }
5702 else if (op == "Tan")
5703 {
5704 int op_type = 11;
5705 fprintf(pp, " 0=%d", op_type);
5706 }
5707 else if (op == "Tanh")
5708 {
5709 int op_type = 16;
5710 fprintf(pp, " 0=%d", op_type);
5711 }
5712 else if (op == "Transpose")
5713 {
5714 std::vector<int> perm = get_node_attr_ai(node, "perm");
5715
5716 if (perm.size() == 3)
5717 {
5718 if (perm[1] == 1 && perm[2] == 2)
5719 fprintf(pp, " 0=0"); // w h
5720 else if (perm[1] == 2 && perm[2] == 1)
5721 fprintf(pp, " 0=1"); // h w
5722 else if (perm[0] == 1 && perm[1] == 0 && perm[2] == 2)
5723 fprintf(pp, " 0=0"); // w h
5724 else if (perm[0] == 2 && perm[1] == 0 && perm[2] == 1)
5725 fprintf(pp, " 0=1"); // h w
5726 }
5727 else if (perm.size() == 4)
5728 {
5729 if (perm[1] == 1 && perm[2] == 2 && perm[3] == 3)
5730 fprintf(pp, " 0=0"); // w h c
5731 else if (perm[1] == 1 && perm[2] == 3 && perm[3] == 2)
5732 fprintf(pp, " 0=1"); // h w c
5733 else if (perm[1] == 2 && perm[2] == 1 && perm[3] == 3)
5734 fprintf(pp, " 0=2"); // w c h
5735 else if (perm[1] == 2 && perm[2] == 3 && perm[3] == 1)
5736 fprintf(pp, " 0=3"); // c w h
5737 else if (perm[1] == 3 && perm[2] == 1 && perm[3] == 2)
5738 fprintf(pp, " 0=4"); // h c w
5739 else if (perm[1] == 3 && perm[2] == 2 && perm[3] == 1)
5740 fprintf(pp, " 0=5"); // c h w
5741 }
5742 else if (perm.size() == 5)
5743 {
5744 if (perm[1] == 1 && perm[2] == 2 && perm[3] == 3 && perm[4] == 4)
5745 fprintf(pp, " 0=0"); // wx h c
5746 else if (perm[1] == 1 && perm[2] == 3 && perm[3] == 4 && perm[4] == 2)
5747 fprintf(pp, " 0=1"); // h wx c
5748 else if (perm[1] == 2 && perm[2] == 1 && perm[3] == 3 && perm[4] == 4)
5749 fprintf(pp, " 0=2"); // wx c h
5750 else if (perm[1] == 2 && perm[2] == 3 && perm[3] == 4 && perm[4] == 1)
5751 fprintf(pp, " 0=3"); // c wx h
5752 else if (perm[1] == 3 && perm[2] == 4 && perm[3] == 1 && perm[4] == 2)
5753 fprintf(pp, " 0=4"); // h c wx
5754 else if (perm[1] == 3 && perm[2] == 4 && perm[3] == 2 && perm[4] == 1)
5755 fprintf(pp, " 0=5"); // c h wx
5756 else
5757 fprintf(stderr, "Unsupported transpose type !\n");
5758 }
5759 }
5760 else if (op == "Upsample")
5761 {
5762 std::string mode = get_node_attr_s(node, "mode");
5763 std::string align = get_node_attr_s(node, "coordinate_transformation_mode");
5764
5765 std::vector<float> scales;
5766
5767 if (node.input_size() == 1)
5768 {
5769 scales = get_node_attr_af(node, "scales");
5770 }
5771 else
5772 {
5773 scales = get_node_attr_from_input_af(weights[node.input(1)]);
5774 }
5775
5776 int resize_type = 1;
5777 if (mode == "nearest")
5778 {
5779 resize_type = 1;
5780 }
5781 else if (mode == "bilinear" || mode == "linear")
5782 {
5783 resize_type = 2;
5784 }
5785 else if (mode == "trilinear")
5786 {
5787 fprintf(stderr, "Unsupported Upsample mode !\n");
5788 }
5789
5790 float h_scale = 1.f;
5791 float w_scale = 1.f;
5792 if (scales.size() == 2)
5793 {
5794 w_scale = scales[1];
5795 }
5796 else if (scales.size() == 3)
5797 {
5798 h_scale = scales[1];
5799 w_scale = scales[2];
5800 }
5801 else if (scales.size() == 4)
5802 {
5803 h_scale = scales[2];
5804 w_scale = scales[3];
5805
5806 if (scales[1] != 1.f)
5807 fprintf(stderr, "Unsupported Upsample scales !\n");
5808 }
5809 else
5810 {
5811 fprintf(stderr, "Unsupported Upsample scales !\n");
5812 }
5813
5814 int align_corner = 0;
5815 if (align == "align_corners")
5816 {
5817 align_corner = 1;
5818 }
5819
5820 fprintf(pp, " 0=%d", resize_type);
5821 fprintf(pp, " 1=%e", h_scale);
5822 fprintf(pp, " 2=%e", w_scale);
5823 fprintf(pp, " 6=%d", align_corner);
5824 }
5825 else if (op == "Unsqueeze")
5826 {
5827 std::vector<int> axes = get_node_attr_ai(node, "axes");
5828
5829 fprintf(pp, " -23303=%zu", axes.size());
5830 for (int i = 0; i < (int)axes.size(); i++)
5831 {
5832 if (axes[i] == 0 || axes[i] > 4 || axes[i] < -4)
5833 fprintf(stderr, "Unsupported unsqueeze axes !\n");
5834 fprintf(pp, ",%d", axes[i]);
5835 }
5836 }
5837 else
5838 {
5839 // TODO op specific param
5840 for (int j = 0; j < node.attribute_size(); j++)
5841 {
5842 const onnx::AttributeProto& attr = node.attribute(j);
5843 if (attr.type() == 1)
5844 {
5845 fprintf(stderr, " # %s=%g\n", attr.name().c_str(), attr.f());
5846 }
5847 else if (attr.type() == 2)
5848 {
5849 fprintf(stderr, " # %s=%lld\n", attr.name().c_str(), (long long)attr.i());
5850 }
5851 else if (attr.type() == 3)
5852 {
5853 fprintf(stderr, " # %s=%s\n", attr.name().c_str(), attr.s().c_str());
5854 }
5855 else
5856 {
5857 fprintf(stderr, " # %s %d\n", attr.name().c_str(), attr.type());
5858 }
5859 }
5860 }
5861
5862 fprintf(pp, "\n");
5863
5864 for (int j = 0; j < output_size; j++)
5865 {
5866 const std::string& output_name = node.output(j);
5867 if (node_reference.find(output_name) != node_reference.end())
5868 {
5869 int refcount = node_reference[output_name];
5870 if (refcount > 1)
5871 {
5872 char splitname[256];
5873 sprintf(splitname, "splitncnn_%d", internal_split);
5874 fprintf(pp, "%-16s %-24s %d %d", "Split", splitname, 1, refcount);
5875
5876 fprintf(pp, " %s", output_name.c_str());
5877
5878 for (int k = 0; k < refcount; k++)
5879 {
5880 fprintf(pp, " %s_splitncnn_%d", output_name.c_str(), k);
5881 }
5882 fprintf(pp, "\n");
5883
5884 internal_split++;
5885 }
5886 }
5887 }
5888 }
5889
5890 fclose(pp);
5891 fclose(bp);
5892
5893 return 0;
5894 }
5895