1 // Tencent is pleased to support the open source community by making ncnn available.
2 //
3 // Copyright (C) 2017 THL A29 Limited, a Tencent company. All rights reserved.
4 //
5 // Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
6 // in compliance with the License. You may obtain a copy of the License at
7 //
8 // https://opensource.org/licenses/BSD-3-Clause
9 //
10 // Unless required by applicable law or agreed to in writing, software distributed
11 // under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
12 // CONDITIONS OF ANY KIND, either express or implied. See the License for the
13 // specific language governing permissions and limitations under the License.
14
15 #include "onnx.pb.h"
16
17 #include <algorithm>
18 #include <float.h>
19 #include <fstream>
20 #include <google/protobuf/io/coded_stream.h>
21 #include <google/protobuf/io/zero_copy_stream_impl.h>
22 #include <google/protobuf/message.h>
23 #include <google/protobuf/text_format.h>
24 #include <iostream>
25 #include <limits.h>
26 #include <limits>
27 #include <set>
28 #include <stdio.h>
29
read_proto_from_binary(const char * filepath,onnx::ModelProto * message)30 static bool read_proto_from_binary(const char* filepath, onnx::ModelProto* message)
31 {
32 std::ifstream fs(filepath, std::ifstream::in | std::ifstream::binary);
33 if (!fs.is_open())
34 {
35 fprintf(stderr, "open failed %s\n", filepath);
36 return false;
37 }
38
39 google::protobuf::io::IstreamInputStream input(&fs);
40 google::protobuf::io::CodedInputStream codedstr(&input);
41
42 #if GOOGLE_PROTOBUF_VERSION >= 3011000
43 codedstr.SetTotalBytesLimit(INT_MAX);
44 #else
45 codedstr.SetTotalBytesLimit(INT_MAX, INT_MAX / 2);
46 #endif
47
48 bool success = message->ParseFromCodedStream(&codedstr);
49
50 fs.close();
51
52 return success;
53 }
54
get_node_attr_ai(const onnx::NodeProto & node,const char * key)55 static std::vector<int> get_node_attr_ai(const onnx::NodeProto& node, const char* key)
56 {
57 std::vector<int> v;
58
59 for (int i = 0; i < node.attribute_size(); i++)
60 {
61 const onnx::AttributeProto& attr = node.attribute(i);
62 if (attr.name() == key)
63 {
64 v.resize(attr.ints_size());
65 for (int j = 0; j < attr.ints_size(); j++)
66 {
67 v[j] = std::max(std::min(attr.ints(j), (::google::protobuf::int64)INT_MAX), (::google::protobuf::int64)INT_MIN);
68 }
69
70 break;
71 }
72 }
73
74 return v;
75 }
76
get_node_attr_af(const onnx::NodeProto & node,const char * key)77 static std::vector<float> get_node_attr_af(const onnx::NodeProto& node, const char* key)
78 {
79 std::vector<float> v;
80
81 for (int i = 0; i < node.attribute_size(); i++)
82 {
83 const onnx::AttributeProto& attr = node.attribute(i);
84 if (attr.name() == key)
85 {
86 v.resize(attr.floats_size());
87 for (int j = 0; j < attr.floats_size(); j++)
88 {
89 v[j] = attr.floats(j);
90 }
91
92 break;
93 }
94 }
95
96 return v;
97 }
98
get_node_attr_i(const onnx::NodeProto & node,const char * key,int def=0)99 static int get_node_attr_i(const onnx::NodeProto& node, const char* key, int def = 0)
100 {
101 for (int i = 0; i < node.attribute_size(); i++)
102 {
103 const onnx::AttributeProto& attr = node.attribute(i);
104 if (attr.name() == key)
105 {
106 return std::max(std::min(attr.i(), (::google::protobuf::int64)INT_MAX), (::google::protobuf::int64)INT_MIN);
107 }
108 }
109
110 return def;
111 }
112
get_node_attr_f(const onnx::NodeProto & node,const char * key,float def=0.f)113 static float get_node_attr_f(const onnx::NodeProto& node, const char* key, float def = 0.f)
114 {
115 for (int i = 0; i < node.attribute_size(); i++)
116 {
117 const onnx::AttributeProto& attr = node.attribute(i);
118 if (attr.name() == key)
119 {
120 return attr.f();
121 }
122 }
123
124 return def;
125 }
126
get_node_attr_s(const onnx::NodeProto & node,const char * key,const std::string & def=std::string ())127 static std::string get_node_attr_s(const onnx::NodeProto& node, const char* key, const std::string& def = std::string())
128 {
129 for (int i = 0; i < node.attribute_size(); i++)
130 {
131 const onnx::AttributeProto& attr = node.attribute(i);
132 if (attr.name() == key)
133 {
134 return attr.s();
135 }
136 }
137
138 return def;
139 }
140
get_node_attr_tensor(const onnx::NodeProto & node,const char * key)141 static onnx::TensorProto get_node_attr_tensor(const onnx::NodeProto& node, const char* key)
142 {
143 for (int i = 0; i < node.attribute_size(); i++)
144 {
145 const onnx::AttributeProto& attr = node.attribute(i);
146 if (attr.name() == key)
147 {
148 return attr.t();
149 }
150 }
151
152 return onnx::TensorProto();
153 }
154
get_node_attr_from_input_f(const onnx::TensorProto & tp)155 static float get_node_attr_from_input_f(const onnx::TensorProto& tp)
156 {
157 float v = 0.f;
158
159 // float
160 if (tp.data_type() == 1)
161 {
162 const float* shape_data = 0;
163 if (tp.has_raw_data())
164 {
165 shape_data = (const float*)tp.raw_data().data();
166 }
167 else
168 {
169 shape_data = tp.float_data().data();
170 }
171 v = shape_data[0];
172 }
173 // double
174 else if (tp.data_type() == 11)
175 {
176 const double* shape_data = 0;
177 if (tp.has_raw_data())
178 {
179 shape_data = (const double*)tp.raw_data().data();
180 }
181 else
182 {
183 shape_data = tp.double_data().data();
184 }
185 v = shape_data[0];
186 }
187 // int64
188 else if (tp.data_type() == 7)
189 {
190 const int64_t* shape_data = 0;
191 if (tp.has_raw_data())
192 {
193 shape_data = (const int64_t*)tp.raw_data().data();
194 }
195 else
196 {
197 shape_data = tp.int64_data().data();
198 }
199 v = std::max(std::min(shape_data[0], (::google::protobuf::int64)INT_MAX), (::google::protobuf::int64)INT_MIN);
200 }
201 // int32
202 else if (tp.data_type() == 6)
203 {
204 const int32_t* shape_data = 0;
205 if (tp.has_raw_data())
206 {
207 shape_data = (const int32_t*)tp.raw_data().data();
208 }
209 else
210 {
211 shape_data = tp.int32_data().data();
212 }
213 v = shape_data[0];
214 }
215 else
216 {
217 fprintf(stderr, "Unknown data type %d\n", tp.data_type());
218 abort();
219 }
220
221 return v;
222 }
223
get_node_attr_from_input_ai(const onnx::TensorProto & tp)224 static std::vector<int> get_node_attr_from_input_ai(const onnx::TensorProto& tp)
225 {
226 int size = 0;
227
228 std::vector<int> v;
229
230 // int64
231 if (tp.data_type() == 7)
232 {
233 const int64_t* shape_data = 0;
234 if (tp.has_raw_data())
235 {
236 shape_data = (const int64_t*)tp.raw_data().data();
237 size = (int)(tp.raw_data().size() / 8);
238 }
239 else
240 {
241 shape_data = tp.int64_data().data();
242 size = tp.int64_data_size();
243 }
244 for (int j = 0; j < size; j++)
245 {
246 int vi = std::max(std::min(shape_data[j], (::google::protobuf::int64)INT_MAX), (::google::protobuf::int64)INT_MIN);
247 v.push_back(vi);
248 }
249 }
250 // int32
251 else if (tp.data_type() == 6)
252 {
253 const int32_t* shape_data = 0;
254 if (tp.has_raw_data())
255 {
256 shape_data = (const int32_t*)tp.raw_data().data();
257 size = (int)(tp.raw_data().size() / 4);
258 }
259 else
260 {
261 shape_data = tp.int32_data().data();
262 size = tp.int32_data_size();
263 }
264 for (int j = 0; j < size; j++)
265 {
266 v.push_back(shape_data[j]);
267 }
268 }
269 else
270 {
271 fprintf(stderr, "Unknown data type %d\n", tp.data_type());
272 }
273
274 return v;
275 }
276
get_node_attr_from_input_af(const onnx::TensorProto & tp)277 static std::vector<float> get_node_attr_from_input_af(const onnx::TensorProto& tp)
278 {
279 int size = 0;
280
281 std::vector<float> v;
282
283 // float
284 if (tp.data_type() == 1)
285 {
286 const float* shape_data = 0;
287 if (tp.has_raw_data())
288 {
289 shape_data = (const float*)tp.raw_data().data();
290 size = (int)(tp.raw_data().size() / 4);
291 }
292 else
293 {
294 shape_data = tp.float_data().data();
295 size = tp.float_data_size();
296 }
297 for (int j = 0; j < size; j++)
298 {
299 v.push_back(shape_data[j]);
300 }
301 }
302 // double
303 else if (tp.data_type() == 11)
304 {
305 const double* shape_data = 0;
306 if (tp.has_raw_data())
307 {
308 shape_data = (const double*)tp.raw_data().data();
309 size = (int)(tp.raw_data().size() / 8);
310 }
311 else
312 {
313 shape_data = tp.double_data().data();
314 size = tp.double_data_size();
315 }
316 for (int j = 0; j < size; j++)
317 {
318 v.push_back((float)shape_data[j]);
319 }
320 }
321 else
322 {
323 fprintf(stderr, "Unknown data type %d\n", tp.data_type());
324 }
325
326 return v;
327 }
328
get_tensor_proto_data_size(const onnx::TensorProto & tp)329 static int get_tensor_proto_data_size(const onnx::TensorProto& tp)
330 {
331 if (tp.has_raw_data())
332 {
333 const std::string& raw_data = tp.raw_data();
334 int size = (int)raw_data.size() / 4;
335 return size;
336 }
337 else if (tp.data_type() == 1)
338 {
339 return tp.float_data_size();
340 }
341
342 return 0;
343 }
344
fwrite_tensor_proto_data(const onnx::TensorProto & tp,FILE * bp)345 static void fwrite_tensor_proto_data(const onnx::TensorProto& tp, FILE* bp)
346 {
347 int size = get_tensor_proto_data_size(tp);
348
349 if (tp.has_raw_data())
350 {
351 const std::string& raw_data = tp.raw_data();
352 fwrite(raw_data.data(), sizeof(float), size, bp);
353 }
354 else if (tp.data_type() == 1)
355 {
356 fwrite(tp.float_data().data(), sizeof(float), size, bp);
357 }
358 }
359
fuse_weight_reshape(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)360 static void fuse_weight_reshape(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
361 {
362 int node_count = mutable_graph->node_size();
363 for (int i = 0; i < node_count; i++)
364 {
365 onnx::NodeProto* node = mutable_graph->mutable_node(i);
366
367 // weight <= Reshape(weight)
368 if (node->op_type() == "Reshape")
369 {
370 // check weight
371 if (weights.find(node->input(0)) == weights.end())
372 continue;
373
374 weights[node->output(0)] = weights[node->input(0)];
375
376 // set weight shape directly
377 std::vector<int> shape;
378 if (node->input_size() == 1)
379 {
380 shape = get_node_attr_ai(*node, "shape");
381 }
382 else if (node->input_size() == 2)
383 {
384 // opset 5
385 shape = get_node_attr_from_input_ai(weights[node->input(1)]);
386 }
387
388 weights[node->output(0)].clear_dims();
389 for (int j = 0; j < shape.size(); j++)
390 {
391 weights[node->output(0)].add_dims(shape[j]);
392 }
393
394 // reduce
395 node->set_op_type("noop_reducedncnn");
396
397 node_reference[node->input(0)] -= 1;
398 if (node->input_size() == 2)
399 {
400 node_reference[node->input(1)] -= 1;
401 }
402
403 reduced_node_count += 1;
404 i += 1;
405 }
406 }
407 }
408
fuse_weight_transpose(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)409 static void fuse_weight_transpose(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
410 {
411 int node_count = mutable_graph->node_size();
412 for (int i = 0; i < node_count; i++)
413 {
414 onnx::NodeProto* node = mutable_graph->mutable_node(i);
415
416 // weight <= Transpose(weight)
417 if (node->op_type() == "Transpose")
418 {
419 // check weight
420 if (weights.find(node->input(0)) == weights.end())
421 continue;
422
423 if (weights[node->input(0)].dims_size() != 2)
424 continue;
425
426 // perm = (1, 0)
427 std::vector<int> perm = get_node_attr_ai(*node, "perm");
428 if (perm.size() != 2)
429 continue;
430 if (perm[0] != 1 || perm[1] != 0)
431 continue;
432
433 weights[node->output(0)] = weights[node->input(0)];
434
435 // permute weight
436 {
437 onnx::TensorProto& B = weights[node->output(0)];
438
439 const int h = B.dims(0);
440 const int w = B.dims(1);
441
442 std::vector<float> permuted_data;
443 permuted_data.reserve((size_t)h * w);
444 const float* bptr = B.has_raw_data() ? (const float*)B.raw_data().data() : B.float_data().data();
445
446 for (int j = 0; j < w; j++)
447 {
448 for (int k = 0; k < h; k++)
449 {
450 float vb = bptr[k * w + j];
451 permuted_data.push_back(vb);
452 }
453 }
454
455 B.set_dims(0, w);
456 B.set_dims(1, h);
457
458 if (B.has_raw_data())
459 {
460 B.set_raw_data(permuted_data.data(), permuted_data.size() * sizeof(float));
461 }
462 else
463 {
464 for (int j = 0; j < (int)permuted_data.size(); j++)
465 B.set_float_data(j, permuted_data[j]);
466 }
467 }
468
469 // reduce
470 node->set_op_type("noop_reducedncnn");
471
472 node_reference[node->input(0)] -= 1;
473
474 reduced_node_count += 1;
475 i += 1;
476 }
477 }
478 }
479
fuse_shufflechannel(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)480 static void fuse_shufflechannel(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
481 {
482 int node_count = mutable_graph->node_size();
483 for (int i = 0; i < node_count; i++)
484 {
485 onnx::NodeProto* node = mutable_graph->mutable_node(i);
486
487 // ShuffleChannel <= Reshape - Transpose - Reshape
488 // ShuffleChannel <= Reshape - Transpose - Constant - Reshape
489 if (node->op_type() == "Reshape")
490 {
491 if (node_reference[node->output(0)] != 1)
492 continue;
493
494 std::vector<int> shape;
495 if (node->input_size() == 1)
496 {
497 shape = get_node_attr_ai(*node, "shape");
498 }
499 else
500 {
501 // skip weight reshape
502 if (weights.find(node->input(1)) == weights.end())
503 continue;
504
505 shape = get_node_attr_from_input_ai(weights[node->input(1)]);
506 }
507
508 // 1 groups channels_per_group, height, width
509 // reverse style = channels_per_group, groups, height * width
510 if (shape.size() != 5 && shape.size() != 3)
511 continue;
512
513 if (shape.size() == 5 && shape[0] != 1)
514 continue;
515
516 if (i + 2 >= node_count)
517 continue;
518
519 onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
520 onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
521
522 if (node3->op_type() == "Constant")
523 {
524 if (i + 3 >= node_count)
525 continue;
526
527 node3 = mutable_graph->mutable_node(i + 3);
528 }
529
530 if (node2->op_type() != "Transpose" || node3->op_type() != "Reshape")
531 continue;
532
533 if (node_reference[node2->output(0)] != 1)
534 continue;
535
536 // 0 2 1 3 4
537 // reverse style = 1 0 2
538 std::vector<int> perm = get_node_attr_ai(*node2, "perm");
539 if (perm.size() != 5 && perm.size() != 3)
540 continue;
541
542 if (perm.size() == 5 && (perm[0] != 0 || perm[1] != 2 || perm[2] != 1 || perm[3] != 3 || perm[4] != 4))
543 continue;
544
545 if (perm.size() == 3 && (perm[0] != 1 || perm[1] != 0 || perm[2] != 2))
546 continue;
547
548 std::vector<int> shape3;
549 if (node3->input_size() == 1)
550 {
551 shape3 = get_node_attr_ai(*node3, "shape");
552 }
553 else
554 {
555 // skip weight reshape
556 if (weights.find(node3->input(1)) == weights.end())
557 continue;
558
559 shape3 = get_node_attr_from_input_ai(weights[node3->input(1)]);
560 }
561
562 // 1, -1, height, width
563 // reverse style = group, -1, channels_per_group, height, width
564 if (shape3.size() != 4 && shape3.size() != 5)
565 continue;
566
567 if (shape3.size() == 4 && (shape3[0] != 1 || (shape3[1] != -1 && shape3[1] != shape[1] * shape[2])))
568 continue;
569
570 if (shape3.size() == 5 && (shape3[0] != shape[1] || shape3[2] != shape[0] || shape3[3] * shape3[4] != shape[2]))
571 continue;
572
573 // reduce
574 node->set_op_type("noop_reducedncnn");
575 node2->set_op_type("noop_reducedncnn");
576
577 if (node->input_size() == 2)
578 {
579 node_reference[node->input(1)] -= 1;
580 }
581 node_reference[node->output(0)] -= 1;
582 node_reference[node2->output(0)] -= 1;
583 if (node3->input_size() == 2)
584 {
585 node_reference[node3->input(1)] -= 1;
586 }
587
588 blob_names.erase(node->output(0));
589 blob_names.erase(node2->output(0));
590
591 node3->set_op_type("ShuffleChannel");
592 node3->set_input(0, node->input(0));
593
594 onnx::AttributeProto* attr_group = node3->add_attribute();
595 attr_group->set_name("group");
596 attr_group->set_i(shape[1]);
597
598 onnx::AttributeProto* attr_reverse = node3->add_attribute();
599 attr_reverse->set_name("reverse");
600 attr_reverse->set_i(shape.size() == 3);
601
602 reduced_node_count += 2;
603 i += 2;
604 }
605 }
606 }
607
fuse_shufflechannel_split(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)608 static void fuse_shufflechannel_split(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
609 {
610 int node_count = mutable_graph->node_size();
611 for (int i = 0; i < node_count; i++)
612 {
613 onnx::NodeProto* node = mutable_graph->mutable_node(i);
614
615 // Split <= ShuffleChannel(reverse type) - Gather(0) - Gather(1)
616 if (node->op_type() == "ShuffleChannel")
617 {
618 // reverse = 1
619 int reverse = get_node_attr_i(*node, "reverse");
620 if (reverse != 1)
621 continue;
622
623 if (i + 2 >= node_count)
624 continue;
625
626 onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
627 onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
628
629 if (node2->op_type() != "Gather" || node3->op_type() != "Gather")
630 continue;
631
632 if (node2->input(0) != node->output(0) || node3->input(0) != node->output(0))
633 continue;
634
635 // axis = 0
636 int gather2_axis = get_node_attr_i(*node2, "axis");
637 if (gather2_axis != 0)
638 continue;
639
640 // indices = 0
641 if (weights.find(node2->input(1)) == weights.end())
642 continue;
643
644 std::vector<int> gather2_indices = get_node_attr_from_input_ai(weights[node2->input(1)]);
645 if (gather2_indices.size() != 1 || gather2_indices[0] != 0)
646 continue;
647
648 // axis = 0
649 int gather3_axis = get_node_attr_i(*node3, "axis");
650 if (gather3_axis != 0)
651 continue;
652
653 // indices = 1
654 if (weights.find(node3->input(1)) == weights.end())
655 continue;
656
657 std::vector<int> gather3_indices = get_node_attr_from_input_ai(weights[node3->input(1)]);
658 if (gather3_indices.size() != 1 || gather3_indices[0] != 1)
659 continue;
660
661 // reduce
662 node2->set_op_type("noop_reducedncnn");
663
664 node_reference[node->output(0)] -= 2;
665 node_reference[node2->input(1)] -= 1;
666 node_reference[node3->input(1)] -= 1;
667
668 node3->set_op_type("Split");
669 node3->clear_input();
670 node3->add_input(node->output(0));
671 node3->add_output(node3->output(0));
672 node3->set_output(0, node2->output(0));
673
674 node3->clear_attribute();
675 onnx::AttributeProto* attr_axis = node3->add_attribute();
676 attr_axis->set_name("axis");
677 attr_axis->set_i(1);
678
679 reduced_node_count += 1;
680 i += 1;
681 }
682 }
683 }
684
fuse_hardswish(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)685 static void fuse_hardswish(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
686 {
687 int node_count = mutable_graph->node_size();
688 for (int i = 0; i < node_count; i++)
689 {
690 onnx::NodeProto* node = mutable_graph->mutable_node(i);
691
692 // HardSwish <= Add(+3) - Clip(0,6) - Mul(X,) - Div(/6)
693 // HardSwish <= Add(+3) - Clip(0,6) - Mul(X,) - Mul(*(1/6))
694 // HardSwish <= Add(+3) - Clip(0,6) - Mul(X,) - Constant - Div(/6)
695 // HardSwish <= Add(+3) - Clip(0,6) - Mul(X,) - Constant - Mul(*(1/6))
696 // out = x * F.relu6(x + 3, inplace=True) / 6
697 if (node->op_type() == "Add")
698 {
699 if (node_reference[node->output(0)] != 1)
700 continue;
701
702 if (i + 3 >= node_count)
703 continue;
704
705 if (weights.find(node->input(1)) == weights.end())
706 continue;
707
708 const onnx::TensorProto& add_three = weights[node->input(1)];
709 if (add_three.dims_size() != 0 || get_tensor_proto_data_size(add_three) != 1)
710 continue;
711
712 float constant_add_three = get_node_attr_from_input_f(add_three);
713 if (constant_add_three != 3.f)
714 continue;
715
716 onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
717 onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
718 onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3);
719
720 if (node4->op_type() == "Constant")
721 {
722 if (i + 4 >= node_count)
723 continue;
724
725 node4 = mutable_graph->mutable_node(i + 4);
726 }
727
728 if (node2->op_type() != "Clip" || node3->op_type() != "Mul" || (node4->op_type() != "Div" && node4->op_type() != "Mul"))
729 continue;
730
731 if (node_reference[node2->output(0)] != 1)
732 continue;
733
734 float relu6_min;
735 float relu6_max;
736 if (node2->input_size() == 1)
737 {
738 relu6_min = get_node_attr_f(*node2, "min", -FLT_MAX);
739 relu6_max = get_node_attr_f(*node2, "max", FLT_MAX);
740 }
741 else
742 {
743 const onnx::TensorProto& min_tp = weights[node2->input(1)];
744 const onnx::TensorProto& max_tp = weights[node2->input(2)];
745
746 relu6_min = get_node_attr_from_input_f(min_tp);
747 relu6_max = get_node_attr_from_input_f(max_tp);
748 }
749 if (relu6_min != 0.f || relu6_max != 6.f)
750 continue;
751
752 if (node_reference[node3->output(0)] != 1)
753 continue;
754
755 if (node3->input(0) != node->input(0) || node3->input(1) != node2->output(0))
756 continue;
757
758 if (weights.find(node4->input(1)) == weights.end())
759 continue;
760
761 const onnx::TensorProto& div_six = weights[node4->input(1)];
762 if (div_six.dims_size() != 0 || get_tensor_proto_data_size(div_six) != 1)
763 continue;
764
765 float constant_div_six = get_node_attr_from_input_f(div_six);
766 if (node4->op_type() == "Div" && constant_div_six != 6.f)
767 continue;
768 if (node4->op_type() == "Mul" && constant_div_six != 1 / 6.f)
769 continue;
770
771 // reduce
772 node->set_op_type("noop_reducedncnn");
773 node2->set_op_type("noop_reducedncnn");
774 node3->set_op_type("noop_reducedncnn");
775
776 node_reference[node->input(0)] -= 1;
777 node_reference[node->input(1)] -= 1;
778 node_reference[node->output(0)] -= 1;
779 if (node2->input_size() == 3)
780 {
781 node_reference[node2->input(1)] -= 1;
782 node_reference[node2->input(2)] -= 1;
783 }
784 node_reference[node2->output(0)] -= 1;
785 node_reference[node3->output(0)] -= 1;
786 node_reference[node4->input(1)] -= 1;
787
788 blob_names.erase(node->output(0));
789 blob_names.erase(node2->output(0));
790 blob_names.erase(node3->output(0));
791
792 node4->set_op_type("HardSwish");
793 node4->clear_input();
794 node4->add_input(node->input(0));
795
796 onnx::AttributeProto* attr_alpha = node4->add_attribute();
797 attr_alpha->set_name("alpha");
798 attr_alpha->set_f(1.f / 6.f);
799
800 onnx::AttributeProto* attr_beta = node4->add_attribute();
801 attr_beta->set_name("beta");
802 attr_beta->set_f(3.f / 6.f);
803
804 reduced_node_count += 3;
805 i += 3;
806 }
807 }
808
809 for (int i = 0; i < node_count; i++)
810 {
811 onnx::NodeProto* node = mutable_graph->mutable_node(i);
812
813 // HardSwish <= HardSigmoid - Mul
814 // out = x * hsigmoid(x)
815 if (node->op_type() == "HardSigmoid")
816 {
817 if (node_reference[node->output(0)] != 1)
818 continue;
819
820 float alpha = get_node_attr_f(*node, "alpha", 0.2f);
821 float beta = get_node_attr_f(*node, "beta", 0.5f);
822
823 if (i + 1 >= node_count)
824 continue;
825
826 onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
827
828 if (node2->op_type() != "Mul")
829 continue;
830
831 if (node2->input(0) != node->input(0) || node2->input(1) != node->output(0))
832 continue;
833
834 // reduce
835 node->set_op_type("noop_reducedncnn");
836
837 node_reference[node->input(0)] -= 1;
838 node_reference[node->output(0)] -= 1;
839
840 blob_names.erase(node->output(0));
841
842 node2->set_op_type("HardSwish");
843 node2->clear_input();
844 node2->add_input(node->input(0));
845
846 onnx::AttributeProto* attr_alpha = node2->add_attribute();
847 attr_alpha->set_name("alpha");
848 attr_alpha->set_f(alpha);
849
850 onnx::AttributeProto* attr_beta = node2->add_attribute();
851 attr_beta->set_name("beta");
852 attr_beta->set_f(beta);
853
854 reduced_node_count += 1;
855 i += 1;
856 }
857 }
858 }
859
fuse_hardsigmoid(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)860 static void fuse_hardsigmoid(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
861 {
862 int node_count = mutable_graph->node_size();
863 for (int i = 0; i < node_count; i++)
864 {
865 onnx::NodeProto* node = mutable_graph->mutable_node(i);
866
867 // HardSigmoid <= Add(+3) - Clip(0,6) - Div(/6)
868 // HardSigmoid <= Add(+3) - Clip(0,6) - Mul(*(1/6))
869 // HardSigmoid <= Add(+3) - Clip(0,6) - Constant - Div(/6)
870 // HardSigmoid <= Add(+3) - Clip(0,6) - Constant - Mul(*(1/6))
871 // out = F.relu6(x + 3, inplace=True) / 6
872 if (node->op_type() == "Add")
873 {
874 if (node_reference[node->output(0)] != 1)
875 continue;
876
877 if (i + 2 >= node_count)
878 continue;
879
880 if (weights.find(node->input(1)) == weights.end())
881 continue;
882
883 const onnx::TensorProto& add_three = weights[node->input(1)];
884 if (add_three.dims_size() != 0 || get_tensor_proto_data_size(add_three) != 1)
885 continue;
886
887 float constant_add_three = get_node_attr_from_input_f(add_three);
888 if (constant_add_three != 3.f)
889 continue;
890
891 onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
892 onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
893
894 if (node3->op_type() == "Constant")
895 {
896 if (i + 3 >= node_count)
897 continue;
898
899 node3 = mutable_graph->mutable_node(i + 3);
900 }
901
902 if (node2->op_type() != "Clip" || (node3->op_type() != "Div" && node3->op_type() != "Mul"))
903 continue;
904
905 if (node_reference[node2->output(0)] != 1)
906 continue;
907
908 float relu6_min;
909 float relu6_max;
910 if (node2->input_size() == 1)
911 {
912 relu6_min = get_node_attr_f(*node2, "min", -FLT_MAX);
913 relu6_max = get_node_attr_f(*node2, "max", FLT_MAX);
914 }
915 else
916 {
917 const onnx::TensorProto& min_tp = weights[node2->input(1)];
918 const onnx::TensorProto& max_tp = weights[node2->input(2)];
919
920 relu6_min = get_node_attr_from_input_f(min_tp);
921 relu6_max = get_node_attr_from_input_f(max_tp);
922 }
923 if (relu6_min != 0.f || relu6_max != 6.f)
924 continue;
925
926 if (weights.find(node3->input(1)) == weights.end())
927 continue;
928
929 const onnx::TensorProto& div_six = weights[node3->input(1)];
930 if (div_six.dims_size() != 0 || get_tensor_proto_data_size(div_six) != 1)
931 continue;
932
933 float constant_div_six = get_node_attr_from_input_f(div_six);
934 if (node3->op_type() == "Div" && constant_div_six != 6.f)
935 continue;
936 if (node3->op_type() == "Mul" && constant_div_six != 1 / 6.f)
937 continue;
938
939 // reduce
940 node->set_op_type("noop_reducedncnn");
941 node2->set_op_type("noop_reducedncnn");
942
943 node_reference[node->input(1)] -= 1;
944 node_reference[node->output(0)] -= 1;
945 if (node2->input_size() == 3)
946 {
947 node_reference[node2->input(1)] -= 1;
948 node_reference[node2->input(2)] -= 1;
949 }
950 node_reference[node2->output(0)] -= 1;
951 node_reference[node3->input(1)] -= 1;
952
953 blob_names.erase(node->output(0));
954 blob_names.erase(node2->output(0));
955
956 node3->set_op_type("HardSigmoid");
957 node3->clear_input();
958 node3->add_input(node->input(0));
959
960 onnx::AttributeProto* attr_alpha = node3->add_attribute();
961 attr_alpha->set_name("alpha");
962 attr_alpha->set_f(1.f / 6.f);
963
964 onnx::AttributeProto* attr_beta = node3->add_attribute();
965 attr_beta->set_name("beta");
966 attr_beta->set_f(3.f / 6.f);
967
968 reduced_node_count += 2;
969 i += 2;
970 }
971 }
972 }
973
fuse_swish(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)974 static void fuse_swish(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
975 {
976 int node_count = mutable_graph->node_size();
977 for (int i = 0; i < node_count; i++)
978 {
979 onnx::NodeProto* node = mutable_graph->mutable_node(i);
980
981 // Swish <= Sigmoid - Mul
982 // x * torch.sigmoid(x)
983 if (node->op_type() == "Sigmoid")
984 {
985 if (node_reference[node->output(0)] != 1)
986 continue;
987
988 if (i + 1 >= node_count)
989 continue;
990
991 onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
992
993 if (node2->op_type() != "Mul")
994 continue;
995
996 if (node2->input(0) != node->input(0) || node2->input(1) != node->output(0))
997 continue;
998
999 // reduce
1000 node->set_op_type("noop_reducedncnn");
1001
1002 node_reference[node->input(0)] -= 1;
1003 node_reference[node->output(0)] -= 1;
1004
1005 blob_names.erase(node->output(0));
1006
1007 node2->set_op_type("Swish");
1008 node2->clear_input();
1009 node2->add_input(node->input(0));
1010
1011 reduced_node_count += 1;
1012 i += 1;
1013 }
1014 }
1015 }
1016
fuse_batchnorm1d_squeeze_unsqueeze(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)1017 static void fuse_batchnorm1d_squeeze_unsqueeze(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
1018 {
1019 int node_count = mutable_graph->node_size();
1020 for (int i = 0; i < node_count; i++)
1021 {
1022 onnx::NodeProto* node = mutable_graph->mutable_node(i);
1023
1024 // BatchNormalization <= Unsqueeze - BatchNormalization - Squeeze
1025 if (node->op_type() == "Unsqueeze")
1026 {
1027 if (node_reference[node->output(0)] != 1)
1028 continue;
1029
1030 if (i + 2 >= node_count)
1031 continue;
1032
1033 onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
1034 onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
1035
1036 if (node2->op_type() != "BatchNormalization" || node3->op_type() != "Squeeze")
1037 continue;
1038
1039 if (node_reference[node2->output(0)] != 1)
1040 continue;
1041
1042 if (node2->input(0) != node->output(0) || node3->input(0) != node2->output(0))
1043 continue;
1044
1045 // reduce
1046 node->set_op_type("noop_reducedncnn");
1047 node3->set_op_type("noop_reducedncnn");
1048
1049 node_reference[node->output(0)] -= 1;
1050 node_reference[node2->output(0)] -= 1;
1051
1052 blob_names.erase(node->output(0));
1053 blob_names.erase(node2->output(0));
1054
1055 node2->set_input(0, node->input(0));
1056 node2->set_output(0, node3->output(0));
1057
1058 reduced_node_count += 2;
1059 i += 2;
1060 }
1061 }
1062 }
1063
fuse_unsqueeze_prelu(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)1064 static void fuse_unsqueeze_prelu(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
1065 {
1066 int node_count = mutable_graph->node_size();
1067 for (int i = 0; i < node_count; i++)
1068 {
1069 onnx::NodeProto* node = mutable_graph->mutable_node(i);
1070
1071 // PReLU <= Unsqueeze - PReLU
1072 if (node->op_type() == "Unsqueeze")
1073 {
1074 // check weight
1075 if (weights.find(node->input(0)) == weights.end())
1076 continue;
1077
1078 onnx::TensorProto& B = weights[node->input(0)];
1079 if (B.dims_size() != 1)
1080 continue;
1081
1082 if (node_reference[node->output(0)] != 1)
1083 continue;
1084
1085 // axes = (1, 2)
1086 std::vector<int> axes = get_node_attr_ai(*node, "axes");
1087 if (axes.size() != 2)
1088 continue;
1089 if (axes[0] != 1 || axes[1] != 2)
1090 continue;
1091
1092 if (i + 1 >= node_count)
1093 continue;
1094
1095 onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
1096
1097 if (node2->op_type() != "PRelu")
1098 continue;
1099
1100 if (node2->input(1) != node->output(0))
1101 continue;
1102
1103 // reduce
1104 node->set_op_type("noop_reducedncnn");
1105
1106 node_reference[node->output(0)] -= 1;
1107
1108 blob_names.erase(node->output(0));
1109
1110 node2->set_input(1, node->input(0));
1111
1112 reduced_node_count += 1;
1113 i += 1;
1114 }
1115 }
1116 }
1117
fuse_normalize(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)1118 static void fuse_normalize(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
1119 {
1120 int node_count = mutable_graph->node_size();
1121 for (int i = 0; i < node_count; i++)
1122 {
1123 onnx::NodeProto* node = mutable_graph->mutable_node(i);
1124
1125 // Normalize <= X - ReduceL2 - Clip - Expand - Div
1126 // Normalize <= X - ReduceL2 - Clip - Shape - Expand - Div
1127 if (node->op_type() == "ReduceL2")
1128 {
1129 if (node_reference[node->output(0)] != 1)
1130 continue;
1131
1132 // axes = (1)
1133 std::vector<int> axes = get_node_attr_ai(*node, "axes");
1134 if (axes.size() != 1)
1135 continue;
1136 if (axes[0] != 1)
1137 continue;
1138
1139 if (i + 3 >= node_count)
1140 continue;
1141
1142 onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
1143 onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
1144 onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3);
1145
1146 bool has_shape_node = node3->op_type() == "Shape";
1147 onnx::NodeProto* node_shape = 0;
1148 if (has_shape_node)
1149 {
1150 if (i + 4 >= node_count)
1151 continue;
1152
1153 node_shape = node3;
1154 node3 = mutable_graph->mutable_node(i + 3);
1155 node4 = mutable_graph->mutable_node(i + 4);
1156 }
1157
1158 if (node2->op_type() != "Clip" || node3->op_type() != "Expand" || node4->op_type() != "Div")
1159 continue;
1160
1161 if (node_reference[node2->output(0)] != 1)
1162 continue;
1163
1164 if (node_reference[node3->output(0)] != 1)
1165 continue;
1166
1167 if (node2->input(0) != node->output(0) || node3->input(0) != node2->output(0)
1168 || node4->input(0) != node->input(0) || node4->input(1) != node3->output(0))
1169 continue;
1170
1171 if (has_shape_node)
1172 {
1173 if (node_shape->input(0) != node->input(0) || node3->input(1) != node_shape->output(0))
1174 continue;
1175 }
1176
1177 // +eps
1178 float clip_min;
1179 if (node2->input_size() == 1)
1180 {
1181 clip_min = get_node_attr_f(*node2, "min", -FLT_MAX);
1182 }
1183 else
1184 {
1185 const onnx::TensorProto& min_tp = weights[node2->input(1)];
1186
1187 clip_min = get_node_attr_from_input_f(min_tp);
1188 }
1189
1190 // reduce
1191 node->set_op_type("noop_reducedncnn");
1192 node2->set_op_type("noop_reducedncnn");
1193 if (has_shape_node)
1194 {
1195 node_shape->set_op_type("noop_reducedncnn");
1196 }
1197 node3->set_op_type("noop_reducedncnn");
1198
1199 node_reference[node->input(0)] -= has_shape_node ? 2 : 1;
1200 node_reference[node->output(0)] -= 1;
1201 node_reference[node2->output(0)] -= 1;
1202 if (has_shape_node)
1203 {
1204 node_reference[node_shape->output(0)] -= 1;
1205 }
1206 node_reference[node3->output(0)] -= 1;
1207 if (node3->input_size() == 2)
1208 {
1209 node_reference[node3->input(1)] -= 1;
1210 }
1211
1212 blob_names.erase(node->output(0));
1213 blob_names.erase(node2->output(0));
1214 if (has_shape_node)
1215 {
1216 blob_names.erase(node_shape->output(0));
1217 }
1218 blob_names.erase(node3->output(0));
1219
1220 node4->set_op_type("Normalize");
1221 node4->clear_input();
1222 node4->add_input(node->input(0));
1223
1224 onnx::AttributeProto* attr_alpha = node4->add_attribute();
1225 attr_alpha->set_name("eps");
1226 attr_alpha->set_f(clip_min);
1227
1228 reduced_node_count += has_shape_node ? 4 : 3;
1229 i += has_shape_node ? 4 : 3;
1230 }
1231 }
1232 }
1233
fuse_groupnorm(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)1234 static void fuse_groupnorm(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
1235 {
1236 int node_count = mutable_graph->node_size();
1237 for (int i = 0; i < node_count; i++)
1238 {
1239 onnx::NodeProto* node = mutable_graph->mutable_node(i);
1240
1241 // GroupNorm <= X - Reshape - InstanceNormalization - Reshape - Mul - Add
1242 if (node->op_type() == "Reshape")
1243 {
1244 if (node_reference[node->output(0)] != 1)
1245 continue;
1246
1247 std::vector<int> shape;
1248 if (node->input_size() == 1)
1249 {
1250 shape = get_node_attr_ai(*node, "shape");
1251 }
1252 else
1253 {
1254 // skip weight reshape
1255 if (weights.find(node->input(1)) == weights.end())
1256 continue;
1257
1258 shape = get_node_attr_from_input_ai(weights[node->input(1)]);
1259 }
1260
1261 // 0, group, -1
1262 if (shape.size() != 3)
1263 continue;
1264
1265 if (shape[0] != 0 || shape[2] != -1)
1266 continue;
1267
1268 int groups = shape[1];
1269
1270 if (i + 4 >= node_count)
1271 continue;
1272
1273 onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
1274 onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
1275 onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3);
1276 onnx::NodeProto* node5 = mutable_graph->mutable_node(i + 4);
1277
1278 if (node2->op_type() != "InstanceNormalization" || node3->op_type() != "Reshape" || node4->op_type() != "Mul" || node5->op_type() != "Add")
1279 continue;
1280
1281 if (node_reference[node2->output(0)] != 1)
1282 continue;
1283
1284 if (node_reference[node3->output(0)] != 1)
1285 continue;
1286
1287 if (node_reference[node4->output(0)] != 1)
1288 continue;
1289
1290 if (node2->input(0) != node->output(0) || node3->input(0) != node2->output(0)
1291 || node4->input(0) != node3->output(0) || node5->input(0) != node4->output(0))
1292 continue;
1293
1294 // +eps
1295 float eps = get_node_attr_f(*node2, "epsilon", 1e-05f);
1296
1297 // InstanceNormalization S=1 B=0
1298 std::vector<float> S = get_node_attr_from_input_af(weights[node2->input(1)]);
1299 std::vector<float> B = get_node_attr_from_input_af(weights[node2->input(2)]);
1300 if ((int)S.size() != groups || (int)B.size() != groups)
1301 continue;
1302
1303 bool instancenorm_affine = false;
1304 for (int j = 0; j < groups; j++)
1305 {
1306 if (S[j] != 1.f || B[j] != 0.f)
1307 {
1308 instancenorm_affine = true;
1309 break;
1310 }
1311 }
1312
1313 if (instancenorm_affine)
1314 continue;
1315
1316 std::vector<int> shape2;
1317 if (node3->input_size() == 1)
1318 {
1319 shape2 = get_node_attr_ai(*node3, "shape");
1320 }
1321 else
1322 {
1323 // skip weight reshape
1324 if (weights.find(node3->input(1)) == weights.end())
1325 continue;
1326
1327 shape2 = get_node_attr_from_input_ai(weights[node3->input(1)]);
1328 }
1329
1330 // 1, channels, w, h
1331 if (shape2.size() != 4)
1332 continue;
1333
1334 if (shape2[0] != 1)
1335 continue;
1336
1337 int channels = shape2[1];
1338
1339 // affine
1340 std::vector<float> affine_S = get_node_attr_from_input_af(weights[node4->input(1)]);
1341 std::vector<float> affine_B = get_node_attr_from_input_af(weights[node5->input(1)]);
1342 if (affine_S.size() == 1 && affine_S[0] == 1.f && affine_B.size() == 1 && affine_B[0] == 0.f)
1343 {
1344 // no affine
1345 }
1346 else if ((int)affine_S.size() != channels && (int)affine_B.size() != channels)
1347 {
1348 // we only allow per-channel affine
1349 continue;
1350 }
1351
1352 // reduce
1353 node->set_op_type("noop_reducedncnn");
1354 node2->set_op_type("noop_reducedncnn");
1355 node3->set_op_type("noop_reducedncnn");
1356 node4->set_op_type("noop_reducedncnn");
1357
1358 if (node->input_size() == 2)
1359 {
1360 node_reference[node->input(1)] -= 1;
1361 }
1362 node_reference[node->output(0)] -= 1;
1363 node_reference[node2->input(1)] -= 1;
1364 node_reference[node2->input(2)] -= 1;
1365 node_reference[node2->output(0)] -= 1;
1366 if (node3->input_size() == 2)
1367 {
1368 node_reference[node3->input(1)] -= 1;
1369 }
1370 node_reference[node3->output(0)] -= 1;
1371 node_reference[node4->output(0)] -= 1;
1372
1373 blob_names.erase(node->output(0));
1374 blob_names.erase(node2->output(0));
1375 blob_names.erase(node3->output(0));
1376 blob_names.erase(node4->output(0));
1377
1378 std::string affine_scale = node4->input(1);
1379 std::string affine_bias = node5->input(1);
1380
1381 node5->set_op_type("GroupNorm");
1382 node5->clear_input();
1383 node5->add_input(node->input(0));
1384 node5->add_input(affine_scale);
1385 node5->add_input(affine_bias);
1386
1387 onnx::AttributeProto* attr_groups = node5->add_attribute();
1388 attr_groups->set_name("groups");
1389 attr_groups->set_i(groups);
1390
1391 onnx::AttributeProto* attr_channels = node5->add_attribute();
1392 attr_channels->set_name("channels");
1393 attr_channels->set_i(channels);
1394
1395 onnx::AttributeProto* attr_eps = node5->add_attribute();
1396 attr_eps->set_name("epsilon");
1397 attr_eps->set_f(eps);
1398
1399 onnx::AttributeProto* attr_affine = node5->add_attribute();
1400 attr_affine->set_name("affine");
1401 attr_affine->set_i(1);
1402
1403 reduced_node_count += 4;
1404 i += 4;
1405 }
1406 }
1407 }
1408
fuse_layernorm(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)1409 static void fuse_layernorm(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
1410 {
1411 int node_count = mutable_graph->node_size();
1412 for (int i = 0; i < node_count; i++)
1413 {
1414 onnx::NodeProto* node = mutable_graph->mutable_node(i);
1415
1416 // LayerNorm <= X - ReduceMean - Sub - Pow - ReduceMean - Add - Sqrt - Div
1417 // LayerNorm <= X - ReduceMean - Sub - Pow - ReduceMean - Add - Sqrt - Div - Mul - Add
1418 if (node->op_type() == "ReduceMean")
1419 {
1420 if (node_reference[node->output(0)] != 1)
1421 continue;
1422
1423 std::vector<int> axes = get_node_attr_ai(*node, "axes");
1424
1425 // -1
1426 // -2 -1
1427 if (axes.size() != 1 && axes.size() != 2)
1428 continue;
1429
1430 int normed_axes = (int)axes.size();
1431 if (normed_axes == 1 && axes[0] != -1)
1432 continue;
1433 if (normed_axes == 2 && (axes[0] != -2 || axes[1] != -1))
1434 continue;
1435
1436 if (i + 6 >= node_count)
1437 continue;
1438
1439 onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
1440 onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
1441 onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3);
1442 onnx::NodeProto* node5 = mutable_graph->mutable_node(i + 4);
1443 onnx::NodeProto* node6 = mutable_graph->mutable_node(i + 5);
1444 onnx::NodeProto* node7 = mutable_graph->mutable_node(i + 6);
1445
1446 if (node2->op_type() != "Sub" || node3->op_type() != "Pow" || node4->op_type() != "ReduceMean" || node5->op_type() != "Add" || node6->op_type() != "Sqrt" || node7->op_type() != "Div")
1447 continue;
1448
1449 if (node_reference[node2->output(0)] != 2)
1450 continue;
1451
1452 if (node_reference[node3->output(0)] != 1)
1453 continue;
1454
1455 if (node_reference[node4->output(0)] != 1)
1456 continue;
1457
1458 if (node_reference[node5->output(0)] != 1)
1459 continue;
1460
1461 if (node_reference[node6->output(0)] != 1)
1462 continue;
1463
1464 if (node2->input(0) != node->input(0) || node2->input(1) != node->output(0)
1465 || node3->input(0) != node2->output(0) || node4->input(0) != node3->output(0)
1466 || node5->input(0) != node4->output(0) || node6->input(0) != node5->output(0)
1467 || node7->input(0) != node2->output(0) || node7->input(1) != node6->output(0))
1468 continue;
1469
1470 if (weights.find(node3->input(1)) == weights.end())
1471 continue;
1472
1473 const onnx::TensorProto& pow_two = weights[node3->input(1)];
1474 if (pow_two.dims_size() != 0 || get_tensor_proto_data_size(pow_two) != 1)
1475 continue;
1476
1477 float constant_pow_two = get_node_attr_from_input_f(pow_two);
1478 if (constant_pow_two != 2.f)
1479 continue;
1480
1481 std::vector<int> axes4 = get_node_attr_ai(*node4, "axes");
1482
1483 // -1
1484 // -2 -1
1485 if ((int)axes4.size() != normed_axes)
1486 continue;
1487
1488 if (normed_axes == 1 && axes4[0] != -1)
1489 continue;
1490 if (normed_axes == 2 && (axes4[0] != -2 || axes4[1] != -1))
1491 continue;
1492
1493 if (weights.find(node5->input(1)) == weights.end())
1494 continue;
1495
1496 const onnx::TensorProto& add_eps = weights[node5->input(1)];
1497 if (add_eps.dims_size() != 0 || get_tensor_proto_data_size(add_eps) != 1)
1498 continue;
1499
1500 float eps = get_node_attr_from_input_f(add_eps);
1501
1502 int affine = 0;
1503 while (i + 8 < node_count)
1504 {
1505 onnx::NodeProto* node8 = mutable_graph->mutable_node(i + 7);
1506 onnx::NodeProto* node9 = mutable_graph->mutable_node(i + 8);
1507
1508 if (node8->op_type() != "Mul" || node9->op_type() != "Add")
1509 break;
1510
1511 if (node_reference[node7->output(0)] != 1)
1512 break;
1513
1514 if (node_reference[node8->output(0)] != 1)
1515 break;
1516
1517 if (node8->input(0) != node7->output(0) || node9->input(0) != node8->output(0))
1518 break;
1519
1520 // affine
1521 std::vector<float> affine_S = get_node_attr_from_input_af(weights[node8->input(1)]);
1522 std::vector<float> affine_B = get_node_attr_from_input_af(weights[node9->input(1)]);
1523 if (affine_S.size() != affine_B.size())
1524 break;
1525
1526 affine = 1;
1527 break;
1528 }
1529
1530 // reduce
1531 node->set_op_type("noop_reducedncnn");
1532 node2->set_op_type("noop_reducedncnn");
1533 node3->set_op_type("noop_reducedncnn");
1534 node4->set_op_type("noop_reducedncnn");
1535 node5->set_op_type("noop_reducedncnn");
1536 node6->set_op_type("noop_reducedncnn");
1537
1538 node_reference[node->input(0)] -= 1;
1539 node_reference[node2->input(0)] -= 1;
1540 node_reference[node2->input(1)] -= 1;
1541 node_reference[node3->input(0)] -= 1;
1542 node_reference[node3->input(1)] -= 1;
1543 node_reference[node4->input(0)] -= 1;
1544 node_reference[node5->input(0)] -= 1;
1545 node_reference[node5->input(1)] -= 1;
1546 node_reference[node6->input(0)] -= 1;
1547 node_reference[node7->input(0)] -= 1;
1548 node_reference[node7->input(1)] -= 1;
1549
1550 blob_names.erase(node->output(0));
1551 blob_names.erase(node2->output(0));
1552 blob_names.erase(node3->output(0));
1553 blob_names.erase(node4->output(0));
1554 blob_names.erase(node5->output(0));
1555 blob_names.erase(node6->output(0));
1556
1557 node_reference[node->input(0)] += 1;
1558
1559 if (affine == 0)
1560 {
1561 node7->set_op_type("LayerNorm");
1562 node7->clear_input();
1563 node7->add_input(node->input(0));
1564
1565 onnx::AttributeProto* attr_eps = node7->add_attribute();
1566 attr_eps->set_name("epsilon");
1567 attr_eps->set_f(eps);
1568
1569 onnx::AttributeProto* attr_affine = node7->add_attribute();
1570 attr_affine->set_name("affine");
1571 attr_affine->set_i(affine);
1572
1573 reduced_node_count += 6;
1574 i += 6;
1575 }
1576 else // if (affine == 1)
1577 {
1578 onnx::NodeProto* node8 = mutable_graph->mutable_node(i + 7);
1579 onnx::NodeProto* node9 = mutable_graph->mutable_node(i + 8);
1580
1581 node7->set_op_type("noop_reducedncnn");
1582 node8->set_op_type("noop_reducedncnn");
1583
1584 node_reference[node8->input(0)] -= 1;
1585 node_reference[node9->input(0)] -= 1;
1586
1587 blob_names.erase(node7->output(0));
1588 blob_names.erase(node8->output(0));
1589
1590 std::string affine_scale = node8->input(1);
1591 std::string affine_bias = node9->input(1);
1592
1593 node9->set_op_type("LayerNorm");
1594 node9->clear_input();
1595 node9->add_input(node->input(0));
1596 node9->add_input(affine_scale);
1597 node9->add_input(affine_bias);
1598
1599 onnx::AttributeProto* attr_eps = node9->add_attribute();
1600 attr_eps->set_name("epsilon");
1601 attr_eps->set_f(eps);
1602
1603 onnx::AttributeProto* attr_affine = node9->add_attribute();
1604 attr_affine->set_name("affine");
1605 attr_affine->set_i(affine);
1606
1607 reduced_node_count += 8;
1608 i += 8;
1609 }
1610 }
1611 }
1612 }
1613
fuse_flatten(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)1614 static void fuse_flatten(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
1615 {
1616 int node_count = mutable_graph->node_size();
1617 for (int i = 0; i < node_count; i++)
1618 {
1619 onnx::NodeProto* node = mutable_graph->mutable_node(i);
1620
1621 // Flatten <= X - Shape - Gather - Constant - Unsqueeze - Unsqueeze - Concat - Reshape
1622 if (node->op_type() == "Shape")
1623 {
1624 if (node_reference[node->output(0)] != 1)
1625 continue;
1626
1627 if (i + 6 >= node_count)
1628 continue;
1629
1630 onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
1631 onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
1632 onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3);
1633 onnx::NodeProto* node5 = mutable_graph->mutable_node(i + 4);
1634 onnx::NodeProto* node6 = mutable_graph->mutable_node(i + 5);
1635 onnx::NodeProto* node7 = mutable_graph->mutable_node(i + 6);
1636
1637 if (node2->op_type() != "Gather" || node3->op_type() != "Constant" || node4->op_type() != "Unsqueeze" || node5->op_type() != "Unsqueeze"
1638 || node6->op_type() != "Concat" || node7->op_type() != "Reshape")
1639 continue;
1640
1641 if (node_reference[node2->output(0)] != 1)
1642 continue;
1643
1644 // if (node_reference[node3->output(0)] != 1)
1645 // continue;
1646
1647 if (node_reference[node4->output(0)] != 1)
1648 continue;
1649
1650 if (node_reference[node5->output(0)] != 1)
1651 continue;
1652
1653 if (node_reference[node6->output(0)] != 1)
1654 continue;
1655
1656 if (node2->input(0) != node->output(0) || node4->input(0) != node2->output(0) || node5->input(0) != node3->output(0)
1657 || node6->input(0) != node4->output(0) || node6->input(1) != node5->output(0)
1658 || node7->input(0) != node->input(0) || node7->input(1) != node6->output(0))
1659 continue;
1660
1661 // axis = 0
1662 int gather_axis = get_node_attr_i(*node2, "axis");
1663 if (gather_axis != 0)
1664 continue;
1665
1666 // indices = 0
1667 if (weights.find(node2->input(1)) == weights.end())
1668 continue;
1669
1670 std::vector<int> gather_indices = get_node_attr_from_input_ai(weights[node2->input(1)]);
1671 if (gather_indices.size() != 1 || gather_indices[0] != 0)
1672 continue;
1673
1674 // axes = (0)
1675 std::vector<int> unsqueeze_axes = get_node_attr_ai(*node4, "axes");
1676 if (unsqueeze_axes.size() != 1)
1677 continue;
1678 if (unsqueeze_axes[0] != 0)
1679 continue;
1680
1681 // axes = (0)
1682 std::vector<int> unsqueeze2_axes = get_node_attr_ai(*node5, "axes");
1683 if (unsqueeze2_axes.size() != 1)
1684 continue;
1685 if (unsqueeze2_axes[0] != 0)
1686 continue;
1687
1688 // data = -1
1689 if (weights.find(node5->input(0)) == weights.end())
1690 continue;
1691
1692 std::vector<int> unsqueeze2_data = get_node_attr_from_input_ai(weights[node5->input(0)]);
1693 if (unsqueeze2_data.size() != 1 || unsqueeze2_data[0] != -1)
1694 continue;
1695
1696 // axis = 0
1697 int concat_axis = get_node_attr_i(*node6, "axis");
1698 if (concat_axis != 0)
1699 continue;
1700
1701 // reduce
1702 node->set_op_type("noop_reducedncnn");
1703 node2->set_op_type("noop_reducedncnn");
1704 // node3->set_op_type("noop_reducedncnn");
1705 node4->set_op_type("noop_reducedncnn");
1706 node5->set_op_type("noop_reducedncnn");
1707 node6->set_op_type("noop_reducedncnn");
1708
1709 node_reference[node->input(0)] -= 1;
1710 node_reference[node->output(0)] -= 1;
1711 node_reference[node2->input(1)] -= 1;
1712 node_reference[node2->output(0)] -= 1;
1713 // node_reference[node3->output(0)] -= 1;
1714 node_reference[node4->output(0)] -= 1;
1715 node_reference[node5->input(0)] -= 1;
1716 node_reference[node5->output(0)] -= 1;
1717 node_reference[node6->output(0)] -= 1;
1718
1719 blob_names.erase(node->output(0));
1720 blob_names.erase(node2->output(0));
1721 // blob_names.erase(node3->output(0));
1722 blob_names.erase(node4->output(0));
1723 blob_names.erase(node5->output(0));
1724 blob_names.erase(node6->output(0));
1725
1726 node7->set_op_type("Flatten");
1727 node7->clear_input();
1728 node7->add_input(node->input(0));
1729
1730 reduced_node_count += 5;
1731 i += 5;
1732 }
1733 }
1734 }
1735
fuse_pixelshuffle(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)1736 static void fuse_pixelshuffle(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
1737 {
1738 int node_count = mutable_graph->node_size();
1739 for (int i = 0; i < node_count; i++)
1740 {
1741 onnx::NodeProto* node = mutable_graph->mutable_node(i);
1742
1743 // PixelShuffle <= Reshape - Transpose - Reshape
1744 // PixelShuffle <= Reshape - Transpose - Constant - Reshape
1745 if (node->op_type() == "Reshape")
1746 {
1747 if (node_reference[node->output(0)] != 1)
1748 continue;
1749
1750 std::vector<int> shape;
1751 if (node->input_size() == 1)
1752 {
1753 shape = get_node_attr_ai(*node, "shape");
1754 }
1755 else
1756 {
1757 // skip weight reshape
1758 if (weights.find(node->input(1)) == weights.end())
1759 continue;
1760
1761 shape = get_node_attr_from_input_ai(weights[node->input(1)]);
1762 }
1763
1764 // -1, 3, upscale_factor, upscale_factor, height, width
1765 if (shape.size() != 6)
1766 continue;
1767
1768 if (shape[0] != 1 && shape[0] != -1)
1769 continue;
1770
1771 if (shape[2] != shape[3])
1772 continue;
1773
1774 if (i + 2 >= node_count)
1775 continue;
1776
1777 onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
1778 onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
1779
1780 if (node3->op_type() == "Constant")
1781 {
1782 if (i + 3 >= node_count)
1783 continue;
1784
1785 node3 = mutable_graph->mutable_node(i + 3);
1786 }
1787
1788 if (node2->op_type() != "Transpose" || node3->op_type() != "Reshape")
1789 continue;
1790
1791 if (node_reference[node2->output(0)] != 1)
1792 continue;
1793
1794 // 0 1 4 2 5 3
1795 std::vector<int> perm = get_node_attr_ai(*node2, "perm");
1796 if (perm.size() != 6)
1797 continue;
1798
1799 if (perm[0] != 0 || perm[1] != 1 || perm[2] != 4 || perm[3] != 2 || perm[4] != 5 || perm[5] != 3)
1800 continue;
1801
1802 std::vector<int> shape3;
1803 if (node3->input_size() == 1)
1804 {
1805 shape3 = get_node_attr_ai(*node3, "shape");
1806 }
1807 else
1808 {
1809 // skip weight reshape
1810 if (weights.find(node3->input(1)) == weights.end())
1811 continue;
1812
1813 shape3 = get_node_attr_from_input_ai(weights[node3->input(1)]);
1814 }
1815
1816 // -1, 3, height, width
1817 if (shape3.size() != 4)
1818 continue;
1819
1820 if (shape3[0] != 1 && shape3[0] != -1)
1821 continue;
1822
1823 if (shape3[1] != shape[1] || shape3[2] != shape[2] * shape[4] || shape3[3] != shape[3] * shape[5])
1824 continue;
1825
1826 // reduce
1827 node->set_op_type("noop_reducedncnn");
1828 node2->set_op_type("noop_reducedncnn");
1829
1830 if (node->input_size() == 2)
1831 {
1832 node_reference[node->input(1)] -= 1;
1833 }
1834 node_reference[node->output(0)] -= 1;
1835 node_reference[node2->output(0)] -= 1;
1836 if (node3->input_size() == 2)
1837 {
1838 node_reference[node3->input(1)] -= 1;
1839 }
1840
1841 blob_names.erase(node->output(0));
1842 blob_names.erase(node2->output(0));
1843
1844 node3->set_op_type("PixelShuffle");
1845 node3->set_input(0, node->input(0));
1846
1847 onnx::AttributeProto* attr_group = node3->add_attribute();
1848 attr_group->set_name("scale_factor");
1849 attr_group->set_i(shape[2]);
1850
1851 reduced_node_count += 2;
1852 i += 2;
1853 }
1854 }
1855 }
1856
fuse_reorg(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)1857 static void fuse_reorg(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
1858 {
1859 int node_count = mutable_graph->node_size();
1860 for (int i = 0; i < node_count; i++)
1861 {
1862 onnx::NodeProto* node = mutable_graph->mutable_node(i);
1863
1864 // PixelShuffle <= Reshape - Transpose - Reshape
1865 // PixelShuffle <= Reshape - Transpose - Constant - Reshape
1866 if (node->op_type() == "Reshape")
1867 {
1868 if (node_reference[node->output(0)] != 1)
1869 continue;
1870
1871 std::vector<int> shape;
1872 if (node->input_size() == 1)
1873 {
1874 shape = get_node_attr_ai(*node, "shape");
1875 }
1876 else
1877 {
1878 // skip weight reshape
1879 if (weights.find(node->input(1)) == weights.end())
1880 continue;
1881
1882 shape = get_node_attr_from_input_ai(weights[node->input(1)]);
1883 }
1884
1885 // -1, 3, out_height, block_size, out_width, block_size
1886 if (shape.size() != 6)
1887 continue;
1888
1889 if (shape[0] != 1 && shape[0] != -1)
1890 continue;
1891
1892 if (shape[3] != shape[5])
1893 continue;
1894
1895 if (i + 2 >= node_count)
1896 continue;
1897
1898 onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
1899 onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
1900
1901 if (node3->op_type() == "Constant")
1902 {
1903 if (i + 3 >= node_count)
1904 continue;
1905
1906 node3 = mutable_graph->mutable_node(i + 3);
1907 }
1908
1909 if (node2->op_type() != "Transpose" || node3->op_type() != "Reshape")
1910 continue;
1911
1912 if (node_reference[node2->output(0)] != 1)
1913 continue;
1914
1915 // 0 1 3 5 2 4
1916 std::vector<int> perm = get_node_attr_ai(*node2, "perm");
1917 if (perm.size() != 6)
1918 continue;
1919
1920 if (perm[0] != 0 || perm[1] != 1 || perm[2] != 3 || perm[3] != 5 || perm[4] != 2 || perm[5] != 4)
1921 continue;
1922
1923 std::vector<int> shape3;
1924 if (node3->input_size() == 1)
1925 {
1926 shape3 = get_node_attr_ai(*node3, "shape");
1927 }
1928 else
1929 {
1930 // skip weight reshape
1931 if (weights.find(node3->input(1)) == weights.end())
1932 continue;
1933
1934 shape3 = get_node_attr_from_input_ai(weights[node3->input(1)]);
1935 }
1936
1937 // -1, out_channels, out_height, out_width
1938 if (shape3.size() != 4)
1939 continue;
1940
1941 if (shape3[0] != 1 && shape3[0] != -1)
1942 continue;
1943
1944 if (shape3[1] != shape[1] * shape[3] * shape[5] || shape3[2] != shape[2] || shape3[3] != shape[4])
1945 continue;
1946
1947 // reduce
1948 node->set_op_type("noop_reducedncnn");
1949 node2->set_op_type("noop_reducedncnn");
1950
1951 if (node->input_size() == 2)
1952 {
1953 node_reference[node->input(1)] -= 1;
1954 }
1955 node_reference[node->output(0)] -= 1;
1956 node_reference[node2->output(0)] -= 1;
1957 if (node3->input_size() == 2)
1958 {
1959 node_reference[node3->input(1)] -= 1;
1960 }
1961
1962 blob_names.erase(node->output(0));
1963 blob_names.erase(node2->output(0));
1964
1965 node3->set_op_type("Reorg");
1966 node3->set_input(0, node->input(0));
1967
1968 onnx::AttributeProto* attr_group = node3->add_attribute();
1969 attr_group->set_name("stride");
1970 attr_group->set_i(shape[3]);
1971
1972 reduced_node_count += 2;
1973 i += 2;
1974 }
1975 }
1976 }
1977
fuse_expand_broadcast(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)1978 static void fuse_expand_broadcast(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
1979 {
1980 int node_count = mutable_graph->node_size();
1981 for (int i = 0; i < node_count; i++)
1982 {
1983 onnx::NodeProto* node = mutable_graph->mutable_node(i);
1984
1985 // Add/Sub/Mul/Div/Min/Max <= Expand - Add/Sub/Mul/Div/Min/Max
1986 if (node->op_type() == "Expand")
1987 {
1988 if (node_reference[node->output(0)] != 1)
1989 continue;
1990
1991 if (i + 1 >= node_count)
1992 continue;
1993
1994 onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
1995
1996 if (node2->op_type() != "Add" && node2->op_type() != "Sub" && node2->op_type() != "Mul" && node2->op_type() != "Div" && node2->op_type() != "Min" && node2->op_type() != "Max")
1997 continue;
1998
1999 if (node2->input(1) != node->output(0) && node2->input(0) != node->output(0))
2000 continue;
2001
2002 // reduce
2003 node->set_op_type("noop_reducedncnn");
2004
2005 node_reference[node->output(0)] -= 1;
2006 if (node->input_size() == 2)
2007 {
2008 node_reference[node->input(1)] -= 1;
2009 }
2010
2011 blob_names.erase(node->output(0));
2012
2013 if (node2->input(0) == node->output(0))
2014 {
2015 node2->set_input(0, node->input(0));
2016 }
2017 else
2018 {
2019 node2->set_input(1, node->input(0));
2020 }
2021
2022 reduced_node_count += 1;
2023 i += 1;
2024 }
2025 }
2026 }
2027
fuse_lstm_gru_rnn(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)2028 static void fuse_lstm_gru_rnn(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
2029 {
2030 int node_count = mutable_graph->node_size();
2031 for (int i = 0; i < node_count; i++)
2032 {
2033 onnx::NodeProto* node = mutable_graph->mutable_node(i);
2034
2035 // LSTM(bi) <= LSTM(bi) - Transpose - Reshape - Transpose
2036 if (node->op_type() == "LSTM" || node->op_type() == "GRU" || node->op_type() == "RNN")
2037 {
2038 if (node_reference[node->output(0)] != 1)
2039 continue;
2040
2041 if (i + 2 >= node_count)
2042 continue;
2043
2044 onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
2045 onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
2046
2047 if (node2->op_type() != "Transpose" || node3->op_type() != "Reshape")
2048 continue;
2049
2050 if (node_reference[node2->output(0)] != 1)
2051 continue;
2052
2053 if (node2->input(0) != node->output(0) || node3->input(0) != node2->output(0))
2054 continue;
2055
2056 std::string direction = get_node_attr_s(*node, "direction");
2057 if (direction != "bidirectional")
2058 continue;
2059
2060 // 0 2 1 3
2061 std::vector<int> perm = get_node_attr_ai(*node2, "perm");
2062 if (perm.size() != 4)
2063 continue;
2064
2065 if (perm[0] != 0 || perm[1] != 2 || perm[2] != 1 || perm[3] != 3)
2066 continue;
2067
2068 std::vector<int> shape;
2069 if (node3->input_size() == 1)
2070 {
2071 shape = get_node_attr_ai(*node3, "shape");
2072 }
2073 else
2074 {
2075 // skip weight reshape
2076 if (weights.find(node3->input(1)) == weights.end())
2077 continue;
2078
2079 shape = get_node_attr_from_input_ai(weights[node3->input(1)]);
2080 }
2081
2082 // 0 0 -1
2083 if (shape.size() != 3)
2084 continue;
2085
2086 if (shape[0] != 0 || shape[1] != 0 || shape[2] != -1)
2087 continue;
2088
2089 // reduce
2090 node2->set_op_type("noop_reducedncnn");
2091 node3->set_op_type("noop_reducedncnn");
2092
2093 node_reference[node->output(0)] -= 1;
2094 node_reference[node2->output(0)] -= 1;
2095 if (node3->input_size() == 2)
2096 {
2097 node_reference[node3->input(1)] -= 1;
2098 }
2099
2100 blob_names.erase(node->output(0));
2101 blob_names.erase(node2->output(0));
2102
2103 node->set_output(0, node3->output(0));
2104
2105 reduced_node_count += 2;
2106 i += 2;
2107
2108 if (i + 1 < node_count)
2109 {
2110 if (node_reference[node3->output(0)] != 1)
2111 continue;
2112
2113 onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 1);
2114
2115 if (node4->op_type() != "Transpose")
2116 continue;
2117
2118 if (node4->input(0) != node->output(0))
2119 continue;
2120
2121 // 1 0 2
2122 std::vector<int> perm4 = get_node_attr_ai(*node4, "perm");
2123 if (perm4.size() != 3)
2124 continue;
2125
2126 if (perm4[0] != 1 || perm4[1] != 0 || perm4[2] != 2)
2127 continue;
2128
2129 // reduce
2130 node4->set_op_type("noop_reducedncnn");
2131
2132 node_reference[node->output(0)] -= 1;
2133
2134 blob_names.erase(node->output(0));
2135
2136 node->set_output(0, node4->output(0));
2137
2138 reduced_node_count += 1;
2139 i += 1;
2140 }
2141 }
2142 }
2143
2144 for (int i = 0; i < node_count; i++)
2145 {
2146 onnx::NodeProto* node = mutable_graph->mutable_node(i);
2147
2148 // LSTM(uni) <= LSTM(uni) - Squeeze - Transpose
2149 if (node->op_type() == "LSTM" || node->op_type() == "GRU" || node->op_type() == "RNN")
2150 {
2151 if (node_reference[node->output(0)] != 1)
2152 continue;
2153
2154 if (i + 1 >= node_count)
2155 continue;
2156
2157 onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
2158
2159 if (node2->op_type() != "Squeeze")
2160 continue;
2161
2162 if (node2->input(0) != node->output(0))
2163 continue;
2164
2165 std::string direction = get_node_attr_s(*node, "direction");
2166 if (direction == "bidirectional")
2167 continue;
2168
2169 // 1
2170 std::vector<int> axes = get_node_attr_ai(*node2, "axes");
2171 if (axes.size() != 1)
2172 continue;
2173
2174 if (axes[0] != 1)
2175 continue;
2176
2177 // reduce
2178 node2->set_op_type("noop_reducedncnn");
2179
2180 node_reference[node->output(0)] -= 1;
2181
2182 blob_names.erase(node->output(0));
2183
2184 node->set_output(0, node2->output(0));
2185
2186 reduced_node_count += 1;
2187 i += 1;
2188
2189 if (i + 1 < node_count)
2190 {
2191 if (node_reference[node2->output(0)] != 1)
2192 continue;
2193
2194 onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 1);
2195
2196 if (node3->op_type() != "Transpose")
2197 continue;
2198
2199 if (node3->input(0) != node->output(0))
2200 continue;
2201
2202 // 1 0 2
2203 std::vector<int> perm4 = get_node_attr_ai(*node3, "perm");
2204 if (perm4.size() != 3)
2205 continue;
2206
2207 if (perm4[0] != 1 || perm4[1] != 0 || perm4[2] != 2)
2208 continue;
2209
2210 // reduce
2211 node3->set_op_type("noop_reducedncnn");
2212
2213 node_reference[node->output(0)] -= 1;
2214
2215 blob_names.erase(node->output(0));
2216
2217 node->set_output(0, node3->output(0));
2218
2219 reduced_node_count += 1;
2220 i += 1;
2221 }
2222 }
2223 }
2224
2225 for (int i = 0; i < node_count; i++)
2226 {
2227 onnx::NodeProto* node = mutable_graph->mutable_node(i);
2228
2229 // LSTM <= Transpose - LSTM
2230 if (node->op_type() == "Transpose")
2231 {
2232 if (node_reference[node->output(0)] != 1)
2233 continue;
2234
2235 // 1 0 2
2236 std::vector<int> perm = get_node_attr_ai(*node, "perm");
2237 if (perm.size() != 3)
2238 continue;
2239
2240 if (perm[0] != 1 || perm[1] != 0 || perm[2] != 2)
2241 continue;
2242
2243 if (i + 1 >= node_count)
2244 continue;
2245
2246 onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
2247
2248 if (node2->op_type() != "LSTM" && node->op_type() != "GRU" && node->op_type() != "RNN")
2249 continue;
2250
2251 if (node2->input(0) != node->output(0))
2252 continue;
2253
2254 // reduce
2255 node->set_op_type("noop_reducedncnn");
2256
2257 node_reference[node->output(0)] -= 1;
2258
2259 blob_names.erase(node->output(0));
2260
2261 node2->set_input(0, node->input(0));
2262
2263 reduced_node_count += 1;
2264 i += 1;
2265 }
2266 }
2267 }
2268
fuse_multiheadattention(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)2269 static void fuse_multiheadattention(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
2270 {
2271 int node_count = mutable_graph->node_size();
2272 for (int i = 0; i < node_count; i++)
2273 {
2274 onnx::NodeProto* node = mutable_graph->mutable_node(i);
2275
2276 // MultiHeadAttention <= MatMul(q) - Add
2277 // - MatMul(k) - Add
2278 // - MatMul(v) - Add
2279 // - Mul
2280 // - Reshape - Transpose
2281 // - Reshape - Reshape - Transpose - Transpose
2282 // - Gemm - Softmax - Gemm - Transpose - Reshape - MatMul - Add
2283 if (node->op_type() == "MatMul")
2284 {
2285 if (i + 19 >= node_count)
2286 continue;
2287
2288 if (node_reference[node->output(0)] != 1)
2289 continue;
2290
2291 onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
2292 onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
2293 onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3);
2294 onnx::NodeProto* node5 = mutable_graph->mutable_node(i + 4);
2295 onnx::NodeProto* node6 = mutable_graph->mutable_node(i + 5);
2296 onnx::NodeProto* node7 = mutable_graph->mutable_node(i + 6);
2297 onnx::NodeProto* node8 = mutable_graph->mutable_node(i + 7);
2298 onnx::NodeProto* node9 = mutable_graph->mutable_node(i + 8);
2299 onnx::NodeProto* node10 = mutable_graph->mutable_node(i + 9);
2300 onnx::NodeProto* node11 = mutable_graph->mutable_node(i + 10);
2301 onnx::NodeProto* node12 = mutable_graph->mutable_node(i + 11);
2302 onnx::NodeProto* node13 = mutable_graph->mutable_node(i + 12);
2303 onnx::NodeProto* node14 = mutable_graph->mutable_node(i + 13);
2304 onnx::NodeProto* node15 = mutable_graph->mutable_node(i + 14);
2305 onnx::NodeProto* node16 = mutable_graph->mutable_node(i + 15);
2306 onnx::NodeProto* node17 = mutable_graph->mutable_node(i + 16);
2307 onnx::NodeProto* node18 = mutable_graph->mutable_node(i + 17);
2308 onnx::NodeProto* node19 = mutable_graph->mutable_node(i + 18);
2309 onnx::NodeProto* node20 = mutable_graph->mutable_node(i + 19);
2310
2311 if (node2->op_type() != "Add" || node3->op_type() != "MatMul" || node4->op_type() != "Add" || node5->op_type() != "MatMul" || node6->op_type() != "Add" || node7->op_type() != "Mul" || node8->op_type() != "Reshape" || node9->op_type() != "Transpose" || node10->op_type() != "Reshape" || node11->op_type() != "Reshape" || node12->op_type() != "Transpose" || node13->op_type() != "Transpose" || node14->op_type() != "MatMul" || node15->op_type() != "Softmax" || node16->op_type() != "MatMul" || node17->op_type() != "Transpose" || node18->op_type() != "Reshape" || node19->op_type() != "MatMul" || node20->op_type() != "Add")
2312 continue;
2313
2314 if (node_reference[node2->output(0)] != 1 || node_reference[node3->output(0)] != 1 || node_reference[node4->output(0)] != 1 || node_reference[node5->output(0)] != 1 || node_reference[node6->output(0)] != 1 || node_reference[node7->output(0)] != 1 || node_reference[node8->output(0)] != 1 || node_reference[node9->output(0)] != 1 || node_reference[node10->output(0)] != 1 || node_reference[node11->output(0)] != 1 || node_reference[node12->output(0)] != 1 || node_reference[node13->output(0)] != 1 || node_reference[node14->output(0)] != 1 || node_reference[node15->output(0)] != 1 || node_reference[node16->output(0)] != 1 || node_reference[node17->output(0)] != 1 || node_reference[node18->output(0)] != 1 || node_reference[node19->output(0)] != 1)
2315 continue;
2316
2317 if (node2->input(0) != node->output(0) || node4->input(0) != node3->output(0) || node6->input(0) != node5->output(0) || node7->input(0) != node2->output(0) || node8->input(0) != node7->output(0) || node9->input(0) != node8->output(0) || node10->input(0) != node4->output(0) || node11->input(0) != node6->output(0) || node12->input(0) != node11->output(0) || node13->input(0) != node10->output(0) || node14->input(0) != node9->output(0) || node14->input(1) != node13->output(0) || node15->input(0) != node14->output(0) || node16->input(0) != node15->output(0) || node16->input(1) != node12->output(0) || node17->input(0) != node16->output(0) || node18->input(0) != node17->output(0) || node19->input(0) != node18->output(0) || node20->input(0) != node19->output(0))
2318 continue;
2319
2320 std::vector<float> q_B = get_node_attr_from_input_af(weights[node2->input(1)]);
2321 std::vector<float> k_B = get_node_attr_from_input_af(weights[node4->input(1)]);
2322 std::vector<float> v_B = get_node_attr_from_input_af(weights[node6->input(1)]);
2323 std::vector<float> o_B = get_node_attr_from_input_af(weights[node20->input(1)]);
2324
2325 if (q_B.size() != k_B.size() || q_B.size() != v_B.size() || q_B.size() != o_B.size())
2326 continue;
2327
2328 int embed_dim = q_B.size();
2329
2330 // 1 0 2
2331 std::vector<int> perm9 = get_node_attr_ai(*node9, "perm");
2332 std::vector<int> perm12 = get_node_attr_ai(*node12, "perm");
2333 if (perm9.size() != 3 || perm12.size() != 3)
2334 continue;
2335
2336 if (perm9[0] != 1 || perm9[1] != 0 || perm9[2] != 2 || perm12[0] != 1 || perm12[1] != 0 || perm12[2] != 2)
2337 continue;
2338
2339 // 1 2 0
2340 std::vector<int> perm13 = get_node_attr_ai(*node13, "perm");
2341 if (perm13.size() != 3)
2342 continue;
2343
2344 if (perm13[0] != 1 || perm13[1] != 2 || perm13[2] != 0)
2345 continue;
2346
2347 // 1 0 2
2348 std::vector<int> perm17 = get_node_attr_ai(*node17, "perm");
2349 if (perm17.size() != 3)
2350 continue;
2351
2352 if (perm17[0] != 1 || perm17[1] != 0 || perm17[2] != 2)
2353 continue;
2354
2355 int softmax_axis = get_node_attr_i(*node15, "axis");
2356 if (softmax_axis != 2)
2357 continue;
2358
2359 // 1/-1, seqlen * num_heads, embed_dim / num_heads
2360 std::vector<int> shape8;
2361 std::vector<int> shape10;
2362 std::vector<int> shape11;
2363 if (node8->input_size() == 1)
2364 {
2365 shape8 = get_node_attr_ai(*node8, "shape");
2366 }
2367 else
2368 {
2369 // skip weight reshape
2370 if (weights.find(node8->input(1)) == weights.end())
2371 continue;
2372
2373 shape8 = get_node_attr_from_input_ai(weights[node8->input(1)]);
2374 }
2375 if (node10->input_size() == 1)
2376 {
2377 shape10 = get_node_attr_ai(*node10, "shape");
2378 }
2379 else
2380 {
2381 // skip weight reshape
2382 if (weights.find(node10->input(1)) == weights.end())
2383 continue;
2384
2385 shape10 = get_node_attr_from_input_ai(weights[node10->input(1)]);
2386 }
2387 if (node11->input_size() == 1)
2388 {
2389 shape11 = get_node_attr_ai(*node11, "shape");
2390 }
2391 else
2392 {
2393 // skip weight reshape
2394 if (weights.find(node11->input(1)) == weights.end())
2395 continue;
2396
2397 shape11 = get_node_attr_from_input_ai(weights[node11->input(1)]);
2398 }
2399
2400 if (shape8.size() != 3 || shape10.size() != 3 || shape11.size() != 3)
2401 continue;
2402
2403 if (shape8[1] != shape10[1] || shape8[1] != shape11[1] || shape8[2] != shape10[2] || shape8[2] != shape11[2])
2404 continue;
2405
2406 int num_heads = embed_dim / shape8[2];
2407
2408 // 1, seqlen, embed_dim
2409 std::vector<int> shape18;
2410 if (node18->input_size() == 1)
2411 {
2412 shape18 = get_node_attr_ai(*node18, "shape");
2413 }
2414 else
2415 {
2416 // skip weight reshape
2417 if (weights.find(node18->input(1)) == weights.end())
2418 continue;
2419
2420 shape18 = get_node_attr_from_input_ai(weights[node18->input(1)]);
2421 }
2422
2423 if (shape18.size() != 3)
2424 continue;
2425
2426 if (shape18[2] != embed_dim || shape18[1] * num_heads != shape8[1])
2427 continue;
2428
2429 // reduce
2430 node->set_op_type("noop_reducedncnn");
2431 node2->set_op_type("noop_reducedncnn");
2432 node3->set_op_type("noop_reducedncnn");
2433 node4->set_op_type("noop_reducedncnn");
2434 node5->set_op_type("noop_reducedncnn");
2435 node6->set_op_type("noop_reducedncnn");
2436 node7->set_op_type("noop_reducedncnn");
2437 node8->set_op_type("noop_reducedncnn");
2438 node9->set_op_type("noop_reducedncnn");
2439 node10->set_op_type("noop_reducedncnn");
2440 node11->set_op_type("noop_reducedncnn");
2441 node12->set_op_type("noop_reducedncnn");
2442 node13->set_op_type("noop_reducedncnn");
2443 node14->set_op_type("noop_reducedncnn");
2444 node15->set_op_type("noop_reducedncnn");
2445 node16->set_op_type("noop_reducedncnn");
2446 node17->set_op_type("noop_reducedncnn");
2447 node18->set_op_type("noop_reducedncnn");
2448 node19->set_op_type("noop_reducedncnn");
2449
2450 node_reference[node2->input(0)] -= 1;
2451 node_reference[node4->input(0)] -= 1;
2452 node_reference[node6->input(0)] -= 1;
2453 node_reference[node7->input(0)] -= 1;
2454 node_reference[node7->input(1)] -= 1;
2455 node_reference[node8->input(0)] -= 1;
2456 if (node8->input_size() == 2)
2457 {
2458 node_reference[node8->input(1)] -= 1;
2459 }
2460 node_reference[node9->input(0)] -= 1;
2461 node_reference[node10->input(0)] -= 1;
2462 if (node10->input_size() == 2)
2463 {
2464 node_reference[node10->input(1)] -= 1;
2465 }
2466 node_reference[node11->input(0)] -= 1;
2467 if (node11->input_size() == 2)
2468 {
2469 node_reference[node11->input(1)] -= 1;
2470 }
2471 node_reference[node12->input(0)] -= 1;
2472 node_reference[node13->input(0)] -= 1;
2473 node_reference[node14->input(0)] -= 1;
2474 node_reference[node14->input(1)] -= 1;
2475 node_reference[node15->input(0)] -= 1;
2476 node_reference[node16->input(0)] -= 1;
2477 node_reference[node16->input(1)] -= 1;
2478 node_reference[node17->input(0)] -= 1;
2479 node_reference[node18->input(0)] -= 1;
2480 if (node18->input_size() == 2)
2481 {
2482 node_reference[node18->input(1)] -= 1;
2483 }
2484 node_reference[node19->input(0)] -= 1;
2485 node_reference[node20->input(0)] -= 1;
2486
2487 blob_names.erase(node->output(0));
2488 blob_names.erase(node2->output(0));
2489 blob_names.erase(node3->output(0));
2490 blob_names.erase(node4->output(0));
2491 blob_names.erase(node5->output(0));
2492 blob_names.erase(node6->output(0));
2493 blob_names.erase(node7->output(0));
2494 blob_names.erase(node8->output(0));
2495 blob_names.erase(node9->output(0));
2496 blob_names.erase(node10->output(0));
2497 blob_names.erase(node11->output(0));
2498 blob_names.erase(node12->output(0));
2499 blob_names.erase(node13->output(0));
2500 blob_names.erase(node14->output(0));
2501 blob_names.erase(node15->output(0));
2502 blob_names.erase(node16->output(0));
2503 blob_names.erase(node17->output(0));
2504 blob_names.erase(node18->output(0));
2505 blob_names.erase(node19->output(0));
2506
2507 std::string qw = node->input(1);
2508 std::string qb = node2->input(1);
2509 std::string kw = node3->input(1);
2510 std::string kb = node4->input(1);
2511 std::string vw = node5->input(1);
2512 std::string vb = node6->input(1);
2513 std::string ow = node19->input(1);
2514 std::string ob = node20->input(1);
2515
2516 node20->set_op_type("MultiHeadAttention");
2517 node20->clear_input();
2518 node20->add_input(node->input(0));
2519 node20->add_input(node3->input(0));
2520 node20->add_input(node5->input(0));
2521 // q
2522 node20->add_input(qw);
2523 node20->add_input(qb);
2524 // k
2525 node20->add_input(kw);
2526 node20->add_input(kb);
2527 // v
2528 node20->add_input(vw);
2529 node20->add_input(vb);
2530 // out linear
2531 node20->add_input(ow);
2532 node20->add_input(ob);
2533
2534 onnx::AttributeProto* attr_embed_dim = node20->add_attribute();
2535 attr_embed_dim->set_name("embed_dim");
2536 attr_embed_dim->set_i(embed_dim);
2537
2538 onnx::AttributeProto* attr_num_heads = node20->add_attribute();
2539 attr_num_heads->set_name("num_heads");
2540 attr_num_heads->set_i(num_heads);
2541
2542 reduced_node_count += 19;
2543 i += 19;
2544 }
2545 }
2546
2547 for (int i = 0; i < node_count; i++)
2548 {
2549 onnx::NodeProto* node = mutable_graph->mutable_node(i);
2550
2551 // MultiHeadAttention <= MatMul(qkv) - Add - Split
2552 // - Mul
2553 // - Reshape - Transpose
2554 // - Reshape - Reshape - Transpose - Transpose
2555 // - Gemm - Softmax - Gemm - Transpose - Reshape - MatMul - Add
2556 if (node->op_type() == "MatMul")
2557 {
2558 if (i + 16 >= node_count)
2559 continue;
2560
2561 if (node_reference[node->output(0)] != 1)
2562 continue;
2563
2564 onnx::NodeProto* node2 = mutable_graph->mutable_node(i + 1);
2565 onnx::NodeProto* node3 = mutable_graph->mutable_node(i + 2);
2566 onnx::NodeProto* node4 = mutable_graph->mutable_node(i + 3);
2567 onnx::NodeProto* node5 = mutable_graph->mutable_node(i + 4);
2568 onnx::NodeProto* node6 = mutable_graph->mutable_node(i + 5);
2569 onnx::NodeProto* node7 = mutable_graph->mutable_node(i + 6);
2570 onnx::NodeProto* node8 = mutable_graph->mutable_node(i + 7);
2571 onnx::NodeProto* node9 = mutable_graph->mutable_node(i + 8);
2572 onnx::NodeProto* node10 = mutable_graph->mutable_node(i + 9);
2573 onnx::NodeProto* node11 = mutable_graph->mutable_node(i + 10);
2574 onnx::NodeProto* node12 = mutable_graph->mutable_node(i + 11);
2575 onnx::NodeProto* node13 = mutable_graph->mutable_node(i + 12);
2576 onnx::NodeProto* node14 = mutable_graph->mutable_node(i + 13);
2577 onnx::NodeProto* node15 = mutable_graph->mutable_node(i + 14);
2578 onnx::NodeProto* node16 = mutable_graph->mutable_node(i + 15);
2579 onnx::NodeProto* node17 = mutable_graph->mutable_node(i + 16);
2580
2581 if (node2->op_type() != "Add" || node3->op_type() != "Split" || node4->op_type() != "Mul" || node5->op_type() != "Reshape" || node6->op_type() != "Transpose" || node7->op_type() != "Reshape" || node8->op_type() != "Reshape" || node9->op_type() != "Transpose" || node10->op_type() != "Transpose" || node11->op_type() != "MatMul" || node12->op_type() != "Softmax" || node13->op_type() != "MatMul" || node14->op_type() != "Transpose" || node15->op_type() != "Reshape" || node16->op_type() != "MatMul" || node17->op_type() != "Add")
2582 continue;
2583
2584 if (node_reference[node2->output(0)] != 1 || node_reference[node3->output(0)] != 1 || node_reference[node3->output(1)] != 1 || node_reference[node3->output(2)] != 1 || node_reference[node4->output(0)] != 1 || node_reference[node5->output(0)] != 1 || node_reference[node6->output(0)] != 1 || node_reference[node7->output(0)] != 1 || node_reference[node8->output(0)] != 1 || node_reference[node9->output(0)] != 1 || node_reference[node10->output(0)] != 1 || node_reference[node11->output(0)] != 1 || node_reference[node12->output(0)] != 1 || node_reference[node13->output(0)] != 1 || node_reference[node14->output(0)] != 1 || node_reference[node15->output(0)] != 1 || node_reference[node16->output(0)] != 1)
2585 continue;
2586
2587 if (node2->input(0) != node->output(0) || node3->input(0) != node2->output(0) || node4->input(0) != node3->output(0) || node5->input(0) != node4->output(0) || node6->input(0) != node5->output(0) || node7->input(0) != node3->output(1) || node8->input(0) != node3->output(2) || node9->input(0) != node8->output(0) || node10->input(0) != node7->output(0) || node11->input(0) != node6->output(0) || node11->input(1) != node10->output(0) || node12->input(0) != node11->output(0) || node13->input(0) != node12->output(0) || node13->input(1) != node9->output(0) || node14->input(0) != node13->output(0) || node15->input(0) != node14->output(0) || node16->input(0) != node15->output(0) || node17->input(0) != node16->output(0))
2588 continue;
2589
2590 std::vector<float> qkv_B = get_node_attr_from_input_af(weights[node2->input(1)]);
2591 std::vector<float> o_B = get_node_attr_from_input_af(weights[node17->input(1)]);
2592
2593 if (qkv_B.size() != o_B.size() * 3)
2594 continue;
2595
2596 int embed_dim = o_B.size();
2597
2598 // 1 0 2
2599 std::vector<int> perm6 = get_node_attr_ai(*node6, "perm");
2600 std::vector<int> perm9 = get_node_attr_ai(*node9, "perm");
2601 if (perm6.size() != 3 || perm9.size() != 3)
2602 continue;
2603
2604 if (perm6[0] != 1 || perm6[1] != 0 || perm6[2] != 2 || perm9[0] != 1 || perm9[1] != 0 || perm9[2] != 2)
2605 continue;
2606
2607 // 1 2 0
2608 std::vector<int> perm10 = get_node_attr_ai(*node10, "perm");
2609 if (perm10.size() != 3)
2610 continue;
2611
2612 if (perm10[0] != 1 || perm10[1] != 2 || perm10[2] != 0)
2613 continue;
2614
2615 // 1 0 2
2616 std::vector<int> perm14 = get_node_attr_ai(*node14, "perm");
2617 if (perm14.size() != 3)
2618 continue;
2619
2620 if (perm14[0] != 1 || perm14[1] != 0 || perm14[2] != 2)
2621 continue;
2622
2623 int softmax_axis = get_node_attr_i(*node12, "axis");
2624 if (softmax_axis != 2)
2625 continue;
2626
2627 // 1/-1, seqlen * num_heads, embed_dim / num_heads
2628 std::vector<int> shape5;
2629 std::vector<int> shape7;
2630 std::vector<int> shape8;
2631 if (node5->input_size() == 1)
2632 {
2633 shape5 = get_node_attr_ai(*node5, "shape");
2634 }
2635 else
2636 {
2637 // skip weight reshape
2638 if (weights.find(node5->input(1)) == weights.end())
2639 continue;
2640
2641 shape5 = get_node_attr_from_input_ai(weights[node5->input(1)]);
2642 }
2643 if (node7->input_size() == 1)
2644 {
2645 shape7 = get_node_attr_ai(*node7, "shape");
2646 }
2647 else
2648 {
2649 // skip weight reshape
2650 if (weights.find(node7->input(1)) == weights.end())
2651 continue;
2652
2653 shape7 = get_node_attr_from_input_ai(weights[node7->input(1)]);
2654 }
2655 if (node8->input_size() == 1)
2656 {
2657 shape8 = get_node_attr_ai(*node8, "shape");
2658 }
2659 else
2660 {
2661 // skip weight reshape
2662 if (weights.find(node8->input(1)) == weights.end())
2663 continue;
2664
2665 shape8 = get_node_attr_from_input_ai(weights[node8->input(1)]);
2666 }
2667
2668 if (shape5.size() != 3 || shape7.size() != 3 || shape8.size() != 3)
2669 continue;
2670
2671 if (shape5[1] != shape7[1] || shape5[1] != shape8[1] || shape5[2] != shape7[2] || shape5[2] != shape8[2])
2672 continue;
2673
2674 int num_heads = embed_dim / shape5[2];
2675
2676 // 1, seqlen, embed_dim
2677 std::vector<int> shape15;
2678 if (node15->input_size() == 1)
2679 {
2680 shape15 = get_node_attr_ai(*node15, "shape");
2681 }
2682 else
2683 {
2684 // skip weight reshape
2685 if (weights.find(node15->input(1)) == weights.end())
2686 continue;
2687
2688 shape15 = get_node_attr_from_input_ai(weights[node15->input(1)]);
2689 }
2690
2691 if (shape15.size() != 3)
2692 continue;
2693
2694 if (shape15[2] != embed_dim || shape15[1] * num_heads != shape8[1])
2695 continue;
2696
2697 // reduce
2698 node->set_op_type("noop_reducedncnn");
2699 node2->set_op_type("noop_reducedncnn");
2700 node3->set_op_type("noop_reducedncnn");
2701 node4->set_op_type("noop_reducedncnn");
2702 node5->set_op_type("noop_reducedncnn");
2703 node6->set_op_type("noop_reducedncnn");
2704 node7->set_op_type("noop_reducedncnn");
2705 node8->set_op_type("noop_reducedncnn");
2706 node9->set_op_type("noop_reducedncnn");
2707 node10->set_op_type("noop_reducedncnn");
2708 node11->set_op_type("noop_reducedncnn");
2709 node12->set_op_type("noop_reducedncnn");
2710 node13->set_op_type("noop_reducedncnn");
2711 node14->set_op_type("noop_reducedncnn");
2712 node15->set_op_type("noop_reducedncnn");
2713 node16->set_op_type("noop_reducedncnn");
2714
2715 node_reference[node2->input(0)] -= 1;
2716 node_reference[node3->input(0)] -= 1;
2717 node_reference[node4->input(0)] -= 1;
2718 node_reference[node4->input(1)] -= 1;
2719 node_reference[node5->input(0)] -= 1;
2720 if (node5->input_size() == 2)
2721 {
2722 node_reference[node5->input(1)] -= 1;
2723 }
2724 node_reference[node6->input(0)] -= 1;
2725 node_reference[node7->input(0)] -= 1;
2726 if (node7->input_size() == 2)
2727 {
2728 node_reference[node7->input(1)] -= 1;
2729 }
2730 node_reference[node8->input(0)] -= 1;
2731 if (node8->input_size() == 2)
2732 {
2733 node_reference[node8->input(1)] -= 1;
2734 }
2735 node_reference[node9->input(0)] -= 1;
2736 node_reference[node10->input(0)] -= 1;
2737 node_reference[node11->input(0)] -= 1;
2738 node_reference[node11->input(1)] -= 1;
2739 node_reference[node12->input(0)] -= 1;
2740 node_reference[node13->input(0)] -= 1;
2741 node_reference[node13->input(1)] -= 1;
2742 node_reference[node14->input(0)] -= 1;
2743 node_reference[node15->input(0)] -= 1;
2744 if (node15->input_size() == 2)
2745 {
2746 node_reference[node15->input(1)] -= 1;
2747 }
2748 node_reference[node16->input(0)] -= 1;
2749 node_reference[node17->input(0)] -= 1;
2750
2751 blob_names.erase(node->output(0));
2752 blob_names.erase(node2->output(0));
2753 blob_names.erase(node3->output(0));
2754 blob_names.erase(node3->output(1));
2755 blob_names.erase(node3->output(2));
2756 blob_names.erase(node4->output(0));
2757 blob_names.erase(node5->output(0));
2758 blob_names.erase(node6->output(0));
2759 blob_names.erase(node7->output(0));
2760 blob_names.erase(node8->output(0));
2761 blob_names.erase(node9->output(0));
2762 blob_names.erase(node10->output(0));
2763 blob_names.erase(node11->output(0));
2764 blob_names.erase(node12->output(0));
2765 blob_names.erase(node13->output(0));
2766 blob_names.erase(node14->output(0));
2767 blob_names.erase(node15->output(0));
2768 blob_names.erase(node16->output(0));
2769
2770 std::string qkvw = node->input(1);
2771 std::string qkvb = node2->input(1);
2772 std::string ow = node16->input(1);
2773 std::string ob = node17->input(1);
2774
2775 node17->set_op_type("MultiHeadAttention");
2776 node17->clear_input();
2777 node17->add_input(node->input(0));
2778 // qkv
2779 node17->add_input(qkvw);
2780 node17->add_input(qkvb);
2781 // out linear
2782 node17->add_input(ow);
2783 node17->add_input(ob);
2784
2785 onnx::AttributeProto* attr_embed_dim = node17->add_attribute();
2786 attr_embed_dim->set_name("embed_dim");
2787 attr_embed_dim->set_i(embed_dim);
2788
2789 onnx::AttributeProto* attr_num_heads = node17->add_attribute();
2790 attr_num_heads->set_name("num_heads");
2791 attr_num_heads->set_i(num_heads);
2792
2793 reduced_node_count += 16;
2794 i += 16;
2795 }
2796 }
2797 }
2798
fuse_binaryop_with_scalar(onnx::GraphProto * mutable_graph,std::map<std::string,onnx::TensorProto> & weights,std::map<std::string,int> & node_reference,std::set<std::string> & blob_names,int & reduced_node_count)2799 static void fuse_binaryop_with_scalar(onnx::GraphProto* mutable_graph, std::map<std::string, onnx::TensorProto>& weights, std::map<std::string, int>& node_reference, std::set<std::string>& blob_names, int& reduced_node_count)
2800 {
2801 int node_count = mutable_graph->node_size();
2802 for (int i = 0; i < node_count; i++)
2803 {
2804 onnx::NodeProto* node = mutable_graph->mutable_node(i);
2805
2806 // Add/Sub/Mul/Div/Min/Max/Pow(a, x)
2807 if (node->op_type() == "Add" || node->op_type() == "Sub" || node->op_type() == "Mul" || node->op_type() == "Div" || node->op_type() == "Max" || node->op_type() == "Min" || node->op_type() == "Pow")
2808 {
2809 if (weights.find(node->input(0)) == weights.end())
2810 continue;
2811
2812 const onnx::TensorProto& scalar_b = weights[node->input(0)];
2813 if (scalar_b.dims_size() != 0 || get_tensor_proto_data_size(scalar_b) != 1)
2814 continue;
2815
2816 if (node->op_type() == "Sub")
2817 {
2818 node->set_op_type("RSub");
2819 }
2820 else if (node->op_type() == "Div")
2821 {
2822 node->set_op_type("RDiv");
2823 }
2824
2825 float b = get_node_attr_from_input_f(scalar_b);
2826
2827 node_reference[node->input(0)] -= 1;
2828
2829 std::string input = node->input(1);
2830
2831 node->clear_input();
2832 node->add_input(input);
2833
2834 onnx::AttributeProto* attr_with_scalar = node->add_attribute();
2835 attr_with_scalar->set_name("with_scalar");
2836 attr_with_scalar->set_i(1);
2837
2838 onnx::AttributeProto* attr_b = node->add_attribute();
2839 attr_b->set_name("b");
2840 attr_b->set_f(b);
2841 }
2842 }
2843
2844 for (int i = 0; i < node_count; i++)
2845 {
2846 onnx::NodeProto* node = mutable_graph->mutable_node(i);
2847
2848 // Add/Sub/Mul/Div/Min/Max/Pow(x, b)
2849 if (node->op_type() == "Add" || node->op_type() == "Sub" || node->op_type() == "Mul" || node->op_type() == "Div" || node->op_type() == "Max" || node->op_type() == "Min" || node->op_type() == "Pow")
2850 {
2851 if (weights.find(node->input(1)) == weights.end())
2852 continue;
2853
2854 const onnx::TensorProto& scalar_b = weights[node->input(1)];
2855 if (scalar_b.dims_size() != 0 || get_tensor_proto_data_size(scalar_b) != 1)
2856 continue;
2857
2858 float b = get_node_attr_from_input_f(scalar_b);
2859
2860 node_reference[node->input(1)] -= 1;
2861
2862 std::string input = node->input(0);
2863
2864 node->clear_input();
2865 node->add_input(input);
2866
2867 onnx::AttributeProto* attr_with_scalar = node->add_attribute();
2868 attr_with_scalar->set_name("with_scalar");
2869 attr_with_scalar->set_i(1);
2870
2871 onnx::AttributeProto* attr_b = node->add_attribute();
2872 attr_b->set_name("b");
2873 attr_b->set_f(b);
2874 }
2875 }
2876 }
2877
main(int argc,char ** argv)2878 int main(int argc, char** argv)
2879 {
2880 if (!(argc == 2 || argc == 4))
2881 {
2882 fprintf(stderr, "Usage: %s [onnxpb] [ncnnparam] [ncnnbin]\n", argv[0]);
2883 return -1;
2884 }
2885
2886 const char* onnxpb = argv[1];
2887 const char* ncnn_prototxt = argc == 4 ? argv[2] : "ncnn.param";
2888 const char* ncnn_modelbin = argc == 4 ? argv[3] : "ncnn.bin";
2889
2890 onnx::ModelProto model;
2891
2892 // load
2893 bool s1 = read_proto_from_binary(onnxpb, &model);
2894 if (!s1)
2895 {
2896 fprintf(stderr, "read_proto_from_binary failed\n");
2897 return -1;
2898 }
2899
2900 FILE* pp = fopen(ncnn_prototxt, "wb");
2901 FILE* bp = fopen(ncnn_modelbin, "wb");
2902
2903 // magic
2904 fprintf(pp, "7767517\n");
2905
2906 const onnx::GraphProto& graph = model.graph();
2907 onnx::GraphProto* mutable_graph = model.mutable_graph();
2908
2909 int node_count = graph.node_size();
2910
2911 // node reference
2912 std::map<std::string, int> node_reference;
2913
2914 // weight node and weight reshape node
2915 std::map<std::string, onnx::TensorProto> weights;
2916
2917 for (int j = 0; j < graph.initializer_size(); j++)
2918 {
2919 const onnx::TensorProto& initializer = graph.initializer(j);
2920
2921 // fprintf(stderr, "weight = %s %d\n", initializer.name().c_str(), initializer.data_type());
2922
2923 weights[initializer.name()] = initializer;
2924 }
2925
2926 // topological sort
2927 {
2928 // name -> producer node index
2929 std::set<std::string> producers;
2930 for (int j = 0; j < graph.input_size(); j++)
2931 {
2932 const std::string& input_name = graph.input(j).name();
2933 producers.insert(input_name);
2934 }
2935
2936 for (int i = 0; i < node_count;)
2937 {
2938 onnx::NodeProto* node = mutable_graph->mutable_node(i);
2939
2940 bool swapnode = false;
2941 std::string missing_input_name;
2942 for (int j = 0; j < (int)node->input_size(); j++)
2943 {
2944 const std::string& input_name = node->input(j);
2945 if (input_name.empty())
2946 continue;
2947
2948 if (producers.find(input_name) == producers.end() && weights.find(input_name) == weights.end())
2949 {
2950 swapnode = true;
2951 missing_input_name = input_name;
2952 break;
2953 }
2954 }
2955
2956 if (!swapnode)
2957 {
2958 for (int j = 0; j < (int)node->output_size(); j++)
2959 {
2960 const std::string& output_name = node->output(j);
2961 if (output_name.empty())
2962 continue;
2963
2964 producers.insert(output_name);
2965 }
2966
2967 i++;
2968 continue;
2969 }
2970
2971 // find node that produce missing_input_name
2972 int q = i + 1;
2973 for (; q < node_count; q++)
2974 {
2975 onnx::NodeProto* nodeq = mutable_graph->mutable_node(q);
2976 bool found = false;
2977 for (int j = 0; j < (int)nodeq->output_size(); j++)
2978 {
2979 const std::string& output_name = nodeq->output(j);
2980 if (output_name == missing_input_name)
2981 {
2982 found = true;
2983 break;
2984 }
2985 }
2986
2987 if (found)
2988 break;
2989 }
2990
2991 if (q == node_count)
2992 {
2993 fprintf(stderr, "cannot find node produces %s but node %d requires it\n", missing_input_name.c_str(), i);
2994 return -1;
2995 }
2996
2997 // fprintf(stderr, "swap %d %d\n", i, q);
2998 // swap this node with q
2999 onnx::NodeProto* nodeq = mutable_graph->mutable_node(q);
3000 onnx::NodeProto tmp = *node;
3001 *node = *nodeq;
3002 *nodeq = tmp;
3003 }
3004 }
3005
3006 // global definition line
3007 // [layer count] [blob count]
3008 std::set<std::string> blob_names;
3009 for (int i = 0; i < node_count; i++)
3010 {
3011 const onnx::NodeProto& node = graph.node(i);
3012
3013 const std::string& op = node.op_type();
3014
3015 std::string name = node.name();
3016 if (name.empty())
3017 {
3018 name = node.output(0);
3019 }
3020
3021 if (op == "Constant")
3022 {
3023 onnx::TensorProto tensor = get_node_attr_tensor(node, "value");
3024 weights[node.output(0)] = tensor;
3025 }
3026
3027 for (int j = 0; j < (int)node.input_size(); j++)
3028 {
3029 const std::string& input_name = node.input(j);
3030
3031 blob_names.insert(input_name);
3032
3033 if (node_reference.find(input_name) == node_reference.end())
3034 {
3035 node_reference[input_name] = 1;
3036 }
3037 else
3038 {
3039 node_reference[input_name] = node_reference[input_name] + 1;
3040 }
3041 }
3042
3043 if (op == "Dropout")
3044 {
3045 const std::string& output_name = node.output(0);
3046 blob_names.insert(output_name);
3047 node_reference[output_name] = 0;
3048 continue;
3049 }
3050
3051 for (int j = 0; j < (int)node.output_size(); j++)
3052 {
3053 const std::string& output_name = node.output(j);
3054
3055 blob_names.insert(output_name);
3056
3057 node_reference[output_name] = 0;
3058 }
3059 }
3060
3061 // include Input node
3062 int input_node_count = 0;
3063 for (int j = 0; j < graph.input_size(); j++)
3064 {
3065 const std::string& input_name = graph.input(j).name();
3066
3067 // check weight
3068 if (weights.find(input_name) != weights.end())
3069 continue;
3070
3071 blob_names.insert(input_name);
3072
3073 input_node_count++;
3074 }
3075
3076 // for (auto a: node_reference)
3077 // {
3078 // fprintf(stderr, "a = %s %d\n", a.first.c_str(), a.second);
3079 // }
3080
3081 // op chain fusion
3082 int reduced_node_count = 0;
3083 fuse_weight_reshape(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3084 fuse_weight_transpose(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3085 fuse_shufflechannel(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3086 fuse_shufflechannel_split(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3087 fuse_hardsigmoid(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3088 fuse_hardswish(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3089 fuse_swish(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3090 fuse_batchnorm1d_squeeze_unsqueeze(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3091 fuse_unsqueeze_prelu(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3092 fuse_normalize(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3093 fuse_groupnorm(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3094 fuse_layernorm(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3095 fuse_flatten(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3096 fuse_pixelshuffle(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3097 fuse_reorg(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3098 fuse_expand_broadcast(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3099 fuse_lstm_gru_rnn(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3100 fuse_multiheadattention(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3101 fuse_binaryop_with_scalar(mutable_graph, weights, node_reference, blob_names, reduced_node_count);
3102
3103 // reduce common const weight node_reference
3104 for (int i = 0; i < node_count; i++)
3105 {
3106 const onnx::NodeProto& node = graph.node(i);
3107
3108 const std::string& op = node.op_type();
3109
3110 if (op == "BatchNormalization")
3111 {
3112 node_reference[node.input(1)] -= 1;
3113 node_reference[node.input(2)] -= 1;
3114 node_reference[node.input(3)] -= 1;
3115 node_reference[node.input(4)] -= 1;
3116 }
3117 else if (op == "BiasGelu")
3118 {
3119 node_reference[node.input(1)] -= 1;
3120 }
3121 else if (op == "Clip")
3122 {
3123 if (node.input_size() == 3)
3124 {
3125 node_reference[node.input(1)] -= 1;
3126 node_reference[node.input(2)] -= 1;
3127 }
3128 }
3129 else if (op == "Conv")
3130 {
3131 node_reference[node.input(1)] -= 1;
3132 if (node.input_size() == 3)
3133 {
3134 node_reference[node.input(2)] -= 1;
3135 }
3136 }
3137 else if (op == "ConvTranspose")
3138 {
3139 node_reference[node.input(1)] -= 1;
3140 if (node.input_size() == 3)
3141 {
3142 node_reference[node.input(2)] -= 1;
3143 }
3144 }
3145 else if (op == "EmbedLayerNormalization")
3146 {
3147 node_reference[node.input(1)] -= 1;
3148 node_reference[node.input(2)] -= 1;
3149 node_reference[node.input(3)] -= 1;
3150 node_reference[node.input(4)] -= 1;
3151 node_reference[node.input(5)] -= 1;
3152 node_reference[node.input(6)] -= 1;
3153 }
3154 else if (op == "Gemm")
3155 {
3156 float alpha = get_node_attr_f(node, "alpha", 1.f);
3157 float beta = get_node_attr_f(node, "beta", 1.f);
3158 int transA = get_node_attr_i(node, "transA", 0);
3159 int transB = get_node_attr_i(node, "transB", 0);
3160
3161 if (alpha == 1.f && beta == 1.f && transA == 0 && transB == 1)
3162 {
3163 // InnerProduct-like A * B + C
3164 node_reference[node.input(1)] -= 1;
3165 node_reference[node.input(2)] -= 1;
3166 }
3167 }
3168 else if (op == "GroupNorm")
3169 {
3170 int affine = get_node_attr_i(node, "affine", 1);
3171 if (affine)
3172 {
3173 node_reference[node.input(1)] -= 1;
3174 node_reference[node.input(2)] -= 1;
3175 }
3176 }
3177 else if (op == "GRU")
3178 {
3179 for (int j = 1; j < node.input_size(); j++)
3180 {
3181 node_reference[node.input(j)] -= 1;
3182 }
3183 }
3184 else if (op == "InstanceNormalization")
3185 {
3186 node_reference[node.input(1)] -= 1;
3187 node_reference[node.input(2)] -= 1;
3188 }
3189 else if (op == "LayerNorm")
3190 {
3191 int affine = get_node_attr_i(node, "affine", 1);
3192 if (affine)
3193 {
3194 node_reference[node.input(1)] -= 1;
3195 node_reference[node.input(2)] -= 1;
3196 }
3197 }
3198 else if (op == "LSTM")
3199 {
3200 for (int j = 1; j < node.input_size(); j++)
3201 {
3202 node_reference[node.input(j)] -= 1;
3203 }
3204 }
3205 else if (op == "MatMul")
3206 {
3207 if (weights.find(node.input(1)) != weights.end() && weights[node.input(1)].dims_size() == 2)
3208 {
3209 // InnerProduct
3210 node_reference[node.input(1)] -= 1;
3211 }
3212 }
3213 else if (op == "MultiHeadAttention")
3214 {
3215 if (node.input_size() == 5)
3216 {
3217 node_reference[node.input(1)] -= 1;
3218 node_reference[node.input(2)] -= 1;
3219 node_reference[node.input(3)] -= 1;
3220 node_reference[node.input(4)] -= 1;
3221 }
3222 else
3223 {
3224 node_reference[node.input(3)] -= 1;
3225 node_reference[node.input(4)] -= 1;
3226 node_reference[node.input(5)] -= 1;
3227 node_reference[node.input(6)] -= 1;
3228 node_reference[node.input(7)] -= 1;
3229 node_reference[node.input(8)] -= 1;
3230 node_reference[node.input(9)] -= 1;
3231 node_reference[node.input(10)] -= 1;
3232 }
3233 }
3234 else if (op == "Pad")
3235 {
3236 if (node.input_size() >= 2)
3237 {
3238 node_reference[node.input(1)] -= 1;
3239 }
3240 }
3241 else if (op == "PRelu")
3242 {
3243 node_reference[node.input(1)] -= 1;
3244 }
3245 else if (op == "Reshape")
3246 {
3247 if (node.input_size() >= 2)
3248 {
3249 node_reference[node.input(1)] -= 1;
3250 }
3251 }
3252 else if (op == "Resize")
3253 {
3254 if (node.input_size() == 2)
3255 {
3256 // opset 10
3257 node_reference[node.input(1)] -= 1;
3258 }
3259 else
3260 {
3261 // opset 11+
3262 node_reference[node.input(1)] -= 1;
3263 node_reference[node.input(2)] -= 1;
3264 if (node.input_size() >= 4)
3265 {
3266 node_reference[node.input(3)] -= 1;
3267 }
3268 }
3269 }
3270 else if (op == "RNN")
3271 {
3272 for (int j = 1; j < node.input_size(); j++)
3273 {
3274 node_reference[node.input(j)] -= 1;
3275 }
3276 }
3277 else if (op == "SkipLayerNormalization")
3278 {
3279 node_reference[node.input(2)] -= 1;
3280 node_reference[node.input(3)] -= 1;
3281 node_reference[node.input(4)] -= 1;
3282 }
3283 else if (op == "Slice")
3284 {
3285 if (node.input_size() >= 2)
3286 {
3287 node_reference[node.input(1)] -= 1;
3288 node_reference[node.input(2)] -= 1;
3289 if (node.input_size() >= 4)
3290 node_reference[node.input(3)] -= 1;
3291 if (node.input_size() >= 5)
3292 node_reference[node.input(4)] -= 1;
3293 }
3294 }
3295 else if (op == "Upsample")
3296 {
3297 if (node.input_size() >= 2)
3298 {
3299 node_reference[node.input(1)] -= 1;
3300 }
3301 }
3302 else if (op == "adaptive_avg_pool2d" || op == "adaptive_max_pool2d")
3303 {
3304 if (node.input_size() >= 2)
3305 {
3306 node_reference[node.input(1)] -= 1;
3307 }
3308 }
3309 }
3310
3311 // for (auto a: node_reference)
3312 // {
3313 // fprintf(stderr, "b = %s %d\n", a.first.c_str(), a.second);
3314 // }
3315
3316 // count all weight node with zero reference
3317 int zero_reference_weight_node_count = 0;
3318 for (std::map<std::string, onnx::TensorProto>::iterator it = weights.begin(); it != weights.end(); it++)
3319 {
3320 const std::string& input_name = it->first;
3321
3322 // there may be some weight nodes in initializer but none of the graph node use them
3323 // add them to blob_names so we could get proper blob count later
3324 blob_names.insert(input_name);
3325
3326 int refcount = node_reference[input_name];
3327 if (refcount == 0)
3328 zero_reference_weight_node_count++;
3329 }
3330
3331 // we always treat constant node as weight or binaryop_weights
3332 // do not count it twice for layer_count
3333 int constant_node_count_moved_to_weight = 0;
3334 for (int i = 0; i < node_count; i++)
3335 {
3336 const onnx::NodeProto& node = graph.node(i);
3337
3338 const std::string& op = node.op_type();
3339
3340 if (op == "Constant")
3341 {
3342 constant_node_count_moved_to_weight++;
3343 }
3344 }
3345
3346 // some op may have anonymous input
3347 // LSTM sequence_lens
3348 blob_names.erase("");
3349 node_reference.erase("");
3350
3351 // remove node_reference entry with reference equals to one
3352 int split_layer_count = 0;
3353 int splitncnn_blob_count = 0;
3354 // split node reference
3355 std::map<std::string, int> split_node_reference;
3356 for (std::map<std::string, int>::iterator it = node_reference.begin(); it != node_reference.end(); it++)
3357 {
3358 if (it->second > 1)
3359 {
3360 split_layer_count++;
3361 splitncnn_blob_count += it->second;
3362
3363 split_node_reference[it->first] = it->second;
3364 }
3365 }
3366
3367 fprintf(pp, "%zu %zu\n", node_count - constant_node_count_moved_to_weight + weights.size() - zero_reference_weight_node_count - reduced_node_count + input_node_count + split_layer_count, blob_names.size() - zero_reference_weight_node_count + splitncnn_blob_count);
3368
3369 int internal_split = 0;
3370
3371 // place Input at the beginning
3372 for (int j = 0; j < graph.input_size(); j++)
3373 {
3374 const std::string& input_name = graph.input(j).name();
3375
3376 // check weight
3377 if (weights.find(input_name) != weights.end())
3378 continue;
3379
3380 fprintf(pp, "%-16s %-24s 0 1 %s\n", "Input", input_name.c_str(), input_name.c_str());
3381
3382 int refcount = node_reference[input_name];
3383 if (refcount <= 1)
3384 {
3385 continue;
3386 }
3387
3388 char splitname[256];
3389 sprintf(splitname, "splitncnn_input%d", j);
3390 fprintf(pp, "%-16s %-24s %d %d", "Split", splitname, 1, refcount);
3391 fprintf(pp, " %s", input_name.c_str());
3392
3393 for (int k = 0; k < refcount; k++)
3394 {
3395 fprintf(pp, " %s_splitncnn_%d", input_name.c_str(), k);
3396 }
3397 fprintf(pp, "\n");
3398 }
3399
3400 // place MemoryData next
3401 for (std::map<std::string, onnx::TensorProto>::iterator weight_it = weights.begin(); weight_it != weights.end(); weight_it++)
3402 {
3403 const std::string& input_name = weight_it->first;
3404
3405 int refcount = node_reference[input_name];
3406 if (refcount == 0)
3407 {
3408 continue;
3409 }
3410
3411 fprintf(pp, "%-16s %-24s 0 1 %s", "MemoryData", input_name.c_str(), input_name.c_str());
3412
3413 const onnx::TensorProto& M = weights[input_name];
3414
3415 if (M.dims_size() == 0)
3416 {
3417 fprintf(pp, " 0=%d", get_tensor_proto_data_size(M));
3418 }
3419 else if (M.dims_size() == 1)
3420 {
3421 fprintf(pp, " 0=%d", (int)M.dims(0));
3422 }
3423 else if (M.dims_size() == 2)
3424 {
3425 fprintf(pp, " 0=%d", (int)M.dims(1));
3426 if (M.dims(0) != 1)
3427 {
3428 fprintf(pp, " 1=%d", (int)M.dims(0));
3429 }
3430 }
3431 else if (M.dims_size() == 3)
3432 {
3433 fprintf(pp, " 0=%d", (int)M.dims(2));
3434 fprintf(pp, " 1=%d", (int)M.dims(1));
3435 if (M.dims(0) != 1)
3436 {
3437 fprintf(pp, " 2=%d", (int)M.dims(0));
3438 }
3439 }
3440 else if (M.dims_size() == 4)
3441 {
3442 fprintf(pp, " 0=%d", (int)M.dims(3));
3443 fprintf(pp, " 1=%d", (int)M.dims(2));
3444 fprintf(pp, " 2=%d", (int)M.dims(1));
3445 }
3446
3447 fprintf(pp, "\n");
3448
3449 fwrite_tensor_proto_data(M, bp);
3450
3451 if (refcount <= 1)
3452 {
3453 continue;
3454 }
3455
3456 char splitname[256];
3457 sprintf(splitname, "splitncnn_%d", internal_split);
3458 fprintf(pp, "%-16s %-24s %d %d", "Split", splitname, 1, refcount);
3459
3460 fprintf(pp, " %s", input_name.c_str());
3461
3462 for (int k = 0; k < refcount; k++)
3463 {
3464 fprintf(pp, " %s_splitncnn_%d", input_name.c_str(), k);
3465 }
3466 fprintf(pp, "\n");
3467
3468 internal_split++;
3469 }
3470
3471 for (int i = 0; i < node_count; i++)
3472 {
3473 const onnx::NodeProto& node = graph.node(i);
3474
3475 const std::string& op = node.op_type();
3476
3477 // fprintf(stderr, "op = %s\n", op.c_str());
3478
3479 if (op == "noop_reducedncnn")
3480 {
3481 continue;
3482 }
3483
3484 std::string name = node.name();
3485 if (name.empty())
3486 {
3487 name = node.output(0);
3488 }
3489
3490 int input_size = node.input_size();
3491 int output_size = node.output_size();
3492
3493 for (int j = 0; j < (int)node.input_size(); j++)
3494 {
3495 const std::string& input_name = node.input(j);
3496
3497 // check weight
3498 if (weights.find(input_name) != weights.end() && node_reference[input_name] == 0)
3499 {
3500 input_size--;
3501 }
3502
3503 if (input_name.empty())
3504 {
3505 input_size--;
3506 }
3507
3508 // fprintf(stderr, " input = %s\n", input_name.c_str());
3509 }
3510 /*
3511 for (int j=0; j<(int)node.output_size(); j++)
3512 {
3513 const std::string& output_name = node.output(j);
3514 fprintf(stderr, " output = %s\n", output_name.c_str());
3515 }
3516 */
3517
3518 if (op == "Abs")
3519 {
3520 fprintf(pp, "%-16s", "UnaryOp");
3521 }
3522 else if (op == "Acos")
3523 {
3524 fprintf(pp, "%-16s", "UnaryOp");
3525 }
3526 else if (op == "Add")
3527 {
3528 fprintf(pp, "%-16s", "BinaryOp");
3529 }
3530 else if (op == "Asin")
3531 {
3532 fprintf(pp, "%-16s", "UnaryOp");
3533 }
3534 else if (op == "Atan")
3535 {
3536 fprintf(pp, "%-16s", "UnaryOp");
3537 }
3538 else if (op == "AveragePool" || op == "MaxPool")
3539 {
3540 std::vector<int> kernel_shape = get_node_attr_ai(node, "kernel_shape");
3541 if (kernel_shape.size() == 1)
3542 {
3543 fprintf(pp, "%-16s", "Pooling1D");
3544 }
3545 else
3546 {
3547 fprintf(pp, "%-16s", "Pooling");
3548 }
3549 }
3550 else if (op == "BatchNormalization")
3551 {
3552 fprintf(pp, "%-16s", "BatchNorm");
3553 }
3554 else if (op == "BiasGelu")
3555 {
3556 fprintf(pp, "%-16s", "BiasGelu");
3557 }
3558 else if (op == "Ceil")
3559 {
3560 fprintf(pp, "%-16s", "UnaryOp");
3561 }
3562 else if (op == "Clip")
3563 {
3564 fprintf(pp, "%-16s", "Clip");
3565 }
3566 else if (op == "Concat")
3567 {
3568 fprintf(pp, "%-16s", "Concat");
3569 }
3570 else if (op == "Constant")
3571 {
3572 continue;
3573 }
3574 else if (op == "Conv")
3575 {
3576 std::vector<int> kernel_shape = get_node_attr_ai(node, "kernel_shape");
3577 if (kernel_shape.size() == 1)
3578 {
3579 fprintf(pp, "%-16s", "Convolution1D");
3580 }
3581 else
3582 {
3583 int group = get_node_attr_i(node, "group", 1);
3584 if (group > 1)
3585 {
3586 fprintf(pp, "%-16s", "ConvolutionDepthWise");
3587 }
3588 else
3589 {
3590 fprintf(pp, "%-16s", "Convolution");
3591 }
3592 }
3593 }
3594 else if (op == "ConvTranspose")
3595 {
3596 int group = get_node_attr_i(node, "group", 1);
3597 if (group > 1)
3598 {
3599 fprintf(pp, "%-16s", "DeconvolutionDepthWise");
3600 }
3601 else
3602 {
3603 fprintf(pp, "%-16s", "Deconvolution");
3604 }
3605 }
3606 else if (op == "Cos")
3607 {
3608 fprintf(pp, "%-16s", "UnaryOp");
3609 }
3610 else if (op == "DepthToSpace")
3611 {
3612 fprintf(pp, "%-16s", "PixelShuffle");
3613 }
3614 else if (op == "Div")
3615 {
3616 fprintf(pp, "%-16s", "BinaryOp");
3617 }
3618 else if (op == "Dropout")
3619 {
3620 fprintf(pp, "%-16s", "Dropout");
3621 output_size = 1;
3622 }
3623 else if (op == "Elu")
3624 {
3625 fprintf(pp, "%-16s", "ELU");
3626 }
3627 else if (op == "EmbedLayerNormalization")
3628 {
3629 fprintf(pp, "%-16s", "EmbedLayerNormalization");
3630 }
3631 else if (op == "Exp")
3632 {
3633 fprintf(pp, "%-16s", "UnaryOp");
3634 }
3635 else if (op == "Flatten")
3636 {
3637 fprintf(pp, "%-16s", "Flatten");
3638 }
3639 else if (op == "Floor")
3640 {
3641 fprintf(pp, "%-16s", "UnaryOp");
3642 }
3643 else if (op == "Gemm")
3644 {
3645 float alpha = get_node_attr_f(node, "alpha", 1.f);
3646 float beta = get_node_attr_f(node, "beta", 1.f);
3647 int transA = get_node_attr_i(node, "transA", 0);
3648 int transB = get_node_attr_i(node, "transB", 0);
3649
3650 if (alpha == 1.f && beta == 1.f && transA == 0 && transB == 1)
3651 {
3652 // InnerProduct-like A * B + C
3653 fprintf(pp, "%-16s", "InnerProduct");
3654 }
3655 else
3656 {
3657 fprintf(pp, "%-16s", "Gemm");
3658 }
3659 }
3660 else if (op == "GlobalAveragePool")
3661 {
3662 fprintf(pp, "%-16s", "Pooling");
3663 }
3664 else if (op == "GlobalMaxPool")
3665 {
3666 fprintf(pp, "%-16s", "Pooling");
3667 }
3668 else if (op == "adaptive_avg_pool2d" || op == "adaptive_max_pool2d")
3669 {
3670 fprintf(pp, "%-16s", "Pooling");
3671 }
3672 else if (op == "GroupNorm")
3673 {
3674 fprintf(pp, "%-16s", "GroupNorm");
3675 }
3676 else if (op == "GRU")
3677 {
3678 fprintf(pp, "%-16s", "GRU");
3679 }
3680 else if (op == "HardSigmoid")
3681 {
3682 fprintf(pp, "%-16s", "HardSigmoid");
3683 }
3684 else if (op == "HardSwish")
3685 {
3686 fprintf(pp, "%-16s", "HardSwish");
3687 }
3688 else if (op == "ImageScaler")
3689 {
3690 fprintf(pp, "%-16s", "Scale");
3691 }
3692 else if (op == "InstanceNormalization")
3693 {
3694 fprintf(pp, "%-16s", "InstanceNorm");
3695 }
3696 else if (op == "LayerNorm")
3697 {
3698 fprintf(pp, "%-16s", "LayerNorm");
3699 }
3700 else if (op == "LeakyRelu")
3701 {
3702 fprintf(pp, "%-16s", "ReLU");
3703 }
3704 else if (op == "Log")
3705 {
3706 fprintf(pp, "%-16s", "UnaryOp");
3707 }
3708 else if (op == "LRN")
3709 {
3710 fprintf(pp, "%-16s", "LRN");
3711 }
3712 else if (op == "LSTM")
3713 {
3714 fprintf(pp, "%-16s", "LSTM");
3715 }
3716 else if (op == "MatMul")
3717 {
3718 if (weights.find(node.input(1)) != weights.end() && weights[node.input(1)].dims_size() == 2)
3719 {
3720 fprintf(pp, "%-16s", "InnerProduct");
3721 }
3722 else
3723 {
3724 fprintf(pp, "%-16s", "Gemm");
3725 }
3726 }
3727 else if (op == "Max")
3728 {
3729 fprintf(pp, "%-16s", "BinaryOp");
3730 }
3731 else if (op == "Min")
3732 {
3733 fprintf(pp, "%-16s", "BinaryOp");
3734 }
3735 else if (op == "Mul")
3736 {
3737 fprintf(pp, "%-16s", "BinaryOp");
3738 }
3739 else if (op == "MultiHeadAttention")
3740 {
3741 fprintf(pp, "%-16s", "MultiHeadAttention");
3742 }
3743 else if (op == "Neg")
3744 {
3745 fprintf(pp, "%-16s", "UnaryOp");
3746 }
3747 else if (op == "Normalize")
3748 {
3749 fprintf(pp, "%-16s", "Normalize");
3750 }
3751 else if (op == "Pad")
3752 {
3753 fprintf(pp, "%-16s", "Padding");
3754 }
3755 else if (op == "PixelShuffle")
3756 {
3757 fprintf(pp, "%-16s", "PixelShuffle");
3758 }
3759 else if (op == "Pow")
3760 {
3761 fprintf(pp, "%-16s", "BinaryOp");
3762 }
3763 else if (op == "PRelu")
3764 {
3765 fprintf(pp, "%-16s", "PReLU");
3766 }
3767 else if (op == "Reciprocal")
3768 {
3769 fprintf(pp, "%-16s", "UnaryOp");
3770 }
3771 else if (op == "ReduceMax" || op == "ReduceMin" || op == "ReduceMean" || op == "ReduceProd" || op == "ReduceSum" || op == "ReduceSumSquare" || op == "ReduceL1" || op == "ReduceL2" || op == "ReduceLogSum" || op == "ReduceLogSumExp")
3772 {
3773 fprintf(pp, "%-16s", "Reduction");
3774 }
3775 else if (op == "Relu")
3776 {
3777 fprintf(pp, "%-16s", "ReLU");
3778 }
3779 else if (op == "Reorg")
3780 {
3781 fprintf(pp, "%-16s", "Reorg");
3782 }
3783 else if (op == "Reshape")
3784 {
3785 fprintf(pp, "%-16s", "Reshape");
3786 }
3787 else if (op == "RNN")
3788 {
3789 fprintf(pp, "%-16s", "RNN");
3790 }
3791 else if (op == "RDiv")
3792 {
3793 fprintf(pp, "%-16s", "BinaryOp");
3794 }
3795 else if (op == "RSub")
3796 {
3797 fprintf(pp, "%-16s", "BinaryOp");
3798 }
3799 else if (op == "ShuffleChannel")
3800 {
3801 fprintf(pp, "%-16s", "ShuffleChannel");
3802 }
3803 else if (op == "Sigmoid")
3804 {
3805 fprintf(pp, "%-16s", "Sigmoid");
3806 }
3807 else if (op == "Sin")
3808 {
3809 fprintf(pp, "%-16s", "UnaryOp");
3810 }
3811 else if (op == "SkipLayerNormalization")
3812 {
3813 fprintf(pp, "%-16s", "SkipLayerNormalization");
3814 }
3815 else if (op == "Slice")
3816 {
3817 fprintf(pp, "%-16s", "Crop");
3818 }
3819 else if (op == "Softmax")
3820 {
3821 fprintf(pp, "%-16s", "Softmax");
3822 }
3823 else if (op == "Softplus")
3824 {
3825 fprintf(pp, "%-16s", "Softplus");
3826 }
3827 else if (op == "Split")
3828 {
3829 fprintf(pp, "%-16s", "Slice");
3830 }
3831 else if (op == "Sqrt")
3832 {
3833 fprintf(pp, "%-16s", "UnaryOp");
3834 }
3835 else if (op == "Squeeze")
3836 {
3837 fprintf(pp, "%-16s", "Squeeze");
3838 }
3839 else if (op == "Sub")
3840 {
3841 fprintf(pp, "%-16s", "BinaryOp");
3842 }
3843 else if (op == "Sum")
3844 {
3845 fprintf(pp, "%-16s", "Eltwise");
3846 }
3847 else if (op == "Swish")
3848 {
3849 fprintf(pp, "%-16s", "Swish");
3850 }
3851 else if (op == "Tan")
3852 {
3853 fprintf(pp, "%-16s", "UnaryOp");
3854 }
3855 else if (op == "Tanh")
3856 {
3857 fprintf(pp, "%-16s", "UnaryOp");
3858 }
3859 else if (op == "Transpose")
3860 {
3861 fprintf(pp, "%-16s", "Permute");
3862 }
3863 else if (op == "Upsample" || op == "Resize")
3864 {
3865 fprintf(pp, "%-16s", "Interp");
3866 }
3867 else if (op == "Unsqueeze")
3868 {
3869 fprintf(pp, "%-16s", "ExpandDims");
3870 }
3871 else
3872 {
3873 // TODO
3874 fprintf(stderr, "%s not supported yet!\n", op.c_str());
3875 fprintf(pp, "%-16s", op.c_str());
3876 }
3877
3878 fprintf(pp, " %-24s %d %d", name.c_str(), input_size, output_size);
3879
3880 for (int j = 0; j < (int)node.input_size(); j++)
3881 {
3882 std::string input_name = node.input(j);
3883
3884 // check weight
3885 if (weights.find(input_name) != weights.end() && node_reference[input_name] == 0)
3886 {
3887 continue;
3888 }
3889
3890 if (input_name.empty())
3891 {
3892 continue;
3893 }
3894
3895 if (split_node_reference.find(input_name) != split_node_reference.end())
3896 {
3897 int refidx = split_node_reference[input_name] - 1;
3898 split_node_reference[input_name] = refidx;
3899
3900 char splitsuffix[256];
3901 sprintf(splitsuffix, "_splitncnn_%d", refidx);
3902 input_name = input_name + splitsuffix;
3903 }
3904
3905 fprintf(pp, " %s", input_name.c_str());
3906 }
3907
3908 for (int j = 0; j < output_size; j++)
3909 {
3910 const std::string& output_name = node.output(j);
3911
3912 fprintf(pp, " %s", output_name.c_str());
3913 }
3914
3915 if (op == "Abs")
3916 {
3917 int op_type = 0;
3918 fprintf(pp, " 0=%d", op_type);
3919 }
3920 else if (op == "Acos")
3921 {
3922 int op_type = 13;
3923 fprintf(pp, " 0=%d", op_type);
3924 }
3925 else if (op == "Add")
3926 {
3927 int op_type = 0;
3928 fprintf(pp, " 0=%d", op_type);
3929
3930 int with_scalar = get_node_attr_i(node, "with_scalar", 0);
3931 float b = get_node_attr_f(node, "b", 0.f);
3932 if (with_scalar)
3933 {
3934 fprintf(pp, " 1=%d", with_scalar);
3935 fprintf(pp, " 2=%e", b);
3936 }
3937 }
3938 else if (op == "Asin")
3939 {
3940 int op_type = 12;
3941 fprintf(pp, " 0=%d", op_type);
3942 }
3943 else if (op == "Atan")
3944 {
3945 int op_type = 14;
3946 fprintf(pp, " 0=%d", op_type);
3947 }
3948 else if (op == "AveragePool" || op == "MaxPool")
3949 {
3950 std::string auto_pad = get_node_attr_s(node, "auto_pad");
3951 int ceil_mode = get_node_attr_i(node, "ceil_mode", 0);
3952 std::vector<int> kernel_shape = get_node_attr_ai(node, "kernel_shape");
3953 std::vector<int> strides = get_node_attr_ai(node, "strides");
3954 std::vector<int> pads = get_node_attr_ai(node, "pads");
3955
3956 int pool = op == "AveragePool" ? 1 : 0;
3957 int pad_mode = 1;
3958
3959 if (auto_pad == "SAME_UPPER")
3960 {
3961 pad_mode = 2;
3962 }
3963 else if (auto_pad == "SAME_LOWER")
3964 {
3965 pad_mode = 3;
3966 }
3967
3968 if (ceil_mode == 1)
3969 {
3970 pad_mode = 0;
3971 }
3972
3973 fprintf(pp, " 0=%d", pool);
3974
3975 if (kernel_shape.size() == 1)
3976 {
3977 fprintf(pp, " 1=%d", kernel_shape[0]);
3978 }
3979 else if (kernel_shape.size() == 2)
3980 {
3981 fprintf(pp, " 1=%d", kernel_shape[1]);
3982 fprintf(pp, " 11=%d", kernel_shape[0]);
3983 }
3984
3985 if (strides.size() == 1)
3986 {
3987 fprintf(pp, " 2=%d", strides[0]);
3988 }
3989 else if (strides.size() == 2)
3990 {
3991 fprintf(pp, " 2=%d", strides[1]);
3992 fprintf(pp, " 12=%d", strides[0]);
3993 }
3994
3995 if (pads.size() == 1)
3996 {
3997 fprintf(pp, " 3=%d", pads[0]);
3998 }
3999 else if (pads.size() == 2)
4000 {
4001 fprintf(pp, " 3=%d", pads[1]);
4002 fprintf(pp, " 13=%d", pads[0]);
4003 }
4004 else if (pads.size() == 4)
4005 {
4006 fprintf(pp, " 3=%d", pads[1]);
4007 fprintf(pp, " 13=%d", pads[0]);
4008 fprintf(pp, " 14=%d", pads[3]);
4009 fprintf(pp, " 15=%d", pads[2]);
4010 }
4011
4012 fprintf(pp, " 5=%d", pad_mode);
4013
4014 if (op == "AveragePool")
4015 {
4016 int avgpool_count_include_pad = get_node_attr_i(node, "count_include_pad", 0);
4017 fprintf(pp, " 6=%d", avgpool_count_include_pad);
4018 }
4019 }
4020 else if (op == "BatchNormalization")
4021 {
4022 float epsilon = get_node_attr_f(node, "epsilon", 1e-5f);
4023
4024 const onnx::TensorProto& scale = weights[node.input(1)];
4025 const onnx::TensorProto& B = weights[node.input(2)];
4026 const onnx::TensorProto& mean = weights[node.input(3)];
4027 const onnx::TensorProto& var = weights[node.input(4)];
4028
4029 int channels = get_tensor_proto_data_size(scale);
4030
4031 fprintf(pp, " 0=%d", channels);
4032
4033 fwrite_tensor_proto_data(scale, bp);
4034 fwrite_tensor_proto_data(mean, bp);
4035 // apply epsilon to var
4036 {
4037 const float* v = var.has_raw_data() ? (const float*)var.raw_data().data() : var.float_data().data();
4038
4039 for (int j = 0; j < channels; j++)
4040 {
4041 float ve = v[j] + epsilon;
4042 fwrite(&ve, sizeof(float), 1, bp);
4043 }
4044 }
4045 fwrite_tensor_proto_data(B, bp);
4046 }
4047 else if (op == "BiasGelu")
4048 {
4049 const onnx::TensorProto& B = weights[node.input(1)];
4050
4051 fprintf(pp, " 0=%d", get_tensor_proto_data_size(B));
4052
4053 int quantize_tag = 0;
4054 fwrite(&quantize_tag, sizeof(int), 1, bp);
4055
4056 fwrite_tensor_proto_data(B, bp);
4057 }
4058 else if (op == "Ceil")
4059 {
4060 int op_type = 3;
4061 fprintf(pp, " 0=%d", op_type);
4062 }
4063 else if (op == "Clip")
4064 {
4065 float min;
4066 float max;
4067 if (node.input_size() == 1)
4068 {
4069 min = get_node_attr_f(node, "min", -FLT_MAX);
4070 max = get_node_attr_f(node, "max", FLT_MAX);
4071 }
4072 else
4073 {
4074 min = weights.find(node.input(1)) != weights.end() ? get_node_attr_from_input_f(weights[node.input(1)]) : -FLT_MAX;
4075 max = weights.find(node.input(2)) != weights.end() ? get_node_attr_from_input_f(weights[node.input(2)]) : FLT_MAX;
4076 }
4077
4078 fprintf(pp, " 0=%e", min);
4079 fprintf(pp, " 1=%e", max);
4080 }
4081 else if (op == "Concat")
4082 {
4083 int axis = get_node_attr_i(node, "axis", 1);
4084 fprintf(pp, " 0=%d", axis > 0 ? axis - 1 : axis);
4085 }
4086 else if (op == "Constant")
4087 {
4088 // never reach here
4089 }
4090 else if (op == "Conv")
4091 {
4092 const onnx::TensorProto& W = weights[node.input(1)];
4093
4094 int num_filter = W.dims(0);
4095 int has_bias = node.input_size() == 3 ? 1 : 0;
4096
4097 std::string auto_pad = get_node_attr_s(node, "auto_pad");
4098 std::vector<int> kernel_shape = get_node_attr_ai(node, "kernel_shape");
4099 std::vector<int> dilations = get_node_attr_ai(node, "dilations");
4100 std::vector<int> strides = get_node_attr_ai(node, "strides");
4101 std::vector<int> pads = get_node_attr_ai(node, "pads");
4102 int group = get_node_attr_i(node, "group", 1);
4103
4104 fprintf(pp, " 0=%d", num_filter);
4105
4106 if (kernel_shape.size() == 1)
4107 {
4108 fprintf(pp, " 1=%d", kernel_shape[0]);
4109 }
4110 else if (kernel_shape.size() == 2)
4111 {
4112 fprintf(pp, " 1=%d", kernel_shape[1]);
4113 fprintf(pp, " 11=%d", kernel_shape[0]);
4114 }
4115
4116 if (dilations.size() == 1)
4117 {
4118 fprintf(pp, " 2=%d", dilations[0]);
4119 }
4120 else if (dilations.size() == 2)
4121 {
4122 fprintf(pp, " 2=%d", dilations[1]);
4123 fprintf(pp, " 12=%d", dilations[0]);
4124 }
4125
4126 if (strides.size() == 1)
4127 {
4128 fprintf(pp, " 3=%d", strides[0]);
4129 }
4130 else if (strides.size() == 2)
4131 {
4132 fprintf(pp, " 3=%d", strides[1]);
4133 fprintf(pp, " 13=%d", strides[0]);
4134 }
4135
4136 if (auto_pad == "SAME_UPPER")
4137 {
4138 fprintf(pp, " 4=-233");
4139 }
4140 else if (auto_pad == "SAME_LOWER")
4141 {
4142 fprintf(pp, " 4=-234");
4143 }
4144 else
4145 {
4146 if (pads.size() == 1)
4147 {
4148 fprintf(pp, " 4=%d", pads[0]);
4149 }
4150 else if (pads.size() == 2)
4151 {
4152 fprintf(pp, " 4=%d", pads[1]);
4153 fprintf(pp, " 14=%d", pads[0]);
4154 }
4155 else if (pads.size() == 4)
4156 {
4157 fprintf(pp, " 4=%d", pads[1]);
4158 fprintf(pp, " 14=%d", pads[0]);
4159 fprintf(pp, " 15=%d", pads[3]);
4160 fprintf(pp, " 16=%d", pads[2]);
4161 }
4162 }
4163
4164 fprintf(pp, " 5=%d", has_bias);
4165
4166 fprintf(pp, " 6=%d", get_tensor_proto_data_size(W));
4167
4168 if (group > 1)
4169 {
4170 fprintf(pp, " 7=%d", group);
4171 }
4172
4173 int quantize_tag = 0;
4174 fwrite(&quantize_tag, sizeof(int), 1, bp);
4175
4176 fwrite_tensor_proto_data(W, bp);
4177
4178 if (has_bias)
4179 {
4180 const onnx::TensorProto& B = weights[node.input(2)];
4181 fwrite_tensor_proto_data(B, bp);
4182 }
4183 }
4184 else if (op == "ConvTranspose")
4185 {
4186 const onnx::TensorProto& W = weights[node.input(1)];
4187
4188 int has_bias = node.input_size() == 3 ? 1 : 0;
4189
4190 std::string auto_pad = get_node_attr_s(node, "auto_pad");
4191 std::vector<int> kernel_shape = get_node_attr_ai(node, "kernel_shape");
4192 std::vector<int> dilations = get_node_attr_ai(node, "dilations");
4193 std::vector<int> strides = get_node_attr_ai(node, "strides");
4194 std::vector<int> output_padding = get_node_attr_ai(node, "output_padding");
4195 std::vector<int> output_shape = get_node_attr_ai(node, "output_shape");
4196 std::vector<int> pads = get_node_attr_ai(node, "pads");
4197 int group = get_node_attr_i(node, "group", 1);
4198 int num_filter = W.dims(1) * group;
4199
4200 fprintf(pp, " 0=%d", num_filter);
4201
4202 if (kernel_shape.size() == 1)
4203 {
4204 fprintf(pp, " 1=%d", kernel_shape[0]);
4205 }
4206 else if (kernel_shape.size() == 2)
4207 {
4208 fprintf(pp, " 1=%d", kernel_shape[1]);
4209 fprintf(pp, " 11=%d", kernel_shape[0]);
4210 }
4211
4212 if (dilations.size() == 1)
4213 {
4214 fprintf(pp, " 2=%d", dilations[0]);
4215 }
4216 else if (dilations.size() == 2)
4217 {
4218 fprintf(pp, " 2=%d", dilations[1]);
4219 fprintf(pp, " 12=%d", dilations[0]);
4220 }
4221
4222 if (strides.size() == 1)
4223 {
4224 fprintf(pp, " 3=%d", strides[0]);
4225 }
4226 else if (strides.size() == 2)
4227 {
4228 fprintf(pp, " 3=%d", strides[1]);
4229 fprintf(pp, " 13=%d", strides[0]);
4230 }
4231
4232 if (auto_pad == "SAME_UPPER")
4233 {
4234 fprintf(pp, " 4=-233");
4235 }
4236 else if (auto_pad == "SAME_LOWER")
4237 {
4238 fprintf(pp, " 4=-234");
4239 }
4240 else
4241 {
4242 if (pads.size() == 1)
4243 {
4244 fprintf(pp, " 4=%d", pads[0]);
4245 }
4246 else if (pads.size() == 2)
4247 {
4248 fprintf(pp, " 4=%d", pads[1]);
4249 fprintf(pp, " 14=%d", pads[0]);
4250 }
4251 else if (pads.size() == 4)
4252 {
4253 fprintf(pp, " 4=%d", pads[1]);
4254 fprintf(pp, " 14=%d", pads[0]);
4255 fprintf(pp, " 15=%d", pads[3]);
4256 fprintf(pp, " 16=%d", pads[2]);
4257 }
4258 }
4259
4260 if (output_padding.size() == 1)
4261 {
4262 fprintf(pp, " 18=%d", output_padding[0]);
4263 }
4264 else if (output_padding.size() == 2)
4265 {
4266 fprintf(pp, " 18=%d", output_padding[1]);
4267 fprintf(pp, " 19=%d", output_padding[0]);
4268 }
4269
4270 if (output_shape.size() == 1)
4271 {
4272 fprintf(pp, " 20=%d", output_shape[0]);
4273 }
4274 else if (output_shape.size() == 2)
4275 {
4276 fprintf(pp, " 20=%d", output_shape[1]);
4277 fprintf(pp, " 21=%d", output_shape[0]);
4278 }
4279
4280 fprintf(pp, " 5=%d", has_bias);
4281
4282 fprintf(pp, " 6=%d", get_tensor_proto_data_size(W));
4283
4284 if (group > 1)
4285 {
4286 fprintf(pp, " 7=%d", group);
4287 }
4288
4289 int quantize_tag = 0;
4290 fwrite(&quantize_tag, sizeof(int), 1, bp);
4291
4292 int maxk = 0;
4293 if (kernel_shape.size() == 2)
4294 {
4295 maxk = kernel_shape[1] * kernel_shape[0];
4296 }
4297 else
4298 {
4299 maxk = kernel_shape[0] * kernel_shape[0];
4300 }
4301 int weight_data_size = get_tensor_proto_data_size(W);
4302 const float* weight_data = 0;
4303 if (W.has_raw_data())
4304 {
4305 weight_data = (const float*)W.raw_data().data();
4306 }
4307 else if (W.data_type() == 1)
4308 {
4309 weight_data = W.float_data().data();
4310 }
4311 for (int g = 0; g < group; g++)
4312 {
4313 // reorder weight from inch-outch to outch-inch
4314 int num_filter_g = num_filter / group;
4315 int num_input = weight_data_size / maxk / num_filter_g / group;
4316 const float* weight_data_ptr = weight_data + g * maxk * num_filter_g * num_input;
4317 for (int k = 0; k < num_filter_g; k++)
4318 {
4319 for (int j = 0; j < num_input; j++)
4320 {
4321 fwrite(weight_data_ptr + (j * num_filter_g + k) * maxk, sizeof(float), maxk, bp);
4322 }
4323 }
4324 }
4325
4326 if (has_bias)
4327 {
4328 const onnx::TensorProto& B = weights[node.input(2)];
4329 fwrite_tensor_proto_data(B, bp);
4330 }
4331 }
4332 else if (op == "Cos")
4333 {
4334 int op_type = 10;
4335 fprintf(pp, " 0=%d", op_type);
4336 }
4337 else if (op == "DepthToSpace")
4338 {
4339 // pixelshuffle
4340 int scale_factor = get_node_attr_i(node, "blocksize", 1);
4341 std::string mode = get_node_attr_s(node, "mode");
4342 fprintf(pp, " 0=%d", scale_factor);
4343 if (mode == "CRD")
4344 {
4345 fprintf(pp, " 1=0");
4346 }
4347 else if (mode == "DCR")
4348 {
4349 fprintf(pp, " 1=1");
4350 }
4351 }
4352 else if (op == "Div")
4353 {
4354 int op_type = 3;
4355 fprintf(pp, " 0=%d", op_type);
4356
4357 int with_scalar = get_node_attr_i(node, "with_scalar", 0);
4358 float b = get_node_attr_f(node, "b", 0.f);
4359 if (with_scalar)
4360 {
4361 fprintf(pp, " 1=%d", with_scalar);
4362 fprintf(pp, " 2=%e", b);
4363 }
4364 }
4365 else if (op == "Dropout")
4366 {
4367 // no-op
4368 }
4369 else if (op == "Elu")
4370 {
4371 float alpha = get_node_attr_f(node, "alpha", 1.f);
4372 fprintf(pp, " 0=%e", alpha);
4373 }
4374 else if (op == "EmbedLayerNormalization")
4375 {
4376 const onnx::TensorProto& words = weights[node.input(2)];
4377 const onnx::TensorProto& positions = weights[node.input(3)];
4378 const onnx::TensorProto& W = weights[node.input(5)];
4379 const onnx::TensorProto& B = weights[node.input(6)];
4380
4381 fprintf(pp, " 0=%d", get_tensor_proto_data_size(B));
4382 fprintf(pp, " 1=%d", get_tensor_proto_data_size(words));
4383 fprintf(pp, " 2=%d", get_tensor_proto_data_size(positions));
4384
4385 int quantize_tag = 0;
4386 fwrite(&quantize_tag, sizeof(int), 1, bp);
4387
4388 fwrite_tensor_proto_data(words, bp);
4389
4390 fwrite(&quantize_tag, sizeof(int), 1, bp);
4391
4392 fwrite_tensor_proto_data(positions, bp);
4393
4394 fwrite(&quantize_tag, sizeof(int), 1, bp);
4395
4396 fwrite_tensor_proto_data(W, bp);
4397
4398 fwrite(&quantize_tag, sizeof(int), 1, bp);
4399
4400 fwrite_tensor_proto_data(B, bp);
4401 }
4402 else if (op == "Exp")
4403 {
4404 int op_type = 7;
4405 fprintf(pp, " 0=%d", op_type);
4406 }
4407 else if (op == "Flatten")
4408 {
4409 int axis = get_node_attr_i(node, "axis", 1);
4410 if (axis != 1)
4411 {
4412 fprintf(stderr, "Unsupported Flatten axis %d!\n", axis);
4413 }
4414 }
4415 else if (op == "Floor")
4416 {
4417 int op_type = 2;
4418 fprintf(pp, " 0=%d", op_type);
4419 }
4420 else if (op == "Gemm")
4421 {
4422 float alpha = get_node_attr_f(node, "alpha", 1.f);
4423 float beta = get_node_attr_f(node, "beta", 1.f);
4424 int transA = get_node_attr_i(node, "transA", 0);
4425 int transB = get_node_attr_i(node, "transB", 0);
4426
4427 if (alpha == 1.f && beta == 1.f && transA == 0 && transB == 1)
4428 {
4429 // InnerProduct-like A * B + C
4430 const onnx::TensorProto& B = weights[node.input(1)];
4431 const onnx::TensorProto& C = weights[node.input(2)];
4432
4433 fprintf(pp, " 0=%d", get_tensor_proto_data_size(C));
4434 fprintf(pp, " 1=1");
4435 fprintf(pp, " 2=%d", get_tensor_proto_data_size(B));
4436
4437 int quantize_tag = 0;
4438 fwrite(&quantize_tag, sizeof(int), 1, bp);
4439
4440 fwrite_tensor_proto_data(B, bp);
4441 fwrite_tensor_proto_data(C, bp);
4442 }
4443 else
4444 {
4445 // gemm
4446 fprintf(pp, " 0=%e", alpha);
4447 fprintf(pp, " 1=%e", beta);
4448 fprintf(pp, " 2=%d", transA);
4449 fprintf(pp, " 3=%d", transB);
4450 }
4451 }
4452 else if (op == "GlobalAveragePool")
4453 {
4454 int pool = 1;
4455 int global_pool = 1;
4456
4457 fprintf(pp, " 0=%d", pool);
4458 fprintf(pp, " 4=%d", global_pool);
4459 }
4460 else if (op == "GlobalMaxPool")
4461 {
4462 int pool = 0;
4463 int global_pool = 1;
4464
4465 fprintf(pp, " 0=%d", pool);
4466 fprintf(pp, " 4=%d", global_pool);
4467 }
4468 else if (op == "adaptive_avg_pool2d" || op == "adaptive_max_pool2d")
4469 {
4470 int pool = 0;
4471 if (op == "adaptive_avg_pool2d")
4472 {
4473 pool = 1;
4474 }
4475 int adaptive_pooling = 1;
4476 const onnx::TensorProto& out_shape_tp = weights[node.input(1)];
4477 std::vector<int> out_shape = get_node_attr_from_input_ai(out_shape_tp);
4478
4479 fprintf(pp, " 0=%d", pool);
4480 fprintf(pp, " 7=%d", adaptive_pooling);
4481 if (out_shape.size() == 1)
4482 {
4483 fprintf(pp, " 8=%d", out_shape[0]);
4484 }
4485 else if (out_shape.size() == 2)
4486 {
4487 // out_w
4488 fprintf(pp, " 8=%d", out_shape[1]);
4489 // out_h
4490 fprintf(pp, " 18=%d", out_shape[0]);
4491 }
4492 }
4493 else if (op == "GroupNorm")
4494 {
4495 int groups = get_node_attr_i(node, "groups", 1);
4496 int channels = get_node_attr_i(node, "channels", 1);
4497 float eps = get_node_attr_f(node, "epsilon", 1e-5f);
4498 int affine = get_node_attr_i(node, "affine", 1);
4499
4500 if (affine)
4501 {
4502 // discard affine-less S=1 B=0
4503 std::vector<float> affine_S = get_node_attr_from_input_af(weights[node.input(1)]);
4504 std::vector<float> affine_B = get_node_attr_from_input_af(weights[node.input(2)]);
4505 if (affine_S.size() == 1 && affine_S[0] == 1.f && affine_B.size() == 1 && affine_B[0] == 0.f)
4506 {
4507 affine = 0;
4508 }
4509 else
4510 {
4511 affine = 0;
4512 {
4513 for (int j = 0; j < channels; j++)
4514 {
4515 if (affine_S[j] != 1.f || affine_B[j] != 0.f)
4516 {
4517 affine = 1;
4518 break;
4519 }
4520 }
4521 }
4522 }
4523 }
4524
4525 fprintf(pp, " 0=%d", groups);
4526 fprintf(pp, " 1=%d", channels);
4527 fprintf(pp, " 2=%e", eps);
4528 fprintf(pp, " 3=%d", affine);
4529 if (affine)
4530 {
4531 const onnx::TensorProto& scale = weights[node.input(1)];
4532 const onnx::TensorProto& B = weights[node.input(2)];
4533
4534 fwrite_tensor_proto_data(scale, bp);
4535 fwrite_tensor_proto_data(B, bp);
4536 }
4537 }
4538 else if (op == "GRU")
4539 {
4540 const onnx::TensorProto& W = weights[node.input(1)];
4541 const onnx::TensorProto& R = weights[node.input(2)];
4542 const onnx::TensorProto& B = weights[node.input(3)];
4543
4544 int hidden_size = get_node_attr_i(node, "hidden_size", 0);
4545 std::string direction = get_node_attr_s(node, "direction");
4546
4547 int direction_type = 0;
4548 if (direction == "forward")
4549 {
4550 direction_type = 0;
4551 }
4552 else if (direction == "reverse")
4553 {
4554 direction_type = 1;
4555 }
4556 else if (direction == "bidirectional")
4557 {
4558 direction_type = 2;
4559 }
4560
4561 int weight_data_size = get_tensor_proto_data_size(W);
4562
4563 fprintf(pp, " 0=%d", hidden_size);
4564 fprintf(pp, " 1=%d", weight_data_size);
4565 fprintf(pp, " 2=%d", direction_type);
4566
4567 int num_directions = direction_type == 2 ? 2 : 1;
4568
4569 int quantize_tag = 0;
4570
4571 // reorder num_directions-URN-hidden-size to num_directions-RUN-hidden-size
4572 {
4573 fwrite(&quantize_tag, sizeof(int), 1, bp);
4574
4575 int weight_data_size_g = get_tensor_proto_data_size(W) / 3 / num_directions;
4576 const float* wptr = W.has_raw_data() ? (const float*)W.raw_data().data() : W.float_data().data();
4577
4578 const float* uptr = wptr;
4579 const float* rptr = wptr + weight_data_size_g;
4580 const float* nptr = wptr + weight_data_size_g * 2;
4581 fwrite(rptr, sizeof(float), weight_data_size_g, bp);
4582 fwrite(uptr, sizeof(float), weight_data_size_g, bp);
4583 fwrite(nptr, sizeof(float), weight_data_size_g, bp);
4584
4585 if (direction_type == 2)
4586 {
4587 uptr += weight_data_size_g * 3;
4588 rptr += weight_data_size_g * 3;
4589 nptr += weight_data_size_g * 3;
4590 fwrite(rptr, sizeof(float), weight_data_size_g, bp);
4591 fwrite(uptr, sizeof(float), weight_data_size_g, bp);
4592 fwrite(nptr, sizeof(float), weight_data_size_g, bp);
4593 }
4594 }
4595
4596 // reduce U and R bias except N
4597 // reorder num_directions-URN-hidden to num_directions-RUN-hidden
4598 {
4599 fwrite(&quantize_tag, sizeof(int), 1, bp);
4600
4601 int bias_data_size_g = get_tensor_proto_data_size(B) / 2 / 3 / num_directions;
4602 const float* bptr = B.has_raw_data() ? (const float*)B.raw_data().data() : B.float_data().data();
4603 const float* wuptr = bptr;
4604 const float* wrptr = bptr + bias_data_size_g;
4605 const float* wnptr = bptr + bias_data_size_g * 2;
4606 const float* buptr = bptr + bias_data_size_g * 3;
4607 const float* brptr = bptr + bias_data_size_g * 4;
4608 const float* bnptr = bptr + bias_data_size_g * 5;
4609
4610 for (int j = 0; j < bias_data_size_g; j++)
4611 {
4612 float vb = wrptr[j] + brptr[j];
4613 fwrite(&vb, sizeof(float), 1, bp);
4614 }
4615 for (int j = 0; j < bias_data_size_g; j++)
4616 {
4617 float vb = wuptr[j] + buptr[j];
4618 fwrite(&vb, sizeof(float), 1, bp);
4619 }
4620 fwrite(wnptr, sizeof(float), bias_data_size_g, bp);
4621 fwrite(bnptr, sizeof(float), bias_data_size_g, bp);
4622
4623 if (direction_type == 2)
4624 {
4625 wuptr += bias_data_size_g * 6;
4626 wrptr += bias_data_size_g * 6;
4627 wnptr += bias_data_size_g * 6;
4628 buptr += bias_data_size_g * 6;
4629 brptr += bias_data_size_g * 6;
4630 bnptr += bias_data_size_g * 6;
4631
4632 for (int j = 0; j < bias_data_size_g; j++)
4633 {
4634 float vb = wrptr[j] + brptr[j];
4635 fwrite(&vb, sizeof(float), 1, bp);
4636 }
4637 for (int j = 0; j < bias_data_size_g; j++)
4638 {
4639 float vb = wuptr[j] + buptr[j];
4640 fwrite(&vb, sizeof(float), 1, bp);
4641 }
4642 fwrite(wnptr, sizeof(float), bias_data_size_g, bp);
4643 fwrite(bnptr, sizeof(float), bias_data_size_g, bp);
4644 }
4645 }
4646
4647 // reorder num_directions-URN-hidden-hidden to num_directions-RUN-hidden-hidden
4648 {
4649 fwrite(&quantize_tag, sizeof(int), 1, bp);
4650
4651 int weight_data_size_g = get_tensor_proto_data_size(R) / 3 / num_directions;
4652 const float* Rptr = R.has_raw_data() ? (const float*)R.raw_data().data() : R.float_data().data();
4653
4654 const float* uptr = Rptr;
4655 const float* rptr = Rptr + weight_data_size_g;
4656 const float* nptr = Rptr + weight_data_size_g * 2;
4657 fwrite(rptr, sizeof(float), weight_data_size_g, bp);
4658 fwrite(uptr, sizeof(float), weight_data_size_g, bp);
4659 fwrite(nptr, sizeof(float), weight_data_size_g, bp);
4660
4661 if (direction_type == 2)
4662 {
4663 uptr += weight_data_size_g * 3;
4664 rptr += weight_data_size_g * 3;
4665 nptr += weight_data_size_g * 3;
4666 fwrite(rptr, sizeof(float), weight_data_size_g, bp);
4667 fwrite(uptr, sizeof(float), weight_data_size_g, bp);
4668 fwrite(nptr, sizeof(float), weight_data_size_g, bp);
4669 }
4670 }
4671 }
4672 else if (op == "HardSigmoid")
4673 {
4674 float alpha = get_node_attr_f(node, "alpha", 0.2f);
4675 float beta = get_node_attr_f(node, "beta", 0.5f);
4676
4677 fprintf(pp, " 0=%e", alpha);
4678 fprintf(pp, " 1=%e", beta);
4679 }
4680 else if (op == "HardSwish")
4681 {
4682 float alpha = get_node_attr_f(node, "alpha", 0.2f);
4683 float beta = get_node_attr_f(node, "beta", 0.5f);
4684
4685 fprintf(pp, " 0=%e", alpha);
4686 fprintf(pp, " 1=%e", beta);
4687 }
4688 else if (op == "ImageScaler")
4689 {
4690 std::vector<float> bias = get_node_attr_af(node, "bias");
4691 float scale = get_node_attr_f(node, "scale", 1.f);
4692
4693 int channels = (int)bias.size();
4694
4695 fprintf(pp, " 0=%d", channels);
4696 fprintf(pp, " 1=1");
4697
4698 for (int j = 0; j < channels; j++)
4699 {
4700 fwrite(&scale, sizeof(float), 1, bp);
4701 }
4702 fwrite(&bias[0], sizeof(float), channels, bp);
4703 }
4704 else if (op == "InstanceNormalization")
4705 {
4706 float eps = get_node_attr_f(node, "epsilon", 1e-5f);
4707
4708 // discard affine-less S=1 B=0
4709 std::vector<float> affine_S = get_node_attr_from_input_af(weights[node.input(1)]);
4710 std::vector<float> affine_B = get_node_attr_from_input_af(weights[node.input(2)]);
4711 int channels = (int)affine_S.size();
4712 int affine = 0;
4713 {
4714 for (int j = 0; j < channels; j++)
4715 {
4716 if (affine_S[j] != 1.f || affine_B[j] != 0.f)
4717 {
4718 affine = 1;
4719 break;
4720 }
4721 }
4722 }
4723
4724 fprintf(pp, " 0=%d", channels);
4725 fprintf(pp, " 1=%e", eps);
4726 fprintf(pp, " 2=%d", affine);
4727 if (affine)
4728 {
4729 const onnx::TensorProto& scale = weights[node.input(1)];
4730 const onnx::TensorProto& B = weights[node.input(2)];
4731
4732 fwrite_tensor_proto_data(scale, bp);
4733 fwrite_tensor_proto_data(B, bp);
4734 }
4735 }
4736 else if (op == "LayerNorm")
4737 {
4738 float eps = get_node_attr_f(node, "epsilon", 1e-5f);
4739 int affine = get_node_attr_i(node, "affine", 1);
4740
4741 if (affine)
4742 {
4743 // discard affine-less S=1 B=0
4744 std::vector<float> affine_S = get_node_attr_from_input_af(weights[node.input(1)]);
4745 std::vector<float> affine_B = get_node_attr_from_input_af(weights[node.input(2)]);
4746 int affine_size = (int)affine_S.size();
4747 affine = 0;
4748 {
4749 for (int j = 0; j < affine_size; j++)
4750 {
4751 if (affine_S[j] != 1.f || affine_B[j] != 0.f)
4752 {
4753 affine = 1;
4754 break;
4755 }
4756 }
4757 }
4758
4759 if (affine)
4760 {
4761 fprintf(pp, " 0=%d", affine_size);
4762 }
4763 }
4764
4765 fprintf(pp, " 1=%e", eps);
4766 fprintf(pp, " 2=%d", affine);
4767
4768 if (affine)
4769 {
4770 const onnx::TensorProto& scale = weights[node.input(1)];
4771 const onnx::TensorProto& B = weights[node.input(2)];
4772
4773 fwrite_tensor_proto_data(scale, bp);
4774 fwrite_tensor_proto_data(B, bp);
4775 }
4776 }
4777 else if (op == "LeakyRelu")
4778 {
4779 float alpha = get_node_attr_f(node, "alpha", 0.01f);
4780
4781 fprintf(pp, " 0=%e", alpha);
4782 }
4783 else if (op == "Log")
4784 {
4785 int op_type = 8;
4786 fprintf(pp, " 0=%d", op_type);
4787 }
4788 else if (op == "LRN")
4789 {
4790 float alpha = get_node_attr_f(node, "alpha", 1.f);
4791 float beta = get_node_attr_f(node, "beta", 0.5f);
4792 float bias = get_node_attr_f(node, "bias", 1.f);
4793 int size = get_node_attr_i(node, "size", 1);
4794
4795 int norm_region = 0;
4796
4797 fprintf(pp, " 0=%d", norm_region);
4798 fprintf(pp, " 1=%d", size);
4799 fprintf(pp, " 2=%e", alpha);
4800 fprintf(pp, " 3=%e", beta);
4801 fprintf(pp, " 4=%e", bias);
4802 }
4803 else if (op == "LSTM")
4804 {
4805 const onnx::TensorProto& W = weights[node.input(1)];
4806 const onnx::TensorProto& R = weights[node.input(2)];
4807 const onnx::TensorProto& B = weights[node.input(3)];
4808
4809 int hidden_size = get_node_attr_i(node, "hidden_size", 0);
4810 std::string direction = get_node_attr_s(node, "direction");
4811
4812 int direction_type = 0;
4813 if (direction == "forward")
4814 {
4815 direction_type = 0;
4816 }
4817 else if (direction == "reverse")
4818 {
4819 direction_type = 1;
4820 }
4821 else if (direction == "bidirectional")
4822 {
4823 direction_type = 2;
4824 }
4825
4826 int weight_data_size = get_tensor_proto_data_size(W);
4827
4828 fprintf(pp, " 0=%d", hidden_size);
4829 fprintf(pp, " 1=%d", weight_data_size);
4830 fprintf(pp, " 2=%d", direction_type);
4831
4832 int num_directions = direction_type == 2 ? 2 : 1;
4833
4834 int quantize_tag = 0;
4835
4836 // reorder num_directions-IOFG-hidden-size to num_directions-IFOG-hidden-size
4837 {
4838 fwrite(&quantize_tag, sizeof(int), 1, bp);
4839
4840 int weight_data_size_g = get_tensor_proto_data_size(W) / 4 / num_directions;
4841 const float* wptr = W.has_raw_data() ? (const float*)W.raw_data().data() : W.float_data().data();
4842
4843 const float* iptr = wptr;
4844 const float* optr = wptr + weight_data_size_g;
4845 const float* fptr = wptr + weight_data_size_g * 2;
4846 const float* gptr = wptr + weight_data_size_g * 3;
4847 fwrite(iptr, sizeof(float), weight_data_size_g, bp);
4848 fwrite(fptr, sizeof(float), weight_data_size_g, bp);
4849 fwrite(optr, sizeof(float), weight_data_size_g, bp);
4850 fwrite(gptr, sizeof(float), weight_data_size_g, bp);
4851
4852 if (direction_type == 2)
4853 {
4854 iptr += weight_data_size_g * 4;
4855 optr += weight_data_size_g * 4;
4856 fptr += weight_data_size_g * 4;
4857 gptr += weight_data_size_g * 4;
4858 fwrite(iptr, sizeof(float), weight_data_size_g, bp);
4859 fwrite(fptr, sizeof(float), weight_data_size_g, bp);
4860 fwrite(optr, sizeof(float), weight_data_size_g, bp);
4861 fwrite(gptr, sizeof(float), weight_data_size_g, bp);
4862 }
4863 }
4864
4865 // reduce xc and hc bias
4866 // reorder num_directions-IOFG-hidden to num_directions-IFOG-hidden
4867 {
4868 fwrite(&quantize_tag, sizeof(int), 1, bp);
4869
4870 int bias_data_size_g = get_tensor_proto_data_size(B) / 2 / 4 / num_directions;
4871 const float* xcbptr = B.has_raw_data() ? (const float*)B.raw_data().data() : B.float_data().data();
4872 const float* xiptr = xcbptr;
4873 const float* xoptr = xcbptr + bias_data_size_g;
4874 const float* xfptr = xcbptr + bias_data_size_g * 2;
4875 const float* xgptr = xcbptr + bias_data_size_g * 3;
4876 const float* hiptr = xcbptr + bias_data_size_g * 4;
4877 const float* hoptr = xcbptr + bias_data_size_g * 5;
4878 const float* hfptr = xcbptr + bias_data_size_g * 6;
4879 const float* hgptr = xcbptr + bias_data_size_g * 7;
4880
4881 for (int j = 0; j < bias_data_size_g; j++)
4882 {
4883 float vb = xiptr[j] + hiptr[j];
4884 fwrite(&vb, sizeof(float), 1, bp);
4885 }
4886 for (int j = 0; j < bias_data_size_g; j++)
4887 {
4888 float vb = xfptr[j] + hfptr[j];
4889 fwrite(&vb, sizeof(float), 1, bp);
4890 }
4891 for (int j = 0; j < bias_data_size_g; j++)
4892 {
4893 float vb = xoptr[j] + hoptr[j];
4894 fwrite(&vb, sizeof(float), 1, bp);
4895 }
4896 for (int j = 0; j < bias_data_size_g; j++)
4897 {
4898 float vb = xgptr[j] + hgptr[j];
4899 fwrite(&vb, sizeof(float), 1, bp);
4900 }
4901
4902 if (direction_type == 2)
4903 {
4904 xiptr += bias_data_size_g * 8;
4905 xoptr += bias_data_size_g * 8;
4906 xfptr += bias_data_size_g * 8;
4907 xgptr += bias_data_size_g * 8;
4908 hiptr += bias_data_size_g * 8;
4909 hoptr += bias_data_size_g * 8;
4910 hfptr += bias_data_size_g * 8;
4911 hgptr += bias_data_size_g * 8;
4912
4913 for (int j = 0; j < bias_data_size_g; j++)
4914 {
4915 float vb = xiptr[j] + hiptr[j];
4916 fwrite(&vb, sizeof(float), 1, bp);
4917 }
4918 for (int j = 0; j < bias_data_size_g; j++)
4919 {
4920 float vb = xfptr[j] + hfptr[j];
4921 fwrite(&vb, sizeof(float), 1, bp);
4922 }
4923 for (int j = 0; j < bias_data_size_g; j++)
4924 {
4925 float vb = xoptr[j] + hoptr[j];
4926 fwrite(&vb, sizeof(float), 1, bp);
4927 }
4928 for (int j = 0; j < bias_data_size_g; j++)
4929 {
4930 float vb = xgptr[j] + hgptr[j];
4931 fwrite(&vb, sizeof(float), 1, bp);
4932 }
4933 }
4934 }
4935
4936 // reorder num_directions-IOFG-hidden-hidden to num_directions-IFOG-hidden-hidden
4937 {
4938 fwrite(&quantize_tag, sizeof(int), 1, bp);
4939
4940 int weight_data_size_g = get_tensor_proto_data_size(R) / 4 / num_directions;
4941 const float* rptr = R.has_raw_data() ? (const float*)R.raw_data().data() : R.float_data().data();
4942
4943 const float* iptr = rptr;
4944 const float* optr = rptr + weight_data_size_g;
4945 const float* fptr = rptr + weight_data_size_g * 2;
4946 const float* gptr = rptr + weight_data_size_g * 3;
4947 fwrite(iptr, sizeof(float), weight_data_size_g, bp);
4948 fwrite(fptr, sizeof(float), weight_data_size_g, bp);
4949 fwrite(optr, sizeof(float), weight_data_size_g, bp);
4950 fwrite(gptr, sizeof(float), weight_data_size_g, bp);
4951
4952 if (direction_type == 2)
4953 {
4954 iptr += weight_data_size_g * 4;
4955 optr += weight_data_size_g * 4;
4956 fptr += weight_data_size_g * 4;
4957 gptr += weight_data_size_g * 4;
4958 fwrite(iptr, sizeof(float), weight_data_size_g, bp);
4959 fwrite(fptr, sizeof(float), weight_data_size_g, bp);
4960 fwrite(optr, sizeof(float), weight_data_size_g, bp);
4961 fwrite(gptr, sizeof(float), weight_data_size_g, bp);
4962 }
4963 }
4964 }
4965 else if (op == "MatMul")
4966 {
4967 if (weights.find(node.input(1)) != weights.end() && weights[node.input(1)].dims_size() == 2)
4968 {
4969 // InnerProduct
4970 const onnx::TensorProto& B = weights[node.input(1)];
4971
4972 int weight_data_size = get_tensor_proto_data_size(B);
4973
4974 int num_output = B.dims(B.dims_size() - 1);
4975 int num_input = weight_data_size / num_output;
4976
4977 fprintf(pp, " 0=%d", num_output);
4978 fprintf(pp, " 1=0");
4979 fprintf(pp, " 2=%d", weight_data_size);
4980
4981 int quantize_tag = 0;
4982 fwrite(&quantize_tag, sizeof(int), 1, bp);
4983
4984 // reorder num_input-num_output to num_output-num_input
4985 {
4986 const float* bptr = B.has_raw_data() ? (const float*)B.raw_data().data() : B.float_data().data();
4987
4988 for (int j = 0; j < num_output; j++)
4989 {
4990 for (int k = 0; k < num_input; k++)
4991 {
4992 float vb = bptr[k * num_output + j];
4993 fwrite(&vb, sizeof(float), 1, bp);
4994 }
4995 }
4996 }
4997
4998 // fwrite_tensor_proto_data(B, bp)
4999 }
5000 else
5001 {
5002 // default matrix multiplication
5003 }
5004 }
5005 else if (op == "Max")
5006 {
5007 int op_type = 4;
5008 fprintf(pp, " 0=%d", op_type);
5009
5010 int with_scalar = get_node_attr_i(node, "with_scalar", 0);
5011 float b = get_node_attr_f(node, "b", 0.f);
5012 if (with_scalar)
5013 {
5014 fprintf(pp, " 1=%d", with_scalar);
5015 fprintf(pp, " 2=%e", b);
5016 }
5017 }
5018 else if (op == "Min")
5019 {
5020 int op_type = 5;
5021 fprintf(pp, " 0=%d", op_type);
5022
5023 int with_scalar = get_node_attr_i(node, "with_scalar", 0);
5024 float b = get_node_attr_f(node, "b", 0.f);
5025 if (with_scalar)
5026 {
5027 fprintf(pp, " 1=%d", with_scalar);
5028 fprintf(pp, " 2=%e", b);
5029 }
5030 }
5031 else if (op == "Mul")
5032 {
5033 int op_type = 2;
5034 fprintf(pp, " 0=%d", op_type);
5035
5036 int with_scalar = get_node_attr_i(node, "with_scalar", 0);
5037 float b = get_node_attr_f(node, "b", 0.f);
5038 if (with_scalar)
5039 {
5040 fprintf(pp, " 1=%d", with_scalar);
5041 fprintf(pp, " 2=%e", b);
5042 }
5043 }
5044 else if (op == "MultiHeadAttention")
5045 {
5046 int embed_dim = get_node_attr_i(node, "embed_dim", 0);
5047 int num_heads = get_node_attr_i(node, "num_heads", 0);
5048
5049 fprintf(pp, " 0=%d", embed_dim);
5050 fprintf(pp, " 1=%d", num_heads);
5051
5052 if (node.input_size() == 5)
5053 {
5054 const onnx::TensorProto& qkvw = weights[node.input(1)];
5055 const onnx::TensorProto& qkvb = weights[node.input(2)];
5056 const onnx::TensorProto& ow = weights[node.input(3)];
5057 const onnx::TensorProto& ob = weights[node.input(4)];
5058
5059 int weight_data_size = get_tensor_proto_data_size(ow);
5060
5061 fprintf(pp, " 2=%d", weight_data_size);
5062
5063 int quantize_tag = 0;
5064
5065 fwrite(&quantize_tag, sizeof(int), 1, bp);
5066 // transpose qw
5067 {
5068 const float* wptr = qkvw.has_raw_data() ? (const float*)qkvw.raw_data().data() : qkvw.float_data().data();
5069 const float* bptr = qkvb.has_raw_data() ? (const float*)qkvb.raw_data().data() : qkvb.float_data().data();
5070
5071 for (int j = 0; j < embed_dim; j++)
5072 {
5073 for (int k = 0; k < embed_dim; k++)
5074 {
5075 float vb = wptr[k * embed_dim * 3 + j];
5076 fwrite(&vb, sizeof(float), 1, bp);
5077 }
5078 }
5079
5080 fwrite(bptr, sizeof(float), embed_dim, bp);
5081 }
5082
5083 fwrite(&quantize_tag, sizeof(int), 1, bp);
5084 // transpose kw
5085 {
5086 const float* wptr = qkvw.has_raw_data() ? (const float*)qkvw.raw_data().data() : qkvw.float_data().data();
5087 const float* bptr = qkvb.has_raw_data() ? (const float*)qkvb.raw_data().data() : qkvb.float_data().data();
5088 bptr += embed_dim;
5089
5090 for (int j = 0; j < embed_dim; j++)
5091 {
5092 for (int k = 0; k < embed_dim; k++)
5093 {
5094 float vb = wptr[k * embed_dim * 3 + j + embed_dim];
5095 fwrite(&vb, sizeof(float), 1, bp);
5096 }
5097 }
5098
5099 fwrite(bptr, sizeof(float), embed_dim, bp);
5100 }
5101
5102 fwrite(&quantize_tag, sizeof(int), 1, bp);
5103 // transpose vw
5104 {
5105 const float* wptr = qkvw.has_raw_data() ? (const float*)qkvw.raw_data().data() : qkvw.float_data().data();
5106 const float* bptr = qkvb.has_raw_data() ? (const float*)qkvb.raw_data().data() : qkvb.float_data().data();
5107 bptr += embed_dim * 2;
5108
5109 for (int j = 0; j < embed_dim; j++)
5110 {
5111 for (int k = 0; k < embed_dim; k++)
5112 {
5113 float vb = wptr[k * embed_dim * 3 + j + embed_dim * 2];
5114 fwrite(&vb, sizeof(float), 1, bp);
5115 }
5116 }
5117
5118 fwrite(bptr, sizeof(float), embed_dim, bp);
5119 }
5120
5121 fwrite(&quantize_tag, sizeof(int), 1, bp);
5122 // transpose ow
5123 {
5124 const float* wptr = ow.has_raw_data() ? (const float*)ow.raw_data().data() : ow.float_data().data();
5125
5126 for (int j = 0; j < embed_dim; j++)
5127 {
5128 for (int k = 0; k < embed_dim; k++)
5129 {
5130 float vb = wptr[k * embed_dim + j];
5131 fwrite(&vb, sizeof(float), 1, bp);
5132 }
5133 }
5134 }
5135 fwrite_tensor_proto_data(ob, bp);
5136 }
5137 else
5138 {
5139 const onnx::TensorProto& qw = weights[node.input(3)];
5140 const onnx::TensorProto& qb = weights[node.input(4)];
5141 const onnx::TensorProto& kw = weights[node.input(5)];
5142 const onnx::TensorProto& kb = weights[node.input(6)];
5143 const onnx::TensorProto& vw = weights[node.input(7)];
5144 const onnx::TensorProto& vb = weights[node.input(8)];
5145 const onnx::TensorProto& ow = weights[node.input(9)];
5146 const onnx::TensorProto& ob = weights[node.input(10)];
5147
5148 int weight_data_size = get_tensor_proto_data_size(qw);
5149
5150 fprintf(pp, " 2=%d", weight_data_size);
5151
5152 int quantize_tag = 0;
5153
5154 fwrite(&quantize_tag, sizeof(int), 1, bp);
5155 // transpose qw
5156 {
5157 const float* wptr = qw.has_raw_data() ? (const float*)qw.raw_data().data() : qw.float_data().data();
5158
5159 for (int j = 0; j < embed_dim; j++)
5160 {
5161 for (int k = 0; k < embed_dim; k++)
5162 {
5163 float vb = wptr[k * embed_dim + j];
5164 fwrite(&vb, sizeof(float), 1, bp);
5165 }
5166 }
5167 }
5168 fwrite_tensor_proto_data(qb, bp);
5169
5170 fwrite(&quantize_tag, sizeof(int), 1, bp);
5171 // transpose kw
5172 {
5173 const float* wptr = kw.has_raw_data() ? (const float*)kw.raw_data().data() : kw.float_data().data();
5174
5175 for (int j = 0; j < embed_dim; j++)
5176 {
5177 for (int k = 0; k < embed_dim; k++)
5178 {
5179 float vb = wptr[k * embed_dim + j];
5180 fwrite(&vb, sizeof(float), 1, bp);
5181 }
5182 }
5183 }
5184 fwrite_tensor_proto_data(kb, bp);
5185
5186 fwrite(&quantize_tag, sizeof(int), 1, bp);
5187 // transpose vw
5188 {
5189 const float* wptr = vw.has_raw_data() ? (const float*)vw.raw_data().data() : vw.float_data().data();
5190
5191 for (int j = 0; j < embed_dim; j++)
5192 {
5193 for (int k = 0; k < embed_dim; k++)
5194 {
5195 float vb = wptr[k * embed_dim + j];
5196 fwrite(&vb, sizeof(float), 1, bp);
5197 }
5198 }
5199 }
5200 fwrite_tensor_proto_data(vb, bp);
5201
5202 fwrite(&quantize_tag, sizeof(int), 1, bp);
5203 // transpose ow
5204 {
5205 const float* wptr = ow.has_raw_data() ? (const float*)ow.raw_data().data() : ow.float_data().data();
5206
5207 for (int j = 0; j < embed_dim; j++)
5208 {
5209 for (int k = 0; k < embed_dim; k++)
5210 {
5211 float vb = wptr[k * embed_dim + j];
5212 fwrite(&vb, sizeof(float), 1, bp);
5213 }
5214 }
5215 }
5216 fwrite_tensor_proto_data(ob, bp);
5217 }
5218 }
5219 else if (op == "Neg")
5220 {
5221 int op_type = 1;
5222 fprintf(pp, " 0=%d", op_type);
5223 }
5224 else if (op == "Normalize")
5225 {
5226 float eps = get_node_attr_f(node, "eps", 0.f);
5227 int scale_data_size = 1;
5228
5229 fprintf(pp, " 1=1"); // channel_shared
5230 fprintf(pp, " 2=%e", eps);
5231 fprintf(pp, " 3=%d", scale_data_size);
5232 fprintf(pp, " 9=1"); // TODO hardcode pytorch style
5233
5234 const float scale_data[1] = {1.f};
5235 fwrite(scale_data, sizeof(float), 1, bp);
5236 }
5237 else if (op == "Pad")
5238 {
5239 std::string mode = get_node_attr_s(node, "mode");
5240 float value = get_node_attr_f(node, "value", 0.f);
5241
5242 std::vector<int> pads;
5243 if (node.input_size() == 1)
5244 {
5245 pads = get_node_attr_ai(node, "pads");
5246 }
5247 else
5248 {
5249 pads = get_node_attr_from_input_ai(weights[node.input(1)]);
5250 }
5251
5252 int type = 0;
5253 if (mode == "constant")
5254 {
5255 type = 0;
5256 }
5257 else if (mode == "edge")
5258 {
5259 type = 1;
5260 }
5261 else if (mode == "reflect")
5262 {
5263 type = 2;
5264 }
5265
5266 int pad_size = (int)pads.size();
5267 int top = 0;
5268 int bottom = 0;
5269 int left = 0;
5270 int right = 0;
5271 int front = 0;
5272 int behind = 0;
5273 if (pad_size == 8)
5274 {
5275 //NCHW
5276 top = pads[2];
5277 bottom = pads[6];
5278 left = pads[3];
5279 right = pads[7];
5280 front = pads[1];
5281 behind = pads[5];
5282 }
5283 else if (pad_size == 6)
5284 {
5285 //NHW
5286 top = pads[1];
5287 bottom = pads[4];
5288 left = pads[2];
5289 right = pads[5];
5290 }
5291 else
5292 {
5293 //NW
5294 left = pads[1];
5295 right = pads[3];
5296 }
5297
5298 fprintf(pp, " 0=%d", top);
5299 fprintf(pp, " 1=%d", bottom);
5300 fprintf(pp, " 2=%d", left);
5301 fprintf(pp, " 3=%d", right);
5302 fprintf(pp, " 4=%d", type);
5303 fprintf(pp, " 5=%e", value);
5304 fprintf(pp, " 7=%d", front);
5305 fprintf(pp, " 8=%d", behind);
5306 }
5307 else if (op == "Pow")
5308 {
5309 int op_type = 6;
5310 fprintf(pp, " 0=%d", op_type);
5311
5312 int with_scalar = get_node_attr_i(node, "with_scalar", 0);
5313 float b = get_node_attr_f(node, "b", 0.f);
5314 if (with_scalar)
5315 {
5316 fprintf(pp, " 1=%d", with_scalar);
5317 fprintf(pp, " 2=%e", b);
5318 }
5319 }
5320 else if (op == "PixelShuffle")
5321 {
5322 int scale_factor = get_node_attr_i(node, "scale_factor", 1);
5323 fprintf(pp, " 0=%d", scale_factor);
5324 }
5325 else if (op == "PRelu")
5326 {
5327 const onnx::TensorProto& slope = weights[node.input(1)];
5328
5329 int num_slope = get_tensor_proto_data_size(slope);
5330
5331 fprintf(pp, " 0=%d", num_slope);
5332
5333 fwrite_tensor_proto_data(slope, bp);
5334 }
5335 else if (op == "Reciprocal")
5336 {
5337 int op_type = 15;
5338 fprintf(pp, " 0=%d", op_type);
5339 }
5340 else if (op == "ReduceMax" || op == "ReduceMin" || op == "ReduceMean" || op == "ReduceProd" || op == "ReduceSum" || op == "ReduceSumSquare" || op == "ReduceL1" || op == "ReduceL2" || op == "ReduceLogSum" || op == "ReduceLogSumExp")
5341 {
5342 int op_type = -233;
5343 if (op == "ReduceSum")
5344 op_type = 0;
5345 else if (op == "ReduceSumSquare")
5346 op_type = 2;
5347 else if (op == "ReduceMean")
5348 op_type = 3;
5349 else if (op == "ReduceMax")
5350 op_type = 4;
5351 else if (op == "ReduceMin")
5352 op_type = 5;
5353 else if (op == "ReduceProd")
5354 op_type = 6;
5355 else if (op == "ReduceL1")
5356 op_type = 7;
5357 else if (op == "ReduceL2")
5358 op_type = 8;
5359 else if (op == "ReduceLogSum")
5360 op_type = 9;
5361 else if (op == "ReduceLogSumExp")
5362 op_type = 10;
5363 fprintf(pp, " 0=%d", op_type);
5364
5365 std::vector<int> axes = get_node_attr_ai(node, "axes");
5366 int keepdims = get_node_attr_i(node, "keepdims", 1);
5367
5368 if (axes.size() > 0)
5369 {
5370 // if axes set, reduce according to axes
5371 fprintf(pp, " 1=%d", 0);
5372 fprintf(pp, " -23303=%zu", axes.size());
5373 for (size_t j = 0; j < axes.size(); j++)
5374 {
5375 if (axes[j] == 0 || axes[j] > 3 || axes[j] < -3)
5376 fprintf(stderr, "Unsupported reduction axes !\n");
5377 fprintf(pp, ",%d", axes[j]);
5378 }
5379 }
5380 else
5381 {
5382 // if axes not set, reduce all axes by default
5383 fprintf(pp, " 1=%d", 1);
5384 }
5385 fprintf(pp, " 4=%d", keepdims);
5386 }
5387 else if (op == "Reorg")
5388 {
5389 int stride = get_node_attr_i(node, "stride", 1);
5390 fprintf(pp, " 0=%d", stride);
5391 }
5392 else if (op == "Reshape")
5393 {
5394 std::vector<int> shape;
5395
5396 if (node.input_size() == 1)
5397 {
5398 shape = get_node_attr_ai(node, "shape");
5399 }
5400 else
5401 {
5402 shape = get_node_attr_from_input_ai(weights[node.input(1)]);
5403 }
5404
5405 if (shape.size() == 1)
5406 {
5407 fprintf(pp, " 0=%d", shape[0]); // should never reach here
5408 }
5409 else if (shape.size() == 2)
5410 {
5411 fprintf(pp, " 0=%d", shape[1]);
5412 }
5413 else if (shape.size() == 3)
5414 {
5415 fprintf(pp, " 0=%d", shape[2]);
5416 fprintf(pp, " 1=%d", shape[1]);
5417 }
5418 else if (shape.size() == 4)
5419 {
5420 fprintf(pp, " 0=%d", shape[3]);
5421 fprintf(pp, " 1=%d", shape[2]);
5422 fprintf(pp, " 2=%d", shape[1]);
5423 }
5424 else if (shape.size() == 5)
5425 {
5426 fprintf(pp, " 0=%d", shape[4] * shape[3]);
5427 fprintf(pp, " 1=%d", shape[2]);
5428 fprintf(pp, " 2=%d", shape[1]);
5429 }
5430 }
5431 else if (op == "Resize")
5432 {
5433 std::string mode = get_node_attr_s(node, "mode");
5434 std::string align = get_node_attr_s(node, "coordinate_transformation_mode");
5435
5436 std::vector<float> scales;
5437 std::vector<int> sizes;
5438 if (node.input_size() == 2)
5439 {
5440 // opset 10
5441 scales = get_node_attr_from_input_af(weights[node.input(1)]);
5442 }
5443 else
5444 {
5445 // opset 11+
5446 scales = get_node_attr_from_input_af(weights[node.input(2)]);
5447 if (node.input_size() >= 4)
5448 {
5449 sizes = get_node_attr_from_input_ai(weights[node.input(3)]);
5450 }
5451 }
5452
5453 int resize_type = 1;
5454 if (mode == "nearest")
5455 {
5456 resize_type = 1;
5457 }
5458 else if (mode == "linear")
5459 {
5460 resize_type = 2;
5461 }
5462 else if (mode == "cubic")
5463 {
5464 resize_type = 3;
5465 }
5466
5467 if (scales.empty() && sizes.empty())
5468 {
5469 fprintf(stderr, "Unsupported Resize scales and sizes are all empty!\n");
5470 }
5471
5472 float h_scale = 1.f;
5473 float w_scale = 1.f;
5474 if (scales.size() == 2)
5475 {
5476 w_scale = scales[1];
5477 }
5478 else if (scales.size() == 3)
5479 {
5480 h_scale = scales[1];
5481 w_scale = scales[2];
5482 }
5483 else if (scales.size() == 4)
5484 {
5485 h_scale = scales[2];
5486 w_scale = scales[3];
5487
5488 if (scales[1] != 1.f)
5489 fprintf(stderr, "Unsupported Resize scales !\n");
5490 }
5491
5492 int output_height = 0;
5493 int output_width = 0;
5494 if (sizes.size() == 2)
5495 {
5496 output_width = sizes[1];
5497 }
5498 else if (sizes.size() == 3)
5499 {
5500 output_height = sizes[1];
5501 output_width = sizes[2];
5502 }
5503 else if (sizes.size() == 4)
5504 {
5505 output_height = sizes[2];
5506 output_width = sizes[3];
5507 }
5508
5509 int align_corner = 0;
5510 if (align == "align_corners")
5511 {
5512 align_corner = 1;
5513 }
5514
5515 fprintf(pp, " 0=%d", resize_type);
5516 fprintf(pp, " 1=%e", h_scale);
5517 fprintf(pp, " 2=%e", w_scale);
5518 fprintf(pp, " 3=%d", output_height);
5519 fprintf(pp, " 4=%d", output_width);
5520 fprintf(pp, " 6=%d", align_corner);
5521 }
5522 else if (op == "RNN")
5523 {
5524 const onnx::TensorProto& W = weights[node.input(1)];
5525 const onnx::TensorProto& R = weights[node.input(2)];
5526 const onnx::TensorProto& B = weights[node.input(3)];
5527
5528 int hidden_size = get_node_attr_i(node, "hidden_size", 0);
5529 std::string direction = get_node_attr_s(node, "direction");
5530
5531 int direction_type = 0;
5532 if (direction == "forward")
5533 {
5534 direction_type = 0;
5535 }
5536 else if (direction == "reverse")
5537 {
5538 direction_type = 1;
5539 }
5540 else if (direction == "bidirectional")
5541 {
5542 direction_type = 2;
5543 }
5544
5545 int weight_data_size = get_tensor_proto_data_size(W);
5546
5547 fprintf(pp, " 0=%d", hidden_size);
5548 fprintf(pp, " 1=%d", weight_data_size);
5549 fprintf(pp, " 2=%d", direction_type);
5550
5551 int num_directions = direction_type == 2 ? 2 : 1;
5552
5553 int quantize_tag = 0;
5554
5555 fwrite(&quantize_tag, sizeof(int), 1, bp);
5556 fwrite_tensor_proto_data(W, bp);
5557
5558 // reduce xc and hc bias
5559 {
5560 fwrite(&quantize_tag, sizeof(int), 1, bp);
5561
5562 int bias_data_size_g = get_tensor_proto_data_size(B) / 2 / num_directions;
5563 const float* bptr = B.has_raw_data() ? (const float*)B.raw_data().data() : B.float_data().data();
5564 const float* xiptr = bptr;
5565 const float* hiptr = bptr + bias_data_size_g;
5566
5567 for (int j = 0; j < bias_data_size_g; j++)
5568 {
5569 float vb = xiptr[j] + hiptr[j];
5570 fwrite(&vb, sizeof(float), 1, bp);
5571 }
5572
5573 if (direction_type == 2)
5574 {
5575 xiptr += bias_data_size_g * 2;
5576 hiptr += bias_data_size_g * 2;
5577
5578 for (int j = 0; j < bias_data_size_g; j++)
5579 {
5580 float vb = xiptr[j] + hiptr[j];
5581 fwrite(&vb, sizeof(float), 1, bp);
5582 }
5583 }
5584 }
5585
5586 fwrite(&quantize_tag, sizeof(int), 1, bp);
5587 fwrite_tensor_proto_data(R, bp);
5588 }
5589 else if (op == "RDiv")
5590 {
5591 int op_type = 8;
5592 fprintf(pp, " 0=%d", op_type);
5593
5594 int with_scalar = get_node_attr_i(node, "with_scalar", 0);
5595 float b = get_node_attr_f(node, "b", 0.f);
5596 if (with_scalar)
5597 {
5598 fprintf(pp, " 1=%d", with_scalar);
5599 fprintf(pp, " 2=%e", b);
5600 }
5601 }
5602 else if (op == "RSub")
5603 {
5604 int op_type = 7;
5605 fprintf(pp, " 0=%d", op_type);
5606
5607 int with_scalar = get_node_attr_i(node, "with_scalar", 0);
5608 float b = get_node_attr_f(node, "b", 0.f);
5609 if (with_scalar)
5610 {
5611 fprintf(pp, " 1=%d", with_scalar);
5612 fprintf(pp, " 2=%e", b);
5613 }
5614 }
5615 else if (op == "ShuffleChannel")
5616 {
5617 int group = get_node_attr_i(node, "group", 1);
5618 int reverse = get_node_attr_i(node, "reverse", 0);
5619 fprintf(pp, " 0=%d", group);
5620 fprintf(pp, " 1=%d", reverse);
5621 }
5622 else if (op == "Sigmoid")
5623 {
5624 // no param
5625 }
5626 else if (op == "Sin")
5627 {
5628 int op_type = 9;
5629 fprintf(pp, " 0=%d", op_type);
5630 }
5631 else if (op == "SkipLayerNormalization")
5632 {
5633 const onnx::TensorProto& W = weights[node.input(2)];
5634 const onnx::TensorProto& B = weights[node.input(3)];
5635 const onnx::TensorProto& B2 = weights[node.input(4)];
5636
5637 fprintf(pp, " 0=%d", get_tensor_proto_data_size(B));
5638
5639 int quantize_tag = 0;
5640 fwrite(&quantize_tag, sizeof(int), 1, bp);
5641
5642 fwrite_tensor_proto_data(W, bp);
5643
5644 fwrite(&quantize_tag, sizeof(int), 1, bp);
5645
5646 fwrite_tensor_proto_data(B, bp);
5647
5648 fwrite(&quantize_tag, sizeof(int), 1, bp);
5649
5650 fwrite_tensor_proto_data(B2, bp);
5651 }
5652 else if (op == "Slice")
5653 {
5654 std::vector<int> starts;
5655 std::vector<int> ends;
5656 std::vector<int> axes;
5657 std::vector<int> steps;
5658 if (node.input_size() == 1)
5659 {
5660 starts = get_node_attr_ai(node, "starts");
5661 ends = get_node_attr_ai(node, "ends");
5662 axes = get_node_attr_ai(node, "axes");
5663 steps = get_node_attr_ai(node, "steps"); // TODO
5664 }
5665 else
5666 {
5667 starts = get_node_attr_from_input_ai(weights[node.input(1)]);
5668 ends = get_node_attr_from_input_ai(weights[node.input(2)]);
5669 if (node.input_size() >= 4)
5670 axes = get_node_attr_from_input_ai(weights[node.input(3)]);
5671 if (node.input_size() >= 5)
5672 steps = get_node_attr_from_input_ai(weights[node.input(4)]);
5673 }
5674
5675 // assert step == 1
5676 for (int i = 0; i < (int)steps.size(); i++)
5677 {
5678 if (steps[i] != 1)
5679 fprintf(stderr, "Unsupported slice step !\n");
5680 }
5681
5682 // filter out N-dim axis
5683 if (!axes.empty())
5684 {
5685 for (int i = 0; i < (int)axes.size(); i++)
5686 {
5687 int axis = axes[i];
5688 if (axis == 0)
5689 {
5690 starts.erase(starts.begin() + i);
5691 ends.erase(ends.begin() + i);
5692 axes.erase(axes.begin() + i);
5693 break;
5694 }
5695 }
5696 }
5697
5698 fprintf(pp, " -23309=%d", (int)starts.size());
5699 for (int i = 0; i < (int)starts.size(); i++)
5700 {
5701 fprintf(pp, ",%d", starts[i]);
5702 }
5703 fprintf(pp, " -23310=%d", (int)ends.size());
5704 for (int i = 0; i < (int)ends.size(); i++)
5705 {
5706 fprintf(pp, ",%d", ends[i]);
5707 }
5708 if (!axes.empty())
5709 {
5710 fprintf(pp, " -23311=%d", (int)axes.size());
5711 for (int i = 0; i < (int)axes.size(); i++)
5712 {
5713 int axis = axes[i];
5714 if (axis == 0 || axis > 3 || axis < -3)
5715 fprintf(stderr, "Unsupported slice axes !\n");
5716
5717 if (axis > 0)
5718 axis = axis - 1; // -1 for skip N-dim
5719
5720 fprintf(pp, ",%d", axis);
5721 }
5722 }
5723 }
5724 else if (op == "Softmax")
5725 {
5726 int axis = get_node_attr_i(node, "axis", 1);
5727 fprintf(pp, " 0=%d", axis - 1);
5728 fprintf(pp, " 1=1");
5729 }
5730 else if (op == "Split")
5731 {
5732 int axis = get_node_attr_i(node, "axis", 0);
5733 std::vector<int> split = get_node_attr_ai(node, "split");
5734 if (axis < 1)
5735 fprintf(stderr, "Unsupported split axis !\n");
5736
5737 fprintf(pp, " -23300=%d", output_size);
5738 if (split.empty())
5739 {
5740 for (int i = 0; i < output_size; i++)
5741 {
5742 fprintf(pp, ",-233");
5743 }
5744 }
5745 else
5746 {
5747 for (size_t i = 0; i < split.size() - 1; i++)
5748 {
5749 fprintf(pp, ",%d", split[i]);
5750 }
5751 fprintf(pp, ",-233");
5752 }
5753 fprintf(pp, " 1=%d", axis - 1);
5754 }
5755 else if (op == "Sqrt")
5756 {
5757 int op_type = 5;
5758 fprintf(pp, " 0=%d", op_type);
5759 }
5760 else if (op == "Squeeze")
5761 {
5762 std::vector<int> axes = get_node_attr_ai(node, "axes");
5763
5764 if (axes.empty())
5765 {
5766 fprintf(pp, " 0=1");
5767 fprintf(pp, " 1=1");
5768 fprintf(pp, " 2=1");
5769 }
5770 else
5771 {
5772 fprintf(pp, " -23303=%zu", axes.size());
5773 for (int i = 0; i < (int)axes.size(); i++)
5774 {
5775 if (axes[i] == 0 || axes[i] > 3 || axes[i] < -3)
5776 fprintf(stderr, "Unsupported squeeze axes !\n");
5777 fprintf(pp, ",%d", axes[i]);
5778 }
5779 }
5780 }
5781 else if (op == "Sub")
5782 {
5783 int op_type = 1;
5784 fprintf(pp, " 0=%d", op_type);
5785
5786 int with_scalar = get_node_attr_i(node, "with_scalar", 0);
5787 float b = get_node_attr_f(node, "b", 0.f);
5788 if (with_scalar)
5789 {
5790 fprintf(pp, " 1=%d", with_scalar);
5791 fprintf(pp, " 2=%e", b);
5792 }
5793 }
5794 else if (op == "Sum")
5795 {
5796 int op_type = 1;
5797 fprintf(pp, " 0=%d", op_type);
5798 }
5799 else if (op == "Swish")
5800 {
5801 // no param
5802 }
5803 else if (op == "Tan")
5804 {
5805 int op_type = 11;
5806 fprintf(pp, " 0=%d", op_type);
5807 }
5808 else if (op == "Tanh")
5809 {
5810 int op_type = 16;
5811 fprintf(pp, " 0=%d", op_type);
5812 }
5813 else if (op == "Transpose")
5814 {
5815 std::vector<int> perm = get_node_attr_ai(node, "perm");
5816
5817 if (perm.size() == 3)
5818 {
5819 if (perm[1] == 1 && perm[2] == 2)
5820 fprintf(pp, " 0=0"); // w h
5821 else if (perm[1] == 2 && perm[2] == 1)
5822 fprintf(pp, " 0=1"); // h w
5823 else if (perm[0] == 1 && perm[1] == 0 && perm[2] == 2)
5824 fprintf(pp, " 0=0"); // w h
5825 else if (perm[0] == 2 && perm[1] == 0 && perm[2] == 1)
5826 fprintf(pp, " 0=1"); // h w
5827 }
5828 else if (perm.size() == 4)
5829 {
5830 if (perm[1] == 1 && perm[2] == 2 && perm[3] == 3)
5831 fprintf(pp, " 0=0"); // w h c
5832 else if (perm[1] == 1 && perm[2] == 3 && perm[3] == 2)
5833 fprintf(pp, " 0=1"); // h w c
5834 else if (perm[1] == 2 && perm[2] == 1 && perm[3] == 3)
5835 fprintf(pp, " 0=2"); // w c h
5836 else if (perm[1] == 2 && perm[2] == 3 && perm[3] == 1)
5837 fprintf(pp, " 0=3"); // c w h
5838 else if (perm[1] == 3 && perm[2] == 1 && perm[3] == 2)
5839 fprintf(pp, " 0=4"); // h c w
5840 else if (perm[1] == 3 && perm[2] == 2 && perm[3] == 1)
5841 fprintf(pp, " 0=5"); // c h w
5842 }
5843 else if (perm.size() == 5)
5844 {
5845 if (perm[1] == 1 && perm[2] == 2 && perm[3] == 3 && perm[4] == 4)
5846 fprintf(pp, " 0=0"); // wx h c
5847 else if (perm[1] == 1 && perm[2] == 3 && perm[3] == 4 && perm[4] == 2)
5848 fprintf(pp, " 0=1"); // h wx c
5849 else if (perm[1] == 2 && perm[2] == 1 && perm[3] == 3 && perm[4] == 4)
5850 fprintf(pp, " 0=2"); // wx c h
5851 else if (perm[1] == 2 && perm[2] == 3 && perm[3] == 4 && perm[4] == 1)
5852 fprintf(pp, " 0=3"); // c wx h
5853 else if (perm[1] == 3 && perm[2] == 4 && perm[3] == 1 && perm[4] == 2)
5854 fprintf(pp, " 0=4"); // h c wx
5855 else if (perm[1] == 3 && perm[2] == 4 && perm[3] == 2 && perm[4] == 1)
5856 fprintf(pp, " 0=5"); // c h wx
5857 else
5858 fprintf(stderr, "Unsupported transpose type !\n");
5859 }
5860 }
5861 else if (op == "Upsample")
5862 {
5863 std::string mode = get_node_attr_s(node, "mode");
5864 std::string align = get_node_attr_s(node, "coordinate_transformation_mode");
5865
5866 std::vector<float> scales;
5867
5868 if (node.input_size() == 1)
5869 {
5870 scales = get_node_attr_af(node, "scales");
5871 }
5872 else
5873 {
5874 scales = get_node_attr_from_input_af(weights[node.input(1)]);
5875 }
5876
5877 int resize_type = 1;
5878 if (mode == "nearest")
5879 {
5880 resize_type = 1;
5881 }
5882 else if (mode == "bilinear" || mode == "linear")
5883 {
5884 resize_type = 2;
5885 }
5886 else if (mode == "trilinear")
5887 {
5888 fprintf(stderr, "Unsupported Upsample mode !\n");
5889 }
5890
5891 float h_scale = 1.f;
5892 float w_scale = 1.f;
5893 if (scales.size() == 2)
5894 {
5895 w_scale = scales[1];
5896 }
5897 else if (scales.size() == 3)
5898 {
5899 h_scale = scales[1];
5900 w_scale = scales[2];
5901 }
5902 else if (scales.size() == 4)
5903 {
5904 h_scale = scales[2];
5905 w_scale = scales[3];
5906
5907 if (scales[1] != 1.f)
5908 fprintf(stderr, "Unsupported Upsample scales !\n");
5909 }
5910 else
5911 {
5912 fprintf(stderr, "Unsupported Upsample scales !\n");
5913 }
5914
5915 int align_corner = 0;
5916 if (align == "align_corners")
5917 {
5918 align_corner = 1;
5919 }
5920
5921 fprintf(pp, " 0=%d", resize_type);
5922 fprintf(pp, " 1=%e", h_scale);
5923 fprintf(pp, " 2=%e", w_scale);
5924 fprintf(pp, " 6=%d", align_corner);
5925 }
5926 else if (op == "Unsqueeze")
5927 {
5928 std::vector<int> axes = get_node_attr_ai(node, "axes");
5929
5930 fprintf(pp, " -23303=%zu", axes.size());
5931 for (int i = 0; i < (int)axes.size(); i++)
5932 {
5933 if (axes[i] == 0 || axes[i] > 4 || axes[i] < -4)
5934 fprintf(stderr, "Unsupported unsqueeze axes !\n");
5935 fprintf(pp, ",%d", axes[i]);
5936 }
5937 }
5938 else
5939 {
5940 // TODO op specific param
5941 for (int j = 0; j < node.attribute_size(); j++)
5942 {
5943 const onnx::AttributeProto& attr = node.attribute(j);
5944 if (attr.type() == 1)
5945 {
5946 fprintf(stderr, " # %s=%g\n", attr.name().c_str(), attr.f());
5947 }
5948 else if (attr.type() == 2)
5949 {
5950 fprintf(stderr, " # %s=%lld\n", attr.name().c_str(), (long long)attr.i());
5951 }
5952 else if (attr.type() == 3)
5953 {
5954 fprintf(stderr, " # %s=%s\n", attr.name().c_str(), attr.s().c_str());
5955 }
5956 else
5957 {
5958 fprintf(stderr, " # %s %d\n", attr.name().c_str(), attr.type());
5959 }
5960 }
5961 }
5962
5963 fprintf(pp, "\n");
5964
5965 for (int j = 0; j < output_size; j++)
5966 {
5967 const std::string& output_name = node.output(j);
5968 if (node_reference.find(output_name) != node_reference.end())
5969 {
5970 int refcount = node_reference[output_name];
5971 if (refcount > 1)
5972 {
5973 char splitname[256];
5974 sprintf(splitname, "splitncnn_%d", internal_split);
5975 fprintf(pp, "%-16s %-24s %d %d", "Split", splitname, 1, refcount);
5976
5977 fprintf(pp, " %s", output_name.c_str());
5978
5979 for (int k = 0; k < refcount; k++)
5980 {
5981 fprintf(pp, " %s_splitncnn_%d", output_name.c_str(), k);
5982 }
5983 fprintf(pp, "\n");
5984
5985 internal_split++;
5986 }
5987 }
5988 }
5989 }
5990
5991 fclose(pp);
5992 fclose(bp);
5993
5994 return 0;
5995 }
5996